Skip to content

Commit 4513476

Browse files
authored
[webgpu] Fix multihead-attention for ort-web-tests (#24485)
Compute 'total_sequence_length' the same way as JSEP.
1 parent 64b5642 commit 4513476

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,8 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T
446446
WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* seqlen_k) {
447447
const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)});
448448
const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0;
449-
const int total_sequence_length = parameters.total_sequence_length_;
449+
const int total_sequence_length =
450+
parameters.is_gqa_ ? parameters.total_sequence_length_ : past_sequence_length + parameters.kv_sequence_length_;
450451

451452
const TensorShapeVector probs_dims({parameters.batch_size_, parameters.num_heads_,
452453
parameters.sequence_length_, total_sequence_length});

0 commit comments

Comments
 (0)