@@ -1748,6 +1748,155 @@ def _on_exception(self) -> None:
1748
1748
file_path = os .path .join (self .default_root_dir , ".pl_auto_save.ckpt" )
1749
1749
self .save_checkpoint (file_path )
1750
1750
1751
+ """
1752
+ Data loading methods
1753
+ """
1754
+
1755
+ def reset_train_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
1756
+ """Resets the train dataloader and initialises required variables (number of batches, when to validate,
1757
+ etc.).
1758
+
1759
+ Args:
1760
+ model: The ``LightningModule`` if calling this outside of the trainer scope.
1761
+ """
1762
+ self .train_dataloader = self ._data_connector ._request_dataloader (RunningStage .TRAINING , model = model )
1763
+
1764
+ if self .overfit_batches > 0 :
1765
+ self .train_dataloader = self ._data_connector ._resolve_overfit_batches (self .train_dataloader )
1766
+
1767
+ # automatically add samplers
1768
+ self .train_dataloader = apply_to_collection (
1769
+ self .train_dataloader ,
1770
+ DataLoader ,
1771
+ self ._data_connector ._prepare_dataloader ,
1772
+ shuffle = True ,
1773
+ mode = RunningStage .TRAINING ,
1774
+ )
1775
+
1776
+ # check the workers recursively
1777
+ apply_to_collection (self .train_dataloader , DataLoader , self ._data_connector ._worker_check , "train_dataloader" )
1778
+
1779
+ # add worker_init_fn for correct seeding in worker processes
1780
+ apply_to_collection (self .train_dataloader , DataLoader , _auto_add_worker_init_fn , rank = self .global_rank )
1781
+
1782
+ # add collate_fn to collect metadata for fault tolerant training
1783
+ if _fault_tolerant_training ():
1784
+ apply_to_collection (self .train_dataloader , DataLoader , _add_capture_metadata_collate )
1785
+
1786
+ # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
1787
+ self .train_dataloader = CombinedLoader (self .train_dataloader , self ._data_connector .multiple_trainloader_mode )
1788
+
1789
+ module = model or self .lightning_module or self .datamodule
1790
+ self .num_training_batches = (
1791
+ len (self .train_dataloader )
1792
+ if has_len_all_ranks (self .train_dataloader , self .strategy , module )
1793
+ else float ("inf" )
1794
+ )
1795
+
1796
+ if isinstance (self .limit_train_batches , int ) or self .limit_train_batches == 0.0 :
1797
+ self .num_training_batches = min (self .num_training_batches , int (self .limit_train_batches ))
1798
+ elif self .num_training_batches != float ("inf" ):
1799
+ self .num_training_batches = int (self .num_training_batches * self .limit_train_batches )
1800
+ elif self .limit_train_batches != 1.0 :
1801
+ raise MisconfigurationException (
1802
+ "When using an IterableDataset for `limit_train_batches`,"
1803
+ " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
1804
+ " `num_training_batches` to use."
1805
+ )
1806
+
1807
+ # determine when to check validation
1808
+ # if int passed in, val checks that often
1809
+ # otherwise, it checks in [0, 1.0] % range of a training epoch
1810
+ if isinstance (self .val_check_interval , int ):
1811
+ self .val_check_batch = self .val_check_interval
1812
+ if self .val_check_batch > self .num_training_batches :
1813
+ raise ValueError (
1814
+ f"`val_check_interval` ({ self .val_check_interval } ) must be less than or equal "
1815
+ f"to the number of the training batches ({ self .num_training_batches } ). "
1816
+ "If you want to disable validation set `limit_val_batches` to 0.0 instead."
1817
+ )
1818
+ else :
1819
+ if not has_len_all_ranks (self .train_dataloader , self .strategy , module ):
1820
+ if self .val_check_interval == 1.0 :
1821
+ self .val_check_batch = float ("inf" )
1822
+ else :
1823
+ raise MisconfigurationException (
1824
+ "When using an IterableDataset for `train_dataloader`,"
1825
+ " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
1826
+ " checking validation every k training batches."
1827
+ )
1828
+ else :
1829
+ self .val_check_batch = int (self .num_training_batches * self .val_check_interval )
1830
+ self .val_check_batch = max (1 , self .val_check_batch )
1831
+
1832
+ if self .logger and self .num_training_batches < self .log_every_n_steps :
1833
+ rank_zero_warn (
1834
+ f"The number of training samples ({ self .num_training_batches } ) is smaller than the logging interval"
1835
+ f" Trainer(log_every_n_steps={ self .log_every_n_steps } ). Set a lower value for log_every_n_steps if"
1836
+ " you want to see logs for the training epoch." ,
1837
+ category = PossibleUserWarning ,
1838
+ )
1839
+
1840
+ # store epoch of dataloader reset for reload_dataloaders_every_n_epochs
1841
+ self ._last_train_dl_reload_epoch = self .current_epoch
1842
+
1843
+ def reset_val_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
1844
+ """Resets the validation dataloader and determines the number of batches.
1845
+
1846
+ Args:
1847
+ model: The ``LightningModule`` if called outside of the trainer scope.
1848
+ """
1849
+ source = self ._data_connector ._val_dataloader_source
1850
+ pl_module = self .lightning_module or model
1851
+ has_step = is_overridden ("validation_step" , pl_module )
1852
+ if source .is_defined () and has_step :
1853
+ self .num_val_batches , self .val_dataloaders = self ._data_connector ._reset_eval_dataloader (
1854
+ RunningStage .VALIDATING , model = pl_module
1855
+ )
1856
+
1857
+ # store epoch of dataloader reset for reload_dataloaders_every_n_epochs
1858
+ self ._last_val_dl_reload_epoch = self .current_epoch
1859
+
1860
+ def reset_test_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
1861
+ """Resets the test dataloader and determines the number of batches.
1862
+
1863
+ Args:
1864
+ model: The ``LightningModule`` if called outside of the trainer scope.
1865
+ """
1866
+ source = self ._data_connector ._test_dataloader_source
1867
+ pl_module = self .lightning_module or model
1868
+ has_step = is_overridden ("test_step" , pl_module )
1869
+ if source .is_defined () and has_step :
1870
+ self .num_test_batches , self .test_dataloaders = self ._data_connector ._reset_eval_dataloader (
1871
+ RunningStage .TESTING , model = pl_module
1872
+ )
1873
+
1874
+ def reset_predict_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
1875
+ """Resets the predict dataloader and determines the number of batches.
1876
+
1877
+ Args:
1878
+ model: The ``LightningModule`` if called outside of the trainer scope.
1879
+ """
1880
+ source = self ._data_connector ._predict_dataloader_source
1881
+ pl_module = self .lightning_module or model
1882
+ if source .is_defined ():
1883
+ self .num_predict_batches , self .predict_dataloaders = self ._data_connector ._reset_eval_dataloader (
1884
+ RunningStage .PREDICTING , model = pl_module
1885
+ )
1886
+
1887
+ def reset_train_val_dataloaders (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
1888
+ """Resets train and val dataloaders if none are attached to the trainer.
1889
+
1890
+ The val dataloader must be initialized before training loop starts, as the training loop
1891
+ inspects the val dataloader to determine whether to run the evaluation loop.
1892
+ Args:
1893
+ model: The ``LightningModule`` if called outside of the trainer scope.
1894
+ """
1895
+ if self .train_dataloader is None :
1896
+ self .reset_train_dataloader (model = model )
1897
+ if self .val_dataloaders is None :
1898
+ self .reset_val_dataloader (model = model )
1899
+
1751
1900
"""
1752
1901
Accelerator properties
1753
1902
"""
@@ -2378,151 +2527,6 @@ def terminate_on_nan(self, val: bool) -> None:
2378
2527
)
2379
2528
self ._terminate_on_nan = val # : 212
2380
2529
2381
- def reset_train_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
2382
- """Resets the train dataloader and initialises required variables (number of batches, when to validate,
2383
- etc.).
2384
-
2385
- Args:
2386
- model: The ``LightningModule`` if calling this outside of the trainer scope.
2387
- """
2388
- self .train_dataloader = self ._data_connector ._request_dataloader (RunningStage .TRAINING , model = model )
2389
-
2390
- if self .overfit_batches > 0 :
2391
- self .train_dataloader = self ._data_connector ._resolve_overfit_batches (self .train_dataloader )
2392
-
2393
- # automatically add samplers
2394
- self .train_dataloader = apply_to_collection (
2395
- self .train_dataloader ,
2396
- DataLoader ,
2397
- self ._data_connector ._prepare_dataloader ,
2398
- shuffle = True ,
2399
- mode = RunningStage .TRAINING ,
2400
- )
2401
-
2402
- # check the workers recursively
2403
- apply_to_collection (self .train_dataloader , DataLoader , self ._data_connector ._worker_check , "train_dataloader" )
2404
-
2405
- # add worker_init_fn for correct seeding in worker processes
2406
- apply_to_collection (self .train_dataloader , DataLoader , _auto_add_worker_init_fn , rank = self .global_rank )
2407
-
2408
- # add collate_fn to collect metadata for fault tolerant training
2409
- if _fault_tolerant_training ():
2410
- apply_to_collection (self .train_dataloader , DataLoader , _add_capture_metadata_collate )
2411
-
2412
- # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
2413
- self .train_dataloader = CombinedLoader (self .train_dataloader , self ._data_connector .multiple_trainloader_mode )
2414
-
2415
- module = model or self .lightning_module or self .datamodule
2416
- self .num_training_batches = (
2417
- len (self .train_dataloader )
2418
- if has_len_all_ranks (self .train_dataloader , self .strategy , module )
2419
- else float ("inf" )
2420
- )
2421
-
2422
- if isinstance (self .limit_train_batches , int ) or self .limit_train_batches == 0.0 :
2423
- self .num_training_batches = min (self .num_training_batches , int (self .limit_train_batches ))
2424
- elif self .num_training_batches != float ("inf" ):
2425
- self .num_training_batches = int (self .num_training_batches * self .limit_train_batches )
2426
- elif self .limit_train_batches != 1.0 :
2427
- raise MisconfigurationException (
2428
- "When using an IterableDataset for `limit_train_batches`,"
2429
- " `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies"
2430
- " `num_training_batches` to use."
2431
- )
2432
-
2433
- # determine when to check validation
2434
- # if int passed in, val checks that often
2435
- # otherwise, it checks in [0, 1.0] % range of a training epoch
2436
- if isinstance (self .val_check_interval , int ):
2437
- self .val_check_batch = self .val_check_interval
2438
- if self .val_check_batch > self .num_training_batches :
2439
- raise ValueError (
2440
- f"`val_check_interval` ({ self .val_check_interval } ) must be less than or equal "
2441
- f"to the number of the training batches ({ self .num_training_batches } ). "
2442
- "If you want to disable validation set `limit_val_batches` to 0.0 instead."
2443
- )
2444
- else :
2445
- if not has_len_all_ranks (self .train_dataloader , self .strategy , module ):
2446
- if self .val_check_interval == 1.0 :
2447
- self .val_check_batch = float ("inf" )
2448
- else :
2449
- raise MisconfigurationException (
2450
- "When using an IterableDataset for `train_dataloader`,"
2451
- " `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies"
2452
- " checking validation every k training batches."
2453
- )
2454
- else :
2455
- self .val_check_batch = int (self .num_training_batches * self .val_check_interval )
2456
- self .val_check_batch = max (1 , self .val_check_batch )
2457
-
2458
- if self .logger and self .num_training_batches < self .log_every_n_steps :
2459
- rank_zero_warn (
2460
- f"The number of training samples ({ self .num_training_batches } ) is smaller than the logging interval"
2461
- f" Trainer(log_every_n_steps={ self .log_every_n_steps } ). Set a lower value for log_every_n_steps if"
2462
- " you want to see logs for the training epoch." ,
2463
- category = PossibleUserWarning ,
2464
- )
2465
-
2466
- # store epoch of dataloader reset for reload_dataloaders_every_n_epochs
2467
- self ._last_train_dl_reload_epoch = self .current_epoch
2468
-
2469
- def reset_val_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
2470
- """Resets the validation dataloader and determines the number of batches.
2471
-
2472
- Args:
2473
- model: The ``LightningModule`` if called outside of the trainer scope.
2474
- """
2475
- source = self ._data_connector ._val_dataloader_source
2476
- pl_module = self .lightning_module or model
2477
- has_step = is_overridden ("validation_step" , pl_module )
2478
- if source .is_defined () and has_step :
2479
- self .num_val_batches , self .val_dataloaders = self ._data_connector ._reset_eval_dataloader (
2480
- RunningStage .VALIDATING , model = pl_module
2481
- )
2482
-
2483
- # store epoch of dataloader reset for reload_dataloaders_every_n_epochs
2484
- self ._last_val_dl_reload_epoch = self .current_epoch
2485
-
2486
- def reset_test_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
2487
- """Resets the test dataloader and determines the number of batches.
2488
-
2489
- Args:
2490
- model: The ``LightningModule`` if called outside of the trainer scope.
2491
- """
2492
- source = self ._data_connector ._test_dataloader_source
2493
- pl_module = self .lightning_module or model
2494
- has_step = is_overridden ("test_step" , pl_module )
2495
- if source .is_defined () and has_step :
2496
- self .num_test_batches , self .test_dataloaders = self ._data_connector ._reset_eval_dataloader (
2497
- RunningStage .TESTING , model = pl_module
2498
- )
2499
-
2500
- def reset_predict_dataloader (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
2501
- """Resets the predict dataloader and determines the number of batches.
2502
-
2503
- Args:
2504
- model: The ``LightningModule`` if called outside of the trainer scope.
2505
- """
2506
- source = self ._data_connector ._predict_dataloader_source
2507
- pl_module = self .lightning_module or model
2508
- if source .is_defined ():
2509
- self .num_predict_batches , self .predict_dataloaders = self ._data_connector ._reset_eval_dataloader (
2510
- RunningStage .PREDICTING , model = pl_module
2511
- )
2512
-
2513
- def reset_train_val_dataloaders (self , model : Optional ["pl.LightningModule" ] = None ) -> None :
2514
- """Resets train and val dataloaders if none are attached to the trainer.
2515
-
2516
- The val dataloader must be initialized before training loop starts, as the training loop
2517
- inspects the val dataloader to determine whether to run the evaluation loop.
2518
- Args:
2519
- model: The ``LightningModule`` if called outside of the trainer scope.
2520
- """
2521
- if self .train_dataloader is None :
2522
- self .reset_train_dataloader (model = model )
2523
- if self .val_dataloaders is None :
2524
- self .reset_val_dataloader (model = model )
2525
-
2526
2530
2527
2531
def _determine_batch_limits (batches : Union [int , float ], name : str ) -> Union [int , float ]:
2528
2532
if 0 <= batches <= 1 :
0 commit comments