@@ -125,6 +125,12 @@ class XlaFullyShardedDataParallel(nn.Module):
125
125
if ``True``, use PyTorch XLA 1.10's all_gather implementation,
126
126
which performs all_gather via padding and all_reduce and avoids
127
127
the GRPC error (see https://github.com/pytorch/xla/issues/3423).
128
+ mark_step_on_freeing (bool, Optional):
129
+ if ``True``, call `xm.mark_step` upon freeing full parameters.
130
+ This is a temporary and inefficient workaround to avoid XLA compiler
131
+ fusion that breaks parameter freeing in nested FSDP. It is useful
132
+ only when ``reshard_after_forward`` is ``True``. See details in
133
+ https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513.
128
134
"""
129
135
130
136
def __init__ (
@@ -134,6 +140,7 @@ def __init__(
134
140
flatten_parameters : bool = True ,
135
141
execute_sharding_on_init : bool = True ,
136
142
use_all_gather_via_all_reduce : bool = True ,
143
+ mark_step_on_freeing : bool = False ,
137
144
):
138
145
is_forward_defined = (
139
146
hasattr (module , "forward" ) and hasattr (module .forward , "__func__" ) and
@@ -232,8 +239,16 @@ def __init__(
232
239
233
240
if execute_sharding_on_init :
234
241
# Execute the parameter sharding immediately and free up the memory
235
- xm .mark_step ()
236
242
gc .collect ()
243
+ xm .mark_step ()
244
+
245
+ # TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved
246
+ # This is a temporary workaround before after we have a mature solution
247
+ # to avoid undesired fusion with XLA compiler optimization barrier (see
248
+ # https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513
249
+ # for details). This workaround notably increases the execution time and
250
+ # may trigger more compilation, so we need a permanent solution to #3455.
251
+ self ._mark_step_on_freeing = mark_step_on_freeing
237
252
238
253
def _get_gradient_predivide_factor (self , world_size : int ) -> float :
239
254
factor : int = 1
@@ -864,6 +879,11 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
864
879
# free the original full parameter
865
880
p .data = self ._dummy_data_placeholder
866
881
p ._has_full_param = False
882
+ # immediately execute the parameter freeing as a workaround to undesired XLA fusion
883
+ # see https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513 for details
884
+ # TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved
885
+ if self ._mark_step_on_freeing :
886
+ xm .mark_step ()
867
887
868
888
def assert_state (self , state : Union [TrainingState ,
869
889
List [TrainingState ]]) -> None :
0 commit comments