Skip to content

Commit 7304282

Browse files
authored
feat: support custom attention mask in prefill/append attention kernels (#266)
Some speculative decoding algorithms requires tree attention, which could be supported via prefill/append attention kernels with custom attention mask. This PR supports this feature. Related issues: #152 # API Breaking Changes The `begin_forward` function in `BatchPrefillWithPagedKVCacheWrapper` now has an additional argument `page_size` to accomodate this new feature.
1 parent 08ab1c1 commit 7304282

22 files changed

+1048
-309
lines changed

Diff for: CMakeLists.txt

+12-12
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS})
8181
set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS})
8282
set (POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES})
8383
set (ALLOW_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS})
84-
set (CAUSALS ${FLASHINFER_GEN_CASUALS})
84+
set (MASK_MODES ${FLASHINFER_GEN_MASK_MODES})
8585
set (DECODE_DTYPES "f16")
8686
set (PREFILL_DTYPES "f16")
8787
set (DECODE_F8_DTYPES)
@@ -104,14 +104,14 @@ message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
104104
message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}")
105105
message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
106106
message(STATUS "FLASHINFER_ALLOW_FP16_QK_REDUCTIONS=${ALLOW_FP16_QK_REDUCTIONS}")
107-
message(STATUS "FLASHINFER_CAUSALS=${CAUSALS}")
107+
message(STATUS "FLASHINFER_MASK_MODES=${MASK_MODES}")
108108

109109
file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)
110110

111111
set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc)
112112
add_custom_command(
113113
OUTPUT ${dispatch_inc_file}
114-
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --causals ${CAUSALS}
114+
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --group_sizes ${GROUP_SIZES} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
115115
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py
116116
COMMENT "Generating additional source file ${generated_dispatch_inc}"
117117
VERBATIM
@@ -225,9 +225,9 @@ foreach(group_size IN LISTS GROUP_SIZES)
225225
foreach(kv_layout IN LISTS KV_LAYOUTS)
226226
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
227227
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
228-
foreach(causal IN LISTS CAUSALS)
228+
foreach(mask_mode IN LISTS MASK_MODES)
229229
foreach(dtype IN LISTS PREFILL_DTYPES)
230-
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}.cu)
230+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/single_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}.cu)
231231
add_custom_command(
232232
OUTPUT ${generated_kernel_src}
233233
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_single_prefill_inst.py ${generated_kernel_src}
@@ -237,7 +237,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
237237
)
238238
list(APPEND single_prefill_kernels_src ${generated_kernel_src})
239239
endforeach(dtype)
240-
endforeach(causal)
240+
endforeach(mask_mode)
241241
endforeach(allow_fp16_qk_reduction)
242242
endforeach(pos_encoding_mode)
243243
endforeach(kv_layout)
@@ -251,10 +251,10 @@ foreach(group_size IN LISTS GROUP_SIZES)
251251
foreach(kv_layout IN LISTS KV_LAYOUTS)
252252
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
253253
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
254-
foreach(causal IN LISTS CAUSALS)
254+
foreach(mask_mode IN LISTS MASK_MODES)
255255
foreach(dtype IN LISTS PREFILL_DTYPES)
256256
foreach(idtype IN LISTS IDTYPES)
257-
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
257+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_group_${group_size}_page_${page_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
258258
add_custom_command(
259259
OUTPUT ${generated_kernel_src}
260260
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
@@ -265,7 +265,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
265265
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
266266
endforeach(idtype)
267267
endforeach(dtype)
268-
endforeach(causal)
268+
endforeach(mask_mode)
269269
endforeach(allow_fp16_qk_reduction)
270270
endforeach(pos_encoding_mode)
271271
endforeach(kv_layout)
@@ -279,10 +279,10 @@ foreach(group_size IN LISTS GROUP_SIZES)
279279
foreach(kv_layout IN LISTS KV_LAYOUTS)
280280
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
281281
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
282-
foreach(causal IN LISTS CAUSALS)
282+
foreach(mask_mode IN LISTS MASK_MODES)
283283
foreach(dtype IN LISTS PREFILL_DTYPES)
284284
foreach(idtype IN LISTS IDTYPES)
285-
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_causal_${causal}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
285+
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_ragged_prefill_group_${group_size}_head_${head_dim}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
286286
add_custom_command(
287287
OUTPUT ${generated_kernel_src}
288288
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_ragged_prefill_inst.py ${generated_kernel_src}
@@ -293,7 +293,7 @@ foreach(group_size IN LISTS GROUP_SIZES)
293293
list(APPEND batch_ragged_prefill_kernels_src ${generated_kernel_src})
294294
endforeach(idtype)
295295
endforeach(dtype)
296-
endforeach(causal)
296+
endforeach(mask_mode)
297297
endforeach(allow_fp16_qk_reduction)
298298
endforeach(pos_encoding_mode)
299299
endforeach(kv_layout)

Diff for: cmake/config.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
2424
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
2525
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
2626
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
27-
set(FLASHINFER_GEN_CASUALS "false" "true")
27+
set(FLASHINFER_GEN_MASK_MODES 0 1)
2828

2929
# Set target cuda architectures for tests/benchmarks, defaults to native.
3030
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.

Diff for: include/flashinfer/attention/mask.cuh

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_ATTENTION_MASK_CUH_
17+
#define FLASHINFER_ATTENTION_MASK_CUH_
18+
19+
namespace flashinfer {
20+
21+
enum class MaskMode {
22+
kNone = 0U, // No mask
23+
kCausal = 1U, // Causal mask
24+
kCustom = 2U, // Custom mask
25+
};
26+
27+
} // namespace flashinfer
28+
29+
#endif // FLASHINFER_ATTENTION_MASK_CUH_

0 commit comments

Comments
 (0)