Skip to content

Commit e1823e5

Browse files
josharianiThalay
authored andcommitted
whisper : improve beam search candidate diversity (ggml-org#1947)
As of ggml-org#1486, whisper.cpp uses a unified KV cache with KQ masking. As a result, depending on their location in the batch, identical sequences in a batch can have slightly different outputs due to floating point rounding errors during reduction. See the discussion in ggml-org#1941 for more details. The beam search code used "has identical sum of log probabilities" as a shorthand for "is an identical token sequence". However, per above, identical tokens do not necessarily result in identical probabilities. Instead, explicitly compare on sequences. This is linear in cost when they are identical, but the lengths are always small and the comparisons are cheap. This increases diversity during beam search. This improves output quality for some short samples I've been working with, at no detectable performance cost. I haven't checked against larger corpuses. Fixes ggml-org#1941
1 parent 3238a84 commit e1823e5

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

whisper.cpp

+14-1
Original file line numberDiff line numberDiff line change
@@ -4759,6 +4759,19 @@ static void whisper_process_logits(
47594759
#endif
47604760
}
47614761

4762+
static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) {
4763+
if (a.tokens.size() != b.tokens.size()) {
4764+
return false;
4765+
}
4766+
// sequences are more likely to diverge at the end
4767+
for (int i = a.tokens.size() - 1; i >= 0; i--) {
4768+
if (a.tokens[i].id != b.tokens[i].id) {
4769+
return false;
4770+
}
4771+
}
4772+
return true;
4773+
}
4774+
47624775
static whisper_token_data whisper_sample_token(
47634776
whisper_context & ctx,
47644777
const whisper_decoder & decoder,
@@ -5378,7 +5391,7 @@ int whisper_full_with_state(
53785391

53795392
auto & cur = beam_candidates[cur_c++];
53805393

5381-
while (beam_candidates.size() > cur_c && beam_candidates[cur_c].sequence.sum_logprobs_all == cur.sequence.sum_logprobs_all && i > 0) {
5394+
while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) {
53825395
++cur_c;
53835396
}
53845397

0 commit comments

Comments
 (0)