Skip to content

Commit ed88794

Browse files
authored
infra: open source XQA kernels (#3762)
Replace libtensorrt_llm_nvrtc_wrapper.so with its source code, which consists of two parts: 1. NVRTC glue code 2. XQA kernel code During TensorRT-LLM build, XQA kernel code is embedded as C++ arries via gen_cpp_header.py and passed to NVRTC for JIT compilation. Signed-off-by: Ming Wei <[email protected]>
1 parent 1ada3c9 commit ed88794

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+17570
-118
lines changed

cpp/CMakeLists.txt

+2-17
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,6 @@ else()
5151
message(STATUS "NVTX is enabled")
5252
endif()
5353

54-
if(EXISTS
55-
"${CMAKE_CURRENT_SOURCE_DIR}/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/nvrtcWrapper/CMakeLists.txt"
56-
)
57-
set(BUILD_NVRTC_WRAPPER_DEFAULT ON)
58-
else()
59-
set(BUILD_NVRTC_WRAPPER_DEFAULT OFF)
60-
endif()
61-
62-
option(BUILD_NVRTC_WRAPPER "Build nvrtc wrapper from source"
63-
${BUILD_NVRTC_WRAPPER_DEFAULT})
64-
65-
if(BUILD_NVRTC_WRAPPER)
66-
message(STATUS "Building nvrtc wrapper")
67-
else()
68-
message(STATUS "Importing nvrtc wrapper")
69-
endif()
70-
7154
if(EXISTS
7255
"${CMAKE_CURRENT_SOURCE_DIR}/tensorrt_llm/kernels/internal_cutlass_kernels/CMakeLists.txt"
7356
)
@@ -154,6 +137,8 @@ set(CURAND_LIB CUDA::curand)
154137
set(CUDA_DRV_LIB CUDA::cuda_driver)
155138
set(CUDA_NVML_LIB CUDA::nvml)
156139
set(CUDA_RT_LIB CUDA::cudart_static)
140+
set(NVRTC_LIB CUDA::nvrtc_static)
141+
set(NVRTC_BUILTINS_LIB CUDA::nvrtc_builtins_static)
157142
set(CMAKE_CUDA_RUNTIME_LIBRARY Static)
158143

159144
resolve_dirs(CUDAToolkit_INCLUDE_DIRS "${CUDAToolkit_INCLUDE_DIRS}")

cpp/kernels/xqa/CMakeLists.txt

+132
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
2+
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
5+
# use this file except in compliance with the License. You may obtain a copy of
6+
# the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
12+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
13+
# License for the specific language governing permissions and limitations under
14+
# the License.
15+
cmake_minimum_required(VERSION 3.18)
16+
project(xqa LANGUAGES CXX CUDA)
17+
18+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
19+
set(CMAKE_CXX_STANDARD 20)
20+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
21+
set(CMAKE_CUDA_STANDARD 17)
22+
set(CMAKE_CUDA_ARCHITECTURES 89-real 90a-real)
23+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
24+
25+
option(BUILD_XQA_TESTS "Build XQA tests" OFF)
26+
27+
# todo: remove include_directories link_directories and link libs like
28+
# CUDA::cuda_driver CUDA::cudart CUDA::nvrtc
29+
find_package(CUDAToolkit REQUIRED)
30+
31+
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
32+
33+
link_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/../lib64
34+
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/../lib)
35+
36+
set(CMAKE_CXX_FLAGS
37+
"${CMAKE_CXX_FLAGS} -march=haswell -Wfatal-errors -Wreturn-type -Wall -Wextra -Wno-unknown-pragmas"
38+
)
39+
set(CMAKE_CUDA_FLAGS
40+
"${CMAKE_CUDA_FLAGS} -allow-unsupported-compiler --expt-relaxed-constexpr -t 0 -res-usage"
41+
)
42+
set(CUDA_PTXAS_FLAGS "-warn-lmem-usage -warn-double-usage -warn-spills"
43+
)# -Werror -v
44+
set(CMAKE_CUDA_FLAGS_RELEASE
45+
"${CMAKE_CUDA_FLAGS_RELEASE} -lineinfo -keep --use_fast_math -Xptxas='${CUDA_PTXAS_FLAGS}'"
46+
)
47+
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -keep")
48+
# add_definitions(-DSPEC_DEC) set(CMAKE_CUDA_FLAGS_DEBUG
49+
# "${CMAKE_CUDA_FLAGS_RELEASE}")
50+
51+
set(XQA_SOURCES
52+
"cuda_hint.cuh"
53+
"defines.h"
54+
"ldgsts.cuh"
55+
"mha.h"
56+
"mhaUtils.cuh"
57+
"mma.cuh"
58+
"platform.h"
59+
"utils.cuh"
60+
"utils.h"
61+
"mha_stdheaders.cuh"
62+
"gmma.cuh"
63+
"gmma_impl.cuh"
64+
"barriers.cuh"
65+
"tma.h"
66+
"mha.cu"
67+
"mha_sm90.cu")
68+
69+
# For ${Python3_EXECUTABLE}
70+
find_package(Python3 COMPONENTS Interpreter REQUIRED)
71+
72+
set(XQA_SOURCES_H ${CMAKE_CURRENT_BINARY_DIR}/xqa_sources.h)
73+
add_custom_command(
74+
OUTPUT ${XQA_SOURCES_H}
75+
COMMAND ${Python3_EXECUTABLE} gen_cpp_header.py -o ${XQA_SOURCES_H}
76+
--cuda_root ${CUDAToolkit_LIBRARY_ROOT}
77+
COMMENT "Generating xqa_sources.h for XQAJIT..."
78+
DEPENDS gen_cpp_header.py ${XQA_SOURCES}
79+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
80+
VERBATIM)
81+
add_custom_target(xqa_sources_h DEPENDS ${XQA_SOURCES_H})
82+
83+
if(BUILD_XQA_TESTS)
84+
# GoogleTest Preparation - Code block copied from
85+
# https://google.github.io/googletest/quickstart-cmake.html
86+
include(FetchContent)
87+
FetchContent_Declare(
88+
googletest
89+
GIT_REPOSITORY https://github.com/google/googletest.git
90+
GIT_TAG v1.15.2)
91+
include(GoogleTest)
92+
93+
# Add Eigen via FetchContent
94+
FetchContent_Declare(
95+
eigen
96+
GIT_REPOSITORY https://gitlab.com/libeigen/eigen.git
97+
GIT_TAG 3.4.0)
98+
FetchContent_MakeAvailable(googletest eigen)
99+
100+
enable_testing()
101+
add_executable(unitTests mha.cu mha_sm90.cu test/test.cpp
102+
test/refAttention.cpp)
103+
target_include_directories(unitTests PUBLIC ${EIGEN3_INCLUDE_DIR})
104+
target_link_libraries(unitTests PUBLIC GTest::gtest GTest::gtest_main cuda
105+
Eigen3::Eigen)
106+
107+
find_library(
108+
NVRTC_LIB nvrtc
109+
HINTS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/../lib
110+
PATH_SUFFIXES lib64 lib lib/x64)
111+
if(NOT NVRTC_LIB)
112+
message("Nvrtc not found")
113+
add_definitions(-DENABLE_NVRTC=0)
114+
else()
115+
add_definitions(-DENABLE_NVRTC=1)
116+
target_link_libraries(unitTests PUBLIC ${NVRTC_LIB})
117+
# Generate xqa_sources.h for nvrtc testing.
118+
include_directories(${PROJECT_BINARY_DIR})
119+
set(GENERATED_XQA_SOURCES
120+
"${CMAKE_CURRENT_BINARY_DIR}/generated/xqa_sources.h")
121+
add_custom_command(
122+
OUTPUT ${GENERATED_XQA_SOURCES}
123+
COMMAND
124+
./gen_cpp_header.py -o ${GENERATED_XQA_SOURCES} --embed-cuda-headers
125+
--cuda_root ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/..
126+
DEPENDS gen_cpp_header.py ${XQA_SOURCES}
127+
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
128+
target_sources(unitTests PUBLIC ${GENERATED_XQA_SOURCES})
129+
endif()
130+
131+
add_test(NAME unitTests COMMAND unitTests)
132+
endif()

cpp/kernels/xqa/README.md

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<div align="left">
2+
3+
# XQA - A set of optimized kernels for generation-phase MQA/GQA
4+
5+
## Dependency
6+
7+
If you want to build & run unit tests, you need libgtest-dev and libeigen3-dev.
8+
9+
## Options
10+
11+
Kernel compile-time options can be found in defines.h. See code comments for details. Runtime options of unit tests can be modified in test.cpp.
12+
13+
## Build & run unit tests
14+
15+
You need to install libgtest-dev and libeigen3-dev before building. To build, use the normal cmake build steps:
16+
17+
- ```mkdir build```
18+
- ```cd build```
19+
- ```cmake .. -DCMAKE_BUILD_TYPE=Release```
20+
- ```cmake --build . -j```
21+
22+
To run unit tests, run `./unitTests`. There are a few runtime options that can be controlled with environment variables:
23+
24+
- XQA_ZERO_FILL: Set this to 1 to initialize input data with zeros (instead of random numbers). This is useful if you want to run perf tests quickly and skip the slow random data generation step. Note there is an impact on measure perf.
25+
- XQA_USE_QGMMA: On Hopper, we try to use TMA+QGMMA kernel (mha_sm90.cu) by default if possible. To force using mha.cu, set this to 0.
26+
- XQA_NB_SUB_SEQ: The number of CUDA thread blocks used to handle one K/V head. We have reasonable default but if you want to change it manually, use this variable.
27+
28+
## Generation cubins used in TensorRT-LLM
29+
30+
Run `gen_cubin.py` in the repo workspace.

cpp/kernels/xqa/RefChecker.cuh

+96
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
4+
*
5+
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
6+
* property and proprietary rights in and to this material, related
7+
* documentation and any modifications thereto. Any use, reproduction,
8+
* disclosure or distribution of this material and related documentation
9+
* without an express license agreement from NVIDIA CORPORATION or
10+
* its affiliates is strictly prohibited.
11+
*/
12+
13+
#pragma once
14+
#include "cuda_hint.cuh"
15+
#include "utils.cuh"
16+
#include <cassert>
17+
#include <cuda_fp16.h>
18+
#include <filesystem>
19+
#include <fstream>
20+
#include <sstream>
21+
#include <type_traits>
22+
23+
struct RefChecker
24+
{
25+
half q[8][32][32];
26+
half k[8][4][64][32];
27+
float qk[4][32][64];
28+
float tileRowMax[4][32];
29+
half x[4][32][64];
30+
half v[8][4][32][64];
31+
float tileRowSum[4][32];
32+
float acc1PerStep[4][32][256];
33+
half out[32][256];
34+
35+
void init()
36+
{
37+
#define INIT_MEMBER(member) initMember(member, #member)
38+
INIT_MEMBER(q);
39+
INIT_MEMBER(k);
40+
INIT_MEMBER(qk);
41+
INIT_MEMBER(tileRowMax);
42+
INIT_MEMBER(x);
43+
INIT_MEMBER(v);
44+
INIT_MEMBER(tileRowSum);
45+
INIT_MEMBER(acc1PerStep);
46+
INIT_MEMBER(out);
47+
#undef INIT_MEMBER
48+
}
49+
50+
private:
51+
template <typename T>
52+
void initMember(T& dst, char const* varName);
53+
};
54+
55+
template <typename T, size_t d0, size_t d1, size_t d2, size_t d3>
56+
std::enable_if_t<std::is_same_v<std::decay_t<T>, float> || std::is_same_v<std::decay_t<T>, half>, std::string>
57+
makeFileName(T (&dst)[d0][d1][d2][d3], char const* varName)
58+
{
59+
std::stringstream ss;
60+
ss << varName << '_' << d0 << 'x' << d1 << 'x' << d2 << 'x' << d3 << '_'
61+
<< (std::is_same_v<std::decay_t<T>, float> ? "f32" : "f16") << ".bin";
62+
return ss.str();
63+
}
64+
65+
template <typename T, size_t d0, size_t d1, size_t d2>
66+
std::enable_if_t<std::is_same_v<std::decay_t<T>, float> || std::is_same_v<std::decay_t<T>, half>, std::string>
67+
makeFileName(T (&dst)[d0][d1][d2], char const* varName)
68+
{
69+
std::stringstream ss;
70+
ss << varName << '_' << d0 << 'x' << d1 << 'x' << d2 << '_'
71+
<< (std::is_same_v<std::decay_t<T>, float> ? "f32" : "f16") << ".bin";
72+
return ss.str();
73+
}
74+
75+
template <typename T, size_t d0, size_t d1>
76+
std::enable_if_t<std::is_same_v<std::decay_t<T>, float> || std::is_same_v<std::decay_t<T>, half>, std::string>
77+
makeFileName(T (&dst)[d0][d1], char const* varName)
78+
{
79+
std::stringstream ss;
80+
ss << varName << '_' << d0 << 'x' << d1 << '_' << (std::is_same_v<std::decay_t<T>, float> ? "f32" : "f16")
81+
<< ".bin";
82+
return ss.str();
83+
}
84+
85+
template <typename T>
86+
void RefChecker::initMember(T& dst, char const* varName)
87+
{
88+
std::string const filename = makeFileName(dst, varName);
89+
printf("loading %s\n", filename.c_str());
90+
namespace fs = std::filesystem;
91+
assert(fs::exists(filename));
92+
assert(fs::file_size(filename) == sizeof(dst));
93+
std::ifstream fin(filename, std::ios::binary);
94+
fin.read(reinterpret_cast<char*>(&dst), sizeof(dst));
95+
assert(fin);
96+
}

0 commit comments

Comments
 (0)