Skip to content

Commit 6ff43cb

Browse files
fix resuming from checkpoint for fault-tolerant in case of no failure (#9371)
Co-authored-by: Justus Schock <[email protected]>
1 parent 7ca038b commit 6ff43cb

File tree

5 files changed

+239
-5
lines changed

5 files changed

+239
-5
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2323

2424

2525
- Progress tracking
26-
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598)
27-
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320)
26+
* Integrate `TrainingEpochLoop.total_batch_idx` ([#8598](https://github.com/PyTorchLightning/pytorch-lightning/pull/8598))
27+
* Avoid optional `Tracker` attributes ([#9320](https://github.com/PyTorchLightning/pytorch-lightning/pull/9320))
28+
* Reset `current` progress counters when restarting an epoch loop that had already finished ([#9371](https://github.com/PyTorchLightning/pytorch-lightning/pull/9371))
2829

2930

3031
- Added `batch_size` and `rank_zero_only` arguments for `log_dict` to match `log` ([#8628](https://github.com/PyTorchLightning/pytorch-lightning/pull/8628))

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def reset(self) -> None:
9898
# track epoch output
9999
self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))]
100100

101-
if not self.restarting:
101+
if not self.restarting or self._num_training_batches_reached():
102102
self.batch_progress.current.reset()
103103
self.scheduler_progress.current.reset()
104104
self.batch_loop.optimizer_loop.optim_progress.reset_on_epoch()

pytorch_lightning/loops/fit_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def teardown(self) -> None:
263263

264264
def on_save_checkpoint(self) -> Dict:
265265
state_dict = super().on_save_checkpoint()
266-
# FIXME(@tchaton) Should pass has_completed=True when iterator is exhausted ?
266+
# TODO: update has_completed to its proper value
267267
state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(has_completed=False)
268268
return state_dict
269269

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def connect(self, **kwargs: "Loop") -> None:
6363
raise NotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
6464

6565
def reset(self) -> None:
66-
if not self.restarting:
66+
if not self.restarting or self.done:
6767
self.optim_progress.optimizer_idx = 0
6868
self.outputs = [[] for _ in range(len(self.trainer.optimizers))]
6969

tests/loops/test_loops.py

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323

2424
from pytorch_lightning import Trainer
25+
from pytorch_lightning.callbacks import ModelCheckpoint
2526
from pytorch_lightning.loops import Loop, TrainingBatchLoop
2627
from pytorch_lightning.trainer.progress import BaseProgress
2728
from tests.helpers import BoringModel
@@ -513,3 +514,235 @@ def configure_optimizers_multiple(self):
513514
assert state_dict != checkpoint["loops"]["fit_loop"]
514515
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
515516
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch
517+
518+
519+
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
520+
@pytest.mark.parametrize("n_optimizers", (1, 3, 5))
521+
@RunIf(min_torch="1.7.0")
522+
def test_loop_state_on_complete_run(n_optimizers, tmpdir):
523+
n_epochs = 3
524+
n_batches = 3
525+
accumulate_grad_batches = 1
526+
527+
class TestModel(BoringModel):
528+
def __init__(self):
529+
super().__init__()
530+
if n_optimizers > 1:
531+
self.configure_optimizers = self.configure_optimizers_multiple
532+
533+
def training_step(self, batch, batch_idx, optimizer_idx=0):
534+
return super().training_step(batch, batch_idx)
535+
536+
def configure_optimizers_multiple(self):
537+
optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)]
538+
539+
lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1)
540+
lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1)
541+
# no scheduler for optimizer_2
542+
lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}]
543+
544+
return optimizers, lr_schedulers
545+
546+
model = TestModel()
547+
model.training_epoch_end = None
548+
549+
trainer = Trainer(
550+
default_root_dir=tmpdir,
551+
max_epochs=n_epochs,
552+
limit_train_batches=n_batches,
553+
limit_val_batches=0,
554+
accumulate_grad_batches=accumulate_grad_batches,
555+
progress_bar_refresh_rate=0,
556+
logger=False,
557+
checkpoint_callback=True,
558+
)
559+
trainer.fit(model)
560+
561+
ckpt_path = trainer.checkpoint_callback.best_model_path
562+
assert os.path.exists(ckpt_path)
563+
checkpoint = torch.load(ckpt_path)
564+
565+
n_sch_steps_total = n_epochs
566+
n_sch_steps_current = 1
567+
if n_optimizers > 1:
568+
n_sch_steps_total = n_epochs + n_epochs * n_batches
569+
n_sch_steps_current = n_batches + 1
570+
571+
expected = {
572+
"state_dict": ANY,
573+
"epoch_progress": {
574+
"total": {
575+
"ready": n_epochs,
576+
"started": n_epochs,
577+
"processed": n_epochs,
578+
# TODO: the following "-1" offset will be fixed by
579+
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
580+
"completed": n_epochs - 1,
581+
},
582+
"current": {
583+
"ready": n_epochs,
584+
"started": n_epochs,
585+
"processed": n_epochs,
586+
# TODO: the following "-1" offset will be fixed by
587+
# https://github.com/PyTorchLightning/pytorch-lightning/pull/8578
588+
"completed": n_epochs - 1,
589+
},
590+
},
591+
"epoch_loop.state_dict": ANY,
592+
"epoch_loop.batch_progress": {
593+
"total": {
594+
"ready": n_epochs * n_batches,
595+
"started": n_epochs * n_batches,
596+
"processed": n_epochs * n_batches,
597+
"completed": n_epochs * n_batches,
598+
},
599+
"current": {
600+
"ready": n_batches,
601+
"started": n_batches,
602+
"processed": n_batches,
603+
"completed": n_batches,
604+
},
605+
},
606+
"epoch_loop.scheduler_progress": {
607+
"total": {"ready": n_sch_steps_total, "completed": n_sch_steps_total},
608+
"current": {"ready": n_sch_steps_current, "completed": n_sch_steps_current},
609+
},
610+
"epoch_loop.batch_loop.state_dict": ANY,
611+
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
612+
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
613+
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
614+
"optimizer_idx": n_optimizers,
615+
"optimizer": {
616+
"step": {
617+
"total": {
618+
"ready": n_epochs * n_batches * n_optimizers,
619+
"completed": n_epochs * n_batches * n_optimizers,
620+
},
621+
"current": {
622+
"ready": n_batches * n_optimizers,
623+
"completed": n_batches * n_optimizers,
624+
},
625+
},
626+
"zero_grad": {
627+
"total": {
628+
"ready": n_epochs * n_batches * n_optimizers,
629+
"started": n_epochs * n_batches * n_optimizers,
630+
"completed": n_epochs * n_batches * n_optimizers,
631+
},
632+
"current": {
633+
"ready": n_batches * n_optimizers,
634+
"started": n_batches * n_optimizers,
635+
"completed": n_batches * n_optimizers,
636+
},
637+
},
638+
},
639+
},
640+
"epoch_loop.val_loop.state_dict": ANY,
641+
"epoch_loop.val_loop.dataloader_progress": ANY,
642+
"epoch_loop.val_loop.epoch_loop.state_dict": ANY,
643+
"epoch_loop.val_loop.epoch_loop.batch_progress": ANY,
644+
"epoch_loop.val_loop._results": ANY,
645+
"epoch_loop._results": ANY,
646+
}
647+
assert checkpoint["loops"]["fit_loop"] == expected
648+
649+
650+
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
651+
@RunIf(min_torch="1.7.0")
652+
def test_fit_loop_reset(tmpdir):
653+
"""Test that the reset logic in fit- and epoch loop is aware of whether the loop is restarting from a completed
654+
loop or from a mid-epoch checkpoint."""
655+
656+
# generate checkpoints at end of epoch and mid-epoch
657+
model = BoringModel()
658+
checkpoint_callback = ModelCheckpoint(
659+
dirpath=tmpdir,
660+
every_n_train_steps=2,
661+
save_top_k=-1,
662+
)
663+
trainer = Trainer(
664+
default_root_dir=tmpdir,
665+
limit_train_batches=4,
666+
num_sanity_val_steps=0,
667+
max_epochs=2,
668+
callbacks=[checkpoint_callback],
669+
logger=False,
670+
weights_summary=None,
671+
)
672+
trainer.fit(model)
673+
674+
# reset state loaded from a checkpoint from mid-epoch
675+
mid_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=1.ckpt"))
676+
fit_loop = trainer.fit_loop
677+
epoch_loop = fit_loop.epoch_loop
678+
optimizer_loop = epoch_loop.batch_loop.optimizer_loop
679+
assert not fit_loop.restarting
680+
assert not epoch_loop.restarting
681+
assert not optimizer_loop.restarting
682+
683+
fit_loop.load_state_dict(mid_epoch_ckpt["loops"]["fit_loop"])
684+
685+
def mid_epoch_reset_assertions():
686+
assert fit_loop.restarting
687+
assert fit_loop.epoch_progress.total.ready == 1
688+
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint was saved mid epoch
689+
assert fit_loop.epoch_progress.current.ready == 0
690+
assert fit_loop.epoch_progress.current.completed == 0
691+
692+
assert epoch_loop.restarting
693+
assert epoch_loop.batch_progress.total.ready == 2
694+
assert epoch_loop.batch_progress.total.completed == 1 # the checkpoint was saved on train_batch_end
695+
assert epoch_loop.batch_progress.current.ready == 2
696+
assert epoch_loop.batch_progress.current.completed == 2
697+
698+
# resetting from a mid-epoch checkpoint should not change progress counters
699+
mid_epoch_reset_assertions()
700+
assert optimizer_loop.optim_progress.optimizer_idx == 1
701+
fit_loop.reset()
702+
epoch_loop.reset()
703+
optimizer_loop.reset()
704+
mid_epoch_reset_assertions()
705+
assert optimizer_loop.optim_progress.optimizer_idx == 0
706+
707+
# reset state loaded from a checkpoint from the end of an epoch
708+
end_of_epoch_ckpt = torch.load(str(tmpdir / "epoch=0-step=3.ckpt"))
709+
fit_loop = trainer.fit_loop
710+
epoch_loop = fit_loop.epoch_loop
711+
fit_loop.restarting = False
712+
epoch_loop.restarting = False
713+
optimizer_loop.restarting = False
714+
715+
fit_loop.load_state_dict(end_of_epoch_ckpt["loops"]["fit_loop"])
716+
717+
assert fit_loop.restarting
718+
assert fit_loop.epoch_progress.total.ready == 1
719+
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
720+
assert fit_loop.epoch_progress.current.ready == 0
721+
assert fit_loop.epoch_progress.current.completed == 0
722+
723+
assert epoch_loop.restarting
724+
assert epoch_loop.batch_progress.total.ready == 4
725+
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
726+
assert epoch_loop.batch_progress.current.ready == 4
727+
assert epoch_loop.batch_progress.current.completed == 4
728+
729+
assert optimizer_loop.optim_progress.optimizer_idx == 1
730+
731+
# resetting from a end-of-epoch checkpoint should reset the current counters to 0
732+
fit_loop.reset()
733+
epoch_loop.reset()
734+
optimizer_loop.reset()
735+
736+
assert fit_loop.restarting
737+
assert fit_loop.epoch_progress.total.ready == 1
738+
assert fit_loop.epoch_progress.total.completed == 0 # the checkpoint saves before the epoch completes
739+
assert fit_loop.epoch_progress.current.ready == 0
740+
assert fit_loop.epoch_progress.current.completed == 0
741+
742+
assert epoch_loop.restarting
743+
assert epoch_loop.batch_progress.total.ready == 4
744+
assert epoch_loop.batch_progress.total.completed == 3 # the checkpoint was saved on train_batch_end
745+
assert epoch_loop.batch_progress.current.ready == 0
746+
assert epoch_loop.batch_progress.current.completed == 0
747+
748+
assert optimizer_loop.optim_progress.optimizer_idx == 0

0 commit comments

Comments
 (0)