Skip to content

Commit 700521c

Browse files
four4fishawaelchlipre-commit-ci[bot]
authored
1/n Move precision plugin into strategy - update reference (#10570)
* 1/n move precision plugin into strategy - update reference * update precision plugin reference in tpu_spawn * add missing reference in error message * add back removed license line * update references in tests * update reference in trainer * update return annotation for precision_plugin property on TTP * simplify access to precision plugin reference in sharded plug * add changelog * remove precision property from ttp and add deprecation message * fix make doc and update precision reference * simplify a reference to precision accidentally overridden Adrian's change, now add it back * Update CHANGELOG.md add Adrian's change back * Update accelerator precision Add Adrian's change back * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add none check for precision plugin just to be safe * Update ipu.py * update precision_plugin param deprecation message * Update accelerator.py * Remove deprecated warning Tests will fail after 9940 Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2c7c4aa commit 700521c

24 files changed

+142
-59
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Raise `MisconfigurationException` when `enable_progress_bar=False` and a progress bar instance has been passed in the callback list ([#10520](https://github.com/PyTorchLightning/pytorch-lightning/issues/10520))
3838

3939

40+
- Moved ownership of the `PrecisionPlugin` into `TrainingTypePlugin` and updated all references ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
41+
42+
4043
-
4144

4245

@@ -50,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5053
- Deprecated `DistributedType` in favor of `_StrategyType` ([#10505](https://github.com/PyTorchLightning/pytorch-lightning/pull/10505))
5154

5255

53-
-
56+
- Deprecated the `precision_plugin` constructor argument from `Accelerator` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
5457

5558

5659
-
@@ -139,6 +142,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
139142
- Removed deprecated `reload_dataloaders_every_epoch` from `Trainer` in favour of `reload_dataloaders_every_n_epochs` ([#10481](https://github.com/PyTorchLightning/pytorch-lightning/pull/10481))
140143

141144

145+
- Removed the `precision_plugin` attribute from `Accelerator` in favor of its equivalent attribute `precision_plugin` in the `TrainingTypePlugin` ([#10570](https://github.com/PyTorchLightning/pytorch-lightning/pull/10570))
142146

143147
### Fixed
144148

pytorch_lightning/accelerators/accelerator.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
2626
from pytorch_lightning.plugins.training_type import DataParallelPlugin, TrainingTypePlugin
2727
from pytorch_lightning.trainer.states import TrainerFn
28+
from pytorch_lightning.utilities import rank_zero_deprecation
2829
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2930
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
3031
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -44,15 +45,23 @@ class Accelerator:
4445
One to handle differences from the training routine and one to handle different precisions.
4546
"""
4647

47-
def __init__(self, precision_plugin: PrecisionPlugin, training_type_plugin: TrainingTypePlugin) -> None:
48+
def __init__(self, precision_plugin: Optional[PrecisionPlugin], training_type_plugin: TrainingTypePlugin) -> None:
4849
"""
4950
Args:
5051
precision_plugin: the plugin to handle precision-specific parts
52+
53+
.. deprecated::
54+
The ``precision_plugin`` parameter has been deprecated and will be removed soon.
55+
Pass the precision plugin as a parameter to the ``TrainingTypePlugin`` instead.
56+
5157
training_type_plugin: the plugin to handle different training routines
5258
"""
53-
self.precision_plugin = precision_plugin
59+
5460
self.training_type_plugin = training_type_plugin
5561

62+
if precision_plugin is not None:
63+
self.training_type_plugin._precision_plugin = precision_plugin
64+
5665
self.optimizers: List = []
5766
self.lr_schedulers: List = []
5867
self.optimizer_frequencies: List = []
@@ -84,7 +93,7 @@ def pre_dispatch(self, trainer: "pl.Trainer") -> None:
8493
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
8594
self.setup_optimizers(trainer)
8695

87-
self.precision_plugin.pre_dispatch()
96+
self.training_type_plugin.precision_plugin.pre_dispatch()
8897

8998
def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
9099
"""Moves the state of the optimizers to the GPU if needed."""
@@ -96,12 +105,12 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
96105
def dispatch(self, trainer: "pl.Trainer") -> None:
97106
"""Hook to do something before the training/evaluation/prediction starts."""
98107
self.training_type_plugin.dispatch(trainer)
99-
self.precision_plugin.dispatch(trainer)
108+
self.training_type_plugin.precision_plugin.dispatch(trainer)
100109

101110
def post_dispatch(self, trainer: "pl.Trainer") -> None:
102111
"""Hook to do something after the training/evaluation/prediction starts."""
103112
self.training_type_plugin.post_dispatch(trainer)
104-
self.precision_plugin.post_dispatch()
113+
self.training_type_plugin.precision_plugin.post_dispatch()
105114

106115
@property
107116
def model(self) -> Module:
@@ -159,31 +168,31 @@ def training_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
159168
160169
See :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` for more details
161170
"""
162-
with self.precision_plugin.train_step_context():
171+
with self.training_type_plugin.precision_plugin.train_step_context():
163172
return self.training_type_plugin.training_step(*step_kwargs.values())
164173

165174
def validation_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
166175
"""The actual validation step.
167176
168177
See :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step` for more details
169178
"""
170-
with self.precision_plugin.val_step_context():
179+
with self.training_type_plugin.precision_plugin.val_step_context():
171180
return self.training_type_plugin.validation_step(*step_kwargs.values())
172181

173182
def test_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> Optional[STEP_OUTPUT]:
174183
"""The actual test step.
175184
176185
See :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step` for more details
177186
"""
178-
with self.precision_plugin.test_step_context():
187+
with self.training_type_plugin.precision_plugin.test_step_context():
179188
return self.training_type_plugin.test_step(*step_kwargs.values())
180189

181190
def predict_step(self, step_kwargs: Dict[str, Union[Any, int]]) -> STEP_OUTPUT:
182191
"""The actual predict step.
183192
184193
See :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` for more details
185194
"""
186-
with self.precision_plugin.predict_step_context():
195+
with self.training_type_plugin.precision_plugin.predict_step_context():
187196
return self.training_type_plugin.predict_step(*step_kwargs.values())
188197

189198
def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
@@ -193,11 +202,11 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
193202
closure_loss: a tensor holding the loss value to backpropagate
194203
"""
195204
self.training_type_plugin.pre_backward(closure_loss)
196-
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)
205+
closure_loss = self.training_type_plugin.precision_plugin.pre_backward(self.lightning_module, closure_loss)
197206

198-
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
207+
self.training_type_plugin.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
199208

200-
closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
209+
closure_loss = self.training_type_plugin.precision_plugin.post_backward(self.lightning_module, closure_loss)
201210
self.training_type_plugin.post_backward(closure_loss)
202211

203212
return closure_loss
@@ -208,7 +217,7 @@ def optimizer_step(
208217
opt_idx: int,
209218
closure: Callable[[], Any],
210219
model: Optional[Union["pl.LightningModule", Module]] = None,
211-
**kwargs: Any
220+
**kwargs: Any,
212221
) -> None:
213222
"""performs the actual optimizer step.
214223
@@ -220,7 +229,7 @@ def optimizer_step(
220229
**kwargs: Any extra arguments to ``optimizer.step``
221230
"""
222231
model = model or self.lightning_module
223-
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
232+
self.training_type_plugin.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
224233

225234
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
226235
"""Zeros all model parameter's gradients."""
@@ -248,26 +257,38 @@ def setup_training_type_plugin(self) -> None:
248257

249258
def setup_precision_plugin(self) -> None:
250259
"""Attaches the precision plugin to the accelerator."""
251-
model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers)
260+
model, optimizers, schedulers = self.training_type_plugin.precision_plugin.connect(
261+
self.model, self.optimizers, self.lr_schedulers
262+
)
252263
self.model = model
253264
self.optimizers = optimizers
254265
self.lr_schedulers = schedulers
255266

256267
@property
257268
def amp_backend(self) -> Optional[LightningEnum]:
258-
if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin):
269+
if isinstance(self.training_type_plugin.precision_plugin, ApexMixedPrecisionPlugin):
259270
return AMPType.APEX
260-
if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin):
271+
if isinstance(self.training_type_plugin.precision_plugin, NativeMixedPrecisionPlugin):
261272
return AMPType.NATIVE
262273
return None
263274

264275
@property
265276
def precision(self) -> Union[str, int]:
266-
return self.precision_plugin.precision
277+
"""The type of precision being used with this accelerator.
278+
279+
.. deprecated::
280+
This property been deprecated and will be removed soon.
281+
Use ``training_type_plugin.precision_plugin.precision`` instead.
282+
"""
283+
rank_zero_deprecation(
284+
f"`{self.__class__.__name__}.precision` has been deprecated and will be removed soon"
285+
f" Use `training_type_plugin.precision_plugin.precision` instead."
286+
)
287+
return self.training_type_plugin.precision_plugin.precision
267288

268289
@property
269290
def scaler(self) -> Optional["GradScaler"]:
270-
return getattr(self.precision_plugin, "scaler", None)
291+
return getattr(self.training_type_plugin.precision_plugin, "scaler", None)
271292

272293
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
273294
"""Returns state of an optimizer.

pytorch_lightning/accelerators/tpu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ def setup(self, trainer: "pl.Trainer") -> None:
3636
ValueError:
3737
If the precision or training type plugin are unsupported.
3838
"""
39-
if not isinstance(self.precision_plugin, TPUPrecisionPlugin):
39+
if not isinstance(self.training_type_plugin.precision_plugin, TPUPrecisionPlugin):
4040
# this configuration should have been avoided in the accelerator connector
4141
raise ValueError(
42-
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`, found: {self.precision_plugin}."
42+
f"The `TPUAccelerator` can only be used with a `TPUPrecisionPlugin`,"
43+
f" found: {self.training_type_plugin.precision_plugin}."
4344
)
4445
if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
4546
raise ValueError(

pytorch_lightning/lite/lite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
)
109109
self._accelerator = self._accelerator_connector.accelerator
110110
self._strategy = self._accelerator.training_type_plugin
111-
self._precision_plugin = self._accelerator.precision_plugin
111+
self._precision_plugin = self._strategy.precision_plugin
112112
self._models_setup: int = 0
113113

114114
# wrap the run method so we can inject setup logic or spawn processes for the user

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from pytorch_lightning.overrides.distributed import prepare_for_backward
3737
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3838
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
39+
from pytorch_lightning.plugins.precision import PrecisionPlugin
3940
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
4041
from pytorch_lightning.trainer.states import TrainerFn
4142
from pytorch_lightning.utilities import (
@@ -86,6 +87,7 @@ def __init__(
8687
parallel_devices: Optional[List[torch.device]] = None,
8788
cluster_environment: Optional[ClusterEnvironment] = None,
8889
checkpoint_io: Optional[CheckpointIO] = None,
90+
precision_plugin: Optional[PrecisionPlugin] = None,
8991
ddp_comm_state: Optional[object] = None,
9092
ddp_comm_hook: Optional[callable] = None,
9193
ddp_comm_wrapper: Optional[callable] = None,
@@ -96,6 +98,7 @@ def __init__(
9698
parallel_devices=parallel_devices,
9799
cluster_environment=cluster_environment,
98100
checkpoint_io=checkpoint_io,
101+
precision_plugin=precision_plugin,
99102
)
100103
self.interactive_ddp_procs = []
101104
self._num_nodes = 1

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pytorch_lightning.overrides.distributed import prepare_for_backward
3030
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3131
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
32+
from pytorch_lightning.plugins.precision import PrecisionPlugin
3233
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
3334
from pytorch_lightning.trainer.states import TrainerFn
3435
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8, rank_zero_warn
@@ -65,6 +66,7 @@ def __init__(
6566
parallel_devices: Optional[List[torch.device]] = None,
6667
cluster_environment: Optional[ClusterEnvironment] = None,
6768
checkpoint_io: Optional[CheckpointIO] = None,
69+
precision_plugin: Optional[PrecisionPlugin] = None,
6870
ddp_comm_state: Optional[object] = None,
6971
ddp_comm_hook: Optional[callable] = None,
7072
ddp_comm_wrapper: Optional[callable] = None,
@@ -74,6 +76,7 @@ def __init__(
7476
parallel_devices=parallel_devices,
7577
cluster_environment=cluster_environment,
7678
checkpoint_io=checkpoint_io,
79+
precision_plugin=precision_plugin,
7780
)
7881
self._num_nodes = 1
7982
self.sync_batchnorm = False

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
3131
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3232
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
33+
from pytorch_lightning.plugins.precision import PrecisionPlugin
3334
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
3435
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
3536
from pytorch_lightning.trainer.states import TrainerFn
@@ -129,6 +130,7 @@ def __init__(
129130
synchronize_checkpoint_boundary: bool = False,
130131
load_full_weights: bool = False,
131132
partition_module: bool = True,
133+
precision_plugin: Optional[PrecisionPlugin] = None,
132134
) -> None:
133135
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
134136
billion parameter models. `For more information: https://pytorch-
@@ -273,6 +275,7 @@ def __init__(
273275
super().__init__(
274276
parallel_devices=parallel_devices,
275277
cluster_environment=cluster_environment,
278+
precision_plugin=precision_plugin,
276279
)
277280

278281
self.config = self._load_config(config)
@@ -331,7 +334,7 @@ def __init__(
331334

332335
@property
333336
def precision(self) -> Union[str, int]:
334-
return self._precision or self.lightning_module.trainer.precision
337+
return self._precision or self.precision_plugin.precision
335338

336339
@property
337340
def amp_level(self) -> Optional[str]:
@@ -456,8 +459,7 @@ def init_deepspeed(self):
456459
"DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs."
457460
)
458461

459-
precision = self.lightning_module.trainer.accelerator.precision
460-
model = LightningDeepSpeedModule(pl_module=self.model, precision=precision)
462+
model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision)
461463

462464
if self.zero_stage_3 and self.partition_module:
463465
# Ensure the entire model has been moved to the appropriate device

pytorch_lightning/plugins/training_type/dp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
2020
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
21+
from pytorch_lightning.plugins.precision import PrecisionPlugin
2122
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
2223
from pytorch_lightning.utilities.apply_func import apply_to_collection
2324
from pytorch_lightning.utilities.enums import _StrategyType
@@ -35,8 +36,14 @@ def __init__(
3536
self,
3637
parallel_devices: Optional[List[torch.device]] = None,
3738
checkpoint_io: Optional[CheckpointIO] = None,
39+
precision_plugin: Optional[PrecisionPlugin] = None,
3840
):
39-
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
41+
super().__init__(
42+
parallel_devices=parallel_devices,
43+
cluster_environment=None,
44+
checkpoint_io=checkpoint_io,
45+
precision_plugin=precision_plugin,
46+
)
4047

4148
@property
4249
def global_rank(self) -> int:

pytorch_lightning/plugins/training_type/fully_sharded.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
2020
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
21+
from pytorch_lightning.plugins.precision import PrecisionPlugin
2122
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
2223
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
2324
from pytorch_lightning.utilities.enums import _StrategyType
@@ -46,6 +47,7 @@ def __init__(
4647
parallel_devices: Optional[List[torch.device]] = None,
4748
cluster_environment: Optional[ClusterEnvironment] = None,
4849
checkpoint_io: Optional[CheckpointIO] = None,
50+
precision_plugin: Optional[PrecisionPlugin] = None,
4951
):
5052
"""Plugin for Fully Sharded Data Parallel provided by FairScale.
5153
@@ -97,6 +99,7 @@ def __init__(
9799
parallel_devices=parallel_devices,
98100
cluster_environment=cluster_environment,
99101
checkpoint_io=checkpoint_io,
102+
precision_plugin=precision_plugin,
100103
)
101104
self.cpu_offload = cpu_offload
102105
self.move_grads_to_cpu = move_grads_to_cpu
@@ -124,7 +127,7 @@ def setup_distributed(self) -> None:
124127

125128
@contextlib.contextmanager
126129
def model_sharded_context(self) -> Generator:
127-
precision = self.lightning_module.trainer.precision
130+
precision = self.precision_plugin.precision
128131

129132
def wrap_policy(*args, **kwargs):
130133
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params)

pytorch_lightning/plugins/training_type/horovod.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from pytorch_lightning.core.optimizer import LightningOptimizer
2323
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
24+
from pytorch_lightning.plugins.precision import PrecisionPlugin
2425
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
2526
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
2627
from pytorch_lightning.utilities.distributed import distributed_available
@@ -41,8 +42,14 @@ def __init__(
4142
self,
4243
parallel_devices: Optional[List[torch.device]] = None,
4344
checkpoint_io: Optional[CheckpointIO] = None,
45+
precision_plugin: Optional[PrecisionPlugin] = None,
4446
):
45-
super().__init__(parallel_devices=parallel_devices, cluster_environment=None, checkpoint_io=checkpoint_io)
47+
super().__init__(
48+
parallel_devices=parallel_devices,
49+
cluster_environment=None,
50+
checkpoint_io=checkpoint_io,
51+
precision_plugin=precision_plugin,
52+
)
4653
rank_zero_only.rank = self.global_rank
4754

4855
@property

0 commit comments

Comments
 (0)