Skip to content

Seq2seq generation with prefix #3739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

yjernite
Copy link
Member

@yjernite yjernite commented Apr 10, 2020

This PR introduces two small changes in the way model.generate() works.

Previously, the function took in an input_ids argument which had different behaviors in the seq2seq and language model settings: in language modeling, input_ids could be used to provide a prefix for the generation, while in seq2seq, input_ids represented the encoder input and the generation prefix was automatically initialized to a batch with one time step willed with the [BOS] token.

Conceptually, this feels a little awkward, as a language model and the decoder of a seq2seq model should really behave similarly (the latter just has added conditioning). And more importantly, there was no way to provide both the encoder input_ids and a generation prefix in the seq2seq model.

I've added a prefix_ids argument to fix that. The model will still default to using input_ids as a prefix in the language model setting so as not to break current use cases, but otherwise the model works with prefix_ids and initializes it similarly for the LM and seq2seq settings.

The second smaller change is the initialization of the past variable in generate_beam_search and generate_no_beam_search: it is now initialized to the form it will have in later generation steps, so we can dispense with the firs step tests in the prepare_inputs_for_generation functions in modeling_t5.py and modeling_bart.py

(Next time I'll do two separate PR's as suggested by @sshleifer :) )

@yjernite yjernite changed the title continue from encoder_decoder_generation branch Seq2seq generation with prefix Apr 10, 2020
@yjernite yjernite requested a review from thomwolf April 10, 2020 21:16
@codecov-io
Copy link

codecov-io commented Apr 10, 2020

Codecov Report

Merging #3739 into master will decrease coverage by 0.01%.
The diff coverage is 95.74%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #3739      +/-   ##
==========================================
- Coverage   78.27%   78.26%   -0.02%     
==========================================
  Files         104      104              
  Lines       17835    17843       +8     
==========================================
+ Hits        13960    13964       +4     
- Misses       3875     3879       +4     
Impacted Files Coverage Δ
src/transformers/modeling_utils.py 91.81% <95.45%> (-0.09%) ⬇️
src/transformers/modeling_tf_utils.py 92.89% <95.65%> (-0.08%) ⬇️
src/transformers/modeling_bart.py 96.48% <100.00%> (-0.02%) ⬇️
src/transformers/modeling_t5.py 82.77% <100.00%> (-0.44%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7a7fdf7...0de2191. Read the comment docs.

past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
past = (
(encoder_outputs, None) if encoder_outputs is not None else None
) # defined for encoder-decoder models, None for decoder-only models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ilke that idea here - that removes the unnecessary if else statements in the prepare_inputs_for_generation functions

@@ -601,6 +603,14 @@ def generate(
"Please use another model class (e.g. `TFOpenAIGPTLMHeadModel`, `TFXLNetLMHeadModel`, `TFGPT2LMHeadModel`, `TFCTRLLMHeadModel`, `TFT5ForConditionalGeneration`, `TFTransfoXLLMHeadModel`)"
)

if self.config.is_encoder_decoder:
bos_token_id = bos_token_id if bos_token_id is not None else self.config.decoder_start_token_id
Copy link
Contributor

@patrickvonplaten patrickvonplaten Apr 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we get rid of decoder_start_token_id but still use it here as config.decoder_start_token_id.
Also I'm not sure if this works in general. Bart e.g. has a different decoder_start_token_id than bos_token_id.

Also please always check the hard-coded integration tests of at least Bart, T5 and GPT2 with your changes:
RUN_SLOW=1 pytest tests/test_modeling_bart.py
RUN_SLOW=1 pytest tests/test_modeling_gpt2.py
RUN_SLOW=1 pytest tests/test_modeling_t5.py
RUN_SLOW=1 pytest tests/test_modeling_tf_gpt2.py
RUN_SLOW=1 pytest tests/test_modeling_tf_t5.py

Language model: alias for prefix_ids, the sequence used as a prompt for the generation.
Seq2seq model: the sequence input to the encoder model.

prefix_ids: (`optional`) `tf.Tensor` of `dtype=tf.int32 of shape `(batch_size, sequence_length)`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think prefix_ids should never overwrite input_ids. For me it would be ok to add an optional decoder_input_ids variable that can be used only for encoder-decoder models, when the user wants to encode the input_ids and then generate from something specific the decoder_input_ids. Initially, I didn't see many use cases for this, but maybe there are??
Also it might be good to keep the same wording we used for the models:
If having only "one" model (GPT2), never name anything with decoder_ . If having an encoder-decoder model, the normal variables "input_ids, attention_mask" are used for the encoder and the decoder uses the same naming and prepends a "decoder_" to the variable names.


if input_ids is not None:
batch_size = shape_list(input_ids)[0] # overriden by the input batch_size
batch_size = input_ids.shape[0] # overriden by the input batch_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not use tensor.shape in TF. The function shape_list was made so that shape behaves correctly in eager mode and no eager mode: #3063

@@ -656,14 +665,16 @@ def generate(
bad_words_ids is None or isinstance(bad_words_ids, list) and isinstance(bad_words_ids[0], list)
), "`bad_words_ids` is either `None` or a list of lists of tokens that should not be generated"

if input_ids is None:
if prefix_ids is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like prefix_ids overwriting input_ids as mentioned before. Why should we do that?

@@ -1007,6 +1010,7 @@ def _generate_beam_search(

while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
print(model_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure to remove this at a later stage ;-)

encoder_outputs, decoder_cached_states = past, None
else:
encoder_outputs, decoder_cached_states = past
encoder_outputs, decoder_cached_states = past
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like this change!

@patrickvonplaten
Copy link
Contributor

In general, I think we have to be careful with a distinction between the different special token ids.
I can see why decoder_input_token_id looks weird at first glance, but in #3225 and #3140, we decided to add it to keep Bart's good performance on summarization.

I don't really see the need to overwrite input_ids with prefix_ids - do we have to do this?
I would be ok with adding an optional decoder_input_ids that would be used for encoder-decoder models only.

There are quite a few hidden hacks in generation() (like the force_token_id fn) that look quite strange. If we replace / delete them, we should always check that the hard-coded integration tests don't fail (running the tests with Run_SLOW=1 as mentioned above.

@yjernite yjernite force-pushed the generate_separate_decoder_input_ids branch from 3e82158 to 0de2191 Compare April 13, 2020 13:35
@stale
Copy link

stale bot commented Jul 8, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jul 8, 2020
@yjernite yjernite closed this Jul 8, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants