Skip to content

Commit d2d9b5b

Browse files
authored
fix output shape inference packed gqa (#19374)
### Description fix output shape inference packed gqa
1 parent d120104 commit d2d9b5b

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

onnxruntime/core/graph/contrib_ops/bert_defs.cc

+10
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,16 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
259259
*output_shape.add_dim() = query_dims[1];
260260
*output_shape.add_dim() = query_dims[2];
261261
updateOutputShape(ctx, 0, output_shape);
262+
} else {
263+
ONNX_NAMESPACE::TensorShapeProto output_shape;
264+
int64_t num_heads = getAttribute(ctx, "num_heads", 0);
265+
int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0);
266+
int64_t hidden_size = query_dims[2].dim_value();
267+
int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads);
268+
*output_shape.add_dim() = query_dims[0];
269+
*output_shape.add_dim() = query_dims[1];
270+
output_shape.add_dim()->set_dim_value(head_size * num_heads);
271+
updateOutputShape(ctx, 0, output_shape);
262272
}
263273
}
264274

0 commit comments

Comments
 (0)