Skip to content

Commit 18fdd8c

Browse files
committed
enable mxfp8_cublas recipe in roofline script
Summary: Enables us to see roofline vs actual performance of this recipe Test Plan: ``` python benchmarks/float8/float8_roofline.py ~/local/tmp/20250305_test --mx_recipe_name mxfp8_cublas ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 1e5a431 ghstack-comment-id: 2701761490 Pull Request resolved: #1843
1 parent c509315 commit 18fdd8c

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

benchmarks/float8/float8_roofline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,11 @@ def get_gemm_times(
184184
elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"):
185185
scale_a = torch.ones(M, 1, device=device)
186186
scale_b = torch.ones(1, N, device=device)
187+
elif mx_recipe_name == "mxfp8_cublas":
188+
scale_a = torch.ones(M, K // 32, device=device, dtype=torch.float8_e8m0fnu)
189+
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
187190
else:
188-
assert False, "TODO add mx gemm here"
191+
assert False, "TODO add cutlass mx gemm here"
189192

190193
def do_matmul(A, B):
191194
return torch._scaled_mm(

torchao/testing/float8/roofline_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ def get_tensor_memory_traffic_ovhd_s(
165165
assert False, "unsupported"
166166

167167
else:
168-
assert mx_recipe_name in ("mxfp8_emulated", "mxfp8_cutlass"), "unsupported"
168+
assert mx_recipe_name in (
169+
"mxfp8_emulated",
170+
"mxfp8_cutlass",
171+
"mxfp8_cublas",
172+
), "unsupported"
169173

170174
if tensor_role == "weight":
171175
# x_bf16 = ...
@@ -219,7 +223,11 @@ def get_individual_gemm_time_sympy(
219223
num_writes = M * N
220224

221225
if mx_recipe_name is not None:
222-
assert mx_recipe_name in ("mxfp8_emulated", "mxfp8_cutlass"), "unsupported"
226+
assert mx_recipe_name in (
227+
"mxfp8_emulated",
228+
"mxfp8_cutlass",
229+
"mxfp8_cublas",
230+
), "unsupported"
223231
assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "unsupported"
224232
# adjust reads for MX scaling
225233
block_size = 32

0 commit comments

Comments
 (0)