Skip to content

Commit 485e17e

Browse files
kunal-vaishnavirachguo
authored and
rachguo
committed
Whisper Timestamps and Temperature (#19509)
This PR updates exporting and running the Whisper model with beam search by adding the following. - Adds temperature as a graph input to the exported model - Fixes the token ids by adding them as attributes to `WhisperBeamSearch` - Fixes the timestamps test cases so they pass now - Fixes a bug with invoking `torch.onnx.export` - Cleans up the Whisper scripts and groups the arguments in `convert_to_onnx.py` - Adds a `requirements.txt` file to specify package dependencies - Adds `whisper-large-v3` to list of pretrained models - Fixes a bug with missing cross-attention KV cache inputs in the decoder subgraph - This is a follow-up to [this PR](#19188). - The incorrect token ids in the timestamps processor were first noticed during [this PR review](#17500 (comment)). When they were originally added in [this PR](#15853), the offsets were previously constant across the Whisper model sizes. When comparing the new `whisper-large-v3` variant, the English-only variants (e.g. `whisper-tiny.en`), and the original variants (e.g. `whisper-tiny`), both the values and the offsets differ. Therefore, it is easier to set the token ids as attributes to `WhisperBeamSearch` when exporting to ensure the right values are used in the timestamps processor. - The Hugging Face API for returning timestamps and the expected outputs from the PyTorch model have both changed. - The fix for `torch.onnx.export` is a follow-up to [this PR review](#17179 (comment)). - The argument grouping is a follow-up to [this PR review](#17500 (comment)). - Specific package versions are needed to run the Whisper scripts and the `requirements.txt` file ensures that these versions are installed. - The `whisper-large-v3` variant is released and should be in the list of official pretrained models. - After the changes from [this PR](#17316), the exported model is not loading in an ORT inference session because the cross-attention KV cache inputs are missing in the decoder subgraph.
1 parent ad86d13 commit 485e17e

21 files changed

+578
-370
lines changed

docs/ContribOperators.md

+21-11
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
461461
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
462462
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
463463
<dt><tt>vocab_mask</tt> (optional) : M</dt>
464-
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
464+
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
465465
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
466466
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
467467
<dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
22522252
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
22532253
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
22542254
<dt><tt>vocab_mask</tt> (optional) : I</dt>
2255-
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
2255+
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
22562256
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
22572257
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
22582258
<dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
51545154
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
51555155
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
51565156
<dt><tt>vocab_mask</tt> (optional) : I</dt>
5157-
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
5157+
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
51585158
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
51595159
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
51605160
<dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr
57435743
#### Attributes
57445744

57455745
<dl>
5746+
<dt><tt>beginning_timestamp_token_id</tt> : int</dt>
5747+
<dd>The id of the first timestamp</dd>
57465748
<dt><tt>decoder</tt> : graph (required)</dt>
57475749
<dd>Decoder subgraph to execute in a loop.</dd>
57485750
<dt><tt>decoder_output_cross_qk</tt> : int</dt>
57495751
<dd>If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.</dd>
57505752
<dt><tt>decoder_start_token_id</tt> : int</dt>
5751-
<dd>The id of the token that indicates decoding starts.</dd>
5753+
<dd>The id of the token that indicates decoding starts (i.e. the start of transcription token id)</dd>
57525754
<dt><tt>early_stopping</tt> : int</dt>
57535755
<dd>early stop or not</dd>
57545756
<dt><tt>encoder</tt> : graph</dt>
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
57615763
<dd>Must be 2 for whisper</dd>
57625764
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
57635765
<dd>no repeat ngrams size</dd>
5764-
<dt><tt>no_speech_token</tt> : int</dt>
5766+
<dt><tt>no_speech_token_id</tt> : int</dt>
57655767
<dd>The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.</dd>
5768+
<dt><tt>no_timestamps_token_id</tt> : int</dt>
5769+
<dd>The id of the token that indicates no timestamps</dd>
57665770
<dt><tt>pad_token_id</tt> : int (required)</dt>
57675771
<dd>The id of the padding token</dd>
5772+
<dt><tt>start_of_lm_token_id</tt> : int</dt>
5773+
<dd>The id of the token that indicates LM starts</dd>
5774+
<dt><tt>transcribe_token_id</tt> : int</dt>
5775+
<dd>The id of the transcribe task</dd>
5776+
<dt><tt>translate_token_id</tt> : int</dt>
5777+
<dd>The id of the translate task</dd>
57685778
<dt><tt>vocab_size</tt> : int</dt>
57695779
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
57705780
</dl>
@@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
57835793
<dt><tt>num_return_sequences</tt> : I</dt>
57845794
<dd>The number of returned sequences in the batch. Shape is (1)</dd>
57855795
<dt><tt>length_penalty</tt> (optional) : T</dt>
5786-
<dd>Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)</dd>
5796+
<dd>Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)</dd>
57875797
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
57885798
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
57895799
<dt><tt>vocab_mask</tt> (optional) : M</dt>
5790-
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
5800+
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)</dd>
57915801
<dt><tt>prefix_vocab_mask</tt> (optional) : M</dt>
57925802
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
57935803
<dt><tt>attention_mask</tt> (optional) : I</dt>
@@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
57975807
<dt><tt>logits_processor</tt> (optional) : I</dt>
57985808
<dd>Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)</dd>
57995809
<dt><tt>cross_qk_layer_head</tt> (optional) : I</dt>
5800-
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
5810+
<dd>Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]</dd>
58015811
<dt><tt>extra_decoding_ids</tt> (optional) : I</dt>
58025812
<dd>Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.</dd>
58035813
</dl>
@@ -5810,11 +5820,11 @@ This version of the operator has been available since version 1 of the 'com.micr
58105820
<dt><tt>sequences_scores</tt> (optional) : T</dt>
58115821
<dd>Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)</dd>
58125822
<dt><tt>scores</tt> (optional) : T</dt>
5813-
<dd>Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
5823+
<dd>Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
58145824
<dt><tt>cross_qk</tt> (optional) : V</dt>
5815-
<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
5825+
<dd>Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]</dd>
58165826
<dt><tt>non_speech_probs</tt> (optional) : T</dt>
5817-
<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]</dd>
5827+
<dd>For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]</dd>
58185828
</dl>
58195829

58205830
#### Type Constraints

onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
134134
TensorShape no_speech_probs_shape{parameters->batch_size};
135135
Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape);
136136
if (no_speech_probs && no_speech_probs->MutableData<T>()) {
137-
ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size,
138-
"no_speech_token id out of range, it is ", parameters->no_speech_token,
137+
ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size,
138+
"no_speech_token_id is out of range, it is ", parameters->no_speech_token_id,
139139
", vocab_size is ", parameters->vocab_size);
140140
this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData<T>();
141141
}

onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc

+7-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info)
141141
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeWhisper));
142142
ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper);
143143

144-
no_speech_token = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token", -1LL));
144+
// Token ids are defined below in the order that they appear in the tokenizer
145+
translate_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("translate_token_id", -1LL));
146+
transcribe_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("transcribe_token_id", -1LL));
147+
start_of_lm_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("start_of_lm_token_id", -1LL));
148+
no_speech_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_speech_token_id", -1LL));
149+
no_timestamps_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_timestamps_token_id", -1LL));
150+
beginning_timestamp_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("beginning_timestamp_token_id", -1LL));
145151
cross_qk_layer_head_input_id = 12;
146152
extra_decoding_ids_input_id = 13;
147153
cross_qk_output_id = 3;

onnxruntime/contrib_ops/cpu/transformers/generation_shared.h

+8-1
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,14 @@ struct IGenerationParameters {
180180
// Parameters for whisper model
181181
bool decoder_output_cross_qk = false;
182182
gsl::span<const int32_t> extra_decoding_ids;
183-
int32_t no_speech_token = -1;
183+
184+
// Token ids are defined below in the order that they appear in the tokenizer
185+
int32_t translate_token_id = -1;
186+
int32_t transcribe_token_id = -1;
187+
int32_t start_of_lm_token_id = -1;
188+
int32_t no_speech_token_id = -1;
189+
int32_t no_timestamps_token_id = -1;
190+
int32_t beginning_timestamp_token_id = -1;
184191
void* no_speech_probs = nullptr;
185192

186193
int cross_qk_layer_head_input_id = -1;

0 commit comments

Comments
 (0)