@@ -3,8 +3,8 @@ project(flashinfer CUDA CXX)
3
3
4
4
include (cmake/utils/Utils.cmake)
5
5
6
- set (CMAKE_CXX_STANDARD 17 )
7
- set (CMAKE_CUDA_STANDARD 17 )
6
+ set (CMAKE_CXX_STANDARD 20 )
7
+ set (CMAKE_CUDA_STANDARD 20 )
8
8
9
9
if (EXISTS ${CMAKE_BINARY_DIR} /config.cmake)
10
10
include (${CMAKE_BINARY_DIR} /config.cmake)
@@ -63,7 +63,7 @@ flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256)
63
63
flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0
64
64
1 2)
65
65
flashinfer_option(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS
66
- "QK reductions to enable" "false" "true" )
66
+ "QK reductions to enable" OFF )
67
67
flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2)
68
68
69
69
if (DEFINED FLASHINFER_CUDA_ARCHITECTURES)
@@ -125,25 +125,77 @@ set(POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES})
125
125
set (USE_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS} )
126
126
set (MASK_MODES ${FLASHINFER_GEN_MASK_MODES} )
127
127
128
+ set (SM90_ALLOWED_HEAD_DIMS "64,64" "128,128" "256,256" "192,128" )
129
+ set (HEAD_DIMS_SM90 "" )
130
+
131
+ foreach (DIM_VAL ${HEAD_DIMS} )
132
+ string (CONCAT TUPLE_VAL "${DIM_VAL} " "," "${DIM_VAL} " )
133
+ list (FIND SM90_ALLOWED_HEAD_DIMS ${TUPLE_VAL} RESULT)
134
+ if (NOT ${RESULT} EQUAL -1)
135
+ list (APPEND HEAD_DIMS_SM90 ${TUPLE_VAL} )
136
+ endif (NOT ${RESULT} EQUAL -1)
137
+ endforeach (DIM_VAL)
138
+
139
+ foreach (TUPLE_VAL ${SM90_ALLOWED_HEAD_DIMS} )
140
+ string (REPLACE "," ";" HEAD_DIMS_LIST ${TUPLE_VAL} )
141
+ list (GET HEAD_DIMS_LIST 0 K)
142
+ list (GET HEAD_DIMS_LIST 1 V)
143
+ if (NOT K EQUAL V)
144
+ list (APPEND HEAD_DIMS_SM90 ${TUPLE_VAL} )
145
+ endif (NOT K EQUAL V)
146
+ endforeach (TUPLE_VAL)
147
+
148
+ list (REMOVE_DUPLICATES HEAD_DIMS_SM90)
149
+
128
150
# log options
129
151
message (STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS} " )
130
152
message (STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES} " )
131
153
message (STATUS "FLASHINFER_USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS} " )
132
154
message (STATUS "FLASHINFER_MASK_MODES=${MASK_MODES} " )
133
155
156
+ # Log SM90_ALLOWED_HEAD_DIMS and HEAD_DIMS_SM90
157
+ message (STATUS "SM90_ALLOWED_HEAD_DIMS=${SM90_ALLOWED_HEAD_DIMS} " )
158
+ message (STATUS "HEAD_DIMS_SM90=${HEAD_DIMS_SM90} " )
159
+
160
+ set (GENERATED_SOURCE_DIR ${PROJECT_SOURCE_DIR} /src/generated )
134
161
file (MAKE_DIRECTORY ${PROJECT_SOURCE_DIR} /src/generated )
135
162
163
+ if (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
164
+ # ----------------------------- Dependencies -------------------------------#
165
+ include (FetchContent)
166
+
167
+ set (BOOST_ENABLE_CMAKE ON )
168
+ FetchContent_Declare(boost_math
169
+ GIT_REPOSITORY https://github.com/boostorg/math.git)
170
+ FetchContent_MakeAvailable(boost_math)
171
+ # --------------------------------------------------------------------------#
172
+ set (USE_FP16_QK_REDUCTIONS "true" )
173
+ message (STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS} " )
174
+ else (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
175
+ set (USE_FP16_QK_REDUCTIONS "false" )
176
+ message (STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS} " )
177
+ endif (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
178
+
136
179
set (AOT_GENERATE_COMMAND
137
180
${Python3_EXECUTABLE} -m aot_build_utils.generate --path
138
- ${PROJECT_SOURCE_DIR} /src/ generated --head_dims ${HEAD_DIMS}
139
- --pos_encoding_modes ${POS_ENCODING_MODES} --use_fp16_qk_reductions
140
- ${USE_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} --enable_f16
141
- ${FLASHINFER_ENABLE_F16} --enable_bf16 ${FLASHINFER_ENABLE_BF16}
142
- --enable_fp8_e4m3 ${FLASHINFER_ENABLE_FP8_E4M3} --enable_fp8_e5m2
181
+ ${GENERATED_SOURCE_DIR} --head_dims ${HEAD_DIMS} --pos_encoding_modes
182
+ ${POS_ENCODING_MODES} --use_fp16_qk_reductions ${USE_FP16_QK_REDUCTIONS}
183
+ --mask_modes ${MASK_MODES} --enable_f16 ${FLASHINFER_ENABLE_F16}
184
+ --enable_bf16 ${FLASHINFER_ENABLE_BF16} --enable_fp8_e4m3
185
+ ${FLASHINFER_ENABLE_FP8_E4M3} --enable_fp8_e5m2
143
186
${FLASHINFER_ENABLE_FP8_E5M2} )
144
187
188
+ set (AOT_GENERATE_DISPATCH_INC_COMMAND
189
+ ${Python3_EXECUTABLE} -m aot_build_utils.generate_dispatch_inc --path
190
+ "${GENERATED_SOURCE_DIR} /dispatch.inc" --head_dims ${HEAD_DIMS}
191
+ --head_dims_sm90 ${HEAD_DIMS_SM90} --pos_encoding_modes
192
+ ${POS_ENCODING_MODES} --use_fp16_qk_reductions ${USE_FP16_QK_REDUCTIONS}
193
+ --mask_modes ${MASK_MODES} )
194
+
145
195
execute_process (COMMAND ${AOT_GENERATE_COMMAND}
146
196
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} )
197
+ execute_process (COMMAND ${AOT_GENERATE_DISPATCH_INC_COMMAND}
198
+ WORKING_DIRECTORY ${PROJECT_SOURCE_DIR} )
147
199
148
200
file (GLOB_RECURSE FLASHINFER_GENERATORS
149
201
${PROJECT_SOURCE_DIR} /aot_build_utils/*.py)
@@ -157,21 +209,33 @@ file(GLOB_RECURSE DISPATCH_INC_FILE
157
209
add_custom_command (
158
210
OUTPUT ${DECODE_KERNELS_SRCS} ${PREFILL_KERNELS_SRCS} ${DISPATCH_INC_FILE}
159
211
COMMAND ${AOT_GENERATE_COMMAND}
212
+ COMMAND ${AOT_GENERATE_DISPATCH_INC_COMMAND}
160
213
DEPENDS ${FLASHINFER_GENERATORS}
161
214
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
162
215
COMMENT "Generating kernel sources"
163
216
VERBATIM )
164
217
add_custom_target (dispatch_inc DEPENDS ${DISPATCH_INC_FILE} )
165
218
219
+ string (CONCAT CXX_FLAGS "-fpic " "-fPIC " )
220
+
221
+ string (CONCAT NVCC_FLAGS "-O3 " "--threads=1 " "-Xfatbin=-compress-all "
222
+ "-use_fast_math " "--expt-relaxed-constexpr " )
223
+
224
+ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_FLAGS} " )
225
+ set (CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS} " )
226
+
166
227
add_library (decode_kernels STATIC ${DECODE_KERNELS_SRCS} )
167
228
target_include_directories (decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR} )
168
- target_compile_options (decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options
169
- -compress-all )
229
+ if (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
230
+ target_link_libraries (decode_kernels PRIVATE Boost::math)
231
+ endif (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
170
232
171
233
add_library (prefill_kernels STATIC ${PREFILL_KERNELS_SRCS} )
172
234
target_include_directories (prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR} )
173
- target_compile_options (prefill_kernels PRIVATE -Xcompiler=-fPIC
174
- --fatbin-options -compress-all )
235
+ if (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
236
+ add_definitions (-DFP16_QK_REDUCTION_SUPPORTED)
237
+ target_link_libraries (prefill_kernels PRIVATE Boost::math)
238
+ endif (FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
175
239
176
240
if (FLASHINFER_DECODE)
177
241
message (STATUS "Compile single decode kernel benchmarks." )
0 commit comments