|
| 1 | +"""Tests for the MOE layers. |
| 2 | +
|
| 3 | +Run `pytest tests/kernels/test_moe.py`. |
| 4 | +""" |
| 5 | + |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | + |
| 9 | +from transformers import MixtralConfig |
| 10 | +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock |
| 11 | + |
| 12 | +from vllm.model_executor.layers.fused_moe import fused_moe |
| 13 | +from vllm.model_executor.layers.activation import SiluAndMul |
| 14 | +from vllm.model_executor.models.mixtral import MixtralMoE |
| 15 | + |
| 16 | + |
| 17 | +def torch_moe(a, w1, w2, topk_weight, topk_ids): |
| 18 | + B, D = a.shape |
| 19 | + a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) |
| 20 | + out = torch.zeros(B * topk_ids.shape[1], |
| 21 | + w2.shape[1], |
| 22 | + dtype=a.dtype, |
| 23 | + device=a.device) |
| 24 | + topk_ids = topk_ids.view(-1) |
| 25 | + topk_weight = topk_weight.view(-1) |
| 26 | + for i in range(w1.shape[0]): |
| 27 | + mask = topk_ids == i |
| 28 | + if mask.sum(): |
| 29 | + out[mask] = SiluAndMul()( |
| 30 | + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) |
| 31 | + return (out.view(B, -1, w2.shape[1]) * |
| 32 | + topk_weight.view(B, -1, 1)).sum(dim=1) |
| 33 | + |
| 34 | + |
| 35 | +@pytest.mark.parametrize("m", [512, 222, 33, 1]) |
| 36 | +@pytest.mark.parametrize("n", [2048, 256, 1024]) |
| 37 | +@pytest.mark.parametrize("k", [128, 511, 1024]) |
| 38 | +@pytest.mark.parametrize("e", [8, 64]) |
| 39 | +@pytest.mark.parametrize("topk", [2, 6]) |
| 40 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 41 | +def test_fused_moe( |
| 42 | + m: int, |
| 43 | + n: int, |
| 44 | + k: int, |
| 45 | + e: int, |
| 46 | + topk: int, |
| 47 | + dtype: torch.dtype, |
| 48 | +): |
| 49 | + a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 |
| 50 | + w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 |
| 51 | + w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 |
| 52 | + |
| 53 | + score = torch.randn((m, e), device='cuda', dtype=dtype) |
| 54 | + score = torch.softmax(score, dim=-1) |
| 55 | + topk_weight, topk_ids = torch.topk(score, topk) |
| 56 | + |
| 57 | + triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False) |
| 58 | + torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids) |
| 59 | + assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) |
| 60 | + |
| 61 | + |
| 62 | +@pytest.mark.parametrize("dtype", |
| 63 | + [torch.float32, torch.float16, torch.bfloat16]) |
| 64 | +@torch.inference_mode() |
| 65 | +def test_mixtral_moe(dtype: torch.dtype): |
| 66 | + "Make sure our Mixtral MoE implementation agrees with the one from huggingface." |
| 67 | + |
| 68 | + # Instantiate our and huggingface's MoE blocks |
| 69 | + config = MixtralConfig() |
| 70 | + hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") |
| 71 | + vllm_moe = MixtralMoE( |
| 72 | + num_experts=config.num_local_experts, |
| 73 | + top_k=config.num_experts_per_tok, |
| 74 | + hidden_size=config.hidden_size, |
| 75 | + intermediate_size=config.intermediate_size, |
| 76 | + params_dtype=dtype, |
| 77 | + tp_size=1, |
| 78 | + ) |
| 79 | + |
| 80 | + # Load the weights |
| 81 | + vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data |
| 82 | + for i in range(config.num_local_experts): |
| 83 | + weights = (hf_moe.experts[i].w1.weight.data, |
| 84 | + hf_moe.experts[i].w3.weight.data) |
| 85 | + vllm_moe.ws[i][:] = torch.cat(weights, dim=0) |
| 86 | + vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data |
| 87 | + |
| 88 | + # Generate input batch of dimensions [batch_size, seq_len, hidden_dim] |
| 89 | + inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") |
| 90 | + |
| 91 | + # Run forward passes for both MoE blocks |
| 92 | + hf_states, _ = hf_moe.forward(inputs) |
| 93 | + vllm_states = vllm_moe.forward(inputs) |
| 94 | + |
| 95 | + mixtral_moe_tol = { |
| 96 | + torch.float32: 1e-3, |
| 97 | + torch.float16: 1e-3, |
| 98 | + torch.bfloat16: 1e-2, |
| 99 | + } |
| 100 | + |
| 101 | + assert torch.allclose(hf_states, |
| 102 | + vllm_states, |
| 103 | + rtol=mixtral_moe_tol[dtype], |
| 104 | + atol=mixtral_moe_tol[dtype]) |
0 commit comments