-
Notifications
You must be signed in to change notification settings - Fork 3k
[Triton Kernel] Add varlen segment mean triton kernel #10369
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (17.74%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #10369 +/- ##
===========================================
- Coverage 48.99% 48.97% -0.02%
===========================================
Files 765 766 +1
Lines 125974 126036 +62
===========================================
+ Hits 61720 61731 +11
- Misses 64254 64305 +51 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
PR types
New features
PR changes
Others
Description
Add segment mean triton kernel. Implemented for segment mean operation for varlen qkv.
For example, the k tensor is: [total_seqlen, num_head, head_dim], where total_seqlen = seqlen 1 + seqlen 2 + ... + seqlen n.
So the segment mean triton kernel will do mean operation along the seqlen dim.
It will finally generate a
[bsz, num_head, head_dim]
shape-like result, as the result of mean value of each seqlen segment.