Skip to content

Commit a483757

Browse files
authored
vulkan: use aligned loads for flash attention mask (#12853)
Rewrite the stride logic for the mask tensor in the FA shader to force the stride to be aligned, to allow using more efficient loads.
1 parent e59ea53 commit a483757

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

+7-4
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ void main() {
201201
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
202202
uint32_t k_stride = p.nb11;
203203
uint32_t v_stride = p.nb21;
204+
// When using grouped query attention, all rows use the same mask (stride 0).
205+
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
206+
// that prevents the compiler from folding the "&" through the select
207+
// and breaking the alignment detection.
208+
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
204209
// hint to the compiler that strides are aligned for the aligned variant of the shader
205210
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
206211
{
@@ -209,6 +214,7 @@ void main() {
209214
k_stride &= ~7;
210215
v_stride &= ~7;
211216
#endif
217+
m_stride &= ~7;
212218
}
213219
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
214220
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@@ -261,10 +267,7 @@ void main() {
261267
if (p.mask != 0) {
262268
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
263269
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
264-
// When using grouped query attention, all rows use the same mask.
265-
if (p.gqa_ratio > 1) {
266-
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
267-
}
270+
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
268271

269272
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
270273

0 commit comments

Comments
 (0)