Skip to content

Commit 93e1a26

Browse files
authored
[Refactor] Unify JIT/Customization/AOT mode (#748)
This PR implements the #706 to unify the codebase for (1) JIT compilation of default attention (2) JIT compilation of customized attention (3) AOT compilation of default attention, and supports customized attention for batch prefill/decode (both fa2/fa3 template). More specifically: 1. All template files are stored in standalone Jinja files instead of embedded python strings. 2. All attention modes use the same set of codebase. Default attentions are instantiated as special forms of customized attention where additional parameters are hard-coded. 3. The name of optional additional tensor parameters should start with `maybe_`. 4. For FA3 template, additional parameters are set in an `AdditionalParams` structure that will be passed to MainloopParams, so that we can avoid passing the entire kernel parameter class where many of the members are duplicate of MainloopParams and EpilogueParams. 5. The customized batch prefill/decode examples are added and tested. 6. We change the arguments order of pytorch bindings to unify customized attention and default attention interface. The APIs exposed to user is unchanged. cc @hyhieu @merrymercy for visibility. ## Milestones * [x] JIT default attention * [x] JIT customized attention * [x] AOT default attention * [x] Check all unittests. * [ ] C++ tests/benchmarks ## C++/Python Interface of PyTorch Bindings ### Single Decode/Prefill Atteniton Kernels #### Decode C++ interface: ```cpp #define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_alibi_slopes, \ float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta void single_decode_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, unsigned int layout, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); ``` #### Decode python interface: ```python def single_decode_with_kv_cache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp: torch.Tensor, o: torch.Tensor, layout: int, window_left: int, *args, cuda_stream: int = 0) -> None: pass ``` For default attention, `*args` is expanded to `maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta`. #### Prefill C++ interface: ```cpp #define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_custom_mask, std::optional<at::Tensor> maybe_alibi_slopes, \ float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta void single_prefill_with_kv_cache(at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor tmp, at::Tensor o, std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); ``` #### Prefill python interface: ```python def single_prefill_with_kv_cache(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode_code: int, layout: int, window_left: int, *args, cuda_stream: int = 0) -> None: pass ``` For default attention, `*args` is expanded to `maybe_custom_mask, maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta` for fa2 template, and `logits_soft_cap, sm_scale` for fa3 template. ### Batch Decode/Prefill Attention Kernels #### Decode c++ interface: ```cpp #define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_alibi_slopes, \ float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta std::vector<int64_t> BatchDecodeWithPagedKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor indptr, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph, bool use_logits_soft_cap, unsigned int head_dim, at::Tensor empty_q_data, at::Tensor empty_kv_data, int64_t cuda_stream); void BatchDecodeWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse, unsigned int kv_layout_code, int window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); ``` #### Decode python interface ```python def batch_decode_with_paged_kv_cache_plan( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, page_locked_int_workspace_buffer: torch.Tensor, indptr: torch.Tensor, batch_size: int, num_qo_heads: int, num_kv_heads: int, page_size: int, enable_cuda_graph: bool, use_logits_soft_cap: bool, head_dim: int, empty_q_data: torch.Tensor, empty_kv_data: torch.Tensor, cuda_stream: int) -> List[int]: pass def batch_decode_with_paged_kv_cache_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor, paged_v_cache: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], kv_layout_code: int, window_left: int, *args, cuda_stream: int) -> None: pass ``` For default attention, `*args` is expanded to `maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta`. #### Prefill C++ interface: ```cpp #define AOT_ADDITIONAL_FUNC_PARAMS , std::optional<at::Tensor> maybe_custom_mask, std::optional<at::Tensor> maybe_mask_indptr, std::optional<at::Tensor> maybe_alibi_slopes, \ float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta std::vector<int64_t> BatchPrefillWithKVCachePlan( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, at::Tensor page_locked_int_workspace_buffer, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len_arr, unsigned total_num_rows, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int page_size, bool enable_cuda_graph, unsigned int head_dim, bool causal, int64_t cuda_stream); void BatchPrefillWithRaggedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor k, at::Tensor v, at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor o, std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); void BatchPrefillWithPagedKVCacheRun( at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, std::vector<int64_t> plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, at::Tensor paged_v_cache, at::Tensor qo_indptr, at::Tensor paged_kv_indptr, at::Tensor paged_kv_indices, at::Tensor paged_kv_last_page_len, at::Tensor o, std::optional<at::Tensor> maybe_lse, unsigned int mask_mode_code, unsigned int layout, int32_t window_left ADDITIONAL_FUNC_PARAMS, int64_t cuda_stream); ``` #### Prefill python interface: ```python def batch_prefill_with_kv_cache_plan( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, page_locked_int_workspace_buffer: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_len_arr: torch.Tensor, total_num_rows: int, batch_size: int, num_qo_heads: int, num_kv_heads: int, page_size: int, enable_cuda_graph: bool, head_dim: int, causal: bool, cuda_stream: int) -> List[int]: pass def batch_prefill_with_ragged_kv_cache_jit_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, qo_indptr: torch.Tensor, kv_indptr: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode_code: int, layout: int, window_left: int, *args, cuda_stream: int) -> None: pass def batch_prefill_with_paged_kv_cache_jit_run( float_workspace_buffer: torch.Tensor, int_workspace_buffer: torch.Tensor, plan_info_vec: List[int], q: torch.Tensor, paged_k_cache: torch.Tensor, paged_v_cache: torch.Tensor, qo_indptr: torch.Tensor, paged_kv_indptr: torch.Tensor, paged_kv_indices: torch.Tensor, paged_kv_last_page_len: torch.Tensor, o: torch.Tensor, maybe_lse: Optional[torch.Tensor], mask_mode_code: int, layout: int, window_left: int, *args, cuda_stream: int) -> None: pass ``` The `*args` is expanded to `maybe_custom_mask, maybe_mask_indptr, maybe_alibi_slopes, logits_soft_cap, sm_scale, rope_scale, rope_theta` for fa2 template, and `logits_soft_cap, sm_scale` for fa3 template.
1 parent 4e8eb18 commit 93e1a26

File tree

115 files changed

+5223
-6196
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

115 files changed

+5223
-6196
lines changed

CMakeLists.txt

+17-23
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ flashinfer_option(FLASHINFER_TVM_SOURCE_DIR
6666
flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256)
6767
flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0
6868
1 2)
69-
flashinfer_option(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS
69+
flashinfer_option(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS
7070
"QK reductions to enable" "false" "true")
7171
flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2)
7272

@@ -126,34 +126,28 @@ endif(FLASHINFER_ENABLE_BF16)
126126
# generate kernel inst
127127
set(HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS})
128128
set(POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES})
129-
set(ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS})
129+
set(USE_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS})
130130
set(MASK_MODES ${FLASHINFER_GEN_MASK_MODES})
131131

132132
# log options
133133
message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
134134
message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
135-
message(
136-
STATUS "FLASHINFER_ALLOW_FP16_QK_REDUCTIONS=${ALLOW_FP16_QK_REDUCTIONS}")
135+
message(STATUS "FLASHINFER_USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS}")
137136
message(STATUS "FLASHINFER_MASK_MODES=${MASK_MODES}")
138137

139138
file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)
140139

141140
set(AOT_GENERATE_COMMAND
142-
${Python3_EXECUTABLE}
143-
-m aot_build_utils.generate
144-
--path ${PROJECT_SOURCE_DIR}/src/generated
145-
--head_dims ${HEAD_DIMS}
146-
--pos_encoding_modes ${POS_ENCODING_MODES}
147-
--allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS}
148-
--mask_modes ${MASK_MODES}
149-
--enable_f16 ${FLASHINFER_ENABLE_F16}
150-
--enable_bf16 ${FLASHINFER_ENABLE_BF16}
151-
--enable_fp8_e4m3 ${FLASHINFER_ENABLE_FP8_E4M3}
152-
--enable_fp8_e5m2 ${FLASHINFER_ENABLE_FP8_E5M2})
153-
154-
execute_process(
155-
COMMAND ${AOT_GENERATE_COMMAND}
156-
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
141+
${Python3_EXECUTABLE} -m aot_build_utils.generate --path
142+
${PROJECT_SOURCE_DIR}/src/generated --head_dims ${HEAD_DIMS}
143+
--pos_encoding_modes ${POS_ENCODING_MODES} --use_fp16_qk_reductions
144+
${USE_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} --enable_f16
145+
${FLASHINFER_ENABLE_F16} --enable_bf16 ${FLASHINFER_ENABLE_BF16}
146+
--enable_fp8_e4m3 ${FLASHINFER_ENABLE_FP8_E4M3} --enable_fp8_e5m2
147+
${FLASHINFER_ENABLE_FP8_E5M2})
148+
149+
execute_process(COMMAND ${AOT_GENERATE_COMMAND}
150+
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
157151

158152
file(GLOB_RECURSE FLASHINFER_GENERATORS
159153
${PROJECT_SOURCE_DIR}/aot_build_utils/*.py)
@@ -175,13 +169,13 @@ add_custom_target(dispatch_inc DEPENDS ${DISPATCH_INC_FILE})
175169

176170
add_library(decode_kernels STATIC ${DECODE_KERNELS_SRCS})
177171
target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
178-
target_compile_options(decode_kernels PRIVATE
179-
-Xcompiler=-fPIC --fatbin-options -compress-all)
172+
target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options
173+
-compress-all)
180174

181175
add_library(prefill_kernels STATIC ${PREFILL_KERNELS_SRCS})
182176
target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
183-
target_compile_options(prefill_kernels PRIVATE
184-
-Xcompiler=-fPIC --fatbin-options -compress-all)
177+
target_compile_options(prefill_kernels PRIVATE -Xcompiler=-fPIC
178+
--fatbin-options -compress-all)
185179

186180
if(FLASHINFER_DECODE)
187181
message(STATUS "Compile single decode kernel benchmarks.")

aot_build_utils/generate.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import List
2121

2222
from . import (
23+
generate_aot_default_additional_params_header,
2324
generate_batch_paged_decode_inst,
2425
generate_batch_paged_prefill_inst,
2526
generate_batch_ragged_prefill_inst,
@@ -38,7 +39,7 @@ def write_if_different(path: Path, content: str) -> None:
3839
path: Path = args.path
3940
head_dims: List[int] = args.head_dims
4041
pos_encoding_modes: List[int] = args.pos_encoding_modes
41-
allow_fp16_qk_reductions: List[int] = args.allow_fp16_qk_reductions
42+
use_fp16_qk_reductions: List[int] = args.use_fp16_qk_reductions
4243
mask_modes: List[int] = args.mask_modes
4344
enable_f16: bool = args.enable_f16
4445
enable_bf16: bool = args.enable_bf16
@@ -54,12 +55,17 @@ def write_if_different(path: Path, content: str) -> None:
5455
argparse.Namespace(
5556
head_dims=head_dims,
5657
pos_encoding_modes=pos_encoding_modes,
57-
allow_fp16_qk_reductions=allow_fp16_qk_reductions,
58+
use_fp16_qk_reductions=use_fp16_qk_reductions,
5859
mask_modes=mask_modes,
5960
)
6061
),
6162
)
6263

64+
write_if_different(
65+
path / "aot_default_additional_params.h",
66+
generate_aot_default_additional_params_header.get_aot_default_additional_params_header_str(),
67+
)
68+
6369
idtypes = ["i32"]
6470
prefill_dtypes = []
6571
decode_dtypes = []
@@ -150,22 +156,22 @@ def write_if_different(path: Path, content: str) -> None:
150156
for (
151157
head_dim,
152158
pos_encoding_mode,
153-
allow_fp16_qk_reduction,
159+
use_fp16_qk_reduction,
154160
mask_mode,
155161
) in product(
156162
head_dims,
157163
pos_encoding_modes,
158-
allow_fp16_qk_reductions,
164+
use_fp16_qk_reductions,
159165
mask_modes,
160166
):
161167
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
162168
product(prefill_dtypes, fp8_dtypes)
163169
):
164-
fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
170+
fname = f"single_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}.cu"
165171
content = generate_single_prefill_inst.get_cu_file_str(
166172
head_dim,
167173
pos_encoding_mode,
168-
allow_fp16_qk_reduction,
174+
use_fp16_qk_reduction,
169175
mask_mode,
170176
dtype_q, # dtype_q
171177
dtype_kv, # dtype_kv
@@ -184,7 +190,7 @@ def write_if_different(path: Path, content: str) -> None:
184190
f"posenc_{pos_encoding_mode}_"
185191
f"use_swa_{use_sliding_window}_"
186192
f"use_logits_cap_{use_logits_soft_cap}_"
187-
f"f16qk_{bool(allow_fp16_qk_reduction)}"
193+
f"f16qk_{bool(use_fp16_qk_reduction)}"
188194
)
189195
write_if_different(path / fname, content)
190196

@@ -193,24 +199,24 @@ def write_if_different(path: Path, content: str) -> None:
193199
for (
194200
head_dim,
195201
pos_encoding_mode,
196-
allow_fp16_qk_reduction,
202+
use_fp16_qk_reduction,
197203
mask_mode,
198204
idtype,
199205
) in product(
200206
head_dims,
201207
pos_encoding_modes,
202-
allow_fp16_qk_reductions,
208+
use_fp16_qk_reductions,
203209
mask_modes,
204210
idtypes,
205211
):
206212
for dtype_q, dtype_kv in list(zip(prefill_dtypes, prefill_dtypes)) + list(
207213
product(prefill_dtypes, fp8_dtypes)
208214
):
209-
fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
215+
fname = f"batch_paged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
210216
content = generate_batch_paged_prefill_inst.get_cu_file_str(
211217
head_dim,
212218
pos_encoding_mode,
213-
allow_fp16_qk_reduction,
219+
use_fp16_qk_reduction,
214220
mask_mode,
215221
dtype_q, # dtype_q
216222
dtype_kv, # dtype_kv
@@ -219,11 +225,11 @@ def write_if_different(path: Path, content: str) -> None:
219225
)
220226
write_if_different(path / fname, content)
221227

222-
fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{allow_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
228+
fname = f"batch_ragged_prefill_head_{head_dim}_posenc_{pos_encoding_mode}_fp16qkred_{use_fp16_qk_reduction}_mask_{mask_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_q}_idtype_{idtype}.cu"
223229
content = generate_batch_ragged_prefill_inst.get_cu_file_str(
224230
head_dim,
225231
pos_encoding_mode,
226-
allow_fp16_qk_reduction,
232+
use_fp16_qk_reduction,
227233
mask_mode,
228234
dtype_q, # dtype_q
229235
dtype_kv, # dtype_kv
@@ -246,7 +252,7 @@ def write_if_different(path: Path, content: str) -> None:
246252
f"posenc_{pos_encoding_mode}_"
247253
f"use_swa_{sliding_window}_"
248254
f"use_logits_cap_{logits_soft_cap}_"
249-
f"f16qk_{bool(allow_fp16_qk_reduction)}"
255+
f"f16qk_{bool(use_fp16_qk_reduction)}"
250256
)
251257

252258
return (
@@ -273,7 +279,7 @@ def write_if_different(path: Path, content: str) -> None:
273279
help="Position encoding modes",
274280
)
275281
parser.add_argument(
276-
"--allow_fp16_qk_reductions",
282+
"--use_fp16_qk_reductions",
277283
type=lambda x: x if isinstance(x, int) else int(x.lower() == "true"),
278284
required=True,
279285
nargs="+",
@@ -287,7 +293,7 @@ def write_if_different(path: Path, content: str) -> None:
287293
help="Mask modes",
288294
)
289295
parser.add_argument(
290-
"--enable_fp16",
296+
"--enable_f16",
291297
type=lambda x: x if isinstance(x, int) else x.lower() == "true",
292298
required=True,
293299
nargs="+",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
"""
2+
Copyright (c) 2024 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
18+
def generate_macro_entry(
19+
macro_prefix,
20+
additional_tensor_names,
21+
additional_tensor_dtypes,
22+
additional_scalar_names,
23+
additional_scalar_dtypes,
24+
is_sm90_template: bool = False,
25+
) -> str:
26+
# NOTE(Zihao): mostly copy-paste from generate_additional_params in flashinfer.jit.attention.py
27+
additional_func_params = "".join(
28+
[
29+
(
30+
f", std::optional<at::Tensor> {var}"
31+
if var.startswith("maybe")
32+
else f", at::Tensor {var}"
33+
)
34+
for var in additional_tensor_names
35+
]
36+
+ [
37+
f", {dtype} {var}"
38+
for dtype, var in zip(additional_scalar_dtypes, additional_scalar_names)
39+
]
40+
)
41+
if is_sm90_template:
42+
additional_params_setter = " \\\n".join(
43+
[
44+
(
45+
f"params.additional_params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;"
46+
if var.startswith("maybe")
47+
else f"params.additional_params.{var} = static_cast<{dtype}*>({var}.data_ptr());"
48+
)
49+
for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names)
50+
]
51+
+ [
52+
f"params.additional_params.{var} = {var};"
53+
for var in additional_scalar_names
54+
]
55+
)
56+
else:
57+
additional_params_setter = " \\\n".join(
58+
[
59+
(
60+
f"params.{var} = {var} ? static_cast<{dtype}*>({var}->data_ptr()): nullptr;"
61+
if var.startswith("maybe")
62+
else f"params.{var} = static_cast<{dtype}*>({var}.data_ptr());"
63+
)
64+
for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names)
65+
]
66+
+ [f"params.{var} = {var};" for var in additional_scalar_names]
67+
)
68+
return f"""#define {macro_prefix}_ADDITIONAL_FUNC_PARAMS {additional_func_params}
69+
70+
#define {macro_prefix}_ADDITIONAL_PARAMS_SETTER {additional_params_setter}
71+
72+
"""
73+
74+
75+
def get_aot_default_additional_params_header_str() -> str:
76+
ret = ""
77+
78+
ret += generate_macro_entry(
79+
"SINGLE_DECODE",
80+
["maybe_alibi_slopes"], # additional_tensor_names
81+
["float"], # additional_tensor_dtypes
82+
[
83+
"logits_soft_cap",
84+
"sm_scale",
85+
"rope_rcp_scale",
86+
"rope_rcp_theta",
87+
], # additional_scalar_names
88+
["float", "float", "float", "float"], # additional_scalar_dtypes
89+
)
90+
91+
ret += generate_macro_entry(
92+
"SINGLE_PREFILL",
93+
["maybe_custom_mask", "maybe_alibi_slopes"],
94+
["uint8_t", "float"],
95+
[
96+
"logits_soft_cap",
97+
"sm_scale",
98+
"rope_rcp_scale",
99+
"rope_rcp_theta",
100+
],
101+
["float", "float", "float", "float"],
102+
)
103+
104+
ret += generate_macro_entry(
105+
"SINGLE_PREFILL_SM90",
106+
[],
107+
[],
108+
["logits_soft_cap", "sm_scale"],
109+
["float", "float"],
110+
is_sm90_template=True,
111+
)
112+
113+
ret += generate_macro_entry(
114+
"BATCH_DECODE",
115+
["maybe_alibi_slopes"], # additional_tensor_names
116+
["float"], # additional_tensor_dtypes
117+
[
118+
"logits_soft_cap",
119+
"sm_scale",
120+
"rope_rcp_scale",
121+
"rope_rcp_theta",
122+
], # additional_scalar_names
123+
["float", "float", "float", "float"], # additional_scalar_dtypes
124+
)
125+
126+
ret += generate_macro_entry(
127+
"BATCH_PREFILL",
128+
["maybe_custom_mask", "maybe_mask_indptr", "maybe_alibi_slopes"],
129+
["uint8_t", "int32_t", "float"],
130+
[
131+
"logits_soft_cap",
132+
"sm_scale",
133+
"rope_rcp_scale",
134+
"rope_rcp_theta",
135+
],
136+
["float", "float", "float", "float"],
137+
)
138+
139+
ret += generate_macro_entry(
140+
"BATCH_PREFILL_SM90",
141+
[],
142+
[],
143+
["logits_soft_cap", "sm_scale"],
144+
["float", "float"],
145+
is_sm90_template=True,
146+
)
147+
148+
return ret

0 commit comments

Comments
 (0)