Skip to content

Commit 0d46532

Browse files
mgoinlk-chen
authored andcommitted
[Kernel] Support W8A8 channel-wise weights and per-token activations in triton fused_moe_kernel (vllm-project#16366)
Signed-off-by: mgoin <[email protected]>
1 parent 4d3a83b commit 0d46532

File tree

7 files changed

+1229
-158
lines changed

7 files changed

+1229
-158
lines changed

tests/kernels/test_block_fp8.py

Lines changed: 19 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
1919
from vllm.platforms import current_platform
2020

21+
from .utils_block import native_w8a8_block_matmul
22+
2123
dg_available = False
2224
try:
2325
import deep_gemm
@@ -75,61 +77,6 @@ def native_per_token_group_quant_fp8(x,
7577
return x_q, x_s
7678

7779

78-
def native_w8a8_block_fp8_matmul(A,
79-
B,
80-
As,
81-
Bs,
82-
block_size,
83-
output_dtype=torch.float16):
84-
"""Matrix multiplication with block-wise quantization using native torch."""
85-
A = A.to(torch.float32)
86-
B = B.to(torch.float32)
87-
assert A.shape[-1] == B.shape[-1]
88-
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
89-
assert len(block_size) == 2
90-
block_n, block_k = block_size[0], block_size[1]
91-
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
92-
assert A.shape[:-1] == As.shape[:-1]
93-
94-
M = A.numel() // A.shape[-1]
95-
N, K = B.shape
96-
origin_C_shape = A.shape[:-1] + (N, )
97-
A = A.reshape(M, A.shape[-1])
98-
As = As.reshape(M, As.shape[-1])
99-
n_tiles = (N + block_n - 1) // block_n
100-
k_tiles = (K + block_k - 1) // block_k
101-
assert n_tiles == Bs.shape[0]
102-
assert k_tiles == Bs.shape[1]
103-
104-
C_shape = (M, N)
105-
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
106-
107-
A_tiles = [
108-
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
109-
]
110-
B_tiles = [[
111-
B[
112-
j * block_n:min((j + 1) * block_n, N),
113-
i * block_k:min((i + 1) * block_k, K),
114-
] for i in range(k_tiles)
115-
] for j in range(n_tiles)]
116-
C_tiles = [
117-
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
118-
]
119-
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
120-
121-
for i in range(k_tiles):
122-
for j in range(n_tiles):
123-
a = A_tiles[i]
124-
b = B_tiles[j][i]
125-
c = C_tiles[j]
126-
s = As_tiles[i] * Bs[j][i]
127-
c[:, :] += torch.matmul(a, b.t()) * s
128-
129-
C = C.reshape(origin_C_shape).to(output_dtype)
130-
return C
131-
132-
13380
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
13481
"""Fused moe with block-wise quantization using native torch."""
13582
B, D = a.shape
@@ -146,22 +93,22 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
14693
for i in range(w1.shape[0]):
14794
mask = topk_ids == i
14895
if mask.sum():
149-
inter_out = native_w8a8_block_fp8_matmul(a_q[mask],
150-
w1[i],
151-
a_s[mask],
152-
w1_s[i],
153-
block_shape,
154-
output_dtype=a.dtype)
96+
inter_out = native_w8a8_block_matmul(a_q[mask],
97+
w1[i],
98+
a_s[mask],
99+
w1_s[i],
100+
block_shape,
101+
output_dtype=a.dtype)
155102
act_out = SiluAndMul().forward_native(inter_out)
156103
act_out_q, act_out_s = native_per_token_group_quant_fp8(
157104
act_out, block_k)
158105
act_out = act_out.to(torch.float32)
159-
out[mask] = native_w8a8_block_fp8_matmul(act_out_q,
160-
w2[i],
161-
act_out_s,
162-
w2_s[i],
163-
block_shape,
164-
output_dtype=a.dtype)
106+
out[mask] = native_w8a8_block_matmul(act_out_q,
107+
w2[i],
108+
act_out_s,
109+
w2_s[i],
110+
block_shape,
111+
output_dtype=a.dtype)
165112
return (out.view(B, -1, w2.shape[1]) *
166113
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
167114

@@ -215,8 +162,8 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
215162
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
216163
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
217164

218-
ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
219-
out_dtype)
165+
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
166+
out_dtype)
220167
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
221168

222169
rel_diff = (torch.mean(
@@ -239,8 +186,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
239186
fp8_info = torch.finfo(torch.float8_e4m3fn)
240187
fp8_max, fp8_min = fp8_info.max, fp8_info.min
241188

242-
vllm_config = VllmConfig()
243-
244189
a = torch.randn((M, K), dtype=dtype) / 10
245190

246191
w1_bf16 = (torch.rand(
@@ -266,6 +211,7 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
266211
score = torch.randn((M, E), dtype=dtype)
267212

268213
# Set the context to avoid lots of warning spam.
214+
vllm_config = VllmConfig()
269215
with set_current_vllm_config(vllm_config):
270216
out = fused_moe(
271217
a,
@@ -334,8 +280,8 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
334280
As = As_fp8.to(torch.float32)
335281
Bs = Bs_fp8.to(torch.float32)
336282

337-
ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
338-
out_dtype)
283+
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
284+
out_dtype)
339285

340286
# Transpose earlier so that the testing will not trigger transposing kernels
341287
As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8)

tests/kernels/test_block_int8.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_block_int8.py
4+
import itertools
5+
6+
import pytest
7+
import torch
8+
9+
from vllm.config import VllmConfig, set_current_vllm_config
10+
from vllm.model_executor.layers.activation import SiluAndMul
11+
from vllm.model_executor.layers.fused_moe import fused_moe
12+
from vllm.model_executor.layers.quantization.utils.int8_utils import (
13+
w8a8_block_int8_matmul)
14+
from vllm.platforms import current_platform
15+
16+
from .utils_block import native_w8a8_block_matmul
17+
18+
if current_platform.get_device_capability() < (7, 0):
19+
pytest.skip("INT8 Triton requires CUDA 7.0 or higher",
20+
allow_module_level=True)
21+
22+
23+
# For test
24+
def native_per_token_group_quant_int8(x,
25+
group_size,
26+
eps=1e-10,
27+
dtype=torch.int8):
28+
"""Function to perform per-token-group quantization on an input tensor
29+
`x` using native torch.
30+
31+
It converts the tensor values into int8 values and returns the
32+
quantized tensor along with the scaling factor used for quantization.
33+
"""
34+
assert (x.shape[-1] % group_size == 0
35+
), "the last dimension of `x` cannot be divisible by `group_size`"
36+
assert x.is_contiguous(), "`x` is not contiguous"
37+
38+
iinfo = torch.iinfo(dtype)
39+
int8_min = iinfo.min
40+
int8_max = iinfo.max
41+
42+
x_ = x.reshape(x.numel() // group_size, group_size)
43+
# Use float32 for scale calculation for stability
44+
amax = x_.abs().max(dim=-1,
45+
keepdim=True)[0].clamp(min=eps).to(torch.float32)
46+
x_s = amax / int8_max
47+
x_q = (x_.to(torch.float32) / x_s).round().clamp(
48+
min=int8_min, max=int8_max).to(dtype) # Round before clamping
49+
x_q = x_q.reshape(x.shape)
50+
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
51+
52+
return x_q, x_s
53+
54+
55+
# For test
56+
def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
57+
"""This function performs fused moe with block-wise quantization using
58+
native torch."""
59+
B, D = a.shape
60+
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
61+
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
62+
score = torch.softmax(score, dim=-1, dtype=torch.float32)
63+
topk_weight, topk_ids = torch.topk(score, topk)
64+
topk_weight = topk_weight.view(-1)
65+
topk_ids = topk_ids.view(-1)
66+
67+
_, block_k = block_shape[0], block_shape[1]
68+
a_q, a_s = native_per_token_group_quant_int8(a, block_k)
69+
for i in range(w1.shape[0]):
70+
mask = topk_ids == i
71+
if mask.sum():
72+
inter_out = native_w8a8_block_matmul(a_q[mask],
73+
w1[i],
74+
a_s[mask],
75+
w1_s[i],
76+
block_shape,
77+
output_dtype=a.dtype)
78+
act_out = SiluAndMul().forward_native(inter_out)
79+
act_out_q, act_out_s = native_per_token_group_quant_int8(
80+
act_out, block_k)
81+
act_out = act_out.to(torch.float32)
82+
out[mask] = native_w8a8_block_matmul(act_out_q,
83+
w2[i],
84+
act_out_s,
85+
w2_s[i],
86+
block_shape,
87+
output_dtype=a.dtype)
88+
return (out.view(B, -1, w2.shape[1]) *
89+
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
90+
91+
92+
DTYPES = [torch.half, torch.bfloat16]
93+
M = [1, 33, 64, 222]
94+
N = [128, 1024]
95+
K = [256, 4096]
96+
E = [8, 24]
97+
TOP_KS = [2, 6]
98+
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
99+
BLOCK_SIZE = [[128, 128]]
100+
SEEDS = [0]
101+
102+
103+
@pytest.fixture(autouse=True, scope="module")
104+
def setup_cuda():
105+
"""Sets the default CUDA device for all tests in this module."""
106+
torch.set_default_device("cuda")
107+
108+
109+
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
110+
itertools.product(M, N, K, BLOCK_SIZE, DTYPES, SEEDS))
111+
@torch.inference_mode()
112+
def test_w8a8_block_int8_matmul(M, N, K, block_size, out_dtype, seed):
113+
torch.manual_seed(seed)
114+
factor_for_scale = 1e-2
115+
int8_info = torch.iinfo(torch.int8)
116+
int8_max, int8_min = int8_info.max, int8_info.min
117+
118+
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * int8_max
119+
A_fp8 = A_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
120+
121+
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * int8_max
122+
B_fp8 = B_fp32.clamp(min=int8_min, max=int8_max).to(torch.float8_e4m3fn)
123+
124+
block_n, block_k = block_size[0], block_size[1]
125+
n_tiles = (N + block_n - 1) // block_n
126+
k_tiles = (K + block_k - 1) // block_k
127+
128+
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
129+
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
130+
131+
ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
132+
out_dtype)
133+
out = w8a8_block_int8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
134+
135+
rel_diff = (torch.mean(
136+
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
137+
torch.mean(torch.abs(ref_out.to(torch.float32))))
138+
assert rel_diff < 0.001
139+
140+
141+
@pytest.mark.parametrize(
142+
"M, N, K, E, topk, block_size, dtype, seed",
143+
itertools.product(M, N, K, E, TOP_KS, BLOCK_SIZE, DTYPES, SEEDS))
144+
@torch.inference_mode()
145+
def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
146+
"""Tests the fused_moe kernel with W8A8 INT8 block quantization against a
147+
native torch reference."""
148+
torch.manual_seed(seed)
149+
# Use a smaller factor for scale initialization to prevent large
150+
# values/overflow especially when output dtype might be float16
151+
factor_for_scale = 1e-2
152+
int8_info = torch.iinfo(torch.int8)
153+
int8_max, int8_min = int8_info.max, int8_info.min
154+
155+
a = torch.randn((M, K), dtype=dtype) / 10
156+
157+
w1_fp32 = (torch.rand(
158+
(E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max
159+
w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
160+
161+
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max
162+
w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8)
163+
164+
block_n, block_k = block_size[0], block_size[1]
165+
n_tiles_w1 = (2 * N + block_n - 1) // block_n
166+
n_tiles_w2 = (K + block_n - 1) // block_n
167+
k_tiles_w1 = (K + block_k - 1) // block_k
168+
k_tiles_w2 = (N + block_k - 1) // block_k
169+
170+
w1_s = (torch.rand(
171+
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale)
172+
w2_s = (torch.rand(
173+
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale)
174+
175+
score = torch.randn((M, E), dtype=dtype)
176+
177+
# Set the context to avoid lots of warning spam.
178+
vllm_config = VllmConfig()
179+
with set_current_vllm_config(vllm_config):
180+
out = fused_moe(
181+
a,
182+
w1,
183+
w2,
184+
score,
185+
topk,
186+
renormalize=False,
187+
use_int8_w8a8=True,
188+
w1_scale=w1_s,
189+
w2_scale=w2_s,
190+
block_shape=block_size,
191+
)
192+
ref_out = torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk,
193+
block_size)
194+
195+
# Check results
196+
rel_diff = (torch.mean(
197+
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
198+
torch.mean(torch.abs(ref_out.to(torch.float32))))
199+
assert rel_diff < 0.06

0 commit comments

Comments
 (0)