Skip to content

Mark trainer.data_connector as protected #10031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

dataloader_idx: int = self.current_dataloader_idx
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dataloader = self.trainer._data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
dl_max_batches = self._max_batches[dataloader_idx]

dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def advance(
if batch is None:
raise StopIteration

if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device:
if not self.trainer._data_connector.evaluation_data_fetcher.store_on_device:
with self.trainer.profiler.profile("evaluation_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:

batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)

if not self.trainer.data_connector.train_data_fetcher.store_on_device:
if not self.trainer._data_connector.train_data_fetcher.store_on_device:
with self.trainer.profiler.profile("training_batch_to_device"):
batch = self.trainer.accelerator.batch_to_device(batch)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def on_advance_start(self) -> None:
def advance(self) -> None:
"""Runs one whole epoch."""
dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)

with self.trainer.profiler.profile("run_training_epoch"):
self.epoch_loop.run(data_fetcher)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ def _auto_select_batch_size(self):
# train_micro_batch_size_per_gpu is used for throughput logging purposes
# by default we try to use the batch size of the loader
batch_size = 1
train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source
train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source
if train_dl_source.is_defined():
train_dataloader = train_dl_source.dataloader()
if hasattr(train_dataloader, "batch_sampler"):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No
@staticmethod
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
"""Validate and fail fast if the dataloaders were passed directly to fit."""
connector: DataConnector = model.trainer.data_connector
connector: DataConnector = model.trainer._data_connector
sources = (
connector._train_dataloader_source,
connector._val_dataloader_source,
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
# -----------------------------------
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = trainer.data_connector._train_dataloader_source.is_defined()
has_train_dataloader = trainer._data_connector._train_dataloader_source.is_defined()
if not has_train_dataloader:
raise MisconfigurationException(
"No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
Expand Down Expand Up @@ -176,7 +176,7 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) ->


def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
has_predict_dataloader = trainer.data_connector._predict_dataloader_source.is_defined()
has_predict_dataloader = trainer._data_connector._predict_dataloader_source.is_defined()
if not has_predict_dataloader:
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
# ----------------------------------------------
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf")

Expand Down Expand Up @@ -488,7 +488,7 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._val_dataloader_source
source = self._data_connector._val_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("validation_step", pl_module)
if source.is_defined() and has_step:
Expand All @@ -502,7 +502,7 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._test_dataloader_source
source = self._data_connector._test_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("test_step", pl_module)
if source.is_defined() and has_step:
Expand All @@ -516,7 +516,7 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None)
Args:
model: The `LightningModule` if called outside of the trainer scope.
"""
source = self.data_connector._predict_dataloader_source
source = self._data_connector._predict_dataloader_source
pl_module = self.lightning_module or model
if source.is_defined():
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(
Expand Down Expand Up @@ -545,7 +545,7 @@ def request_dataloader(
Returns:
The requested dataloader
"""
source = getattr(self.data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")

hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def __init__(
gpu_ids, tpu_cores = self._parse_devices(gpus, auto_select_gpus, tpu_cores)

# init connectors
self.data_connector = DataConnector(self, multiple_trainloader_mode)
self._data_connector = DataConnector(self, multiple_trainloader_mode)
self.optimizer_connector = OptimizerConnector(self)

self.accelerator_connector = AcceleratorConnector(
Expand Down Expand Up @@ -514,7 +514,7 @@ def __init__(
self.optimizer_connector.on_trainer_init()

# init data flags
self.data_connector.on_trainer_init(
self._data_connector.on_trainer_init(
check_val_every_n_epoch,
reload_dataloaders_every_n_epochs,
reload_dataloaders_every_epoch,
Expand Down Expand Up @@ -663,7 +663,7 @@ def _fit_impl(
)

# links data to the trainer
self.data_connector.attach_data(
self._data_connector.attach_data(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

Expand Down Expand Up @@ -747,7 +747,7 @@ def _validate_impl(
)

# links data to the trainer
self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)

self.validated_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
Expand Down Expand Up @@ -837,7 +837,7 @@ def _test_impl(
)

# links data to the trainer
self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)

self.tested_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
Expand Down Expand Up @@ -921,7 +921,7 @@ def _predict_impl(
)

# links data to the trainer
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)

self.predicted_ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
Expand Down Expand Up @@ -985,7 +985,7 @@ def tune(
)

# links data to the trainer
self.data_connector.attach_data(
self._data_connector.attach_data(
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
)

Expand Down Expand Up @@ -1027,7 +1027,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
self.training_type_plugin.connect(model)

# hook
self.data_connector.prepare_data()
self._data_connector.prepare_data()
self.callback_connector._attach_model_callbacks()

if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def _post_dispatch(self):
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
# which need to happen before.
self.accelerator.teardown()
self.data_connector.teardown()
self._data_connector.teardown()
self._active_loop.teardown()
self.logger_connector.teardown()

Expand Down Expand Up @@ -1258,7 +1258,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
return self.predict_loop.run()

def _run_sanity_check(self, ref_model):
using_val_step = self.data_connector._val_dataloader_source.is_defined() and is_overridden(
using_val_step = self._data_connector._val_dataloader_source.is_defined() and is_overridden(
"validation_step", ref_model
)
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def scale_batch_size(
" If this is not the intended behavior, please remove either one."
)

if not trainer.data_connector._train_dataloader_source.is_module():
if not trainer._data_connector._train_dataloader_source.is_module():
raise MisconfigurationException(
"The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`."
" Please disable the feature or incorporate the dataloader into the model."
Expand Down
16 changes: 8 additions & 8 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_can_prepare_data(local_rank, node_rank):
local_rank.return_value = 0
assert trainer.local_rank == 0

trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()
assert dm.random_full is not None

# local rank = 1 (False)
Expand All @@ -60,7 +60,7 @@ def test_can_prepare_data(local_rank, node_rank):
local_rank.return_value = 1
assert trainer.local_rank == 1

trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()
assert dm.random_full is None

# prepare_data_per_node = False (prepare across all nodes)
Expand All @@ -71,7 +71,7 @@ def test_can_prepare_data(local_rank, node_rank):
node_rank.return_value = 0
local_rank.return_value = 0

trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()
assert dm.random_full is not None

# global rank = 1 (False)
Expand All @@ -80,13 +80,13 @@ def test_can_prepare_data(local_rank, node_rank):
node_rank.return_value = 1
local_rank.return_value = 0

trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()
assert dm.random_full is None

node_rank.return_value = 0
local_rank.return_value = 1

trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()
assert dm.random_full is None

# 2 dm
Expand All @@ -100,13 +100,13 @@ def test_can_prepare_data(local_rank, node_rank):
# has been called
# False
dm._has_prepared_data = True
trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()
dm_mock.assert_not_called()

# has not been called
# True
dm._has_prepared_data = False
trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()
dm_mock.assert_called_once()


Expand Down Expand Up @@ -629,7 +629,7 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
trainer = Trainer(prepare_data_per_node=False)
trainer.model = model
trainer.datamodule = dm
trainer.data_connector.prepare_data()
trainer._data_connector.prepare_data()


DATALOADER = DataLoader(RandomDataset(1, 32))
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_error_iterable_dataloaders_passed_to_fit(
model = BoringModelNoDataloaders()
model.trainer = trainer

trainer.data_connector.attach_dataloaders(
trainer._data_connector.attach_dataloaders(
model,
train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
Expand Down
18 changes: 9 additions & 9 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
trainer = Trainer(overfit_batches=4)
model.trainer = trainer
trainer.data_connector.attach_dataloaders(model=model)
trainer._data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
assert trainer.num_training_batches == 4

Expand All @@ -96,7 +96,7 @@ def test_overfit_batch_limits(tmpdir):

trainer = Trainer(overfit_batches=0.11)
model.trainer = trainer
trainer.data_connector.attach_dataloaders(model=model)
trainer._data_connector.attach_dataloaders(model=model)
trainer.reset_train_dataloader(model)
# The dataloader should have been overwritten with a Sequential sampler.
assert trainer.train_dataloader is not train_loader
Expand All @@ -116,7 +116,7 @@ def test_overfit_batch_limits(tmpdir):
# test overfit_batches as percent
# ------------------------------------------------------
trainer = Trainer(overfit_batches=0.11)
trainer.data_connector.attach_dataloaders(model)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == num_train_samples

Expand All @@ -132,11 +132,11 @@ def test_overfit_batch_limits(tmpdir):
# test overfit_batches as int
# ------------------------------------------------------
trainer = Trainer(overfit_batches=1)
trainer.data_connector.attach_dataloaders(model)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 1
trainer = Trainer(overfit_batches=5)
trainer.data_connector.attach_dataloaders(model)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 5

Expand All @@ -145,21 +145,21 @@ def test_overfit_batch_limits(tmpdir):
# ------------------------------------------------------
if split == RunningStage.VALIDATING:
trainer = Trainer(limit_val_batches=0.1)
trainer.data_connector.attach_dataloaders(model)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == int(0.1 * len(val_loader))

trainer = Trainer(limit_val_batches=10)
trainer.data_connector.attach_dataloaders(model)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 10
else:
trainer = Trainer(limit_test_batches=0.1)
trainer.data_connector.attach_dataloaders(model)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == int(0.1 * len(test_loader))

trainer = Trainer(limit_test_batches=10)
trainer.data_connector.attach_dataloaders(model)
trainer._data_connector.attach_dataloaders(model)
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
assert loader_num_batches[0] == 10
8 changes: 4 additions & 4 deletions tests/utilities/test_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,9 @@ def __init__(self, check_inter_batch: bool):

def on_train_epoch_end(self, trainer, lightning_module):
if self._check_inter_batch:
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
assert isinstance(trainer._data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
else:
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
assert isinstance(trainer._data_connector.train_data_fetcher, DataFetcher)

trainer_kwargs = dict(
default_root_dir=tmpdir,
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs):

def training_step(self, dataloader_iter, batch_idx):
assert self.count == batch_idx
assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
assert isinstance(self.trainer._data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
# fetch 2 batches
self.batches.append(next(dataloader_iter))
self.batches.append(next(dataloader_iter))
Expand All @@ -255,7 +255,7 @@ def training_step(self, dataloader_iter, batch_idx):

def training_epoch_end(self, *_):
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
assert self.trainer.data_connector.train_data_fetcher.fetched == 64
assert self.trainer._data_connector.train_data_fetcher.fetched == 64
assert self.count == 64

model = TestModel(automatic_optimization=automatic_optimization)
Expand Down