From 6f52056fa44276644ca24d7de2ec83dff97d9e8b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 14:10:27 -0700 Subject: [PATCH 1/2] Update data_loading.py --- pytorch_lightning/trainer/data_loading.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 846e267690e94..c1299b79081b1 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -56,6 +56,7 @@ class TrainerDataLoadingMixin(ABC): accelerator: Accelerator accelerator_connector: AcceleratorConnector dev_debugger: InternalDebugger + call_hook: Callable def _worker_check(self, dataloader: DataLoader, name: str) -> None: if not isinstance(dataloader, DataLoader): @@ -437,8 +438,7 @@ def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader: Returns: The dataloader """ - if model.trainer is not None: - model.trainer.call_hook(f"on_{stage}_dataloader") + self.call_hook(f"on_{stage}_dataloader") dataloader: DataLoader = getattr(model, f'{stage}_dataloader')() dataloader = self._flatten_dl_only(dataloader) self.accelerator.barrier('get_dataloaders') From 72c0919b76e04559ea19dc07dc0ce428a5569b0e Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 2 May 2021 14:17:32 -0700 Subject: [PATCH 2/2] Update data_loading.py --- pytorch_lightning/trainer/data_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index c1299b79081b1..6c66a063c9382 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -17,7 +17,7 @@ from abc import ABC from copy import deepcopy from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler