Skip to content

Commit 8c3e0de

Browse files
committed
updates
1 parent ad58a55 commit 8c3e0de

File tree

4 files changed

+14
-6
lines changed

4 files changed

+14
-6
lines changed

pytorch_lightning/callbacks/batch_size_finder.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,6 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
110110
f"Field {self.batch_arg_name} not found in both `model` and `model.hparams`"
111111
)
112112

113-
if not lightning_hasattr(pl_module, self.batch_arg_name):
114-
raise MisconfigurationException(
115-
f"Field {self.batch_arg_name} not found in both `model` and `model.hparams`"
116-
)
117-
118113
if (
119114
hasattr(pl_module, self.batch_arg_name)
120115
and hasattr(pl_module, "hparams")
@@ -126,6 +121,7 @@ def scale_batch_size(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule
126121
" If this is not the intended behavior, please remove either one."
127122
)
128123

124+
# TODO: check if this can be enabled (#4040)
129125
if not trainer._data_connector._train_dataloader_source.is_module():
130126
raise MisconfigurationException(
131127
"The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`."
@@ -329,6 +325,7 @@ def _reset_params(self, trainer: "pl.Trainer") -> None:
329325
trainer.limit_predict_batches = self.steps_per_trial
330326

331327
def _restore_params(self, trainer: "pl.Trainer") -> None:
328+
# TODO: There are more states that needs to be reset (#4512 and #4870)
332329
from pytorch_lightning.trainer.states import TrainerFn
333330

334331
trainer.logger = self._dumped_params["logger"]
@@ -350,6 +347,7 @@ def _restore_params(self, trainer: "pl.Trainer") -> None:
350347
trainer.limit_predict_batches = self._dumped_params["limit_eval_batches"]
351348

352349
loop.load_state_dict(deepcopy(self._dumped_params["loop_state_dict"]))
350+
loop.restarting = False
353351
if "loop_verbose" in self._dumped_params:
354352
loop.verbose = self._dumped_params["loop_verbose"]
355353

pytorch_lightning/loops/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from pytorch_lightning.trainer.progress import BaseProgress
2424
from pytorch_lightning.utilities.enums import _FaultTolerantMode
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
26-
from pytorch_lightning.utilities.imports import _fault_tolerant_training
2726

2827
T = TypeVar("T") # the output type of `run`
2928

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,16 @@ def restore_loops(self) -> None:
225225
if state_dict is not None and self.trainer.state.fn != TrainerFn.TUNING:
226226
if self.trainer.state.fn == TrainerFn.FITTING:
227227
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
228+
self.trainer.fit_loop.restarting = True
228229
elif self.trainer.state.fn == TrainerFn.VALIDATING:
229230
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
231+
self.trainer.validate_loop.restarting = True
230232
elif self.trainer.state.fn == TrainerFn.TESTING:
231233
self.trainer.test_loop.load_state_dict(state_dict["test_loop"])
234+
self.trainer.test_loop.restarting = True
232235
elif self.trainer.state.fn == TrainerFn.PREDICTING:
233236
self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"])
237+
self.trainer.predict_loop.restarting = True
234238

235239
if self.trainer.state.fn != TrainerFn.FITTING:
236240
return

tests/loops/test_loops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
283283
state_dict["loop_child.state_dict"]["a"] = 3
284284
# check restarting after `load_state_dict`
285285
loop_parent.load_state_dict(state_dict)
286+
loop_parent.restarting = True
286287
assert loop_parent.restarting
287288

288289
loop_parent.run()
@@ -306,6 +307,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:
306307
loop_child = Simple(2)
307308
loop_parent.loop_child = loop_child
308309
loop_parent.load_state_dict(state_dict)
310+
loop_parent.restarting = True
309311
assert loop_parent.progress.increment == 1
310312
assert loop_parent.loop_child.progress.increment == 1
311313

@@ -359,6 +361,7 @@ def val_dataloader(self):
359361
assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected
360362

361363
trainer.fit_loop.load_state_dict(checkpoint)
364+
trainer.fit_loop.restarting = True
362365

363366
# `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch
364367
# the fit-validation total batch progress is reset per epoch so it's not counted for the total value.
@@ -548,6 +551,7 @@ def configure_optimizers_multiple(self):
548551
assert checkpoint["loops"]["fit_loop"] == expected
549552

550553
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
554+
trainer.fit_loop.restarting = True
551555
state_dict = trainer.fit_loop.state_dict()
552556

553557
# need to remove these elements for comparison; comparing with `fit_loop.state_dict()` would require the
@@ -557,6 +561,7 @@ def configure_optimizers_multiple(self):
557561
assert state_dict == checkpoint["loops"]["fit_loop"]
558562

559563
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
564+
trainer.fit_loop.restarting = True
560565
# test resetting manually, we expect all `ready` counters to be reset to `completed`
561566
trainer.fit_loop.reset()
562567
trainer.fit_loop.epoch_loop.reset()
@@ -753,6 +758,7 @@ def test_fit_loop_reset(tmpdir):
753758

754759
# we load exactly what was saved - no reset yet
755760
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])
761+
fit_loop.restarting = True
756762
# resetting from a mid-of-epoch checkpoint SHOULD NOT reset the current counters to 0
757763
fit_loop.reset()
758764
epoch_loop.reset()
@@ -785,6 +791,7 @@ def test_fit_loop_reset(tmpdir):
785791

786792
# we load exactly what was saved - no reset yet
787793
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])
794+
fit_loop.restarting = True
788795
# resetting from a end-of-epoch checkpoint SHOULD reset the current counters to 0
789796
fit_loop.reset()
790797
epoch_loop.reset()

0 commit comments

Comments
 (0)