perf: MLA decode kernel implemented by CuTe targeted to SM80 #844
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hi @yzh119 , this is a follow up of #766, an interesting idea came to my mind today, can't help to change few lines to verify this idea.

We can use asymmetric warp config to solve the register file size limit issue, the solution is simply to use 8 warps for the output mma stage, and keep other parts unchanged, because the limitation is on the reg num per cuda thread not the whole SM, there is 64K 32b registers per SM which is enough for the f32 output of 64 heads.
So we now have 4 warps for the att mma stage, 2 warps for the softmax stage, 8 warps for output mma stage, and 4 warps for data load stage, the diagram is updated below:
After the change, output mma stage needs more computation, the benchmark drops a little as expected, but still looks good:

It seems the performance of this CuTe implementation is slightly better than the current FA2 implementation according to #814

So I think this CuTe implementation still has its value, consider such interesting scheduling design and better performance, maybe we can regard it as an ad hoc implementation for (decode only /128 q-heads / SM80) case, and JIT logic can accommodate this kernel.