Skip to content

Commit 3fb7fea

Browse files
awaelchlirohitgr7
authored andcommitted
Standardize model attribute access in training type plugins (#11072)
1 parent dbfdb0a commit 3fb7fea

File tree

11 files changed

+20
-20
lines changed

11 files changed

+20
-20
lines changed

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def _register_ddp_hooks(self) -> None:
273273
_TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device
274274
):
275275
register_ddp_comm_hook(
276-
model=self._model,
276+
model=self.model,
277277
ddp_comm_state=self._ddp_comm_state,
278278
ddp_comm_hook=self._ddp_comm_hook,
279279
ddp_comm_wrapper=self._ddp_comm_wrapper,
@@ -330,7 +330,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
330330

331331
def configure_ddp(self) -> None:
332332
self.pre_configure_ddp()
333-
self._model = self._setup_model(LightningDistributedModule(self.model))
333+
self.model = self._setup_model(LightningDistributedModule(self.model))
334334
self._register_ddp_hooks()
335335

336336
def determine_ddp_device_ids(self):

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,15 @@ def _register_ddp_hooks(self) -> None:
203203
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
204204
if _TORCH_GREATER_EQUAL_1_8 and self.on_gpu and self._is_single_process_single_device:
205205
register_ddp_comm_hook(
206-
model=self._model,
206+
model=self.model,
207207
ddp_comm_state=self._ddp_comm_state,
208208
ddp_comm_hook=self._ddp_comm_hook,
209209
ddp_comm_wrapper=self._ddp_comm_wrapper,
210210
)
211211

212212
def configure_ddp(self) -> None:
213213
self.pre_configure_ddp()
214-
self._model = self._setup_model(LightningDistributedModule(self.model))
214+
self.model = self._setup_model(LightningDistributedModule(self.model))
215215
self._register_ddp_hooks()
216216

217217
def determine_ddp_device_ids(self):

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,9 @@ def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]
398398
# normally we set this to the batch size, but it is not available here unless the user provides it
399399
# as part of the config
400400
self.config.setdefault("train_micro_batch_size_per_gpu", 1)
401-
self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
401+
self.model, optimizer = self._setup_model_and_optimizer(model, optimizers[0])
402402
self._set_deepspeed_activation_checkpointing()
403-
return self._model, [optimizer]
403+
return self.model, [optimizer]
404404

405405
def _setup_model_and_optimizer(
406406
self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def world_size(self) -> int:
6565
def setup(self, trainer: "pl.Trainer") -> None:
6666
# model needs to be moved to the device before it is wrapped
6767
self.model_to_device()
68-
self._model = self._setup_model(LightningParallelModule(self._model))
68+
self.model = self._setup_model(LightningParallelModule(self.model))
6969
super().setup(trainer)
7070

7171
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
@@ -107,7 +107,7 @@ def root_device(self):
107107
return self.parallel_devices[0]
108108

109109
def model_to_device(self) -> None:
110-
self._model.to(self.root_device)
110+
self.model.to(self.root_device)
111111

112112
def barrier(self, *args, **kwargs):
113113
pass

pytorch_lightning/plugins/training_type/parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def on_tpu(self) -> bool:
5858

5959
@property
6060
def lightning_module(self) -> Optional["pl.LightningModule"]:
61-
return unwrap_lightning_module(self._model) if self._model is not None else None
61+
return unwrap_lightning_module(self.model) if self.model is not None else None
6262

6363
@property
6464
def global_rank(self) -> int:

pytorch_lightning/plugins/training_type/sharded.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def configure_ddp(self) -> None:
4545
# For multi-node training, enabling bucketing will improve performance.
4646
self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0
4747

48-
self._model, optimizers = self._setup_model_and_optimizers(
48+
self.model, optimizers = self._setup_model_and_optimizers(
4949
model=LightningShardedDataParallel(self.model),
5050
optimizers=trainer.optimizers,
5151
)
@@ -107,7 +107,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
107107
"`DDPShardedPlugin` requires `fairscale` to be installed."
108108
" Install it by running `pip install fairscale`."
109109
)
110-
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
110+
return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
111111

112112
def pre_backward(self, closure_loss: torch.Tensor) -> None:
113113
pass

pytorch_lightning/plugins/training_type/sharded_spawn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin):
4141

4242
def configure_ddp(self) -> None:
4343
trainer = self.lightning_module.trainer
44-
self._model, optimizers = self._setup_model_and_optimizers(
44+
self.model, optimizers = self._setup_model_and_optimizers(
4545
model=LightningShardedDataParallel(self.model),
4646
optimizers=trainer.optimizers,
4747
)
@@ -106,7 +106,7 @@ def lightning_module(self) -> Optional["pl.LightningModule"]:
106106
"`DDPSpawnShardedPlugin` requires `fairscale` to be installed."
107107
" Install it by running `pip install fairscale`."
108108
)
109-
return unwrap_lightning_module_sharded(self._model) if self._model is not None else None
109+
return unwrap_lightning_module_sharded(self.model) if self.model is not None else None
110110

111111
def pre_backward(self, closure_loss: torch.Tensor) -> None:
112112
pass

pytorch_lightning/plugins/training_type/single_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def root_device(self) -> torch.device:
6868
return self.device
6969

7070
def model_to_device(self) -> None:
71-
self._model.to(self.root_device)
71+
self.model.to(self.root_device)
7272

7373
def setup(self, trainer: "pl.Trainer") -> None:
7474
self.model_to_device()

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
132132
set_shared_parameters(self.model.module, shared_params)
133133

134134
self.setup_optimizers(trainer)
135-
self.precision_plugin.connect(self._model, None, None)
135+
self.precision_plugin.connect(self.model, None, None)
136136

137137
def setup(self, trainer: "pl.Trainer") -> None:
138138
self.start_method = "fork"

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def model(self, new_model: Optional[Module]) -> None:
293293
@property
294294
def lightning_module(self) -> Optional["pl.LightningModule"]:
295295
"""Returns the pure LightningModule without potential wrappers."""
296-
return unwrap_lightning_module(self._model) if self._model is not None else None
296+
return unwrap_lightning_module(self.model) if self.model is not None else None
297297

298298
def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]:
299299
torch.cuda.empty_cache()

tests/plugins/test_ddp_plugin_with_comm_hook.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_ddp_fp16_compress_comm_hook(tmpdir):
4040
fast_dev_run=True,
4141
)
4242
trainer.fit(model)
43-
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
43+
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
4444
expected_comm_hook = default.fp16_compress_hook.__qualname__
4545
assert trainer_comm_hook == expected_comm_hook
4646
assert trainer.state.finished, f"Training failed with {trainer.state}"
@@ -63,7 +63,7 @@ def test_ddp_sgd_comm_hook(tmpdir):
6363
fast_dev_run=True,
6464
)
6565
trainer.fit(model)
66-
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
66+
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
6767
expected_comm_hook = powerSGD.powerSGD_hook.__qualname__
6868
assert trainer_comm_hook == expected_comm_hook
6969
assert trainer.state.finished, f"Training failed with {trainer.state}"
@@ -87,7 +87,7 @@ def test_ddp_fp16_compress_wrap_sgd_comm_hook(tmpdir):
8787
fast_dev_run=True,
8888
)
8989
trainer.fit(model)
90-
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
90+
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
9191
expected_comm_hook = default.fp16_compress_wrapper(powerSGD.powerSGD_hook).__qualname__
9292
assert trainer_comm_hook == expected_comm_hook
9393
assert trainer.state.finished, f"Training failed with {trainer.state}"
@@ -132,7 +132,7 @@ def test_ddp_post_local_sgd_comm_hook(tmpdir):
132132
sync_batchnorm=True,
133133
)
134134
trainer.fit(model)
135-
trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data().comm_hook
135+
trainer_comm_hook = trainer.accelerator.training_type_plugin.model.get_ddp_logging_data().comm_hook
136136
expected_comm_hook = post_localSGD.post_localSGD_hook.__qualname__
137137
assert trainer_comm_hook == expected_comm_hook
138138
assert trainer.state.finished, f"Training failed with {trainer.state}"

0 commit comments

Comments
 (0)