Skip to content

Commit 57567ed

Browse files
Move newly added Trainer methods to be with other methods (#11335)
1 parent 42a1c72 commit 57567ed

File tree

1 file changed

+149
-145
lines changed

1 file changed

+149
-145
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 149 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,155 @@ def _on_exception(self) -> None:
17481748
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
17491749
self.save_checkpoint(file_path)
17501750

1751+
"""
1752+
Data loading methods
1753+
"""
1754+
1755+
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
1756+
"""Resets the train dataloader and initialises required variables (number of batches, when to validate,
1757+
etc.).
1758+
1759+
Args:
1760+
model: The ``LightningModule`` if calling this outside of the trainer scope.
1761+
"""
1762+
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)
1763+
1764+
if self.overfit_batches > 0:
1765+
self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader)
1766+
1767+
# automatically add samplers
1768+
self.train_dataloader = apply_to_collection(
1769+
self.train_dataloader,
1770+
DataLoader,
1771+
self._data_connector._prepare_dataloader,
1772+
shuffle=True,
1773+
mode=RunningStage.TRAINING,
1774+
)
1775+
1776+
# check the workers recursively
1777+
apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader")
1778+
1779+
# add worker_init_fn for correct seeding in worker processes
1780+
apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)
1781+
1782+
# add collate_fn to collect metadata for fault tolerant training
1783+
if _fault_tolerant_training():
1784+
apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate)
1785+
1786+
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
1787+
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)
1788+
1789+
module = model or self.lightning_module or self.datamodule
1790+
self.num_training_batches = (
1791+
len(self.train_dataloader)
1792+
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
1793+
else float("inf")
1794+
)
1795+
1796+
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
1797+
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
1798+
elif self.num_training_batches != float("inf"):
1799+
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
1800+
elif self.limit_train_batches != 1.0:
1801+
raise MisconfigurationException(
1802+
"When using an IterableDataset for `limit_train_batches`,"
1803+
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
1804+
" `num_training_batches` to use."
1805+
)
1806+
1807+
# determine when to check validation
1808+
# if int passed in, val checks that often
1809+
# otherwise, it checks in [0, 1.0] % range of a training epoch
1810+
if isinstance(self.val_check_interval, int):
1811+
self.val_check_batch = self.val_check_interval
1812+
if self.val_check_batch > self.num_training_batches:
1813+
raise ValueError(
1814+
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
1815+
f"to the number of the training batches ({self.num_training_batches}). "
1816+
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
1817+
)
1818+
else:
1819+
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
1820+
if self.val_check_interval == 1.0:
1821+
self.val_check_batch = float("inf")
1822+
else:
1823+
raise MisconfigurationException(
1824+
"When using an IterableDataset for `train_dataloader`,"
1825+
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
1826+
" checking validation every k training batches."
1827+
)
1828+
else:
1829+
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
1830+
self.val_check_batch = max(1, self.val_check_batch)
1831+
1832+
if self.logger and self.num_training_batches < self.log_every_n_steps:
1833+
rank_zero_warn(
1834+
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
1835+
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
1836+
" you want to see logs for the training epoch.",
1837+
category=PossibleUserWarning,
1838+
)
1839+
1840+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
1841+
self._last_train_dl_reload_epoch = self.current_epoch
1842+
1843+
def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
1844+
"""Resets the validation dataloader and determines the number of batches.
1845+
1846+
Args:
1847+
model: The ``LightningModule`` if called outside of the trainer scope.
1848+
"""
1849+
source = self._data_connector._val_dataloader_source
1850+
pl_module = self.lightning_module or model
1851+
has_step = is_overridden("validation_step", pl_module)
1852+
if source.is_defined() and has_step:
1853+
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
1854+
RunningStage.VALIDATING, model=pl_module
1855+
)
1856+
1857+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
1858+
self._last_val_dl_reload_epoch = self.current_epoch
1859+
1860+
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
1861+
"""Resets the test dataloader and determines the number of batches.
1862+
1863+
Args:
1864+
model: The ``LightningModule`` if called outside of the trainer scope.
1865+
"""
1866+
source = self._data_connector._test_dataloader_source
1867+
pl_module = self.lightning_module or model
1868+
has_step = is_overridden("test_step", pl_module)
1869+
if source.is_defined() and has_step:
1870+
self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
1871+
RunningStage.TESTING, model=pl_module
1872+
)
1873+
1874+
def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
1875+
"""Resets the predict dataloader and determines the number of batches.
1876+
1877+
Args:
1878+
model: The ``LightningModule`` if called outside of the trainer scope.
1879+
"""
1880+
source = self._data_connector._predict_dataloader_source
1881+
pl_module = self.lightning_module or model
1882+
if source.is_defined():
1883+
self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
1884+
RunningStage.PREDICTING, model=pl_module
1885+
)
1886+
1887+
def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
1888+
"""Resets train and val dataloaders if none are attached to the trainer.
1889+
1890+
The val dataloader must be initialized before training loop starts, as the training loop
1891+
inspects the val dataloader to determine whether to run the evaluation loop.
1892+
Args:
1893+
model: The ``LightningModule`` if called outside of the trainer scope.
1894+
"""
1895+
if self.train_dataloader is None:
1896+
self.reset_train_dataloader(model=model)
1897+
if self.val_dataloaders is None:
1898+
self.reset_val_dataloader(model=model)
1899+
17511900
"""
17521901
Accelerator properties
17531902
"""
@@ -2378,151 +2527,6 @@ def terminate_on_nan(self, val: bool) -> None:
23782527
)
23792528
self._terminate_on_nan = val # : 212
23802529

2381-
def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2382-
"""Resets the train dataloader and initialises required variables (number of batches, when to validate,
2383-
etc.).
2384-
2385-
Args:
2386-
model: The ``LightningModule`` if calling this outside of the trainer scope.
2387-
"""
2388-
self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model)
2389-
2390-
if self.overfit_batches > 0:
2391-
self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader)
2392-
2393-
# automatically add samplers
2394-
self.train_dataloader = apply_to_collection(
2395-
self.train_dataloader,
2396-
DataLoader,
2397-
self._data_connector._prepare_dataloader,
2398-
shuffle=True,
2399-
mode=RunningStage.TRAINING,
2400-
)
2401-
2402-
# check the workers recursively
2403-
apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader")
2404-
2405-
# add worker_init_fn for correct seeding in worker processes
2406-
apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank)
2407-
2408-
# add collate_fn to collect metadata for fault tolerant training
2409-
if _fault_tolerant_training():
2410-
apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate)
2411-
2412-
# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
2413-
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)
2414-
2415-
module = model or self.lightning_module or self.datamodule
2416-
self.num_training_batches = (
2417-
len(self.train_dataloader)
2418-
if has_len_all_ranks(self.train_dataloader, self.strategy, module)
2419-
else float("inf")
2420-
)
2421-
2422-
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
2423-
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
2424-
elif self.num_training_batches != float("inf"):
2425-
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
2426-
elif self.limit_train_batches != 1.0:
2427-
raise MisconfigurationException(
2428-
"When using an IterableDataset for `limit_train_batches`,"
2429-
" `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
2430-
" `num_training_batches` to use."
2431-
)
2432-
2433-
# determine when to check validation
2434-
# if int passed in, val checks that often
2435-
# otherwise, it checks in [0, 1.0] % range of a training epoch
2436-
if isinstance(self.val_check_interval, int):
2437-
self.val_check_batch = self.val_check_interval
2438-
if self.val_check_batch > self.num_training_batches:
2439-
raise ValueError(
2440-
f"`val_check_interval` ({self.val_check_interval}) must be less than or equal "
2441-
f"to the number of the training batches ({self.num_training_batches}). "
2442-
"If you want to disable validation set `limit_val_batches` to 0.0 instead."
2443-
)
2444-
else:
2445-
if not has_len_all_ranks(self.train_dataloader, self.strategy, module):
2446-
if self.val_check_interval == 1.0:
2447-
self.val_check_batch = float("inf")
2448-
else:
2449-
raise MisconfigurationException(
2450-
"When using an IterableDataset for `train_dataloader`,"
2451-
" `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
2452-
" checking validation every k training batches."
2453-
)
2454-
else:
2455-
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
2456-
self.val_check_batch = max(1, self.val_check_batch)
2457-
2458-
if self.logger and self.num_training_batches < self.log_every_n_steps:
2459-
rank_zero_warn(
2460-
f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"
2461-
f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if"
2462-
" you want to see logs for the training epoch.",
2463-
category=PossibleUserWarning,
2464-
)
2465-
2466-
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
2467-
self._last_train_dl_reload_epoch = self.current_epoch
2468-
2469-
def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2470-
"""Resets the validation dataloader and determines the number of batches.
2471-
2472-
Args:
2473-
model: The ``LightningModule`` if called outside of the trainer scope.
2474-
"""
2475-
source = self._data_connector._val_dataloader_source
2476-
pl_module = self.lightning_module or model
2477-
has_step = is_overridden("validation_step", pl_module)
2478-
if source.is_defined() and has_step:
2479-
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
2480-
RunningStage.VALIDATING, model=pl_module
2481-
)
2482-
2483-
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
2484-
self._last_val_dl_reload_epoch = self.current_epoch
2485-
2486-
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2487-
"""Resets the test dataloader and determines the number of batches.
2488-
2489-
Args:
2490-
model: The ``LightningModule`` if called outside of the trainer scope.
2491-
"""
2492-
source = self._data_connector._test_dataloader_source
2493-
pl_module = self.lightning_module or model
2494-
has_step = is_overridden("test_step", pl_module)
2495-
if source.is_defined() and has_step:
2496-
self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader(
2497-
RunningStage.TESTING, model=pl_module
2498-
)
2499-
2500-
def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
2501-
"""Resets the predict dataloader and determines the number of batches.
2502-
2503-
Args:
2504-
model: The ``LightningModule`` if called outside of the trainer scope.
2505-
"""
2506-
source = self._data_connector._predict_dataloader_source
2507-
pl_module = self.lightning_module or model
2508-
if source.is_defined():
2509-
self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader(
2510-
RunningStage.PREDICTING, model=pl_module
2511-
)
2512-
2513-
def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None:
2514-
"""Resets train and val dataloaders if none are attached to the trainer.
2515-
2516-
The val dataloader must be initialized before training loop starts, as the training loop
2517-
inspects the val dataloader to determine whether to run the evaluation loop.
2518-
Args:
2519-
model: The ``LightningModule`` if called outside of the trainer scope.
2520-
"""
2521-
if self.train_dataloader is None:
2522-
self.reset_train_dataloader(model=model)
2523-
if self.val_dataloaders is None:
2524-
self.reset_val_dataloader(model=model)
2525-
25262530

25272531
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
25282532
if 0 <= batches <= 1:

0 commit comments

Comments
 (0)