Skip to content

Commit b7a4448

Browse files
ananthsubawaelchli
andauthored
Remove model.trainer call inside of dataloading mixin (#7317)
* Update data_loading.py * Update data_loading.py Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 78a6fd5 commit b7a4448

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from abc import ABC
1818
from copy import deepcopy
1919
from functools import partial
20-
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20+
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
2121

2222
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
2323
from torch.utils.data.distributed import DistributedSampler
@@ -56,6 +56,7 @@ class TrainerDataLoadingMixin(ABC):
5656
accelerator: Accelerator
5757
accelerator_connector: AcceleratorConnector
5858
dev_debugger: InternalDebugger
59+
call_hook: Callable
5960

6061
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
6162
if not isinstance(dataloader, DataLoader):
@@ -437,8 +438,7 @@ def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader:
437438
Returns:
438439
The dataloader
439440
"""
440-
if model.trainer is not None:
441-
model.trainer.call_hook(f"on_{stage}_dataloader")
441+
self.call_hook(f"on_{stage}_dataloader")
442442
dataloader: DataLoader = getattr(model, f'{stage}_dataloader')()
443443
dataloader = self._flatten_dl_only(dataloader)
444444
self.accelerator.barrier('get_dataloaders')

0 commit comments

Comments
 (0)