@@ -81,7 +81,7 @@ set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS})
81
81
set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS} )
82
82
set (POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES} )
83
83
set (ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS} )
84
- set (CAUSALS ${FLASHINFER_GEN_CASUALS } )
84
+ set (MASK_MODES ${FLASHINFER_GEN_MASK_MODES } )
85
85
set (DECODE_DTYPES "f16" )
86
86
set (PREFILL_DTYPES "f16" )
87
87
set (DECODE_F8_DTYPES)
@@ -104,14 +104,14 @@ message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
104
104
message (STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS} " )
105
105
message (STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES} " )
106
106
message (STATUS "FLASHINFER_ALLOW_FP16_QK_REDUCTIONS=${ALLOW_FP16_QK_REDUCTIONS} " )
107
- message (STATUS "FLASHINFER_CAUSALS =${CAUSALS } " )
107
+ message (STATUS "FLASHINFER_MASK_MODES =${MASK_MODES } " )
108
108
109
109
file (MAKE_DIRECTORY ${PROJECT_SOURCE_DIR} /src/generated )
110
110
111
111
set (dispatch_inc_file ${PROJECT_SOURCE_DIR} /src/dispatch.inc)
112
112
add_custom_command (
113
113
OUTPUT ${dispatch_inc_file}
114
- COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR} /src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --causals ${CAUSALS }
114
+ COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR} /src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES }
115
115
DEPENDS ${PROJECT_SOURCE_DIR} /python/generate_dispatch_inc.py
116
116
COMMENT "Generating additional source file ${generated_dispatch_inc} "
117
117
VERBATIM
@@ -225,9 +225,9 @@ foreach(group_size IN LISTS GROUP_SIZES)
225
225
foreach (kv_layout IN LISTS KV_LAYOUTS)
226
226
foreach (pos_encoding_mode IN LISTS POS_ENCODING_MODES)
227
227
foreach (allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
228
- foreach (causal IN LISTS CAUSALS )
228
+ foreach (mask_mode IN LISTS MASK_MODES )
229
229
foreach (dtype IN LISTS PREFILL_DTYPES)
230
- set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated /single_prefill_group_${group_size} _head_${head_dim} _layout_${kv_layout} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _causal_ ${causal } _dtypein_${dtype} _dtypeout_${dtype} .cu)
230
+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated /single_prefill_group_${group_size} _head_${head_dim} _layout_${kv_layout} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_ ${mask_mode } _dtypein_${dtype} _dtypeout_${dtype} .cu)
231
231
add_custom_command (
232
232
OUTPUT ${generated_kernel_src}
233
233
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_single_prefill_inst.py ${generated_kernel_src}
@@ -237,7 +237,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
237
237
)
238
238
list (APPEND single_prefill_kernels_src ${generated_kernel_src} )
239
239
endforeach (dtype)
240
- endforeach (causal )
240
+ endforeach (mask_mode )
241
241
endforeach (allow_fp16_qk_reduction)
242
242
endforeach (pos_encoding_mode)
243
243
endforeach (kv_layout)
@@ -251,10 +251,10 @@ foreach(group_size IN LISTS GROUP_SIZES)
251
251
foreach (kv_layout IN LISTS KV_LAYOUTS)
252
252
foreach (pos_encoding_mode IN LISTS POS_ENCODING_MODES)
253
253
foreach (allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
254
- foreach (causal IN LISTS CAUSALS )
254
+ foreach (mask_mode IN LISTS MASK_MODES )
255
255
foreach (dtype IN LISTS PREFILL_DTYPES)
256
256
foreach (idtype IN LISTS IDTYPES)
257
- set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated /batch_paged_prefill_group_${group_size} _page_${page_size} _head_${head_dim} _layout_${kv_layout} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _causal_ ${causal } _dtypein_${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
257
+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated /batch_paged_prefill_group_${group_size} _page_${page_size} _head_${head_dim} _layout_${kv_layout} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_ ${mask_mode } _dtypein_${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
258
258
add_custom_command (
259
259
OUTPUT ${generated_kernel_src}
260
260
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
@@ -265,7 +265,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
265
265
list (APPEND batch_paged_prefill_kernels_src ${generated_kernel_src} )
266
266
endforeach (idtype)
267
267
endforeach (dtype)
268
- endforeach (causal )
268
+ endforeach (mask_mode )
269
269
endforeach (allow_fp16_qk_reduction)
270
270
endforeach (pos_encoding_mode)
271
271
endforeach (kv_layout)
@@ -279,10 +279,10 @@ foreach(group_size IN LISTS GROUP_SIZES)
279
279
foreach (kv_layout IN LISTS KV_LAYOUTS)
280
280
foreach (pos_encoding_mode IN LISTS POS_ENCODING_MODES)
281
281
foreach (allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
282
- foreach (causal IN LISTS CAUSALS )
282
+ foreach (mask_mode IN LISTS MASK_MODES )
283
283
foreach (dtype IN LISTS PREFILL_DTYPES)
284
284
foreach (idtype IN LISTS IDTYPES)
285
- set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated /batch_ragged_prefill_group_${group_size} _head_${head_dim} _layout_${kv_layout} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _causal_ ${causal } _dtypein_${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
285
+ set (generated_kernel_src ${PROJECT_SOURCE_DIR} /src/generated /batch_ragged_prefill_group_${group_size} _head_${head_dim} _layout_${kv_layout} _posenc_${pos_encoding_mode} _fp16qkred_${allow_fp16_qk_reduction} _mask_ ${mask_mode } _dtypein_${dtype} _dtypeout_${dtype} _idtype_${idtype} .cu)
286
286
add_custom_command (
287
287
OUTPUT ${generated_kernel_src}
288
288
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR} /python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
@@ -293,7 +293,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
293
293
list (APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src} )
294
294
endforeach (idtype)
295
295
endforeach (dtype)
296
- endforeach (causal )
296
+ endforeach (mask_mode )
297
297
endforeach (allow_fp16_qk_reduction)
298
298
endforeach (pos_encoding_mode)
299
299
endforeach (kv_layout)
0 commit comments