Skip to content

Commit 6e588da

Browse files
[Build/CI] Fix CUDA 11.8 build (#17679)
Signed-off-by: Tyler Michael Smith <[email protected]> Signed-off-by: Lucas Wilkinson <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Lucas Wilkinson <[email protected]>
1 parent f8d2cc5 commit 6e588da

File tree

9 files changed

+78
-15
lines changed

9 files changed

+78
-15
lines changed

CMakeLists.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
3030
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
3131

3232
# Supported NVIDIA architectures.
33-
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
33+
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL)
34+
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
35+
else()
36+
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
37+
endif()
3438

3539
# Supported AMD GPU architectures.
3640
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")

csrc/moe/moe_ops.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,6 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
2828
torch::Tensor num_tokens_post_pad, int64_t top_k,
2929
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
3030
int64_t BLOCK_SIZE_K, int64_t bit);
31-
#endif
31+
#endif
32+
33+
bool moe_permute_unpermute_supported();

csrc/moe/moe_permute_unpermute_op.cu

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
#include "permute_unpermute_kernels/dispatch.h"
66
#include "core/registration.h"
77

8+
// moe_permute kernels require at least CUDA 12.0
9+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
10+
811
void moe_permute(
912
const torch::Tensor& input, // [n_token, hidden]
1013
const torch::Tensor& topk_weights, //[n_token, topk]
@@ -127,7 +130,45 @@ void moe_unpermute(
127130
});
128131
}
129132

133+
#else
134+
135+
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
136+
torch::Tensor& topk_ids,
137+
const torch::Tensor& token_expert_indicies,
138+
const std::optional<torch::Tensor>& expert_map,
139+
int64_t n_expert, int64_t n_local_expert, int64_t topk,
140+
const std::optional<int64_t>& align_block_size,
141+
torch::Tensor& permuted_input,
142+
torch::Tensor& expert_first_token_offset,
143+
torch::Tensor& src_row_id2dst_row_id_map,
144+
torch::Tensor& m_indices) {
145+
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
146+
}
147+
148+
void moe_unpermute(const torch::Tensor& input,
149+
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
150+
const torch::Tensor& token_expert_indicies,
151+
const std::optional<torch::Tensor>& expert_map,
152+
int64_t n_expert, int64_t n_local_expert, int64_t topk,
153+
const std::optional<int64_t>& align_block_size,
154+
torch::Tensor& permuted_input,
155+
torch::Tensor& expert_first_token_offset,
156+
torch::Tensor& src_row_id2dst_row_id_map,
157+
torch::Tensor& m_indices) {
158+
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
159+
}
160+
161+
#endif
162+
163+
bool moe_permute_unpermute_supported() {
164+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
165+
return true;
166+
#else
167+
return false;
168+
#endif
169+
}
170+
130171
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
131172
m.impl("moe_permute", &moe_permute);
132173
m.impl("moe_unpermute", &moe_unpermute);
133-
}
174+
}

csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11

22
#include "moe_permute_unpermute_kernel.h"
33

4+
// moe_permute kernels require at least CUDA 12.0
5+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
6+
47
// CubKeyValueSorter definition begin
58
CubKeyValueSorter::CubKeyValueSorter()
69
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
@@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
131134
int num_experts) {
132135
auto tidx = threadIdx.x;
133136
auto bidx = blockIdx.x;
134-
auto lidx = tidx & 31;
135-
auto widx = tidx >> 5;
136-
auto warp_count = (blockDim.x + 31) >> 5;
137137
auto offset = bidx * blockDim.x;
138138
auto bound = min(offset + blockDim.x, size);
139139
extern __shared__ int smem_expert_map[];
@@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset,
226226
expert_first_token_offset, align_expert_first_token_offset, m_indices,
227227
num_local_expert, align_block_size);
228228
}
229-
}
229+
}
230+
231+
#endif

csrc/moe/torch_bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
7777
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
7878
"expert_first_token_offset, int n_expert, int n_local_expert,int "
7979
"topk, Tensor! hidden_states)->()");
80-
// conditionally compiled so impl registration is in source file
80+
81+
m.def("moe_permute_unpermute_supported() -> bool");
82+
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
8183

8284
#endif
8385
}

csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
123123
}
124124

125125
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
126-
// CUTLASS groped FP8 kernels need at least CUDA 12.3
126+
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
127127
// and SM90 (Hopper)
128128

129129
#if defined CUDA_VERSION

docker/Dockerfile

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,11 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
263263
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \
264264
else \
265265
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \
266-
fi && \
267-
export FLASHINFER_ENABLE_AOT=1; \
266+
fi; \
267+
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
268+
if [ "$CUDA_MAJOR" -lt 12 ]; then \
269+
export FLASHINFER_ENABLE_SM90=0; \
270+
fi; \
268271
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \
269272
fi
270273
COPY examples examples
@@ -275,7 +278,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
275278
. /etc/environment && \
276279
uv pip list
277280

278-
# Although we build Flashinfer with AOT mode, there's still
281+
# Even when we build Flashinfer with AOT mode, there's still
279282
# some issues w.r.t. JIT compilation. Therefore we need to
280283
# install build dependencies for JIT compilation.
281284
# TODO: Remove this once FlashInfer AOT wheel is fixed
@@ -303,8 +306,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
303306
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/[email protected]"
304307

305308
# install development dependencies (for testing)
306-
RUN --mount=type=cache,target=/root/.cache/uv \
307-
uv pip install --system -r requirements/dev.txt
309+
RUN --mount=type=cache,target=/root/.cache/uv \
310+
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
311+
if [ "$CUDA_MAJOR" -ge 12 ]; then \
312+
uv pip install --system -r requirements/dev.txt; \
313+
fi
308314

309315
# install development dependencies (for testing)
310316
RUN --mount=type=cache,target=/root/.cache/uv \

tests/kernels/moe/test_moe_permute_unpermute.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
1414
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
1515
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
16-
moe_permute, moe_unpermute)
16+
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
1717
from vllm.platforms import current_platform
1818

1919
NUM_EXPERTS = [16, 64]
@@ -167,6 +167,8 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
167167
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
168168
n_expert: int, ep_size: int, dtype: torch.dtype,
169169
align_block_size: Optional[int]):
170+
if not moe_permute_unpermute_supported():
171+
pytest.skip("moe_permute_unpermute is not supported on this platform.")
170172
fill_invalid_expert = 0
171173
ep_rank = np.random.randint(0, ep_size)
172174
expert_map = None

vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,7 @@ def moe_unpermute(
182182
expert_first_token_offset, n_expert,
183183
n_local_expert, topk, hidden_states)
184184
return hidden_states
185+
186+
187+
def moe_permute_unpermute_supported():
188+
return torch.ops._moe_C.moe_permute_unpermute_supported()

0 commit comments

Comments
 (0)