Skip to content

Commit 1664252

Browse files
committed
add optimization_barrier_ (#3493) to avoid fusion of full parameter reconstruction with subsequent freeing
1 parent 2c684a8 commit 1664252

File tree

1 file changed

+58
-6
lines changed

1 file changed

+58
-6
lines changed

torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ class XlaFullyShardedDataParallel(nn.Module):
121121
execute_sharding_on_init (bool, Optional):
122122
if ``True``, immediately execute the parameter sharding via
123123
`xm.mark_step` to free up the memory of the full parameters.
124+
optimization_barrier_on_output (bool, Optional):
125+
if ``True``, apply `xm.optimization_barrier_` on the FSDP module's
126+
outputs. This avoids fusion (by the XLA compiler) with subsequent
127+
computation after the FSDP module and could save additional memory.
124128
use_all_gather_via_all_reduce (bool, Optional):
125129
if ``True``, use PyTorch XLA 1.10's all_gather implementation,
126130
which performs all_gather via padding and all_reduce and avoids
@@ -139,7 +143,8 @@ def __init__(
139143
reshard_after_forward: bool = True,
140144
flatten_parameters: bool = True,
141145
execute_sharding_on_init: bool = True,
142-
use_all_gather_via_all_reduce: bool = True,
146+
optimization_barrier_on_output: bool = True,
147+
use_all_gather_via_all_reduce: bool = False,
143148
mark_step_on_freeing: bool = False,
144149
):
145150
if isinstance(module, XlaFullyShardedDataParallel):
@@ -165,6 +170,7 @@ def __init__(
165170
self.world_size = xm.xrt_world_size()
166171
self.reshard_after_forward = self._orig_reshard_after_forward = reshard_after_forward
167172
self.flatten_parameters = flatten_parameters
173+
self.optimization_barrier_on_output = optimization_barrier_on_output
168174
if use_all_gather_via_all_reduce:
169175
self.all_gather_op = all_gather_via_all_reduce
170176
else:
@@ -591,6 +597,16 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
591597
if self.reshard_after_forward:
592598
self._free_full_params()
593599

600+
if self.optimization_barrier_on_output:
601+
# Apply the XLA compiler optimization barrier to avoid fusion with
602+
# subsequent computation. This potentially saves additional memory
603+
# since fusion sometimes results in higher memory consumption.
604+
def _apply_barrier(t):
605+
xm.optimization_barrier_([t])
606+
return t
607+
608+
outputs = apply_to_tensors(_apply_barrier, outputs)
609+
594610
# Register pre-backward hooks to all-gather the params for the backward
595611
# pass (if output's grad was needed). This won't register anything if
596612
# we are in eval mode.
@@ -865,12 +881,32 @@ def _rebuild_full_params(self) -> None:
865881
"""
866882
if self.has_full_params:
867883
return
884+
p_list, p_shard_list, p_data_list, p_shared_data_list = [], [], [], []
868885
for p, p_shard in zip(self.full_params, self.sharded_params):
869886
if not p._has_full_param:
870887
# gather full parameter from shards
871-
p_padded = self.all_gather_op(p_shard).flatten().detach()
872-
p.data = p_padded[:p_shard._orig_size.numel()].view(p_shard._orig_size)
873-
p._has_full_param = True
888+
p_padded = self.all_gather_op(p_shard.data).flatten().detach()
889+
p_data = p_padded[:p_shard._orig_size.numel()].view(p_shard._orig_size)
890+
p_list.append(p)
891+
p_shard_list.append(p_shard)
892+
p_data_list.append(p_data)
893+
p_shared_data_list.append(p_shard.data)
894+
895+
if len(p_data_list) + len(p_shared_data_list) > 0:
896+
# Apply the XLA compiler optimization barrier to avoid fusion of the
897+
# full parameter reconstruction with other computation.
898+
# Otherwise, the XLA compiler might fuse `_rebuild_full_params` in the
899+
# the forward pass with any `_rebuild_full_params` in the backward pass
900+
# through common subexpression elimination (CSE) and keep the full
901+
# parameters (not freeing them and rebuilding them later, essentially
902+
# changing `reshard_after_forward` to `False`` and using more memory).
903+
xm.optimization_barrier_(p_data_list + p_shared_data_list)
904+
for p, p_shard, p_data, p_shard_data in zip(p_list, p_shard_list,
905+
p_data_list,
906+
p_shared_data_list):
907+
p.data = p_data
908+
p_shard.data = p_shard_data
909+
p._has_full_param = True
874910
self.has_full_params = True
875911

876912
@torch.no_grad()
@@ -879,11 +915,27 @@ def _free_full_params(self, params: Optional[List[Parameter]] = None) -> None:
879915
if params is None:
880916
params = self.full_params
881917
self.has_full_params = False
918+
p_list, p_data_list = [], []
882919
for p in params:
883920
if p._has_full_param:
884921
# free the original full parameter
885-
p.data = self._dummy_data_placeholder
886-
p._has_full_param = False
922+
p_data = self._dummy_data_placeholder
923+
p_list.append(p)
924+
p_data_list.append(p_data)
925+
926+
if len(p_data_list) > 0:
927+
# Apply the XLA compiler optimization barrier to avoid fusion of the
928+
# full parameter freeing with other computation.
929+
# Otherwise, the XLA compiler might fuse `_free_full_params` in the
930+
# forward pass with any `_free_full_params` in the backward pass
931+
# through common subexpression elimination (CSE) and keep the full
932+
# parameters (not freeing them and rebuilding them later, essentially
933+
# changing `reshard_after_forward` to `False`` and using more memory).
934+
xm.optimization_barrier_(p_data_list)
935+
for p, p_data in zip(p_list, p_data_list):
936+
p.data = p_data
937+
p._has_full_param = False
938+
887939
# immediately execute the parameter freeing as a workaround to undesired XLA fusion
888940
# see https://github.com/pytorch/xla/issues/3455#issuecomment-1085448513 for details
889941
# TODO (ronghanghu): remove when https://github.com/pytorch/xla/issues/3455 is resolved

0 commit comments

Comments
 (0)