Skip to content

Commit 794fde7

Browse files
Fixing flex attention for torch=2.6.0 (#37285)
* adding compile kwarg for torch 2.6 * fixing dynamic * addressing comment * typo * Update src/transformers/integrations/flex_attention.py --------- Co-authored-by: Arthur <[email protected]>
1 parent b54c2f4 commit 794fde7

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/transformers/integrations/flex_attention.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import torch
3232

3333
from ..utils import is_torch_flex_attn_available
34+
from ..utils.import_utils import _torch_version
3435

3536

3637
if is_torch_flex_attn_available():
@@ -60,8 +61,16 @@ def __init__(self):
6061
"""
6162
Initialize or update the singleton instance.
6263
"""
63-
if self._is_flex_compiled is False:
64-
self._compiled_flex_attention = torch.compile(flex_attention, backend="inductor")
64+
if not self._is_flex_compiled:
65+
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
66+
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
67+
# see https://github.com/pytorch/pytorch/issues/146260 for training
68+
if _torch_version == "2.6.0":
69+
self._compiled_flex_attention = torch.compile(
70+
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
71+
)
72+
else:
73+
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
6574
self._is_flex_compiled = True
6675

6776
def __call__(self):

0 commit comments

Comments
 (0)