-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Seq2seq generation with prefix #3739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
77c7883
31a0da8
b62f3fc
29750fc
0de2191
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -458,6 +458,7 @@ def _do_output_past(self, outputs): | |
def generate( | ||
self, | ||
input_ids=None, | ||
prefix_ids=None, | ||
max_length=None, | ||
min_length=None, | ||
do_sample=None, | ||
|
@@ -488,15 +489,21 @@ def generate( | |
|
||
Parameters: | ||
|
||
input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32` of shape `(batch_size, sequence_length)` | ||
input_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)` | ||
Language model: alias for prefix_ids, the sequence used as a prompt for the generation. | ||
Seq2seq model: the sequence input to the encoder model. | ||
|
||
prefix_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
The sequence used as a prompt for the generation. If `None` the method initializes | ||
it as an empty `torch.LongTensor` of shape `(1,)`. | ||
Language model: If `None`, uses input_ids. | ||
|
||
max_length: (`optional`) int | ||
The max length of the sequence to be generated. Between 1 and infinity. Default to 20. | ||
|
||
min_length: (`optional`) int | ||
The min length of the sequence to be generated. Between 0 and infinity. Default to 0. | ||
|
||
do_sample: (`optional`) bool | ||
If set to `False` greedy decoding is used. Otherwise sampling is used. Defaults to `False` as defined in `configuration_utils.PretrainedConfig`. | ||
|
||
|
@@ -601,6 +608,17 @@ def generate( | |
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)" | ||
) | ||
|
||
if self.config.is_encoder_decoder: | ||
bos_token_id = decoder_start_token_id if decoder_start_token_id is not None else bos_token_id | ||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.decoder_start_token_id | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why we get rid of Also please always check the hard-coded integration tests of at least Bart, T5 and GPT2 with your changes: |
||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id | ||
|
||
assert ( | ||
bos_token_id is not None | ||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" | ||
else: | ||
prefix_ids = prefix_ids if prefix_ids is not None else input_ids | ||
|
||
max_length = max_length if max_length is not None else self.config.max_length | ||
min_length = min_length if min_length is not None else self.config.min_length | ||
do_sample = do_sample if do_sample is not None else self.config.do_sample | ||
|
@@ -621,12 +639,11 @@ def generate( | |
num_return_sequences = ( | ||
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences | ||
) | ||
decoder_start_token_id = ( | ||
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id | ||
) | ||
|
||
if input_ids is not None: | ||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size | ||
batch_size = input_ids.shape[0] # overriden by the input batch_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should not use |
||
elif prefix_ids is not None: | ||
batch_size = prefix_ids.shape[0] | ||
else: | ||
batch_size = 1 | ||
|
||
|
@@ -639,7 +656,7 @@ def generate( | |
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." | ||
assert 0 <= top_p <= 1, "`top_p` should be between 0 and 1." | ||
assert repetition_penalty >= 1.0, "`repetition_penalty` should be >= 1." | ||
assert input_ids is not None or ( | ||
assert prefix_ids is not None or ( | ||
isinstance(bos_token_id, int) and bos_token_id >= 0 | ||
), "If input_ids is not defined, `bos_token_id` should be a positive integer." | ||
assert pad_token_id is None or ( | ||
|
@@ -656,14 +673,16 @@ def generate( | |
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) | ||
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" | ||
|
||
if input_ids is None: | ||
if prefix_ids is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like |
||
assert isinstance(bos_token_id, int) and bos_token_id >= 0, ( | ||
"you should either supply a context to complete as `input_ids` input " | ||
"or a `bos_token_id` (integer >= 0) as a first token to start the generation." | ||
) | ||
input_ids = tf.fill((batch_size, 1), bos_token_id) | ||
prefix_ids = tf.fill((batch_size, 1), bos_token_id) | ||
if input_ids is None: | ||
input_ids = tf.fill((batch_size, 1), bos_token_id) | ||
else: | ||
assert len(shape_list(input_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)." | ||
assert len(shape_list(prefix_ids)) == 2, "Input prompt should be of shape (batch_size, sequence length)." | ||
|
||
# not allow to duplicate outputs when greedy decoding | ||
if do_sample is False: | ||
|
@@ -693,7 +712,7 @@ def generate( | |
pad_token_id = eos_token_id | ||
|
||
# current position and vocab size | ||
cur_len = shape_list(input_ids)[1] | ||
cur_len = shape_list(prefix_ids)[1] | ||
vocab_size = self.config.vocab_size | ||
|
||
# set effective batch size and effective batch multiplier according to do_sample | ||
|
@@ -707,45 +726,41 @@ def generate( | |
# Expand input ids if num_beams > 1 or num_return_sequences > 1 | ||
if num_return_sequences > 1 or num_beams > 1: | ||
input_ids_len = shape_list(input_ids)[-1] | ||
input_ids = tf.broadcast_to( | ||
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) | ||
prefix_ids_len = shape_list(prefix_ids)[-1] | ||
prefix_ids = tf.broadcast_to( | ||
tf.expand_dims(prefix_ids, 1), (batch_size, effective_batch_mult * num_beams, prefix_ids_len) | ||
) | ||
attention_mask = tf.broadcast_to( | ||
tf.expand_dims(attention_mask, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) | ||
) | ||
input_ids = tf.reshape( | ||
input_ids, (effective_batch_size * num_beams, input_ids_len) | ||
prefix_ids = tf.reshape( | ||
prefix_ids, (effective_batch_size * num_beams, prefix_ids_len) | ||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
attention_mask = tf.reshape( | ||
attention_mask, (effective_batch_size * num_beams, input_ids_len) | ||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
if self.config.is_encoder_decoder: | ||
input_ids = tf.broadcast_to( | ||
tf.expand_dims(input_ids, 1), (batch_size, effective_batch_mult * num_beams, input_ids_len) | ||
) | ||
input_ids = tf.reshape( | ||
input_ids, (effective_batch_size * num_beams, input_ids_len) | ||
) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
|
||
# If encoder-decoder, get encoder outputs | ||
if self.config.is_encoder_decoder: | ||
if decoder_start_token_id is None: | ||
decoder_start_token_id = bos_token_id | ||
|
||
assert ( | ||
decoder_start_token_id is not None | ||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation" | ||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) | ||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) | ||
|
||
# get encoder and store encoder outputs | ||
encoder = self.get_encoder() | ||
|
||
encoder_outputs = encoder(input_ids, attention_mask=attention_mask) | ||
|
||
# create empty decoder_input_ids | ||
input_ids = tf.ones((effective_batch_size * num_beams, 1), dtype=tf.int32,) * decoder_start_token_id | ||
cur_len = 1 | ||
|
||
else: | ||
encoder_outputs = None | ||
cur_len = shape_list(input_ids)[-1] | ||
|
||
if num_beams > 1: | ||
output = self._generate_beam_search( | ||
input_ids, | ||
input_ids=prefix_ids, | ||
cur_len=cur_len, | ||
max_length=max_length, | ||
min_length=min_length, | ||
|
@@ -771,7 +786,7 @@ def generate( | |
) | ||
else: | ||
output = self._generate_no_beam_search( | ||
input_ids, | ||
input_ids=prefix_ids, | ||
cur_len=cur_len, | ||
max_length=max_length, | ||
min_length=min_length, | ||
|
@@ -1007,6 +1022,7 @@ def _generate_beam_search( | |
|
||
while cur_len < max_length: | ||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) | ||
print(model_inputs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure to remove this at a later stage ;-) |
||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) | ||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like this change!