Skip to content

Commit 7eb4255

Browse files
[BugFix] Accuracy fix for llama4 int4 - improperly casted scales (vllm-project#16801)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 6a0f547 commit 7eb4255

File tree

3 files changed

+6
-9
lines changed

3 files changed

+6
-9
lines changed

csrc/moe/moe_wna16.cu

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
template <typename scalar_t, int bit, int GROUPS>
1414
__global__ void moe_wna16_gemm_kernel(
1515
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
16-
1716
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
1817
const uint32_t* __restrict__ qzeros,
1918

@@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
5453
if (token_index / top_k >= size_m) break;
5554

5655
num_valid_tokens = m + 1;
57-
if (blockIdx.z == 0 && offset_n < size_n)
58-
output[token_index * size_n + offset_n] = Dtype::int2num(0);
5956

6057
if (expert_id != -1) {
6158
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
@@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
284281
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
285282
int64_t BLOCK_SIZE_K, int64_t bit) {
286283
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
287-
auto options =
288-
torch::TensorOptions().dtype(input.dtype()).device(input.device());
284+
output.zero_();
289285

290286
const int num_experts = b_qweight.size(0);
291287
const int size_m = input.size(0);
@@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
302298
const uint32_t* b_qzeros_ptr;
303299
if (b_qzeros.has_value())
304300
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
305-
const float* topk_weights_ptr;
301+
const float* topk_weights_ptr = nullptr;
306302
if (topk_weights.has_value())
307-
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
303+
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
308304

309305
int groups_per_block_row = BLOCK_SIZE_K / group_size;
310306
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ def __init__(
422422

423423
if params_dtype is None:
424424
params_dtype = torch.get_default_dtype()
425+
self.params_dtype = params_dtype
425426

426427
# Note: here we guard against accessing the TP and DP groups when
427428
# uninitialized (this happens when testing)

vllm/model_executor/models/llama4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def custom_routing_function(
5151
renormalize: bool,
5252
) -> Tuple[torch.Tensor, torch.Tensor]:
5353
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
54-
router_scores = torch.sigmoid(router_scores.float()).to(
55-
hidden_states.dtype)
54+
# psuedo-standard is that the router scores are floats
55+
router_scores = torch.sigmoid(router_scores.float())
5656
return (router_scores, router_indices.to(torch.int32))
5757

5858
def __init__(self,

0 commit comments

Comments
 (0)