-
Notifications
You must be signed in to change notification settings - Fork 6k
Remove mps workaround for fp16 GELU, which is now supported natively #10133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove mps workaround for fp16 GELU, which is now supported natively #10133
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Thanks a lot @skotapati! Do you know the minimum PyTorch version that benefits from this? |
@pcuenca Looks like it's been fixed since 1.13 pytorch/pytorch#86218 the code was originally added #1533 and at that time the minimum supported version for PyTorch was 1.4 |
I tested on |
Makes sense, adding a check for torch>2.0 |
18ab1b9
to
d0feeae
Compare
@pcuenca thanks for the review, added the fallback for torch < 2.0 |
thanks @skotapati @pcuenca @hlky ! |
…10133) * Remove mps workaround for fp16 GELU, which is now supported natively --------- Co-authored-by: hlky <[email protected]>
What does this PR do?
Fixes MPS fp16 GELU calls being forced to cast to fp32. fp16 GELU is now supported natively in MPS so this workaround is no longer needed
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.