Skip to content

Commit 3a0e42d

Browse files
committed
Implement double-closure
1 parent 6ed7a0c commit 3a0e42d

File tree

5 files changed

+39
-19
lines changed

5 files changed

+39
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
665665
- Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287))
666666

667667

668+
- Fixed `on_before_optimizer_step` getting called before the optimizer closure (including backward) has run ([#10167](https://github.com/PyTorchLightning/pytorch-lightning/pull/10167))
669+
670+
668671
- Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118))
669672

670673

pytorch_lightning/accelerators/accelerator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ def optimizer_step(
333333
"""
334334
model = model or self.lightning_module
335335
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs)
336+
337+
if not isinstance(model, pl.LightningModule):
338+
# gradient clipping and norm tracking only available with a LightingModule/Trainer
339+
return
340+
336341
trainer = model.trainer
337342
assert isinstance(trainer, pl.Trainer)
338343
# TODO: this is done for the entire model but should be changed to per-optimizer

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,28 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k
110110
"""
111111
tensor.backward(*args, **kwargs)
112112

113+
def _wrap_closure(
114+
self,
115+
model: Union["pl.LightningModule", Module],
116+
optimizer: Optimizer,
117+
optimizer_idx: int,
118+
lambda_closure: Callable[[], Any],
119+
) -> Callable[[], Any]:
120+
"""This double-closure allows makes sure the ``lambda_closure`` is executed before the
121+
``on_before_optimizer_step`` hook is called.
122+
123+
The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
124+
consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(lambda_closure)`` directly.
125+
"""
126+
127+
def inner() -> Any:
128+
closure_result = lambda_closure()
129+
if isinstance(model, pl.LightningModule):
130+
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
131+
return closure_result
132+
133+
return inner
134+
113135
def optimizer_step(
114136
self,
115137
model: Union["pl.LightningModule", Module],
@@ -119,12 +141,11 @@ def optimizer_step(
119141
**kwargs: Any,
120142
) -> None:
121143
"""Hook to run the optimizer step."""
122-
if isinstance(model, pl.LightningModule):
123-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
124-
optimizer.step(closure=lambda_closure, **kwargs)
144+
closure = self._wrap_closure(model, optimizer, optimizer_idx, lambda_closure)
145+
optimizer.step(closure=closure, **kwargs)
125146

126147
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
127-
if float(trainer.track_grad_norm) == -1:
148+
if trainer.track_grad_norm == -1:
128149
return
129150
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
130151
if grad_norm_dict:

pytorch_lightning/plugins/precision/tpu.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ def optimizer_step(
3434
lambda_closure: Callable[[], Any],
3535
**kwargs: Any
3636
) -> None:
37-
if isinstance(model, pl.LightningModule):
38-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
39-
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs})
37+
closure = self._wrap_closure(model, optimizer, optimizer_idx, lambda_closure)
38+
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
4039
skipped_backward = closure_result is None
4140
# in manual optimization, the closure does not return a value
4241
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:

tests/models/test_hooks.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,7 @@ def _train_batch(self, *args, **kwargs):
275275
def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs):
276276
using_native_amp = kwargs.get("amp_backend") == "native"
277277
using_deepspeed = kwargs.get("strategy") == "deepspeed"
278-
using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy")
279278
out = []
280-
on_before_optimizer_step = [
281-
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
282-
dict(name="on_before_optimizer_step", args=(ANY, 0)),
283-
]
284279
for i in range(batches):
285280
out.extend(
286281
[
@@ -291,8 +286,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
291286
dict(name="Callback.on_batch_start", args=(trainer, model)),
292287
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)),
293288
dict(name="on_train_batch_start", args=(ANY, i)),
294-
# without a precision plugin, we execute the closure inside the `optimizer.step`
295-
*([] if using_plugin else on_before_optimizer_step),
296289
dict(name="forward", args=(ANY,)),
297290
dict(name="training_step", args=(ANY, i)),
298291
dict(name="training_step_end", args=(dict(loss=ANY),)),
@@ -305,7 +298,9 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
305298
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
306299
dict(name="Callback.on_after_backward", args=(trainer, model)),
307300
dict(name="on_after_backward"),
308-
*(on_before_optimizer_step if using_plugin else []),
301+
# note: unscaling happens here in the case of AMP
302+
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
303+
dict(name="on_before_optimizer_step", args=(ANY, 0)),
309304
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
310305
dict(
311306
name="clip_gradients",
@@ -334,7 +329,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
334329
@staticmethod
335330
def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs):
336331
using_deepspeed = kwargs.get("strategy") == "deepspeed"
337-
using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy")
338332
out = []
339333
for i in range(batches):
340334
out.extend(
@@ -355,11 +349,9 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
355349
dict(name="on_after_backward"),
356350
# `manual_backward` calls the previous 3
357351
dict(name="manual_backward", args=(ANY,)),
358-
*([dict(name="closure")] if using_plugin else []),
352+
dict(name="closure"),
359353
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
360354
dict(name="on_before_optimizer_step", args=(ANY, 0)),
361-
# without a precision plugin, we execute the closure inside the `optimizer.step`
362-
*([] if using_plugin else [dict(name="closure")]),
363355
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
364356
dict(name="training_step", args=(ANY, i)),
365357
dict(name="training_step_end", args=(dict(loss=ANY),)),

0 commit comments

Comments
 (0)