@@ -201,6 +201,11 @@ void main() {
201
201
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202
202
uint32_t k_stride = p.nb11;
203
203
uint32_t v_stride = p.nb21;
204
+ // When using grouped query attention, all rows use the same mask (stride 0).
205
+ // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
206
+ // that prevents the compiler from folding the "&" through the select
207
+ // and breaking the alignment detection.
208
+ uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
204
209
// hint to the compiler that strides are aligned for the aligned variant of the shader
205
210
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
206
211
{
@@ -209,6 +214,7 @@ void main() {
209
214
k_stride &= ~7;
210
215
v_stride &= ~7;
211
216
#endif
217
+ m_stride &= ~7;
212
218
}
213
219
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
214
220
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@@ -261,10 +267,7 @@ void main() {
261
267
if (p.mask != 0) {
262
268
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
263
269
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
264
- // When using grouped query attention, all rows use the same mask.
265
- if (p.gqa_ratio > 1) {
266
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
267
- }
270
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
268
271
269
272
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
270
273
0 commit comments