Skip to content

Commit 71f7af7

Browse files
authored
[Runtime] Use preferred host memory (pinned memory) in KV cache (#17036)
This PR updates the PagedKVCache with the pinned memory support, which can reduce the copy overhead between CPU and GPU. This PR also bumps FlashInfer version, which now supports * specifying kernels to build via cmake, * pinned memory as host memory. We also update CMakeLists.txt and config.cmake to include the FlashInfer compile options. Prior to this PR, the kernels being built is hardcoded in FlashInfer header files.
1 parent 8bdd54b commit 71f7af7

File tree

5 files changed

+205
-98
lines changed

5 files changed

+205
-98
lines changed

3rdparty/flashinfer

Submodule flashinfer updated 55 files

CMakeLists.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -960,13 +960,13 @@ option(USE_FLASHINFER "Build TVM with FlashInfer" OFF)
960960
if (USE_FLASHINFER STREQUAL "ON")
961961
message(STATUS "Build with FlashInfer")
962962
set(FLASHINFER_TVM_BINDING ON)
963-
set(FLASHINFER_TVM_HOME ${PROJECT_SOURCE_DIR})
964-
set(FLASHINFER_ENABLE_FP8 OFF)
965-
set(FLASHINFER_ENABLE_BF16 OFF)
963+
set(FLASHINFER_TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR})
966964
set(FLASHINFER_PREFILL OFF)
967965
set(FLASHINFER_DECODE OFF)
968966
set(FLASHINFER_PAGE OFF)
969967
set(FLASHINFER_CASCADE OFF)
968+
set(FLASHINFER_SAMPLING OFF)
969+
set(FLASHINFER_NORM OFF)
970970
add_subdirectory(3rdparty/flashinfer)
971971
else ()
972972
message(STATUS "Build without FlashInfer")

cmake/config.cmake

+13
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,19 @@ set(USE_GTEST AUTO)
444444
# Need to have USE_CUDA=ON
445445
set(USE_CUTLASS OFF)
446446

447+
# Whether to enable FlashInfer or not.
448+
set(USE_FLASHINFER OFF)
449+
# Options for FlashInfer kernel compilation.
450+
set(FLASHINFER_ENABLE_FP8 OFF)
451+
set(FLASHINFER_ENABLE_BF16 OFF)
452+
set(FLASHINFER_GEN_GROUP_SIZES 1 4 6 8)
453+
set(FLASHINFER_GEN_PAGE_SIZES 16)
454+
set(FLASHINFER_GEN_HEAD_DIMS 128)
455+
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
456+
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1)
457+
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false")
458+
set(FLASHINFER_GEN_CASUALS "false" "true")
459+
447460
# Enable to show a summary of TVM options
448461
set(SUMMARIZE OFF)
449462

include/tvm/runtime/ndarray.h

+17
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,23 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
534534
return true;
535535
}
536536

537+
/*!
538+
* \brief Get the preferred host device from the input device.
539+
* - For CUDA and ROCm, CUDAHost and ROCMHost will be returned for pinned memory,
540+
* since pinned memory reduces copy overhead.
541+
* - For other devices, CPU is returned as a fallback.
542+
*/
543+
inline Device GetPreferredHostDevice(Device device) {
544+
if (device.device_type == DLDeviceType::kDLCUDA) {
545+
return Device{DLDeviceType::kDLCUDAHost, 0};
546+
} else if (device.device_type == DLDeviceType::kDLROCM) {
547+
return Device{DLDeviceType::kDLROCMHost, 0};
548+
} else {
549+
// Fallback to CPU.
550+
return Device{DLDeviceType::kDLCPU, 0};
551+
}
552+
}
553+
537554
} // namespace runtime
538555
} // namespace tvm
539556

0 commit comments

Comments
 (0)