Skip to content

Refactoring and bug fixing beam search generate #3135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Mar 5, 2020

This PR cleanes the beam_search decoding part of language generation. It simplifies the code and fixes a small bug for do_sample=True (see comments in code).
It was also tested on all language generation slow tests.

Future PR

# stop when there is a </s> in each sentence, or if we exceed the maximul length
if unfinished_sents.max() == 0:
break

cur_len = cur_len + 1
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Mar 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think it's always good to have cur_len += 1 as the last statement in the loop!

@@ -996,6 +997,9 @@ def _generate_beam_search(
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)

# sort the sampled vector to make sure that the first num_beams samples are the best
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Mar 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this small sort function for the following reason. We are sampling 2 * num_beams samples per batch_idx and always take the first three of those samples. I think the first three samples that we take should then also correspond to the best three out of 6 samples.

@thomwolf @LysandreJik

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with that, it goes with the beam search philosophy indeed (a bit pushed to the extreme when you're sampling anyway)

if eos_token_ids is not None and token_id.item() in eos_token_ids:
generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item(),
input_ids[effective_beam_id].clone(), score.item(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cur_len parameter is useless here

# stop when we are done with each sentence
if all(done):
break

# update current length
cur_len = cur_len + 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think it's always good to have cur_len += 1 as the last statement in the loop!

generated_hyps[batch_idx].add(
input_ids[batch_idx * num_beams + beam_id, :cur_len].clone(), score.item()
)
# test that beam scores match previously calculated scores if not eos and batch_idx not done
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactoring was taken from bart's generate() function (@sshleifer) I think it's much cleaner by showing clearly that at the end the current best 3 (first 3) open hypotheses are added to the generated hypotheses. Also we take the final scores from the variable beam_scores here instead of next_scores (as before) which is the "correct" variable to take the scores from since it's the most current updated accumulated score.
Also an assert statement verifying the beam_scores are the correctly calculated is added making sure that the logic will not be broken in further changes. @thomwolf @LysandreJik

@patrickvonplaten patrickvonplaten changed the title [WIP] refactoring and bug fixing beam search generate Refactoring and bug fixing beam search generate Mar 5, 2020
@patrickvonplaten
Copy link
Contributor Author

Good to merge for me

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect!

@@ -996,6 +997,9 @@ def _generate_beam_search(
# Compute next scores
next_scores = torch.gather(_scores, -1, next_tokens) # (batch_size, num_beams * 2)

# sort the sampled vector to make sure that the first num_beams samples are the best
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with that, it goes with the beam search philosophy indeed (a bit pushed to the extreme when you're sampling anyway)

@thomwolf thomwolf merged commit bbabbc1 into huggingface:master Mar 5, 2020
@patrickvonplaten patrickvonplaten deleted the refactor_beam_search_generate branch March 5, 2020 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants