Skip to content

Commit 6ed476c

Browse files
rohitgr7nicolai86
authored andcommitted
Deprecate tuning enum and trainer properties (#15100)
1 parent e4ef4bc commit 6ed476c

File tree

14 files changed

+123
-73
lines changed

14 files changed

+123
-73
lines changed

src/pytorch_lightning/CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9090
- Deprecated duplicate `SaveConfigCallback` parameters in `LightningCLI.__init__`: `save_config_kwargs`, `save_config_overwrite` and `save_config_multifile`. New `save_config_kwargs` parameter should be used instead ([#14998](https://github.com/Lightning-AI/lightning/pull/14998))
9191

9292

93+
- Deprecated `TrainerFn.TUNING`, `RunningStage.TUNING` and `trainer.tuning` property ([#15100](https://github.com/Lightning-AI/lightning/pull/15100)
94+
95+
9396
### Removed
9497

9598
- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))

src/pytorch_lightning/callbacks/timer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ def __init__(
9595
self._duration = duration.total_seconds() if duration is not None else None
9696
self._interval = interval
9797
self._verbose = verbose
98-
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
99-
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
98+
self._start_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
99+
self._end_time: Dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage._without_tune()}
100100
self._offset = 0
101101

102102
def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:
@@ -161,7 +161,7 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -
161161
self._check_time_remaining(trainer)
162162

163163
def state_dict(self) -> Dict[str, Any]:
164-
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in list(RunningStage)}}
164+
return {"time_elapsed": {stage.value: self.time_elapsed(stage) for stage in RunningStage._without_tune()}}
165165

166166
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
167167
time_elapsed = state_dict.get("time_elapsed", {})

src/pytorch_lightning/strategies/deepspeed.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
605605
Args:
606606
trainer: the Trainer, these optimizers should be connected to
607607
"""
608-
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
608+
if trainer.state.fn != TrainerFn.FITTING:
609609
return
610610
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
611611
# User may have specified config options instead in configure_optimizers, but this is handled

src/pytorch_lightning/strategies/ipu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
130130
# Separate models are instantiated for different stages, but they share the same weights on host.
131131
# When validation/test models are run, weights are synced first.
132132
trainer_fn = self.lightning_module.trainer.state.fn
133-
if trainer_fn in (TrainerFn.FITTING, TrainerFn.TUNING):
133+
if trainer_fn == TrainerFn.FITTING:
134134
# Create model for training and validation which will run on fit
135135
training_opts = self.training_opts
136136
inference_opts = self.inference_opts

src/pytorch_lightning/strategies/strategy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
136136
Args:
137137
trainer: the Trainer, these optimizers should be connected to
138138
"""
139-
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
139+
if trainer.state.fn != TrainerFn.FITTING:
140140
return
141141
assert self.lightning_module is not None
142142
self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(

src/pytorch_lightning/trainer/configuration_validator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
3737

3838
if trainer.state.fn is None:
3939
raise ValueError("Unexpected: Trainer state fn must be set before validating loop configuration.")
40-
if trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
40+
if trainer.state.fn == TrainerFn.FITTING:
4141
__verify_train_val_loop_configuration(trainer, model)
4242
__verify_manual_optimization_support(trainer, model)
4343
__check_training_step_requires_dataloader_iter(model)

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def restore_loops(self) -> None:
368368
assert self.trainer.state.fn is not None
369369
state_dict = self._loaded_checkpoint.get("loops")
370370
if state_dict is not None:
371-
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
371+
if self.trainer.state.fn == TrainerFn.FITTING:
372372
fit_loop.load_state_dict(state_dict["fit_loop"])
373373
elif self.trainer.state.fn == TrainerFn.VALIDATING:
374374
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])

src/pytorch_lightning/trainer/states.py

+52-10
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,40 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
15-
from typing import Optional
15+
from enum import Enum, EnumMeta
16+
from typing import Any, List, Optional
17+
18+
from lightning_utilities.core.rank_zero import rank_zero_deprecation
1619

1720
from pytorch_lightning.utilities import LightningEnum
1821
from pytorch_lightning.utilities.enums import _FaultTolerantMode
1922

2023

24+
class _DeprecationManagingEnumMeta(EnumMeta):
25+
"""Enum that calls `deprecate()` whenever a member is accessed.
26+
27+
Adapted from: https://stackoverflow.com/a/62309159/208880
28+
"""
29+
30+
def __getattribute__(cls, name: str) -> Any:
31+
obj = super().__getattribute__(name)
32+
# ignore __dunder__ names -- prevents potential recursion errors
33+
if not (name.startswith("__") and name.endswith("__")) and isinstance(obj, Enum):
34+
obj.deprecate()
35+
return obj
36+
37+
def __getitem__(cls, name: str) -> Any:
38+
member: _DeprecationManagingEnumMeta = super().__getitem__(name)
39+
member.deprecate()
40+
return member
41+
42+
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
43+
obj = super().__call__(*args, **kwargs)
44+
if isinstance(obj, Enum):
45+
obj.deprecate()
46+
return obj
47+
48+
2149
class TrainerStatus(LightningEnum):
2250
"""Enum for the status of the :class:`~pytorch_lightning.trainer.trainer.Trainer`"""
2351

@@ -31,7 +59,7 @@ def stopped(self) -> bool:
3159
return self in (self.FINISHED, self.INTERRUPTED)
3260

3361

34-
class TrainerFn(LightningEnum):
62+
class TrainerFn(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
3563
"""
3664
Enum for the user-facing functions of the :class:`~pytorch_lightning.trainer.trainer.Trainer`
3765
such as :meth:`~pytorch_lightning.trainer.trainer.Trainer.fit` and
@@ -44,16 +72,19 @@ class TrainerFn(LightningEnum):
4472
PREDICTING = "predict"
4573
TUNING = "tune"
4674

47-
@property
48-
def _setup_fn(self) -> "TrainerFn":
49-
"""``FITTING`` is used instead of ``TUNING`` as there are no "tune" dataloaders.
75+
def deprecate(self) -> None:
76+
if self == self.TUNING:
77+
rank_zero_deprecation(
78+
f"`TrainerFn.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0."
79+
)
5080

51-
This is used for the ``setup()`` and ``teardown()`` hooks
52-
"""
53-
return TrainerFn.FITTING if self == TrainerFn.TUNING else self
81+
@classmethod
82+
def _without_tune(cls) -> List["TrainerFn"]:
83+
fns = [fn for fn in cls if fn != "tune"]
84+
return fns
5485

5586

56-
class RunningStage(LightningEnum):
87+
class RunningStage(LightningEnum, metaclass=_DeprecationManagingEnumMeta):
5788
"""Enum for the current running stage.
5889
5990
This stage complements :class:`TrainerFn` by specifying the current running stage for each function.
@@ -79,12 +110,23 @@ def evaluating(self) -> bool:
79110

80111
@property
81112
def dataloader_prefix(self) -> Optional[str]:
82-
if self in (self.SANITY_CHECKING, self.TUNING):
113+
if self == self.SANITY_CHECKING:
83114
return None
84115
if self == self.VALIDATING:
85116
return "val"
86117
return self.value
87118

119+
def deprecate(self) -> None:
120+
if self == self.TUNING:
121+
rank_zero_deprecation(
122+
f"`RunningStage.{self.name}` has been deprecated in v1.8.0 and will be removed in v1.10.0."
123+
)
124+
125+
@classmethod
126+
def _without_tune(cls) -> List["RunningStage"]:
127+
fns = [fn for fn in cls if fn != "tune"]
128+
return fns
129+
88130

89131
@dataclass
90132
class TrainerState:

src/pytorch_lightning/trainer/trainer.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,7 @@ def _restore_modules_and_callbacks(self, checkpoint_path: Optional[_PATH] = None
961961
def _run(
962962
self, model: "pl.LightningModule", ckpt_path: Optional[str] = None
963963
) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]:
964-
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
964+
if self.state.fn == TrainerFn.FITTING:
965965
min_epochs, max_epochs = _parse_loop_limits(
966966
self.min_steps, self.max_steps, self.min_epochs, self.max_epochs, self
967967
)
@@ -1233,7 +1233,7 @@ def _run_sanity_check(self) -> None:
12331233

12341234
def _call_setup_hook(self) -> None:
12351235
assert self.state.fn is not None
1236-
fn = self.state.fn._setup_fn
1236+
fn = self.state.fn
12371237

12381238
self.strategy.barrier("pre_setup")
12391239

@@ -1256,7 +1256,7 @@ def _call_configure_sharded_model(self) -> None:
12561256

12571257
def _call_teardown_hook(self) -> None:
12581258
assert self.state.fn is not None
1259-
fn = self.state.fn._setup_fn
1259+
fn = self.state.fn
12601260

12611261
if self.datamodule is not None:
12621262
self._call_lightning_datamodule_hook("teardown", stage=fn)
@@ -1449,7 +1449,7 @@ def __setup_profiler(self) -> None:
14491449
assert self.state.fn is not None
14501450
local_rank = self.local_rank if self.world_size > 1 else None
14511451
self.profiler._lightning_module = proxy(self.lightning_module)
1452-
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
1452+
self.profiler.setup(stage=self.state.fn, local_rank=local_rank, log_dir=self.log_dir)
14531453

14541454
"""
14551455
Data loading methods
@@ -1965,10 +1965,13 @@ def predicting(self, val: bool) -> None:
19651965

19661966
@property
19671967
def tuning(self) -> bool:
1968+
rank_zero_deprecation("`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.")
19681969
return self.state.stage == RunningStage.TUNING
19691970

19701971
@tuning.setter
19711972
def tuning(self, val: bool) -> None:
1973+
rank_zero_deprecation("Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0.")
1974+
19721975
if val:
19731976
self.state.stage = RunningStage.TUNING
19741977
elif self.tuning:
@@ -2097,7 +2100,7 @@ def predict_loop(self, loop: PredictionLoop) -> None:
20972100

20982101
@property
20992102
def _evaluation_loop(self) -> EvaluationLoop:
2100-
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
2103+
if self.state.fn == TrainerFn.FITTING:
21012104
return self.fit_loop.epoch_loop.val_loop
21022105
if self.state.fn == TrainerFn.VALIDATING:
21032106
return self.validate_loop

src/pytorch_lightning/tuner/tuning.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytorch_lightning.callbacks.callback import Callback
2121
from pytorch_lightning.callbacks.lr_finder import LearningRateFinder
2222
from pytorch_lightning.core.datamodule import LightningDataModule
23-
from pytorch_lightning.trainer.states import TrainerFn, TrainerStatus
23+
from pytorch_lightning.trainer.states import TrainerStatus
2424
from pytorch_lightning.tuner.lr_finder import _LRFinder
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
@@ -77,9 +77,7 @@ def _tune(
7777

7878
# Run learning rate finder:
7979
if self.trainer.auto_lr_find:
80-
self.trainer.state.fn = TrainerFn.TUNING
8180
self.trainer.state.status = TrainerStatus.RUNNING
82-
self.tuning = True
8381

8482
# TODO: Remove this once LRFinder is converted to a Callback
8583
# if a datamodule comes in as the second arg, then fix it for the user
@@ -112,7 +110,6 @@ def _run(self, *args: Any, **kwargs: Any) -> None:
112110
self.trainer.state.status = TrainerStatus.RUNNING # last `_run` call might have set it to `FINISHED`
113111
self.trainer.training = True
114112
self.trainer._run(*args, **kwargs)
115-
self.trainer.tuning = True
116113

117114
def scale_batch_size(
118115
self,
@@ -170,10 +167,6 @@ def scale_batch_size(
170167
- ``model.hparams``
171168
- ``trainer.datamodule`` (the datamodule passed to the tune method)
172169
"""
173-
# TODO: Remove TrainerFn.TUNING since we are now calling fit/validate/test/predict methods directly
174-
self.trainer.state.fn = TrainerFn.TUNING
175-
self.tuning = True
176-
177170
_check_tuner_configuration(self.trainer, train_dataloaders, val_dataloaders, dataloaders, method)
178171

179172
batch_size_finder: Callback = BatchSizeFinder(
@@ -254,9 +247,6 @@ def lr_find(
254247
If learning rate/lr in ``model`` or ``model.hparams`` isn't overridden when ``auto_lr_find=True``,
255248
or if you are using more than one optimizer.
256249
"""
257-
self.trainer.state.fn = TrainerFn.TUNING
258-
self.tuning = True
259-
260250
if method != "fit":
261251
raise MisconfigurationException("method='fit' is an invalid configuration to run lr finder.")
262252

tests/tests_pytorch/deprecated_api/test_remove_1-10.py

+25
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pytorch_lightning.strategies.bagua import LightningBaguaModule
3636
from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule
3737
from pytorch_lightning.strategies.utils import on_colab_kaggle
38+
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
3839
from pytorch_lightning.utilities.apply_func import (
3940
apply_to_collection,
4041
apply_to_collections,
@@ -297,3 +298,27 @@ def test_lite_convert_deprecated_tpus_argument(tpu_available):
297298
def test_lightningCLI_save_config_init_params_deprecation_warning(name, value):
298299
with mock.patch("sys.argv", ["any.py"]), pytest.deprecated_call(match=f".*{name!r} init parameter is deprecated.*"):
299300
LightningCLI(BoringModel, run=False, **{name: value})
301+
302+
303+
def test_tuning_enum():
304+
with pytest.deprecated_call(
305+
match="`TrainerFn.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0."
306+
):
307+
TrainerFn.TUNING
308+
309+
with pytest.deprecated_call(
310+
match="`RunningStage.TUNING` has been deprecated in v1.8.0 and will be removed in v1.10.0."
311+
):
312+
RunningStage.TUNING
313+
314+
315+
def test_tuning_trainer_property():
316+
trainer = Trainer()
317+
318+
with pytest.deprecated_call(match="`Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0."):
319+
trainer.tuning
320+
321+
with pytest.deprecated_call(
322+
match="Setting `Trainer.tuning` has been deprecated in v1.8.0 and will be removed in v1.10.0."
323+
):
324+
trainer.tuning = True

tests/tests_pytorch/models/test_restore.py

+8-19
Original file line numberDiff line numberDiff line change
@@ -183,39 +183,28 @@ def _check_model_state_dict(self):
183183
for actual, expected in zip(self.state_dict(), state_dict["state_dict"])
184184
)
185185

186-
def _test_on_val_test_predict_tune_start(self):
186+
def _test_on_val_test_predict_start(self):
187187
assert self.trainer.current_epoch == state_dict["epoch"]
188188
assert self.trainer.global_step == state_dict["global_step"]
189189
assert self._check_model_state_dict()
190190

191-
# no optimizes and schedulers are loaded otherwise
192-
if self.trainer.state.fn != TrainerFn.TUNING:
193-
return
194-
195-
assert not self._check_optimizers()
196-
assert not self._check_schedulers()
197-
198191
def on_train_start(self):
199-
if self.trainer.state.fn == TrainerFn.TUNING:
200-
self._test_on_val_test_predict_tune_start()
201-
else:
202-
assert self.trainer.current_epoch == state_dict["epoch"] + 1
203-
assert self.trainer.global_step == state_dict["global_step"]
204-
assert self._check_model_state_dict()
205-
assert self._check_optimizers()
206-
assert self._check_schedulers()
192+
assert self.trainer.current_epoch == state_dict["epoch"] + 1
193+
assert self.trainer.global_step == state_dict["global_step"]
194+
assert self._check_model_state_dict()
195+
assert self._check_optimizers()
196+
assert self._check_schedulers()
207197

208198
def on_validation_start(self):
209199
if self.trainer.state.fn == TrainerFn.VALIDATING:
210-
self._test_on_val_test_predict_tune_start()
200+
self._test_on_val_test_predict_start()
211201

212202
def on_test_start(self):
213-
self._test_on_val_test_predict_tune_start()
203+
self._test_on_val_test_predict_start()
214204

215205
for fn in ("fit", "validate", "test", "predict"):
216206
model = CustomClassifModel()
217207
dm = ClassifDataModule()
218-
trainer_args["auto_scale_batch_size"] = (fn == "tune",)
219208
trainer = Trainer(**trainer_args)
220209
trainer_fn = getattr(trainer, fn)
221210
trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)

tests/tests_pytorch/strategies/test_ddp_strategy.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,7 @@ def test_ddp_configure_ddp():
155155

156156

157157
@RunIf(min_cuda_gpus=1)
158-
@pytest.mark.parametrize(
159-
"trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TUNING, TrainerFn.TESTING, TrainerFn.PREDICTING)
160-
)
158+
@pytest.mark.parametrize("trainer_fn", (TrainerFn.VALIDATING, TrainerFn.TESTING, TrainerFn.PREDICTING))
161159
def test_ddp_dont_configure_sync_batchnorm(trainer_fn):
162160
model = BoringModelGPU()
163161
model.layer = torch.nn.BatchNorm1d(10)

0 commit comments

Comments
 (0)