Skip to content

[Build/CI] Fix CUDA 11.8 build #17679

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fe02909
[Build/CI] Disable moe_permute_unpermute kernels on CUDA 11.8
tlrmchlsmth May 5, 2025
2700c8d
Unblock 11.8 build for testing
tlrmchlsmth May 6, 2025
961c060
naming
tlrmchlsmth May 6, 2025
c6d8786
Merge branch 'main' into fix_permute_build
tlrmchlsmth May 12, 2025
d009766
revert
tlrmchlsmth May 12, 2025
865f611
try to fix 11.8 flashinfer issues
tlrmchlsmth May 15, 2025
add6cc3
fixup?
tlrmchlsmth May 15, 2025
9716bf4
potential fix for 11.8
LucasWilkinson May 15, 2025
5fb0487
Merge remote-tracking branch 'nm/lwilkinson/fix-11.8-build' into fix_…
tlrmchlsmth May 16, 2025
ece9ced
Merge branch 'main' into fix_permute_build
tlrmchlsmth May 16, 2025
b4c526c
revert cuda.py changes
tlrmchlsmth May 16, 2025
c691b4a
[Build/CI] Disable moe_permute_unpermute kernels on CUDA 11.8
tlrmchlsmth May 5, 2025
4225c43
Unblock 11.8 build for testing
tlrmchlsmth May 6, 2025
e7d92a2
naming
tlrmchlsmth May 6, 2025
d87632d
revert
tlrmchlsmth May 12, 2025
9172dbb
try to fix 11.8 flashinfer issues
tlrmchlsmth May 15, 2025
b50fadc
fixup?
tlrmchlsmth May 15, 2025
a00364f
potential fix for 11.8
LucasWilkinson May 15, 2025
69c3687
revert cuda.py changes
tlrmchlsmth May 16, 2025
f6aab1b
Merge branch 'main' into fix_permute_build
tlrmchlsmth May 20, 2025
27bfbbc
Merge branch 'fix_permute_build' of https://github.com/vllm-project/v…
tlrmchlsmth May 20, 2025
1218067
Fixup
tlrmchlsmth May 21, 2025
dd74a07
respect flashinfer_use_aot arg
tlrmchlsmth May 21, 2025
f156e06
FLASHINFER_ENABLE_SM90=0 env
tlrmchlsmth May 21, 2025
dce957f
dont install dev dependency for 11.8
LucasWilkinson May 22, 2025
729310b
add missing binding
LucasWilkinson May 22, 2025
02614b8
Merge remote-tracking branch 'origin/main' into fix_permute_build
LucasWilkinson May 22, 2025
d6d9fa8
remove turning off AOT
LucasWilkinson May 22, 2025
3bbecc6
fix registration
LucasWilkinson May 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL)
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")
else()
set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0")
endif()

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201")
Expand Down
44 changes: 43 additions & 1 deletion csrc/moe/moe_permute_unpermute_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"

// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)

void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
Expand Down Expand Up @@ -127,7 +130,46 @@ void moe_unpermute(
});
}

#else

void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indicies,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
}

void moe_unpermute(const torch::Tensor& input,
const torch::Tensor& topk_weights, torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indicies,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_unpermute is not supported on CUDA < 12.0");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it mean we don't want to support these models or we would like to add some fallback logic?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT, these functions aren't called anywhere yet, outside of tests. @CalebDu is this right?

These kernels are an optimization added in #14568. IIUC, this would be used in conjunction with the CUTLASS kernels, which also need CUDA 12.0. Otherwise we can fall back to the triton MoE kernels

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT, these functions aren't called anywhere yet, outside of tests. @CalebDu is this right?

These kernels are an optimization added in #14568. IIUC, this would be used in conjunction with the CUTLASS kernels, which also need CUDA 12.0. Otherwise we can fall back to the triton MoE kernels

Yes, customized permute/unpermute kernel only are called in test now. But I have no idea why these kernel are incompatible to cuda 11.8.

}

#endif

bool moe_permute_unpermute_supported() {
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
return true;
#else
return false;
#endif
}

TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}
m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@

#include "moe_permute_unpermute_kernel.h"

// moe_permute kernels require at least CUDA 12.0
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)

// CubKeyValueSorter definition begin
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
Expand Down Expand Up @@ -131,9 +134,6 @@ __global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
auto lidx = tidx & 31;
auto widx = tidx >> 5;
auto warp_count = (blockDim.x + 31) >> 5;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
Expand Down Expand Up @@ -226,4 +226,6 @@ void getMIndices(int64_t* expert_first_token_offset,
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}
}

#endif
1 change: 1 addition & 0 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()");
m.def("moe_permute_unpermute_supported() -> bool");
// conditionally compiled so impl registration is in source file

#endif
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
}

bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
// CUTLASS groped FP8 kernels need at least CUDA 12.3
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
// and SM90 (Hopper)

#if defined CUDA_VERSION
Expand Down
16 changes: 11 additions & 5 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,11 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \
else \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \
fi && \
export FLASHINFER_ENABLE_AOT=1; \
fi; \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -lt 12 ]; then \
export FLASHINFER_ENABLE_SM90=0; \
fi; \
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \
fi
COPY examples examples
Expand All @@ -275,7 +278,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
. /etc/environment && \
uv pip list

# Although we build Flashinfer with AOT mode, there's still
# Even when we build Flashinfer with AOT mode, there's still
# some issues w.r.t. JIT compilation. Therefore we need to
# install build dependencies for JIT compilation.
# TODO: Remove this once FlashInfer AOT wheel is fixed
Expand Down Expand Up @@ -303,8 +306,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/[email protected]"

# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install --system -r requirements/dev.txt
RUN --mount=type=cache,target=/root/.cache/uv \
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
if [ "$CUDA_MAJOR" -ge 12 ]; then \
uv pip install --system -r requirements/dev.txt; \
fi

# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
Expand Down
4 changes: 3 additions & 1 deletion tests/kernels/moe/test_moe_permute_unpermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
moe_permute, moe_permute_unpermute_supported, moe_unpermute)
from vllm.platforms import current_platform

NUM_EXPERTS = [16, 64]
Expand Down Expand Up @@ -167,6 +167,8 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
n_expert: int, ep_size: int, dtype: torch.dtype,
align_block_size: Optional[int]):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,7 @@ def moe_unpermute(
expert_first_token_offset, n_expert,
n_local_expert, topk, hidden_states)
return hidden_states


def moe_permute_unpermute_supported():
return torch.ops._moe_C.moe_permute_unpermute_supported()