Skip to content

Commit 7b38ba7

Browse files
rohan-varmaSeanNaren
authored and
SeanNaren
committed
[Feat] Added uneven input support/sync with upstream DDP
1 parent c1920fe commit 7b38ba7

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pytorch_lightning/overrides/data_parallel.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,20 @@
1313
# limitations under the License.
1414

1515
import itertools
16+
import logging
1617
import threading
1718
from collections.abc import Iterable, Mapping
1819
from itertools import chain
1920

2021
import torch
22+
import torch.distributed as dist
2123
from torch.cuda._utils import _get_device_index
2224
from torch.nn import DataParallel
2325
from torch.nn.parallel import DistributedDataParallel
2426
from torch.nn.parallel._functions import Gather
2527

2628
from pytorch_lightning.core.step_result import Result
29+
from pytorch_lightning.utilities import DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE
2730
from pytorch_lightning.utilities.warning_utils import WarningCache
2831

2932

@@ -161,7 +164,30 @@ def parallel_apply(self, replicas, inputs, kwargs):
161164
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
162165

163166
def forward(self, *inputs, **kwargs): # pragma: no-cover
164-
self._sync_params()
167+
# TODO: Update uneven inputs code path when PyTorch 1.8 is released.
168+
if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled:
169+
ones = torch.ones(
170+
1, device=self.device
171+
)
172+
work = dist.all_reduce(ones, group=self.process_group, async_op=True)
173+
self.reducer._set_forward_pass_work_handle(
174+
work, self.ddp_join_divide_by_initial_world_size
175+
)
176+
177+
# Calling _rebuild_buckets before forward computation,
178+
# It may allocate new buckets before deallocating old buckets
179+
# inside _rebuild_buckets. To save peak memory usage,
180+
# call _rebuild_buckets before the peak memory usage increases
181+
# during forward computation.
182+
# This should be called only once during whole training period.
183+
if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.reducer._rebuild_buckets():
184+
logging.info("Reducer buckets have been rebuilt in this iteration.")
185+
186+
if self.require_forward_param_sync:
187+
self._sync_params()
188+
if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled:
189+
# Notify joined ranks whether they should sync in backwards pass or not.
190+
self._check_global_requires_backward_grad_sync(is_joined_rank=False)
165191
self.reducer_reset_hooks()
166192
fx_called: str = ''
167193

pytorch_lightning/utilities/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def _module_available(module_path: str) -> bool:
6262
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
6363
_BOLTS_AVAILABLE = _module_available('pl_bolts')
6464

65+
DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE = LooseVersion(torch.__version__) >= LooseVersion("1.7.0")
66+
6567
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
6668
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
6769
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps

0 commit comments

Comments
 (0)