Skip to content

Commit 0fe3379

Browse files
daniellepintzkaushikb11rohitgr7pre-commit-ci[bot]awaelchli
authored
Deprecate weights_save_path from the Trainer constructor (#12084)
Co-authored-by: Kaushik B <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6309a59 commit 0fe3379

File tree

16 files changed

+70
-34
lines changed

16 files changed

+70
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
442442
- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))
443443

444444

445+
- Deprecated passing `weights_save_path` to the `Trainer` constructor in favor of adding the `ModelCheckpoint` callback with `dirpath` directly to the list of callbacks ([#12084](https://github.com/PyTorchLightning/pytorch-lightning/pull/12084))
446+
447+
445448
- Deprecated `BaseProfiler.profile_iterable` ([#12102](https://github.com/PyTorchLightning/pytorch-lightning/pull/12102))
446449

447450

docs/source/common/trainer.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,6 +1584,12 @@ Can specify as float or int.
15841584
weights_save_path
15851585
^^^^^^^^^^^^^^^^^
15861586

1587+
1588+
.. warning:: `weights_save_path` has been deprecated in v1.6 and will be removed in v1.8. Please pass
1589+
``dirpath`` directly to the :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
1590+
callback.
1591+
1592+
15871593
.. raw:: html
15881594

15891595
<video width="50%" max-width="400px" controls

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,18 +579,20 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
579579
580580
1. Checkpoint callback's path (if passed in)
581581
2. The default_root_dir from trainer if trainer has no logger
582-
3. The weights_save_path from trainer, if user provides it
582+
3. The weights_save_path from trainer, if user provides it (deprecated)
583583
4. User provided weights_saved_path
584584
585585
The base path gets extended with logger name and version (if these are available)
586586
and subfolder "checkpoints".
587587
"""
588588
if self.dirpath is not None:
589589
return # short circuit
590+
591+
# TODO: Remove weights_save_path logic here in v1.8
590592
if trainer.loggers:
591-
if trainer.weights_save_path != trainer.default_root_dir:
593+
if trainer._weights_save_path_internal != trainer.default_root_dir:
592594
# the user has changed weights_save_path, it overrides anything
593-
save_dir = trainer.weights_save_path
595+
save_dir = trainer._weights_save_path_internal
594596
elif len(trainer.loggers) == 1:
595597
save_dir = trainer.logger.save_dir or trainer.default_root_dir
596598
else:
@@ -602,7 +604,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
602604

603605
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
604606
else:
605-
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
607+
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
606608

607609
ckpt_path = trainer.strategy.broadcast(ckpt_path)
608610

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,5 +365,5 @@ def _check_deprecated_callback_hooks(trainer: "pl.Trainer") -> None:
365365
if is_overridden(method_name=hook, instance=callback):
366366
rank_zero_deprecation(
367367
f"The `Callback.{hook}` hook has been deprecated in v1.6 and"
368-
f" will be removed in v1.8. Please use `Callback.on_fit_start` instead."
368+
" will be removed in v1.8. Please use `Callback.on_fit_start` instead."
369369
)

pytorch_lightning/trainer/connectors/callback_connector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ def on_trainer_init(
5353
):
5454
# init folder paths for checkpoint + weights save callbacks
5555
self.trainer._default_root_dir = default_root_dir or os.getcwd()
56+
if weights_save_path:
57+
rank_zero_deprecation(
58+
"Setting `Trainer(weights_save_path=)` has been deprecated in v1.6 and will be"
59+
" removed in v1.8. Please pass ``dirpath`` directly to the `ModelCheckpoint` callback"
60+
)
61+
5662
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
5763
if stochastic_weight_avg:
5864
rank_zero_deprecation(

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,12 @@ def __init__(self, trainer: "pl.Trainer", resume_from_checkpoint: Optional[_PATH
5555

5656
@property
5757
def _hpc_resume_path(self) -> Optional[str]:
58-
weights_save_path = self.trainer.weights_save_path
59-
fs = get_filesystem(weights_save_path)
60-
if not fs.isdir(weights_save_path):
58+
# TODO: in v1.8 set this equal to self.trainer.default_root_dir
59+
dir_path_hpc = self.trainer._weights_save_path_internal
60+
fs = get_filesystem(dir_path_hpc)
61+
if not fs.isdir(dir_path_hpc):
6162
return None
62-
dir_path_hpc = str(weights_save_path)
63+
dir_path_hpc = str(dir_path_hpc)
6364
max_version = self.__max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
6465
if max_version is not None:
6566
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")

pytorch_lightning/trainer/connectors/signal_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
6868
# save logger to make sure we get all the metrics
6969
for logger in self.trainer.loggers:
7070
logger.finalize("finished")
71-
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer.weights_save_path)
71+
# TODO: in v1.8 change this to use self.trainer.default_root_dir
72+
hpc_save_path = self.trainer._checkpoint_connector.hpc_save_path(self.trainer._weights_save_path_internal)
7273
self.trainer.save_checkpoint(hpc_save_path)
7374

7475
if self.trainer.is_global_zero:

pytorch_lightning/trainer/trainer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def __init__(
169169
precision: Union[int, str] = 32,
170170
enable_model_summary: bool = True,
171171
weights_summary: Optional[str] = "top",
172-
weights_save_path: Optional[str] = None,
172+
weights_save_path: Optional[str] = None, # TODO: Remove in 1.8
173173
num_sanity_val_steps: int = 2,
174174
resume_from_checkpoint: Optional[Union[Path, str]] = None,
175175
profiler: Optional[Union[BaseProfiler, str]] = None,
@@ -447,6 +447,11 @@ def __init__(
447447
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
448448
Defaults to `default_root_dir`.
449449
450+
.. deprecated:: v1.6
451+
``weights_save_path`` has been deprecated in v1.6 and will be removed in v1.8. Please pass
452+
``dirpath`` directly to the :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
453+
callback.
454+
450455
move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu.
451456
This can save some gpu memory, but can make training slower. Use with attention.
452457
Default: ``False``.
@@ -2210,6 +2215,20 @@ def weights_save_path(self) -> str:
22102215
"""
22112216
The default root location to save weights (checkpoints), e.g., when the
22122217
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
2218+
2219+
.. deprecated:: v1.6
2220+
`Trainer.weights_save_path` has been deprecated in v1.6 and will be removed in v1.8.
2221+
"""
2222+
rank_zero_deprecation("`Trainer.weights_save_path` has been deprecated in v1.6 and will be removed in v1.8.")
2223+
return self._weights_save_path_internal
2224+
2225+
# TODO: Remove _weights_save_path_internal in v1.8
2226+
@property
2227+
def _weights_save_path_internal(self) -> str:
2228+
"""This is an internal implementation of weights_save_path which allows weights_save_path to be used
2229+
internally by the framework without emitting a deprecation warning.
2230+
2231+
To be removed in v1.8.
22132232
"""
22142233
if get_filesystem(self._weights_save_path).protocol == "file":
22152234
return os.path.normpath(self._weights_save_path)

tests/deprecated_api/test_remove_1-8.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def on_pretrain_routine_start(self, trainer, pl_module):
590590
default_root_dir=tmpdir,
591591
)
592592
with pytest.deprecated_call(
593-
match="The `Callback.on_pretrain_routine_start` hook has been deprecated in v1.6" " and will be removed in v1.8"
593+
match="The `Callback.on_pretrain_routine_start` hook has been deprecated in v1.6 and will be removed in v1.8"
594594
):
595595
trainer.fit(model)
596596

@@ -607,11 +607,18 @@ def on_pretrain_routine_end(self, trainer, pl_module):
607607
default_root_dir=tmpdir,
608608
)
609609
with pytest.deprecated_call(
610-
match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6" " and will be removed in v1.8"
610+
match="The `Callback.on_pretrain_routine_end` hook has been deprecated in v1.6 and will be removed in v1.8"
611611
):
612612
trainer.fit(model)
613613

614614

615+
def test_v1_8_0_weights_save_path(tmpdir):
616+
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_save_path=\)` has been deprecated in v1.6"):
617+
trainer = Trainer(weights_save_path=tmpdir)
618+
with pytest.deprecated_call(match=r"`Trainer.weights_save_path` has been deprecated in v1.6"):
619+
_ = trainer.weights_save_path
620+
621+
615622
@pytest.mark.flaky(reruns=3)
616623
@pytest.mark.parametrize(["action", "expected"], [("a", [3, 1]), ("b", [2]), ("c", [1])])
617624
def test_simple_profiler_iterable_durations(tmpdir, action: str, expected: list):

tests/loggers/test_all.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,26 +208,28 @@ def name(self):
208208
logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
209209
trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
210210
trainer.fit(model)
211-
assert trainer.weights_save_path == trainer.default_root_dir
211+
assert trainer._weights_save_path_internal == trainer.default_root_dir
212212
assert trainer.checkpoint_callback.dirpath == os.path.join(logger.save_dir, "name", "version", "checkpoints")
213213
assert trainer.default_root_dir == tmpdir
214214

215215
# with weights_save_path given, the logger path and checkpoint path should be different
216216
save_dir = tmpdir / "logs"
217217
weights_save_path = tmpdir / "weights"
218218
logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
219-
trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
219+
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_save_path=\)` has been deprecated in v1.6"):
220+
trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
220221
trainer.fit(model)
221-
assert trainer.weights_save_path == weights_save_path
222+
assert trainer._weights_save_path_internal == weights_save_path
222223
assert trainer.logger.save_dir == save_dir
223224
assert trainer.checkpoint_callback.dirpath == weights_save_path / "name" / "version" / "checkpoints"
224225
assert trainer.default_root_dir == tmpdir
225226

226227
# no logger given
227228
weights_save_path = tmpdir / "weights"
228-
trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path)
229+
with pytest.deprecated_call(match=r"Setting `Trainer\(weights_save_path=\)` has been deprecated in v1.6"):
230+
trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path)
229231
trainer.fit(model)
230-
assert trainer.weights_save_path == weights_save_path
232+
assert trainer._weights_save_path_internal == weights_save_path
231233
assert trainer.checkpoint_callback.dirpath == weights_save_path / "checkpoints"
232234
assert trainer.default_root_dir == tmpdir
233235

tests/models/data/horovod/test_train_script.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ def test_horovod_model_script(tmpdir):
1919
"""This just for testing/debugging horovod script without horovod..."""
2020
trainer_options = dict(
2121
default_root_dir=str(tmpdir),
22-
weights_save_path=str(tmpdir),
2322
gradient_clip_val=1.0,
2423
enable_progress_bar=False,
2524
max_epochs=1,

tests/models/data/horovod/train_default_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def run_test_from_config(trainer_options, on_gpu, check_size=True):
4949
set_random_main_port()
5050
reset_seed()
5151

52-
ckpt_path = trainer_options["weights_save_path"]
52+
ckpt_path = trainer_options["default_root_dir"]
5353
trainer_options.update(callbacks=[ModelCheckpoint(dirpath=ckpt_path)])
5454

5555
class TestModel(BoringModel):

tests/models/test_cpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def test_cpu_slurm_save_load(tmpdir):
6666
# save logger to make sure we get all the metrics
6767
if logger:
6868
logger.finalize("finished")
69-
hpc_save_path = trainer._checkpoint_connector.hpc_save_path(trainer.weights_save_path)
69+
hpc_save_path = trainer._checkpoint_connector.hpc_save_path(trainer.default_root_dir)
7070
trainer.save_checkpoint(hpc_save_path)
7171
assert os.path.exists(hpc_save_path)
7272

tests/models/test_horovod.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def test_horovod_cpu(tmpdir):
8080
"""Test Horovod running multi-process on CPU."""
8181
trainer_options = dict(
8282
default_root_dir=str(tmpdir),
83-
weights_save_path=str(tmpdir),
8483
gradient_clip_val=1.0,
8584
enable_progress_bar=False,
8685
max_epochs=1,
@@ -110,7 +109,6 @@ def test_horovod_cpu_clip_grad_by_value(tmpdir):
110109
"""Test Horovod running multi-process on CPU."""
111110
trainer_options = dict(
112111
default_root_dir=str(tmpdir),
113-
weights_save_path=str(tmpdir),
114112
gradient_clip_val=1.0,
115113
gradient_clip_algorithm="value",
116114
enable_progress_bar=False,
@@ -127,7 +125,6 @@ def test_horovod_cpu_implicit(tmpdir):
127125
"""Test Horovod without specifying a backend, inferring from env set by `horovodrun`."""
128126
trainer_options = dict(
129127
default_root_dir=str(tmpdir),
130-
weights_save_path=str(tmpdir),
131128
gradient_clip_val=1.0,
132129
enable_progress_bar=False,
133130
max_epochs=1,
@@ -142,7 +139,6 @@ def test_horovod_multi_gpu(tmpdir):
142139
"""Test Horovod with multi-GPU support."""
143140
trainer_options = dict(
144141
default_root_dir=str(tmpdir),
145-
weights_save_path=str(tmpdir),
146142
gradient_clip_val=1.0,
147143
enable_progress_bar=False,
148144
max_epochs=1,
@@ -193,7 +189,6 @@ def test_horovod_multi_gpu_grad_by_value(tmpdir):
193189
"""Test Horovod with multi-GPU support."""
194190
trainer_options = dict(
195191
default_root_dir=str(tmpdir),
196-
weights_save_path=str(tmpdir),
197192
gradient_clip_val=1.0,
198193
gradient_clip_algorithm="value",
199194
enable_progress_bar=False,
@@ -216,7 +211,6 @@ def test_horovod_apex(tmpdir):
216211
"""Test Horovod with multi-GPU support using apex amp."""
217212
trainer_options = dict(
218213
default_root_dir=str(tmpdir),
219-
weights_save_path=str(tmpdir),
220214
gradient_clip_val=1.0,
221215
enable_progress_bar=False,
222216
max_epochs=1,
@@ -236,7 +230,6 @@ def test_horovod_amp(tmpdir):
236230
"""Test Horovod with multi-GPU support using native amp."""
237231
trainer_options = dict(
238232
default_root_dir=str(tmpdir),
239-
weights_save_path=str(tmpdir),
240233
gradient_clip_val=1.0,
241234
enable_progress_bar=False,
242235
max_epochs=1,
@@ -256,7 +249,6 @@ def test_horovod_gather(tmpdir):
256249
"""Test Horovod with multi-GPU support using native amp."""
257250
trainer_options = dict(
258251
default_root_dir=str(tmpdir),
259-
weights_save_path=str(tmpdir),
260252
gradient_clip_val=1.0,
261253
enable_progress_bar=False,
262254
max_epochs=1,

tests/trainer/test_trainer_cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_default_args(mock_argparse, tmpdir):
4343
assert trainer.max_epochs == 5
4444

4545

46-
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []])
46+
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
4747
def test_add_argparse_args_redefined(cli_args: list):
4848
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
4949
parser = ArgumentParser(add_help=False)
@@ -64,7 +64,7 @@ def test_add_argparse_args_redefined(cli_args: list):
6464
assert isinstance(trainer, Trainer)
6565

6666

67-
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []])
67+
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
6868
def test_add_argparse_args(cli_args: list):
6969
"""Simple test ensuring Trainer.add_argparse_args works."""
7070
parser = ArgumentParser(add_help=False)
@@ -128,7 +128,6 @@ def _raise():
128128
# They should not be changed by the argparse interface.
129129
"min_steps": None,
130130
"accelerator": None,
131-
"weights_save_path": None,
132131
"profiler": None,
133132
},
134133
),

tests/utilities/test_cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_default_args(mock_argparse):
7575
assert trainer.max_epochs == 5
7676

7777

78-
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], ["--weights_save_path=./"], []])
78+
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
7979
def test_add_argparse_args_redefined(cli_args):
8080
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
8181
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
@@ -139,7 +139,6 @@ def _raise():
139139
# interface.
140140
min_steps=None,
141141
accelerator=None,
142-
weights_save_path=None,
143142
profiler=None,
144143
),
145144
),

0 commit comments

Comments
 (0)