|
27 | 27 | import torch
|
28 | 28 | from packaging.version import Version
|
29 | 29 | from torch.optim import Optimizer
|
| 30 | +from torch.utils.data import DataLoader |
30 | 31 |
|
31 | 32 | import pytorch_lightning as pl
|
32 | 33 | from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
|
|
70 | 71 | from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
71 | 72 | from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
72 | 73 | from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
|
| 74 | +from pytorch_lightning.trainer.supporters import CombinedLoader |
73 | 75 | from pytorch_lightning.tuner.lr_finder import _LRFinder
|
74 | 76 | from pytorch_lightning.tuner.tuning import Tuner
|
75 | 77 | from pytorch_lightning.utilities import (
|
|
85 | 87 | rank_zero_info,
|
86 | 88 | rank_zero_warn,
|
87 | 89 | )
|
| 90 | +from pytorch_lightning.utilities.apply_func import apply_to_collection |
88 | 91 | from pytorch_lightning.utilities.argparse import (
|
89 | 92 | _defaults_from_env_vars,
|
90 | 93 | add_argparse_args,
|
91 | 94 | from_argparse_args,
|
92 | 95 | parse_argparser,
|
93 | 96 | parse_env_variables,
|
94 | 97 | )
|
| 98 | +from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate |
95 | 99 | from pytorch_lightning.utilities.cloud_io import get_filesystem
|
| 100 | +from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks |
96 | 101 | from pytorch_lightning.utilities.distributed import distributed_available
|
97 | 102 | from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
|
98 | 103 | from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
|
119 | 124 |
|
120 | 125 |
|
121 | 126 | class Trainer(
|
122 |
| - TrainerCallbackHookMixin, |
123 |
| - TrainerOptimizersMixin, |
124 |
| - TrainerDataLoadingMixin, |
| 127 | + TrainerCallbackHookMixin, # TODO: Remove in v1.8 |
| 128 | + TrainerOptimizersMixin, # TODO: Remove in v1.8 |
| 129 | + TrainerDataLoadingMixin, # TODO: Remove in v1.8 |
125 | 130 | ):
|
126 | 131 | # Needed because of LightningOptimizer
|
127 | 132 | _lightning_optimizers = None
|
@@ -2372,6 +2377,151 @@ def terminate_on_nan(self, val: bool) -> None:
|
2372 | 2377 | )
|
2373 | 2378 | self._terminate_on_nan = val # : 212
|
2374 | 2379 |
|
| 2380 | + def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: |
| 2381 | + """Resets the train dataloader and initialises required variables (number of batches, when to validate, |
| 2382 | + etc.). |
| 2383 | +
|
| 2384 | + Args: |
| 2385 | + model: The ``LightningModule`` if calling this outside of the trainer scope. |
| 2386 | + """ |
| 2387 | + self.train_dataloader = self._data_connector._request_dataloader(RunningStage.TRAINING, model=model) |
| 2388 | + |
| 2389 | + if self.overfit_batches > 0: |
| 2390 | + self.train_dataloader = self._data_connector._resolve_overfit_batches(self.train_dataloader) |
| 2391 | + |
| 2392 | + # automatically add samplers |
| 2393 | + self.train_dataloader = apply_to_collection( |
| 2394 | + self.train_dataloader, |
| 2395 | + DataLoader, |
| 2396 | + self._data_connector._prepare_dataloader, |
| 2397 | + shuffle=True, |
| 2398 | + mode=RunningStage.TRAINING, |
| 2399 | + ) |
| 2400 | + |
| 2401 | + # check the workers recursively |
| 2402 | + apply_to_collection(self.train_dataloader, DataLoader, self._data_connector._worker_check, "train_dataloader") |
| 2403 | + |
| 2404 | + # add worker_init_fn for correct seeding in worker processes |
| 2405 | + apply_to_collection(self.train_dataloader, DataLoader, _auto_add_worker_init_fn, rank=self.global_rank) |
| 2406 | + |
| 2407 | + # add collate_fn to collect metadata for fault tolerant training |
| 2408 | + if _fault_tolerant_training(): |
| 2409 | + apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) |
| 2410 | + |
| 2411 | + # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches |
| 2412 | + self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) |
| 2413 | + |
| 2414 | + module = model or self.lightning_module or self.datamodule |
| 2415 | + self.num_training_batches = ( |
| 2416 | + len(self.train_dataloader) |
| 2417 | + if has_len_all_ranks(self.train_dataloader, self.strategy, module) |
| 2418 | + else float("inf") |
| 2419 | + ) |
| 2420 | + |
| 2421 | + if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0: |
| 2422 | + self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches)) |
| 2423 | + elif self.num_training_batches != float("inf"): |
| 2424 | + self.num_training_batches = int(self.num_training_batches * self.limit_train_batches) |
| 2425 | + elif self.limit_train_batches != 1.0: |
| 2426 | + raise MisconfigurationException( |
| 2427 | + "When using an IterableDataset for `limit_train_batches`," |
| 2428 | + " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies" |
| 2429 | + " `num_training_batches` to use." |
| 2430 | + ) |
| 2431 | + |
| 2432 | + # determine when to check validation |
| 2433 | + # if int passed in, val checks that often |
| 2434 | + # otherwise, it checks in [0, 1.0] % range of a training epoch |
| 2435 | + if isinstance(self.val_check_interval, int): |
| 2436 | + self.val_check_batch = self.val_check_interval |
| 2437 | + if self.val_check_batch > self.num_training_batches: |
| 2438 | + raise ValueError( |
| 2439 | + f"`val_check_interval` ({self.val_check_interval}) must be less than or equal " |
| 2440 | + f"to the number of the training batches ({self.num_training_batches}). " |
| 2441 | + "If you want to disable validation set `limit_val_batches` to 0.0 instead." |
| 2442 | + ) |
| 2443 | + else: |
| 2444 | + if not has_len_all_ranks(self.train_dataloader, self.strategy, module): |
| 2445 | + if self.val_check_interval == 1.0: |
| 2446 | + self.val_check_batch = float("inf") |
| 2447 | + else: |
| 2448 | + raise MisconfigurationException( |
| 2449 | + "When using an IterableDataset for `train_dataloader`," |
| 2450 | + " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies" |
| 2451 | + " checking validation every k training batches." |
| 2452 | + ) |
| 2453 | + else: |
| 2454 | + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) |
| 2455 | + self.val_check_batch = max(1, self.val_check_batch) |
| 2456 | + |
| 2457 | + if self.logger and self.num_training_batches < self.log_every_n_steps: |
| 2458 | + rank_zero_warn( |
| 2459 | + f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval" |
| 2460 | + f" Trainer(log_every_n_steps={self.log_every_n_steps}). Set a lower value for log_every_n_steps if" |
| 2461 | + " you want to see logs for the training epoch.", |
| 2462 | + category=PossibleUserWarning, |
| 2463 | + ) |
| 2464 | + |
| 2465 | + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs |
| 2466 | + self._last_train_dl_reload_epoch = self.current_epoch |
| 2467 | + |
| 2468 | + def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: |
| 2469 | + """Resets the validation dataloader and determines the number of batches. |
| 2470 | +
|
| 2471 | + Args: |
| 2472 | + model: The ``LightningModule`` if called outside of the trainer scope. |
| 2473 | + """ |
| 2474 | + source = self._data_connector._val_dataloader_source |
| 2475 | + pl_module = self.lightning_module or model |
| 2476 | + has_step = is_overridden("validation_step", pl_module) |
| 2477 | + if source.is_defined() and has_step: |
| 2478 | + self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader( |
| 2479 | + RunningStage.VALIDATING, model=pl_module |
| 2480 | + ) |
| 2481 | + |
| 2482 | + # store epoch of dataloader reset for reload_dataloaders_every_n_epochs |
| 2483 | + self._last_val_dl_reload_epoch = self.current_epoch |
| 2484 | + |
| 2485 | + def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: |
| 2486 | + """Resets the test dataloader and determines the number of batches. |
| 2487 | +
|
| 2488 | + Args: |
| 2489 | + model: The ``LightningModule`` if called outside of the trainer scope. |
| 2490 | + """ |
| 2491 | + source = self._data_connector._test_dataloader_source |
| 2492 | + pl_module = self.lightning_module or model |
| 2493 | + has_step = is_overridden("test_step", pl_module) |
| 2494 | + if source.is_defined() and has_step: |
| 2495 | + self.num_test_batches, self.test_dataloaders = self._data_connector._reset_eval_dataloader( |
| 2496 | + RunningStage.TESTING, model=pl_module |
| 2497 | + ) |
| 2498 | + |
| 2499 | + def reset_predict_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None: |
| 2500 | + """Resets the predict dataloader and determines the number of batches. |
| 2501 | +
|
| 2502 | + Args: |
| 2503 | + model: The ``LightningModule`` if called outside of the trainer scope. |
| 2504 | + """ |
| 2505 | + source = self._data_connector._predict_dataloader_source |
| 2506 | + pl_module = self.lightning_module or model |
| 2507 | + if source.is_defined(): |
| 2508 | + self.num_predict_batches, self.predict_dataloaders = self._data_connector._reset_eval_dataloader( |
| 2509 | + RunningStage.PREDICTING, model=pl_module |
| 2510 | + ) |
| 2511 | + |
| 2512 | + def reset_train_val_dataloaders(self, model: Optional["pl.LightningModule"] = None) -> None: |
| 2513 | + """Resets train and val dataloaders if none are attached to the trainer. |
| 2514 | +
|
| 2515 | + The val dataloader must be initialized before training loop starts, as the training loop |
| 2516 | + inspects the val dataloader to determine whether to run the evaluation loop. |
| 2517 | + Args: |
| 2518 | + model: The ``LightningModule`` if called outside of the trainer scope. |
| 2519 | + """ |
| 2520 | + if self.train_dataloader is None: |
| 2521 | + self.reset_train_dataloader(model=model) |
| 2522 | + if self.val_dataloaders is None: |
| 2523 | + self.reset_val_dataloader(model=model) |
| 2524 | + |
2375 | 2525 |
|
2376 | 2526 | def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
|
2377 | 2527 | if 0 <= batches <= 1:
|
|
0 commit comments