Skip to content

Commit 321502f

Browse files
awaelchlitchaton
andauthored
Update backward hook for PrecisionPlugin (#10008)
Co-authored-by: thomas chaton <[email protected]>
1 parent 8f14e77 commit 321502f

File tree

3 files changed

+18
-1
lines changed

3 files changed

+18
-1
lines changed

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any
6565
deepspeed_engine = model.trainer.model
6666
deepspeed_engine.backward(closure_loss, *args, **kwargs)
6767

68+
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
69+
model.backward(tensor, *args, **kwargs)
70+
6871
def clip_gradients(
6972
self,
7073
optimizer: Optimizer,

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing import Any, Callable, Dict, Generator, Union
1616

1717
import torch
18+
from torch import Tensor
19+
from torch.nn import Module
1820
from torch.optim import LBFGS, Optimizer
1921

2022
import pytorch_lightning as pl
@@ -68,6 +70,11 @@ def pre_backward(self, model: "pl.LightningModule", closure_loss: torch.Tensor)
6870
closure_loss = self.scaler.scale(closure_loss)
6971
return super().pre_backward(model, closure_loss)
7072

73+
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
74+
if not self.is_bfloat16:
75+
tensor = self.scaler.scale(tensor)
76+
super()._run_backward(tensor, model, *args, **kwargs)
77+
7178
def pre_optimizer_step(
7279
self,
7380
model: "pl.LightningModule",

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def backward(
7676
if model is not None and isinstance(model, pl.LightningModule):
7777
model.backward(closure_loss, optimizer, *args, **kwargs)
7878
else:
79-
closure_loss.backward(*args, **kwargs)
79+
self._run_backward(closure_loss, *args, **kwargs)
8080

8181
def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Tensor:
8282
"""Run after precision plugin executes backward.
@@ -90,6 +90,13 @@ def post_backward(self, model: "pl.LightningModule", closure_loss: Tensor) -> Te
9090
model.trainer.call_hook("on_after_backward")
9191
return closure_loss
9292

93+
def _run_backward(self, tensor: Tensor, model: Module, *args: Any, **kwargs: Any) -> None:
94+
"""Lightning-independent backward logic.
95+
96+
Currently only used by Lightning Lite. Subject to further refactors.
97+
"""
98+
tensor.backward(*args, **kwargs)
99+
93100
def pre_optimizer_step(
94101
self,
95102
model: "pl.LightningModule",

0 commit comments

Comments
 (0)