Skip to content

Seq2seq generation with prefix #3739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/transformers/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,11 +936,8 @@ def forward(
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"

# first step, decoder_cached_states are empty
if not past[1]:
encoder_outputs, decoder_cached_states = past, None
else:
encoder_outputs, decoder_cached_states = past
encoder_outputs, decoder_cached_states = past
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this change!


return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,11 +1160,7 @@ def forward(
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"

# first step
if len(past) < 2:
encoder_outputs, decoder_past_key_value_states = past, None
else:
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
encoder_outputs, decoder_past_key_value_states = past

return {
"decoder_input_ids": input_ids,
Expand Down
74 changes: 45 additions & 29 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def _do_output_past(self, outputs):
def generate(
self,
input_ids=None,
prefix_ids=None,
max_length=None,
min_length=None,
do_sample=None,
Expand Down Expand Up @@ -488,15 +489,21 @@ def generate(

Parameters:

input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)`
input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)`
Language model: alias for prefix_ids, the sequence used as a prompt for the generation.
Seq2seq model: the sequence input to the encoder model.

prefix_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think prefix_ids should never overwrite input_ids. For me it would be ok to add an optional decoder_input_ids variable that can be used only for encoder-decoder models, when the user wants to encode the input_ids and then generate from something specific the decoder_input_ids. Initially, I didn't see many use cases for this, but maybe there are??
Also it might be good to keep the same wording we used for the models:
If having only "one" model (GPT2), never name anything with decoder_ . If having an encoder-decoder model, the normal variables "input_ids, attention_mask" are used for the encoder and the decoder uses the same naming and prepends a "decoder_" to the variable names.

The sequence used as a prompt for the generation. If `None` the method initializes
it as an empty `torch.LongTensor` of shape `(1,)`.
Language model: If `None`, uses input_ids.

max_length: (`optional`) int
The max length of the sequence to be generated. Between 1 and infinity. Default to 20.

min_length: (`optional`) int
The min length of the sequence to be generated. Between 0 and infinity. Default to 0.

do_sample: (`optional`) bool
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`.

Expand Down Expand Up @@ -601,6 +608,17 @@ def generate(
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
)

if self.config.is_encoder_decoder:
bos_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id
bos_token_id = bos_token_id if bos_token_id is not None else self.config.decoder_start_token_id
Copy link
Contributor

@patrickvonplaten patrickvonplaten Apr 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we get rid of decoder_start_token_id but still use it here as config.decoder_start_token_id.
Also I'm not sure if this works in general. Bart e.g. has a different decoder_start_token_id than bos_token_id.

Also please always check the hard-coded integration tests of at least Bart, T5 and GPT2 with your changes:
RUN_SLOW=1 pytest tests/test_modeling_bart.py
RUN_SLOW=1 pytest tests/test_modeling_gpt2.py
RUN_SLOW=1 pytest tests/test_modeling_t5.py
RUN_SLOW=1 pytest tests/test_modeling_tf_gpt2.py
RUN_SLOW=1 pytest tests/test_modeling_tf_t5.py

bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id

assert (
bos_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
else:
prefix_ids = prefix_ids if prefix_ids is not None else input_ids

max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
Expand All @@ -621,12 +639,11 @@ def generate(
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
)

if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
batch_size = input_ids.shape[0] # overriden by the input batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not use tensor.shape in TF. The function shape_list was made so that shape behaves correctly in eager mode and no eager mode: #3063

elif prefix_ids is not None:
batch_size = prefix_ids.shape[0]
else:
batch_size = 1

Expand All @@ -639,7 +656,7 @@ def generate(
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1."
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1."
assert input_ids is not None or (
assert prefix_ids is not None or (
isinstance(bos_token_id, int) and bos_token_id >= 0
), "If input_ids is not defined, `bos_token_id` should be a positive integer."
assert pad_token_id is None or (
Expand All @@ -656,14 +673,16 @@ def generate(
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"

if input_ids is None:
if prefix_ids is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like prefix_ids overwriting input_ids as mentioned before. Why should we do that?

assert isinstance(bos_token_id, int) and bos_token_id >= 0, (
"you should either supply a context to complete as `input_ids` input "
"or a `bos_token_id` (integer >= 0) as a first token to start the generation."
)
input_ids = tf.fill((batch_size, 1), bos_token_id)
prefix_ids = tf.fill((batch_size, 1), bos_token_id)
if input_ids is None:
input_ids = tf.fill((batch_size, 1), bos_token_id)
else:
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."
assert len(shape_list(prefix_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)."

# not allow to duplicate outputs when greedy decoding
if do_sample is False:
Expand Down Expand Up @@ -693,7 +712,7 @@ def generate(
pad_token_id = eos_token_id

# current position and vocab size
cur_len = shape_list(input_ids)[1]
cur_len = shape_list(prefix_ids)[1]
vocab_size = self.config.vocab_size

# set effective batch size and effective batch multiplier according to do_sample
Expand All @@ -707,45 +726,41 @@ def generate(
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = shape_list(input_ids)[-1]
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
prefix_ids_len = shape_list(prefix_ids)[-1]
prefix_ids = tf.broadcast_to(
tf.expand_dims(prefix_ids, 1), (batch_size, effective_batch_mult * num_beams, prefix_ids_len)
)
attention_mask = tf.broadcast_to(
tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
)
input_ids = tf.reshape(
input_ids, (effective_batch_size * num_beams, input_ids_len)
prefix_ids = tf.reshape(
prefix_ids, (effective_batch_size * num_beams, prefix_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = tf.reshape(
attention_mask, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
if self.config.is_encoder_decoder:
input_ids = tf.broadcast_to(
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len)
)
input_ids = tf.reshape(
input_ids, (effective_batch_size * num_beams, input_ids_len)
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)

# If encoder-decoder, get encoder outputs
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id

assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

# get encoder and store encoder outputs
encoder = self.get_encoder()

encoder_outputs = encoder(input_ids, attention_mask=attention_mask)

# create empty decoder_input_ids
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id
cur_len = 1

else:
encoder_outputs = None
cur_len = shape_list(input_ids)[-1]

if num_beams > 1:
output = self._generate_beam_search(
input_ids,
input_ids=prefix_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
Expand All @@ -771,7 +786,7 @@ def generate(
)
else:
output = self._generate_no_beam_search(
input_ids,
input_ids=prefix_ids,
cur_len=cur_len,
max_length=max_length,
min_length=min_length,
Expand Down Expand Up @@ -1007,6 +1022,7 @@ def _generate_beam_search(

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
print(model_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure to remove this at a later stage ;-)

outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)

Expand Down
Loading