Skip to content

Commit a05ca53

Browse files
Run plugin closure before on_before_optimizer_step [1/2] (#9288)
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 597a1c8 commit a05ca53

File tree

8 files changed

+98
-55
lines changed

8 files changed

+98
-55
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
[#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627))
1616
- Fixed `EarlyStopping` running on train epoch end when `check_val_every_n_epoch>1` is set ([#9156](https://github.com/PyTorchLightning/pytorch-lightning/pull/9156))
1717
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
18+
19+
20+
- Fixed the Apex and DeepSpeed plugin closure running after the `on_before_optimizer_step` hook ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288))
21+
22+
23+
- Fixed the Native AMP plugin closure not running with manual optimization ([#9288](https://github.com/PyTorchLightning/pytorch-lightning/issues/9288))
24+
25+
1826
- Fixed bug where data-loading functions where not getting the correct running stage passed ([#8858](https://github.com/PyTorchLightning/pytorch-lightning/pull/8858))
1927
- Fixed intra-epoch evaluation outputs staying in memory when the respective `*_epoch_end` hook wasn't overridden ([#9261](https://github.com/PyTorchLightning/pytorch-lightning/pull/9261))
2028
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,8 @@ def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
669669
670670
- :class:`~torch.Tensor` - The loss tensor
671671
- ``dict`` - A dictionary. Can include any keys, but must include the key ``'loss'``
672-
- ``None`` - Training will skip to the next batch
673-
674-
Note:
675-
Returning ``None`` is currently not supported for multi-GPU or TPU, or with 16-bit precision enabled.
672+
- ``None`` - Training will skip to the next batch. This is only for automatic optimization.
673+
This is not supported for multi-GPU or TPU, or using ``DeepSpeed``.
676674
677675
In this step you'd normally do the forward pass and calculate the loss for a batch.
678676
You can also do fancier things like multiple forward passes or something model specific.

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@ def pre_optimizer_step(
9797
**kwargs: Any,
9898
) -> bool:
9999
"""Hook to do something before each optimizer step."""
100+
result = lambda_closure() # APEX amp does not support closures
100101
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
101-
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
102-
lambda_closure() # APEX amp does not support closures
103-
optimizer.step(**kwargs)
102+
skipped_backward = result is None
103+
# in manual optimization, the closure does not return a value
104+
if not model.automatic_optimization or not skipped_backward:
105+
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
106+
optimizer.step(**kwargs)
104107
return False
105108

106109
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import pytorch_lightning as pl
2121
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
2222
from pytorch_lightning.utilities import GradClipAlgorithmType
23+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2324
from pytorch_lightning.utilities.model_helpers import is_overridden
2425
from pytorch_lightning.utilities.warnings import WarningCache
2526

@@ -42,9 +43,14 @@ def pre_optimizer_step(
4243
**kwargs: Any,
4344
) -> bool:
4445
"""Hook to do something before each optimizer step."""
46+
result = lambda_closure() # DeepSpeed does not support closures
4547
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
48+
# in manual optimization, the closure does not return a value
49+
if model.automatic_optimization and result is None:
50+
raise MisconfigurationException(
51+
"Skipping backward by returning `None` from your `training_step` is not supported by `DeepSpeed`"
52+
)
4653
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
47-
lambda_closure() # DeepSpeed does not support closures
4854
deepspeed_engine = model.trainer.model
4955
deepspeed_engine.step()
5056
return False

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ def pre_optimizer_step(
5454
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
5555
" To request, please file a Github issue in PyTorch and tag @mcarilli"
5656
)
57-
result = True
58-
if model.automatic_optimization:
59-
result = lambda_closure()
57+
result = lambda_closure() # native amp does not support closures
6058
self.scaler.unscale_(optimizer)
6159
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
62-
# lambda_closure returning None indicates that backward has been skipped
63-
if result is not None:
60+
skipped_backward = result is None
61+
# in manual optimization, the closure does not return a value
62+
if not model.automatic_optimization or not skipped_backward:
63+
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
6464
self.scaler.step(optimizer)
6565
self.scaler.update()
6666
return False

tests/models/test_hooks.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _train_batch(self, *args, **kwargs):
271271
def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs):
272272
using_native_amp = kwargs.get("amp_backend") == "native"
273273
using_deepspeed = kwargs.get("plugins") == "deepspeed"
274+
using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins")
274275
out = []
275276
on_before_optimizer_step = [
276277
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
@@ -286,10 +287,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
286287
dict(name="Callback.on_batch_start", args=(trainer, model)),
287288
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)),
288289
dict(name="on_train_batch_start", args=(ANY, i, 0)),
289-
# these are before the training step because
290-
# they are not part of the `training_step_and_backward` closure, however,
291-
# with native amp, the closure is run first and then the optimizer step.
292-
*(on_before_optimizer_step if not using_native_amp else []),
290+
# without a precision plugin, we execute the closure inside the `optimizer.step`
291+
*([] if using_plugin else on_before_optimizer_step),
293292
dict(name="forward", args=(ANY,)),
294293
dict(name="training_step", args=(ANY, i)),
295294
dict(name="training_step_end", args=(dict(loss=ANY),)),
@@ -302,7 +301,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
302301
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
303302
dict(name="Callback.on_after_backward", args=(trainer, model)),
304303
dict(name="on_after_backward"),
305-
*(on_before_optimizer_step if using_native_amp else []),
304+
*(on_before_optimizer_step if using_plugin else []),
306305
dict(
307306
name="optimizer_step",
308307
args=(current_epoch, i, ANY, 0, ANY),
@@ -318,6 +317,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
318317
@staticmethod
319318
def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs):
320319
using_deepspeed = kwargs.get("plugins") == "deepspeed"
320+
using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins")
321321
out = []
322322
for i in range(batches):
323323
out.extend(
@@ -338,8 +338,11 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
338338
dict(name="on_after_backward"),
339339
# `manual_backward` calls the previous 3
340340
dict(name="manual_backward", args=(ANY,)),
341+
*([dict(name="closure")] if using_plugin else []),
341342
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
342343
dict(name="on_before_optimizer_step", args=(ANY, 0)),
344+
# without a precision plugin, we execute the closure inside the `optimizer.step`
345+
*([] if using_plugin else [dict(name="closure")]),
343346
dict(name="training_step", args=(ANY, i)),
344347
dict(name="training_step_end", args=(dict(loss=ANY),)),
345348
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
@@ -435,7 +438,7 @@ def training_step(self, batch, batch_idx):
435438
opt = self.optimizers()
436439
opt.zero_grad()
437440
self.manual_backward(loss)
438-
opt.step()
441+
opt.step(lambda: called.append({"name": "closure"}))
439442
return {"loss": loss}
440443

441444
model = TestModel(called)

tests/plugins/test_deepspeed_plugin.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,3 +755,15 @@ def test_deepspeed_multigpu_no_schedulers(tmpdir):
755755
trainer.fit(model)
756756

757757
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelNoSchedulers)
758+
759+
760+
@RunIf(min_gpus=1, deepspeed=True)
761+
def test_deepspeed_skip_backward_raises(tmpdir):
762+
class TestModel(BoringModel):
763+
def training_step(self, batch, batch_idx):
764+
return None
765+
766+
model = TestModel()
767+
trainer = Trainer(default_root_dir=tmpdir, plugins=[DeepSpeedPlugin()], gpus=1, fast_dev_run=True, precision=16)
768+
with pytest.raises(MisconfigurationException, match="returning `None` .* is not supported"):
769+
trainer.fit(model)

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,50 @@ def configure_optimizers(self):
6464
return optimizer, optimizer_2
6565

6666

67-
def test_multiple_optimizers_manual_no_return(tmpdir):
67+
@pytest.mark.parametrize(
68+
"kwargs",
69+
[
70+
{},
71+
pytest.param({"gpus": 1, "precision": 16, "amp_backend": "native"}, marks=RunIf(amp_native=True, min_gpus=1)),
72+
pytest.param(
73+
{"gpus": 1, "precision": 16, "amp_backend": "apex", "amp_level": "O2"},
74+
marks=RunIf(amp_apex=True, min_gpus=1),
75+
),
76+
],
77+
)
78+
def test_multiple_optimizers_manual_no_return(tmpdir, kwargs):
79+
apex_optimizer_patches = []
80+
apex_optimizer_steps = []
81+
6882
class TestModel(ManualOptModel):
6983
def training_step(self, batch, batch_idx):
7084
# avoid returning a value
7185
super().training_step(batch, batch_idx)
7286

73-
def training_epoch_end(self, outputs) -> None:
87+
def training_epoch_end(self, outputs):
7488
# outputs is empty as training_step does not return
7589
# and it is not automatic optimization
7690
assert not outputs
7791

92+
def on_train_start(self):
93+
if kwargs.get("amp_backend") != "apex":
94+
return
95+
# extremely ugly. APEX patches all the native torch optimizers on `_initialize` which we call on
96+
# `ApexMixedPrecisionPlugin.dispatch`. Additionally, their replacement `new_step` functions are locally
97+
# defined so can't even patch those, thus we need to create the mock after APEX has been initialized
98+
nonlocal apex_optimizer_patches, apex_optimizer_steps
99+
for opt in self.trainer.optimizers:
100+
# `amp.scale_loss` will also patch the step to avoid it when gradient overflow happens. avoid it
101+
opt._amp_stash.already_patched = True
102+
patch = mock.patch.object(opt, "step")
103+
apex_optimizer_patches.append(patch)
104+
apex_optimizer_steps.append(patch.start())
105+
106+
def on_train_end(self):
107+
if kwargs.get("amp_backend") == "apex":
108+
for p in apex_optimizer_patches:
109+
p.stop()
110+
78111
model = TestModel()
79112
model.val_dataloader = None
80113

@@ -86,12 +119,26 @@ def training_epoch_end(self, outputs) -> None:
86119
max_epochs=1,
87120
log_every_n_steps=1,
88121
weights_summary=None,
122+
**kwargs,
89123
)
90124

125+
if kwargs.get("amp_backend") == "native":
126+
# mock the scaler instead of the optimizer step because it can be skipped with NaNs
127+
scaler_step_patch = mock.patch.object(
128+
trainer.precision_plugin.scaler, "step", wraps=trainer.precision_plugin.scaler.step
129+
)
130+
scaler_step = scaler_step_patch.start()
131+
91132
with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock:
92133
trainer.fit(model)
93134
assert bwd_mock.call_count == limit_train_batches * 3
94135

136+
if kwargs.get("amp_backend") == "native":
137+
scaler_step_patch.stop()
138+
assert scaler_step.call_count == len(model.optimizers()) * limit_train_batches
139+
if kwargs.get("amp_backend") == "apex":
140+
assert [s.call_count for s in apex_optimizer_steps] == [len(model.optimizers())] * limit_train_batches
141+
95142

96143
def test_multiple_optimizers_manual_return(tmpdir):
97144
class TestModel(ManualOptModel):
@@ -171,40 +218,6 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
171218
assert bwd_mock.call_count == limit_train_batches * 3
172219

173220

174-
@RunIf(min_gpus=1, amp_apex=True)
175-
def test_multiple_optimizers_manual_apex_no_return(tmpdir):
176-
class TestModel(ManualOptModel):
177-
def training_step(self, batch, batch_idx):
178-
# avoid returning a value
179-
super().training_step(batch, batch_idx)
180-
181-
def training_epoch_end(self, outputs) -> None:
182-
# outputs is empty as training_step does not return
183-
# and it is not automatic optimization
184-
assert len(outputs) == 0
185-
186-
model = TestModel()
187-
model.val_dataloader = None
188-
189-
limit_train_batches = 2
190-
trainer = Trainer(
191-
default_root_dir=tmpdir,
192-
limit_train_batches=limit_train_batches,
193-
limit_val_batches=2,
194-
max_epochs=1,
195-
log_every_n_steps=1,
196-
weights_summary=None,
197-
precision=16,
198-
amp_level="O2",
199-
amp_backend="apex",
200-
gpus=1,
201-
)
202-
203-
with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock:
204-
trainer.fit(model)
205-
assert bwd_mock.call_count == limit_train_batches * 3
206-
207-
208221
class ManualOptimizationExtendedModel(BoringModel):
209222

210223
count = 0

0 commit comments

Comments
 (0)