Skip to content

Commit 5b59c95

Browse files
daniellepintzrohitgr7akihironittacarmocca
authored
Deprecate TrainerDataLoadingMixin and move logic to DataConnector (#11282)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Aki Nitta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent c0726ba commit 5b59c95

File tree

14 files changed

+529
-485
lines changed

14 files changed

+529
-485
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
241241
- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))
242242

243243

244+
- Deprecated `TrainerDataLoadingMixin` and moved functionality to `Trainer` and `DataConnector` ([#11282](https://github.com/PyTorchLightning/pytorch-lightning/pull/11282))
245+
246+
244247
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
245248

246249

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def _reload_evaluation_dataloaders(self) -> None:
188188
"""Reloads dataloaders if necessary."""
189189
if self.trainer.testing:
190190
self.trainer.reset_test_dataloader()
191-
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl:
191+
elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl:
192192
self.trainer.reset_val_dataloader()
193193

194194
def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:

pytorch_lightning/loops/fit_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
203203
model = self.trainer.lightning_module
204204

205205
# reset train dataloader
206-
if not self._is_fresh_start_epoch and self.trainer._should_reload_train_dl:
206+
if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl:
207207
self.trainer.reset_train_dataloader(model)
208208
self._is_fresh_start_epoch = False
209209

@@ -223,6 +223,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
223223

224224
def advance(self) -> None: # type: ignore[override]
225225
"""Runs one whole epoch."""
226+
assert self.trainer.train_dataloader is not None
226227
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
227228
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)
228229

pytorch_lightning/strategies/ipu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
121121
# patch the dataloader creation function with the custom `poptorch.DataLoader`.
122122
# this violates the intended control flow for the plugins, but since this is experimental, we have chosen
123123
# to use the simpler solution before adding abstractions to override the `DataLoader` class
124-
self._update_dataloader_original = pl.trainer.data_loading._update_dataloader
125-
pl.trainer.data_loading._update_dataloader = self._convert_to_poptorch_loader
124+
self._update_dataloader_original = pl.trainer.connectors.data_connector._update_dataloader
125+
pl.trainer.connectors.data_connector._update_dataloader = self._convert_to_poptorch_loader
126126

127127
super().setup(trainer)
128128

@@ -278,7 +278,7 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
278278
def teardown(self) -> None:
279279
super().teardown()
280280
# undo dataloader patching
281-
pl.trainer.data_loading._update_dataloader = self._update_dataloader_original
281+
pl.trainer.connectors.data_connector._update_dataloader = self._update_dataloader_original
282282

283283
for model in self.poptorch_models.values():
284284
model.destroy()

pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 307 additions & 3 deletions
Large diffs are not rendered by default.

pytorch_lightning/trainer/data_loading.py

Lines changed: 26 additions & 457 deletions
Large diffs are not rendered by default.

pytorch_lightning/trainer/trainer.py

Lines changed: 153 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch
2828
from packaging.version import Version
2929
from torch.optim import Optimizer
30+
from torch.utils.data import DataLoader
3031

3132
import pytorch_lightning as pl
3233
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
@@ -70,6 +71,7 @@
7071
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
7172
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
7273
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
74+
from pytorch_lightning.trainer.supporters import CombinedLoader
7375
from pytorch_lightning.tuner.lr_finder import _LRFinder
7476
from pytorch_lightning.tuner.tuning import Tuner
7577
from pytorch_lightning.utilities import (
@@ -85,14 +87,17 @@
8587
rank_zero_info,
8688
rank_zero_warn,
8789
)
90+
from pytorch_lightning.utilities.apply_func import apply_to_collection
8891
from pytorch_lightning.utilities.argparse import (
8992
_defaults_from_env_vars,
9093
add_argparse_args,
9194
from_argparse_args,
9295
parse_argparser,
9396
parse_env_variables,
9497
)
98+
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
9599
from pytorch_lightning.utilities.cloud_io import get_filesystem
100+
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
96101
from pytorch_lightning.utilities.distributed import distributed_available
97102
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
98103
from pytorch_lightning.utilities.imports import _fault_tolerant_training
@@ -119,9 +124,9 @@
119124

120125

121126
class Trainer(
122-
TrainerCallbackHookMixin,
123-
TrainerOptimizersMixin,
124-
TrainerDataLoadingMixin,
127+
TrainerCallbackHookMixin, # TODO: Remove in v1.8
128+
TrainerOptimizersMixin, # TODO: Remove in v1.8
129+
TrainerDataLoadingMixin, # TODO: Remove in v1.8
125130
):
126131
# Needed because of LightningOptimizer
127132
_lightning_optimizers = None
@@ -2372,6 +2377,151 @@ def terminate_on_nan(self, val: bool) -> None:
23722377
)
23732378
self._terminate_on_nan = val # : 212
23742379

2380+
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2381+
"""Resets the train dataloader and initialises required variables (number of batches, when to validate,
2382+
etc.).
2383+
2384+
Args:
2385+
model: The ``LightningModule`` if calling this outside of the trainer scope.
2386+
"""
2387+
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)
2388+
2389+
if self.overfit_batches > 0:
2390+
self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader)
2391+
2392+
# automatically add samplers
2393+
self.train_dataloader = apply_to_collection(
2394+
self.train_dataloader,
2395+
DataLoader,
2396+
self._data_connector._prepare_dataloader,
2397+
shuffle=True,
2398+
mode=RunningStage.TRAINING,
2399+
)
2400+
2401+
# check the workers recursively
2402+
apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader")
2403+
2404+
# add worker_init_fn for correct seeding in worker processes
2405+
apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)
2406+
2407+
# add collate_fn to collect metadata for fault tolerant training
2408+
if _fault_tolerant_training():
2409+
apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate)
2410+
2411+
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
2412+
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)
2413+
2414+
module = model or self.lightning_module or self.datamodule
2415+
self.num_training_batches = (
2416+
len(self.train_dataloader)
2417+
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
2418+
else float("inf")
2419+
)
2420+
2421+
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
2422+
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
2423+
elif self.num_training_batches != float("inf"):
2424+
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
2425+
elif self.limit_train_batches != 1.0:
2426+
raise MisconfigurationException(
2427+
"When using an IterableDataset for `limit_train_batches`,"
2428+
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
2429+
" `num_training_batches` to use."
2430+
)
2431+
2432+
# determine when to check validation
2433+
# if int passed in, val checks that often
2434+
# otherwise, it checks in [0, 1.0] % range of a training epoch
2435+
if isinstance(self.val_check_interval, int):
2436+
self.val_check_batch = self.val_check_interval
2437+
if self.val_check_batch > self.num_training_batches:
2438+
raise ValueError(
2439+
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
2440+
f"to the number of the training batches ({self.num_training_batches}). "
2441+
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
2442+
)
2443+
else:
2444+
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
2445+
if self.val_check_interval == 1.0:
2446+
self.val_check_batch = float("inf")
2447+
else:
2448+
raise MisconfigurationException(
2449+
"When using an IterableDataset for `train_dataloader`,"
2450+
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
2451+
" checking validation every k training batches."
2452+
)
2453+
else:
2454+
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
2455+
self.val_check_batch = max(1, self.val_check_batch)
2456+
2457+
if self.logger and self.num_training_batches < self.log_every_n_steps:
2458+
rank_zero_warn(
2459+
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
2460+
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
2461+
" you want to see logs for the training epoch.",
2462+
category=PossibleUserWarning,
2463+
)
2464+
2465+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
2466+
self._last_train_dl_reload_epoch = self.current_epoch
2467+
2468+
def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2469+
"""Resets the validation dataloader and determines the number of batches.
2470+
2471+
Args:
2472+
model: The ``LightningModule`` if called outside of the trainer scope.
2473+
"""
2474+
source = self._data_connector._val_dataloader_source
2475+
pl_module = self.lightning_module or model
2476+
has_step = is_overridden("validation_step", pl_module)
2477+
if source.is_defined() and has_step:
2478+
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
2479+
RunningStage.VALIDATING, model=pl_module
2480+
)
2481+
2482+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
2483+
self._last_val_dl_reload_epoch = self.current_epoch
2484+
2485+
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2486+
"""Resets the test dataloader and determines the number of batches.
2487+
2488+
Args:
2489+
model: The ``LightningModule`` if called outside of the trainer scope.
2490+
"""
2491+
source = self._data_connector._test_dataloader_source
2492+
pl_module = self.lightning_module or model
2493+
has_step = is_overridden("test_step", pl_module)
2494+
if source.is_defined() and has_step:
2495+
self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
2496+
RunningStage.TESTING, model=pl_module
2497+
)
2498+
2499+
def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2500+
"""Resets the predict dataloader and determines the number of batches.
2501+
2502+
Args:
2503+
model: The ``LightningModule`` if called outside of the trainer scope.
2504+
"""
2505+
source = self._data_connector._predict_dataloader_source
2506+
pl_module = self.lightning_module or model
2507+
if source.is_defined():
2508+
self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
2509+
RunningStage.PREDICTING, model=pl_module
2510+
)
2511+
2512+
def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
2513+
"""Resets train and val dataloaders if none are attached to the trainer.
2514+
2515+
The val dataloader must be initialized before training loop starts, as the training loop
2516+
inspects the val dataloader to determine whether to run the evaluation loop.
2517+
Args:
2518+
model: The ``LightningModule`` if called outside of the trainer scope.
2519+
"""
2520+
if self.train_dataloader is None:
2521+
self.reset_train_dataloader(model=model)
2522+
if self.val_dataloaders is None:
2523+
self.reset_val_dataloader(model=model)
2524+
23752525

23762526
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
23772527
if 0 <= batches <= 1:

tests/deprecated_api/test_remove_1-8.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@
3131
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
3232
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
3333
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
34+
from pytorch_lightning.trainer.states import RunningStage
3435
from pytorch_lightning.utilities import rank_zero_warn
3536
from pytorch_lightning.utilities.apply_func import move_data_to_device
3637
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
3738
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
38-
from tests.helpers.boring_model import BoringModel
39+
from tests.helpers.boring_model import BoringDataModule, BoringModel
3940
from tests.helpers.runif import RunIf
4041
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
4142

@@ -271,6 +272,22 @@ def test_v1_8_0_deprecated_training_type_plugin_property():
271272
trainer.training_type_plugin
272273

273274

275+
def test_v1_8_0_deprecate_trainer_data_loading_mixin():
276+
trainer = Trainer(max_epochs=1)
277+
model = BoringModel()
278+
dm = BoringDataModule()
279+
trainer.fit(model, datamodule=dm)
280+
281+
with pytest.deprecated_call(
282+
match=r"`TrainerDataLoadingMixin.prepare_dataloader` was deprecated in v1.6 and will be removed in v1.8.",
283+
):
284+
trainer.prepare_dataloader(dataloader=model.train_dataloader, shuffle=False)
285+
with pytest.deprecated_call(
286+
match=r"`TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 and will be removed in v1.8.",
287+
):
288+
trainer.request_dataloader(stage=RunningStage.TRAINING)
289+
290+
274291
def test_v_1_8_0_deprecated_device_stats_monitor_prefix_metric_keys():
275292
from pytorch_lightning.callbacks.device_stats_monitor import prefix_metric_keys
276293

tests/trainer/flags/test_limit_batches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_eval_limit_batches(stage, mode, limit_batches):
6262
trainer = Trainer(**{limit_eval_batches: limit_batches})
6363
model.trainer = trainer
6464
trainer._data_connector.attach_dataloaders(model)
65-
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
65+
loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model)
6666
expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches
6767
assert loader_num_batches[0] == expected_batches
6868
assert len(dataloaders[0]) == len(eval_loader)

tests/trainer/flags/test_overfit_batches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_overfit_batch_limits_eval(stage, mode, overfit_batches):
8080
model.trainer = trainer
8181
trainer._data_connector.attach_datamodule(model, datamodule=dm)
8282

83-
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
83+
loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model)
8484
if stage == RunningStage.VALIDATING:
8585
assert loader_num_batches[0] == 0
8686
else:

tests/trainer/test_data_loading.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,17 @@ def __init__(
248248
class CustomDummyObj:
249249
sampler = None
250250

251-
result = trainer.prepare_dataloader(CustomDummyObj(), shuffle=True)
251+
result = trainer._data_connector._prepare_dataloader(CustomDummyObj(), shuffle=True)
252252
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"
253253

254254
dataset = list(range(10))
255-
result = trainer.prepare_dataloader(CustomDataLoader(dataset), shuffle=True)
255+
result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset), shuffle=True)
256256
assert isinstance(result, DataLoader)
257257
assert isinstance(result, CustomDataLoader)
258258
assert result.dummy_kwarg is None
259259

260260
# Shuffled DataLoader should also work
261-
result = trainer.prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True)
261+
result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True)
262262
assert isinstance(result, DataLoader)
263263
assert isinstance(result, CustomDataLoader)
264264
assert result.dummy_kwarg is None
@@ -269,7 +269,7 @@ class CustomSampler(Sampler):
269269
# Should raise an error if existing sampler is being replaced
270270
dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset))
271271
with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"):
272-
trainer.prepare_dataloader(dataloader, shuffle=True)
272+
trainer._data_connector._prepare_dataloader(dataloader, shuffle=True)
273273

274274

275275
class LoaderTestModel(BoringModel):
@@ -351,7 +351,7 @@ def test_error_raised_with_float_limited_eval_batches():
351351
MisconfigurationException,
352352
match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
353353
):
354-
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
354+
trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model)
355355

356356

357357
@pytest.mark.parametrize(
@@ -375,4 +375,4 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
375375
model = BoringModel()
376376
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
377377
with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"):
378-
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
378+
trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model)

tests/trainer/test_dataloaders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ def train_dataloader(self):
563563
@RunIf(skip_windows=True)
564564
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
565565
@pytest.mark.parametrize("stage", ("train", "test", "val"))
566-
@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4)
566+
@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
567567
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
568568
"""Test that error is raised if dataloader with only a few workers is used."""
569569

@@ -593,7 +593,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
593593
@RunIf(skip_windows=True)
594594
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
595595
@pytest.mark.parametrize("stage", ("train", "test", "val"))
596-
@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4)
596+
@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
597597
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
598598
"""Test that error is raised if dataloader with only a few workers is used."""
599599

0 commit comments

Comments
 (0)