Skip to content

Commit e3c9f22

Browse files
committed
fix format
1 parent b30d9e5 commit e3c9f22

File tree

5 files changed

+49
-53
lines changed

5 files changed

+49
-53
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: Trai
6161
def connect(self, model: "pl.LightningModule") -> None:
6262
"""Transfers ownership of the model to this plugin.
6363
64+
See deprecation warning below.
65+
6466
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
6567
`training_type_plugin.connect` directly.
6668
"""
@@ -73,17 +75,9 @@ def connect(self, model: "pl.LightningModule") -> None:
7375
def setup_environment(self) -> None:
7476
"""Setup any processes or distributed connections.
7577
76-
.. deprecated:: v1.5
77-
This method is deprecated in v1.5 and will be removed in v1.6.
78-
Please call `training_type_plugin.setup_environment` directly.
79-
8078
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator
8179
environment before setup is complete.
8280
"""
83-
rank_zero_deprecation(
84-
"`Accelerator.setup_environment` is deprecated in v1.5 and will be removed in v1.6. "
85-
"`setup_environment` logic is implemented directly in the `TrainingTypePlugin` implementations."
86-
)
8781
self.training_type_plugin.setup_environment()
8882

8983
def setup(self, trainer: "pl.Trainer") -> None:
@@ -189,16 +183,8 @@ def root_device(self) -> torch.device:
189183
def teardown(self) -> None:
190184
"""This method is called to teardown the training process.
191185
192-
.. deprecated:: v1.5
193-
This method is deprecated in v1.5 and will be removed in v1.6.
194-
Please call `training_type_plugin.teardown` directly.
195-
196186
It is the right place to release memory and free other resources.
197187
"""
198-
rank_zero_deprecation(
199-
"`Accelerator.teardown` is deprecated in v1.5 and will be removed in v1.6. "
200-
"`teardown` logic is implemented directly in the `TrainingTypePlugin` implementations."
201-
)
202188
self.training_type_plugin.teardown()
203189

204190
def batch_to_device(self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0) -> Any:
@@ -588,22 +574,17 @@ def get_device_stats(self, device: Union[str, torch.device]) -> Dict[str, Any]:
588574
raise NotImplementedError
589575

590576
def on_train_start(self) -> None:
591-
"""Called when train begins.
592-
593-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
594-
`training_type_plugin.on_train_start` directly.
595-
"""
596-
rank_zero_deprecation(
597-
"`Accelerator.on_train_start` is deprecated in v1.5 and will be removed in v1.6. "
598-
"`on_train_start` logic is implemented directly in the `TrainingTypePlugin` implementations."
599-
)
577+
"""Called when train begins."""
600578
return self.training_type_plugin.on_train_start()
601579

602580
def on_validation_start(self) -> None:
603581
"""Called when validation begins.
604582
605-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
606-
`training_type_plugin.on_validation_start` directly.
583+
See deprecation warning below.
584+
585+
.. deprecated:: v1.5
586+
This method is deprecated in v1.5 and will be removed in v1.6.
587+
Please call `training_type_plugin.on_validation_start` directly.
607588
"""
608589
rank_zero_deprecation(
609590
"`Accelerator.on_validation_start` is deprecated in v1.5 and will be removed in v1.6. "
@@ -614,8 +595,11 @@ def on_validation_start(self) -> None:
614595
def on_test_start(self) -> None:
615596
"""Called when test begins.
616597
617-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
618-
`training_type_plugin.on_test_start` directly.
598+
See deprecation warning below.
599+
600+
.. deprecated:: v1.5
601+
This method is deprecated in v1.5 and will be removed in v1.6.
602+
Please call `training_type_plugin.on_test_start` directly.
619603
"""
620604
rank_zero_deprecation(
621605
"`Accelerator.on_test_start` is deprecated in v1.5 and will be removed in v1.6. "
@@ -626,8 +610,11 @@ def on_test_start(self) -> None:
626610
def on_predict_start(self) -> None:
627611
"""Called when predict begins.
628612
629-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
630-
`training_type_plugin.on_predict_start` directly.
613+
See deprecation warning below.
614+
615+
.. deprecated:: v1.5
616+
This method is deprecated in v1.5 and will be removed in v1.6.
617+
Please call `training_type_plugin.on_predict_start` directly.
631618
"""
632619
rank_zero_deprecation(
633620
"`Accelerator.on_predict_start` is deprecated in v1.5 and will be removed in v1.6. "
@@ -638,8 +625,11 @@ def on_predict_start(self) -> None:
638625
def on_validation_end(self) -> None:
639626
"""Called when validation ends.
640627
641-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
642-
`training_type_plugin.on_validation_end` directly.
628+
See deprecation warning below.
629+
630+
.. deprecated:: v1.5
631+
This method is deprecated in v1.5 and will be removed in v1.6.
632+
Please call `training_type_plugin.on_validation_end` directly.
643633
"""
644634
rank_zero_deprecation(
645635
"`Accelerator.on_validation_end` is deprecated in v1.5 and will be removed in v1.6. "
@@ -650,8 +640,11 @@ def on_validation_end(self) -> None:
650640
def on_test_end(self) -> None:
651641
"""Called when test end.
652642
653-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
654-
`training_type_plugin.on_test_end` directly.
643+
See deprecation warning below.
644+
645+
.. deprecated:: v1.5
646+
This method is deprecated in v1.5 and will be removed in v1.6.
647+
Please call `training_type_plugin.on_test_end` directly.
655648
"""
656649
rank_zero_deprecation(
657650
"`Accelerator.on_test_end` is deprecated in v1.5 and will be removed in v1.6. "
@@ -662,8 +655,11 @@ def on_test_end(self) -> None:
662655
def on_predict_end(self) -> None:
663656
"""Called when predict ends.
664657
665-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
666-
`training_type_plugin.on_predict_end` directly.
658+
See deprecation warning below.
659+
660+
.. deprecated:: v1.5
661+
This method is deprecated in v1.5 and will be removed in v1.6.
662+
Please call `training_type_plugin.on_predict_end` directly.
667663
"""
668664
rank_zero_deprecation(
669665
"`Accelerator.on_predict_end` is deprecated in v1.5 and will be removed in v1.6. "
@@ -674,8 +670,11 @@ def on_predict_end(self) -> None:
674670
def on_train_end(self) -> None:
675671
"""Called when train ends.
676672
677-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
678-
`training_type_plugin.on_train_end` directly.
673+
See deprecation warning below.
674+
675+
.. deprecated:: v1.5
676+
This method is deprecated in v1.5 and will be removed in v1.6.
677+
Please call `training_type_plugin.on_train_end` directly.
679678
"""
680679
rank_zero_deprecation(
681680
"`Accelerator.on_train_end` is deprecated in v1.5 and will be removed in v1.6. "
@@ -687,8 +686,11 @@ def on_train_end(self) -> None:
687686
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
688687
"""Called in the training loop before anything happens for that batch.
689688
690-
.. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.6. Please call
691-
`training_type_plugin.on_train_batch_start` directly.
689+
See deprecation warning below.
690+
691+
.. deprecated:: v1.5
692+
This method is deprecated in v1.5 and will be removed in v1.6.
693+
Please call `training_type_plugin.on_train_batch_start` directly.
692694
"""
693695
rank_zero_deprecation(
694696
"`Accelerator.on_train_batch_start` is deprecated in v1.5 and will be removed in v1.6. "

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,7 +1014,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10141014
# SET UP TRAINING
10151015
# ----------------------------
10161016
self.call_hook("on_before_accelerator_backend_setup")
1017-
self.training_type_plugin.setup_environment()
1017+
self.accelerator.setup_environment()
10181018
self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
10191019

10201020
# check if we should delay restoring checkpoint till later
@@ -1147,7 +1147,7 @@ def _post_dispatch(self):
11471147
self.accelerator.post_dispatch(self)
11481148
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
11491149
# which need to happen before.
1150-
self.training_type_plugin.teardown()
1150+
self.accelerator.teardown()
11511151
self.data_connector.teardown()
11521152
self._active_loop.teardown()
11531153
self.logger_connector.teardown()

tests/deprecated_api/test_remove_1-6.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def test_v1_6_0_deprecated_device_dtype_mixin_import():
330330
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin # noqa: F401
331331

332332

333-
def test_v1_6_0_deprecated_accelerator_collective():
333+
def test_v1_6_0_deprecated_accelerator_pass_through_functions():
334334
from pytorch_lightning.plugins.precision import PrecisionPlugin
335335
from pytorch_lightning.plugins.training_type import SingleDevicePlugin
336336

@@ -352,9 +352,6 @@ def test_v1_6_0_deprecated_accelerator_collective():
352352
model = BoringModel()
353353
accelerator.connect(model)
354354

355-
with pytest.deprecated_call(match="will be removed in v1.6"):
356-
accelerator.setup_environment()
357-
358355
with pytest.deprecated_call(match="will be removed in v1.6"):
359356
accelerator.teardown()
360357

@@ -389,9 +386,6 @@ def test_v1_6_0_deprecated_accelerator_collective():
389386
with pytest.deprecated_call(match="will be removed in v1.6"):
390387
accelerator.restore_checkpoint_after_pre_dispatch
391388

392-
with pytest.deprecated_call(match="will be removed in v1.6"):
393-
accelerator.on_train_start()
394-
395389
with pytest.deprecated_call(match="will be removed in v1.6"):
396390
accelerator.on_validation_start()
397391

tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_invalid_on_cpu(tmpdir):
2424
):
2525
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins="fsdp")
2626
assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin)
27-
trainer.training_type_plugin.setup_environment()
27+
trainer.accelerator.setup_environment()
2828

2929

3030
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"})

tests/plugins/test_ddp_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_ddp_configure_ddp():
109109
# test wrap the model if fitting
110110
trainer.state.fn = TrainerFn.FITTING
111111
trainer.training_type_plugin.connect(model)
112-
trainer.training_type_plugin.setup_environment()
112+
trainer.accelerator.setup_environment()
113113
trainer.accelerator.setup(trainer)
114114
trainer.lightning_module.trainer = trainer
115115
assert isinstance(trainer.model, LightningModule)
@@ -123,7 +123,7 @@ def test_ddp_configure_ddp():
123123
)
124124
# test do not wrap the model if trainerFN is not fitting
125125
trainer.training_type_plugin.connect(model)
126-
trainer.training_type_plugin.setup_environment()
126+
trainer.accelerator.setup_environment()
127127
trainer.accelerator.setup(trainer)
128128
trainer.lightning_module.trainer = trainer
129129
trainer._pre_dispatch()

0 commit comments

Comments
 (0)