Skip to content

Commit c409ae4

Browse files
Revert "Add support for len(datamodule) (Lightning-AI#9895)"
This reverts commit 6429de8.
1 parent c33df26 commit c409ae4

File tree

3 files changed

+4
-125
lines changed

3 files changed

+4
-125
lines changed

CHANGELOG.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
187187
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
188188

189189

190-
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
191-
192-
193190
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))
194191

195192

pytorch_lightning/core/datamodule.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2323
from pytorch_lightning.core.mixins import HyperparametersMixin
2424
from pytorch_lightning.utilities import rank_zero_deprecation
25-
from pytorch_lightning.utilities.apply_func import apply_to_collection
2625
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
27-
from pytorch_lightning.utilities.data import has_len
28-
from pytorch_lightning.utilities.warnings import rank_zero_warn
2926

3027

3128
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
@@ -484,40 +481,3 @@ def __getstate__(self) -> dict:
484481
for fn in ("prepare_data", "setup", "teardown"):
485482
del d[fn]
486483
return d
487-
488-
def __len__(self) -> int:
489-
"""Returns the total number of batches in all dataloaders defined in the datamodule."""
490-
491-
from pytorch_lightning.trainer.supporters import CombinedLoader
492-
493-
num_batches = 0
494-
not_implemented_count = 0
495-
496-
def get_num_batches(dataloader: DataLoader, name: str) -> None:
497-
nonlocal num_batches
498-
if not has_len(dataloader):
499-
rank_zero_warn(
500-
f"The number of batches for a dataloader in `{name}` is counted as 0 "
501-
"because it does not have `__len__` defined."
502-
)
503-
else:
504-
num_batches += len(dataloader)
505-
506-
for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"):
507-
dataloader_method = getattr(self, method_name)
508-
if not callable(dataloader_method):
509-
not_implemented_count += 1
510-
continue
511-
try:
512-
dataloader = dataloader_method()
513-
except NotImplementedError:
514-
not_implemented_count += 1
515-
continue
516-
if isinstance(dataloader, CombinedLoader):
517-
dataloader = dataloader.loaders
518-
apply_to_collection(dataloader, DataLoader, get_num_batches, method_name)
519-
520-
if not_implemented_count == 4:
521-
rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.")
522-
523-
return num_batches

tests/core/test_datamodules.py

Lines changed: 4 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,13 @@
2121
import pytest
2222
import torch
2323
from omegaconf import OmegaConf
24-
from torch.utils.data import DataLoader
2524

2625
from pytorch_lightning import LightningDataModule, Trainer
2726
from pytorch_lightning.callbacks import ModelCheckpoint
28-
from pytorch_lightning.trainer.supporters import CombinedLoader
2927
from pytorch_lightning.utilities import AttributeDict
3028
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3129
from pytorch_lightning.utilities.model_helpers import is_overridden
32-
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
30+
from tests.helpers import BoringDataModule, BoringModel
3331
from tests.helpers.datamodules import ClassifDataModule
3432
from tests.helpers.runif import RunIf
3533
from tests.helpers.simple_models import ClassificationModel
@@ -566,14 +564,13 @@ class BoringDataModule1(LightningDataModule):
566564
batch_size: int
567565
dims: int = 2
568566

569-
def train_dataloader(self):
570-
return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size)
567+
def __post_init__(self):
568+
super().__init__(dims=self.dims)
571569

572570
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
573571
# __repr__, __eq__, __lt__, __le__, etc.
574572
assert BoringDataModule1(batch_size=64).dims == 2
575573
assert BoringDataModule1(batch_size=32)
576-
assert len(BoringDataModule1(batch_size=32)) == 2
577574
assert hasattr(BoringDataModule1, "__repr__")
578575
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
579576

@@ -584,9 +581,7 @@ class BoringDataModule2(LightningDataModule):
584581

585582
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
586583
# __init__, __repr__, __eq__, __lt__, __le__, etc.
587-
assert BoringDataModule2(batch_size=32) is not None
588-
assert BoringDataModule2(batch_size=32).batch_size == 32
589-
assert len(BoringDataModule2(batch_size=32)) == 0
584+
assert BoringDataModule2(batch_size=32)
590585
assert hasattr(BoringDataModule2, "__repr__")
591586
assert BoringDataModule2(batch_size=32).prepare_data() is None
592587
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)
@@ -630,76 +625,3 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
630625
trainer.model = model
631626
trainer.datamodule = dm
632627
trainer._data_connector.prepare_data()
633-
634-
635-
DATALOADER = DataLoader(RandomDataset(1, 32))
636-
637-
638-
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
639-
@pytest.mark.parametrize(
640-
["dataloader", "expected"],
641-
[
642-
[DATALOADER, 32],
643-
[[DATALOADER, DATALOADER], 64],
644-
[[[DATALOADER], [DATALOADER, DATALOADER]], 96],
645-
[[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96],
646-
[{"foo": DATALOADER, "bar": DATALOADER}, 64],
647-
[{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96],
648-
[{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96],
649-
[CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64],
650-
],
651-
)
652-
def test_len_different_types(method_name, dataloader, expected):
653-
dm = LightningDataModule()
654-
setattr(dm, method_name, lambda: dataloader)
655-
assert len(dm) == expected
656-
657-
658-
@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"])
659-
def test_len_dataloader_no_len(method_name):
660-
class CustomNotImplementedErrorDataloader(DataLoader):
661-
def __len__(self):
662-
raise NotImplementedError
663-
664-
dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32))
665-
dm = LightningDataModule()
666-
setattr(dm, method_name, lambda: dataloader)
667-
with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"):
668-
assert len(dm) == 0
669-
670-
671-
def test_len_all_dataloader_methods_implemented():
672-
class BoringDataModule(LightningDataModule):
673-
def __init__(self, dataloader):
674-
super().__init__()
675-
self.dataloader = dataloader
676-
677-
def train_dataloader(self):
678-
return {"foo": self.dataloader, "bar": self.dataloader}
679-
680-
def val_dataloader(self):
681-
return self.dataloader
682-
683-
def test_dataloader(self):
684-
return [self.dataloader]
685-
686-
def predict_dataloader(self):
687-
return [self.dataloader, self.dataloader]
688-
689-
dm = BoringDataModule(DATALOADER)
690-
691-
# 6 dataloaders each producing 32 batches: 6 * 32 = 192
692-
assert len(dm) == 192
693-
694-
695-
def test_len_no_dataloader_methods_implemented():
696-
dm = LightningDataModule()
697-
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
698-
assert len(dm) == 0
699-
700-
dm.train_dataloader = None
701-
dm.val_dataloader = None
702-
dm.test_dataloader = None
703-
dm.predict_dataloader = None
704-
with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"):
705-
assert len(dm) == 0

0 commit comments

Comments
 (0)