Skip to content

Fixing flex attention for torch=2.6.0 #37285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 7, 2025
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/transformers/integrations/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
year = {2024}
}
"""

# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
Expand All @@ -31,6 +32,7 @@
import torch

from ..utils import is_torch_flex_attn_available
from ..utils.import_utils import _torch_version


if is_torch_flex_attn_available():
Expand Down Expand Up @@ -63,8 +65,13 @@ def __init__(self):
"""
Initialize or update the singleton instance.
"""
if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
if not self._is_flex_compiled:
if _torch_version == "2.6.0":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here to describe why this is the case, e.g. by linking to the issue with a small description?

Might be also nice to use/create something like

def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested on my end this should fix training!

self._compiled_flex_attention = torch.compile(
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
)
else:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
self._is_flex_compiled = True

def __call__(self):
Expand Down