Skip to content

Commit e94dcf6

Browse files
Mark trainer.data_connector as protected (#10031)
Co-authored-by: tchaton <[email protected]>
1 parent f95ba20 commit e94dcf6

File tree

14 files changed

+46
-46
lines changed

14 files changed

+46
-46
lines changed

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
101101

102102
dataloader_idx: int = self.current_dataloader_idx
103103
dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
104-
dataloader = self.trainer.data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
104+
dataloader = self.trainer._data_connector.get_profiled_dataloader(dataloader, dataloader_idx=dataloader_idx)
105105
dl_max_batches = self._max_batches[dataloader_idx]
106106

107107
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)

pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def advance(
107107
if batch is None:
108108
raise StopIteration
109109

110-
if not self.trainer.data_connector.evaluation_data_fetcher.store_on_device:
110+
if not self.trainer._data_connector.evaluation_data_fetcher.store_on_device:
111111
with self.trainer.profiler.profile("evaluation_batch_to_device"):
112112
batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx)
113113

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def advance(self, *args: Any, **kwargs: Any) -> None:
147147

148148
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)
149149

150-
if not self.trainer.data_connector.train_data_fetcher.store_on_device:
150+
if not self.trainer._data_connector.train_data_fetcher.store_on_device:
151151
with self.trainer.profiler.profile("training_batch_to_device"):
152152
batch = self.trainer.accelerator.batch_to_device(batch)
153153

pytorch_lightning/loops/fit_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def on_advance_start(self) -> None:
212212
def advance(self) -> None:
213213
"""Runs one whole epoch."""
214214
dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
215-
data_fetcher = self.trainer.data_connector.get_profiled_dataloader(dataloader)
215+
data_fetcher = self.trainer._data_connector.get_profiled_dataloader(dataloader)
216216

217217
with self.trainer.profiler.profile("run_training_epoch"):
218218
self.epoch_loop.run(data_fetcher)

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def _auto_select_batch_size(self):
623623
# train_micro_batch_size_per_gpu is used for throughput logging purposes
624624
# by default we try to use the batch size of the loader
625625
batch_size = 1
626-
train_dl_source = self.lightning_module.trainer.data_connector._train_dataloader_source
626+
train_dl_source = self.lightning_module.trainer._data_connector._train_dataloader_source
627627
if train_dl_source.is_defined():
628628
train_dataloader = train_dl_source.dataloader()
629629
if hasattr(train_dataloader, "batch_sampler"):

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _validate_dataloader(dataloaders: Union[List[DataLoader], DataLoader]) -> No
9595
@staticmethod
9696
def _validate_patched_dataloaders(model: "pl.LightningModule") -> None:
9797
"""Validate and fail fast if the dataloaders were passed directly to fit."""
98-
connector: DataConnector = model.trainer.data_connector
98+
connector: DataConnector = model.trainer._data_connector
9999
sources = (
100100
connector._train_dataloader_source,
101101
connector._val_dataloader_source,

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __verify_train_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightnin
6565
# -----------------------------------
6666
# verify model has a train dataloader
6767
# -----------------------------------
68-
has_train_dataloader = trainer.data_connector._train_dataloader_source.is_defined()
68+
has_train_dataloader = trainer._data_connector._train_dataloader_source.is_defined()
6969
if not has_train_dataloader:
7070
raise MisconfigurationException(
7171
"No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a"
@@ -176,7 +176,7 @@ def __verify_eval_loop_configuration(model: "pl.LightningModule", stage: str) ->
176176

177177

178178
def __verify_predict_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
179-
has_predict_dataloader = trainer.data_connector._predict_dataloader_source.is_defined()
179+
has_predict_dataloader = trainer._data_connector._predict_dataloader_source.is_defined()
180180
if not has_predict_dataloader:
181181
raise MisconfigurationException("Dataloader not found for `Trainer.predict`")
182182
# ----------------------------------------------

pytorch_lightning/trainer/data_loading.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -
343343
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)
344344

345345
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
346-
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)
346+
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)
347347

348348
self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float("inf")
349349

@@ -488,7 +488,7 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
488488
Args:
489489
model: The `LightningModule` if called outside of the trainer scope.
490490
"""
491-
source = self.data_connector._val_dataloader_source
491+
source = self._data_connector._val_dataloader_source
492492
pl_module = self.lightning_module or model
493493
has_step = is_overridden("validation_step", pl_module)
494494
if source.is_defined() and has_step:
@@ -502,7 +502,7 @@ def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
502502
Args:
503503
model: The `LightningModule` if called outside of the trainer scope.
504504
"""
505-
source = self.data_connector._test_dataloader_source
505+
source = self._data_connector._test_dataloader_source
506506
pl_module = self.lightning_module or model
507507
has_step = is_overridden("test_step", pl_module)
508508
if source.is_defined() and has_step:
@@ -516,7 +516,7 @@ def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None)
516516
Args:
517517
model: The `LightningModule` if called outside of the trainer scope.
518518
"""
519-
source = self.data_connector._predict_dataloader_source
519+
source = self._data_connector._predict_dataloader_source
520520
pl_module = self.lightning_module or model
521521
if source.is_defined():
522522
self.num_predict_batches, self.predict_dataloaders = self._reset_eval_dataloader(
@@ -545,7 +545,7 @@ def request_dataloader(
545545
Returns:
546546
The requested dataloader
547547
"""
548-
source = getattr(self.data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
548+
source = getattr(self._data_connector, f"_{stage.dataloader_prefix}_dataloader_source")
549549

550550
hook = f"{stage.dataloader_prefix}_dataloader"
551551
self.call_hook("on_" + hook, pl_module=model)

pytorch_lightning/trainer/trainer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def __init__(
424424
gpu_ids, tpu_cores = self._parse_devices(gpus, auto_select_gpus, tpu_cores)
425425

426426
# init connectors
427-
self.data_connector = DataConnector(self, multiple_trainloader_mode)
427+
self._data_connector = DataConnector(self, multiple_trainloader_mode)
428428
self.optimizer_connector = OptimizerConnector(self)
429429

430430
self.accelerator_connector = AcceleratorConnector(
@@ -514,7 +514,7 @@ def __init__(
514514
self.optimizer_connector.on_trainer_init()
515515

516516
# init data flags
517-
self.data_connector.on_trainer_init(
517+
self._data_connector.on_trainer_init(
518518
check_val_every_n_epoch,
519519
reload_dataloaders_every_n_epochs,
520520
reload_dataloaders_every_epoch,
@@ -663,7 +663,7 @@ def _fit_impl(
663663
)
664664

665665
# links data to the trainer
666-
self.data_connector.attach_data(
666+
self._data_connector.attach_data(
667667
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
668668
)
669669

@@ -747,7 +747,7 @@ def _validate_impl(
747747
)
748748

749749
# links data to the trainer
750-
self.data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
750+
self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule)
751751

752752
self.validated_ckpt_path = self.__set_ckpt_path(
753753
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
@@ -837,7 +837,7 @@ def _test_impl(
837837
)
838838

839839
# links data to the trainer
840-
self.data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
840+
self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule)
841841

842842
self.tested_ckpt_path = self.__set_ckpt_path(
843843
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
@@ -921,7 +921,7 @@ def _predict_impl(
921921
)
922922

923923
# links data to the trainer
924-
self.data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
924+
self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule)
925925

926926
self.predicted_ckpt_path = self.__set_ckpt_path(
927927
ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None
@@ -985,7 +985,7 @@ def tune(
985985
)
986986

987987
# links data to the trainer
988-
self.data_connector.attach_data(
988+
self._data_connector.attach_data(
989989
model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
990990
)
991991

@@ -1027,7 +1027,7 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
10271027
self.training_type_plugin.connect(model)
10281028

10291029
# hook
1030-
self.data_connector.prepare_data()
1030+
self._data_connector.prepare_data()
10311031
self.callback_connector._attach_model_callbacks()
10321032

10331033
if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
@@ -1171,7 +1171,7 @@ def _post_dispatch(self):
11711171
# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns
11721172
# which need to happen before.
11731173
self.accelerator.teardown()
1174-
self.data_connector.teardown()
1174+
self._data_connector.teardown()
11751175
self._active_loop.teardown()
11761176
self.logger_connector.teardown()
11771177

@@ -1258,7 +1258,7 @@ def _run_predict(self) -> Optional[_PREDICT_OUTPUT]:
12581258
return self.predict_loop.run()
12591259

12601260
def _run_sanity_check(self, ref_model):
1261-
using_val_step = self.data_connector._val_dataloader_source.is_defined() and is_overridden(
1261+
using_val_step = self._data_connector._val_dataloader_source.is_defined() and is_overridden(
12621262
"validation_step", ref_model
12631263
)
12641264
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def scale_batch_size(
5151
" If this is not the intended behavior, please remove either one."
5252
)
5353

54-
if not trainer.data_connector._train_dataloader_source.is_module():
54+
if not trainer._data_connector._train_dataloader_source.is_module():
5555
raise MisconfigurationException(
5656
"The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`."
5757
" Please disable the feature or incorporate the dataloader into the model."

tests/core/test_datamodules.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_can_prepare_data(local_rank, node_rank):
5151
local_rank.return_value = 0
5252
assert trainer.local_rank == 0
5353

54-
trainer.data_connector.prepare_data()
54+
trainer._data_connector.prepare_data()
5555
assert dm.random_full is not None
5656

5757
# local rank = 1 (False)
@@ -60,7 +60,7 @@ def test_can_prepare_data(local_rank, node_rank):
6060
local_rank.return_value = 1
6161
assert trainer.local_rank == 1
6262

63-
trainer.data_connector.prepare_data()
63+
trainer._data_connector.prepare_data()
6464
assert dm.random_full is None
6565

6666
# prepare_data_per_node = False (prepare across all nodes)
@@ -71,7 +71,7 @@ def test_can_prepare_data(local_rank, node_rank):
7171
node_rank.return_value = 0
7272
local_rank.return_value = 0
7373

74-
trainer.data_connector.prepare_data()
74+
trainer._data_connector.prepare_data()
7575
assert dm.random_full is not None
7676

7777
# global rank = 1 (False)
@@ -80,13 +80,13 @@ def test_can_prepare_data(local_rank, node_rank):
8080
node_rank.return_value = 1
8181
local_rank.return_value = 0
8282

83-
trainer.data_connector.prepare_data()
83+
trainer._data_connector.prepare_data()
8484
assert dm.random_full is None
8585

8686
node_rank.return_value = 0
8787
local_rank.return_value = 1
8888

89-
trainer.data_connector.prepare_data()
89+
trainer._data_connector.prepare_data()
9090
assert dm.random_full is None
9191

9292
# 2 dm
@@ -100,13 +100,13 @@ def test_can_prepare_data(local_rank, node_rank):
100100
# has been called
101101
# False
102102
dm._has_prepared_data = True
103-
trainer.data_connector.prepare_data()
103+
trainer._data_connector.prepare_data()
104104
dm_mock.assert_not_called()
105105

106106
# has not been called
107107
# True
108108
dm._has_prepared_data = False
109-
trainer.data_connector.prepare_data()
109+
trainer._data_connector.prepare_data()
110110
dm_mock.assert_called_once()
111111

112112

@@ -629,7 +629,7 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
629629
trainer = Trainer(prepare_data_per_node=False)
630630
trainer.model = model
631631
trainer.datamodule = dm
632-
trainer.data_connector.prepare_data()
632+
trainer._data_connector.prepare_data()
633633

634634

635635
DATALOADER = DataLoader(RandomDataset(1, 32))

tests/plugins/test_tpu_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_error_iterable_dataloaders_passed_to_fit(
6565
model = BoringModelNoDataloaders()
6666
model.trainer = trainer
6767

68-
trainer.data_connector.attach_dataloaders(
68+
trainer._data_connector.attach_dataloaders(
6969
model,
7070
train_dataloaders=train_dataloaders,
7171
val_dataloaders=val_dataloaders,

tests/trainer/test_trainer_tricks.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_overfit_batch_limits(tmpdir):
8585
# ------------------------------------------------------
8686
trainer = Trainer(overfit_batches=4)
8787
model.trainer = trainer
88-
trainer.data_connector.attach_dataloaders(model=model)
88+
trainer._data_connector.attach_dataloaders(model=model)
8989
trainer.reset_train_dataloader(model)
9090
assert trainer.num_training_batches == 4
9191

@@ -96,7 +96,7 @@ def test_overfit_batch_limits(tmpdir):
9696

9797
trainer = Trainer(overfit_batches=0.11)
9898
model.trainer = trainer
99-
trainer.data_connector.attach_dataloaders(model=model)
99+
trainer._data_connector.attach_dataloaders(model=model)
100100
trainer.reset_train_dataloader(model)
101101
# The dataloader should have been overwritten with a Sequential sampler.
102102
assert trainer.train_dataloader is not train_loader
@@ -116,7 +116,7 @@ def test_overfit_batch_limits(tmpdir):
116116
# test overfit_batches as percent
117117
# ------------------------------------------------------
118118
trainer = Trainer(overfit_batches=0.11)
119-
trainer.data_connector.attach_dataloaders(model)
119+
trainer._data_connector.attach_dataloaders(model)
120120
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
121121
assert loader_num_batches[0] == num_train_samples
122122

@@ -132,11 +132,11 @@ def test_overfit_batch_limits(tmpdir):
132132
# test overfit_batches as int
133133
# ------------------------------------------------------
134134
trainer = Trainer(overfit_batches=1)
135-
trainer.data_connector.attach_dataloaders(model)
135+
trainer._data_connector.attach_dataloaders(model)
136136
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
137137
assert loader_num_batches[0] == 1
138138
trainer = Trainer(overfit_batches=5)
139-
trainer.data_connector.attach_dataloaders(model)
139+
trainer._data_connector.attach_dataloaders(model)
140140
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
141141
assert loader_num_batches[0] == 5
142142

@@ -145,21 +145,21 @@ def test_overfit_batch_limits(tmpdir):
145145
# ------------------------------------------------------
146146
if split == RunningStage.VALIDATING:
147147
trainer = Trainer(limit_val_batches=0.1)
148-
trainer.data_connector.attach_dataloaders(model)
148+
trainer._data_connector.attach_dataloaders(model)
149149
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
150150
assert loader_num_batches[0] == int(0.1 * len(val_loader))
151151

152152
trainer = Trainer(limit_val_batches=10)
153-
trainer.data_connector.attach_dataloaders(model)
153+
trainer._data_connector.attach_dataloaders(model)
154154
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
155155
assert loader_num_batches[0] == 10
156156
else:
157157
trainer = Trainer(limit_test_batches=0.1)
158-
trainer.data_connector.attach_dataloaders(model)
158+
trainer._data_connector.attach_dataloaders(model)
159159
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
160160
assert loader_num_batches[0] == int(0.1 * len(test_loader))
161161

162162
trainer = Trainer(limit_test_batches=10)
163-
trainer.data_connector.attach_dataloaders(model)
163+
trainer._data_connector.attach_dataloaders(model)
164164
loader_num_batches, dataloaders = trainer._reset_eval_dataloader(split, model=model)
165165
assert loader_num_batches[0] == 10

tests/utilities/test_fetching.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ def __init__(self, check_inter_batch: bool):
185185

186186
def on_train_epoch_end(self, trainer, lightning_module):
187187
if self._check_inter_batch:
188-
assert isinstance(trainer.data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
188+
assert isinstance(trainer._data_connector.train_data_fetcher, InterBatchParallelDataFetcher)
189189
else:
190-
assert isinstance(trainer.data_connector.train_data_fetcher, DataFetcher)
190+
assert isinstance(trainer._data_connector.train_data_fetcher, DataFetcher)
191191

192192
trainer_kwargs = dict(
193193
default_root_dir=tmpdir,
@@ -232,7 +232,7 @@ def __init__(self, *args, automatic_optimization: bool = False, **kwargs):
232232

233233
def training_step(self, dataloader_iter, batch_idx):
234234
assert self.count == batch_idx
235-
assert isinstance(self.trainer.data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
235+
assert isinstance(self.trainer._data_connector.train_data_fetcher, DataLoaderIterDataFetcher)
236236
# fetch 2 batches
237237
self.batches.append(next(dataloader_iter))
238238
self.batches.append(next(dataloader_iter))
@@ -255,7 +255,7 @@ def training_step(self, dataloader_iter, batch_idx):
255255

256256
def training_epoch_end(self, *_):
257257
assert self.trainer.fit_loop.epoch_loop.batch_progress.current.ready == 33
258-
assert self.trainer.data_connector.train_data_fetcher.fetched == 64
258+
assert self.trainer._data_connector.train_data_fetcher.fetched == 64
259259
assert self.count == 64
260260

261261
model = TestModel(automatic_optimization=automatic_optimization)

0 commit comments

Comments
 (0)