Skip to content

Commit 3a69560

Browse files
authored
bugfix: Fix compilation with FP16_QK_REDUCTION enabled. (#962)
As described in #806 and #936, setting the cmake build flag `FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS` to "true" causes a build failure due to `cuda_fp16.h` not supporting `constexpr` cast from `__half` type to `float`. Note that the issue is not just a CMake/C++ configuration issue the issue will be triggered even in the flashinfer JIT code compilation path as reported in #915. The PR fixes #806 and #936 by adding a modified version of the FP16 header from the [FP16 library](https://github.com/Maratyszcza/FP16) that supports `constexpr` versions of the conversion functions. To make the conversion functions `constexpr`, I am using `std::bit_cast` that is the reason for bumping the required standard to 20. With these changes I am able to build the C++ API with both `FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS` ON and OFF. Fixes #936 Fixes #806
1 parent bc81a59 commit 3a69560

File tree

3 files changed

+300
-16
lines changed

3 files changed

+300
-16
lines changed

CMakeLists.txt

+76-12
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ project(flashinfer CUDA CXX)
33

44
include(cmake/utils/Utils.cmake)
55

6-
set(CMAKE_CXX_STANDARD 17)
7-
set(CMAKE_CUDA_STANDARD 17)
6+
set(CMAKE_CXX_STANDARD 20)
7+
set(CMAKE_CUDA_STANDARD 20)
88

99
if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake)
1010
include(${CMAKE_BINARY_DIR}/config.cmake)
@@ -63,7 +63,7 @@ flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256)
6363
flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0
6464
1 2)
6565
flashinfer_option(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS
66-
"QK reductions to enable" "false" "true")
66+
"QK reductions to enable" OFF)
6767
flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2)
6868

6969
if(DEFINED FLASHINFER_CUDA_ARCHITECTURES)
@@ -125,25 +125,77 @@ set(POS_ENCODING_MODES ${FLASHINFER_GEN_POS_ENCODING_MODES})
125125
set(USE_FP16_QK_REDUCTIONS ${FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS})
126126
set(MASK_MODES ${FLASHINFER_GEN_MASK_MODES})
127127

128+
set(SM90_ALLOWED_HEAD_DIMS "64,64" "128,128" "256,256" "192,128")
129+
set(HEAD_DIMS_SM90 "")
130+
131+
foreach(DIM_VAL ${HEAD_DIMS})
132+
string(CONCAT TUPLE_VAL "${DIM_VAL}" "," "${DIM_VAL}")
133+
list(FIND SM90_ALLOWED_HEAD_DIMS ${TUPLE_VAL} RESULT)
134+
if(NOT ${RESULT} EQUAL -1)
135+
list(APPEND HEAD_DIMS_SM90 ${TUPLE_VAL})
136+
endif(NOT ${RESULT} EQUAL -1)
137+
endforeach(DIM_VAL)
138+
139+
foreach(TUPLE_VAL ${SM90_ALLOWED_HEAD_DIMS})
140+
string(REPLACE "," ";" HEAD_DIMS_LIST ${TUPLE_VAL})
141+
list(GET HEAD_DIMS_LIST 0 K)
142+
list(GET HEAD_DIMS_LIST 1 V)
143+
if(NOT K EQUAL V)
144+
list(APPEND HEAD_DIMS_SM90 ${TUPLE_VAL})
145+
endif(NOT K EQUAL V)
146+
endforeach(TUPLE_VAL)
147+
148+
list(REMOVE_DUPLICATES HEAD_DIMS_SM90)
149+
128150
# log options
129151
message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
130152
message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
131153
message(STATUS "FLASHINFER_USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS}")
132154
message(STATUS "FLASHINFER_MASK_MODES=${MASK_MODES}")
133155

156+
# Log SM90_ALLOWED_HEAD_DIMS and HEAD_DIMS_SM90
157+
message(STATUS "SM90_ALLOWED_HEAD_DIMS=${SM90_ALLOWED_HEAD_DIMS}")
158+
message(STATUS "HEAD_DIMS_SM90=${HEAD_DIMS_SM90}")
159+
160+
set(GENERATED_SOURCE_DIR ${PROJECT_SOURCE_DIR}/src/generated)
134161
file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)
135162

163+
if(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
164+
# ----------------------------- Dependencies -------------------------------#
165+
include(FetchContent)
166+
167+
set(BOOST_ENABLE_CMAKE ON)
168+
FetchContent_Declare(boost_math
169+
GIT_REPOSITORY https://github.com/boostorg/math.git)
170+
FetchContent_MakeAvailable(boost_math)
171+
# --------------------------------------------------------------------------#
172+
set(USE_FP16_QK_REDUCTIONS "true")
173+
message(STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS}")
174+
else(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
175+
set(USE_FP16_QK_REDUCTIONS "false")
176+
message(STATUS "USE_FP16_QK_REDUCTIONS=${USE_FP16_QK_REDUCTIONS}")
177+
endif(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
178+
136179
set(AOT_GENERATE_COMMAND
137180
${Python3_EXECUTABLE} -m aot_build_utils.generate --path
138-
${PROJECT_SOURCE_DIR}/src/generated --head_dims ${HEAD_DIMS}
139-
--pos_encoding_modes ${POS_ENCODING_MODES} --use_fp16_qk_reductions
140-
${USE_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES} --enable_f16
141-
${FLASHINFER_ENABLE_F16} --enable_bf16 ${FLASHINFER_ENABLE_BF16}
142-
--enable_fp8_e4m3 ${FLASHINFER_ENABLE_FP8_E4M3} --enable_fp8_e5m2
181+
${GENERATED_SOURCE_DIR} --head_dims ${HEAD_DIMS} --pos_encoding_modes
182+
${POS_ENCODING_MODES} --use_fp16_qk_reductions ${USE_FP16_QK_REDUCTIONS}
183+
--mask_modes ${MASK_MODES} --enable_f16 ${FLASHINFER_ENABLE_F16}
184+
--enable_bf16 ${FLASHINFER_ENABLE_BF16} --enable_fp8_e4m3
185+
${FLASHINFER_ENABLE_FP8_E4M3} --enable_fp8_e5m2
143186
${FLASHINFER_ENABLE_FP8_E5M2})
144187

188+
set(AOT_GENERATE_DISPATCH_INC_COMMAND
189+
${Python3_EXECUTABLE} -m aot_build_utils.generate_dispatch_inc --path
190+
"${GENERATED_SOURCE_DIR}/dispatch.inc" --head_dims ${HEAD_DIMS}
191+
--head_dims_sm90 ${HEAD_DIMS_SM90} --pos_encoding_modes
192+
${POS_ENCODING_MODES} --use_fp16_qk_reductions ${USE_FP16_QK_REDUCTIONS}
193+
--mask_modes ${MASK_MODES})
194+
145195
execute_process(COMMAND ${AOT_GENERATE_COMMAND}
146196
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
197+
execute_process(COMMAND ${AOT_GENERATE_DISPATCH_INC_COMMAND}
198+
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR})
147199

148200
file(GLOB_RECURSE FLASHINFER_GENERATORS
149201
${PROJECT_SOURCE_DIR}/aot_build_utils/*.py)
@@ -157,21 +209,33 @@ file(GLOB_RECURSE DISPATCH_INC_FILE
157209
add_custom_command(
158210
OUTPUT ${DECODE_KERNELS_SRCS} ${PREFILL_KERNELS_SRCS} ${DISPATCH_INC_FILE}
159211
COMMAND ${AOT_GENERATE_COMMAND}
212+
COMMAND ${AOT_GENERATE_DISPATCH_INC_COMMAND}
160213
DEPENDS ${FLASHINFER_GENERATORS}
161214
WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}
162215
COMMENT "Generating kernel sources"
163216
VERBATIM)
164217
add_custom_target(dispatch_inc DEPENDS ${DISPATCH_INC_FILE})
165218

219+
string(CONCAT CXX_FLAGS "-fpic " "-fPIC ")
220+
221+
string(CONCAT NVCC_FLAGS "-O3 " "--threads=1 " "-Xfatbin=-compress-all "
222+
"-use_fast_math " "--expt-relaxed-constexpr ")
223+
224+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX_FLAGS}")
225+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS}")
226+
166227
add_library(decode_kernels STATIC ${DECODE_KERNELS_SRCS})
167228
target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
168-
target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options
169-
-compress-all)
229+
if(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
230+
target_link_libraries(decode_kernels PRIVATE Boost::math)
231+
endif(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
170232

171233
add_library(prefill_kernels STATIC ${PREFILL_KERNELS_SRCS})
172234
target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
173-
target_compile_options(prefill_kernels PRIVATE -Xcompiler=-fPIC
174-
--fatbin-options -compress-all)
235+
if(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
236+
add_definitions(-DFP16_QK_REDUCTION_SUPPORTED)
237+
target_link_libraries(prefill_kernels PRIVATE Boost::math)
238+
endif(FLASHINFER_GEN_USE_FP16_QK_REDUCTIONS)
175239

176240
if(FLASHINFER_DECODE)
177241
message(STATUS "Compile single decode kernel benchmarks.")

include/flashinfer/attention/prefill.cuh

+47-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
#ifndef FLASHINFER_PREFILL_CUH_
1717
#define FLASHINFER_PREFILL_CUH_
18+
1819
#include <cooperative_groups.h>
1920
#include <cuda_bf16.h>
2021
#include <cuda_fp16.h>
@@ -23,6 +24,9 @@
2324

2425
#include "../cp_async.cuh"
2526
#include "../fastdiv.cuh"
27+
#ifdef FP16_QK_REDUCTION_SUPPORTED
28+
#include "../fp16.h"
29+
#endif
2630
#include "../frag_layout_swizzle.cuh"
2731
#include "../math.cuh"
2832
#include "../mma.cuh"
@@ -33,7 +37,6 @@
3337
#include "cascade.cuh"
3438
#include "mask.cuh"
3539
#include "variants.cuh"
36-
3740
namespace flashinfer {
3841

3942
DEFINE_HAS_MEMBER(maybe_q_rope_offset)
@@ -133,9 +136,25 @@ struct KernelTraits {
133136

134137
using SharedStorage = SharedStorageQKVO<NUM_WARPS_KV, CTA_TILE_Q, CTA_TILE_KV, HEAD_DIM_QK,
135138
HEAD_DIM_VO, DTypeQ, DTypeKV, DTypeO>;
139+
#ifdef FP16_QK_REDUCTION_SUPPORTED
140+
template <typename DT>
141+
static constexpr DT getNegInf() {
142+
if constexpr (std::is_same<DT, __half>::value) {
143+
return std::bit_cast<half>(fp16_ieee_from_fp32_value(-math::inf));
144+
} else {
145+
return static_cast<DTypeQKAccum>(-math::inf);
146+
}
147+
}
136148

149+
static constexpr DTypeQKAccum MaskFillValue =
150+
AttentionVariant::use_softmax ? getNegInf<DTypeQKAccum>() : DTypeQKAccum(0.f);
151+
#else
152+
static_assert(!std::is_same<DTypeQKAccum, __half>::value,
153+
"Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math "
154+
"then recompile to support fp16 reduction");
137155
static constexpr DTypeQKAccum MaskFillValue =
138156
AttentionVariant::use_softmax ? DTypeQKAccum(-math::inf) : DTypeQKAccum(0.f);
157+
#endif
139158
};
140159

141160
namespace {
@@ -672,6 +691,8 @@ __device__ __forceinline__ void logits_transform(
672691
const uint32_t kv_head_idx = blockIdx.z) {
673692
const uint32_t lane_idx = tid.x;
674693
uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2];
694+
float logits = 0., logitsTransformed = 0.;
695+
675696
#pragma unroll
676697
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
677698
#pragma unroll
@@ -691,9 +712,31 @@ __device__ __forceinline__ void logits_transform(
691712
2 * (lane_idx % 4) +
692713
8 * (reg_id / 4) + reg_id % 2;
693714
const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2];
694-
s_frag[mma_q][mma_kv][reg_id] =
695-
variant.LogitsTransform(params, s_frag[mma_q][mma_kv][reg_id], batch_idx, q_idx, kv_idx,
696-
qo_head_idx, kv_head_idx);
715+
716+
#ifdef FP16_QK_REDUCTION_SUPPORTED
717+
if constexpr (std::is_same<DTypeQKAccum, __half>::value) {
718+
logits = std::bit_cast<float>(fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id]));
719+
} else if constexpr (!std::is_same<DTypeQKAccum, __half>::value) {
720+
logits = s_frag[mma_q][mma_kv][reg_id];
721+
}
722+
#else
723+
static_assert(!std::is_same<DTypeQKAccum, __half>::value,
724+
"Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math "
725+
"then recompile to support fp16 reduction");
726+
logits = s_frag[mma_q][mma_kv][reg_id];
727+
#endif
728+
logitsTransformed = variant.LogitsTransform(params, logits, batch_idx, q_idx, kv_idx,
729+
qo_head_idx, kv_head_idx);
730+
#ifdef FP16_QK_REDUCTION_SUPPORTED
731+
if constexpr (std::is_same<DTypeQKAccum, __half>::value) {
732+
s_frag[mma_q][mma_kv][reg_id] =
733+
std::bit_cast<half>(fp16_ieee_from_fp32_value(logitsTransformed));
734+
} else if constexpr (!std::is_same<DTypeQKAccum, __half>::value) {
735+
s_frag[mma_q][mma_kv][reg_id] = logitsTransformed;
736+
}
737+
#else
738+
s_frag[mma_q][mma_kv][reg_id] = logitsTransformed;
739+
#endif
697740
}
698741
}
699742
}

0 commit comments

Comments
 (0)