-
Notifications
You must be signed in to change notification settings - Fork 29.2k
Fixing flex attention for torch=2.6.0 #37285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
a026a6a
0d8ba72
f13dc1b
c636e59
48f8964
9efb0d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -31,6 +31,7 @@ | |||
import torch | ||||
|
||||
from ..utils import is_torch_flex_attn_available | ||||
from ..utils.import_utils import _torch_version | ||||
|
||||
|
||||
if is_torch_flex_attn_available(): | ||||
|
@@ -63,8 +64,16 @@ def __init__(self): | |||
""" | ||||
Initialize or update the singleton instance. | ||||
""" | ||||
if self._is_flex_compiled is False: | ||||
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False) | ||||
if not self._is_flex_compiled: | ||||
# In PyTorch 2.6.0, there's a known issue with flex attention compilation which may | ||||
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" | ||||
# see https://github.com/pytorch/pytorch/issues/146260 | ||||
if _torch_version == "2.6.0": | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment here to describe why this is the case, e.g. by linking to the issue with a small description? Might be also nice to use/create something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested on my end this should fix training! |
||||
self._compiled_flex_attention = torch.compile( | ||||
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs" | ||||
) | ||||
else: | ||||
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False) | ||||
self._is_flex_compiled = True | ||||
|
||||
def __call__(self): | ||||
|
Uh oh!
There was an error while loading. Please reload this page.