Skip to content

Commit daa5566

Browse files
authored
bugfix: fix dispatch fp16 type when enable fp8 (#430)
Fix #402 (comment)
1 parent d52f2da commit daa5566

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

python/csrc/pytorch_extension_utils.h

+4
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ using namespace flashinfer;
146146
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \
147147
[&]() -> bool { \
148148
switch (pytorch_dtype) { \
149+
case at::ScalarType::Half: { \
150+
using c_type = nv_half; \
151+
return __VA_ARGS__(); \
152+
} \
149153
case at::ScalarType::Float8_e4m3fn: { \
150154
using c_type = __nv_fp8_e4m3; \
151155
return __VA_ARGS__(); \

0 commit comments

Comments
 (0)