|
21 | 21 | import os
|
22 | 22 | import uuid
|
23 | 23 | from copy import deepcopy
|
24 |
| -from typing import Optional, Tuple |
| 24 | +from typing import List, Optional, Tuple, TypedDict, Union |
25 | 25 |
|
26 | 26 | from torch.utils.data.dataloader import DataLoader
|
27 | 27 |
|
@@ -87,6 +87,18 @@ def __init__(
|
87 | 87 |
|
88 | 88 | self._early_exit = False
|
89 | 89 |
|
| 90 | + from pytorch_lightning.loggers.base import LightningLoggerBase |
| 91 | + |
| 92 | + class _BatchSizeFinderDumpedParams(TypedDict): |
| 93 | + callbacks: List[Callback] |
| 94 | + logger: Optional[LightningLoggerBase] |
| 95 | + max_steps: int |
| 96 | + global_step: Optional[None] |
| 97 | + limit_val_batches: Union[int, float] |
| 98 | + limit_eval_batches: Union[int, float] |
| 99 | + |
| 100 | + self._dumped_params: _BatchSizeFinderDumpedParams = {} |
| 101 | + |
90 | 102 | def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
91 | 103 | if trainer.fast_dev_run:
|
92 | 104 | rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
|
@@ -284,13 +296,13 @@ def _dump_params(self, trainer: "pl.Trainer") -> None:
|
284 | 296 | self._dumped_params["limit_val_batches"] = trainer.limit_val_batches
|
285 | 297 | elif trainer.state.fn == TrainerFn.VALIDATING:
|
286 | 298 | loop = trainer.validate_loop
|
287 |
| - self._dumped_params["limit_val_batches"] = trainer.limit_val_batches |
| 299 | + self._dumped_params["limit_eval_batches"] = trainer.limit_val_batches |
288 | 300 | elif trainer.state.fn == TrainerFn.TESTING:
|
289 | 301 | loop = trainer.test_loop
|
290 |
| - self._dumped_params["limit_test_batches"] = trainer.limit_test_batches |
| 302 | + self._dumped_params["limit_eval_batches"] = trainer.limit_test_batches |
291 | 303 | elif trainer.state.fn == TrainerFn.PREDICTING:
|
292 | 304 | loop = trainer.predict_loop
|
293 |
| - self._dumped_params["limit_predict_batches"] = trainer.limit_predict_batches |
| 305 | + self._dumped_params["limit_eval_batches"] = trainer.limit_predict_batches |
294 | 306 |
|
295 | 307 | self._dumped_params["loop_state_dict"] = deepcopy(loop.state_dict(force_save_progress=True))
|
296 | 308 | if hasattr(loop, "verbose"):
|
@@ -328,13 +340,13 @@ def _restore_params(self, trainer: "pl.Trainer") -> None:
|
328 | 340 | trainer.limit_val_batches = self._dumped_params["limit_val_batches"]
|
329 | 341 | elif trainer.state.fn == TrainerFn.VALIDATING:
|
330 | 342 | loop = trainer.validate_loop
|
331 |
| - trainer.limit_val_batches = self._dumped_params["limit_val_batches"] |
| 343 | + trainer.limit_val_batches = self._dumped_params["limit_eval_batches"] |
332 | 344 | elif trainer.state.fn == TrainerFn.TESTING:
|
333 | 345 | loop = trainer.test_loop
|
334 |
| - trainer.limit_test_batches = self._dumped_params["limit_test_batches"] |
| 346 | + trainer.limit_test_batches = self._dumped_params["limit_eval_batches"] |
335 | 347 | elif trainer.state.fn == TrainerFn.PREDICTING:
|
336 | 348 | loop = trainer.predict_loop
|
337 |
| - trainer.limit_predict_batches = self._dumped_params["limit_predict_batches"] |
| 349 | + trainer.limit_predict_batches = self._dumped_params["limit_eval_batches"] |
338 | 350 |
|
339 | 351 | loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]))
|
340 | 352 | if "loop_verbose" in self._dumped_params:
|
@@ -386,18 +398,18 @@ def _adjust_batch_size(
|
386 | 398 | if desc:
|
387 | 399 | rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
|
388 | 400 |
|
389 |
| - # TODO improve this for CombinedLoader and multi dataloaders |
| 401 | + # TODO improve this for eval CombinedLoader and multi dataloaders |
390 | 402 | if trainer.state.fn == TrainerFn.FITTING:
|
391 | 403 | if not self._is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
|
392 | 404 | new_size = min(new_size, len(trainer.train_dataloader.dataset))
|
393 | 405 | if trainer.state.fn == TrainerFn.VALIDATING:
|
394 |
| - if not self._is_valid_batch_size(new_size, trainer.val_dataloaders, trainer): |
| 406 | + if not self._is_valid_batch_size(new_size, trainer.val_dataloaders[0], trainer): |
395 | 407 | new_size = min(new_size, len(trainer.val_dataloaders[0].dataset))
|
396 | 408 | if trainer.state.fn == TrainerFn.TESTING:
|
397 |
| - if not self._is_valid_batch_size(new_size, trainer.test_dataloaders, trainer): |
| 409 | + if not self._is_valid_batch_size(new_size, trainer.test_dataloaders[0], trainer): |
398 | 410 | new_size = min(new_size, len(trainer.test_dataloaders[0].dataset))
|
399 | 411 | if trainer.state.fn == TrainerFn.PREDICTING:
|
400 |
| - if not self._is_valid_batch_size(new_size, trainer.predict_dataloaders, trainer): |
| 412 | + if not self._is_valid_batch_size(new_size, trainer.predict_dataloaders[0], trainer): |
401 | 413 | new_size = min(new_size, len(trainer.predict_dataloaders[0].dataset))
|
402 | 414 |
|
403 | 415 | changed = new_size != batch_size
|
|
0 commit comments