Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Unify JIT/Customization/AOT mode #748

Merged
merged 12 commits into from
Jan 23, 2025
Merged

[Refactor] Unify JIT/Customization/AOT mode #748

merged 12 commits into from
Jan 23, 2025

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Jan 22, 2025

This PR implements the #706 to unify the codebase for (1) JIT compilation of default attention (2) JIT compilation of customized attention (3) AOT compilation of default attention, and supports customized attention for batch prefill/decode (both fa2/fa3 template).

More specifically:

  1. All template files are stored in standalone Jinja files instead of embedded python strings.
  2. All attention modes use the same set of codebase. Default attentions are instantiated as special forms of customized attention where additional parameters are hard-coded.
  3. The name of optional additional tensor parameters should start with maybe_.
  4. For FA3 template, additional parameters are set in an AdditionalParams structure that will be passed to MainloopParams, so that we can avoid passing the entire kernel parameter class where many of the members are duplicate of MainloopParams and EpilogueParams.
  5. The customized batch prefill/decode examples are added and tested.
  6. We change the arguments order of pytorch bindings to unify customized attention and default attention interface. The APIs exposed to user is unchanged.

cc @hyhieu @merrymercy for visibility.

Milestones

  • JIT default attention
  • JIT customized attention
  • AOT default attention
  • Check all unittests.
  • C++ tests/benchmarks

C++/Python Interface of PyTorch Bindings

Single Decode/Prefill Atteniton Kernels

Decode C++ interface:

#define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_alibi_slopes, \
    float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta
void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp,
                                 at::Tensor o, unsigned int layout, int window_left
                                 ADDITIONAL_FUNC_PARAMS,
                                 int64_t cuda_stream);

Decode python interface:

def single_decode_with_kv_cache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp: torch.Tensor,
                                 o: torch.Tensor, layout: int, window_left: int, *args,
                                 cuda_stream: int = 0) -> None:
    pass

For default attention, *args is expanded to maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta.

Prefill C++ interface:

#define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_custom_mask, std::optional<at::Tensor> maybe_alibi_slopes, \
    float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta
void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k,
                                  at::Tensor v, at::Tensor tmp, at::Tensor o, std::optional<at::Tensor> maybe_lse,
                                  unsigned int mask_mode_code, unsigned int layout, int32_t window_left
                                  ADDITIONAL_FUNC_PARAMS,
                                  int64_t cuda_stream);

Prefill python interface:

def single_prefill_with_kv_cache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp: torch.Tensor,
                                     o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode_code: int,
                                     layout: int, window_left: int, *args,
                                     cuda_stream: int = 0) -> None:
    pass

For default attention, *args is expanded to maybe_custom_mask, maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta for fa2 template, and logits_soft_cap, sm_scale for fa3 template.

Batch Decode/Prefill Attention Kernels

Decode c++ interface:

#define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_alibi_slopes, \
    float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta
std::vector<int64_t> BatchDecodeWithPagedKVCachePlan(
    at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
    at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size,
    unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size,
    bool enable_cuda_graph, bool use_logits_soft_cap, unsigned int head_dim, at::Tensor empty_q_data,
    at::Tensor empty_kv_data, int64_t cuda_stream);

void BatchDecodeWithPagedKVCacheRun(
    at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
    std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
    at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices,
    at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse,
    unsigned int kv_layout_code, int window_left
    ADDITIONAL_FUNC_PARAMS,
    int64_t cuda_stream);

Decode python interface

def batch_decode_with_paged_kv_cache_plan(
    float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor,
    page_locked_int_workspace_buffer: torch.Tensor, indptr: torch.Tensor, batch_size: int,
    num_qo_heads: int, num_kv_heads: int, page_size: int, enable_cuda_graph: bool,
    use_logits_soft_cap: bool, head_dim: int, empty_q_data: torch.Tensor,
    empty_kv_data: torch.Tensor, cuda_stream: int) -> List[int]:
    pass

def batch_decode_with_paged_kv_cache_run(
    float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor,
    plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor,
    paged_v_cache: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor,
    paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor],
    kv_layout_code: int, window_left: int, *args,
    cuda_stream: int) -> None:
    pass

For default attention, *args is expanded to maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta.

Prefill C++ interface:

#define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_custom_mask, std::optional<at::Tensor> maybe_mask_indptr, std::optional<at::Tensor> maybe_alibi_slopes, \
    float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta
std::vector<int64_t> BatchPrefillWithKVCachePlan(
    at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
    at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr,
    unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads,
    unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph,
    unsigned int head_dim, bool causal,
    int64_t cuda_stream);

void BatchPrefillWithRaggedKVCacheRun(
    at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
    std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v,
    at::Tensor qo_indptr, at::Tensor kv_indptr,
    at::Tensor o, std::optional<at::Tensor> maybe_lse,
    unsigned int mask_mode_code, unsigned int layout, int32_t window_left
    ADDITIONAL_FUNC_PARAMS,
    int64_t cuda_stream);

void BatchPrefillWithPagedKVCacheRun(
    at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer,
    std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache,
    at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr,
    at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len,
    at::Tensor o, std::optional<at::Tensor> maybe_lse,
    unsigned int mask_mode_code, unsigned int layout, int32_t window_left
    ADDITIONAL_FUNC_PARAMS,
    int64_t cuda_stream);

Prefill python interface:

def batch_prefill_with_kv_cache_plan(
    float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor,
    page_locked_int_workspace_buffer: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_len_arr: torch.Tensor,
    total_num_rows: int, batch_size: int, num_qo_heads: int, num_kv_heads: int, page_size: int,
    enable_cuda_graph: bool, head_dim: int, causal: bool, cuda_stream: int) -> List[int]:
    pass

def batch_prefill_with_ragged_kv_cache_jit_run(
    float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor,
    plan_info_vec: List[int], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
    qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, o: torch.Tensor,
    maybe_lse: Optional[torch.Tensor], mask_mode_code: int, layout: int, window_left: int,
    *args, cuda_stream: int) -> None:
    pass

def batch_prefill_with_paged_kv_cache_jit_run(
    float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor,
    plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor, paged_v_cache: torch.Tensor,
    qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor,
    paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor],
    mask_mode_code: int, layout: int, window_left: int, *args,
    cuda_stream: int) -> None:
    pass

The *args is expanded to maybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta for fa2 template, and logits_soft_cap, sm_scale for fa3 template.

@yzh119 yzh119 merged commit 93e1a26 into main Jan 23, 2025
@yzh119 yzh119 deleted the unify-jit-aot branch January 24, 2025 02:03
yzh119 added a commit that referenced this pull request Feb 5, 2025
We put `group_size` outside of params mainly because we observe better
performance, but with some recent refactor such as #748 and #776 , there
is no need to decouple group_size with other parts of the parameters,
this PR merge `group_size` back to parameter class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant