Skip to content

Deprecate TrainerDataLoadingMixin and move logic to DataConnector #11282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))


- Deprecated `TrainerDataLoadingMixin` and moved functionality to `Trainer` and `DataConnector` ([#11282](https://github.com/PyTorchLightning/pytorch-lightning/pull/11282))


- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _reload_evaluation_dataloaders(self) -> None:
"""Reloads dataloaders if necessary."""
if self.trainer.testing:
self.trainer.reset_test_dataloader()
elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl:
elif self.trainer.val_dataloaders is None or self.trainer._data_connector._should_reload_val_dl:
self.trainer.reset_val_dataloader()

def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
model = self.trainer.lightning_module

# reset train dataloader
if not self._is_fresh_start_epoch and self.trainer._should_reload_train_dl:
if not self._is_fresh_start_epoch and self.trainer._data_connector._should_reload_train_dl:
self.trainer.reset_train_dataloader(model)
self._is_fresh_start_epoch = False

Expand All @@ -223,6 +223,7 @@ def on_advance_start(self) -> None: # type: ignore[override]

def advance(self) -> None: # type: ignore[override]
"""Runs one whole epoch."""
assert self.trainer.train_dataloader is not None
dataloader = self.trainer.strategy.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/strategies/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def setup(self, trainer: "pl.Trainer") -> None:
# patch the dataloader creation function with the custom `poptorch.DataLoader`.
# this violates the intended control flow for the plugins, but since this is experimental, we have chosen
# to use the simpler solution before adding abstractions to override the `DataLoader` class
self._update_dataloader_original = pl.trainer.data_loading._update_dataloader
pl.trainer.data_loading._update_dataloader = self._convert_to_poptorch_loader
self._update_dataloader_original = pl.trainer.connectors.data_connector._update_dataloader
pl.trainer.connectors.data_connector._update_dataloader = self._convert_to_poptorch_loader

super().setup(trainer)

Expand Down Expand Up @@ -278,7 +278,7 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
def teardown(self) -> None:
super().teardown()
# undo dataloader patching
pl.trainer.data_loading._update_dataloader = self._update_dataloader_original
pl.trainer.connectors.data_connector._update_dataloader = self._update_dataloader_original

for model in self.poptorch_models.values():
model.destroy()
Expand Down
310 changes: 307 additions & 3 deletions pytorch_lightning/trainer/connectors/data_connector.py

Large diffs are not rendered by default.

483 changes: 26 additions & 457 deletions pytorch_lightning/trainer/data_loading.py

Large diffs are not rendered by default.

156 changes: 153 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
Expand Down Expand Up @@ -70,6 +71,7 @@
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.tuner.lr_finder import _LRFinder
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import (
Expand All @@ -85,14 +87,17 @@
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.argparse import (
_defaults_from_env_vars,
add_argparse_args,
from_argparse_args,
parse_argparser,
parse_env_variables,
)
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
Expand All @@ -119,9 +124,9 @@


class Trainer(
TrainerCallbackHookMixin,
TrainerOptimizersMixin,
TrainerDataLoadingMixin,
TrainerCallbackHookMixin, # TODO: Remove in v1.8
TrainerOptimizersMixin, # TODO: Remove in v1.8
TrainerDataLoadingMixin, # TODO: Remove in v1.8
):
# Needed because of LightningOptimizer
_lightning_optimizers = None
Expand Down Expand Up @@ -2364,6 +2369,151 @@ def terminate_on_nan(self, val: bool) -> None:
)
self._terminate_on_nan = val # : 212

def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about moving the reload logic to data_connector and trainer.reset_train_dataloader just calls trainer._data_connector._reset_train_dataloader? not sure if it's a good idea, but just trying to avoid all the logic being kept inside the single Trainer module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel strongly either way. I probably have a slight preference for keeping it as it is now, since I feel like having a function that just calls the private function adds a bit of unnecessary abstraction/complexity. If others feel strongly we can change it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say it's fine as it is. I'd only do that if it aids in testing, but we don't have unit tests for each of these separate methods.

Copy link
Contributor

@awaelchli awaelchli Jan 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine because the trainer owns the dataloaders, so resetting is her responsibility. But this won't be the case in the future, pretty sure. Which is why I also suggested in my previous comment to map the public property to the protected one in the connector.

In any case, let's just please not mix properties and methods like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n any case, let's just please not mix properties and methods like this.

@daniellepintz I believe my comment was missed here. Could you send a follow up PR? I believe reviewers did not catch this properly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli sorry I am confused by that line. Could you clarify which property you are talking about?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean the disorganization that this has caused. Previously, methods and properties were organized, but now new methods have been added and now it's mixed up:

method
method
property
property
new method
new method

It should be

method
method
new method
new method
property
property

given that Trainer wants to have all properties at the bottom.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see good point. Will send a follow up PR

"""Resets the train dataloader and initialises required variables (number of batches, when to validate,
etc.).

Args:
model: The ``LightningModule`` if calling this outside of the trainer scope.
"""
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)

if self.overfit_batches > 0:
self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader)

# automatically add samplers
self.train_dataloader = apply_to_collection(
self.train_dataloader,
DataLoader,
self._data_connector._prepare_dataloader,
shuffle=True,
mode=RunningStage.TRAINING,
)

# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader")

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)

module = model or self.lightning_module or self.datamodule
self.num_training_batches = (
len(self.train_dataloader)
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
else float("inf")
)

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float("inf"):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_train_batches`,"
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
" `num_training_batches` to use."
)

# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
)
else:
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
if self.val_check_interval == 1.0:
self.val_check_batch = float("inf")
else:
raise MisconfigurationException(
"When using an IterableDataset for `train_dataloader`,"
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
" checking validation every k training batches."
)
else:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

if self.logger and self.num_training_batches < self.log_every_n_steps:
rank_zero_warn(
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
" you want to see logs for the training epoch.",
category=PossibleUserWarning,
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch

def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the validation dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._val_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("validation_step", pl_module)
if source.is_defined() and has_step:
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_val_dl_reload_epoch = self.current_epoch

def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the test dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._test_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("test_step", pl_module)
if source.is_defined() and has_step:
self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)

def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the predict dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._predict_dataloader_source
pl_module = self.lightning_module or model
if source.is_defined():
self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)

def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets train and val dataloaders if none are attached to the trainer.

The val dataloader must be initialized before training loop starts, as the training loop
inspects the val dataloader to determine whether to run the evaluation loop.
Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
if self.train_dataloader is None:
self.reset_train_dataloader(model=model)
if self.val_dataloaders is None:
self.reset_val_dataloader(model=model)


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
Expand Down
19 changes: 18 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from tests.helpers.boring_model import BoringModel
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator

Expand Down Expand Up @@ -271,6 +272,22 @@ def test_v1_8_0_deprecated_training_type_plugin_property():
trainer.training_type_plugin


def test_v1_8_0_deprecate_trainer_data_loading_mixin():
trainer = Trainer(max_epochs=1)
model = BoringModel()
dm = BoringDataModule()
trainer.fit(model, datamodule=dm)

with pytest.deprecated_call(
match=r"`TrainerDataLoadingMixin.prepare_dataloader` was deprecated in v1.6 and will be removed in v1.8.",
):
trainer.prepare_dataloader(dataloader=model.train_dataloader, shuffle=False)
with pytest.deprecated_call(
match=r"`TrainerDataLoadingMixin.request_dataloader` was deprecated in v1.6 and will be removed in v1.8.",
):
trainer.request_dataloader(stage=RunningStage.TRAINING)


def test_v_1_8_0_deprecated_device_stats_monitor_prefix_metric_keys():
from pytorch_lightning.callbacks.device_stats_monitor import prefix_metric_keys

Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/flags/test_limit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_eval_limit_batches(stage, mode, limit_batches):
trainer = Trainer(**{limit_eval_batches: limit_batches})
model.trainer = trainer
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model)
expected_batches = int(limit_batches * len(eval_loader)) if isinstance(limit_batches, float) else limit_batches
assert loader_num_batches[0] == expected_batches
assert len(dataloaders[0]) == len(eval_loader)
2 changes: 1 addition & 1 deletion tests/trainer/flags/test_overfit_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_overfit_batch_limits_eval(stage, mode, overfit_batches):
model.trainer = trainer
trainer._data_connector.attach_datamodule(model, datamodule=dm)

loader_num_batches, dataloaders = trainer._reset_eval_dataloader(stage, model=model)
loader_num_batches, dataloaders = trainer._data_connector._reset_eval_dataloader(stage, model=model)
if stage == RunningStage.VALIDATING:
assert loader_num_batches[0] == 0
else:
Expand Down
12 changes: 6 additions & 6 deletions tests/trainer/test_data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,17 +248,17 @@ def __init__(
class CustomDummyObj:
sampler = None

result = trainer.prepare_dataloader(CustomDummyObj(), shuffle=True)
result = trainer._data_connector._prepare_dataloader(CustomDummyObj(), shuffle=True)
assert isinstance(result, CustomDummyObj), "Wrongly reinstantiated data loader"

dataset = list(range(10))
result = trainer.prepare_dataloader(CustomDataLoader(dataset), shuffle=True)
result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset), shuffle=True)
assert isinstance(result, DataLoader)
assert isinstance(result, CustomDataLoader)
assert result.dummy_kwarg is None

# Shuffled DataLoader should also work
result = trainer.prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True)
result = trainer._data_connector._prepare_dataloader(CustomDataLoader(dataset, shuffle=True), shuffle=True)
assert isinstance(result, DataLoader)
assert isinstance(result, CustomDataLoader)
assert result.dummy_kwarg is None
Expand All @@ -269,7 +269,7 @@ class CustomSampler(Sampler):
# Should raise an error if existing sampler is being replaced
dataloader = CustomDataLoader(dataset, sampler=CustomSampler(dataset))
with pytest.raises(MisconfigurationException, match="will be replaced by `DistributedSampler`"):
trainer.prepare_dataloader(dataloader, shuffle=True)
trainer._data_connector._prepare_dataloader(dataloader, shuffle=True)


class LoaderTestModel(BoringModel):
Expand Down Expand Up @@ -351,7 +351,7 @@ def test_error_raised_with_float_limited_eval_batches():
MisconfigurationException,
match=fr"{limit_val_batches} \* {dl_size} < 1. Please increase the `limit_val_batches`",
):
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model)


@pytest.mark.parametrize(
Expand All @@ -375,4 +375,4 @@ def test_non_sequential_sampler_warning_is_raised_for_eval_dataloader(val_dl):
model = BoringModel()
trainer._data_connector.attach_data(model, val_dataloaders=val_dl)
with pytest.warns(PossibleUserWarning, match="recommended .* turn this off for val/test/predict"):
trainer._reset_eval_dataloader(RunningStage.VALIDATING, model)
trainer._data_connector._reset_eval_dataloader(RunningStage.VALIDATING, model)
4 changes: 2 additions & 2 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def train_dataloader(self):
@RunIf(skip_windows=True)
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("stage", ("train", "test", "val"))
@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4)
@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
"""Test that error is raised if dataloader with only a few workers is used."""

Expand Down Expand Up @@ -593,7 +593,7 @@ def test_warning_with_few_workers(_, tmpdir, ckpt_path, stage):
@RunIf(skip_windows=True)
@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("stage", ("train", "test", "val"))
@patch("pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count", return_value=4)
@patch("pytorch_lightning.trainer.connectors.data_connector.multiprocessing.cpu_count", return_value=4)
def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):
"""Test that error is raised if dataloader with only a few workers is used."""

Expand Down
Loading