File tree 1 file changed +2
-2
lines changed
1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -59,15 +59,15 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
59
59
return
60
60
61
61
# inf-norm is equivalent to max(abs(w))
62
- max_weights = torch ._foreach_norm (weights , ord = math .inf ) # Partial
62
+ max_weights = torch ._foreach_norm (weights , ord = math .inf , dtype = torch . float32 ) # Partial
63
63
amax_tensor = torch .stack (max_weights ) # Partial
64
64
# clamp is dispatched through DTensor
65
65
# it will issue a single all-reduce
66
66
amax_tensor = torch .clamp (amax_tensor , EPS ) # Replicate
67
67
scale_tensor = torch .finfo (torch .float8_e4m3fn ).max / amax_tensor # Replicate
68
68
if amax_tensor .dtype is torch .float16 :
69
69
scale_tensor = torch .clamp (scale_tensor , max = torch .finfo (torch .float16 ).max )
70
- local_scale_tensor = scale_tensor .to_local (). to ( dtype = torch . float32 )
70
+ local_scale_tensor = scale_tensor .to_local ()
71
71
for i , float8_linear in enumerate (float8_linears ):
72
72
float8_linear .weight ._local_tensor ._precomputed_scale = local_scale_tensor [i ]
73
73
You can’t perform that action at this time.
0 commit comments