Skip to content

Commit 9a87bca

Browse files
committed
adding mark_step_on_freeing as a temp workaround to #3455
1 parent 8d2675b commit 9a87bca

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class XlaFullyShardedDataParallel(nn.Module):
125125
if ``True``, use PyTorch XLA 1.10's all_gather implementation,
126126
which performs all_gather via padding and all_reduce and avoids
127127
the GRPC error (see https://github.com/pytorch/xla/issues/3423).
128+
mark_step_on_freeing (bool, Optional):
129+
if ``True``, call `xm.mark_step` upon freeing full parameters.
130+
This is a temporary and inefficient workaround to avoid XLA compiler
131+
fusion that breaks parameter freeing in nested FSDP. It is useful
132+
only when ``reshard_after_forward`` is ``True``. See details in
133+
https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513.
128134
"""
129135

130136
def __init__(
@@ -134,6 +140,7 @@ def __init__(
134140
flatten_parameters: bool = True,
135141
execute_sharding_on_init: bool = True,
136142
use_all_gather_via_all_reduce: bool = True,
143+
mark_step_on_freeing: bool = False,
137144
):
138145
is_forward_defined = (
139146
hasattr(module, "forward") and hasattr(module.forward, "__func__") and
@@ -232,8 +239,16 @@ def __init__(
232239

233240
if execute_sharding_on_init:
234241
# Execute the parameter sharding immediately and free up the memory
235-
xm.mark_step()
236242
gc.collect()
243+
xm.mark_step()
244+
245+
# TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved
246+
# This is a temporary workaround before after we have a mature solution
247+
# to avoid undesired fusion with XLA compiler optimization barrier (see
248+
# https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513
249+
# for details). This workaround notably increases the execution time and
250+
# may trigger more compilation, so we need a permanent solution to #3455.
251+
self._mark_step_on_freeing = mark_step_on_freeing
237252

238253
def _get_gradient_predivide_factor(self, world_size: int) -> float:
239254
factor: int = 1
@@ -864,6 +879,11 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
864879
# free the original full parameter
865880
p.data = self._dummy_data_placeholder
866881
p._has_full_param = False
882+
# immediately execute the parameter freeing as a workaround to undesired XLA fusion
883+
# see https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513 for details
884+
# TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved
885+
if self._mark_step_on_freeing:
886+
xm.mark_step()
867887

868888
def assert_state(self, state: Union[TrainingState,
869889
List[TrainingState]]) -> None:

0 commit comments

Comments
 (0)