Skip to content

Commit b87c9ff

Browse files
committed
Merge branch 'master' into ci/drop-1.6
2 parents dce9a9c + 45f6a3b commit b87c9ff

File tree

7 files changed

+43
-38
lines changed

7 files changed

+43
-38
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7272
- Removed deprecated `has_prepared_data`, `has_setup_fit`, `has_setup_validate`, `has_setup_test`, `has_setup_predict`, `has_teardown_fit`, `has_teardown_validate`, `has_teardown_test` and `has_teardown_predict` datamodule lifecycle properties ([#10350](https://github.com/PyTorchLightning/pytorch-lightning/pull/10350))
7373

7474

75+
- Removed deprecated `every_n_val_epochs` parameter of ModelCheckpoint ([#10366](https://github.com/PyTorchLightning/pytorch-lightning/pull/10366))
76+
77+
78+
- Removed deprecated property `configure_slurm_dpp` from accelerator connector ([#10370](https://github.com/PyTorchLightning/pytorch-lightning/pull/10370))
79+
80+
7581
- Removed deprecated arguments `num_nodes` and `sync_batchnorm` from `DDPPlugin`, `DDPSpawnPlugin`, `DeepSpeedPlugin` ([#10357](https://github.com/PyTorchLightning/pytorch-lightning/pull/10357))
7682

7783

@@ -85,6 +91,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8591
- Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316))
8692

8793

94+
- Fixed interception of `__init__` arguments for sub-classed DataLoader re-instantiation in Lite ([#10334](https://github.com/PyTorchLightning/pytorch-lightning/issues/10334))
95+
96+
8897
- Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345))
8998

9099

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
import pytorch_lightning as pl
3535
from pytorch_lightning.callbacks.base import Callback
36-
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn
36+
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
3737
from pytorch_lightning.utilities.cloud_io import get_filesystem
3838
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3939
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
@@ -123,12 +123,6 @@ class ModelCheckpoint(Callback):
123123
where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E.
124124
save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch.
125125
If this is ``False``, then the check runs at the end of the validation.
126-
every_n_val_epochs: Number of epochs between checkpoints.
127-
128-
.. warning::
129-
This argument has been deprecated in v1.4 and will be removed in v1.6.
130-
131-
Use ``every_n_epochs`` instead.
132126
133127
Note:
134128
For extra customization, ModelCheckpoint includes the following attributes:
@@ -214,7 +208,6 @@ def __init__(
214208
train_time_interval: Optional[timedelta] = None,
215209
every_n_epochs: Optional[int] = None,
216210
save_on_train_epoch_end: Optional[bool] = None,
217-
every_n_val_epochs: Optional[int] = None,
218211
):
219212
super().__init__()
220213
self.monitor = monitor
@@ -233,13 +226,6 @@ def __init__(
233226
self.best_model_path = ""
234227
self.last_model_path = ""
235228

236-
if every_n_val_epochs is not None:
237-
rank_zero_deprecation(
238-
"`ModelCheckpoint(every_n_val_epochs)` is deprecated in v1.4 and will be removed in v1.6."
239-
" Please use `every_n_epochs` instead."
240-
)
241-
every_n_epochs = every_n_val_epochs
242-
243229
self.__init_monitor_mode(mode)
244230
self.__init_ckpt_dir(dirpath, filename)
245231
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)

pytorch_lightning/lite/wrappers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import functools
1515
import inspect
1616
from contextlib import contextmanager
17+
from itertools import chain
1718
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union
1819

1920
import torch
@@ -109,7 +110,7 @@ def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
109110
params = dict(inspect.signature(module._old_init).parameters)
110111
params.pop("args")
111112
params.pop("kwargs")
112-
for init_name, init_arg in zip(params, args):
113+
for init_name, init_arg in chain(zip(params, args), kwargs.items()):
113114
setattr(module, init_name, init_arg)
114115
f(module, *args, **kwargs)
115116

@@ -118,15 +119,15 @@ def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
118119

119120
# https://stackoverflow.com/a/63851681/9201239
120121
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
121-
subclass_list = []
122+
subclasses = set()
122123

123124
def recurse(cl: Type[Any]) -> None:
124125
for subclass in cl.__subclasses__():
125-
subclass_list.append(subclass)
126+
subclasses.add(subclass)
126127
recurse(subclass)
127128

128129
recurse(cls)
129-
return set(subclass_list)
130+
return subclasses
130131

131132

132133
def _enable_class(cls: Type[Any]) -> None:

pytorch_lightning/loggers/wandb.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,6 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelChe
490490
"save_top_k",
491491
"save_weights_only",
492492
"_every_n_train_steps",
493-
"_every_n_val_epochs",
494493
]
495494
# ensure it does not break if `ModelCheckpoint` args change
496495
if hasattr(checkpoint_callback, k)

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -990,12 +990,6 @@ def update_device_type_if_training_type_plugin_passed(self) -> None:
990990
elif self.has_gpu:
991991
self._device_type = DeviceType.GPU
992992

993-
def configure_slurm_ddp(self) -> None:
994-
rank_zero_deprecation(
995-
"`AcceleratorConnector.configure_slurm_ddp()` was deprecated in v1.5 and will be removed in v1.6."
996-
)
997-
self._configure_slurm_ddp()
998-
999993
def _configure_slurm_ddp(self):
1000994
# extract SLURM flag vars
1001995
# whenever we have the correct number of tasks, we let slurm manage processes

tests/deprecated_api/test_remove_1-6.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch.optim import Optimizer
2020

2121
from pytorch_lightning import Trainer
22-
from pytorch_lightning.callbacks import ModelCheckpoint
2322
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
2423
from pytorch_lightning.plugins import PrecisionPlugin
2524
from pytorch_lightning.plugins.training_type import DDPPlugin
@@ -167,11 +166,6 @@ def test_v1_6_0_deprecated_disable_validation():
167166
_ = trainer.disable_validation
168167

169168

170-
def test_v1_6_0_every_n_val_epochs():
171-
with pytest.deprecated_call(match="use `every_n_epochs` instead"):
172-
_ = ModelCheckpoint(every_n_val_epochs=1)
173-
174-
175169
def test_v1_6_0_deprecated_hpc_load(tmpdir):
176170
model = BoringModel()
177171
trainer = Trainer(default_root_dir=tmpdir, max_steps=1)
@@ -267,12 +261,6 @@ def test_v1_6_0_deprecated_accelerator_pass_through_functions():
267261
accelerator.on_train_batch_start(batch=None, batch_idx=0)
268262

269263

270-
def test_v1_6_0_configure_slurm_ddp():
271-
trainer = Trainer()
272-
with pytest.deprecated_call(match=r"`AcceleratorConnector.configure_slurm_ddp\(\)` was deprecated in v1.5"):
273-
trainer._accelerator_connector.configure_slurm_ddp()
274-
275-
276264
def test_v1_6_0_master_params():
277265
with pytest.deprecated_call(match="`PrecisionPlugin.master_params` was deprecated in v1.5"):
278266
PrecisionPlugin().master_params(Mock(spec=Optimizer))

tests/lite/test_lite.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,34 @@ def test_setup_dataloaders_return_type():
164164
assert lite_dataloader1.dataset is dataset1
165165

166166

167+
def test_setup_dataloaders_with_custom_type():
168+
"""Test that Lite intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as
169+
attributes."""
170+
171+
class DataLoaderSubclass1(DataLoader):
172+
def __init__(self, attribute1, *args, **kwargs):
173+
# intentionally not setting this attribute, calling super with different args
174+
# self.attribute1 = attribute1
175+
super().__init__(*args, **kwargs)
176+
177+
class DataLoaderSubclass2(DataLoaderSubclass1):
178+
def __init__(self, attribute1, attribute2, *args, **kwargs):
179+
# intentionally not setting this attribute, calling super with different args
180+
# self.attribute2 = attribute2
181+
super().__init__(attribute1, *args, **kwargs)
182+
183+
class LiteWithCustomDataLoader(LightningLite):
184+
def run(self):
185+
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
186+
assert dataloader.attribute1 == "attribute1"
187+
assert dataloader.attribute2 == "attribute2"
188+
lite_dataloader = self.setup_dataloaders(dataloader)
189+
assert lite_dataloader.attribute1 == "attribute1"
190+
assert lite_dataloader.attribute2 == "attribute2"
191+
192+
LiteWithCustomDataLoader().run()
193+
194+
167195
def test_setup_custom_dataloaders():
168196
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
169197
lite = EmptyLite()

0 commit comments

Comments
 (0)