Skip to content

Commit 6429de8

Browse files
kingyiusuentchaton
andauthored
Add support for len(datamodule) (#9895)
Co-authored-by: tchaton <[email protected]>
1 parent 16213b1 commit 6429de8

File tree

3 files changed

+113
-4
lines changed

3 files changed

+113
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
186186

187187
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
188188

189+
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
189190

190191
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
191192

pytorch_lightning/core/datamodule.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2323
from pytorch_lightning.core.mixins import HyperparametersMixin
2424
from pytorch_lightning.utilities import rank_zero_deprecation
25+
from pytorch_lightning.utilities.apply_func import apply_to_collection
2526
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
27+
from pytorch_lightning.utilities.data import has_len
28+
from pytorch_lightning.utilities.warnings import rank_zero_warn
2629

2730

2831
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
@@ -481,3 +484,37 @@ def __getstate__(self) -> dict:
481484
for fn in ("prepare_data", "setup", "teardown"):
482485
del d[fn]
483486
return d
487+
488+
def __len__(self) -> int:
489+
"""Returns the total number of batches in all dataloaders defined in the datamodule."""
490+
491+
from pytorch_lightning.trainer.supporters import CombinedLoader
492+
493+
num_batches = 0
494+
not_implemented_count = 0
495+
496+
def get_num_batches(dataloader: DataLoader, name: str) -> None:
497+
nonlocal num_batches
498+
if not has_len(dataloader):
499+
rank_zero_warn(
500+
f"The number of batches for a dataloader in `{name}` is counted as 0 "
501+
"because it does not have `__len__` defined."
502+
)
503+
else:
504+
num_batches += len(dataloader)
505+
506+
for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"):
507+
dataloader_method = getattr(self, method_name)
508+
try:
509+
dataloader = dataloader_method()
510+
except NotImplementedError:
511+
not_implemented_count += 1
512+
continue
513+
if isinstance(dataloader, CombinedLoader):
514+
dataloader = dataloader.loaders
515+
apply_to_collection(dataloader, DataLoader, get_num_batches, method_name)
516+
517+
if not_implemented_count == 4:
518+
rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.")
519+
520+
return num_batches

tests/core/test_datamodules.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
import pytest
2222
import torch
2323
from omegaconf import OmegaConf
24+
from torch.utils.data import DataLoader
2425

2526
from pytorch_lightning import LightningDataModule, Trainer
2627
from pytorch_lightning.callbacks import ModelCheckpoint
28+
from pytorch_lightning.trainer.supporters import CombinedLoader
2729
from pytorch_lightning.utilities import AttributeDict
2830
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2931
from pytorch_lightning.utilities.model_helpers import is_overridden
30-
from tests.helpers import BoringDataModule, BoringModel
32+
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
3133
from tests.helpers.datamodules import ClassifDataModule
3234
from tests.helpers.runif import RunIf
3335
from tests.helpers.simple_models import ClassificationModel
@@ -564,13 +566,14 @@ class BoringDataModule1(LightningDataModule):
564566
batch_size: int
565567
dims: int = 2
566568

567-
def __post_init__(self):
568-
super().__init__(dims=self.dims)
569+
def train_dataloader(self):
570+
return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size)
569571

570572
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
571573
# __repr__, __eq__, __lt__, __le__, etc.
572574
assert BoringDataModule1(batch_size=64).dims == 2
573575
assert BoringDataModule1(batch_size=32)
576+
assert len(BoringDataModule1(batch_size=32)) == 2
574577
assert hasattr(BoringDataModule1, "__repr__")
575578
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
576579

@@ -581,7 +584,9 @@ class BoringDataModule2(LightningDataModule):
581584

582585
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
583586
# __init__, __repr__, __eq__, __lt__, __le__, etc.
584-
assert BoringDataModule2(batch_size=32)
587+
assert BoringDataModule2(batch_size=32) is not None
588+
assert BoringDataModule2(batch_size=32).batch_size == 32
589+
assert len(BoringDataModule2(batch_size=32)) == 0
585590
assert hasattr(BoringDataModule2, "__repr__")
586591
assert BoringDataModule2(batch_size=32).prepare_data() is None
587592
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
@@ -625,3 +630,69 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
625630
trainer.model = model
626631
trainer.datamodule = dm
627632
trainer.data_connector.prepare_data()
633+
634+
635+
DATALOADER = DataLoader(RandomDataset(1, 32))
636+
637+
638+
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
639+
@pytest.mark.parametrize(
640+
["dataloader", "expected"],
641+
[
642+
[DATALOADER, 32],
643+
[[DATALOADER, DATALOADER], 64],
644+
[[[DATALOADER], [DATALOADER, DATALOADER]], 96],
645+
[[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96],
646+
[{"foo": DATALOADER, "bar": DATALOADER}, 64],
647+
[{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96],
648+
[{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96],
649+
[CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64],
650+
],
651+
)
652+
def test_len_different_types(method_name, dataloader, expected):
653+
dm = LightningDataModule()
654+
setattr(dm, method_name, lambda: dataloader)
655+
assert len(dm) == expected
656+
657+
658+
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
659+
def test_len_dataloader_no_len(method_name):
660+
class CustomNotImplementedErrorDataloader(DataLoader):
661+
def __len__(self):
662+
raise NotImplementedError
663+
664+
dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32))
665+
dm = LightningDataModule()
666+
setattr(dm, method_name, lambda: dataloader)
667+
with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"):
668+
assert len(dm) == 0
669+
670+
671+
def test_len_all_dataloader_methods_implemented():
672+
class BoringDataModule(LightningDataModule):
673+
def __init__(self, dataloader):
674+
super().__init__()
675+
self.dataloader = dataloader
676+
677+
def train_dataloader(self):
678+
return {"foo": self.dataloader, "bar": self.dataloader}
679+
680+
def val_dataloader(self):
681+
return self.dataloader
682+
683+
def test_dataloader(self):
684+
return [self.dataloader]
685+
686+
def predict_dataloader(self):
687+
return [self.dataloader, self.dataloader]
688+
689+
dm = BoringDataModule(DATALOADER)
690+
691+
# 6 dataloaders each producing 32 batches: 6 * 32 = 192
692+
assert len(dm) == 192
693+
694+
695+
def test_len_no_dataloader_methods_implemented():
696+
dm = LightningDataModule()
697+
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
698+
assert len(dm) == 0

0 commit comments

Comments
 (0)