-
-
Notifications
You must be signed in to change notification settings - Fork 7.6k
Use RMSNorm
in TransformersModel
#12776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you can make the model compatible with torch.compile
, then you should be able to directly get the benefit.
Do you mean |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO this is more complexity than it is worth. It should be in our best interest to keep TransformersModel
as minimal as possible while maintaining functionality.
I agree with Kaichao the better approach for performance will be to make TransformersModel
generally work with torch.compile, which should have essentially the same impact as using the fused RMSNorm module from vLLM
Ok, With V1 enabled the benchmarks look like:
So using vLLM's RMSNorm still gives us some benefit 🤔 |
I've made #12785 to handle the UX portion of this PR (which shouldn't be controvertial I don't think) |
This pull request has merge conflicts that must be resolved before it can be |
With the UX change merged, I'll close this in favour of keeping |
Changes:
RMSNorm
class inTransformersModel
Linear
layer cannot be tensor parallelisedBefore and after benchmarks using the following command:
Results:
LlamaForCausalLM
TransformersModel
beforeTransformersModel
afterThis corresponds to a +2.8% performance boost for this model.