Skip to content

Commit a002f87

Browse files
four4fishtchaton
andauthored
[2/n] Directly call TrainingTypePlugin APIs instead of going through the Accelerator (#9901)
Co-authored-by: tchaton <[email protected]>
1 parent 1569869 commit a002f87

File tree

13 files changed

+317
-44
lines changed

13 files changed

+317
-44
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
291291
- Updated error message for interactive incompatible plugins ([#9896](https://github.com/PyTorchLightning/pytorch-lightning/pull/9896))
292292

293293

294+
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
295+
296+
294297
### Deprecated
295298

296299
- Deprecated trainer argument `terminate_on_nan` in favour of `detect_anomaly`([#9175](https://github.com/PyTorchLightning/pytorch-lightning/pull/9175))

pytorch_lightning/accelerators/accelerator.py

Lines changed: 222 additions & 15 deletions
Large diffs are not rendered by default.

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
100100
void(*args, **kwargs)
101101

102102
dataloader_idx: int = self.current_dataloader_idx
103-
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
103+
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
104104
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
105105
dl_max_batches = self._max_batches[dataloader_idx]
106106

pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def on_run_start(self) -> None:
8484
def advance(self, *args: Any, **kwargs: Any) -> None:
8585
"""Predicts one entire dataloader."""
8686
void(*args, **kwargs)
87-
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
87+
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
8888
dataloader_iter = enumerate(dataloader)
8989
dl_max_batches = self.max_batches[self.current_dataloader_idx]
9090

pytorch_lightning/loops/fit_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def on_advance_start(self) -> None:
204204

205205
def advance(self) -> None:
206206
"""Runs one whole epoch."""
207-
dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader)
207+
dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
208208
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)
209209

210210
with self.trainer.profiler.profile("run_training_epoch"):
@@ -234,7 +234,7 @@ def on_run_end(self) -> None:
234234
self.trainer.call_hook("on_train_end")
235235

236236
# give accelerators a chance to finish
237-
self.trainer.accelerator.on_train_end()
237+
self.trainer.training_type_plugin.on_train_end()
238238

239239
def teardown(self) -> None:
240240
self.epoch_loop.teardown()

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
107107
lightning_module._current_fx_name = "training_step"
108108
with self.trainer.profiler.profile("training_step"):
109109
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
110-
self.trainer.accelerator.post_training_step()
110+
self.trainer.training_type_plugin.post_training_step()
111111

112112
del step_kwargs
113113

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
450450
lightning_module._current_fx_name = "training_step"
451451
with self.trainer.profiler.profile("training_step"):
452452
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
453-
self.trainer.accelerator.post_training_step()
453+
self.trainer.training_type_plugin.post_training_step()
454454

455455
del step_kwargs
456456

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def save_checkpoint(self, filepath: _PATH, weights_only: bool = False) -> None:
465465
weights_only: saving model weights only
466466
"""
467467
_checkpoint = self.dump_checkpoint(weights_only)
468-
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)
468+
self.trainer.training_type_plugin.save_checkpoint(_checkpoint, filepath)
469469

470470
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
471471
metrics = (
@@ -478,7 +478,7 @@ def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
478478
metric.persistent(True)
479479
metric.sync()
480480

481-
state_dict = self.trainer.accelerator.lightning_module_state_dict()
481+
state_dict = self.trainer.training_type_plugin.lightning_module_state_dict()
482482

483483
for metric in metrics:
484484
# sync can be a no-op (e.g. on cpu) so `unsync` would raise a user error exception if we don't check

pytorch_lightning/trainer/trainer.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,13 +1020,13 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10201020
self.callback_connector.attach_model_logging_functions(model)
10211021

10221022
# attach model to the training type plugin
1023-
self.accelerator.connect(model)
1023+
self.training_type_plugin.connect(model)
10241024

10251025
# hook
10261026
self.data_connector.prepare_data()
10271027
self.callback_connector._attach_model_callbacks()
10281028

1029-
if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch:
1029+
if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
10301030
self._load_checkpoint_weights()
10311031

10321032
# ----------------------------
@@ -1037,7 +1037,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10371037
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
10381038

10391039
# check if we should delay restoring checkpoint till later
1040-
if not self.accelerator.restore_checkpoint_after_pre_dispatch:
1040+
if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
10411041
self.checkpoint_connector.resume_start()
10421042
self._restore_modules_and_callbacks()
10431043

@@ -1055,9 +1055,9 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10551055
| ||
10561056
{self._dispatch} ||
10571057
| || LIGHTNING
1058-
{self.accelerator.start_training} ||
1059-
or {self.accelerator.start_evaluating} ||
1060-
or {self.accelerator.start_predicting} || FLOW
1058+
{self.training_type_plugin.start_training} ||
1059+
or {self.training_type_plugin.start_evaluating} ||
1060+
or {self.training_type_plugin.start_predicting} || FLOW
10611061
| ||
10621062
{self.run_stage} ||
10631063
| || DIRECTION
@@ -1087,7 +1087,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10871087
# plugin will setup fitting (e.g. ddp will launch child processes)
10881088
self._pre_dispatch()
10891089

1090-
if self.accelerator.restore_checkpoint_after_pre_dispatch:
1090+
if self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
10911091
if self._ckpt_path:
10921092
self._load_checkpoint_weights()
10931093

@@ -1119,7 +1119,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
11191119
self.state.status = TrainerStatus.FINISHED
11201120
self.state.stage = None
11211121

1122-
return self.accelerator.results
1122+
return self.training_type_plugin.results
11231123

11241124
def _pre_dispatch(self):
11251125
self.accelerator.pre_dispatch(self)
@@ -1173,11 +1173,11 @@ def _post_dispatch(self):
11731173

11741174
def _dispatch(self):
11751175
if self.evaluating:
1176-
self.accelerator.start_evaluating(self)
1176+
self.training_type_plugin.start_evaluating(self)
11771177
elif self.predicting:
1178-
self.accelerator.start_predicting(self)
1178+
self.training_type_plugin.start_predicting(self)
11791179
else:
1180-
self.accelerator.start_training(self)
1180+
self.training_type_plugin.start_training(self)
11811181

11821182
def run_stage(self):
11831183
self.accelerator.dispatch(self)
@@ -1509,22 +1509,26 @@ def precision_plugin(self) -> PrecisionPlugin:
15091509

15101510
@property
15111511
def global_rank(self) -> int:
1512-
return self.accelerator.training_type_plugin.global_rank
1512+
return self.training_type_plugin.global_rank
15131513

15141514
@property
15151515
def local_rank(self) -> int:
15161516
# some training types define a local rank
1517-
return getattr(self.accelerator.training_type_plugin, "local_rank", 0)
1517+
return getattr(self.training_type_plugin, "local_rank", 0)
15181518

15191519
@property
15201520
def node_rank(self) -> int:
15211521
# some training types define a local rank
1522-
return getattr(self.accelerator.training_type_plugin, "node_rank", 0)
1522+
return getattr(self.training_type_plugin, "node_rank", 0)
15231523

15241524
@property
15251525
def world_size(self) -> int:
15261526
# some training types define a world size
1527-
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
1527+
return getattr(self.training_type_plugin, "world_size", 1)
1528+
1529+
@property
1530+
def should_rank_save_checkpoint(self) -> bool:
1531+
return self.training_type_plugin.should_rank_save_checkpoint
15281532

15291533
@property
15301534
def _distrib_type(self) -> DistributedType:

tests/accelerators/test_cpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_restore_checkpoint_after_pre_dispatch_default():
4343
"""Assert default for restore_checkpoint_after_pre_dispatch is False."""
4444
plugin = SingleDevicePlugin(torch.device("cpu"))
4545
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
46-
assert not accelerator.restore_checkpoint_after_pre_dispatch
46+
assert not accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch
4747
assert not plugin.restore_checkpoint_after_pre_dispatch
4848

4949

@@ -77,7 +77,7 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
7777
plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO())
7878
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
7979

80-
assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
80+
assert accelerator.training_type_plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
8181
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
8282

8383
trainer = Trainer(

tests/deprecated_api/test_remove_1-6.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import():
330330
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401
331331

332332

333-
def test_v1_6_0_deprecated_accelerator_collective():
333+
def test_v1_6_0_deprecated_accelerator_pass_through_functions():
334334
from pytorch_lightning.plugins.precision import PrecisionPlugin
335335
from pytorch_lightning.plugins.training_type import SingleDevicePlugin
336336

@@ -347,3 +347,62 @@ def test_v1_6_0_deprecated_accelerator_collective():
347347
with pytest.deprecated_call(match="will be removed in v1.6"):
348348
tensor = torch.rand(2, 2, requires_grad=True)
349349
accelerator.all_gather(tensor)
350+
351+
with pytest.deprecated_call(match="will be removed in v1.6"):
352+
model = BoringModel()
353+
accelerator.connect(model)
354+
355+
with pytest.deprecated_call(match="will be removed in v1.6"):
356+
accelerator.post_training_step()
357+
358+
with pytest.deprecated_call(match="will be removed in v1.6"):
359+
tensor = torch.rand(2, 2, requires_grad=True)
360+
accelerator.training_step_end(tensor)
361+
362+
with pytest.deprecated_call(match="will be removed in v1.6"):
363+
tensor = torch.rand(2, 2, requires_grad=True)
364+
accelerator.test_step_end(tensor)
365+
366+
with pytest.deprecated_call(match="will be removed in v1.6"):
367+
tensor = torch.rand(2, 2, requires_grad=True)
368+
accelerator.validation_step_end(tensor)
369+
370+
with pytest.deprecated_call(match="will be removed in v1.6"):
371+
accelerator.lightning_module_state_dict()
372+
373+
with pytest.deprecated_call(match="will be removed in v1.6"):
374+
dl = model.train_dataloader()
375+
accelerator.process_dataloader(dl)
376+
377+
with pytest.deprecated_call(match="will be removed in v1.6"):
378+
accelerator.results
379+
380+
with pytest.deprecated_call(match="will be removed in v1.6"):
381+
accelerator.setup_optimizers_in_pre_dispatch
382+
383+
with pytest.deprecated_call(match="will be removed in v1.6"):
384+
accelerator.restore_checkpoint_after_pre_dispatch
385+
386+
with pytest.deprecated_call(match="will be removed in v1.6"):
387+
accelerator.on_validation_start()
388+
389+
with pytest.deprecated_call(match="will be removed in v1.6"):
390+
accelerator.on_test_start()
391+
392+
with pytest.deprecated_call(match="will be removed in v1.6"):
393+
accelerator.on_predict_start()
394+
395+
with pytest.deprecated_call(match="will be removed in v1.6"):
396+
accelerator.on_validation_end()
397+
398+
with pytest.deprecated_call(match="will be removed in v1.6"):
399+
accelerator.on_test_end()
400+
401+
with pytest.deprecated_call(match="will be removed in v1.6"):
402+
accelerator.on_predict_end()
403+
404+
with pytest.deprecated_call(match="will be removed in v1.6"):
405+
accelerator.on_train_end()
406+
407+
with pytest.deprecated_call(match="will be removed in v1.6"):
408+
accelerator.on_train_batch_start(batch=None, batch_idx=0)

tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_fully_sharded_plugin_checkpoint_multi_gpus(tmpdir):
120120

121121
def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel):
122122
# Use FullySharded to get the state dict for the sake of comparison
123-
model_state_dict = trainer.accelerator.lightning_module_state_dict()
123+
model_state_dict = trainer.training_type_plugin.lightning_module_state_dict()
124124

125125
if trainer.is_global_zero:
126126
saved_model = cls.load_from_checkpoint(ckpt_path)

tests/plugins/test_ddp_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_ddp_configure_ddp():
108108
)
109109
# test wrap the model if fitting
110110
trainer.state.fn = TrainerFn.FITTING
111-
trainer.accelerator.connect(model)
111+
trainer.training_type_plugin.connect(model)
112112
trainer.accelerator.setup_environment()
113113
trainer.accelerator.setup(trainer)
114114
trainer.lightning_module.trainer = trainer
@@ -122,7 +122,7 @@ def test_ddp_configure_ddp():
122122
plugins=[ddp_plugin],
123123
)
124124
# test do not wrap the model if trainerFN is not fitting
125-
trainer.accelerator.connect(model)
125+
trainer.training_type_plugin.connect(model)
126126
trainer.accelerator.setup_environment()
127127
trainer.accelerator.setup(trainer)
128128
trainer.lightning_module.trainer = trainer

0 commit comments

Comments
 (0)