File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed
src/transformers/integrations Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change 31
31
import torch
32
32
33
33
from ..utils import is_torch_flex_attn_available
34
+ from ..utils .import_utils import _torch_version
34
35
35
36
36
37
if is_torch_flex_attn_available ():
@@ -60,8 +61,16 @@ def __init__(self):
60
61
"""
61
62
Initialize or update the singleton instance.
62
63
"""
63
- if self ._is_flex_compiled is False :
64
- self ._compiled_flex_attention = torch .compile (flex_attention , backend = "inductor" )
64
+ if not self ._is_flex_compiled :
65
+ # In PyTorch 2.6.0, there's a known issue with flex attention compilation which may
66
+ # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
67
+ # see https://github.com/pytorch/pytorch/issues/146260 for training
68
+ if _torch_version == "2.6.0" :
69
+ self ._compiled_flex_attention = torch .compile (
70
+ flex_attention , dynamic = False , mode = "max-autotune-no-cudagraphs"
71
+ )
72
+ else :
73
+ self ._compiled_flex_attention = torch .compile (flex_attention , dynamic = False )
65
74
self ._is_flex_compiled = True
66
75
67
76
def __call__ (self ):
You can’t perform that action at this time.
0 commit comments