|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py |
| 4 | +import itertools |
| 5 | + |
| 6 | +import pytest |
| 7 | +import torch |
| 8 | + |
| 9 | +from vllm.config import VllmConfig, set_current_vllm_config |
| 10 | +from vllm.model_executor.layers.activation import SiluAndMul |
| 11 | +from vllm.model_executor.layers.fused_moe import fused_moe |
| 12 | +from vllm.model_executor.layers.quantization.utils.int8_utils import ( |
| 13 | + w8a8_block_int8_matmul) |
| 14 | +from vllm.platforms import current_platform |
| 15 | + |
| 16 | +from .utils_block import native_w8a8_block_matmul |
| 17 | + |
| 18 | +if current_platform.get_device_capability() < (7, 0): |
| 19 | + pytest.skip("INT8 Triton requires CUDA 7.0 or higher", |
| 20 | + allow_module_level=True) |
| 21 | + |
| 22 | + |
| 23 | +# For test |
| 24 | +def native_per_token_group_quant_int8(x, |
| 25 | + group_size, |
| 26 | + eps=1e-10, |
| 27 | + dtype=torch.int8): |
| 28 | + """Function to perform per-token-group quantization on an input tensor |
| 29 | + `x` using native torch. |
| 30 | +
|
| 31 | + It converts the tensor values into int8 values and returns the |
| 32 | + quantized tensor along with the scaling factor used for quantization. |
| 33 | + """ |
| 34 | + assert (x.shape[-1] % group_size == 0 |
| 35 | + ), "the last dimension of `x` cannot be divisible by `group_size`" |
| 36 | + assert x.is_contiguous(), "`x` is not contiguous" |
| 37 | + |
| 38 | + iinfo = torch.iinfo(dtype) |
| 39 | + int8_min = iinfo.min |
| 40 | + int8_max = iinfo.max |
| 41 | + |
| 42 | + x_ = x.reshape(x.numel() // group_size, group_size) |
| 43 | + # Use float32 for scale calculation for stability |
| 44 | + amax = x_.abs().max(dim=-1, |
| 45 | + keepdim=True)[0].clamp(min=eps).to(torch.float32) |
| 46 | + x_s = amax / int8_max |
| 47 | + x_q = (x_.to(torch.float32) / x_s).round().clamp( |
| 48 | + min=int8_min, max=int8_max).to(dtype) # Round before clamping |
| 49 | + x_q = x_q.reshape(x.shape) |
| 50 | + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, )) |
| 51 | + |
| 52 | + return x_q, x_s |
| 53 | + |
| 54 | + |
| 55 | +# For test |
| 56 | +def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): |
| 57 | + """This function performs fused moe with block-wise quantization using |
| 58 | + native torch.""" |
| 59 | + B, D = a.shape |
| 60 | + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) |
| 61 | + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) |
| 62 | + score = torch.softmax(score, dim=-1, dtype=torch.float32) |
| 63 | + topk_weight, topk_ids = torch.topk(score, topk) |
| 64 | + topk_weight = topk_weight.view(-1) |
| 65 | + topk_ids = topk_ids.view(-1) |
| 66 | + |
| 67 | + _, block_k = block_shape[0], block_shape[1] |
| 68 | + a_q, a_s = native_per_token_group_quant_int8(a, block_k) |
| 69 | + for i in range(w1.shape[0]): |
| 70 | + mask = topk_ids == i |
| 71 | + if mask.sum(): |
| 72 | + inter_out = native_w8a8_block_matmul(a_q[mask], |
| 73 | + w1[i], |
| 74 | + a_s[mask], |
| 75 | + w1_s[i], |
| 76 | + block_shape, |
| 77 | + output_dtype=a.dtype) |
| 78 | + act_out = SiluAndMul().forward_native(inter_out) |
| 79 | + act_out_q, act_out_s = native_per_token_group_quant_int8( |
| 80 | + act_out, block_k) |
| 81 | + act_out = act_out.to(torch.float32) |
| 82 | + out[mask] = native_w8a8_block_matmul(act_out_q, |
| 83 | + w2[i], |
| 84 | + act_out_s, |
| 85 | + w2_s[i], |
| 86 | + block_shape, |
| 87 | + output_dtype=a.dtype) |
| 88 | + return (out.view(B, -1, w2.shape[1]) * |
| 89 | + topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) |
| 90 | + |
| 91 | + |
| 92 | +DTYPES = [torch.half, torch.bfloat16] |
| 93 | +M = [1, 33, 64, 222] |
| 94 | +N = [128, 1024] |
| 95 | +K = [256, 4096] |
| 96 | +E = [8, 24] |
| 97 | +TOP_KS = [2, 6] |
| 98 | +# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] |
| 99 | +BLOCK_SIZE = [[128, 128]] |
| 100 | +SEEDS = [0] |
| 101 | + |
| 102 | + |
| 103 | +@pytest.fixture(autouse=True, scope="module") |
| 104 | +def setup_cuda(): |
| 105 | + """Sets the default CUDA device for all tests in this module.""" |
| 106 | + torch.set_default_device("cuda") |
| 107 | + |
| 108 | + |
| 109 | +@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed", |
| 110 | + itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS)) |
| 111 | +@torch.inference_mode() |
| 112 | +def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed): |
| 113 | + torch.manual_seed(seed) |
| 114 | + factor_for_scale = 1e-2 |
| 115 | + int8_info = torch.iinfo(torch.int8) |
| 116 | + int8_max, int8_min = int8_info.max, int8_info.min |
| 117 | + |
| 118 | + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * int8_max |
| 119 | + A_fp8 = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn) |
| 120 | + |
| 121 | + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max |
| 122 | + B_fp8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn) |
| 123 | + |
| 124 | + block_n, block_k = block_size[0], block_size[1] |
| 125 | + n_tiles = (N + block_n - 1) // block_n |
| 126 | + k_tiles = (K + block_k - 1) // block_k |
| 127 | + |
| 128 | + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale |
| 129 | + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale |
| 130 | + |
| 131 | + ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, |
| 132 | + out_dtype) |
| 133 | + out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) |
| 134 | + |
| 135 | + rel_diff = (torch.mean( |
| 136 | + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / |
| 137 | + torch.mean(torch.abs(ref_out.to(torch.float32)))) |
| 138 | + assert rel_diff < 0.001 |
| 139 | + |
| 140 | + |
| 141 | +@pytest.mark.parametrize( |
| 142 | + "M, N, K, E, topk, block_size, dtype, seed", |
| 143 | + itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS)) |
| 144 | +@torch.inference_mode() |
| 145 | +def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): |
| 146 | + """Tests the fused_moe kernel with W8A8 INT8 block quantization against a |
| 147 | + native torch reference.""" |
| 148 | + torch.manual_seed(seed) |
| 149 | + # Use a smaller factor for scale initialization to prevent large |
| 150 | + # values/overflow especially when output dtype might be float16 |
| 151 | + factor_for_scale = 1e-2 |
| 152 | + int8_info = torch.iinfo(torch.int8) |
| 153 | + int8_max, int8_min = int8_info.max, int8_info.min |
| 154 | + |
| 155 | + a = torch.randn((M, K), dtype=dtype) / 10 |
| 156 | + |
| 157 | + w1_fp32 = (torch.rand( |
| 158 | + (E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max |
| 159 | + w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) |
| 160 | + |
| 161 | + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max |
| 162 | + w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) |
| 163 | + |
| 164 | + block_n, block_k = block_size[0], block_size[1] |
| 165 | + n_tiles_w1 = (2 * N + block_n - 1) // block_n |
| 166 | + n_tiles_w2 = (K + block_n - 1) // block_n |
| 167 | + k_tiles_w1 = (K + block_k - 1) // block_k |
| 168 | + k_tiles_w2 = (N + block_k - 1) // block_k |
| 169 | + |
| 170 | + w1_s = (torch.rand( |
| 171 | + (E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale) |
| 172 | + w2_s = (torch.rand( |
| 173 | + (E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale) |
| 174 | + |
| 175 | + score = torch.randn((M, E), dtype=dtype) |
| 176 | + |
| 177 | + # Set the context to avoid lots of warning spam. |
| 178 | + vllm_config = VllmConfig() |
| 179 | + with set_current_vllm_config(vllm_config): |
| 180 | + out = fused_moe( |
| 181 | + a, |
| 182 | + w1, |
| 183 | + w2, |
| 184 | + score, |
| 185 | + topk, |
| 186 | + renormalize=False, |
| 187 | + use_int8_w8a8=True, |
| 188 | + w1_scale=w1_s, |
| 189 | + w2_scale=w2_s, |
| 190 | + block_shape=block_size, |
| 191 | + ) |
| 192 | + ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, |
| 193 | + block_size) |
| 194 | + |
| 195 | + # Check results |
| 196 | + rel_diff = (torch.mean( |
| 197 | + torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / |
| 198 | + torch.mean(torch.abs(ref_out.to(torch.float32)))) |
| 199 | + assert rel_diff < 0.06 |
0 commit comments