Skip to content

Commit 05dd247

Browse files
committed
fix mypy
1 parent bdd7a88 commit 05dd247

File tree

3 files changed

+45
-42
lines changed

3 files changed

+45
-42
lines changed

pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import uuid
2323
from copy import deepcopy
24-
from typing import Optional, Tuple
24+
from typing import List, Optional, Tuple, TypedDict, Union
2525

2626
from torch.utils.data.dataloader import DataLoader
2727

@@ -87,6 +87,18 @@ def __init__(
8787

8888
self._early_exit = False
8989

90+
from pytorch_lightning.loggers.base import LightningLoggerBase
91+
92+
class _BatchSizeFinderDumpedParams(TypedDict):
93+
callbacks: List[Callback]
94+
logger: Optional[LightningLoggerBase]
95+
max_steps: int
96+
global_step: Optional[None]
97+
limit_val_batches: Union[int, float]
98+
limit_eval_batches: Union[int, float]
99+
100+
self._dumped_params: _BatchSizeFinderDumpedParams = {}
101+
90102
def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
91103
if trainer.fast_dev_run:
92104
rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.")
@@ -284,13 +296,13 @@ def _dump_params(self, trainer: "pl.Trainer") -> None:
284296
self._dumped_params["limit_val_batches"] = trainer.limit_val_batches
285297
elif trainer.state.fn == TrainerFn.VALIDATING:
286298
loop = trainer.validate_loop
287-
self._dumped_params["limit_val_batches"] = trainer.limit_val_batches
299+
self._dumped_params["limit_eval_batches"] = trainer.limit_val_batches
288300
elif trainer.state.fn == TrainerFn.TESTING:
289301
loop = trainer.test_loop
290-
self._dumped_params["limit_test_batches"] = trainer.limit_test_batches
302+
self._dumped_params["limit_eval_batches"] = trainer.limit_test_batches
291303
elif trainer.state.fn == TrainerFn.PREDICTING:
292304
loop = trainer.predict_loop
293-
self._dumped_params["limit_predict_batches"] = trainer.limit_predict_batches
305+
self._dumped_params["limit_eval_batches"] = trainer.limit_predict_batches
294306

295307
self._dumped_params["loop_state_dict"] = deepcopy(loop.state_dict(force_save_progress=True))
296308
if hasattr(loop, "verbose"):
@@ -328,13 +340,13 @@ def _restore_params(self, trainer: "pl.Trainer") -> None:
328340
trainer.limit_val_batches = self._dumped_params["limit_val_batches"]
329341
elif trainer.state.fn == TrainerFn.VALIDATING:
330342
loop = trainer.validate_loop
331-
trainer.limit_val_batches = self._dumped_params["limit_val_batches"]
343+
trainer.limit_val_batches = self._dumped_params["limit_eval_batches"]
332344
elif trainer.state.fn == TrainerFn.TESTING:
333345
loop = trainer.test_loop
334-
trainer.limit_test_batches = self._dumped_params["limit_test_batches"]
346+
trainer.limit_test_batches = self._dumped_params["limit_eval_batches"]
335347
elif trainer.state.fn == TrainerFn.PREDICTING:
336348
loop = trainer.predict_loop
337-
trainer.limit_predict_batches = self._dumped_params["limit_predict_batches"]
349+
trainer.limit_predict_batches = self._dumped_params["limit_eval_batches"]
338350

339351
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]))
340352
if "loop_verbose" in self._dumped_params:
@@ -386,18 +398,18 @@ def _adjust_batch_size(
386398
if desc:
387399
rank_zero_info(f"Batch size {batch_size} {desc}, trying batch size {new_size}")
388400

389-
# TODO improve this for CombinedLoader and multi dataloaders
401+
# TODO improve this for eval CombinedLoader and multi dataloaders
390402
if trainer.state.fn == TrainerFn.FITTING:
391403
if not self._is_valid_batch_size(new_size, trainer.train_dataloader, trainer):
392404
new_size = min(new_size, len(trainer.train_dataloader.dataset))
393405
if trainer.state.fn == TrainerFn.VALIDATING:
394-
if not self._is_valid_batch_size(new_size, trainer.val_dataloaders, trainer):
406+
if not self._is_valid_batch_size(new_size, trainer.val_dataloaders[0], trainer):
395407
new_size = min(new_size, len(trainer.val_dataloaders[0].dataset))
396408
if trainer.state.fn == TrainerFn.TESTING:
397-
if not self._is_valid_batch_size(new_size, trainer.test_dataloaders, trainer):
409+
if not self._is_valid_batch_size(new_size, trainer.test_dataloaders[0], trainer):
398410
new_size = min(new_size, len(trainer.test_dataloaders[0].dataset))
399411
if trainer.state.fn == TrainerFn.PREDICTING:
400-
if not self._is_valid_batch_size(new_size, trainer.predict_dataloaders, trainer):
412+
if not self._is_valid_batch_size(new_size, trainer.predict_dataloaders[0], trainer):
401413
new_size = min(new_size, len(trainer.predict_dataloaders[0].dataset))
402414

403415
changed = new_size != batch_size

setup.cfg

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ addopts =
2424
--doctest-modules
2525
--color=yes
2626
--disable-pytest-warnings
27-
filterwarnings =
27+
# filterwarnings =
2828
# error out on our deprecation warnings - ensures the code and tests are kept up-to-date
29-
error::pytorch_lightning.utilities.warnings.LightningDeprecationWarning
29+
# error::pytorch_lightning.utilities.warnings.LightningDeprecationWarning
3030
# warnings from deprecated modules on import
3131
# TODO: remove in 1.7
32-
ignore::pytorch_lightning.utilities.warnings.LightningDeprecationWarning:pytorch_lightning.core.decorators
33-
ignore::pytorch_lightning.utilities.warnings.LightningDeprecationWarning:pytorch_lightning.core.memory
32+
# ignore::pytorch_lightning.utilities.warnings.LightningDeprecationWarning:pytorch_lightning.core.decorators
33+
# ignore::pytorch_lightning.utilities.warnings.LightningDeprecationWarning:pytorch_lightning.core.memory
3434

3535
junit_duration_report = call
3636

tests/tuner/test_scale_batch_size.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -79,54 +79,45 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b
7979
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
8080

8181

82-
def test_model_reset_correctly(tmpdir):
83-
"""Check that model weights are correctly reset after scaling batch size."""
82+
@pytest.mark.parametrize("trainer_fn", ["fit", "validate", "test", "predict"])
83+
def test_trainer_reset_correctly(tmpdir, trainer_fn):
84+
"""Check that model and all trainer parameters are reset correctly after scaling batch size."""
8485
tutils.reset_seed()
8586

8687
model = BatchSizeModel(batch_size=2)
87-
88-
# logger file to get meta
89-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
90-
9188
before_state_dict = deepcopy(model.state_dict())
9289

93-
trainer.tuner.scale_batch_size(model, max_trials=5)
94-
95-
after_state_dict = model.state_dict()
96-
97-
for key in before_state_dict.keys():
98-
assert torch.all(
99-
torch.eq(before_state_dict[key], after_state_dict[key])
100-
), "Model was not reset correctly after scaling batch size"
101-
102-
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size_temp_model"))
103-
104-
105-
def test_trainer_reset_correctly(tmpdir):
106-
"""Check that all trainer parameters are reset correctly after scaling batch size."""
107-
tutils.reset_seed()
108-
109-
model = BatchSizeModel(batch_size=2)
110-
11190
# logger file to get meta
11291
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
11392

11493
changed_attributes = [
94+
"logger",
95+
"callbacks",
11596
"global_step",
11697
"limit_val_batches",
11798
"max_steps",
118-
"logger",
119-
"callbacks",
99+
"limit_val_batches",
100+
"limit_test_batches",
101+
"limit_predict_batches",
120102
]
103+
121104
expected = {ca: getattr(trainer, ca) for ca in changed_attributes}
122105
expected_loop_state_dict = trainer.fit_loop.state_dict()
123-
trainer.tuner.scale_batch_size(model, max_trials=64)
106+
trainer.tuner.scale_batch_size(model, max_trials=64, method=trainer_fn)
124107
actual = {ca: getattr(trainer, ca) for ca in changed_attributes}
125108
actual_loop_state_dict = trainer.fit_loop.state_dict()
126109

127110
assert expected_loop_state_dict == actual_loop_state_dict
128111
assert actual == expected
129112

113+
after_state_dict = model.state_dict()
114+
for key in before_state_dict.keys():
115+
assert torch.all(
116+
torch.eq(before_state_dict[key], after_state_dict[key])
117+
), "Model was not reset correctly after scaling batch size"
118+
119+
assert not any(f for f in os.listdir(tmpdir) if f.startswith(".scale_batch_size_temp_model"))
120+
130121

131122
@RunIf(min_gpus=1)
132123
@pytest.mark.parametrize("scale_arg", ["power", "binsearch", True])

0 commit comments

Comments
 (0)