Skip to content

Commit 29ed451

Browse files
1 parent c33df26 commit 29ed451

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
186186

187187
- Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848))
188188

189+
<<<<<<< HEAD
189190

190191
- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895))
192+
=======
193+
>>>>>>> parent of 6429de894 (Add support for `len(datamodule)` (#9895))
191194
192195

193196
- Added `enable_model_summary` flag to Trainer ([#9699](https://github.com/PyTorchLightning/pytorch-lightning/pull/9699))

pytorch_lightning/core/datamodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
2323
from pytorch_lightning.core.mixins import HyperparametersMixin
2424
from pytorch_lightning.utilities import rank_zero_deprecation
25-
from pytorch_lightning.utilities.apply_func import apply_to_collection
2625
from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types
27-
from pytorch_lightning.utilities.data import has_len
28-
from pytorch_lightning.utilities.warnings import rank_zero_warn
2926

3027

3128
class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin):
@@ -484,6 +481,7 @@ def __getstate__(self) -> dict:
484481
for fn in ("prepare_data", "setup", "teardown"):
485482
del d[fn]
486483
return d
484+
<<<<<<< HEAD
487485

488486
def __len__(self) -> int:
489487
"""Returns the total number of batches in all dataloaders defined in the datamodule."""
@@ -521,3 +519,5 @@ def get_num_batches(dataloader: DataLoader, name: str) -> None:
521519
rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.")
522520

523521
return num_batches
522+
=======
523+
>>>>>>> parent of 6429de894 (Add support for `len(datamodule)` (#9895))

tests/core/test_datamodules.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,13 @@
2121
import pytest
2222
import torch
2323
from omegaconf import OmegaConf
24-
from torch.utils.data import DataLoader
2524

2625
from pytorch_lightning import LightningDataModule, Trainer
2726
from pytorch_lightning.callbacks import ModelCheckpoint
28-
from pytorch_lightning.trainer.supporters import CombinedLoader
2927
from pytorch_lightning.utilities import AttributeDict
3028
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3129
from pytorch_lightning.utilities.model_helpers import is_overridden
32-
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
30+
from tests.helpers import BoringDataModule, BoringModel
3331
from tests.helpers.datamodules import ClassifDataModule
3432
from tests.helpers.runif import RunIf
3533
from tests.helpers.simple_models import ClassificationModel
@@ -566,14 +564,13 @@ class BoringDataModule1(LightningDataModule):
566564
batch_size: int
567565
dims: int = 2
568566

569-
def train_dataloader(self):
570-
return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size)
567+
def __post_init__(self):
568+
super().__init__(dims=self.dims)
571569

572570
# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
573571
# __repr__, __eq__, __lt__, __le__, etc.
574572
assert BoringDataModule1(batch_size=64).dims == 2
575573
assert BoringDataModule1(batch_size=32)
576-
assert len(BoringDataModule1(batch_size=32)) == 2
577574
assert hasattr(BoringDataModule1, "__repr__")
578575
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)
579576

@@ -584,9 +581,7 @@ class BoringDataModule2(LightningDataModule):
584581

585582
# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
586583
# __init__, __repr__, __eq__, __lt__, __le__, etc.
587-
assert BoringDataModule2(batch_size=32) is not None
588-
assert BoringDataModule2(batch_size=32).batch_size == 32
589-
assert len(BoringDataModule2(batch_size=32)) == 0
584+
assert BoringDataModule2(batch_size=32)
590585
assert hasattr(BoringDataModule2, "__repr__")
591586
assert BoringDataModule2(batch_size=32).prepare_data() is None
592587
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)

0 commit comments

Comments
 (0)