Skip to content

Commit c0b0731

Browse files
authored
Specify output dtype to torch.float32 in _foreach_norm (#727)
one less kernel Signed-off-by: Masaki Kozuki <[email protected]>
1 parent 8002099 commit c0b0731

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torchao/float8/fsdp_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
5959
return
6060

6161
# inf-norm is equivalent to max(abs(w))
62-
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
62+
max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float32) # Partial
6363
amax_tensor = torch.stack(max_weights) # Partial
6464
# clamp is dispatched through DTensor
6565
# it will issue a single all-reduce
6666
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
6767
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
6868
if amax_tensor.dtype is torch.float16:
6969
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
70-
local_scale_tensor = scale_tensor.to_local().to(dtype=torch.float32)
70+
local_scale_tensor = scale_tensor.to_local()
7171
for i, float8_linear in enumerate(float8_linears):
7272
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]
7373

0 commit comments

Comments
 (0)