File tree 1 file changed +10
-0
lines changed
onnxruntime/core/graph/contrib_ops
1 file changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -259,6 +259,16 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext&
259
259
*output_shape.add_dim () = query_dims[1 ];
260
260
*output_shape.add_dim () = query_dims[2 ];
261
261
updateOutputShape (ctx, 0 , output_shape);
262
+ } else {
263
+ ONNX_NAMESPACE::TensorShapeProto output_shape;
264
+ int64_t num_heads = getAttribute (ctx, " num_heads" , 0 );
265
+ int64_t kv_num_heads = getAttribute (ctx, " kv_num_heads" , 0 );
266
+ int64_t hidden_size = query_dims[2 ].dim_value ();
267
+ int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads);
268
+ *output_shape.add_dim () = query_dims[0 ];
269
+ *output_shape.add_dim () = query_dims[1 ];
270
+ output_shape.add_dim ()->set_dim_value (head_size * num_heads);
271
+ updateOutputShape (ctx, 0 , output_shape);
262
272
}
263
273
}
264
274
You can’t perform that action at this time.
0 commit comments