@@ -121,6 +121,10 @@ class XlaFullyShardedDataParallel(nn.Module):
121
121
execute_sharding_on_init (bool, Optional):
122
122
if ``True``, immediately execute the parameter sharding via
123
123
`xm.mark_step` to free up the memory of the full parameters.
124
+ optimization_barrier_on_output (bool, Optional):
125
+ if ``True``, apply `xm.optimization_barrier_` on the FSDP module's
126
+ outputs. This avoids fusion (by the XLA compiler) with subsequent
127
+ computation after the FSDP module and could save additional memory.
124
128
use_all_gather_via_all_reduce (bool, Optional):
125
129
if ``True``, use PyTorch XLA 1.10's all_gather implementation,
126
130
which performs all_gather via padding and all_reduce and avoids
@@ -139,7 +143,8 @@ def __init__(
139
143
reshard_after_forward : bool = True ,
140
144
flatten_parameters : bool = True ,
141
145
execute_sharding_on_init : bool = True ,
142
- use_all_gather_via_all_reduce : bool = True ,
146
+ optimization_barrier_on_output : bool = True ,
147
+ use_all_gather_via_all_reduce : bool = False ,
143
148
mark_step_on_freeing : bool = False ,
144
149
):
145
150
if isinstance (module , XlaFullyShardedDataParallel ):
@@ -165,6 +170,7 @@ def __init__(
165
170
self .world_size = xm .xrt_world_size ()
166
171
self .reshard_after_forward = self ._orig_reshard_after_forward = reshard_after_forward
167
172
self .flatten_parameters = flatten_parameters
173
+ self .optimization_barrier_on_output = optimization_barrier_on_output
168
174
if use_all_gather_via_all_reduce :
169
175
self .all_gather_op = all_gather_via_all_reduce
170
176
else :
@@ -591,6 +597,16 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
591
597
if self .reshard_after_forward :
592
598
self ._free_full_params ()
593
599
600
+ if self .optimization_barrier_on_output :
601
+ # Apply the XLA compiler optimization barrier to avoid fusion with
602
+ # subsequent computation. This potentially saves additional memory
603
+ # since fusion sometimes results in higher memory consumption.
604
+ def _apply_barrier (t ):
605
+ xm .optimization_barrier_ ([t ])
606
+ return t
607
+
608
+ outputs = apply_to_tensors (_apply_barrier , outputs )
609
+
594
610
# Register pre-backward hooks to all-gather the params for the backward
595
611
# pass (if output's grad was needed). This won't register anything if
596
612
# we are in eval mode.
@@ -865,12 +881,32 @@ def _rebuild_full_params(self) -> None:
865
881
"""
866
882
if self .has_full_params :
867
883
return
884
+ p_list , p_shard_list , p_data_list , p_shared_data_list = [], [], [], []
868
885
for p , p_shard in zip (self .full_params , self .sharded_params ):
869
886
if not p ._has_full_param :
870
887
# gather full parameter from shards
871
- p_padded = self .all_gather_op (p_shard ).flatten ().detach ()
872
- p .data = p_padded [:p_shard ._orig_size .numel ()].view (p_shard ._orig_size )
873
- p ._has_full_param = True
888
+ p_padded = self .all_gather_op (p_shard .data ).flatten ().detach ()
889
+ p_data = p_padded [:p_shard ._orig_size .numel ()].view (p_shard ._orig_size )
890
+ p_list .append (p )
891
+ p_shard_list .append (p_shard )
892
+ p_data_list .append (p_data )
893
+ p_shared_data_list .append (p_shard .data )
894
+
895
+ if len (p_data_list ) + len (p_shared_data_list ) > 0 :
896
+ # Apply the XLA compiler optimization barrier to avoid fusion of the
897
+ # full parameter reconstruction with other computation.
898
+ # Otherwise, the XLA compiler might fuse `_rebuild_full_params` in the
899
+ # the forward pass with any `_rebuild_full_params` in the backward pass
900
+ # through common subexpression elimination (CSE) and keep the full
901
+ # parameters (not freeing them and rebuilding them later, essentially
902
+ # changing `reshard_after_forward` to `False`` and using more memory).
903
+ xm .optimization_barrier_ (p_data_list + p_shared_data_list )
904
+ for p , p_shard , p_data , p_shard_data in zip (p_list , p_shard_list ,
905
+ p_data_list ,
906
+ p_shared_data_list ):
907
+ p .data = p_data
908
+ p_shard .data = p_shard_data
909
+ p ._has_full_param = True
874
910
self .has_full_params = True
875
911
876
912
@torch .no_grad ()
@@ -879,11 +915,27 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
879
915
if params is None :
880
916
params = self .full_params
881
917
self .has_full_params = False
918
+ p_list , p_data_list = [], []
882
919
for p in params :
883
920
if p ._has_full_param :
884
921
# free the original full parameter
885
- p .data = self ._dummy_data_placeholder
886
- p ._has_full_param = False
922
+ p_data = self ._dummy_data_placeholder
923
+ p_list .append (p )
924
+ p_data_list .append (p_data )
925
+
926
+ if len (p_data_list ) > 0 :
927
+ # Apply the XLA compiler optimization barrier to avoid fusion of the
928
+ # full parameter freeing with other computation.
929
+ # Otherwise, the XLA compiler might fuse `_free_full_params` in the
930
+ # forward pass with any `_free_full_params` in the backward pass
931
+ # through common subexpression elimination (CSE) and keep the full
932
+ # parameters (not freeing them and rebuilding them later, essentially
933
+ # changing `reshard_after_forward` to `False`` and using more memory).
934
+ xm .optimization_barrier_ (p_data_list )
935
+ for p , p_data in zip (p_list , p_data_list ):
936
+ p .data = p_data
937
+ p ._has_full_param = False
938
+
887
939
# immediately execute the parameter freeing as a workaround to undesired XLA fusion
888
940
# see https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513 for details
889
941
# TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved
0 commit comments