-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Refactoring and bug fixing beam search generate #3135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring and bug fixing beam search generate #3135
Conversation
# stop when there is a </s> in each sentence, or if we exceed the maximul length | ||
if unfinished_sents.max() == 0: | ||
break | ||
|
||
cur_len = cur_len + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think it's always good to have cur_len += 1
as the last statement in the loop!
@@ -996,6 +997,9 @@ def _generate_beam_search( | |||
# Compute next scores | |||
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) | |||
|
|||
# sort the sampled vector to make sure that the first num_beams samples are the best |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this small sort function for the following reason. We are sampling 2 * num_beams
samples per batch_idx
and always take the first three of those samples. I think the first three samples that we take should then also correspond to the best three out of 6 samples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with that, it goes with the beam search philosophy indeed (a bit pushed to the extreme when you're sampling anyway)
if eos_token_ids is not None and token_id.item() in eos_token_ids: | ||
generated_hyps[batch_idx].add( | ||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(), | ||
input_ids[effective_beam_id].clone(), score.item(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cur_len
parameter is useless here
# stop when we are done with each sentence | ||
if all(done): | ||
break | ||
|
||
# update current length | ||
cur_len = cur_len + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think it's always good to have cur_len += 1 as the last statement in the loop!
generated_hyps[batch_idx].add( | ||
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item() | ||
) | ||
# test that beam scores match previously calculated scores if not eos and batch_idx not done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This refactoring was taken from bart's generate()
function (@sshleifer) I think it's much cleaner by showing clearly that at the end the current best 3 (first 3) open hypotheses are added to the generated hypotheses. Also we take the final scores from the variable beam_scores
here instead of next_scores
(as before) which is the "correct" variable to take the scores from since it's the most current updated accumulated score.
Also an assert statement verifying the beam_scores
are the correctly calculated is added making sure that the logic will not be broken in further changes. @thomwolf @LysandreJik
Good to merge for me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect!
@@ -996,6 +997,9 @@ def _generate_beam_search( | |||
# Compute next scores | |||
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2) | |||
|
|||
# sort the sampled vector to make sure that the first num_beams samples are the best |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with that, it goes with the beam search philosophy indeed (a bit pushed to the extreme when you're sampling anyway)
This PR cleanes the beam_search decoding part of language generation. It simplifies the code and fixes a small bug for do_sample=True (see comments in code).
It was also tested on all language generation slow tests.
Future PR