Skip to content

Commit 0c3dc0c

Browse files
committed
support flex attention
1 parent c358a1b commit 0c3dc0c

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

src/transformers/integrations/flex_attention.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,17 @@ def compile_friendly_flex_attention(
144144
)
145145

146146

147+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
148+
"""
149+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
150+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
151+
"""
152+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
153+
if n_rep == 1:
154+
return hidden_states
155+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
156+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
157+
147158
def flex_attention_forward(
148159
module: torch.nn.Module,
149160
query: torch.Tensor,
@@ -174,13 +185,20 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
174185
score = score + head_mask[batch_idx][head_idx][0][0]
175186
return score
176187

188+
enable_gqa = True
189+
num_local_query_heads = query.shape[1]
190+
if not((num_local_query_heads & (num_local_query_heads)) == 0):
191+
key = repeat_kv(key, num_local_query_heads)
192+
value = repeat_kv(value, num_local_query_heads)
193+
enable_gqa = False
194+
177195
attn_output, attention_weights = compile_friendly_flex_attention(
178196
query,
179197
key,
180198
value,
181199
score_mod=score_mod,
182200
block_mask=block_mask,
183-
enable_gqa=True,
201+
enable_gqa=enable_gqa,
184202
scale=scaling,
185203
# Last time checked on PyTorch == 2.5.1: Flex Attention always computes the lse regardless.
186204
# For simplification, we thus always return it as no additional computations are introduced.

src/transformers/modeling_utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,16 +1939,12 @@ def post_init(self):
19391939
self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
19401940

19411941
if self._tp_plan is not None and is_torch_greater_or_equal("2.3"):
1942-
unique_names = {re.sub(r"\d+", "*", name) for name, _ in self.named_children() if len(name) > 0}
1943-
for k, v in self._tp_plan.items():
1942+
for _, v in self._tp_plan.items():
19441943
if v not in SUPPORTED_TP_STYLES:
19451944
raise ValueError(
19461945
f"Unsupported tensor parallel style {v}. Supported styles are {SUPPORTED_TP_STYLES}"
19471946
)
1948-
if k not in unique_names:
1949-
raise ValueError(
1950-
f"Unsupported tensor parallel mapping: {k} is not part of the model"
1951-
)
1947+
19521948

19531949
def dequantize(self):
19541950
"""

0 commit comments

Comments
 (0)