Skip to content

Commit 0e20119

Browse files
EricWienerawaelchlicarmoccatchatonrohitgr7
authored
Change default value of the max_steps Trainer argument from None to -1 (#9460)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: thomas chaton <[email protected]> Co-authored-by: rohitgr7 <[email protected]>
1 parent d9dfb2e commit 0e20119

File tree

13 files changed

+94
-58
lines changed

13 files changed

+94
-58
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
337337
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))
338338

339339

340+
- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))
341+
342+
340343
- Disable quantization aware training observers by default during validating/testing/predicting stages ([#8540](https://github.com/PyTorchLightning/pytorch-lightning/pull/8540))
341344

342345

@@ -414,6 +417,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
414417
- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))
415418

416419

420+
- Deprecated setting `Trainer(max_steps=None)`. To turn off the limit, set `Trainer(max_steps=-1)` (default) ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))
421+
422+
417423
- Deprecated access to the `AcceleratorConnector.is_slurm_managing_tasks` attribute and marked it as protected ([#10101](https://github.com/PyTorchLightning/pytorch-lightning/pull/10101))
418424

419425

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
from pytorch_lightning import loops # import as loops to avoid circular imports
2121
from pytorch_lightning.loops.batch import TrainingBatchLoop
2222
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
23-
from pytorch_lightning.loops.utilities import _get_active_optimizers, _update_dataloader_iter
23+
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
2424
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2525
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
2626
from pytorch_lightning.utilities.apply_func import apply_to_collection
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
2929
from pytorch_lightning.utilities.model_helpers import is_overridden
3030
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
31-
from pytorch_lightning.utilities.warnings import WarningCache
31+
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache
3232

3333
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
3434

@@ -41,13 +41,20 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
4141
max_steps: The maximum number of steps (batches) to process
4242
"""
4343

44-
def __init__(self, min_steps: int, max_steps: int):
44+
def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None:
4545
super().__init__()
46-
self.min_steps: int = min_steps
47-
48-
if max_steps and max_steps < -1:
49-
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {max_steps}.")
50-
self.max_steps: int = max_steps
46+
if max_steps is None:
47+
rank_zero_deprecation(
48+
"Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
49+
" Use `max_steps = -1` instead."
50+
)
51+
max_steps = -1
52+
elif max_steps < -1:
53+
raise MisconfigurationException(
54+
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}."
55+
)
56+
self.min_steps = min_steps
57+
self.max_steps = max_steps
5158

5259
self.global_step: int = 0
5360
self.batch_progress = BatchProgress()
@@ -79,7 +86,7 @@ def batch_idx(self) -> int:
7986

8087
@property
8188
def _is_training_done(self) -> bool:
82-
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
89+
max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps)
8390
return max_steps_reached or self._num_ready_batches_reached()
8491

8592
@property

pytorch_lightning/loops/fit_loop.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616

1717
from pytorch_lightning.loops import Loop
1818
from pytorch_lightning.loops.epoch import TrainingEpochLoop
19+
from pytorch_lightning.loops.utilities import _is_max_limit_reached
1920
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2021
from pytorch_lightning.trainer.progress import Progress
2122
from pytorch_lightning.trainer.supporters import TensorRunningAccum
23+
from pytorch_lightning.utilities import rank_zero_deprecation
2224
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2325

2426
log = logging.getLogger(__name__)
@@ -29,15 +31,19 @@ class FitLoop(Loop):
2931
3032
Args:
3133
min_epochs: The minimum number of epochs
32-
max_epochs: The maximum number of epochs
34+
max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
3335
"""
3436

35-
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
37+
def __init__(
38+
self,
39+
min_epochs: Optional[int] = 1,
40+
max_epochs: int = 1000,
41+
) -> None:
3642
super().__init__()
37-
# Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done
38-
if max_epochs and max_epochs < -1:
43+
if max_epochs < -1:
44+
# Allow max_epochs to be zero, since this will be handled by fit_loop.done
3945
raise MisconfigurationException(
40-
f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}."
46+
f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
4147
)
4248

4349
self.max_epochs = max_epochs
@@ -102,8 +108,16 @@ def max_steps(self) -> int:
102108
def max_steps(self, value: int) -> None:
103109
"""Sets the maximum number of steps (forwards to epoch_loop)"""
104110
# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
105-
if value and value < -1:
106-
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {value}.")
111+
if value is None:
112+
rank_zero_deprecation(
113+
"Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
114+
" Use `max_steps = -1` instead."
115+
)
116+
value = -1
117+
elif value < -1:
118+
raise MisconfigurationException(
119+
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
120+
)
107121
self.epoch_loop.max_steps = value
108122

109123
@property
@@ -141,8 +155,8 @@ def done(self) -> bool:
141155
is reached.
142156
"""
143157
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
144-
stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps
145-
stop_epochs = FitLoop._is_max_limit_enabled(self.max_epochs) and self.current_epoch >= self.max_epochs
158+
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
159+
stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs)
146160

147161
should_stop = False
148162
if self.trainer.should_stop:
@@ -249,16 +263,3 @@ def teardown(self) -> None:
249263
def _should_accumulate(self) -> bool:
250264
"""Whether the gradients should be accumulated."""
251265
return self.epoch_loop._should_accumulate()
252-
253-
@staticmethod
254-
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
255-
"""Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps
256-
is enabled.
257-
258-
Args:
259-
max_value: the value to check
260-
261-
Returns:
262-
whether the limit for this value should be enabled
263-
"""
264-
return max_value not in (None, -1)

pytorch_lightning/loops/utilities.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,16 @@ def _get_active_optimizers(
168168
# find optimizer index by looking for the first {item > current_place} in the cumsum list
169169
opt_idx = np.searchsorted(freq_cumsum, current_place_in_loop, side="right")
170170
return [(opt_idx, optimizers[opt_idx])]
171+
172+
173+
def _is_max_limit_reached(current: int, maximum: int = -1) -> bool:
174+
"""Check if the limit has been reached (if enabled).
175+
176+
Args:
177+
current: the current value
178+
maximum: the maximum value (or -1 to disable limit)
179+
180+
Returns:
181+
bool: whether the limit has been reached
182+
"""
183+
return maximum != -1 and current >= maximum

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytorch_lightning as pl
2323
from pytorch_lightning.loggers import LightningLoggerBase
24-
from pytorch_lightning.loops.fit_loop import FitLoop
24+
from pytorch_lightning.loops.utilities import _is_max_limit_reached
2525
from pytorch_lightning.trainer.states import TrainerFn
2626
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
2727
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
@@ -227,7 +227,7 @@ def restore_loops(self) -> None:
227227

228228
# crash if max_epochs is lower then the current epoch from the checkpoint
229229
if (
230-
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
230+
self.trainer.max_epochs != -1
231231
and self.trainer.max_epochs is not None
232232
and self.trainer.current_epoch > self.trainer.max_epochs
233233
):
@@ -358,7 +358,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
358358
# dump epoch/global_step/pytorch-lightning_version
359359
current_epoch = self.trainer.current_epoch
360360
global_step = self.trainer.global_step
361-
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step
361+
has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps)
362362

363363
global_step += 1
364364
if not has_reached_max_steps:

pytorch_lightning/trainer/trainer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(
145145
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
146146
max_epochs: Optional[int] = None,
147147
min_epochs: Optional[int] = None,
148-
max_steps: Optional[int] = None,
148+
max_steps: int = -1,
149149
min_steps: Optional[int] = None,
150150
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
151151
limit_train_batches: Union[int, float] = 1.0,
@@ -327,9 +327,9 @@ def __init__(
327327
min_epochs: Force training for at least these many epochs. Disabled by default (None).
328328
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.
329329
330-
max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None``
331-
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set
332-
``max_steps`` to ``-1``.
330+
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
331+
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
332+
``max_epochs`` to ``-1``.
333333
334334
min_steps: Force training for at least these number of steps. Disabled by default (None).
335335
@@ -460,10 +460,11 @@ def __init__(
460460
self.signal_connector = SignalConnector(self)
461461
self.tuner = Tuner(self)
462462

463-
# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
464463
fit_loop = FitLoop(
465464
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
466-
max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs),
465+
max_epochs=(
466+
max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1)
467+
),
467468
)
468469
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
469470
training_batch_loop = TrainingBatchLoop()
@@ -1332,7 +1333,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
13321333

13331334
if not ckpt_path:
13341335
raise MisconfigurationException(
1335-
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
1336+
f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"
13361337
f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`"
13371338
)
13381339
return ckpt_path
@@ -1937,15 +1938,15 @@ def current_epoch(self) -> int:
19371938
return self.fit_loop.current_epoch
19381939

19391940
@property
1940-
def max_epochs(self) -> Optional[int]:
1941+
def max_epochs(self) -> int:
19411942
return self.fit_loop.max_epochs
19421943

19431944
@property
19441945
def min_epochs(self) -> Optional[int]:
19451946
return self.fit_loop.min_epochs
19461947

19471948
@property
1948-
def max_steps(self) -> Optional[int]:
1949+
def max_steps(self) -> int:
19491950
return self.fit_loop.max_steps
19501951

19511952
@property

tests/callbacks/test_timer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def on_fit_start(self):
4949
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
5050
assert timer._duration == 1
5151
assert trainer.max_epochs == -1
52-
assert trainer.max_steps is None
52+
assert trainer.max_steps == -1
5353

5454

5555
@pytest.mark.parametrize(

tests/deprecated_api/test_remove_1-7.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,15 @@ def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir):
390390
_ = XLAStatsMonitor()
391391

392392

393+
def test_v1_7_0_deprecated_max_steps_none(tmpdir):
394+
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
395+
_ = Trainer(max_steps=None)
396+
397+
trainer = Trainer()
398+
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
399+
trainer.fit_loop.max_steps = None
400+
401+
393402
def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
394403
with pytest.deprecated_call(match=r"Setting `Trainer\(resume_from_checkpoint=\)` is deprecated in v1.5"):
395404
trainer = Trainer(resume_from_checkpoint="a")

tests/trainer/flags/test_env_vars.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ def test_passing_no_env_variables():
2121
"""Testing overwriting trainer arguments."""
2222
trainer = Trainer()
2323
assert trainer.logger is not None
24-
assert trainer.max_steps is None
24+
assert trainer.max_steps == -1
25+
assert trainer.max_epochs == 1000
2526
trainer = Trainer(False, max_steps=42)
2627
assert trainer.logger is None
2728
assert trainer.max_steps == 42
29+
assert trainer.max_epochs == -1
2830

2931

3032
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"})

tests/trainer/optimization/test_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch):
383383
optimizer = optim.Adam(model.parameters())
384384
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
385385
max_epochs = 1 if complete_epoch else None
386-
max_steps = None if complete_epoch else 1
386+
max_steps = -1 if complete_epoch else 1
387387
trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps)
388388

389389
model.configure_optimizers = lambda: {

tests/trainer/test_trainer.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -498,31 +498,30 @@ def test_trainer_max_steps_and_epochs(tmpdir):
498498

499499

500500
@pytest.mark.parametrize(
501-
"max_epochs,max_steps,incorrect_variable,incorrect_value",
501+
"max_epochs,max_steps,incorrect_variable",
502502
[
503-
(-100, None, "max_epochs", -100),
504-
(1, -2, "max_steps", -2),
503+
(-100, -1, "max_epochs"),
504+
(1, -2, "max_steps"),
505505
],
506506
)
507-
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value):
507+
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable):
508508
"""Don't allow max_epochs or max_steps to be less than -1 or a float."""
509509
with pytest.raises(
510510
MisconfigurationException,
511-
match=f"`{incorrect_variable}` must be a positive integer or -1. You passed in {incorrect_value}",
511+
match=f"`{incorrect_variable}` must be a non-negative integer or -1",
512512
):
513513
Trainer(max_epochs=max_epochs, max_steps=max_steps)
514514

515515

516516
@pytest.mark.parametrize(
517517
"max_epochs,max_steps,is_done,correct_trainer_epochs",
518518
[
519-
(None, None, False, 1000),
520-
(-1, None, False, -1),
521-
(None, -1, False, None),
519+
(None, -1, False, 1000),
520+
(-1, -1, False, -1),
522521
(5, -1, False, 5),
523522
(-1, 10, False, -1),
524-
(None, 0, True, None),
525-
(0, None, True, 0),
523+
(None, 0, True, -1),
524+
(0, -1, True, 0),
526525
(-1, 0, True, -1),
527526
(0, -1, True, 0),
528527
],

tests/trainer/test_trainer_cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def _raise():
127127
# These parameters are marked as Optional[...] in Trainer.__init__, with None as default.
128128
# They should not be changed by the argparse interface.
129129
"min_steps": None,
130-
"max_steps": None,
131130
"accelerator": None,
132131
"weights_save_path": None,
133132
"profiler": None,

tests/utilities/test_cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def _raise():
134134
# with None as default. They should not be changed by the argparse
135135
# interface.
136136
min_steps=None,
137-
max_steps=None,
138137
accelerator=None,
139138
weights_save_path=None,
140139
profiler=None,

0 commit comments

Comments
 (0)