|
| 1 | +# Adapted from https://github.com/sgl-project/sglang/pull/2575 |
| 2 | +import itertools |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm.model_executor.layers.activation import SiluAndMul |
| 8 | +from vllm.model_executor.layers.fused_moe import fused_moe |
| 9 | +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( |
| 10 | + per_token_group_quant_fp8, w8a8_block_fp8_matmul) |
| 11 | +from vllm.platforms import current_platform |
| 12 | + |
| 13 | +if current_platform.get_device_capability() < (9, 0): |
| 14 | + pytest.skip("FP8 Triton requires CUDA 9.0 or higher", |
| 15 | + allow_module_level=True) |
| 16 | + |
| 17 | +# Test configurations |
| 18 | +DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] |
| 19 | +NUM_TOKENS = [7, 83, 2048] |
| 20 | +D = [512, 4096, 5120, 13824] |
| 21 | +GROUP_SIZE = [64, 128, 256, 512] |
| 22 | +M = [1, 7, 83, 512, 2048] |
| 23 | +N = [128, 512, 1024, 4096, 7748, 13824] |
| 24 | +K = [256, 4096, 5120, 3884, 13824] |
| 25 | +# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8 |
| 26 | +# and its hidden size is 7168. |
| 27 | +M_moe = [1, 7, 83, 512, 2048] |
| 28 | +N_moe = [4608] # [128, 4608, 13824] |
| 29 | +K_moe = [7168] # [256, 7168, 13824] |
| 30 | +BLOCK_SIZE = [[128, 128]] |
| 31 | +E = [256] # [8, 24, 128, 256] |
| 32 | +TOP_KS = [1] # [1, 2, 6] |
| 33 | +OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16] |
| 34 | +SEEDS = [0] |
| 35 | + |
| 36 | + |
| 37 | +def native_per_token_group_quant_fp8(x, |
| 38 | + group_size, |
| 39 | + eps=1e-10, |
| 40 | + dtype=torch.float8_e4m3fn): |
| 41 | + """Function to perform per-token-group quantization on an input tensor |
| 42 | + `x` using native torch.""" |
| 43 | + assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot " |
| 44 | + "be divisible by `group_size`") |
| 45 | + assert x.is_contiguous(), "`x` is not contiguous" |
| 46 | + |
| 47 | + finfo = torch.finfo(dtype) |
| 48 | + fp8_min = finfo.min |
| 49 | + fp8_max = finfo.max |
| 50 | + |
| 51 | + x_ = x.reshape(x.numel() // group_size, group_size) |
| 52 | + amax = x_.abs().max(dim=-1, |
| 53 | + keepdim=True)[0].clamp(min=eps).to(torch.float32) |
| 54 | + x_s = amax / fp8_max |
| 55 | + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) |
| 56 | + x_q = x_q.reshape(x.shape) |
| 57 | + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) |
| 58 | + |
| 59 | + return x_q, x_s |
| 60 | + |
| 61 | + |
| 62 | +def native_w8a8_block_fp8_matmul(A, |
| 63 | + B, |
| 64 | + As, |
| 65 | + Bs, |
| 66 | + block_size, |
| 67 | + output_dtype=torch.float16): |
| 68 | + """Matrix multiplication with block-wise quantization using native torch.""" |
| 69 | + A = A.to(torch.float32) |
| 70 | + B = B.to(torch.float32) |
| 71 | + assert A.shape[-1] == B.shape[-1] |
| 72 | + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 |
| 73 | + assert len(block_size) == 2 |
| 74 | + block_n, block_k = block_size[0], block_size[1] |
| 75 | + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] |
| 76 | + assert A.shape[:-1] == As.shape[:-1] |
| 77 | + |
| 78 | + M = A.numel() // A.shape[-1] |
| 79 | + N, K = B.shape |
| 80 | + origin_C_shape = A.shape[:-1] + (N, ) |
| 81 | + A = A.reshape(M, A.shape[-1]) |
| 82 | + As = As.reshape(M, As.shape[-1]) |
| 83 | + n_tiles = (N + block_n - 1) // block_n |
| 84 | + k_tiles = (K + block_k - 1) // block_k |
| 85 | + assert n_tiles == Bs.shape[0] |
| 86 | + assert k_tiles == Bs.shape[1] |
| 87 | + |
| 88 | + C_shape = (M, N) |
| 89 | + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) |
| 90 | + |
| 91 | + A_tiles = [ |
| 92 | + A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) |
| 93 | + ] |
| 94 | + B_tiles = [[ |
| 95 | + B[j * block_n:min((j + 1) * block_n, N), |
| 96 | + i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) |
| 97 | + ] for j in range(n_tiles)] |
| 98 | + C_tiles = [ |
| 99 | + C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) |
| 100 | + ] |
| 101 | + As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] |
| 102 | + |
| 103 | + for i in range(k_tiles): |
| 104 | + for j in range(n_tiles): |
| 105 | + a = A_tiles[i] |
| 106 | + b = B_tiles[j][i] |
| 107 | + c = C_tiles[j] |
| 108 | + s = As_tiles[i] * Bs[j][i] |
| 109 | + c[:, :] += torch.matmul(a, b.t()) * s |
| 110 | + |
| 111 | + C = C.reshape(origin_C_shape).to(output_dtype) |
| 112 | + return C |
| 113 | + |
| 114 | + |
| 115 | +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): |
| 116 | + """Fused moe with block-wise quantization using native torch.""" |
| 117 | + B, D = a.shape |
| 118 | + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) |
| 119 | + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) |
| 120 | + score = torch.softmax(score, dim=-1, dtype=torch.float32) |
| 121 | + topk_weight, topk_ids = torch.topk(score, topk) |
| 122 | + topk_weight = topk_weight.view(-1) |
| 123 | + topk_ids = topk_ids.view(-1) |
| 124 | + |
| 125 | + _, block_k = block_shape[0], block_shape[1] |
| 126 | + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) |
| 127 | + a_q = a_q.to(torch.float32) |
| 128 | + for i in range(w1.shape[0]): |
| 129 | + mask = topk_ids == i |
| 130 | + if mask.sum(): |
| 131 | + inter_out = native_w8a8_block_fp8_matmul(a_q[mask], |
| 132 | + w1[i], |
| 133 | + a_s[mask], |
| 134 | + w1_s[i], |
| 135 | + block_shape, |
| 136 | + output_dtype=a.dtype) |
| 137 | + act_out = SiluAndMul().forward_native(inter_out) |
| 138 | + act_out_q, act_out_s = native_per_token_group_quant_fp8( |
| 139 | + act_out, block_k) |
| 140 | + act_out = act_out.to(torch.float32) |
| 141 | + out[mask] = native_w8a8_block_fp8_matmul(act_out_q, |
| 142 | + w2[i], |
| 143 | + act_out_s, |
| 144 | + w2_s[i], |
| 145 | + block_shape, |
| 146 | + output_dtype=a.dtype) |
| 147 | + return (out.view(B, -1, w2.shape[1]) * |
| 148 | + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) |
| 149 | + |
| 150 | + |
| 151 | +# Skip all tests if CUDA is not available |
| 152 | +pytest.importorskip("torch.cuda") |
| 153 | + |
| 154 | + |
| 155 | +@pytest.fixture(autouse=True) |
| 156 | +def setup_cuda(): |
| 157 | + torch.set_default_device("cuda") |
| 158 | + |
| 159 | + |
| 160 | +@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed", |
| 161 | + itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE, |
| 162 | + SEEDS)) |
| 163 | +@torch.inference_mode() |
| 164 | +def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed): |
| 165 | + torch.manual_seed(seed) |
| 166 | + x = torch.rand(num_tokens, d, dtype=dtype) |
| 167 | + |
| 168 | + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) |
| 169 | + out, scale = per_token_group_quant_fp8(x, group_size) |
| 170 | + |
| 171 | + assert torch.allclose(out.to(torch.float32), |
| 172 | + ref_out.to(torch.float32), |
| 173 | + rtol=0.15) |
| 174 | + assert torch.allclose(scale, ref_scale) |
| 175 | + |
| 176 | + |
| 177 | +@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", |
| 178 | + itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, |
| 179 | + SEEDS)) |
| 180 | +@torch.inference_mode() |
| 181 | +def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed): |
| 182 | + torch.manual_seed(seed) |
| 183 | + factor_for_scale = 1e-2 |
| 184 | + fp8_info = torch.finfo(torch.float8_e4m3fn) |
| 185 | + fp8_max, fp8_min = fp8_info.max, fp8_info.min |
| 186 | + |
| 187 | + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max |
| 188 | + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 189 | + |
| 190 | + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max |
| 191 | + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 192 | + |
| 193 | + block_n, block_k = block_size[0], block_size[1] |
| 194 | + n_tiles = (N + block_n - 1) // block_n |
| 195 | + k_tiles = (K + block_k - 1) // block_k |
| 196 | + |
| 197 | + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale |
| 198 | + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale |
| 199 | + |
| 200 | + ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, |
| 201 | + out_dtype) |
| 202 | + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) |
| 203 | + |
| 204 | + rel_diff = (torch.mean( |
| 205 | + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / |
| 206 | + torch.mean(torch.abs(ref_out.to(torch.float32)))) |
| 207 | + assert rel_diff < 0.001 |
| 208 | + |
| 209 | + |
| 210 | +@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed", |
| 211 | + itertools.product(M_moe, N_moe, K_moe, E, TOP_KS, |
| 212 | + BLOCK_SIZE, DTYPES, SEEDS)) |
| 213 | +@torch.inference_mode() |
| 214 | +def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): |
| 215 | + torch.manual_seed(seed) |
| 216 | + factor_for_scale = 1e-2 |
| 217 | + fp8_info = torch.finfo(torch.float8_e4m3fn) |
| 218 | + fp8_max, fp8_min = fp8_info.max, fp8_info.min |
| 219 | + |
| 220 | + a = torch.randn((M, K), dtype=dtype) / 10 |
| 221 | + |
| 222 | + w1_bf16 = (torch.rand( |
| 223 | + (E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max |
| 224 | + w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 225 | + del w1_bf16 |
| 226 | + |
| 227 | + w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max |
| 228 | + w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 229 | + del w2_bf16 |
| 230 | + |
| 231 | + block_n, block_k = block_size[0], block_size[1] |
| 232 | + n_tiles_w1 = (2 * N + block_n - 1) // block_n |
| 233 | + n_tiles_w2 = (K + block_n - 1) // block_n |
| 234 | + k_tiles_w1 = (K + block_k - 1) // block_k |
| 235 | + k_tiles_w2 = (N + block_k - 1) // block_k |
| 236 | + |
| 237 | + w1_s = torch.rand( |
| 238 | + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale |
| 239 | + w2_s = torch.rand( |
| 240 | + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale |
| 241 | + |
| 242 | + score = torch.randn((M, E), dtype=dtype) |
| 243 | + |
| 244 | + out = fused_moe( |
| 245 | + a, |
| 246 | + w1, |
| 247 | + w2, |
| 248 | + score, |
| 249 | + topk, |
| 250 | + renormalize=False, |
| 251 | + use_fp8_w8a8=True, |
| 252 | + w1_scale=w1_s, |
| 253 | + w2_scale=w2_s, |
| 254 | + block_shape=block_size, |
| 255 | + ) |
| 256 | + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, |
| 257 | + block_size) |
| 258 | + |
| 259 | + print(f"{out.sum()=}") |
| 260 | + print(f"{ref_out.sum()=}") |
| 261 | + |
| 262 | + rel_diff = (torch.mean( |
| 263 | + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / |
| 264 | + torch.mean(torch.abs(ref_out.to(torch.float32)))) |
| 265 | + assert rel_diff < 0.03 |
0 commit comments