diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a8e9d2c0e7ad..6bc7ee502437b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,6 +86,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316)) +- Fixed interception of `__init__` arguments for sub-classed DataLoader re-instantiation in Lite ([#10334](https://github.com/PyTorchLightning/pytorch-lightning/issues/10334)) + + - Fixed failure when `DataLoader(batch_size=None)` is passed ([#10345](https://github.com/PyTorchLightning/pytorch-lightning/issues/10345)) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index ad01b44ef30f4..881a663fdb9e5 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -14,6 +14,7 @@ import functools import inspect from contextlib import contextmanager +from itertools import chain from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union import torch @@ -109,7 +110,7 @@ def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None: params = dict(inspect.signature(module._old_init).parameters) params.pop("args") params.pop("kwargs") - for init_name, init_arg in zip(params, args): + for init_name, init_arg in chain(zip(params, args), kwargs.items()): setattr(module, init_name, init_arg) f(module, *args, **kwargs) @@ -118,15 +119,15 @@ def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None: # https://stackoverflow.com/a/63851681/9201239 def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: - subclass_list = [] + subclasses = set() def recurse(cl: Type[Any]) -> None: for subclass in cl.__subclasses__(): - subclass_list.append(subclass) + subclasses.add(subclass) recurse(subclass) recurse(cls) - return set(subclass_list) + return subclasses def _enable_class(cls: Type[Any]) -> None: diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 8eac30f9cf823..b563e56e2fdec 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -164,6 +164,34 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 +def test_setup_dataloaders_with_custom_type(): + """Test that Lite intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as + attributes.""" + + class DataLoaderSubclass1(DataLoader): + def __init__(self, attribute1, *args, **kwargs): + # intentionally not setting this attribute, calling super with different args + # self.attribute1 = attribute1 + super().__init__(*args, **kwargs) + + class DataLoaderSubclass2(DataLoaderSubclass1): + def __init__(self, attribute1, attribute2, *args, **kwargs): + # intentionally not setting this attribute, calling super with different args + # self.attribute2 = attribute2 + super().__init__(attribute1, *args, **kwargs) + + class LiteWithCustomDataLoader(LightningLite): + def run(self): + dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2) + assert dataloader.attribute1 == "attribute1" + assert dataloader.attribute2 == "attribute2" + lite_dataloader = self.setup_dataloaders(dataloader) + assert lite_dataloader.attribute1 == "attribute1" + assert lite_dataloader.attribute2 == "attribute2" + + LiteWithCustomDataLoader().run() + + def test_setup_custom_dataloaders(): """Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader.""" lite = EmptyLite()