From e36c234af727a60c509ed92c0dfa39b03093b4fa Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 12 Oct 2022 14:54:20 +0530 Subject: [PATCH 01/15] Deprecate TrainerFn.TUNING and RunningStage.TUNING --- src/pytorch_lightning/trainer/states.py | 50 +++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index a81073cccc1c0..4d290951f7d28 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -12,12 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field -from typing import Optional +from enum import Enum, EnumMeta +from typing import Any, Optional + +from lightning_utilities.core.rank_zero import rank_zero_deprecation from pytorch_lightning.utilities import LightningEnum from pytorch_lightning.utilities.enums import _FaultTolerantMode +class _DeprecatedEnumMeta(EnumMeta): + """Enum that calls `deprecate()` whenever a member is accessed. + + Adapted from: https://stackoverflow.com/a/62309159/208880 + """ + + def __getattribute__(cls, name: str) -> Any: + obj = super().__getattribute__(name) + # ignore __dunder__ names -- prevents potential recursion errors + if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum): + obj.deprecate() + return obj + + def __getitem__(cls, name: str) -> Any: + member: _DeprecatedEnumMeta = super().__getitem__(name) + breakpoint() + member.deprecate() + return member + + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + obj = super().__call__(*args, **kwargs) + breakpoint() + if isinstance(obj, Enum): + obj.deprecate() + return obj + + class TrainerStatus(LightningEnum): """Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`""" @@ -31,7 +61,7 @@ def stopped(self) -> bool: return self in (self.FINISHED, self.INTERRUPTED) -class TrainerFn(LightningEnum): +class TrainerFn(LightningEnum, metaclass=_DeprecatedEnumMeta): """ Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer` such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and @@ -52,8 +82,15 @@ def _setup_fn(self) -> "TrainerFn": """ return TrainerFn.FITTING if self == TrainerFn.TUNING else self + def deprecate(self) -> None: + return + rank_zero_deprecation( + "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8." + f" Use the string value `{self.value!r}` instead." + ) -class RunningStage(LightningEnum): + +class RunningStage(LightningEnum, metaclass=_DeprecatedEnumMeta): """Enum for the current running stage. This stage complements :class:`TrainerFn` by specifying the current running stage for each function. @@ -85,6 +122,13 @@ def dataloader_prefix(self) -> Optional[str]: return "val" return self.value + def deprecate(self) -> None: + return + rank_zero_deprecation( + "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8." + f" Use the string value `{self.value!r}` instead." + ) + @dataclass class TrainerState: From a1417c8603419e39d7e3ee5eca9b785f12700476 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 12 Oct 2022 17:02:45 +0530 Subject: [PATCH 02/15] update --- src/pytorch_lightning/strategies/deepspeed.py | 2 +- src/pytorch_lightning/strategies/ipu.py | 2 +- src/pytorch_lightning/strategies/strategy.py | 2 +- .../trainer/configuration_validator.py | 2 +- .../connectors/checkpoint_connector.py | 2 +- src/pytorch_lightning/trainer/states.py | 28 ++++++------------- src/pytorch_lightning/trainer/trainer.py | 15 ++++++---- src/pytorch_lightning/tuner/tuning.py | 12 +------- .../deprecated_api/test_remove_1-10.py | 27 ++++++++++++++++++ tests/tests_pytorch/helpers/datamodules.py | 2 +- tests/tests_pytorch/models/test_restore.py | 27 ++++++------------ .../strategies/test_ddp_strategy.py | 4 +-- .../connectors/test_checkpoint_connector.py | 20 ++++++------- 13 files changed, 71 insertions(+), 74 deletions(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 82f0aed9d1366..d7e4cb500665f 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -605,7 +605,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: Args: trainer: the Trainer, these optimizers should be connected to """ - if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer.state.fn != TrainerFn.FITTING: return # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 66c11b26c90ff..87354545aa932 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -130,7 +130,7 @@ def setup(self, trainer: "pl.Trainer") -> None: # Separate models are instantiated for different stages, but they share the same weights on host. # When validation/test models are run, weights are synced first. trainer_fn = self.lightning_module.trainer.state.fn - if trainer_fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer_fn == TrainerFn.FITTING: # Create model for training and validation which will run on fit training_opts = self.training_opts inference_opts = self.inference_opts diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 8e0ca583ff682..31d69eb55de92 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -136,7 +136,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None: Args: trainer: the Trainer, these optimizers should be connected to """ - if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer.state.fn != TrainerFn.FITTING: return assert self.lightning_module is not None self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers( diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index c1947354027e8..3f76967f10385 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -37,7 +37,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: if trainer.state.fn is None: raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.") - if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if trainer.state.fn == TrainerFn.FITTING: __verify_train_val_loop_configuration(trainer, model) __verify_manual_optimization_support(trainer, model) __check_training_step_requires_dataloader_iter(model) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 25fd02a91a280..924b870ce7d9d 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -368,7 +368,7 @@ def restore_loops(self) -> None: assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: - if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.trainer.state.fn == TrainerFn.FITTING: fit_loop.load_state_dict(state_dict["fit_loop"]) elif self.trainer.state.fn == TrainerFn.VALIDATING: self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index 4d290951f7d28..29d1f256c146f 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -74,20 +74,11 @@ class TrainerFn(LightningEnum, metaclass=_DeprecatedEnumMeta): PREDICTING = "predict" TUNING = "tune" - @property - def _setup_fn(self) -> "TrainerFn": - """``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders. - - This is used for the ``setup()`` and ``teardown()`` hooks - """ - return TrainerFn.FITTING if self == TrainerFn.TUNING else self - def deprecate(self) -> None: - return - rank_zero_deprecation( - "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8." - f" Use the string value `{self.value!r}` instead." - ) + if self == self.TUNING: + rank_zero_deprecation( + f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ) class RunningStage(LightningEnum, metaclass=_DeprecatedEnumMeta): @@ -116,18 +107,17 @@ def evaluating(self) -> bool: @property def dataloader_prefix(self) -> Optional[str]: - if self in (self.SANITY_CHECKING, self.TUNING): + if self == self.SANITY_CHECKING: return None if self == self.VALIDATING: return "val" return self.value def deprecate(self) -> None: - return - rank_zero_deprecation( - "`DistributedType` Enum has been deprecated in v1.6 and will be removed in v1.8." - f" Use the string value `{self.value!r}` instead." - ) + if self == self.TUNING: + rank_zero_deprecation( + f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ) @dataclass diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 190a75ec2b286..a0f0d0c8aa3ea 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -955,7 +955,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None def _run( self, model: "pl.LightningModule", ckpt_path: Optional[str] = None ) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: - if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.state.fn == TrainerFn.FITTING: min_epochs, max_epochs = _parse_loop_limits( self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self ) @@ -1225,7 +1225,7 @@ def _run_sanity_check(self) -> None: def _call_setup_hook(self) -> None: assert self.state.fn is not None - fn = self.state.fn._setup_fn + fn = self.state.fn self.strategy.barrier("pre_setup") @@ -1248,7 +1248,7 @@ def _call_configure_sharded_model(self) -> None: def _call_teardown_hook(self) -> None: assert self.state.fn is not None - fn = self.state.fn._setup_fn + fn = self.state.fn if self.datamodule is not None: self._call_lightning_datamodule_hook("teardown", stage=fn) @@ -1441,7 +1441,7 @@ def __setup_profiler(self) -> None: assert self.state.fn is not None local_rank = self.local_rank if self.world_size > 1 else None self.profiler._lightning_module = proxy(self.lightning_module) - self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir) + self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir) """ Data loading methods @@ -1957,10 +1957,15 @@ def predicting(self, val: bool) -> None: @property def tuning(self) -> bool: + rank_zero_deprecation("`trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0.") return self.state.stage == RunningStage.TUNING @tuning.setter def tuning(self, val: bool) -> None: + rank_zero_deprecation( + "Setting `trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0." + ) + if val: self.state.stage = RunningStage.TUNING elif self.tuning: @@ -2089,7 +2094,7 @@ def predict_loop(self, loop: PredictionLoop) -> None: @property def _evaluation_loop(self) -> EvaluationLoop: - if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING): + if self.state.fn == TrainerFn.FITTING: return self.fit_loop.epoch_loop.val_loop if self.state.fn == TrainerFn.VALIDATING: return self.validate_loop diff --git a/src/pytorch_lightning/tuner/tuning.py b/src/pytorch_lightning/tuner/tuning.py index dd18b06c698f2..1abba552b98cd 100644 --- a/src/pytorch_lightning/tuner/tuning.py +++ b/src/pytorch_lightning/tuner/tuning.py @@ -20,7 +20,7 @@ from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.callbacks.lr_finder import LearningRateFinder from pytorch_lightning.core.datamodule import LightningDataModule -from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus +from pytorch_lightning.trainer.states import TrainerStatus from pytorch_lightning.tuner.lr_finder import _LRFinder from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -77,9 +77,7 @@ def _tune( # Run learning rate finder: if self.trainer.auto_lr_find: - self.trainer.state.fn = TrainerFn.TUNING self.trainer.state.status = TrainerStatus.RUNNING - self.tuning = True # TODO: Remove this once LRFinder is converted to a Callback # if a datamodule comes in as the second arg, then fix it for the user @@ -112,7 +110,6 @@ def _run(self, *args: Any, **kwargs: Any) -> None: self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED` self.trainer.training = True self.trainer._run(*args, **kwargs) - self.trainer.tuning = True def scale_batch_size( self, @@ -170,10 +167,6 @@ def scale_batch_size( - ``model.hparams`` - ``trainer.datamodule`` (the datamodule passed to the tune method) """ - # TODO: Remove TrainerFn.TUNING since we are now calling fit/validate/test/predict methods directly - self.trainer.state.fn = TrainerFn.TUNING - self.tuning = True - _check_tuner_configuration(self.trainer, train_dataloaders, val_dataloaders, dataloaders, method) batch_size_finder: Callback = BatchSizeFinder( @@ -254,9 +247,6 @@ def lr_find( If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``, or if you are using more than one optimizer. """ - self.trainer.state.fn = TrainerFn.TUNING - self.tuning = True - if method != "fit": raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.") diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 5378f6647e67b..2d5e804cddbd2 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -35,6 +35,7 @@ from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule from pytorch_lightning.strategies.utils import on_colab_kaggle +from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.apply_func import ( apply_to_collection, apply_to_collections, @@ -297,3 +298,29 @@ def test_lite_convert_deprecated_tpus_argument(tpu_available): def test_lightningCLI_save_config_init_params_deprecation_warning(name, value): with mock.patch("sys.argv", ["any.py"]), pytest.deprecated_call(match=f".*{name!r} init parameter is deprecated.*"): LightningCLI(BoringModel, run=False, **{name: value}) + + +def test_tuning_enum(): + with pytest.deprecated_call( + match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ): + TrainerFn.TUNING + + with pytest.deprecated_call( + match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0." + ): + RunningStage.TUNING + + +def test_tuning_trainer_property(): + trainer = Trainer() + + with pytest.deprecated_call( + match="`trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0." + ): + trainer.tuning + + with pytest.deprecated_call( + match="Setting `trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0." + ): + trainer.tuning = True diff --git a/tests/tests_pytorch/helpers/datamodules.py b/tests/tests_pytorch/helpers/datamodules.py index 4984914c275dd..8076d80af4b03 100644 --- a/tests/tests_pytorch/helpers/datamodules.py +++ b/tests/tests_pytorch/helpers/datamodules.py @@ -20,7 +20,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from tests_pytorch.helpers.datasets import MNIST, SklearnDataset, TrialMNIST -_SKLEARN_AVAILABLE = RequirementCache("sklearn") +_SKLEARN_AVAILABLE = RequirementCache("scikit-learn") class MNISTDataModule(LightningDataModule): diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index a5413e8ee0876..b21bf51de7af3 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -183,39 +183,28 @@ def _check_model_state_dict(self): for actual, expected in zip(self.state_dict(), state_dict["state_dict"]) ) - def _test_on_val_test_predict_tune_start(self): + def _test_on_val_test_predict_start(self): assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() - # no optimizes and schedulers are loaded otherwise - if self.trainer.state.fn != TrainerFn.TUNING: - return - - assert not self._check_optimizers() - assert not self._check_schedulers() - def on_train_start(self): - if self.trainer.state.fn == TrainerFn.TUNING: - self._test_on_val_test_predict_tune_start() - else: - assert self.trainer.current_epoch == state_dict["epoch"] + 1 - assert self.trainer.global_step == state_dict["global_step"] - assert self._check_model_state_dict() - assert self._check_optimizers() - assert self._check_schedulers() + assert self.trainer.current_epoch == state_dict["epoch"] + 1 + assert self.trainer.global_step == state_dict["global_step"] + assert self._check_model_state_dict() + assert self._check_optimizers() + assert self._check_schedulers() def on_validation_start(self): if self.trainer.state.fn == TrainerFn.VALIDATING: - self._test_on_val_test_predict_tune_start() + self._test_on_val_test_predict_start() def on_test_start(self): - self._test_on_val_test_predict_tune_start() + self._test_on_val_test_predict_start() for fn in ("fit", "validate", "test", "predict"): model = CustomClassifModel() dm = ClassifDataModule() - trainer_args["auto_scale_batch_size"] = (fn == "tune",) trainer = Trainer(**trainer_args) trainer_fn = getattr(trainer, fn) trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt) diff --git a/tests/tests_pytorch/strategies/test_ddp_strategy.py b/tests/tests_pytorch/strategies/test_ddp_strategy.py index 7a517bf103fcd..171554f01205b 100644 --- a/tests/tests_pytorch/strategies/test_ddp_strategy.py +++ b/tests/tests_pytorch/strategies/test_ddp_strategy.py @@ -155,9 +155,7 @@ def test_ddp_configure_ddp(): @RunIf(min_cuda_gpus=1) -@pytest.mark.parametrize( - "trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TUNING, TrainerFn.TESTING, TrainerFn.PREDICTING) -) +@pytest.mark.parametrize("trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING)) def test_ddp_dont_configure_sync_batchnorm(trainer_fn): model = BoringModelGPU() model.layer = torch.nn.BatchNorm1d(10) diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 78331aaaf992a..53d23b4d68d4e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -124,21 +124,19 @@ def test_loops_restore(tmpdir): trainer.strategy.connect(model) for fn in TrainerFn: - if fn != TrainerFn.TUNING: - trainer_fn = getattr(trainer, f"{fn}_loop") - trainer_fn.load_state_dict = mock.Mock() + trainer_fn = getattr(trainer, f"{fn}_loop") + trainer_fn.load_state_dict = mock.Mock() for fn in TrainerFn: - if fn != TrainerFn.TUNING: - trainer.state.fn = fn - trainer._checkpoint_connector.resume_start(ckpt_path) - trainer._checkpoint_connector.restore_loops() + trainer.state.fn = fn + trainer._checkpoint_connector.resume_start(ckpt_path) + trainer._checkpoint_connector.restore_loops() - trainer_loop = getattr(trainer, f"{fn}_loop") - trainer_loop.load_state_dict.assert_called() - trainer_loop.load_state_dict.reset_mock() + trainer_loop = getattr(trainer, f"{fn}_loop") + trainer_loop.load_state_dict.assert_called() + trainer_loop.load_state_dict.reset_mock() for fn2 in TrainerFn: - if fn2 not in (fn, TrainerFn.TUNING): + if fn2 != fn: trainer_loop2 = getattr(trainer, f"{fn2}_loop") trainer_loop2.load_state_dict.assert_not_called() From dfc504efe64a15642ea1db7b242abd3dbbdb5a52 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 12 Oct 2022 17:37:05 +0530 Subject: [PATCH 03/15] chlog --- src/pytorch_lightning/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ea8a7aa6f6a2a..17cf8c82f3901 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated duplicate `SaveConfigCallback` parameters in `LightningCLI.__init__`: `save_config_kwargs`, `save_config_overwrite` and `save_config_multifile`. New `save_config_kwargs` parameter should be used instead ([#14998](https://github.com/Lightning-AI/lightning/pull/14998) +- Deprecated `TrainerFn.TUNING`, `RunningStage.TUNING` and `trainer.tuning` property ([#15100](https://github.com/Lightning-AI/lightning/pull/15100) + + ### Removed - Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011)) From ad75577389e210ef4bbd0219c8e71dfa38b48d69 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 12 Oct 2022 17:54:16 +0530 Subject: [PATCH 04/15] remove breakpoints --- src/pytorch_lightning/trainer/states.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index 29d1f256c146f..44302a7fa6ce5 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -36,13 +36,11 @@ def __getattribute__(cls, name: str) -> Any: def __getitem__(cls, name: str) -> Any: member: _DeprecatedEnumMeta = super().__getitem__(name) - breakpoint() member.deprecate() return member def __call__(cls, *args: Any, **kwargs: Any) -> Any: obj = super().__call__(*args, **kwargs) - breakpoint() if isinstance(obj, Enum): obj.deprecate() return obj From becd2946685547db45b5058f98ca5660b80e8677 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 12 Oct 2022 18:17:26 +0530 Subject: [PATCH 05/15] Apply suggestions from code review --- src/pytorch_lightning/trainer/states.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index 44302a7fa6ce5..3e1bd75fd80ca 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -21,7 +21,7 @@ from pytorch_lightning.utilities.enums import _FaultTolerantMode -class _DeprecatedEnumMeta(EnumMeta): +class _DeprecationManagingEnumMeta(EnumMeta): """Enum that calls `deprecate()` whenever a member is accessed. Adapted from: https://stackoverflow.com/a/62309159/208880 @@ -59,7 +59,7 @@ def stopped(self) -> bool: return self in (self.FINISHED, self.INTERRUPTED) -class TrainerFn(LightningEnum, metaclass=_DeprecatedEnumMeta): +class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta): """ Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer` such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and From 0ffbd48688a773e7aa792dec075d1319a71e5ae1 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 12 Oct 2022 18:21:15 +0530 Subject: [PATCH 06/15] Apply suggestions from code review --- src/pytorch_lightning/trainer/states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index 3e1bd75fd80ca..add4b93643cd4 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -79,7 +79,7 @@ def deprecate(self) -> None: ) -class RunningStage(LightningEnum, metaclass=_DeprecatedEnumMeta): +class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta): """Enum for the current running stage. This stage complements :class:`TrainerFn` by specifying the current running stage for each function. From af155466d6a1ac9a06bc4d935d8d3189e71a07d2 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Wed, 12 Oct 2022 18:41:12 +0530 Subject: [PATCH 07/15] Update src/pytorch_lightning/trainer/states.py Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> --- src/pytorch_lightning/trainer/states.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index add4b93643cd4..0abfbe7a33154 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -35,7 +35,7 @@ def __getattribute__(cls, name: str) -> Any: return obj def __getitem__(cls, name: str) -> Any: - member: _DeprecatedEnumMeta = super().__getitem__(name) + member: _DeprecationManagingEnumMeta = super().__getitem__(name) member.deprecate() return member From d0904a9749c62c01c54a3c9e474276c2ecad5c11 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 12 Oct 2022 19:10:10 +0530 Subject: [PATCH 08/15] update tests --- src/pytorch_lightning/callbacks/timer.py | 10 +++++++--- .../trainer/connectors/test_checkpoint_connector.py | 8 +++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/callbacks/timer.py b/src/pytorch_lightning/callbacks/timer.py index ca9a2c9861faa..9b841f8dd0762 100644 --- a/src/pytorch_lightning/callbacks/timer.py +++ b/src/pytorch_lightning/callbacks/timer.py @@ -95,8 +95,10 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} - self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage} + self._start_time: Dict[RunningStage, Optional[float]] = { + stage: None for stage in RunningStage if stage != "tune" + } + self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage if stage != "tune"} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -161,7 +163,9 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) - self._check_time_remaining(trainer) def state_dict(self) -> Dict[str, Any]: - return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}} + return { + "time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage) if stage != "tune"} + } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index 53d23b4d68d4e..e47ea78d60722 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -123,11 +123,13 @@ def test_loops_restore(tmpdir): trainer = Trainer(**trainer_args) trainer.strategy.connect(model) - for fn in TrainerFn: + trainer_fns = [fn for fn in TrainerFn if fn != "tune"] + + for fn in trainer_fns: trainer_fn = getattr(trainer, f"{fn}_loop") trainer_fn.load_state_dict = mock.Mock() - for fn in TrainerFn: + for fn in trainer_fns: trainer.state.fn = fn trainer._checkpoint_connector.resume_start(ckpt_path) trainer._checkpoint_connector.restore_loops() @@ -136,7 +138,7 @@ def test_loops_restore(tmpdir): trainer_loop.load_state_dict.assert_called() trainer_loop.load_state_dict.reset_mock() - for fn2 in TrainerFn: + for fn2 in trainer_fns: if fn2 != fn: trainer_loop2 = getattr(trainer, f"{fn2}_loop") trainer_loop2.load_state_dict.assert_not_called() From 8b98d52ccf46f607034ae96a8ea78cec9d9339ab Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 12 Oct 2022 19:38:01 +0530 Subject: [PATCH 09/15] flaky test --- tests/tests_pytorch/core/test_datamodules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 82bb30c85ba31..8a2abe15e5812 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -180,7 +180,6 @@ def test_train_val_loop_only(tmpdir): # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.callback_metrics["train_loss"] < 1.0 def test_dm_checkpoint_save_and_load(tmpdir): From 3ac47d31cfc79e1e80d5bfb9fc965eed4140911a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 12 Oct 2022 19:53:08 +0530 Subject: [PATCH 10/15] flaky test --- tests/tests_pytorch/core/test_datamodules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index 8a2abe15e5812..b502688ad3709 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -164,7 +164,6 @@ def test_train_loop_only(tmpdir): # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" - assert trainer.callback_metrics["train_loss"] < 1.0 def test_train_val_loop_only(tmpdir): From df501c4c2242d099162d6060ca1a15590b0d068d Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 13 Oct 2022 18:01:57 +0530 Subject: [PATCH 11/15] reviews --- src/pytorch_lightning/callbacks/timer.py | 10 +++------- src/pytorch_lightning/trainer/states.py | 10 ++++++++++ src/pytorch_lightning/trainer/trainer.py | 6 ++---- tests/tests_pytorch/deprecated_api/test_remove_1-10.py | 6 ++---- .../trainer/connectors/test_checkpoint_connector.py | 2 +- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/callbacks/timer.py b/src/pytorch_lightning/callbacks/timer.py index 9b841f8dd0762..75763ae3ac868 100644 --- a/src/pytorch_lightning/callbacks/timer.py +++ b/src/pytorch_lightning/callbacks/timer.py @@ -95,10 +95,8 @@ def __init__( self._duration = duration.total_seconds() if duration is not None else None self._interval = interval self._verbose = verbose - self._start_time: Dict[RunningStage, Optional[float]] = { - stage: None for stage in RunningStage if stage != "tune" - } - self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage if stage != "tune"} + self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()} + self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()} self._offset = 0 def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]: @@ -163,9 +161,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) - self._check_time_remaining(trainer) def state_dict(self) -> Dict[str, Any]: - return { - "time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage) if stage != "tune"} - } + return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: time_elapsed = state_dict.get("time_elapsed", {}) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index 0abfbe7a33154..905f4fec57cde 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -78,6 +78,11 @@ def deprecate(self) -> None: f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0." ) + @classmethod + def _without_tune(cls): + fns = [fn for fn in cls if fn != "tune"] + return fns + class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta): """Enum for the current running stage. @@ -117,6 +122,11 @@ def deprecate(self) -> None: f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0." ) + @classmethod + def _without_tune(cls): + fns = [fn for fn in cls if fn != "tune"] + return fns + @dataclass class TrainerState: diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 01bc4293c95bf..8f8a1e24b42b0 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1965,14 +1965,12 @@ def predicting(self, val: bool) -> None: @property def tuning(self) -> bool: - rank_zero_deprecation("`trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0.") + rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.") return self.state.stage == RunningStage.TUNING @tuning.setter def tuning(self, val: bool) -> None: - rank_zero_deprecation( - "Setting `trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0." - ) + rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.") if val: self.state.stage = RunningStage.TUNING diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 2d5e804cddbd2..ea70aa8be6eb8 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -315,12 +315,10 @@ def test_tuning_enum(): def test_tuning_trainer_property(): trainer = Trainer() - with pytest.deprecated_call( - match="`trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0." - ): + with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0."): trainer.tuning with pytest.deprecated_call( - match="Setting `trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0." + match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0." ): trainer.tuning = True diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index e47ea78d60722..6c5032cf05a1e 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -123,7 +123,7 @@ def test_loops_restore(tmpdir): trainer = Trainer(**trainer_args) trainer.strategy.connect(model) - trainer_fns = [fn for fn in TrainerFn if fn != "tune"] + trainer_fns = [fn for fn in TrainerFn._without_tune()] for fn in trainer_fns: trainer_fn = getattr(trainer, f"{fn}_loop") From 0e8e5d923c48cd872d220b4a0e176bd9ca5687ec Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 13 Oct 2022 18:02:07 +0530 Subject: [PATCH 12/15] rev --- tests/tests_pytorch/core/test_datamodules.py | 2 ++ tests/tests_pytorch/helpers/datamodules.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index b502688ad3709..82bb30c85ba31 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -164,6 +164,7 @@ def test_train_loop_only(tmpdir): # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.callback_metrics["train_loss"] < 1.0 def test_train_val_loop_only(tmpdir): @@ -179,6 +180,7 @@ def test_train_val_loop_only(tmpdir): # fit model trainer.fit(model, datamodule=dm) assert trainer.state.finished, f"Training failed with {trainer.state}" + assert trainer.callback_metrics["train_loss"] < 1.0 def test_dm_checkpoint_save_and_load(tmpdir): diff --git a/tests/tests_pytorch/helpers/datamodules.py b/tests/tests_pytorch/helpers/datamodules.py index 8076d80af4b03..4984914c275dd 100644 --- a/tests/tests_pytorch/helpers/datamodules.py +++ b/tests/tests_pytorch/helpers/datamodules.py @@ -20,7 +20,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from tests_pytorch.helpers.datasets import MNIST, SklearnDataset, TrialMNIST -_SKLEARN_AVAILABLE = RequirementCache("scikit-learn") +_SKLEARN_AVAILABLE = RequirementCache("sklearn") class MNISTDataModule(LightningDataModule): From a99c29d8962367640b59cd0691f5d104f6781891 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 13 Oct 2022 18:03:49 +0530 Subject: [PATCH 13/15] try run --- .github/workflows/ci-pytorch-test-full.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-pytorch-test-full.yml b/.github/workflows/ci-pytorch-test-full.yml index 2f040f15c90f0..fa70f416f1007 100644 --- a/.github/workflows/ci-pytorch-test-full.yml +++ b/.github/workflows/ci-pytorch-test-full.yml @@ -153,7 +153,7 @@ jobs: - name: Testing PyTorch working-directory: tests/tests_pytorch # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - run: coverage run --source pytorch_lightning -m pytest -v --durations=50 --junitxml=results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ matrix.release }}.xml + run: coverage run --source pytorch_lightning -m pytest core/test_datamodules.py -v --durations=50 --junitxml=results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ matrix.release }}.xml - name: Upload pytest results if: failure() From 711138d5d9afd5e05c0a3ef4ba5f5336ec026754 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 13 Oct 2022 18:11:31 +0530 Subject: [PATCH 14/15] rev --- .github/workflows/ci-pytorch-test-full.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-pytorch-test-full.yml b/.github/workflows/ci-pytorch-test-full.yml index fa70f416f1007..2f040f15c90f0 100644 --- a/.github/workflows/ci-pytorch-test-full.yml +++ b/.github/workflows/ci-pytorch-test-full.yml @@ -153,7 +153,7 @@ jobs: - name: Testing PyTorch working-directory: tests/tests_pytorch # NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003 - run: coverage run --source pytorch_lightning -m pytest core/test_datamodules.py -v --durations=50 --junitxml=results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ matrix.release }}.xml + run: coverage run --source pytorch_lightning -m pytest -v --durations=50 --junitxml=results-${{ runner.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ matrix.release }}.xml - name: Upload pytest results if: failure() From 1863f3b3f77dff82a4620cf0079b541316e9634e Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Thu, 13 Oct 2022 18:21:34 +0530 Subject: [PATCH 15/15] mypy --- src/pytorch_lightning/trainer/states.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/states.py b/src/pytorch_lightning/trainer/states.py index 905f4fec57cde..0063ef3fabe96 100644 --- a/src/pytorch_lightning/trainer/states.py +++ b/src/pytorch_lightning/trainer/states.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field from enum import Enum, EnumMeta -from typing import Any, Optional +from typing import Any, List, Optional from lightning_utilities.core.rank_zero import rank_zero_deprecation @@ -79,7 +79,7 @@ def deprecate(self) -> None: ) @classmethod - def _without_tune(cls): + def _without_tune(cls) -> List["TrainerFn"]: fns = [fn for fn in cls if fn != "tune"] return fns @@ -123,7 +123,7 @@ def deprecate(self) -> None: ) @classmethod - def _without_tune(cls): + def _without_tune(cls) -> List["RunningStage"]: fns = [fn for fn in cls if fn != "tune"] return fns