Skip to content

Commit 7321775

Browse files
committed
enable torch.compile for mxfp8_cublas recipe
Summary: This PR enables `MXLinear` with `mxfp8_cublas` recipe to use torch.compile. The current approach is a short term workaround until pytorch/pytorch#147873 is done. Since we can't use e8m0 in torchinductor or triton yet, we create a custom op wrapper around `torch._scaled_mm` which takes `uint8` scales and does the cast to e8m0 inside the wrapper, where torchinductor can't see it. Test Plan: ``` // this now works (although performance is not ideal due to #1788) python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas // we can also uncomment the hardware check and run the unit test pytest test/prototype/mx_formats -s -k test_linear_compile ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e5687e3 ghstack-comment-id: 2701679811 Pull Request resolved: #1841
1 parent 609f534 commit 7321775

File tree

2 files changed

+99
-7
lines changed

2 files changed

+99
-7
lines changed

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313

14+
from torchao.float8.float8_utils import is_row_major
1415
from torchao.prototype.mx_formats.config import (
1516
MXLinearConfig,
1617
MXLinearRecipeName,
@@ -22,6 +23,7 @@
2223
swap_linear_with_mx_inference_linear,
2324
swap_linear_with_mx_linear,
2425
)
26+
from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales
2527
from torchao.quantization.utils import compute_error
2628
from torchao.utils import (
2729
TORCH_VERSION_AT_LEAST_2_4,
@@ -169,11 +171,18 @@ def test_activation_checkpointing():
169171

170172
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
171173
@pytest.mark.skipif(
172-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
174+
is_sm_at_least_100(),
175+
reason="triton does not work yet on CUDA capability 10.0",
173176
)
174177
@pytest.mark.parametrize(
175178
"recipe_name",
176-
["mxfp8_emulated", "mxfp4_emulated", "mxfp8_cutlass", "mxfp4_cutlass"],
179+
[
180+
"mxfp8_emulated",
181+
"mxfp4_emulated",
182+
"mxfp8_cublas",
183+
"mxfp8_cutlass",
184+
"mxfp4_cutlass",
185+
],
177186
)
178187
@pytest.mark.parametrize("bias", [False, True])
179188
# TODO(future PR): figure out why torch.compile does not match eager when
@@ -190,9 +199,9 @@ def test_linear_compile(recipe_name, bias):
190199
if not is_sm_at_least_100():
191200
pytest.skip("CUDA capability >= 10.0 required for MX gemms")
192201

193-
if bias and recipe_name in ["mxfp8_cutlass", "mxfp4_cutlass"]:
202+
if bias and recipe_name in ["mxfp8_cublas", "mxfp8_cutlass", "mxfp4_cutlass"]:
194203
# TODO(future PR): fix this, things are clearly broken with bias=True
195-
pytest.skip("this test is broken for cutlass recipes with bias=True")
204+
pytest.skip("this test is broken for non-emulated recipes with bias=True")
196205

197206
M, K, N = 128, 256, 512
198207
input_shape = (M, K)
@@ -285,6 +294,60 @@ def test_inference_compile_simple(elem_dtype):
285294
assert sqnr >= 13.5
286295

287296

297+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
298+
@pytest.mark.skipif(
299+
is_sm_at_least_100(),
300+
reason="triton does not work yet on CUDA capability 10.0",
301+
)
302+
@pytest.mark.skipif(
303+
not is_sm_at_least_100(),
304+
reason="MX gemms require CUDA capability 10.0",
305+
)
306+
def test_scaled_mm_wrapper():
307+
# today, e8m0 isn't supported in torchinductor or triton
308+
# for now, work around this by creating a wrapper around torch._scaled_mm
309+
# which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper
310+
311+
M, K, N = 128, 256, 512
312+
BLOCK_SIZE = 32
313+
a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn)
314+
b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn)
315+
316+
a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
317+
b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu)
318+
319+
out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16)
320+
321+
def wrapped(a, b, a_scale, b_scale, out_dtype):
322+
if is_row_major(b.stride()):
323+
b = b.t().contiguous().t()
324+
res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype)
325+
return res
326+
327+
wrapped = torch.compile(wrapped)
328+
329+
# correct memory format of `b`
330+
out2 = wrapped(
331+
a,
332+
b.t(),
333+
a_scale.view(torch.uint8),
334+
b_scale.view(torch.uint8),
335+
out_dtype=torch.bfloat16,
336+
)
337+
torch.testing.assert_close(out, out2, atol=0, rtol=0)
338+
339+
# incorrect memory format of `b`
340+
b_col_major = b.t().contiguous().t()
341+
out3 = wrapped(
342+
a,
343+
b_col_major.t(),
344+
a_scale.view(torch.uint8),
345+
b_scale.view(torch.uint8),
346+
out_dtype=torch.bfloat16,
347+
)
348+
torch.testing.assert_close(out, out3, atol=0, rtol=0)
349+
350+
288351
def test_filter_fn():
289352
m1 = nn.Sequential(
290353
nn.Linear(32, 32),

torchao/prototype/mx_formats/mx_ops.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,35 @@
3636
MX_OPS_TABLE: Dict[Any, Any] = {}
3737

3838

39+
@torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=())
40+
def _scaled_mm_with_uint8_scales(
41+
a: torch.Tensor,
42+
b: torch.Tensor,
43+
a_scale: torch.Tensor,
44+
b_scale: torch.Tensor,
45+
out_dtype: torch.dtype,
46+
) -> torch.Tensor:
47+
"""
48+
Until https://github.com/pytorch/pytorch/issues/147873 is done, we need to
49+
work around the lack of support for `torch.float8_e8m0fnu` in
50+
torchinductor. We do so by hiding the cast of scales to e8m0 inside a
51+
custom op.
52+
"""
53+
# cast back to e8m0 where torchinductor can't see it
54+
a_scale = a_scale.view(torch.float8_e8m0fnu)
55+
b_scale = b_scale.view(torch.float8_e8m0fnu)
56+
res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype)
57+
return res
58+
59+
60+
@_scaled_mm_with_uint8_scales.register_fake
61+
def _(a, b, a_scale, b_scale, out_dtype):
62+
m, k = a.shape
63+
k2, n = b.shape
64+
res = torch.empty(m, n, dtype=out_dtype, device=a.device)
65+
return res
66+
67+
3968
def implements(aten_ops):
4069
"""Register aten ops to the mx op table"""
4170

@@ -83,11 +112,11 @@ def mx_mm(aten_op, args, kwargs=None):
83112
if a._elem_dtype == torch.float8_e4m3fn:
84113
assert b._elem_dtype == torch.float8_e4m3fn
85114
if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS:
86-
res = torch._scaled_mm(
115+
res = _scaled_mm_with_uint8_scales(
87116
a._data,
88117
b._data,
89-
a_scale_block.view(torch.float8_e8m0fnu),
90-
b_scale_block.view(torch.float8_e8m0fnu),
118+
a_scale_block,
119+
b_scale_block,
91120
out_dtype=torch.bfloat16,
92121
)
93122
else:

0 commit comments

Comments
 (0)