-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Seq2seq generation with prefix #3739
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Seq2seq generation with prefix #3739
Conversation
Codecov Report
@@ Coverage Diff @@
## master #3739 +/- ##
==========================================
- Coverage 78.27% 78.26% -0.02%
==========================================
Files 104 104
Lines 17835 17843 +8
==========================================
+ Hits 13960 13964 +4
- Misses 3875 3879 +4
Continue to review full report at Codecov.
|
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models | ||
past = ( | ||
(encoder_outputs, None) if encoder_outputs is not None else None | ||
) # defined for encoder-decoder models, None for decoder-only models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ilke that idea here - that removes the unnecessary if else statements in the prepare_inputs_for_generation
functions
@@ -601,6 +603,14 @@ def generate( | |||
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)" | |||
) | |||
|
|||
if self.config.is_encoder_decoder: | |||
bos_token_id = bos_token_id if bos_token_id is not None else self.config.decoder_start_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we get rid of decoder_start_token_id
but still use it here as config.decoder_start_token_id
.
Also I'm not sure if this works in general. Bart e.g. has a different decoder_start_token_id
than bos_token_id
.
Also please always check the hard-coded integration tests of at least Bart, T5 and GPT2 with your changes:
RUN_SLOW=1 pytest tests/test_modeling_bart.py
RUN_SLOW=1 pytest tests/test_modeling_gpt2.py
RUN_SLOW=1 pytest tests/test_modeling_t5.py
RUN_SLOW=1 pytest tests/test_modeling_tf_gpt2.py
RUN_SLOW=1 pytest tests/test_modeling_tf_t5.py
Language model: alias for prefix_ids, the sequence used as a prompt for the generation. | ||
Seq2seq model: the sequence input to the encoder model. | ||
|
||
prefix_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think prefix_ids
should never overwrite input_ids
. For me it would be ok to add an optional decoder_input_ids
variable that can be used only for encoder-decoder models, when the user wants to encode the input_ids
and then generate from something specific the decoder_input_ids
. Initially, I didn't see many use cases for this, but maybe there are??
Also it might be good to keep the same wording we used for the models:
If having only "one" model (GPT2), never name anything with decoder_
. If having an encoder-decoder model, the normal variables "input_ids, attention_mask" are used for the encoder and the decoder uses the same naming and prepends a "decoder_" to the variable names.
|
||
if input_ids is not None: | ||
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size | ||
batch_size = input_ids.shape[0] # overriden by the input batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not use tensor.shape
in TF. The function shape_list
was made so that shape
behaves correctly in eager mode and no eager mode: #3063
@@ -656,14 +665,16 @@ def generate( | |||
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list) | |||
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated" | |||
|
|||
if input_ids is None: | |||
if prefix_ids is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like prefix_ids
overwriting input_ids
as mentioned before. Why should we do that?
@@ -1007,6 +1010,7 @@ def _generate_beam_search( | |||
|
|||
while cur_len < max_length: | |||
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) | |||
print(model_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure to remove this at a later stage ;-)
encoder_outputs, decoder_cached_states = past, None | ||
else: | ||
encoder_outputs, decoder_cached_states = past | ||
encoder_outputs, decoder_cached_states = past |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like this change!
In general, I think we have to be careful with a distinction between the different special token ids. I don't really see the need to overwrite There are quite a few hidden hacks in |
3e82158
to
0de2191
Compare
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
This PR introduces two small changes in the way model.generate() works.
Previously, the function took in an input_ids argument which had different behaviors in the seq2seq and language model settings: in language modeling, input_ids could be used to provide a prefix for the generation, while in seq2seq, input_ids represented the encoder input and the generation prefix was automatically initialized to a batch with one time step willed with the [BOS] token.
Conceptually, this feels a little awkward, as a language model and the decoder of a seq2seq model should really behave similarly (the latter just has added conditioning). And more importantly, there was no way to provide both the encoder input_ids and a generation prefix in the seq2seq model.
I've added a prefix_ids argument to fix that. The model will still default to using input_ids as a prefix in the language model setting so as not to break current use cases, but otherwise the model works with prefix_ids and initializes it similarly for the LM and seq2seq settings.
The second smaller change is the initialization of the past variable in generate_beam_search and generate_no_beam_search: it is now initialized to the form it will have in later generation steps, so we can dispense with the firs step tests in the prepare_inputs_for_generation functions in modeling_t5.py and modeling_bart.py
(Next time I'll do two separate PR's as suggested by @sshleifer :) )