Skip to content

Remove mps workaround for fp16 GELU, which is now supported natively #10133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 13, 2024

Conversation

skotapati
Copy link
Contributor

@skotapati skotapati commented Dec 5, 2024

What does this PR do?

Fixes MPS fp16 GELU calls being forced to cast to fp32. fp16 GELU is now supported natively in MPS so this workaround is no longer needed

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@skotapati skotapati marked this pull request as ready for review December 5, 2024 18:17
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@pcuenca
Copy link
Member

pcuenca commented Dec 5, 2024

Thanks a lot @skotapati! Do you know the minimum PyTorch version that benefits from this?

@hlky
Copy link
Contributor

hlky commented Dec 5, 2024

@pcuenca Looks like it's been fixed since 1.13 pytorch/pytorch#86218 the code was originally added #1533 and at that time the minimum supported version for PyTorch was 1.4

@pcuenca
Copy link
Member

pcuenca commented Dec 6, 2024

I tested on 1.13.1 and it still failed, but it worked on 2.0.0. I'd suggest we add a condition to check for a minimum version of PyTorch >= 2.

@skotapati
Copy link
Contributor Author

Makes sense, adding a check for torch>2.0

@skotapati skotapati force-pushed the skotapati/mps_gelu_workaround branch from 18ab1b9 to d0feeae Compare December 6, 2024 18:55
@skotapati
Copy link
Contributor Author

@pcuenca thanks for the review, added the fallback for torch < 2.0

@yiyixuxu yiyixuxu merged commit ec9bfa9 into huggingface:main Dec 13, 2024
12 checks passed
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Dec 13, 2024

thanks @skotapati @pcuenca @hlky !

sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
…10133)

* Remove mps workaround for fp16 GELU, which is now supported natively

---------

Co-authored-by: hlky <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants