Skip to content

Changed max_steps default from None to -1 #9460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
5ff0694
Changed max_steps to default to -1
EricWiener Sep 11, 2021
93900eb
Update fit_loop max_epoch's docstring
EricWiener Sep 12, 2021
f14ea3b
Changes fit_loop signature to have max_epochs = -1, min_epochs = 0
EricWiener Sep 12, 2021
04b5207
Added deprecation warning + is_max_limit_reached util
EricWiener Sep 12, 2021
bf1594f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2021
953dced
Remove checking if max_epochs is less than -1
EricWiener Sep 14, 2021
b718212
Remove `Optional` from `max_steps` signature
EricWiener Sep 16, 2021
868d914
Update trainer docstring for max_steps
EricWiener Sep 16, 2021
75f3e77
Make is_max_limit_reached private
EricWiener Sep 16, 2021
6ec1c47
Move max_steps validation to training_epoch_loop
EricWiener Sep 16, 2021
4b68bfa
Moved max_epochs logic into fit_loop
EricWiener Sep 21, 2021
c803c7d
Updated logic for checking max_epochs in fit_loop
EricWiener Sep 22, 2021
dce15b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 23, 2021
1cc4642
Checkpoint connector passes mypy
EricWiener Sep 23, 2021
43d0156
Cleaned up checkpoint connector if
EricWiener Sep 23, 2021
0500ffe
Removed ClosureResult from bad merge
EricWiener Sep 23, 2021
47dc657
Changed positive int to non-negative int
EricWiener Sep 24, 2021
a834a57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2021
49fb7cc
fix expected values in trainer test
awaelchli Sep 26, 2021
a554e14
update cli tests
awaelchli Sep 26, 2021
c5017da
update utility function signature and docstring
awaelchli Sep 26, 2021
0af12ae
move max_epochs=None parsing back to trainer
awaelchli Sep 26, 2021
f751354
update epoch loop signature and deprecation message
awaelchli Sep 26, 2021
a83a8a3
update fit loop max_steps signature and deprecation message
awaelchli Sep 26, 2021
c14bad2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2021
4a327dd
rebase and update tests
rohitgr7 Oct 21, 2021
8a2b729
fix issues
rohitgr7 Oct 21, 2021
33ccd90
mypy
rohitgr7 Oct 21, 2021
a8662f9
update changelog
awaelchli Oct 21, 2021
eefcdee
update mypy and error msg
rohitgr7 Oct 22, 2021
abf24d3
Merge branch 'master' into feature/clean_up_max_steps_epochs
rohitgr7 Oct 25, 2021
8cb57ff
Merge branch 'master' into feature/clean_up_max_steps_epochs
rohitgr7 Oct 25, 2021
e8b6436
Merge branch 'master' into feature/clean_up_max_steps_epochs
rohitgr7 Oct 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated several places in the loops and trainer to access `training_type_plugin` directly instead of `accelerator` ([#9901](https://github.com/PyTorchLightning/pytorch-lightning/pull/9901))


- Changed default value of the `max_steps` Trainer argument from `None` to -1 ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))


### Deprecated

Expand Down Expand Up @@ -408,6 +410,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Deprecated `GPUStatsMonitor` and `XLAStatsMonitor` in favor of `DeviceStatsMonitor` callback ([#9924](https://github.com/PyTorchLightning/pytorch-lightning/pull/9924))


- Deprecated setting `Trainer(max_steps=None)`. To turn off the limit, set `Trainer(max_steps=-1)` (default) ([#9460](https://github.com/PyTorchLightning/pytorch-lightning/pull/9460))


### Removed

- Removed deprecated `metrics` ([#8586](https://github.com/PyTorchLightning/pytorch-lightning/pull/8586/))
Expand Down
25 changes: 16 additions & 9 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _get_active_optimizers, _update_dataloader_iter
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache

_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]

Expand All @@ -41,13 +41,20 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
max_steps: The maximum number of steps (batches) to process
"""

def __init__(self, min_steps: int, max_steps: int):
def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None:
super().__init__()
self.min_steps: int = min_steps

if max_steps and max_steps < -1:
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {max_steps}.")
self.max_steps: int = max_steps
if max_steps is None:
rank_zero_deprecation(
"Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
" Use `max_steps = -1` instead."
)
max_steps = -1
elif max_steps < -1:
raise MisconfigurationException(
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}."
)
self.min_steps = min_steps
self.max_steps = max_steps

self.global_step: int = 0
self.batch_progress = BatchProgress()
Expand Down Expand Up @@ -79,7 +86,7 @@ def batch_idx(self) -> int:

@property
def _is_training_done(self) -> bool:
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps)
return max_steps_reached or self._num_ready_batches_reached()

@property
Expand Down
45 changes: 23 additions & 22 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)
Expand All @@ -29,15 +31,19 @@ class FitLoop(Loop):

Args:
min_epochs: The minimum number of epochs
max_epochs: The maximum number of epochs
max_epochs: The maximum number of epochs, can be set -1 to turn this limit off
"""

def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
def __init__(
self,
min_epochs: Optional[int] = 1,
max_epochs: int = 1000,
) -> None:
super().__init__()
# Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done
if max_epochs and max_epochs < -1:
if max_epochs < -1:
# Allow max_epochs to be zero, since this will be handled by fit_loop.done
raise MisconfigurationException(
f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}."
f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}."
)

self.max_epochs = max_epochs
Expand Down Expand Up @@ -102,8 +108,16 @@ def max_steps(self) -> int:
def max_steps(self, value: int) -> None:
"""Sets the maximum number of steps (forwards to epoch_loop)"""
# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
if value and value < -1:
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {value}.")
if value is None:
rank_zero_deprecation(
"Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."
" Use `max_steps = -1` instead."
)
value = -1
elif value < -1:
raise MisconfigurationException(
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
)
self.epoch_loop.max_steps = value

@property
Expand Down Expand Up @@ -141,8 +155,8 @@ def done(self) -> bool:
is reached.
"""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps
stop_epochs = FitLoop._is_max_limit_enabled(self.max_epochs) and self.current_epoch >= self.max_epochs
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs)

should_stop = False
if self.trainer.should_stop:
Expand Down Expand Up @@ -249,16 +263,3 @@ def teardown(self) -> None:
def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()

@staticmethod
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
"""Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps
is enabled.

Args:
max_value: the value to check

Returns:
whether the limit for this value should be enabled
"""
return max_value not in (None, -1)
13 changes: 13 additions & 0 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,16 @@ def _get_active_optimizers(
# find optimizer index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.searchsorted(freq_cumsum, current_place_in_loop, side="right")
return [(opt_idx, optimizers[opt_idx])]


def _is_max_limit_reached(current: int, maximum: int = -1) -> bool:
"""Check if the limit has been reached (if enabled).

Args:
current: the current value
maximum: the maximum value (or -1 to disable limit)

Returns:
bool: whether the limit has been reached
"""
return maximum != -1 and current >= maximum
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning as pl
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
Expand Down Expand Up @@ -220,7 +220,7 @@ def restore_loops(self) -> None:

# crash if max_epochs is lower then the current epoch from the checkpoint
if (
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
self.trainer.max_epochs != -1
and self.trainer.max_epochs is not None
and self.trainer.current_epoch > self.trainer.max_epochs
):
Expand Down Expand Up @@ -351,7 +351,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
# dump epoch/global_step/pytorch-lightning_version
current_epoch = self.trainer.current_epoch
global_step = self.trainer.global_step
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step
has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps)

global_step += 1
if not has_reached_max_steps:
Expand Down
19 changes: 10 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(
accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: Optional[int] = None,
max_steps: int = -1,
min_steps: Optional[int] = None,
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
limit_train_batches: Union[int, float] = 1.0,
Expand Down Expand Up @@ -327,9 +327,9 @@ def __init__(
min_epochs: Force training for at least these many epochs. Disabled by default (None).
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.

max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None``
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set
``max_steps`` to ``-1``.
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
``max_epochs`` to ``-1``.

min_steps: Force training for at least these number of steps. Disabled by default (None).

Expand Down Expand Up @@ -455,10 +455,11 @@ def __init__(
self.signal_connector = SignalConnector(self)
self.tuner = Tuner(self)

# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
fit_loop = FitLoop(
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs),
max_epochs=(
max_epochs if max_epochs is not None else (1000 if (max_steps == -1 and max_time is None) else -1)
),
)
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
training_batch_loop = TrainingBatchLoop()
Expand Down Expand Up @@ -1332,7 +1333,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_

if not ckpt_path:
raise MisconfigurationException(
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"
f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`"
)
return ckpt_path
Expand Down Expand Up @@ -1933,15 +1934,15 @@ def current_epoch(self) -> int:
return self.fit_loop.current_epoch

@property
def max_epochs(self) -> Optional[int]:
def max_epochs(self) -> int:
return self.fit_loop.max_epochs

@property
def min_epochs(self) -> Optional[int]:
return self.fit_loop.min_epochs

@property
def max_steps(self) -> Optional[int]:
def max_steps(self) -> int:
return self.fit_loop.max_steps

@property
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def on_fit_start(self):
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
assert timer._duration == 1
assert trainer.max_epochs == -1
assert trainer.max_steps is None
assert trainer.max_steps == -1


@pytest.mark.parametrize(
Expand Down
9 changes: 9 additions & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,12 @@ def test_v1_7_0_deprecate_gpu_stats_monitor(tmpdir):
def test_v1_7_0_deprecate_xla_stats_monitor(tmpdir):
with pytest.deprecated_call(match="The `XLAStatsMonitor` callback was deprecated in v1.5"):
_ = XLAStatsMonitor()


def test_v1_7_0_deprecated_max_steps_none(tmpdir):
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
_ = Trainer(max_steps=None)

trainer = Trainer()
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
trainer.fit_loop.max_steps = None
4 changes: 3 additions & 1 deletion tests/trainer/flags/test_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def test_passing_no_env_variables():
"""Testing overwriting trainer arguments."""
trainer = Trainer()
assert trainer.logger is not None
assert trainer.max_steps is None
assert trainer.max_steps == -1
assert trainer.max_epochs == 1000
trainer = Trainer(False, max_steps=42)
assert trainer.logger is None
assert trainer.max_steps == 42
assert trainer.max_epochs == -1


@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"})
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch):
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
max_epochs = 1 if complete_epoch else None
max_steps = None if complete_epoch else 1
max_steps = -1 if complete_epoch else 1
trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps)

model.configure_optimizers = lambda: {
Expand Down
19 changes: 9 additions & 10 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,31 +498,30 @@ def test_trainer_max_steps_and_epochs(tmpdir):


@pytest.mark.parametrize(
"max_epochs,max_steps,incorrect_variable,incorrect_value",
"max_epochs,max_steps,incorrect_variable",
[
(-100, None, "max_epochs", -100),
(1, -2, "max_steps", -2),
(-100, -1, "max_epochs"),
(1, -2, "max_steps"),
],
)
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value):
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable):
"""Don't allow max_epochs or max_steps to be less than -1 or a float."""
with pytest.raises(
MisconfigurationException,
match=f"`{incorrect_variable}` must be a positive integer or -1. You passed in {incorrect_value}",
match=f"`{incorrect_variable}` must be a non-negative integer or -1",
):
Trainer(max_epochs=max_epochs, max_steps=max_steps)


@pytest.mark.parametrize(
"max_epochs,max_steps,is_done,correct_trainer_epochs",
[
(None, None, False, 1000),
(-1, None, False, -1),
(None, -1, False, None),
(None, -1, False, 1000),
(-1, -1, False, -1),
(5, -1, False, 5),
(-1, 10, False, -1),
(None, 0, True, None),
(0, None, True, 0),
(None, 0, True, -1),
(0, -1, True, 0),
(-1, 0, True, -1),
(0, -1, True, 0),
],
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def _raise():
# These parameters are marked as Optional[...] in Trainer.__init__, with None as default.
# They should not be changed by the argparse interface.
"min_steps": None,
"max_steps": None,
"accelerator": None,
"weights_save_path": None,
"resume_from_checkpoint": None,
Expand Down
1 change: 0 additions & 1 deletion tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def _raise():
# with None as default. They should not be changed by the argparse
# interface.
min_steps=None,
max_steps=None,
accelerator=None,
weights_save_path=None,
resume_from_checkpoint=None,
Expand Down