20
20
from typing import List
21
21
22
22
from . import (
23
+ generate_aot_default_additional_params_header ,
23
24
generate_batch_paged_decode_inst ,
24
25
generate_batch_paged_prefill_inst ,
25
26
generate_batch_ragged_prefill_inst ,
@@ -38,7 +39,7 @@ def write_if_different(path: Path, content: str) -> None:
38
39
path : Path = args .path
39
40
head_dims : List [int ] = args .head_dims
40
41
pos_encoding_modes : List [int ] = args .pos_encoding_modes
41
- allow_fp16_qk_reductions : List [int ] = args .allow_fp16_qk_reductions
42
+ use_fp16_qk_reductions : List [int ] = args .use_fp16_qk_reductions
42
43
mask_modes : List [int ] = args .mask_modes
43
44
enable_f16 : bool = args .enable_f16
44
45
enable_bf16 : bool = args .enable_bf16
@@ -54,12 +55,17 @@ def write_if_different(path: Path, content: str) -> None:
54
55
argparse .Namespace (
55
56
head_dims = head_dims ,
56
57
pos_encoding_modes = pos_encoding_modes ,
57
- allow_fp16_qk_reductions = allow_fp16_qk_reductions ,
58
+ use_fp16_qk_reductions = use_fp16_qk_reductions ,
58
59
mask_modes = mask_modes ,
59
60
)
60
61
),
61
62
)
62
63
64
+ write_if_different (
65
+ path / "aot_default_additional_params.h" ,
66
+ generate_aot_default_additional_params_header .get_aot_default_additional_params_header_str (),
67
+ )
68
+
63
69
idtypes = ["i32" ]
64
70
prefill_dtypes = []
65
71
decode_dtypes = []
@@ -150,22 +156,22 @@ def write_if_different(path: Path, content: str) -> None:
150
156
for (
151
157
head_dim ,
152
158
pos_encoding_mode ,
153
- allow_fp16_qk_reduction ,
159
+ use_fp16_qk_reduction ,
154
160
mask_mode ,
155
161
) in product (
156
162
head_dims ,
157
163
pos_encoding_modes ,
158
- allow_fp16_qk_reductions ,
164
+ use_fp16_qk_reductions ,
159
165
mask_modes ,
160
166
):
161
167
for dtype_q , dtype_kv in list (zip (prefill_dtypes , prefill_dtypes )) + list (
162
168
product (prefill_dtypes , fp8_dtypes )
163
169
):
164
- fname = f"single_prefill_head_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ allow_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } .cu"
170
+ fname = f"single_prefill_head_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } .cu"
165
171
content = generate_single_prefill_inst .get_cu_file_str (
166
172
head_dim ,
167
173
pos_encoding_mode ,
168
- allow_fp16_qk_reduction ,
174
+ use_fp16_qk_reduction ,
169
175
mask_mode ,
170
176
dtype_q , # dtype_q
171
177
dtype_kv , # dtype_kv
@@ -184,7 +190,7 @@ def write_if_different(path: Path, content: str) -> None:
184
190
f"posenc_{ pos_encoding_mode } _"
185
191
f"use_swa_{ use_sliding_window } _"
186
192
f"use_logits_cap_{ use_logits_soft_cap } _"
187
- f"f16qk_{ bool (allow_fp16_qk_reduction )} "
193
+ f"f16qk_{ bool (use_fp16_qk_reduction )} "
188
194
)
189
195
write_if_different (path / fname , content )
190
196
@@ -193,24 +199,24 @@ def write_if_different(path: Path, content: str) -> None:
193
199
for (
194
200
head_dim ,
195
201
pos_encoding_mode ,
196
- allow_fp16_qk_reduction ,
202
+ use_fp16_qk_reduction ,
197
203
mask_mode ,
198
204
idtype ,
199
205
) in product (
200
206
head_dims ,
201
207
pos_encoding_modes ,
202
- allow_fp16_qk_reductions ,
208
+ use_fp16_qk_reductions ,
203
209
mask_modes ,
204
210
idtypes ,
205
211
):
206
212
for dtype_q , dtype_kv in list (zip (prefill_dtypes , prefill_dtypes )) + list (
207
213
product (prefill_dtypes , fp8_dtypes )
208
214
):
209
- fname = f"batch_paged_prefill_head_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ allow_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } .cu"
215
+ fname = f"batch_paged_prefill_head_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } .cu"
210
216
content = generate_batch_paged_prefill_inst .get_cu_file_str (
211
217
head_dim ,
212
218
pos_encoding_mode ,
213
- allow_fp16_qk_reduction ,
219
+ use_fp16_qk_reduction ,
214
220
mask_mode ,
215
221
dtype_q , # dtype_q
216
222
dtype_kv , # dtype_kv
@@ -219,11 +225,11 @@ def write_if_different(path: Path, content: str) -> None:
219
225
)
220
226
write_if_different (path / fname , content )
221
227
222
- fname = f"batch_ragged_prefill_head_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ allow_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } .cu"
228
+ fname = f"batch_ragged_prefill_head_{ head_dim } _posenc_{ pos_encoding_mode } _fp16qkred_{ use_fp16_qk_reduction } _mask_{ mask_mode } _dtypeq_{ dtype_q } _dtypekv_{ dtype_kv } _dtypeout_{ dtype_q } _idtype_{ idtype } .cu"
223
229
content = generate_batch_ragged_prefill_inst .get_cu_file_str (
224
230
head_dim ,
225
231
pos_encoding_mode ,
226
- allow_fp16_qk_reduction ,
232
+ use_fp16_qk_reduction ,
227
233
mask_mode ,
228
234
dtype_q , # dtype_q
229
235
dtype_kv , # dtype_kv
@@ -246,7 +252,7 @@ def write_if_different(path: Path, content: str) -> None:
246
252
f"posenc_{ pos_encoding_mode } _"
247
253
f"use_swa_{ sliding_window } _"
248
254
f"use_logits_cap_{ logits_soft_cap } _"
249
- f"f16qk_{ bool (allow_fp16_qk_reduction )} "
255
+ f"f16qk_{ bool (use_fp16_qk_reduction )} "
250
256
)
251
257
252
258
return (
@@ -273,7 +279,7 @@ def write_if_different(path: Path, content: str) -> None:
273
279
help = "Position encoding modes" ,
274
280
)
275
281
parser .add_argument (
276
- "--allow_fp16_qk_reductions " ,
282
+ "--use_fp16_qk_reductions " ,
277
283
type = lambda x : x if isinstance (x , int ) else int (x .lower () == "true" ),
278
284
required = True ,
279
285
nargs = "+" ,
@@ -287,7 +293,7 @@ def write_if_different(path: Path, content: str) -> None:
287
293
help = "Mask modes" ,
288
294
)
289
295
parser .add_argument (
290
- "--enable_fp16 " ,
296
+ "--enable_f16 " ,
291
297
type = lambda x : x if isinstance (x , int ) else x .lower () == "true" ,
292
298
required = True ,
293
299
nargs = "+" ,
0 commit comments