Skip to content

Commit b09307a

Browse files
authored
f8 roofline: make utils reusable (#731)
Summary: Moves the float8 roofline gemm and memory traffic utils to `torchao.float8.roofline_utils`, so they can be reused in other places. For now, I want to use this in ao_benchmarks. Test Plan: ``` python benchmarks/float8/float8_roofline.py ~/local/tmp/test.txt --gemm_time_strategy roofline ``` Reviewers: Subscribers: Tasks: Tags:
1 parent c0b0731 commit b09307a

File tree

2 files changed

+230
-212
lines changed

2 files changed

+230
-212
lines changed

benchmarks/float8/float8_roofline.py

+10-212
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
"""
28
This is a script to estimate the benefit from converting a `torch.nn.Linear`
39
layer to float8, by estimating the difference in e2e GPU kernel time between:
@@ -45,26 +51,10 @@
4551
import torch
4652
import torch.utils.benchmark as benchmark
4753

48-
BYTES_PER_EL_FLOAT8 = 1
49-
BYTES_PER_EL_BF16 = 2
50-
51-
# https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
52-
H100_BF16_PEAK_TOPS = 989e12
53-
H100_FP8_PEAK_TOPS = 1979e12
54-
55-
# 2.4 TB per second, custom to Meta's H100 variant
56-
H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12
57-
58-
# based on quick experimental observation with sample large inputs
59-
H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6
60-
61-
# based on previous experience looking at pointwise triton kernels with large inputs,
62-
# which would hit about 2.2k GBPS on Meta's H100 variant
63-
H100_PCT_ACHIEVABLE_MEM_BW = 0.92
64-
65-
# Source: run a triton kernel with a single element read/write on an H100 and
66-
# measure GPU time from the trace
67-
TRITON_KERNEL_1_ELEMENT_TIME_SEC = 0.002 * 0.001
54+
from torchao.float8.roofline_utils import (
55+
get_gemm_time_sympy,
56+
get_float8_mem_sympy,
57+
)
6858

6959

7060
def benchmark_fn_in_sec(f, *args, **kwargs):
@@ -78,90 +68,6 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
7868
return measurement.mean
7969

8070

81-
def get_tensor_memory_traffic_bytes(
82-
dim0,
83-
dim1,
84-
scaling_type: str,
85-
fuse_with_prev=False,
86-
model_torch_compile_limitations=False,
87-
):
88-
# assumes input bf16, output f8
89-
numel = dim0 * dim1
90-
91-
if scaling_type == "dynamic":
92-
# x_bf16 = ...
93-
# kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
94-
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
95-
# kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
96-
97-
if fuse_with_prev:
98-
kernel_1_rw = 0
99-
else:
100-
# kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
101-
kernel_1_rw = BYTES_PER_EL_BF16 * numel
102-
103-
# kernel 3: read in bf16, write twice in float8 (row-major and col-major)
104-
kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
105-
106-
if model_torch_compile_limitations:
107-
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
108-
# has an extra memory read of the input in fp8
109-
# context: https://github.com/pytorch/pytorch/issues/130015
110-
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
111-
else:
112-
tc_adjustment = 0
113-
114-
return kernel_1_rw + kernel_3_rw + tc_adjustment
115-
116-
else:
117-
assert scaling_type == "delayed", "unsupported"
118-
# x_bf16 = ...
119-
# kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp
120-
# kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
121-
# kernel 3 (not modeled): scale -> reciprocal -> inv_scale
122-
123-
if fuse_with_prev:
124-
kernel_1_r = 0
125-
else:
126-
kernel_1_r = numel * BYTES_PER_EL_BF16
127-
# write twice: once in row major, once in col-major
128-
kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2
129-
130-
if model_torch_compile_limitations:
131-
# today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
132-
# has an extra memory read of the input in fp8
133-
# context: https://github.com/pytorch/pytorch/issues/130015
134-
tc_adjustment = numel * BYTES_PER_EL_FLOAT8
135-
136-
# https://github.com/pytorch/pytorch/issues/128063
137-
# instead of
138-
# kernel 1: x_bf16 -> max(abs(x)), x_fp8
139-
# kernel 2: not modeled
140-
# kernel 3: not modeled
141-
# we get
142-
# kernel 1: x_bf16 -> max(abs(x))
143-
# reads: same as before
144-
# writes: 0
145-
# ...
146-
# kernel 4: x_bf16, scale -> x_fp8
147-
# reads: numel * BYTES_PER_EL_BF16
148-
# writes: 2 * numel * BYTES_PER_EL_FLOAT8
149-
# Note that assuming worst case, this issue brings the memory
150-
# traffic for delayed scaling to be equal to that of dynamic scaling.
151-
tc_adjustment += (
152-
# subtract writes from kernel 1
153-
-1 * 2 * numel * BYTES_PER_EL_FLOAT8
154-
# add reads for kernel 4
155-
+ numel * BYTES_PER_EL_BF16
156-
# add writes for kernel 4
157-
+ 2 * numel * BYTES_PER_EL_FLOAT8
158-
)
159-
else:
160-
tc_adjustment = 0
161-
162-
return kernel_1_r + kernel_1_w + tc_adjustment
163-
164-
16571
def get_gemm_times_cache(gemm_benchmarks_file: str):
16672
cache = {}
16773
with open(gemm_benchmarks_file, 'r') as f:
@@ -176,114 +82,6 @@ def get_gemm_times_cache(gemm_benchmarks_file: str):
17682
return cache
17783

17884

179-
def get_gemm_time_sympy(M, K, N, dtype):
180-
gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
181-
if dtype is torch.bfloat16:
182-
peak_tops = H100_BF16_PEAK_TOPS
183-
elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
184-
peak_tops = H100_FP8_PEAK_TOPS
185-
gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
186-
return gemm_time_s
187-
188-
189-
def get_float8_mem_sympy(
190-
M,
191-
K,
192-
N,
193-
model_torch_compile_limitations: bool = False,
194-
scaling_type_input: str = "dynamic",
195-
scaling_type_weight: str = "dynamic",
196-
scaling_type_grad_output: str = "dynamic",
197-
):
198-
199-
assert scaling_type_input in ("dynamic", "delayed"), "unsupported"
200-
assert scaling_type_weight in ("dynamic", "delayed"), "unsupported"
201-
assert scaling_type_grad_output in ("dynamic", "delayed"), "unsupported"
202-
203-
# there are three gemms in the fwd/bwd of a linear:
204-
#
205-
# input @ weight_t = output
206-
# MxK @ KxN => MxN
207-
#
208-
# grad_output @ weight = grad_input
209-
# MxN @ NxK => MxK
210-
#
211-
# input_t @ grad_output = grad_weight
212-
# KxM @ MxN => KxN
213-
214-
#
215-
# forward - output
216-
#
217-
fwd_fp8_input_mem = get_tensor_memory_traffic_bytes(
218-
M, K, scaling_type_input, fuse_with_prev=True,
219-
model_torch_compile_limitations=model_torch_compile_limitations)
220-
fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes(
221-
K, N, scaling_type_weight, fuse_with_prev=False,
222-
model_torch_compile_limitations=model_torch_compile_limitations)
223-
fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem
224-
225-
#
226-
# backward - grad_input
227-
#
228-
gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes(
229-
M, N, scaling_type_grad_output, fuse_with_prev=True,
230-
model_torch_compile_limitations=model_torch_compile_limitations)
231-
# already casted, assuming that we save weight from fw to bw
232-
# TODO: model this if FSDP float8 all-gather is on
233-
# TODO: model this if we don't save weight from fw to bw, and recompute instead
234-
gi_fp8_weight_mem = 0
235-
236-
#
237-
# backward - grad_weight
238-
#
239-
# TODO: model this if we don't save fp8 input from fw to bw
240-
gw_fp8_input_t_mem = 0 # already casted
241-
# this should be always 0
242-
gw_fp8_grad_output_mem = 0 # already casted
243-
244-
bwd_fp8_total_mem = \
245-
gi_fp8_grad_output_mem + gi_fp8_weight_mem + \
246-
gw_fp8_input_t_mem + gw_fp8_grad_output_mem
247-
fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
248-
fp8_mem_time_s = (
249-
fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
250-
)
251-
252-
# Adjust final estimate for small kernel launches
253-
# note that we do this adjustment here because we are assuming a minimal
254-
# kernel overhead in the units of seconds, and the per-gemm-input memory
255-
# estimations are in the units of bytes.
256-
num_extra_kernels = 0
257-
if scaling_type_input == "dynamic":
258-
# second stage of max-abs reduction
259-
num_extra_kernels += 1
260-
elif scaling_type_input == "delayed":
261-
# second stage of max-abs reduction
262-
num_extra_kernels += 1
263-
# reciprocal of scale
264-
num_extra_kernels += 1
265-
if scaling_type_weight == "dynamic":
266-
# second stage of max-abs reduction
267-
num_extra_kernels += 1
268-
elif scaling_type_weight == "delayed":
269-
# second stage of max-abs reduction
270-
num_extra_kernels += 1
271-
# reciprocal of scale
272-
num_extra_kernels += 1
273-
if scaling_type_grad_output == "dynamic":
274-
# second stage of max-abs reduction
275-
num_extra_kernels += 1
276-
elif scaling_type_grad_output == "delayed":
277-
# second stage of max-abs reduction
278-
num_extra_kernels += 1
279-
# reciprocal of scale
280-
num_extra_kernels += 1
281-
282-
extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
283-
284-
return fp8_mem_time_s + extra_kernel_overhead_s
285-
286-
28785
def run(
28886
outfile: str,
28987
gemm_time_strategy: str = "benchmarks",

0 commit comments

Comments
 (0)