[Refactor] Unify JIT/Customization/AOT mode #748
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements the #706 to unify the codebase for (1) JIT compilation of default attention (2) JIT compilation of customized attention (3) AOT compilation of default attention, and supports customized attention for batch prefill/decode (both fa2/fa3 template).
More specifically:
maybe_
.AdditionalParams
structure that will be passed to MainloopParams, so that we can avoid passing the entire kernel parameter class where many of the members are duplicate of MainloopParams and EpilogueParams.cc @hyhieu @merrymercy for visibility.
Milestones
C++/Python Interface of PyTorch Bindings
Single Decode/Prefill Atteniton Kernels
Decode C++ interface:
Decode python interface:
For default attention,
*args
is expanded tomaybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta
.Prefill C++ interface:
Prefill python interface:
For default attention,
*args
is expanded tomaybe_custom_mask, maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta
for fa2 template, andlogits_soft_cap, sm_scale
for fa3 template.Batch Decode/Prefill Attention Kernels
Decode c++ interface:
Decode python interface
For default attention,
*args
is expanded tomaybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta
.Prefill C++ interface:
Prefill python interface:
The
*args
is expanded tomaybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta
for fa2 template, andlogits_soft_cap, sm_scale
for fa3 template.