|
17 | 17 | from abc import ABC
|
18 | 18 | from copy import deepcopy
|
19 | 19 | from functools import partial
|
20 |
| -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
| 20 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union |
21 | 21 |
|
22 | 22 | from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
23 | 23 | from torch.utils.data.distributed import DistributedSampler
|
@@ -56,6 +56,7 @@ class TrainerDataLoadingMixin(ABC):
|
56 | 56 | accelerator: Accelerator
|
57 | 57 | accelerator_connector: AcceleratorConnector
|
58 | 58 | dev_debugger: InternalDebugger
|
| 59 | + call_hook: Callable |
59 | 60 |
|
60 | 61 | def _worker_check(self, dataloader: DataLoader, name: str) -> None:
|
61 | 62 | if not isinstance(dataloader, DataLoader):
|
@@ -437,8 +438,7 @@ def request_dataloader(self, model: LightningModule, stage: str) -> DataLoader:
|
437 | 438 | Returns:
|
438 | 439 | The dataloader
|
439 | 440 | """
|
440 |
| - if model.trainer is not None: |
441 |
| - model.trainer.call_hook(f"on_{stage}_dataloader") |
| 441 | + self.call_hook(f"on_{stage}_dataloader") |
442 | 442 | dataloader: DataLoader = getattr(model, f'{stage}_dataloader')()
|
443 | 443 | dataloader = self._flatten_dl_only(dataloader)
|
444 | 444 | self.accelerator.barrier('get_dataloaders')
|
|
0 commit comments