|
63 | 63 | from pytorch_lightning.utilities.seed import reset_seed
|
64 | 64 | from pytorch_lightning.utilities.types import STEP_OUTPUT
|
65 | 65 |
|
66 |
| -if _TORCH_GREATER_EQUAL_1_10: |
67 |
| - if not _IS_WINDOWS: |
68 |
| - from torch.distributed.optim import DistributedOptimizer |
69 |
| - from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer |
70 |
| - |
71 | 66 | if _FAIRSCALE_AVAILABLE:
|
72 | 67 | from fairscale.optim import OSS
|
73 | 68 | if _HYDRA_AVAILABLE:
|
74 | 69 | from hydra.core.hydra_config import HydraConfig
|
75 | 70 | from hydra.utils import get_original_cwd, to_absolute_path
|
76 | 71 | if _TORCH_GREATER_EQUAL_1_8:
|
77 | 72 | from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
|
78 |
| -if _TORCH_GREATER_EQUAL_1_10: |
79 |
| - import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD |
80 |
| - import torch.distributed.algorithms.model_averaging.averagers as averagers |
| 73 | + |
81 | 74 |
|
82 | 75 | log = logging.getLogger(__name__)
|
83 | 76 |
|
@@ -312,19 +305,24 @@ def _register_ddp_hooks(self) -> None:
|
312 | 305 | ddp_comm_wrapper=self._ddp_comm_wrapper,
|
313 | 306 | )
|
314 | 307 |
|
315 |
| - if ( |
316 |
| - _TORCH_GREATER_EQUAL_1_10 |
317 |
| - and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState) |
318 |
| - and self.lightning_module.trainer.state.fn == TrainerFn.FITTING |
319 |
| - ): |
320 |
| - self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) |
| 308 | + if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: |
| 309 | + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD |
| 310 | + |
| 311 | + if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): |
| 312 | + self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter) |
321 | 313 |
|
322 | 314 | def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
|
323 | 315 | optimizers = self.lightning_module.trainer.optimizers
|
324 | 316 | if self._model_averaging_period is None:
|
325 | 317 | raise ValueError(
|
326 | 318 | "Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
|
327 | 319 | )
|
| 320 | + if _TORCH_GREATER_EQUAL_1_10: |
| 321 | + if not _IS_WINDOWS: |
| 322 | + from torch.distributed.optim import DistributedOptimizer |
| 323 | + import torch.distributed.algorithms.model_averaging.averagers as averagers |
| 324 | + from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer |
| 325 | + |
328 | 326 | averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
|
329 | 327 | for x, optimizer in enumerate(optimizers):
|
330 | 328 | if isinstance(optimizer, LightningOptimizer):
|
|
0 commit comments