Skip to content

Deprecate tuning enum and trainer properties #15100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Oct 13, 2022
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated duplicate `SaveConfigCallback` parameters in `LightningCLI.__init__`: `save_config_kwargs`, `save_config_overwrite` and `save_config_multifile`. New `save_config_kwargs` parameter should be used instead ([#14998](https://github.com/Lightning-AI/lightning/pull/14998)


- Deprecated `TrainerFn.TUNING`, `RunningStage.TUNING` and `trainer.tuning` property ([#15100](https://github.com/Lightning-AI/lightning/pull/15100)


### Removed

- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn != TrainerFn.FITTING:
return
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
# Separate models are instantiated for different stages, but they share the same weights on host.
# When validation/test models are run, weights are synced first.
trainer_fn = self.lightning_module.trainer.state.fn
if trainer_fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer_fn == TrainerFn.FITTING:
# Create model for training and validation which will run on fit
training_opts = self.training_opts
inference_opts = self.inference_opts
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
Args:
trainer: the Trainer, these optimizers should be connected to
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn != TrainerFn.FITTING:
return
assert self.lightning_module is not None
self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:

if trainer.state.fn is None:
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if trainer.state.fn == TrainerFn.FITTING:
__verify_train_val_loop_configuration(trainer, model)
__verify_manual_optimization_support(trainer, model)
__check_training_step_requires_dataloader_iter(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def restore_loops(self) -> None:
assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None:
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.trainer.state.fn == TrainerFn.FITTING:
fit_loop.load_state_dict(state_dict["fit_loop"])
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
Expand Down
54 changes: 43 additions & 11 deletions src/pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
from enum import Enum, EnumMeta
from typing import Any, Optional

from lightning_utilities.core.rank_zero import rank_zero_deprecation

from pytorch_lightning.utilities import LightningEnum
from pytorch_lightning.utilities.enums import _FaultTolerantMode


class _DeprecatedEnumMeta(EnumMeta):
"""Enum that calls `deprecate()` whenever a member is accessed.

Adapted from: https://stackoverflow.com/a/62309159/208880
"""

def __getattribute__(cls, name: str) -> Any:
obj = super().__getattribute__(name)
# ignore __dunder__ names -- prevents potential recursion errors
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
obj.deprecate()
return obj

def __getitem__(cls, name: str) -> Any:
member: _DeprecatedEnumMeta = super().__getitem__(name)
member.deprecate()
return member

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
obj = super().__call__(*args, **kwargs)
if isinstance(obj, Enum):
obj.deprecate()
return obj


class TrainerStatus(LightningEnum):
"""Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`"""

Expand All @@ -31,7 +59,7 @@ def stopped(self) -> bool:
return self in (self.FINISHED, self.INTERRUPTED)


class TrainerFn(LightningEnum):
class TrainerFn(LightningEnum, metaclass=_DeprecatedEnumMeta):
"""
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
Expand All @@ -44,16 +72,14 @@ class TrainerFn(LightningEnum):
PREDICTING = "predict"
TUNING = "tune"

@property
def _setup_fn(self) -> "TrainerFn":
"""``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders.
def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0."
)

This is used for the ``setup()`` and ``teardown()`` hooks
"""
return TrainerFn.FITTING if self == TrainerFn.TUNING else self


class RunningStage(LightningEnum):
class RunningStage(LightningEnum, metaclass=_DeprecatedEnumMeta):
"""Enum for the current running stage.

This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
Expand All @@ -79,12 +105,18 @@ def evaluating(self) -> bool:

@property
def dataloader_prefix(self) -> Optional[str]:
if self in (self.SANITY_CHECKING, self.TUNING):
if self == self.SANITY_CHECKING:
return None
if self == self.VALIDATING:
return "val"
return self.value

def deprecate(self) -> None:
if self == self.TUNING:
rank_zero_deprecation(
f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0."
)


@dataclass
class TrainerState:
Expand Down
15 changes: 10 additions & 5 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
def _run(
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.state.fn == TrainerFn.FITTING:
min_epochs, max_epochs = _parse_loop_limits(
self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
)
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def _run_sanity_check(self) -> None:

def _call_setup_hook(self) -> None:
assert self.state.fn is not None
fn = self.state.fn._setup_fn
fn = self.state.fn

self.strategy.barrier("pre_setup")

Expand All @@ -1248,7 +1248,7 @@ def _call_configure_sharded_model(self) -> None:

def _call_teardown_hook(self) -> None:
assert self.state.fn is not None
fn = self.state.fn._setup_fn
fn = self.state.fn

if self.datamodule is not None:
self._call_lightning_datamodule_hook("teardown", stage=fn)
Expand Down Expand Up @@ -1441,7 +1441,7 @@ def __setup_profiler(self) -> None:
assert self.state.fn is not None
local_rank = self.local_rank if self.world_size > 1 else None
self.profiler._lightning_module = proxy(self.lightning_module)
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir)

"""
Data loading methods
Expand Down Expand Up @@ -1957,10 +1957,15 @@ def predicting(self, val: bool) -> None:

@property
def tuning(self) -> bool:
rank_zero_deprecation("`trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0.")
return self.state.stage == RunningStage.TUNING

@tuning.setter
def tuning(self, val: bool) -> None:
rank_zero_deprecation(
"Setting `trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0."
)

if val:
self.state.stage = RunningStage.TUNING
elif self.tuning:
Expand Down Expand Up @@ -2089,7 +2094,7 @@ def predict_loop(self, loop: PredictionLoop) -> None:

@property
def _evaluation_loop(self) -> EvaluationLoop:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
if self.state.fn == TrainerFn.FITTING:
return self.fit_loop.epoch_loop.val_loop
if self.state.fn == TrainerFn.VALIDATING:
return self.validate_loop
Expand Down
12 changes: 1 addition & 11 deletions src/pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
from pytorch_lightning.trainer.states import TrainerStatus
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
Expand Down Expand Up @@ -77,9 +77,7 @@ def _tune(

# Run learning rate finder:
if self.trainer.auto_lr_find:
self.trainer.state.fn = TrainerFn.TUNING
self.trainer.state.status = TrainerStatus.RUNNING
self.tuning = True

# TODO: Remove this once LRFinder is converted to a Callback
# if a datamodule comes in as the second arg, then fix it for the user
Expand Down Expand Up @@ -112,7 +110,6 @@ def _run(self, *args: Any, **kwargs: Any) -> None:
self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED`
self.trainer.training = True
self.trainer._run(*args, **kwargs)
self.trainer.tuning = True

def scale_batch_size(
self,
Expand Down Expand Up @@ -170,10 +167,6 @@ def scale_batch_size(
- ``model.hparams``
- ``trainer.datamodule`` (the datamodule passed to the tune method)
"""
# TODO: Remove TrainerFn.TUNING since we are now calling fit/validate/test/predict methods directly
self.trainer.state.fn = TrainerFn.TUNING
self.tuning = True

_check_tuner_configuration(self.trainer, train_dataloaders, val_dataloaders, dataloaders, method)

batch_size_finder: Callback = BatchSizeFinder(
Expand Down Expand Up @@ -254,9 +247,6 @@ def lr_find(
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``,
or if you are using more than one optimizer.
"""
self.trainer.state.fn = TrainerFn.TUNING
self.tuning = True

if method != "fit":
raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.")

Expand Down
27 changes: 27 additions & 0 deletions tests/tests_pytorch/deprecated_api/test_remove_1-10.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pytorch_lightning.strategies.bagua import LightningBaguaModule
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.apply_func import (
apply_to_collection,
apply_to_collections,
Expand Down Expand Up @@ -297,3 +298,29 @@ def test_lite_convert_deprecated_tpus_argument(tpu_available):
def test_lightningCLI_save_config_init_params_deprecation_warning(name, value):
with mock.patch("sys.argv", ["any.py"]), pytest.deprecated_call(match=f".*{name!r} init parameter is deprecated.*"):
LightningCLI(BoringModel, run=False, **{name: value})


def test_tuning_enum():
with pytest.deprecated_call(
match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0."
):
TrainerFn.TUNING

with pytest.deprecated_call(
match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0."
):
RunningStage.TUNING


def test_tuning_trainer_property():
trainer = Trainer()

with pytest.deprecated_call(
match="`trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0."
):
trainer.tuning

with pytest.deprecated_call(
match="Setting `trainer.tuning` property has been deprecated in v1.8.0 and will be removed in v1.10.0."
):
trainer.tuning = True
2 changes: 1 addition & 1 deletion tests/tests_pytorch/helpers/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pytorch_lightning.core.datamodule import LightningDataModule
from tests_pytorch.helpers.datasets import MNIST, SklearnDataset, TrialMNIST

_SKLEARN_AVAILABLE = RequirementCache("sklearn")
_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")


class MNISTDataModule(LightningDataModule):
Expand Down
27 changes: 8 additions & 19 deletions tests/tests_pytorch/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,39 +183,28 @@ def _check_model_state_dict(self):
for actual, expected in zip(self.state_dict(), state_dict["state_dict"])
)

def _test_on_val_test_predict_tune_start(self):
def _test_on_val_test_predict_start(self):
assert self.trainer.current_epoch == state_dict["epoch"]
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()

# no optimizes and schedulers are loaded otherwise
if self.trainer.state.fn != TrainerFn.TUNING:
return

assert not self._check_optimizers()
assert not self._check_schedulers()

def on_train_start(self):
if self.trainer.state.fn == TrainerFn.TUNING:
self._test_on_val_test_predict_tune_start()
else:
assert self.trainer.current_epoch == state_dict["epoch"] + 1
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()
assert self._check_optimizers()
assert self._check_schedulers()
assert self.trainer.current_epoch == state_dict["epoch"] + 1
assert self.trainer.global_step == state_dict["global_step"]
assert self._check_model_state_dict()
assert self._check_optimizers()
assert self._check_schedulers()

def on_validation_start(self):
if self.trainer.state.fn == TrainerFn.VALIDATING:
self._test_on_val_test_predict_tune_start()
self._test_on_val_test_predict_start()

def on_test_start(self):
self._test_on_val_test_predict_tune_start()
self._test_on_val_test_predict_start()

for fn in ("fit", "validate", "test", "predict"):
model = CustomClassifModel()
dm = ClassifDataModule()
trainer_args["auto_scale_batch_size"] = (fn == "tune",)
trainer = Trainer(**trainer_args)
trainer_fn = getattr(trainer, fn)
trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
Expand Down
4 changes: 1 addition & 3 deletions tests/tests_pytorch/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ def test_ddp_configure_ddp():


@RunIf(min_cuda_gpus=1)
@pytest.mark.parametrize(
"trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TUNING, TrainerFn.TESTING, TrainerFn.PREDICTING)
)
@pytest.mark.parametrize("trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING))
def test_ddp_dont_configure_sync_batchnorm(trainer_fn):
model = BoringModelGPU()
model.layer = torch.nn.BatchNorm1d(10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,19 @@ def test_loops_restore(tmpdir):
trainer.strategy.connect(model)

for fn in TrainerFn:
if fn != TrainerFn.TUNING:
trainer_fn = getattr(trainer, f"{fn}_loop")
trainer_fn.load_state_dict = mock.Mock()
trainer_fn = getattr(trainer, f"{fn}_loop")
trainer_fn.load_state_dict = mock.Mock()

for fn in TrainerFn:
if fn != TrainerFn.TUNING:
trainer.state.fn = fn
trainer._checkpoint_connector.resume_start(ckpt_path)
trainer._checkpoint_connector.restore_loops()
trainer.state.fn = fn
trainer._checkpoint_connector.resume_start(ckpt_path)
trainer._checkpoint_connector.restore_loops()

trainer_loop = getattr(trainer, f"{fn}_loop")
trainer_loop.load_state_dict.assert_called()
trainer_loop.load_state_dict.reset_mock()
trainer_loop = getattr(trainer, f"{fn}_loop")
trainer_loop.load_state_dict.assert_called()
trainer_loop.load_state_dict.reset_mock()

for fn2 in TrainerFn:
if fn2 not in (fn, TrainerFn.TUNING):
if fn2 != fn:
trainer_loop2 = getattr(trainer, f"{fn2}_loop")
trainer_loop2.load_state_dict.assert_not_called()