diff --git a/CHANGELOG.md b/CHANGELOG.md index dc9b167b81190..b37af8d34b03a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Internal checks for PrecisionType, StrategyType and AcceleratorType have been removed in favor of instance-checks against the respective classes ([#12069](https://github.com/PyTorchLightning/pytorch-lightning/pull/12069)) + + - Drop PyTorch 1.7 support ([#12191](https://github.com/PyTorchLightning/pytorch-lightning/pull/12191)), ([#12432](https://github.com/PyTorchLightning/pytorch-lightning/pull/12432)) @@ -423,6 +426,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `amp_backend` property of the `Trainer` in favor of instance-checks ([#12069](https://github.com/PyTorchLightning/pytorch-lightning/pull/12069)) + +- Deprecated `backend` property of `MixedPrecisionPlugin` in favor of instance-checks ([#12069](https://github.com/PyTorchLightning/pytorch-lightning/pull/12069)) + - Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141)) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 117125f29d8db..7f9433846112d 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Type, Union import torch import torch.nn as nn @@ -24,10 +24,20 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators import Accelerator, CPUAccelerator, GPUAccelerator, TPUAccelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import PLUGIN_INPUT -from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy, TPUSpawnStrategy +from pytorch_lightning.strategies import ( + DataParallelStrategy, + DDPShardedStrategy, + DDPSpawnShardedStrategy, + DDPSpawnStrategy, + DDPStrategy, + DeepSpeedStrategy, + SingleDeviceStrategy, + Strategy, + TPUSpawnStrategy, +) from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector from pytorch_lightning.utilities import _AcceleratorType, _StrategyType, move_data_to_device @@ -78,8 +88,9 @@ def __init__( gpus: Optional[Union[List[int], str, int]] = None, tpu_cores: Optional[Union[List[int], str, int]] = None, ) -> None: - self._check_accelerator_support(accelerator) - self._check_strategy_support(strategy) + self._check_accelerator_flag(accelerator) + self._check_strategy_flag(strategy) + gpu_ids, tpu_cores = _parse_devices(gpus=gpus, auto_select_gpus=False, tpu_cores=tpu_cores) self._accelerator_connector = AcceleratorConnector( num_processes=None, @@ -103,6 +114,10 @@ def __init__( self._strategy = self._accelerator_connector.strategy self._accelerator = self._strategy.accelerator self._precision_plugin = self._strategy.precision_plugin + + self._check_accelerator_type(self._accelerator) + self._check_strategy_type(self._strategy) + self._models_setup: int = 0 # wrap the run method so we can inject setup logic or spawn processes for the user @@ -442,7 +457,7 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) return DistributedSampler(dataloader.dataset, **kwargs) - def _check_accelerator_support(self, accelerator: Optional[Union[str, Accelerator]]) -> None: + def _check_accelerator_flag(self, accelerator: Optional[Union[str, Accelerator]]) -> None: supported = [t.value.lower() for t in self._supported_device_types()] + ["auto"] valid = accelerator is None or isinstance(accelerator, Accelerator) or accelerator in supported if not valid: @@ -451,7 +466,7 @@ def _check_accelerator_support(self, accelerator: Optional[Union[str, Accelerato f" Choose one of {supported} or pass in a `Accelerator` instance." ) - def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> None: + def _check_strategy_flag(self, strategy: Optional[Union[str, Strategy]]) -> None: supported = [t.lower() for t in self._supported_strategy_types()] valid = strategy is None or isinstance(strategy, Strategy) or strategy in supported if not valid: @@ -460,6 +475,26 @@ def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> N f" Choose one of {supported} or pass in a `Strategy` instance." ) + def _check_accelerator_type(self, accelerator: Accelerator) -> None: + if not isinstance(accelerator, self._supported_accelerators()): + supported_values = ["auto"] + [x.lower() for x in self._supported_device_types] + raise MisconfigurationException( + f"`accelerator={accelerator!r}` is not a valid choice for `LightningLite`." + f" Choose one of {supported_values} or pass in a `Accelerator` instance." + ) + + def _check_strategy_type(self, strategy: Optional[Union[str, Strategy]]) -> None: + if not isinstance(strategy, self._supported_strategies()): + valid = [t.lower() for t in self._supported_strategy_types()] + raise MisconfigurationException( + f"`strategy={strategy!r}` is not a valid choice for `LightningLite`." + f" Choose one of {valid} or pass in a `Strategy` instance." + ) + + @staticmethod + def _supported_accelerators() -> Tuple[Type[Accelerator], ...]: + return (CPUAccelerator, GPUAccelerator, TPUAccelerator) + @staticmethod def _supported_device_types() -> Sequence[_AcceleratorType]: return ( @@ -468,6 +503,18 @@ def _supported_device_types() -> Sequence[_AcceleratorType]: _AcceleratorType.TPU, ) + @staticmethod + def _supported_strategies() -> Tuple[Type[Strategy], ...]: + return ( + SingleDeviceStrategy, + DataParallelStrategy, + DDPStrategy, + DDPSpawnStrategy, + DeepSpeedStrategy, + DDPShardedStrategy, + DDPSpawnShardedStrategy, + ) + @staticmethod def _supported_strategy_types() -> Sequence[_StrategyType]: return ( diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index bab025466789a..b40d8f21f4088 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -29,8 +29,9 @@ _extract_hiddens, check_finite_loss, ) +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -353,7 +354,7 @@ def _optimizer_step( is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) # wraps into LightningOptimizer only for running step - if self.trainer.amp_backend == AMPType.APEX: + if isinstance(self.trainer.strategy.precision_plugin, ApexMixedPrecisionPlugin): # apex overrides .step function and need to be wrapped on each step optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer.strategy, opt_idx) else: @@ -374,7 +375,7 @@ def _optimizer_step( opt_idx, train_step_and_backward_closure, on_tpu=isinstance(self.trainer.accelerator, TPUAccelerator), - using_native_amp=(self.trainer.amp_backend == AMPType.NATIVE), + using_native_amp=isinstance(self.trainer.strategy.precision_plugin, NativeMixedPrecisionPlugin), using_lbfgs=is_lbfgs, ) diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index c329aedcf6f00..12410aea4fd79 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -21,6 +21,7 @@ from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation from pytorch_lightning.utilities.types import _PARAMETERS if _APEX_AVAILABLE: @@ -30,8 +31,6 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" - backend = AMPType.APEX - def __init__(self, amp_level: str = "O2") -> None: if not _APEX_AVAILABLE: raise MisconfigurationException( @@ -98,3 +97,11 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: amp.load_state_dict(state_dict) + + @property + def backend(self) -> AMPType: + rank_zero_deprecation( + "The backend property has been deprecated in v1.6 and will be removed in v1.7." + " Please switch to `isinstance(X, ApexMixedPrecisionPlugin)` check instead." + ) + return AMPType.APEX diff --git a/pytorch_lightning/plugins/precision/mixed.py b/pytorch_lightning/plugins/precision/mixed.py index 52c8b96d42882..eeade0f8ebaf7 100644 --- a/pytorch_lightning/plugins/precision/mixed.py +++ b/pytorch_lightning/plugins/precision/mixed.py @@ -11,16 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Union +from typing import Union from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin - -if TYPE_CHECKING: - from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import AMPType class MixedPrecisionPlugin(PrecisionPlugin): """Base Class for mixed precision.""" - backend: "AMPType" + @property + def backend(self) -> AMPType: + """AMP-Backend used by this plugin. + + Typically one of AMPType.NATIVE | AMPType.APEX + + .. deprecated:: v1.6 + This property is deprecated in 1.6 and will be removed in 1.7. + Please use instance checks against the plugin class instead. + """ + precision: Union[str, int] = "mixed" diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index fa749af1d4a08..fa33f52d8ac99 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -23,6 +23,7 @@ from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation if _TORCH_GREATER_EQUAL_1_10: from torch import autocast as new_autocast @@ -39,8 +40,6 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """ - backend = AMPType.NATIVE - def __init__( self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None ) -> None: @@ -116,3 +115,11 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict) + + @property + def backend(self) -> AMPType: + rank_zero_deprecation( + "The backend property has been deprecated in v1.6 and will be removed in v1.7." + " Please switch to `isinstance(X, NativeMixedPrecisionPlugin)` check instead." + ) + return AMPType.NATIVE diff --git a/pytorch_lightning/strategies/deepspeed.py b/pytorch_lightning/strategies/deepspeed.py index bdec69c43b2f4..62d0a0f4bfebd 100644 --- a/pytorch_lightning/strategies/deepspeed.py +++ b/pytorch_lightning/strategies/deepspeed.py @@ -30,6 +30,8 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.precision import PrecisionPlugin +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import GradClipAlgorithmType @@ -39,7 +41,7 @@ get_default_process_group_backend_for_device, log, ) -from pytorch_lightning.utilities.enums import AMPType, PrecisionType +from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.model_helpers import is_overridden @@ -651,7 +653,7 @@ def _auto_select_batch_size(self): def _format_precision_config(self) -> None: if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED): - if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE: + if "fp16" not in self.config and isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") self.config["fp16"] = { @@ -662,7 +664,7 @@ def _format_precision_config(self) -> None: "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale, } - elif "amp" not in self.config and self.precision_plugin.amp_type == AMPType.APEX: + elif "amp" not in self.config and isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): rank_zero_info("Enabling DeepSpeed APEX Implementation.") self.config["amp"] = {"enabled": True, "opt_level": self.precision_plugin.amp_level} diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index e86aa95e7b848..05fdaf993d833 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -70,7 +70,6 @@ TPUSpawnStrategy, ) from pytorch_lightning.utilities import ( - _StrategyType, AMPType, device_parser, LightningEnum, @@ -78,6 +77,7 @@ rank_zero_info, rank_zero_warn, ) +from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE, _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 6e883b3230701..5d856272e2437 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -23,6 +23,7 @@ import pytorch_lightning as pl from pytorch_lightning.accelerators.ipu import IPUAccelerator from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator @@ -165,7 +166,7 @@ def _copy_trainer_model_properties(self, model): for m in [model, ref_model]: m.trainer = proxy(self.trainer) # Remove setting use_amp in v1.8 - m._use_amp = self.trainer.amp_backend is not None + m._use_amp = isinstance(self.trainer.strategy.precision_plugin, MixedPrecisionPlugin) m.precision = self.trainer.precision def attach_dataloaders( diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e5168dfef4c83..4a6381b77f212 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2179,6 +2179,10 @@ def optimizer_frequencies(self, new_freqs: List[int]) -> None: @property def amp_backend(self) -> Optional[AMPType]: + rank_zero_deprecation( + "amp_backend is deprecated in v1.6 and will be removed in v1.7. " + "Use `isinstance` check against the `PrecisionPlugins` directly." + ) if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): return AMPType.APEX if isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 8cfaa843a9f10..cb7c724e75d54 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -34,6 +34,8 @@ SLURMEnvironment, TorchElasticEnvironment, ) +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.strategies import SingleDeviceStrategy from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel @@ -516,3 +518,23 @@ def post_dispatch(self, trainer): with pytest.deprecated_call(match=escape("`CustomPlugin.post_dispatch()` has been deprecated in v1.6")): CustomPlugin(torch.device("cpu")) + + +def test_v1_7_0_trainer_amp_backend(): + trainer = Trainer() + with pytest.deprecated_call(match="amp_backend is deprecated in v1.6 and will be removed in v1.7."): + trainer.amp_backend + + +def test_v1_7_0_mixed_precision_plugin_backend_native(): + plugin = NativeMixedPrecisionPlugin(16, "cpu") + + with pytest.deprecated_call(match="The backend property has been deprecated in v1.6 and will be removed in v1.7."): + plugin.backend + + +@RunIf(amp_apex=True) +def test_v1_7_0_mixed_precision_plugin_backend_apex(): + plugin = ApexMixedPrecisionPlugin() + with pytest.deprecated_call(match="The backend property has been deprecated in v1.6 and will be removed in v1.7."): + plugin.backend diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 3fb42fb0ce29e..bccf6712ca72e 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -22,6 +22,8 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import SLURMEnvironment +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from tests.helpers import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -180,10 +182,13 @@ def test_amp_without_apex(bwd_mock, tmpdir): model = BoringModel() trainer = Trainer(default_root_dir=tmpdir, amp_backend="native") - assert trainer.amp_backend is None + + assert not isinstance(trainer.precision_plugin, NativeMixedPrecisionPlugin) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, amp_backend="apex") - assert trainer.amp_backend is None + + assert not isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) + trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" assert not bwd_mock.called @@ -207,11 +212,10 @@ def configure_optimizers(self): model = CustomModel() model.training_epoch_end = None - trainer = Trainer( default_root_dir=tmpdir, max_steps=5, precision=16, amp_backend="apex", accelerator="gpu", devices=1 ) - assert str(trainer.amp_backend) == "AMPType.APEX" + assert isinstance(trainer.precision_plugin, ApexMixedPrecisionPlugin) trainer.fit(model) assert trainer.state.finished, f"Training failed with {trainer.state}" # `max_steps` is fulfilled in the third batch first optimizer, but we don't check the loop diff --git a/tests/tuner/test_scale_batch_size.py b/tests/tuner/test_scale_batch_size.py index 13112ee9f4a51..a50c4fa4a1697 100644 --- a/tests/tuner/test_scale_batch_size.py +++ b/tests/tuner/test_scale_batch_size.py @@ -20,8 +20,8 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.tuner.tuning import Tuner -from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -246,7 +246,7 @@ def test_auto_scale_batch_size_with_amp(tmpdir): ) trainer.tune(model) after_batch_size = model.batch_size - assert trainer.amp_backend == AMPType.NATIVE + assert isinstance(trainer.strategy.precision_plugin, NativeMixedPrecisionPlugin) assert trainer.scaler is not None assert after_batch_size != before_batch_size