|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import itertools
|
| 16 | +import logging |
16 | 17 | import threading
|
17 | 18 | from collections.abc import Iterable, Mapping
|
18 | 19 | from itertools import chain
|
19 | 20 |
|
20 | 21 | import torch
|
| 22 | +import torch.distributed as dist |
21 | 23 | from torch.cuda._utils import _get_device_index
|
22 | 24 | from torch.nn import DataParallel
|
23 | 25 | from torch.nn.parallel import DistributedDataParallel
|
24 | 26 | from torch.nn.parallel._functions import Gather
|
25 | 27 |
|
26 | 28 | from pytorch_lightning.core.step_result import Result
|
| 29 | +from pytorch_lightning.utilities import DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE |
27 | 30 | from pytorch_lightning.utilities.warning_utils import WarningCache
|
28 | 31 |
|
29 | 32 |
|
@@ -161,7 +164,30 @@ def parallel_apply(self, replicas, inputs, kwargs):
|
161 | 164 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
162 | 165 |
|
163 | 166 | def forward(self, *inputs, **kwargs): # pragma: no-cover
|
164 |
| - self._sync_params() |
| 167 | + # TODO: Update uneven inputs code path when PyTorch 1.8 is released. |
| 168 | + if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled: |
| 169 | + ones = torch.ones( |
| 170 | + 1, device=self.device |
| 171 | + ) |
| 172 | + work = dist.all_reduce(ones, group=self.process_group, async_op=True) |
| 173 | + self.reducer._set_forward_pass_work_handle( |
| 174 | + work, self.ddp_join_divide_by_initial_world_size |
| 175 | + ) |
| 176 | + |
| 177 | + # Calling _rebuild_buckets before forward computation, |
| 178 | + # It may allocate new buckets before deallocating old buckets |
| 179 | + # inside _rebuild_buckets. To save peak memory usage, |
| 180 | + # call _rebuild_buckets before the peak memory usage increases |
| 181 | + # during forward computation. |
| 182 | + # This should be called only once during whole training period. |
| 183 | + if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.reducer._rebuild_buckets(): |
| 184 | + logging.info("Reducer buckets have been rebuilt in this iteration.") |
| 185 | + |
| 186 | + if self.require_forward_param_sync: |
| 187 | + self._sync_params() |
| 188 | + if DDP_JOIN_AND_REBUILD_BUCKETS_AVAILABLE and self.ddp_join_enabled: |
| 189 | + # Notify joined ranks whether they should sync in backwards pass or not. |
| 190 | + self._check_global_requires_backward_grad_sync(is_joined_rank=False) |
165 | 191 | self.reducer_reset_hooks()
|
166 | 192 | fx_called: str = ''
|
167 | 193 |
|
|
0 commit comments