Skip to content

Commit 1f20052

Browse files
mgoinsimon-moHandH1998
authored andcommitted
[Model] [Quantization] Support deepseek_v3 w8a8 fp8 block-wise quantization (vllm-project#11523)
Signed-off-by: mgoin <[email protected]> Signed-off-by: simon-mo <[email protected]> Signed-off-by: simon-mo <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: HandH1998 <[email protected]>
1 parent 98e2fc1 commit 1f20052

File tree

8 files changed

+931
-70
lines changed

8 files changed

+931
-70
lines changed

tests/kernels/test_block_fp8.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
# Adapted from https://github.com/sgl-project/sglang/pull/2575
2+
import itertools
3+
4+
import pytest
5+
import torch
6+
7+
from vllm.model_executor.layers.activation import SiluAndMul
8+
from vllm.model_executor.layers.fused_moe import fused_moe
9+
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
10+
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
11+
from vllm.platforms import current_platform
12+
13+
if current_platform.get_device_capability() < (9, 0):
14+
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
15+
allow_module_level=True)
16+
17+
# Test configurations
18+
DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
19+
NUM_TOKENS = [7, 83, 2048]
20+
D = [512, 4096, 5120, 13824]
21+
GROUP_SIZE = [64, 128, 256, 512]
22+
M = [1, 7, 83, 512, 2048]
23+
N = [128, 512, 1024, 4096, 7748, 13824]
24+
K = [256, 4096, 5120, 3884, 13824]
25+
# Deepseek-V3's intermediate size 18432, so N is 18432*2/8=4608 at TP8
26+
# and its hidden size is 7168.
27+
M_moe = [1, 7, 83, 512, 2048]
28+
N_moe = [4608] # [128, 4608, 13824]
29+
K_moe = [7168] # [256, 7168, 13824]
30+
BLOCK_SIZE = [[128, 128]]
31+
E = [256] # [8, 24, 128, 256]
32+
TOP_KS = [1] # [1, 2, 6]
33+
OUT_DTYPES = [torch.bfloat16] # [torch.float32, torch.half, torch.bfloat16]
34+
SEEDS = [0]
35+
36+
37+
def native_per_token_group_quant_fp8(x,
38+
group_size,
39+
eps=1e-10,
40+
dtype=torch.float8_e4m3fn):
41+
"""Function to perform per-token-group quantization on an input tensor
42+
`x` using native torch."""
43+
assert x.shape[-1] % group_size == 0, ("the last dimension of `x` cannot "
44+
"be divisible by `group_size`")
45+
assert x.is_contiguous(), "`x` is not contiguous"
46+
47+
finfo = torch.finfo(dtype)
48+
fp8_min = finfo.min
49+
fp8_max = finfo.max
50+
51+
x_ = x.reshape(x.numel() // group_size, group_size)
52+
amax = x_.abs().max(dim=-1,
53+
keepdim=True)[0].clamp(min=eps).to(torch.float32)
54+
x_s = amax / fp8_max
55+
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
56+
x_q = x_q.reshape(x.shape)
57+
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size, ))
58+
59+
return x_q, x_s
60+
61+
62+
def native_w8a8_block_fp8_matmul(A,
63+
B,
64+
As,
65+
Bs,
66+
block_size,
67+
output_dtype=torch.float16):
68+
"""Matrix multiplication with block-wise quantization using native torch."""
69+
A = A.to(torch.float32)
70+
B = B.to(torch.float32)
71+
assert A.shape[-1] == B.shape[-1]
72+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
73+
assert len(block_size) == 2
74+
block_n, block_k = block_size[0], block_size[1]
75+
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
76+
assert A.shape[:-1] == As.shape[:-1]
77+
78+
M = A.numel() // A.shape[-1]
79+
N, K = B.shape
80+
origin_C_shape = A.shape[:-1] + (N, )
81+
A = A.reshape(M, A.shape[-1])
82+
As = As.reshape(M, As.shape[-1])
83+
n_tiles = (N + block_n - 1) // block_n
84+
k_tiles = (K + block_k - 1) // block_k
85+
assert n_tiles == Bs.shape[0]
86+
assert k_tiles == Bs.shape[1]
87+
88+
C_shape = (M, N)
89+
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
90+
91+
A_tiles = [
92+
A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles)
93+
]
94+
B_tiles = [[
95+
B[j * block_n:min((j + 1) * block_n, N),
96+
i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles)
97+
] for j in range(n_tiles)]
98+
C_tiles = [
99+
C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles)
100+
]
101+
As_tiles = [As[:, i:i + 1] for i in range(k_tiles)]
102+
103+
for i in range(k_tiles):
104+
for j in range(n_tiles):
105+
a = A_tiles[i]
106+
b = B_tiles[j][i]
107+
c = C_tiles[j]
108+
s = As_tiles[i] * Bs[j][i]
109+
c[:, :] += torch.matmul(a, b.t()) * s
110+
111+
C = C.reshape(origin_C_shape).to(output_dtype)
112+
return C
113+
114+
115+
def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
116+
"""Fused moe with block-wise quantization using native torch."""
117+
B, D = a.shape
118+
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
119+
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
120+
score = torch.softmax(score, dim=-1, dtype=torch.float32)
121+
topk_weight, topk_ids = torch.topk(score, topk)
122+
topk_weight = topk_weight.view(-1)
123+
topk_ids = topk_ids.view(-1)
124+
125+
_, block_k = block_shape[0], block_shape[1]
126+
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
127+
a_q = a_q.to(torch.float32)
128+
for i in range(w1.shape[0]):
129+
mask = topk_ids == i
130+
if mask.sum():
131+
inter_out = native_w8a8_block_fp8_matmul(a_q[mask],
132+
w1[i],
133+
a_s[mask],
134+
w1_s[i],
135+
block_shape,
136+
output_dtype=a.dtype)
137+
act_out = SiluAndMul().forward_native(inter_out)
138+
act_out_q, act_out_s = native_per_token_group_quant_fp8(
139+
act_out, block_k)
140+
act_out = act_out.to(torch.float32)
141+
out[mask] = native_w8a8_block_fp8_matmul(act_out_q,
142+
w2[i],
143+
act_out_s,
144+
w2_s[i],
145+
block_shape,
146+
output_dtype=a.dtype)
147+
return (out.view(B, -1, w2.shape[1]) *
148+
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
149+
150+
151+
# Skip all tests if CUDA is not available
152+
pytest.importorskip("torch.cuda")
153+
154+
155+
@pytest.fixture(autouse=True)
156+
def setup_cuda():
157+
torch.set_default_device("cuda")
158+
159+
160+
@pytest.mark.parametrize("num_tokens,d,dtype,group_size,seed",
161+
itertools.product(NUM_TOKENS, D, DTYPES, GROUP_SIZE,
162+
SEEDS))
163+
@torch.inference_mode()
164+
def test_per_token_group_quant_fp8(num_tokens, d, dtype, group_size, seed):
165+
torch.manual_seed(seed)
166+
x = torch.rand(num_tokens, d, dtype=dtype)
167+
168+
ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size)
169+
out, scale = per_token_group_quant_fp8(x, group_size)
170+
171+
assert torch.allclose(out.to(torch.float32),
172+
ref_out.to(torch.float32),
173+
rtol=0.15)
174+
assert torch.allclose(scale, ref_scale)
175+
176+
177+
@pytest.mark.parametrize("M,N,K,block_size,out_dtype,seed",
178+
itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES,
179+
SEEDS))
180+
@torch.inference_mode()
181+
def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):
182+
torch.manual_seed(seed)
183+
factor_for_scale = 1e-2
184+
fp8_info = torch.finfo(torch.float8_e4m3fn)
185+
fp8_max, fp8_min = fp8_info.max, fp8_info.min
186+
187+
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
188+
A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
189+
190+
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
191+
B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
192+
193+
block_n, block_k = block_size[0], block_size[1]
194+
n_tiles = (N + block_n - 1) // block_n
195+
k_tiles = (K + block_k - 1) // block_k
196+
197+
As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale
198+
Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale
199+
200+
ref_out = native_w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size,
201+
out_dtype)
202+
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)
203+
204+
rel_diff = (torch.mean(
205+
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
206+
torch.mean(torch.abs(ref_out.to(torch.float32))))
207+
assert rel_diff < 0.001
208+
209+
210+
@pytest.mark.parametrize("M,N,K,E,topk,block_size,dtype,seed",
211+
itertools.product(M_moe, N_moe, K_moe, E, TOP_KS,
212+
BLOCK_SIZE, DTYPES, SEEDS))
213+
@torch.inference_mode()
214+
def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
215+
torch.manual_seed(seed)
216+
factor_for_scale = 1e-2
217+
fp8_info = torch.finfo(torch.float8_e4m3fn)
218+
fp8_max, fp8_min = fp8_info.max, fp8_info.min
219+
220+
a = torch.randn((M, K), dtype=dtype) / 10
221+
222+
w1_bf16 = (torch.rand(
223+
(E, 2 * N, K), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
224+
w1 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
225+
del w1_bf16
226+
227+
w2_bf16 = (torch.rand((E, K, N), dtype=torch.bfloat16) - 0.5) * 2 * fp8_max
228+
w2 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
229+
del w2_bf16
230+
231+
block_n, block_k = block_size[0], block_size[1]
232+
n_tiles_w1 = (2 * N + block_n - 1) // block_n
233+
n_tiles_w2 = (K + block_n - 1) // block_n
234+
k_tiles_w1 = (K + block_k - 1) // block_k
235+
k_tiles_w2 = (N + block_k - 1) // block_k
236+
237+
w1_s = torch.rand(
238+
(E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) * factor_for_scale
239+
w2_s = torch.rand(
240+
(E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) * factor_for_scale
241+
242+
score = torch.randn((M, E), dtype=dtype)
243+
244+
out = fused_moe(
245+
a,
246+
w1,
247+
w2,
248+
score,
249+
topk,
250+
renormalize=False,
251+
use_fp8_w8a8=True,
252+
w1_scale=w1_s,
253+
w2_scale=w2_s,
254+
block_shape=block_size,
255+
)
256+
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
257+
block_size)
258+
259+
print(f"{out.sum()=}")
260+
print(f"{ref_out.sum()=}")
261+
262+
rel_diff = (torch.mean(
263+
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
264+
torch.mean(torch.abs(ref_out.to(torch.float32))))
265+
assert rel_diff < 0.03

vllm/config.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ class ModelConfig:
161161
override default pooling config for the pooling model.
162162
logits_processor_pattern: Optional regex pattern specifying valid
163163
logits processor qualified names that can be passed with the
164-
`logits_processors` extra completion argument. Defaults to None,
164+
`logits_processors` extra completion argument. Defaults to None,
165165
which allows no processors.
166166
generation_config: Configuration parameter file for generation.
167167
"""
@@ -364,7 +364,7 @@ def __init__(self,
364364
def maybe_pull_model_tokenizer_for_s3(self, model: str,
365365
tokenizer: str) -> None:
366366
"""
367-
Pull the model config or tokenizer to a temporary
367+
Pull the model config or tokenizer to a temporary
368368
directory in case of S3.
369369
370370
Args:
@@ -866,14 +866,14 @@ def try_get_generation_config(self) -> Dict[str, Any]:
866866

867867
def get_diff_sampling_param(self) -> Dict[str, Any]:
868868
"""
869-
This method returns a dictionary containing the parameters
870-
that differ from the default sampling parameters, but only
871-
if `generation_config` is set. If `generation_config` is not
869+
This method returns a dictionary containing the parameters
870+
that differ from the default sampling parameters, but only
871+
if `generation_config` is set. If `generation_config` is not
872872
set, an empty dictionary is returned.
873873
874874
Returns:
875-
Dict[str, Any]: A dictionary with the differing sampling
876-
parameters if `generation_config` is set, otherwise an
875+
Dict[str, Any]: A dictionary with the differing sampling
876+
parameters if `generation_config` is set, otherwise an
877877
empty dictionary.
878878
"""
879879
if self.generation_config is None:

0 commit comments

Comments
 (0)