Skip to content

Fixing flex attention for torch=2.6.0 #37285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 7, 2025

Conversation

SalmanMohammadi
Copy link
Contributor

What does this PR do?

This PR fixes this issue originally raised in pytorch core pytorch/pytorch#146260 for flex attention on 2.6.0, by setting mode="max-autotune-no-cudagraphs".

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker

@github-actions github-actions bot marked this pull request as draft April 4, 2025 15:22
Copy link
Contributor

github-actions bot commented Apr 4, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.

@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review April 4, 2025 15:22
@github-actions github-actions bot requested review from MekkCyber and SunMarc April 4, 2025 15:22
if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
if not self._is_flex_compiled:
if _torch_version == "2.6.0":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here to describe why this is the case, e.g. by linking to the issue with a small description?

Might be also nice to use/create something like

def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested on my end this should fix training!

@vasqu
Copy link
Contributor

vasqu commented Apr 5, 2025

Seems like the llama4 addition changed flex compilation #37307 😄

@ArthurZucker
Copy link
Collaborator

Flex was not working for decoding!

if self._is_flex_compiled is False:
self._compiled_flex_attention = torch.compile(flex_attention, dynamic=False)
if not self._is_flex_compiled:
if _torch_version == "2.6.0":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested on my end this should fix training!

@ArthurZucker ArthurZucker added the for patch Tag issues / labels that should be included in the next patch label Apr 7, 2025
@ArthurZucker ArthurZucker merged commit 794fde7 into huggingface:main Apr 7, 2025
18 checks passed
ArthurZucker added a commit that referenced this pull request Apr 7, 2025
* adding compile kwarg for torch 2.6

* fixing dynamic

* addressing comment

* typo

* Update src/transformers/integrations/flex_attention.py

---------

Co-authored-by: Arthur <[email protected]>
cyr0930 pushed a commit to cyr0930/transformers that referenced this pull request Apr 18, 2025
* adding compile kwarg for torch 2.6

* fixing dynamic

* addressing comment

* typo

* Update src/transformers/integrations/flex_attention.py

---------

Co-authored-by: Arthur <[email protected]>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* adding compile kwarg for torch 2.6

* fixing dynamic

* addressing comment

* typo

* Update src/transformers/integrations/flex_attention.py

---------

Co-authored-by: Arthur <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
for patch Tag issues / labels that should be included in the next patch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants