Skip to content

Commit b2d4c7e

Browse files
committed
Fix manual optimization on AMP and skipping backward
1 parent 066ae70 commit b2d4c7e

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
632632
- ``None`` - Training will skip to the next batch
633633
634634
Note:
635-
Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled.
635+
Returning ``None`` is currently not supported for multi-GPU or TPU.
636636
637637
In this step you'd normally do the forward pass and calculate the loss for a batch.
638638
You can also do fancier things like multiple forward passes or something model specific.

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@ def pre_optimizer_step(
9797
**kwargs: Any,
9898
) -> bool:
9999
"""Hook to do something before each optimizer step."""
100-
lambda_closure() # APEX amp does not support closures
100+
result = lambda_closure() # APEX amp does not support closures
101101
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
102-
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
103-
optimizer.step(**kwargs)
102+
skipped_backward = result is None
103+
# in manual optimization, the closure does not return a value
104+
if not model.automatic_optimization or not skipped_backward:
105+
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
106+
optimizer.step(**kwargs)
104107
return False
105108

106109
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,14 @@ def pre_optimizer_step(
4242
**kwargs: Any,
4343
) -> bool:
4444
"""Hook to do something before each optimizer step."""
45-
lambda_closure() # DeepSpeed does not support closures
45+
result = lambda_closure() # DeepSpeed does not support closures
4646
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
47-
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
48-
deepspeed_engine = model.trainer.model
49-
deepspeed_engine.step()
47+
skipped_backward = result is None
48+
# in manual optimization, the closure does not return a value
49+
if not model.automatic_optimization or not skipped_backward:
50+
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
51+
deepspeed_engine = model.trainer.model
52+
deepspeed_engine.step()
5053
return False
5154

5255
def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,13 @@ def pre_optimizer_step(
9696
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
9797
" To request, please file a Github issue in PyTorch and tag @mcarilli"
9898
)
99-
result = True
100-
# FIXME: is this correct for manual?
101-
if model.automatic_optimization:
102-
result = lambda_closure()
99+
result = lambda_closure()
103100
self.scaler.unscale_(optimizer)
104101
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
105-
# lambda_closure returning None indicates that backward has been skipped
106-
if result is not None:
102+
skipped_backward = result is None
103+
# in manual optimization, the closure does not return a value
104+
if not model.automatic_optimization or not skipped_backward:
105+
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
107106
self.scaler.step(optimizer)
108107
self.scaler.update()
109108
return False

0 commit comments

Comments
 (0)