diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 0d73c9fbe8bef..1204099411c2a 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -187,6 +187,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated duplicate `SaveConfigCallback` parameters in `LightningCLI.__init__`: `save_config_kwargs`, `save_config_overwrite` and `save_config_multifile`. New `save_config_kwargs` parameter should be used instead ([#14998](https://github.com/Lightning-AI/lightning/pull/14998) +- Deprecated `TrainerFn.TUNING`, `RunningStage.TUNING` and `trainer.tuning` property ([#15100](https://github.com/Lightning-AI/lightning/pull/15100) + + ### Removed - Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011)) diff --git a/src/pytorch_lightning/callbacks/timer.py b/src/pytorch_lightning/callbacks/timer.py index ca9a2c9861faa..75763ae3ac868 100644 --- a/src/pytorch_lightning/callbacks/timer.py +++ b/src/pytorch_lightning/callbacks/timer.py @@ -95,8 +95,8 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} - self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()} + self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -161,7 +161,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) - self._check_time_remaining(trainer) def state_dict(self) -> Dict[str, Any]: - return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}} + return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 82f0aed9d1366..d7e4cb500665f 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -605,7 +605,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: Args: trainer: the Trainer, these optimizers should be connected to """ - if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer.state.fn != TrainerFn.FITTING: return # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 66c11b26c90ff..87354545aa932 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -130,7 +130,7 @@ def setup(self, trainer: "pl.Trainer") -> None: # Separate models are instantiated for different stages, but they share the same weights on host. # When validation/test models are run, weights are synced first. trainer_fn = self.lightning_module.trainer.state.fn - if trainer_fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer_fn == TrainerFn.FITTING: # Create model for training and validation which will run on fit training_opts = self.training_opts inference_opts = self.inference_opts diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 8e0ca583ff682..31d69eb55de92 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -136,7 +136,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: Args: trainer: the Trainer, these optimizers should be connected to """ - if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer.state.fn != TrainerFn.FITTING: return assert self.lightning_module is not None self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers( diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index c1947354027e8..3f76967f10385 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -37,7 +37,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: if trainer.state.fn is None: raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.") - if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer.state.fn == TrainerFn.FITTING: __verify_train_val_loop_configuration(trainer, model) __verify_manual_optimization_support(trainer, model) __check_training_step_requires_dataloader_iter(model) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 25fd02a91a280..924b870ce7d9d 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -368,7 +368,7 @@ def restore_loops(self) -> None: assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: - if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.trainer.state.fn == TrainerFn.FITTING: fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index a81073cccc1c0..0063ef3fabe96 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -12,12 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field -from typing import Optional +from enum import Enum, EnumMeta +from typing import Any, List, Optional + +from lightning_utilities.core.rank_zero import rank_zero_deprecation from pytorch_lightning.utilities import LightningEnum from pytorch_lightning.utilities.enums import _FaultTolerantMode +class _DeprecationManagingEnumMeta(EnumMeta): + """Enum that calls `deprecate()` whenever a member is accessed. + + Adapted from: https://stackoverflow.com/a/62309159/208880 + """ + + def __getattribute__(cls, name: str) -> Any: + obj = super().__getattribute__(name) + # ignore __dunder__ names -- prevents potential recursion errors + if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum): + obj.deprecate() + return obj + + def __getitem__(cls, name: str) -> Any: + member: _DeprecationManagingEnumMeta = super().__getitem__(name) + member.deprecate() + return member + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + obj = super().__call__(*args, **kwargs) + if isinstance(obj, Enum): + obj.deprecate() + return obj + + class TrainerStatus(LightningEnum): """Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`""" @@ -31,7 +59,7 @@ def stopped(self) -> bool: return self in (self.FINISHED, self.INTERRUPTED) -class TrainerFn(LightningEnum): +class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta): """ Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer` such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and @@ -44,16 +72,19 @@ class TrainerFn(LightningEnum): PREDICTING = "predict" TUNING = "tune" - @property - def _setup_fn(self) -> "TrainerFn": - """``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders. + def deprecate(self) -> None: + if self == self.TUNING: + rank_zero_deprecation( + f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ) - This is used for the ``setup()`` and ``teardown()`` hooks - """ - return TrainerFn.FITTING if self == TrainerFn.TUNING else self + @classmethod + def _without_tune(cls) -> List["TrainerFn"]: + fns = [fn for fn in cls if fn != "tune"] + return fns -class RunningStage(LightningEnum): +class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta): """Enum for the current running stage. This stage complements :class:`TrainerFn` by specifying the current running stage for each function. @@ -79,12 +110,23 @@ def evaluating(self) -> bool: @property def dataloader_prefix(self) -> Optional[str]: - if self in (self.SANITY_CHECKING, self.TUNING): + if self == self.SANITY_CHECKING: return None if self == self.VALIDATING: return "val" return self.value + def deprecate(self) -> None: + if self == self.TUNING: + rank_zero_deprecation( + f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ) + + @classmethod + def _without_tune(cls) -> List["RunningStage"]: + fns = [fn for fn in cls if fn != "tune"] + return fns + @dataclass class TrainerState: diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index a385c96cdf415..8f8a1e24b42b0 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -961,7 +961,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None def _run( self, model: "pl.LightningModule", ckpt_path: Optional[str] = None ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self ) @@ -1233,7 +1233,7 @@ def _run_sanity_check(self) -> None: def _call_setup_hook(self) -> None: assert self.state.fn is not None - fn = self.state.fn._setup_fn + fn = self.state.fn self.strategy.barrier("pre_setup") @@ -1256,7 +1256,7 @@ def _call_configure_sharded_model(self) -> None: def _call_teardown_hook(self) -> None: assert self.state.fn is not None - fn = self.state.fn._setup_fn + fn = self.state.fn if self.datamodule is not None: self._call_lightning_datamodule_hook("teardown", stage=fn) @@ -1449,7 +1449,7 @@ def __setup_profiler(self) -> None: assert self.state.fn is not None local_rank = self.local_rank if self.world_size > 1 else None self.profiler._lightning_module = proxy(self.lightning_module) - self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) + self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir) """ Data loading methods @@ -1965,10 +1965,13 @@ def predicting(self, val: bool) -> None: @property def tuning(self) -> bool: + rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.") return self.state.stage == RunningStage.TUNING @tuning.setter def tuning(self, val: bool) -> None: + rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.") + if val: self.state.stage = RunningStage.TUNING elif self.tuning: @@ -2097,7 +2100,7 @@ def predict_loop(self, loop: PredictionLoop) -> None: @property def _evaluation_loop(self) -> EvaluationLoop: - if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.state.fn == TrainerFn.FITTING: return self.fit_loop.epoch_loop.val_loop if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop diff --git a/src/pytorch_lightning/tuner/tuning.py b/src/pytorch_lightning/tuner/tuning.py index dd18b06c698f2..1abba552b98cd 100644 --- a/src/pytorch_lightning/tuner/tuning.py +++ b/src/pytorch_lightning/tuner/tuning.py @@ -20,7 +20,7 @@ from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.callbacks.lr_finder import LearningRateFinder from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus +from pytorch_lightning.trainer.states import TrainerStatus from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -77,9 +77,7 @@ def _tune( # Run learning rate finder: if self.trainer.auto_lr_find: - self.trainer.state.fn = TrainerFn.TUNING self.trainer.state.status = TrainerStatus.RUNNING - self.tuning = True # TODO: Remove this once LRFinder is converted to a Callback # if a datamodule comes in as the second arg, then fix it for the user @@ -112,7 +110,6 @@ def _run(self, *args: Any, **kwargs: Any) -> None: self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED` self.trainer.training = True self.trainer._run(*args, **kwargs) - self.trainer.tuning = True def scale_batch_size( self, @@ -170,10 +167,6 @@ def scale_batch_size( - ``model.hparams`` - ``trainer.datamodule`` (the datamodule passed to the tune method) """ - # TODO: Remove TrainerFn.TUNING since we are now calling fit/validate/test/predict methods directly - self.trainer.state.fn = TrainerFn.TUNING - self.tuning = True - _check_tuner_configuration(self.trainer, train_dataloaders, val_dataloaders, dataloaders, method) batch_size_finder: Callback = BatchSizeFinder( @@ -254,9 +247,6 @@ def lr_find( If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``, or if you are using more than one optimizer. """ - self.trainer.state.fn = TrainerFn.TUNING - self.tuning = True - if method != "fit": raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.") diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 5378f6647e67b..ea70aa8be6eb8 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -35,6 +35,7 @@ from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule from pytorch_lightning.strategies.utils import on_colab_kaggle +from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.apply_func import ( apply_to_collection, apply_to_collections, @@ -297,3 +298,27 @@ def test_lite_convert_deprecated_tpus_argument(tpu_available): def test_lightningCLI_save_config_init_params_deprecation_warning(name, value): with mock.patch("sys.argv", ["any.py"]), pytest.deprecated_call(match=f".*{name!r} init parameter is deprecated.*"): LightningCLI(BoringModel, run=False, **{name: value}) + + +def test_tuning_enum(): + with pytest.deprecated_call( + match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ): + TrainerFn.TUNING + + with pytest.deprecated_call( + match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ): + RunningStage.TUNING + + +def test_tuning_trainer_property(): + trainer = Trainer() + + with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0."): + trainer.tuning + + with pytest.deprecated_call( + match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ): + trainer.tuning = True diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index a5413e8ee0876..b21bf51de7af3 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -183,39 +183,28 @@ def _check_model_state_dict(self): for actual, expected in zip(self.state_dict(), state_dict["state_dict"]) ) - def _test_on_val_test_predict_tune_start(self): + def _test_on_val_test_predict_start(self): assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() - # no optimizes and schedulers are loaded otherwise - if self.trainer.state.fn != TrainerFn.TUNING: - return - - assert not self._check_optimizers() - assert not self._check_schedulers() - def on_train_start(self): - if self.trainer.state.fn == TrainerFn.TUNING: - self._test_on_val_test_predict_tune_start() - else: - assert self.trainer.current_epoch == state_dict["epoch"] + 1 - assert self.trainer.global_step == state_dict["global_step"] - assert self._check_model_state_dict() - assert self._check_optimizers() - assert self._check_schedulers() + assert self.trainer.current_epoch == state_dict["epoch"] + 1 + assert self.trainer.global_step == state_dict["global_step"] + assert self._check_model_state_dict() + assert self._check_optimizers() + assert self._check_schedulers() def on_validation_start(self): if self.trainer.state.fn == TrainerFn.VALIDATING: - self._test_on_val_test_predict_tune_start() + self._test_on_val_test_predict_start() def on_test_start(self): - self._test_on_val_test_predict_tune_start() + self._test_on_val_test_predict_start() for fn in ("fit", "validate", "test", "predict"): model = CustomClassifModel() dm = ClassifDataModule() - trainer_args["auto_scale_batch_size"] = (fn == "tune",) trainer = Trainer(**trainer_args) trainer_fn = getattr(trainer, fn) trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt) diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 7a517bf103fcd..171554f01205b 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -155,9 +155,7 @@ def test_ddp_configure_ddp(): @RunIf(min_cuda_gpus=1) -@pytest.mark.parametrize( - "trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TUNING, TrainerFn.TESTING, TrainerFn.PREDICTING) -) +@pytest.mark.parametrize("trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING)) def test_ddp_dont_configure_sync_batchnorm(trainer_fn): model = BoringModelGPU() model.layer = torch.nn.BatchNorm1d(10) diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 78331aaaf992a..6c5032cf05a1e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -123,22 +123,22 @@ def test_loops_restore(tmpdir): trainer = Trainer(**trainer_args) trainer.strategy.connect(model) - for fn in TrainerFn: - if fn != TrainerFn.TUNING: - trainer_fn = getattr(trainer, f"{fn}_loop") - trainer_fn.load_state_dict = mock.Mock() - - for fn in TrainerFn: - if fn != TrainerFn.TUNING: - trainer.state.fn = fn - trainer._checkpoint_connector.resume_start(ckpt_path) - trainer._checkpoint_connector.restore_loops() - - trainer_loop = getattr(trainer, f"{fn}_loop") - trainer_loop.load_state_dict.assert_called() - trainer_loop.load_state_dict.reset_mock() - - for fn2 in TrainerFn: - if fn2 not in (fn, TrainerFn.TUNING): + trainer_fns = [fn for fn in TrainerFn._without_tune()] + + for fn in trainer_fns: + trainer_fn = getattr(trainer, f"{fn}_loop") + trainer_fn.load_state_dict = mock.Mock() + + for fn in trainer_fns: + trainer.state.fn = fn + trainer._checkpoint_connector.resume_start(ckpt_path) + trainer._checkpoint_connector.restore_loops() + + trainer_loop = getattr(trainer, f"{fn}_loop") + trainer_loop.load_state_dict.assert_called() + trainer_loop.load_state_dict.reset_mock() + + for fn2 in trainer_fns: + if fn2 != fn: trainer_loop2 = getattr(trainer, f"{fn2}_loop") trainer_loop2.load_state_dict.assert_not_called()