Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.

Commit e5308dd

Browse files
four4fishRaalsky
authored andcommitted
Only import PostLocalSGD related modules when it's needed (Lightning-AI#10359)
* Only import PostLocalSGD related modules when it's needed * Only import PostLocalSGD related modules when it's needed * Only import PostLocalSGD related modules when it's needed
1 parent adb2e83 commit e5308dd

File tree

1 file changed

+12
-14
lines changed
  • pytorch_lightning/plugins/training_type

1 file changed

+12
-14
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,14 @@
6363
from pytorch_lightning.utilities.seed import reset_seed
6464
from pytorch_lightning.utilities.types import STEP_OUTPUT
6565

66-
if _TORCH_GREATER_EQUAL_1_10:
67-
if not _IS_WINDOWS:
68-
from torch.distributed.optim import DistributedOptimizer
69-
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
70-
7166
if _FAIRSCALE_AVAILABLE:
7267
from fairscale.optim import OSS
7368
if _HYDRA_AVAILABLE:
7469
from hydra.core.hydra_config import HydraConfig
7570
from hydra.utils import get_original_cwd, to_absolute_path
7671
if _TORCH_GREATER_EQUAL_1_8:
7772
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook
78-
if _TORCH_GREATER_EQUAL_1_10:
79-
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
80-
import torch.distributed.algorithms.model_averaging.averagers as averagers
73+
8174

8275
log = logging.getLogger(__name__)
8376

@@ -312,19 +305,24 @@ def _register_ddp_hooks(self) -> None:
312305
ddp_comm_wrapper=self._ddp_comm_wrapper,
313306
)
314307

315-
if (
316-
_TORCH_GREATER_EQUAL_1_10
317-
and isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState)
318-
and self.lightning_module.trainer.state.fn == TrainerFn.FITTING
319-
):
320-
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
308+
if _TORCH_GREATER_EQUAL_1_10 and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
309+
import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD
310+
311+
if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState):
312+
self._reinit_optimizers_with_post_localSGD(self._ddp_comm_state.start_localSGD_iter)
321313

322314
def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
323315
optimizers = self.lightning_module.trainer.optimizers
324316
if self._model_averaging_period is None:
325317
raise ValueError(
326318
"Post-localSGD algorithm is used, but model averaging period is not provided to DDP plugin."
327319
)
320+
if _TORCH_GREATER_EQUAL_1_10:
321+
if not _IS_WINDOWS:
322+
from torch.distributed.optim import DistributedOptimizer
323+
import torch.distributed.algorithms.model_averaging.averagers as averagers
324+
from torch.distributed.optim import PostLocalSGDOptimizer, ZeroRedundancyOptimizer
325+
328326
averager = averagers.PeriodicModelAverager(period=self._model_averaging_period, warmup_steps=warmup_steps)
329327
for x, optimizer in enumerate(optimizers):
330328
if isinstance(optimizer, LightningOptimizer):

0 commit comments

Comments
 (0)