Skip to content

Commit aa2cd2c

Browse files
tlrmchlsmthmgoin
andauthored
[Bugfix] Disable w16a16 2of4 sparse CompressedTensors24 (#12417)
Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent 9ddc352 commit aa2cd2c

File tree

6 files changed

+263
-169
lines changed

6 files changed

+263
-169
lines changed

tests/kernels/test_cutlass.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
Run `pytest tests/kernels/test_cutlass.py`.
44
"""
5-
from typing import Optional, Type
5+
from typing import Type
66

77
import pytest
88
import torch
@@ -11,6 +11,8 @@
1111
from vllm import _custom_ops as ops
1212
from vllm.platforms import current_platform
1313

14+
from .utils import baseline_scaled_mm, to_fp8, to_int8
15+
1416
MNK_FACTORS = [
1517
(1, 256, 128),
1618
(1, 16384, 1024),
@@ -41,34 +43,10 @@
4143
capability = capability[0] * 10 + capability[1]
4244

4345

44-
def to_fp8(tensor: torch.Tensor):
45-
finfo = torch.finfo(torch.float8_e4m3fn)
46-
return torch.round(tensor.clamp(
47-
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
48-
49-
50-
def to_int8(tensor: torch.Tensor):
51-
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
52-
53-
5446
def rand_int8(shape: tuple, device: str = "cuda"):
5547
return to_int8(torch.rand(shape, device=device) * 255 - 128)
5648

5749

58-
def baseline_scaled_mm(a: torch.Tensor,
59-
b: torch.Tensor,
60-
scale_a: torch.Tensor,
61-
scale_b: torch.Tensor,
62-
out_dtype: Type[torch.dtype],
63-
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
64-
output = (scale_a * (scale_b * (torch.mm(
65-
a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
66-
if bias is not None:
67-
output = output + bias
68-
69-
return output
70-
71-
7250
def cutlass_fp8_gemm_helper(m: int,
7351
n: int,
7452
k: int,
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""Tests for sparse cutlass kernels
2+
3+
Run `pytest tests/kernels/test_semi_structured.py`.
4+
"""
5+
from typing import Tuple, Type
6+
7+
import pytest
8+
import torch
9+
import torch.nn.functional as F
10+
11+
from vllm import _custom_ops as ops
12+
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
13+
sparse_cutlass_supported)
14+
from vllm.platforms import current_platform
15+
16+
from .utils import baseline_scaled_mm, to_fp8, to_int8
17+
18+
CUDA_DEVICES = [
19+
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
20+
]
21+
22+
capability = current_platform.get_device_capability()
23+
capability = capability[0] * 10 + capability[1]
24+
25+
26+
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
27+
return tensor.to(dtype=torch.bfloat16)
28+
29+
30+
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
31+
return tensor.to(dtype=torch.float16)
32+
33+
34+
def prune_to_2_4(tensor):
35+
# Reshape tensor to [N, 4] where N is number of groups of 4
36+
original_shape = tensor.shape
37+
reshaped = tensor.reshape(-1, 4)
38+
39+
# Get indices of top 2 absolute values in each group of 4
40+
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
41+
42+
# Create binary mask
43+
mask = torch.zeros_like(reshaped)
44+
mask.scatter_(dim=1,
45+
index=indices,
46+
src=torch.ones_like(indices, dtype=mask.dtype))
47+
48+
# Apply mask and reshape back
49+
pruned = reshaped * mask
50+
51+
# Turn all -0.0 to 0.0
52+
pruned[pruned == -0.0] = 0.0
53+
54+
return pruned.reshape(original_shape)
55+
56+
57+
def make_rand_sparse_tensors(
58+
dtype: torch.dtype, m: int, n: int, k: int
59+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
60+
a = torch.randn((m, k), device='cuda') * 5
61+
b = torch.randn((n, k), device='cuda').t() * 5
62+
63+
b = prune_to_2_4(b.t()).t()
64+
65+
if dtype == torch.int8:
66+
a, b = to_int8(a), to_int8(b)
67+
elif dtype == torch.float8_e4m3fn:
68+
a, b = to_fp8(a), to_fp8(b)
69+
elif dtype == torch.float16:
70+
a, b = to_fp16(a), to_fp16(b)
71+
elif dtype == torch.bfloat16:
72+
a, b = to_bf16(a), to_bf16(b)
73+
else:
74+
raise ValueError("unsupported dtype")
75+
76+
b_compressed, e = ops.cutlass_sparse_compress(b.t())
77+
78+
# Compressed B, Metadata, Original A, B
79+
return b_compressed, e, a, b
80+
81+
82+
@pytest.mark.skipif(not sparse_cutlass_supported(),
83+
reason="Sparse CUTLASS is not supported on this GPU type.")
84+
# Test working with a subset of A and B for sparse matmul
85+
def test_cutlass_sparse_subset():
86+
87+
big_m = 1024
88+
m, n, k = 512, 512, 512
89+
90+
# Create tensors
91+
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
92+
big_m, n, k)
93+
a = whole_a[0:m, 0:k]
94+
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
95+
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
96+
97+
out = ops.cutlass_scaled_sparse_mm(a,
98+
b_comp,
99+
e,
100+
scale_a,
101+
scale_b,
102+
out_dtype=torch.bfloat16)
103+
baseline = baseline_scaled_mm(a,
104+
b,
105+
scale_a,
106+
scale_b,
107+
out_dtype=torch.bfloat16)
108+
109+
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
110+
111+
112+
MNK_FACTORS = [
113+
(1, 256, 128),
114+
(1, 16384, 1024),
115+
(1, 24576, 512),
116+
(16, 256, 512),
117+
(16, 16384, 128),
118+
(16, 24576, 4096),
119+
(32, 8192, 4096),
120+
(32, 16384, 4096),
121+
(33, 1024, 1024),
122+
(33, 8192, 128),
123+
(64, 2048, 512),
124+
(64, 16384, 1024),
125+
(100, 8192, 512),
126+
(128, 32768, 4096),
127+
(256, 4096, 4096),
128+
(512, 256, 1024),
129+
(512, 8192, 4096),
130+
(512, 16384, 128),
131+
(512, 24576, 128),
132+
]
133+
134+
135+
# Test working with a subset of A and B for sparse matmul
136+
@pytest.mark.skip(reason="2of4 sparse w16a16 CUTLASS produces bad output.")
137+
@pytest.mark.skipif(not sparse_cutlass_supported(),
138+
reason="Sparse CUTLASS is not supported on this GPU type.")
139+
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
140+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
141+
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: Type[torch.dtype]):
142+
143+
# Create tensors
144+
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
145+
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
146+
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
147+
148+
out = ops.cutlass_scaled_sparse_mm(a,
149+
b_comp,
150+
e,
151+
scale_a,
152+
scale_b,
153+
out_dtype=dtype)
154+
baseline = F.linear(a, b.T)
155+
156+
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1e-2)
157+
158+
159+
@pytest.mark.skipif(not sparse_cutlass_supported(),
160+
reason="Sparse CUTLASS is not supported on this GPU type.")
161+
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
162+
@pytest.mark.skipif(not current_platform.has_device_capability(89),
163+
reason="FP8 is not supported on this GPU type.")
164+
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int):
165+
166+
# Create tensors
167+
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
168+
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
169+
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
170+
171+
out = ops.cutlass_scaled_sparse_mm(a,
172+
b_comp,
173+
e,
174+
scale_a,
175+
scale_b,
176+
out_dtype=torch.bfloat16)
177+
178+
baseline = baseline_scaled_mm(a,
179+
b,
180+
scale_a,
181+
scale_b,
182+
out_dtype=torch.bfloat16)
183+
184+
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
185+
186+
187+
@pytest.mark.skipif(not sparse_cutlass_supported(),
188+
reason="Sparse CUTLASS is not supported on this GPU type.")
189+
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
190+
@pytest.mark.parametrize("per_act_token", [True, False])
191+
@pytest.mark.parametrize("per_out_ch", [True, False])
192+
@pytest.mark.parametrize("use_bias", [True, False])
193+
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
194+
per_out_ch: bool, use_bias: bool):
195+
196+
# Create tensors
197+
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
198+
scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
199+
scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
200+
201+
out = ops.cutlass_scaled_sparse_mm(a,
202+
b_comp,
203+
e,
204+
scale_a,
205+
scale_b,
206+
out_dtype=torch.bfloat16)
207+
208+
baseline = baseline_scaled_mm(a,
209+
b,
210+
scale_a,
211+
scale_b,
212+
out_dtype=torch.bfloat16)
213+
214+
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)

0 commit comments

Comments
 (0)