Skip to content

Move newly added Trainer methods to be with other methods #11335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 149 additions & 145 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,6 +1748,155 @@ def _on_exception(self) -> None:
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
self.save_checkpoint(file_path)

"""
Data loading methods
"""

def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the train dataloader and initialises required variables (number of batches, when to validate,
etc.).

Args:
model: The ``LightningModule`` if calling this outside of the trainer scope.
"""
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)

if self.overfit_batches > 0:
self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader)

# automatically add samplers
self.train_dataloader = apply_to_collection(
self.train_dataloader,
DataLoader,
self._data_connector._prepare_dataloader,
shuffle=True,
mode=RunningStage.TRAINING,
)

# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader")

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)

module = model or self.lightning_module or self.datamodule
self.num_training_batches = (
len(self.train_dataloader)
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
else float("inf")
)

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float("inf"):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_train_batches`,"
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
" `num_training_batches` to use."
)

# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
)
else:
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
if self.val_check_interval == 1.0:
self.val_check_batch = float("inf")
else:
raise MisconfigurationException(
"When using an IterableDataset for `train_dataloader`,"
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
" checking validation every k training batches."
)
else:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

if self.logger and self.num_training_batches < self.log_every_n_steps:
rank_zero_warn(
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
" you want to see logs for the training epoch.",
category=PossibleUserWarning,
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch

def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the validation dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._val_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("validation_step", pl_module)
if source.is_defined() and has_step:
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_val_dl_reload_epoch = self.current_epoch

def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the test dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._test_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("test_step", pl_module)
if source.is_defined() and has_step:
self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)

def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the predict dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._predict_dataloader_source
pl_module = self.lightning_module or model
if source.is_defined():
self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)

def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets train and val dataloaders if none are attached to the trainer.

The val dataloader must be initialized before training loop starts, as the training loop
inspects the val dataloader to determine whether to run the evaluation loop.
Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
if self.train_dataloader is None:
self.reset_train_dataloader(model=model)
if self.val_dataloaders is None:
self.reset_val_dataloader(model=model)

"""
Accelerator properties
"""
Expand Down Expand Up @@ -2377,151 +2526,6 @@ def terminate_on_nan(self, val: bool) -> None:
)
self._terminate_on_nan = val # : 212

def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the train dataloader and initialises required variables (number of batches, when to validate,
etc.).

Args:
model: The ``LightningModule`` if calling this outside of the trainer scope.
"""
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)

if self.overfit_batches > 0:
self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader)

# automatically add samplers
self.train_dataloader = apply_to_collection(
self.train_dataloader,
DataLoader,
self._data_connector._prepare_dataloader,
shuffle=True,
mode=RunningStage.TRAINING,
)

# check the workers recursively
apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader")

# add worker_init_fn for correct seeding in worker processes
apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)

module = model or self.lightning_module or self.datamodule
self.num_training_batches = (
len(self.train_dataloader)
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
else float("inf")
)

if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
elif self.num_training_batches != float("inf"):
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
elif self.limit_train_batches != 1.0:
raise MisconfigurationException(
"When using an IterableDataset for `limit_train_batches`,"
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
" `num_training_batches` to use."
)

# determine when to check validation
# if int passed in, val checks that often
# otherwise, it checks in [0, 1.0] % range of a training epoch
if isinstance(self.val_check_interval, int):
self.val_check_batch = self.val_check_interval
if self.val_check_batch > self.num_training_batches:
raise ValueError(
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
f"to the number of the training batches ({self.num_training_batches}). "
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
)
else:
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
if self.val_check_interval == 1.0:
self.val_check_batch = float("inf")
else:
raise MisconfigurationException(
"When using an IterableDataset for `train_dataloader`,"
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
" checking validation every k training batches."
)
else:
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
self.val_check_batch = max(1, self.val_check_batch)

if self.logger and self.num_training_batches < self.log_every_n_steps:
rank_zero_warn(
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
" you want to see logs for the training epoch.",
category=PossibleUserWarning,
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_train_dl_reload_epoch = self.current_epoch

def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the validation dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._val_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("validation_step", pl_module)
if source.is_defined() and has_step:
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.VALIDATING, model=pl_module
)

# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
self._last_val_dl_reload_epoch = self.current_epoch

def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the test dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._test_dataloader_source
pl_module = self.lightning_module or model
has_step = is_overridden("test_step", pl_module)
if source.is_defined() and has_step:
self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.TESTING, model=pl_module
)

def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets the predict dataloader and determines the number of batches.

Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
source = self._data_connector._predict_dataloader_source
pl_module = self.lightning_module or model
if source.is_defined():
self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
RunningStage.PREDICTING, model=pl_module
)

def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
"""Resets train and val dataloaders if none are attached to the trainer.

The val dataloader must be initialized before training loop starts, as the training loop
inspects the val dataloader to determine whether to run the evaluation loop.
Args:
model: The ``LightningModule`` if called outside of the trainer scope.
"""
if self.train_dataloader is None:
self.reset_train_dataloader(model=model)
if self.val_dataloaders is None:
self.reset_val_dataloader(model=model)


def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
Expand Down