Skip to content

Fix DataLoader inspection and re-instantiation in Lite #10334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 5, 2021
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `apply_to_collection(defaultdict)` ([#10316](https://github.com/PyTorchLightning/pytorch-lightning/issues/10316))


-
- Fixed interception of ``__init__`` arguments for sub-classed DataLoader re-instantiation in Lite ([#10334](https://github.com/PyTorchLightning/pytorch-lightning/issues/10334))


-
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import functools
import inspect
from contextlib import contextmanager
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, Optional, Set, Sized, Type, Union

import torch
Expand Down Expand Up @@ -109,7 +110,7 @@ def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
params = dict(inspect.signature(module._old_init).parameters)
params.pop("args")
params.pop("kwargs")
for init_name, init_arg in zip(params, args):
for init_name, init_arg in chain(zip(params, args), kwargs.items()):
setattr(module, init_name, init_arg)
f(module, *args, **kwargs)

Expand Down
28 changes: 28 additions & 0 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,34 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


def test_setup_dataloaders_with_custom_type():
"""Test that Lite intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as
attributes."""

class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute1 = attribute1
super().__init__(*args, **kwargs)

class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute1, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
super().__init__(attribute1, *args, **kwargs)

class LiteWithCustomDataLoader(LightningLite):
def run(self):
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute2 == "attribute2"
lite_dataloader = self.setup_dataloaders(dataloader)
assert lite_dataloader.attribute1 == "attribute1"
assert lite_dataloader.attribute2 == "attribute2"

LiteWithCustomDataLoader().run()


def test_setup_custom_dataloaders():
"""Test that the setup_dataloaders method returns the dataloaders wrapped as LiteDataLoader."""
lite = EmptyLite()
Expand Down