1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD 3-Clause license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
1
7
"""
2
8
This is a script to estimate the benefit from converting a `torch.nn.Linear`
3
9
layer to float8, by estimating the difference in e2e GPU kernel time between:
45
51
import torch
46
52
import torch .utils .benchmark as benchmark
47
53
48
- BYTES_PER_EL_FLOAT8 = 1
49
- BYTES_PER_EL_BF16 = 2
50
-
51
- # https://www.nvidia.com/en-us/data-center/h100/, divide by 2 because no sparsity
52
- H100_BF16_PEAK_TOPS = 989e12
53
- H100_FP8_PEAK_TOPS = 1979e12
54
-
55
- # 2.4 TB per second, custom to Meta's H100 variant
56
- H100_PEAK_MEM_BW_BYTES_SEC = 2.4e12
57
-
58
- # based on quick experimental observation with sample large inputs
59
- H100_PCT_ACHIEVABLE_GEMM_TOPS = 0.6
60
-
61
- # based on previous experience looking at pointwise triton kernels with large inputs,
62
- # which would hit about 2.2k GBPS on Meta's H100 variant
63
- H100_PCT_ACHIEVABLE_MEM_BW = 0.92
64
-
65
- # Source: run a triton kernel with a single element read/write on an H100 and
66
- # measure GPU time from the trace
67
- TRITON_KERNEL_1_ELEMENT_TIME_SEC = 0.002 * 0.001
54
+ from torchao .float8 .roofline_utils import (
55
+ get_gemm_time_sympy ,
56
+ get_float8_mem_sympy ,
57
+ )
68
58
69
59
70
60
def benchmark_fn_in_sec (f , * args , ** kwargs ):
@@ -78,90 +68,6 @@ def benchmark_fn_in_sec(f, *args, **kwargs):
78
68
return measurement .mean
79
69
80
70
81
- def get_tensor_memory_traffic_bytes (
82
- dim0 ,
83
- dim1 ,
84
- scaling_type : str ,
85
- fuse_with_prev = False ,
86
- model_torch_compile_limitations = False ,
87
- ):
88
- # assumes input bf16, output f8
89
- numel = dim0 * dim1
90
-
91
- if scaling_type == "dynamic" :
92
- # x_bf16 = ...
93
- # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp
94
- # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
95
- # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8
96
-
97
- if fuse_with_prev :
98
- kernel_1_rw = 0
99
- else :
100
- # kernel 1: read numel, write 0 (assume size(tmp) ~ 0)
101
- kernel_1_rw = BYTES_PER_EL_BF16 * numel
102
-
103
- # kernel 3: read in bf16, write twice in float8 (row-major and col-major)
104
- kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel
105
-
106
- if model_torch_compile_limitations :
107
- # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
108
- # has an extra memory read of the input in fp8
109
- # context: https://github.com/pytorch/pytorch/issues/130015
110
- tc_adjustment = numel * BYTES_PER_EL_FLOAT8
111
- else :
112
- tc_adjustment = 0
113
-
114
- return kernel_1_rw + kernel_3_rw + tc_adjustment
115
-
116
- else :
117
- assert scaling_type == "delayed" , "unsupported"
118
- # x_bf16 = ...
119
- # kernel 1: x_bf16 -> max_abs_stage_1_and_to_float8 -> x_float8, tmp
120
- # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs
121
- # kernel 3 (not modeled): scale -> reciprocal -> inv_scale
122
-
123
- if fuse_with_prev :
124
- kernel_1_r = 0
125
- else :
126
- kernel_1_r = numel * BYTES_PER_EL_BF16
127
- # write twice: once in row major, once in col-major
128
- kernel_1_w = numel * BYTES_PER_EL_FLOAT8 * 2
129
-
130
- if model_torch_compile_limitations :
131
- # today, the kernel to do cast_to_fp8_row_major_and_col_major(input_bf16, ...)
132
- # has an extra memory read of the input in fp8
133
- # context: https://github.com/pytorch/pytorch/issues/130015
134
- tc_adjustment = numel * BYTES_PER_EL_FLOAT8
135
-
136
- # https://github.com/pytorch/pytorch/issues/128063
137
- # instead of
138
- # kernel 1: x_bf16 -> max(abs(x)), x_fp8
139
- # kernel 2: not modeled
140
- # kernel 3: not modeled
141
- # we get
142
- # kernel 1: x_bf16 -> max(abs(x))
143
- # reads: same as before
144
- # writes: 0
145
- # ...
146
- # kernel 4: x_bf16, scale -> x_fp8
147
- # reads: numel * BYTES_PER_EL_BF16
148
- # writes: 2 * numel * BYTES_PER_EL_FLOAT8
149
- # Note that assuming worst case, this issue brings the memory
150
- # traffic for delayed scaling to be equal to that of dynamic scaling.
151
- tc_adjustment += (
152
- # subtract writes from kernel 1
153
- - 1 * 2 * numel * BYTES_PER_EL_FLOAT8
154
- # add reads for kernel 4
155
- + numel * BYTES_PER_EL_BF16
156
- # add writes for kernel 4
157
- + 2 * numel * BYTES_PER_EL_FLOAT8
158
- )
159
- else :
160
- tc_adjustment = 0
161
-
162
- return kernel_1_r + kernel_1_w + tc_adjustment
163
-
164
-
165
71
def get_gemm_times_cache (gemm_benchmarks_file : str ):
166
72
cache = {}
167
73
with open (gemm_benchmarks_file , 'r' ) as f :
@@ -176,114 +82,6 @@ def get_gemm_times_cache(gemm_benchmarks_file: str):
176
82
return cache
177
83
178
84
179
- def get_gemm_time_sympy (M , K , N , dtype ):
180
- gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N
181
- if dtype is torch .bfloat16 :
182
- peak_tops = H100_BF16_PEAK_TOPS
183
- elif dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
184
- peak_tops = H100_FP8_PEAK_TOPS
185
- gemm_time_s = gemm_ops / peak_tops / H100_PCT_ACHIEVABLE_GEMM_TOPS
186
- return gemm_time_s
187
-
188
-
189
- def get_float8_mem_sympy (
190
- M ,
191
- K ,
192
- N ,
193
- model_torch_compile_limitations : bool = False ,
194
- scaling_type_input : str = "dynamic" ,
195
- scaling_type_weight : str = "dynamic" ,
196
- scaling_type_grad_output : str = "dynamic" ,
197
- ):
198
-
199
- assert scaling_type_input in ("dynamic" , "delayed" ), "unsupported"
200
- assert scaling_type_weight in ("dynamic" , "delayed" ), "unsupported"
201
- assert scaling_type_grad_output in ("dynamic" , "delayed" ), "unsupported"
202
-
203
- # there are three gemms in the fwd/bwd of a linear:
204
- #
205
- # input @ weight_t = output
206
- # MxK @ KxN => MxN
207
- #
208
- # grad_output @ weight = grad_input
209
- # MxN @ NxK => MxK
210
- #
211
- # input_t @ grad_output = grad_weight
212
- # KxM @ MxN => KxN
213
-
214
- #
215
- # forward - output
216
- #
217
- fwd_fp8_input_mem = get_tensor_memory_traffic_bytes (
218
- M , K , scaling_type_input , fuse_with_prev = True ,
219
- model_torch_compile_limitations = model_torch_compile_limitations )
220
- fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes (
221
- K , N , scaling_type_weight , fuse_with_prev = False ,
222
- model_torch_compile_limitations = model_torch_compile_limitations )
223
- fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem
224
-
225
- #
226
- # backward - grad_input
227
- #
228
- gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes (
229
- M , N , scaling_type_grad_output , fuse_with_prev = True ,
230
- model_torch_compile_limitations = model_torch_compile_limitations )
231
- # already casted, assuming that we save weight from fw to bw
232
- # TODO: model this if FSDP float8 all-gather is on
233
- # TODO: model this if we don't save weight from fw to bw, and recompute instead
234
- gi_fp8_weight_mem = 0
235
-
236
- #
237
- # backward - grad_weight
238
- #
239
- # TODO: model this if we don't save fp8 input from fw to bw
240
- gw_fp8_input_t_mem = 0 # already casted
241
- # this should be always 0
242
- gw_fp8_grad_output_mem = 0 # already casted
243
-
244
- bwd_fp8_total_mem = \
245
- gi_fp8_grad_output_mem + gi_fp8_weight_mem + \
246
- gw_fp8_input_t_mem + gw_fp8_grad_output_mem
247
- fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem
248
- fp8_mem_time_s = (
249
- fp8_total_mem / H100_PEAK_MEM_BW_BYTES_SEC / H100_PCT_ACHIEVABLE_MEM_BW
250
- )
251
-
252
- # Adjust final estimate for small kernel launches
253
- # note that we do this adjustment here because we are assuming a minimal
254
- # kernel overhead in the units of seconds, and the per-gemm-input memory
255
- # estimations are in the units of bytes.
256
- num_extra_kernels = 0
257
- if scaling_type_input == "dynamic" :
258
- # second stage of max-abs reduction
259
- num_extra_kernels += 1
260
- elif scaling_type_input == "delayed" :
261
- # second stage of max-abs reduction
262
- num_extra_kernels += 1
263
- # reciprocal of scale
264
- num_extra_kernels += 1
265
- if scaling_type_weight == "dynamic" :
266
- # second stage of max-abs reduction
267
- num_extra_kernels += 1
268
- elif scaling_type_weight == "delayed" :
269
- # second stage of max-abs reduction
270
- num_extra_kernels += 1
271
- # reciprocal of scale
272
- num_extra_kernels += 1
273
- if scaling_type_grad_output == "dynamic" :
274
- # second stage of max-abs reduction
275
- num_extra_kernels += 1
276
- elif scaling_type_grad_output == "delayed" :
277
- # second stage of max-abs reduction
278
- num_extra_kernels += 1
279
- # reciprocal of scale
280
- num_extra_kernels += 1
281
-
282
- extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC
283
-
284
- return fp8_mem_time_s + extra_kernel_overhead_s
285
-
286
-
287
85
def run (
288
86
outfile : str ,
289
87
gemm_time_strategy : str = "benchmarks" ,
0 commit comments