Skip to content

Commit c6e7fd5

Browse files
authored
Refactor turbomind (low-level abstractions) (#3423)
* low level abstraction * refactor * eliminate template * remove unused * refactor bindings * simplify lm head * refactor weight * fix tp * cublas * refactor sampling * remove unused * simplify * fix AWQ support * fix moe * fix nccl lm_head * fix * refactor data types * skip legacy ut * simplify * rename data types * refactor * refactor runtime states * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * format * remove unused * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix msvc build * fix ut & msvc build * fix ut & msvc build * fix gcc build * fix lint & ut * fix lint * fetch Catch2 when building tests * rewind msvc build * fix sampling
1 parent 64eb6d3 commit c6e7fd5

File tree

194 files changed

+6945
-23658
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

194 files changed

+6945
-23658
lines changed

.github/workflows/windows-x64-gpu.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ jobs:
5050
INPUT_CUDA_VERSION: ${{ matrix.cudaver }}
5151
- name: Build wheel
5252
run: |
53-
$env:BUILD_TEST="ON"
53+
$env:BUILD_TEST="OFF"
5454
mkdir build
5555
cd build
5656
..\builder\windows\generate.ps1
57-
cmake --build . --config Release -- /m /v:q
57+
cmake --build . --config Release -- /m /v:n
5858
if (-Not $?) {
5959
echo "build failed"
6060
exit 1

CMakeLists.txt

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,28 @@
1515
cmake_minimum_required(VERSION 3.11 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13
1616
project(TurboMind LANGUAGES CXX CUDA)
1717

18-
find_package(CUDA 10.2 REQUIRED)
18+
if (MSVC)
19+
# use standard conformant preprocessor
20+
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/Zc:preprocessor>)
21+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=/Zc:preprocessor")
22+
endif ()
1923

2024
find_package(CUDAToolkit REQUIRED)
2125

22-
if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11")
26+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11")
2327
add_definitions("-DENABLE_BF16")
24-
message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag")
2528
endif()
2629

2730
set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
2831

2932
option(BUILD_MULTI_GPU "Build multi-gpu support" ON)
3033
option(BUILD_PY_FFI "Build python ffi" ON)
3134
option(BUILD_TEST "Build tests" OFF)
35+
option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF)
36+
option(BUILD_FAST_MATH "Build in fast math mode" ON)
3237

3338
include(FetchContent)
39+
3440
if (BUILD_TEST)
3541
FetchContent_Declare(
3642
repo-cutlass
@@ -45,6 +51,14 @@ if (BUILD_TEST)
4551

4652
set(CUTLASS_HEADER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include)
4753
set(CUTLASS_EXTENSIONS_DIR ${PROJECT_SOURCE_DIR}/src/turbomind/cutlass_extensions/include)
54+
55+
56+
FetchContent_Declare(
57+
Catch2
58+
GIT_REPOSITORY https://github.com/catchorg/Catch2.git
59+
GIT_TAG v3.8.0
60+
)
61+
FetchContent_MakeAvailable(Catch2)
4862
endif()
4963

5064
FetchContent_Declare(
@@ -56,10 +70,6 @@ set(YAML_BUILD_SHARED_LIBS OFF CACHE BOOL "Build static library of yaml-cpp")
5670
FetchContent_MakeAvailable(yaml-cpp)
5771

5872

59-
option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF)
60-
61-
option(BUILD_FAST_MATH "Build in fast math mode" ON)
62-
6373
# the environment variable
6474
# ASAN_OPTIONS=protect_shadow_gap=0,intercept_tls_get_addr=0
6575
# must be set at runtime
@@ -112,13 +122,13 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl") # -Xptxas -v
112122
# TODO: build for sm_72 & sm_87 on aarch64 platform (Jetson devices)
113123
if (NOT CMAKE_CUDA_ARCHITECTURES)
114124
set(CMAKE_CUDA_ARCHITECTURES 70-real 75-real)
115-
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11")
125+
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11")
116126
list(APPEND CMAKE_CUDA_ARCHITECTURES 80-real)
117127
endif ()
118-
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11.1")
128+
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11.1")
119129
list(APPEND CMAKE_CUDA_ARCHITECTURES 86-real)
120130
endif ()
121-
if (${CUDA_VERSION} VERSION_GREATER_EQUAL "11.8")
131+
if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL "11.8")
122132
list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-real)
123133
endif ()
124134
if (MSVC)
@@ -132,19 +142,23 @@ set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
132142
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
133143
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
134144
# set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall --ptxas-options=-v --resource-usage")
135-
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall -DCUDA_PTX_FP8_F2FP_ENABLED")
145+
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall")
136146

137147
set(CMAKE_CXX_STANDARD "${CXX_STD}")
138148
set(CMAKE_CXX_STANDARD_REQUIRED ON)
139149
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda")
140150
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
141-
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED")
151+
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD}")
152+
153+
string(REPLACE "-O2" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}")
154+
string(REPLACE "-O2" "" CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE}")
155+
string(REPLACE "-O2" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}")
156+
string(REPLACE "-O2" "" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")
142157

143-
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
144-
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3")
145-
# set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose")
146-
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
147-
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED")
158+
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
159+
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -O3")
160+
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -O3")
161+
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO} -O3")
148162

149163
if(BUILD_FAST_MATH)
150164
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math")
@@ -207,13 +221,11 @@ link_directories(
207221
${COMMON_LIB_DIRS}
208222
)
209223

210-
# add_subdirectory(3rdparty)
211224
add_subdirectory(src)
212-
# add_subdirectory(examples)
213225

214-
if(BUILD_TEST)
215-
add_subdirectory(tests/csrc)
216-
endif()
226+
# if(BUILD_TEST)
227+
# add_subdirectory(tests/csrc)
228+
# endif()
217229

218230
# install python api
219231
if (BUILD_PY_FFI)

builder/windows/generate.ps1

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@ cmake .. -A x64 -T "v142,cuda=$env:CUDA_PATH" `
33
-DCMAKE_INSTALL_PREFIX=install `
44
-DBUILD_PY_FFI=ON `
55
-DBUILD_MULTI_GPU=OFF `
6-
-DCMAKE_CUDA_FLAGS="-lineinfo" `
7-
-DUSE_NVTX=ON `
6+
-DUSE_NVTX=OFF `
87
-DBUILD_TEST="$env:BUILD_TEST"

builder/windows/setup_cuda.ps1

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ if ($CUDA_VERSION_FULL -eq "12.1.0") {
2424
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_531.14_windows.exe"
2525
} elseif ($CUDA_VERSION_FULL -eq "11.8.0") {
2626
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe"
27+
} elseif ($CUDA_VERSION_FULL -eq "12.5.0") {
28+
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.85_windows.exe"
2729
} else {
2830
Write-Output "Unsupported CUDA version specified"
2931
exit 1
@@ -84,6 +86,8 @@ $msBuildExtensions = (Get-ChildItem "$src\visual_studio_integration\CUDAVisualS
8486
}
8587
}
8688

89+
$CUDA_FLAGS="-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH=1"
90+
8791
# Add to Github env
8892
Write-Output "Setting environment variables for GitHub Actions..."
8993

@@ -97,7 +101,7 @@ Write-Output "CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)=$dst" >> $env:GITHUB_ENV
97101
Write-Output "CUDA_PATH_VX_Y=CUDA_PATH_V$($CUDA_MAJOR)_$($CUDA_MINOR)" >> $env:GITHUB_ENV
98102
Write-Output "CudaToolkitDir=$dst" >> $env:GITHUB_ENV
99103
Write-Output "CMAKE_CUDA_COMPILER=$dst\bin\nvcc.exe" >> $env:GITHUB_ENV
100-
Write-Output "NVCC_APPEND_FLAGS=-allow-unsupported-compiler" >> $env:GITHUB_ENV
104+
Write-Output "NVCC_APPEND_FLAGS=$CUDA_FLAGS" >> $env:GITHUB_ENV
101105

102106
Write-Output "CUDA_VERSION=$CUDA_VERSION_FULL" >> $env:GITHUB_ENV
103107
Write-Output "Setup completed."

lmdeploy/turbomind/deploy/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,8 @@ def pad_weight(tensor: torch.Tensor, tp: int):
319319
if output_weight is not None:
320320
tp = self.model.attn_tp_size
321321
output_weight = pad_weight(output_weight, tp=tp)
322-
self.model.save_split(output_weight, 'output.weight', split_dim=0, split_num=tp)
322+
# transpose
323+
self.model.save_split(output_weight.t(), 'output.weight', split_dim=1, split_num=tp)
323324

324325

325326
class Transformer:

lmdeploy/turbomind/turbomind.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def _from_hf(self, model_source: ModelSource, model_path: str, engine_config: Tu
241241

242242
model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir='',
243243
config=yaml.safe_dump(self.config_dict),
244-
data_type=self.config.model_config.weight_type)
244+
weight_type=self.config.model_config.weight_type)
245245

246246
# create empty weight
247247
self._create_weight(model_comm)
@@ -275,7 +275,7 @@ def _from_workspace(self, model_path: str, engine_config: TurbomindEngineConfig)
275275
weight_dir = osp.join(model_path, 'triton_models', 'weights')
276276
model_comm = _tm.AbstractTransformerModel.create_llama_model(model_dir=weight_dir,
277277
config=yaml.safe_dump(self.config_dict),
278-
data_type=self.config.weight_type)
278+
weight_type=self.config.weight_type)
279279

280280
# create weight and load params
281281
self._create_weight(model_comm)

src/turbomind/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
add_subdirectory(utils)
16+
add_subdirectory(core)
1617
add_subdirectory(kernels)
1718
add_subdirectory(layers)
1819
add_subdirectory(comm)

src/turbomind/comm/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
cmake_minimum_required(VERSION 3.8)
44

55
add_library(host_comm STATIC host_comm.cc thread_comm.cc)
6+
target_link_libraries(host_comm PRIVATE core logger)
67
set_property(TARGET host_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
78

89
add_library(device_comm STATIC device_comm.cc)
9-
target_link_libraries(device_comm PRIVATE logger)
10+
target_link_libraries(device_comm PRIVATE core logger)
1011
set_property(TARGET device_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
1112
set_property(TARGET device_comm PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
1213

@@ -21,7 +22,7 @@ if (BUILD_MULTI_GPU)
2122

2223
if (BUILD_TEST)
2324
add_executable(test_comm test_comm.cu)
24-
target_link_libraries(test_comm PRIVATE device_comm host_comm pthread nvtx_utils)
25+
target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)
2526
target_compile_options(test_comm PRIVATE -O3 -march=native -mtune=native)
2627
endif ()
2728
endif ()

src/turbomind/comm/cuda_ipc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ add_library(cuda_ipc_comm STATIC
1212
target_link_libraries(cuda_ipc_comm PRIVATE
1313
rms_norm
1414
host_comm
15+
core
16+
cuda_utils
1517
CUDA::cuda_driver
1618
logger)
1719

src/turbomind/comm/cuda_ipc/allgather.cu

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "src/turbomind/comm/cuda_ipc/device_semaphore.h"
55

66
#include "src/turbomind/kernels/core/meta.h"
7-
#include "src/turbomind/utils/Tensor.h"
87
#include "src/turbomind/utils/cuda_utils.h"
98

109
namespace turbomind::comm {
@@ -51,7 +50,7 @@ __global__ void __launch_bounds__(1024, 1) Allgather_Simple_Pull(T*
5150
void CudaIpcCommImpl::AllGather(
5251
const void* sendbuff, void* recvbuff, size_t sendcount, DataType type, int group, cudaStream_t stream)
5352
{
54-
const size_t bytesize = get_elem_size(type) * sendcount;
53+
const size_t bytesize = turbomind::byte_size(type) * sendcount;
5554

5655
const int peers = this->n_ranks(group) - 1;
5756
const int rank = this->rank(group);
@@ -165,9 +164,9 @@ void CudaIpcCommImpl::AllGather2D(const void* sendbuff,
165164
int group,
166165
cudaStream_t stream)
167166
{
168-
const size_t byte_width = get_elem_size(type) * width;
169-
const size_t byte_pitch = get_elem_size(type) * pitch;
170-
const size_t byte_stride = get_elem_size(type) * stride;
167+
const size_t byte_width = byte_size(type, width);
168+
const size_t byte_pitch = byte_size(type, pitch);
169+
const size_t byte_stride = byte_size(type, stride);
171170

172171
void* base{};
173172
size_t offset{};

src/turbomind/comm/cuda_ipc/allreduce.cu

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
#include "src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h"
77
#include "src/turbomind/comm/cuda_ipc/device_semaphore.h"
88

9+
#include "src/turbomind/core/data_type.h"
910
#include "src/turbomind/kernels/core/array_ops.h"
1011
#include "src/turbomind/kernels/core/meta.h"
11-
#include "src/turbomind/utils/Tensor.h"
1212

1313
#include "src/turbomind/utils/cuda_utils.h"
1414

@@ -423,14 +423,7 @@ void CudaIpcCommImpl::AllReduceSum(
423423
}
424424
};
425425

426-
switch (type) {
427-
case DataType::TYPE_FP16:
428-
return invoke(half{});
429-
case DataType::TYPE_BF16:
430-
return invoke(nv_bfloat16{});
431-
default:
432-
throw std::runtime_error("not implemented");
433-
}
426+
TM_DISPATCH_PRIMARY_DTYPES(type, invoke);
434427
}
435428

436429
} // namespace turbomind::comm

src/turbomind/comm/cuda_ipc/cuda_ipc_comm.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
// Copyright (c) OpenMMLab. All rights reserved.
22

33
#include <memory>
4-
#include <mutex>
5-
#include <type_traits>
4+
#include <numeric>
65
#include <vector>
76

87
#include <cuda.h>

src/turbomind/comm/cuda_ipc/cuda_ipc_comm.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
#include "src/turbomind/kernels/core/array.h"
1212

13-
#include "src/turbomind/utils/Tensor.h"
1413
#include "src/turbomind/utils/cuda_utils.h"
1514

1615
namespace turbomind::comm {

src/turbomind/comm/cuda_ipc/fused_allreduce.cu

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
#include "src/turbomind/comm/cuda_ipc/device_semaphore.h"
99
#include "src/turbomind/comm/cuda_ipc/group_sum.h"
1010

11+
#include "src/turbomind/core/data_type.h"
1112
#include "src/turbomind/kernels/core/array_ops.h"
1213
#include "src/turbomind/kernels/core/common.h"
1314
#include "src/turbomind/kernels/core/meta.h"
1415

1516
#include "src/turbomind/kernels/norm/rms_norm.h"
1617

17-
#include "src/turbomind/utils/Tensor.h"
1818
#include "src/turbomind/utils/cuda_utils.h"
1919

2020
namespace turbomind::comm {
@@ -424,7 +424,7 @@ void CudaIpcCommImpl::AllreduceResidualBiasRMSnorm(void* hidden,
424424
cudaStream_t stream)
425425
{
426426

427-
const size_t elemsize = get_elem_size(dtype);
427+
const size_t elemsize = byte_size(dtype);
428428
const size_t bytesize = elemsize * token_num * dim;
429429

430430
const int n_ranks = this->n_ranks(group);
@@ -504,19 +504,10 @@ void CudaIpcCommImpl::AllreduceResidualBiasRMSnorm(void* hidden,
504504
return false; // > 1024 vdim
505505
};
506506

507-
auto dispatch = [&] {
508-
switch (dtype) {
509-
case DataType::TYPE_FP16:
510-
return dispatch_D(half{});
511-
case DataType::TYPE_BF16:
512-
return dispatch_D(nv_bfloat16{});
513-
default:
514-
return false;
515-
}
516-
};
507+
auto dispatch = [&]() -> bool { TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D); };
517508

518509
if (bytesize > (1 << 19)) {
519-
if (auto success = dispatch()) {
510+
if (dispatch()) {
520511
return;
521512
}
522513
}

src/turbomind/comm/cuda_ipc/fused_allreduce_ex.cu

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "src/turbomind/comm/cuda_ipc/group_sum.h"
66

77
#include "src/turbomind/comm/cuda_ipc/mscclpp.h"
8+
#include "src/turbomind/core/data_type.h"
89
#include "src/turbomind/kernels/core/array_ops.h"
910
#include "src/turbomind/kernels/core/common.h"
1011
#include "src/turbomind/kernels/core/meta.h"
@@ -279,18 +280,11 @@ void CudaIpcCommImpl::AllreduceResidualBiasRMSnormEx(void* hidden,
279280
return false; // > 1024 vdim
280281
};
281282

282-
auto dispatch = [&] {
283-
switch (dtype) {
284-
case DataType::TYPE_FP16:
285-
return dispatch_D(half{});
286-
case DataType::TYPE_BF16:
287-
return dispatch_D(nv_bfloat16{});
288-
default:
289-
return false;
290-
}
283+
auto dispatch = [&]() -> bool { //
284+
TM_DISPATCH_PRIMARY_DTYPES_RET(dtype, dispatch_D);
291285
};
292286

293-
FT_CHECK(dispatch());
287+
TM_CHECK(dispatch());
294288
}
295289

296290
} // namespace turbomind::comm

src/turbomind/comm/device_comm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ DeviceComm CreateDeviceCommunicator(const std::string& backend, int n_ranks, int
2525
}
2626
#endif
2727

28-
FT_CHECK_WITH_INFO(0, fmtstr("Unknown communication backend: %s", backend.c_str()));
28+
TM_CHECK(0) << "Unknown communication backend: " << backend;
2929
return {};
3030
}
3131

0 commit comments

Comments
 (0)