Skip to content

Commit 41be61c

Browse files
author
Sean Naren
authored
[IPU] Add hooks for IPU lifecycle 4/5 (#7864)
1 parent ea71cf4 commit 41be61c

File tree

3 files changed

+83
-12
lines changed

3 files changed

+83
-12
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6262
- Added reset dataloader hooks to Training Plugins and Accelerators ([#7861](https://github.com/PyTorchLightning/pytorch-lightning/pull/7861))
6363

6464

65+
- Added trainer stage hooks for Training Plugins and Accelerators ([#7864](https://github.com/PyTorchLightning/pytorch-lightning/pull/7864))
66+
67+
6568
### Changed
6669

6770
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

pytorch_lightning/accelerators/accelerator.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,6 @@ def batch_to_device(
179179

180180
return move_data_to_device(batch, device)
181181

182-
def on_train_start(self) -> None:
183-
"""Hook to do something upon the training start"""
184-
pass
185-
186182
def training_step(
187183
self,
188184
step_kwargs: Dict[str, Union[Any, int]],
@@ -348,14 +344,6 @@ def clip_gradients(
348344
model=self.model,
349345
)
350346

351-
def on_train_epoch_end(self) -> None:
352-
"""Hook to do something on the end of an training epoch."""
353-
pass
354-
355-
def on_train_end(self) -> None:
356-
"""Hook to do something at the end of the training"""
357-
pass
358-
359347
def setup_optimizers(self, trainer: 'pl.Trainer') -> None:
360348
"""
361349
Creates optimizers and schedulers
@@ -563,3 +551,45 @@ def setup_optimizers_in_pre_dispatch(self) -> bool:
563551

564552
def update_global_step(self, total_batch_idx: int, current_global_step: int) -> int:
565553
return self.training_type_plugin.update_global_step(total_batch_idx, current_global_step)
554+
555+
def on_train_epoch_end(self) -> None:
556+
"""Hook to do something on the end of an training epoch."""
557+
pass
558+
559+
def on_train_start(self) -> None:
560+
"""Called when train begins."""
561+
return self.training_type_plugin.on_train_start()
562+
563+
def on_validation_start(self) -> None:
564+
"""Called when validation begins."""
565+
return self.training_type_plugin.on_validation_start()
566+
567+
def on_test_start(self) -> None:
568+
"""Called when test begins."""
569+
return self.training_type_plugin.on_test_start()
570+
571+
def on_predict_start(self) -> None:
572+
"""Called when predict begins."""
573+
return self.training_type_plugin.on_predict_start()
574+
575+
def on_validation_end(self) -> None:
576+
"""Called when validation ends."""
577+
return self.training_type_plugin.on_validation_end()
578+
579+
def on_test_end(self) -> None:
580+
"""Called when test end."""
581+
return self.training_type_plugin.on_test_end()
582+
583+
def on_predict_end(self) -> None:
584+
"""Called when predict ends."""
585+
return self.training_type_plugin.on_predict_end()
586+
587+
def on_train_end(self) -> None:
588+
"""Called when train ends."""
589+
return self.training_type_plugin.on_train_end()
590+
591+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
592+
"""
593+
Called in the training loop before anything happens for that batch.
594+
"""
595+
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,41 @@ def register_plugins(cls, plugin_registry):
330330
def should_rank_save_checkpoint(self) -> bool:
331331
"""Returns whether the checkpoint should be saved (rank based)"""
332332
return self.is_global_zero
333+
334+
def on_train_start(self) -> None:
335+
"""Called when train begins."""
336+
pass
337+
338+
def on_validation_start(self) -> None:
339+
"""Called when validation begins."""
340+
pass
341+
342+
def on_test_start(self) -> None:
343+
"""Called when test begins."""
344+
pass
345+
346+
def on_predict_start(self) -> None:
347+
"""Called when predict begins."""
348+
pass
349+
350+
def on_train_end(self) -> None:
351+
"""Called when train ends."""
352+
pass
353+
354+
def on_validation_end(self) -> None:
355+
"""Called when validation ends."""
356+
pass
357+
358+
def on_test_end(self) -> None:
359+
"""Called when test end."""
360+
pass
361+
362+
def on_predict_end(self):
363+
"""Called when predict ends."""
364+
pass
365+
366+
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
367+
"""
368+
Called in the training loop before anything happens for that batch.
369+
"""
370+
pass

0 commit comments

Comments
 (0)