@@ -144,6 +144,17 @@ def compile_friendly_flex_attention(
144
144
)
145
145
146
146
147
+ def repeat_kv (hidden_states : torch .Tensor , n_rep : int ) -> torch .Tensor :
148
+ """
149
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
150
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
151
+ """
152
+ batch , num_key_value_heads , slen , head_dim = hidden_states .shape
153
+ if n_rep == 1 :
154
+ return hidden_states
155
+ hidden_states = hidden_states [:, :, None , :, :].expand (batch , num_key_value_heads , n_rep , slen , head_dim )
156
+ return hidden_states .reshape (batch , num_key_value_heads * n_rep , slen , head_dim )
157
+
147
158
def flex_attention_forward (
148
159
module : torch .nn .Module ,
149
160
query : torch .Tensor ,
@@ -174,13 +185,20 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
174
185
score = score + head_mask [batch_idx ][head_idx ][0 ][0 ]
175
186
return score
176
187
188
+ enable_gqa = True
189
+ num_local_query_heads = query .shape [1 ]
190
+ if not ((num_local_query_heads & (num_local_query_heads )) == 0 ):
191
+ key = repeat_kv (key , num_local_query_heads )
192
+ value = repeat_kv (value , num_local_query_heads )
193
+ enable_gqa = False
194
+
177
195
attn_output , attention_weights = compile_friendly_flex_attention (
178
196
query ,
179
197
key ,
180
198
value ,
181
199
score_mod = score_mod ,
182
200
block_mask = block_mask ,
183
- enable_gqa = True ,
201
+ enable_gqa = enable_gqa ,
184
202
scale = scaling ,
185
203
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
186
204
# For simplification, we thus always return it as no additional computations are introduced.
0 commit comments