Skip to content

Commit 81d15c5

Browse files
authored
Implement double optimizer closure for hook structure consistency (#10167)
1 parent c211adb commit 81d15c5

File tree

10 files changed

+59
-43
lines changed

10 files changed

+59
-43
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
666666
- Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287))
667667

668668

669+
- Fixed `on_before_optimizer_step` getting called before the optimizer closure (including backward) has run ([#10167](https://github.com/PyTorchLightning/pytorch-lightning/pull/10167))
670+
671+
669672
- Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118))
670673

671674

pytorch_lightning/accelerators/accelerator.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def optimizer_step(
318318
self,
319319
optimizer: Optimizer,
320320
opt_idx: int,
321-
lambda_closure: Callable[[], Any],
321+
closure: Callable[[], Any],
322322
model: Optional[Union["pl.LightningModule", Module]] = None,
323323
**kwargs: Any
324324
) -> None:
@@ -327,12 +327,17 @@ def optimizer_step(
327327
Args:
328328
optimizer: the optimizer performing the step
329329
opt_idx: index of the current optimizer
330-
lambda_closure: closure calculating the loss value
330+
closure: closure calculating the loss value
331331
model: reference to the model, optionally defining optimizer step related hooks
332332
**kwargs: Any extra arguments to ``optimizer.step``
333333
"""
334334
model = model or self.lightning_module
335-
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs)
335+
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
336+
337+
if not isinstance(model, pl.LightningModule):
338+
# gradient clipping and norm tracking only available with a LightingModule/Trainer
339+
return
340+
336341
trainer = model.trainer
337342
assert isinstance(trainer, pl.Trainer)
338343
# TODO: this is done for the entire model but should be changed to per-optimizer

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,14 +101,14 @@ def optimizer_step(
101101
model: Union["pl.LightningModule", Module],
102102
optimizer: Optimizer,
103103
optimizer_idx: int,
104-
lambda_closure: Callable[[], Any],
104+
closure: Callable[[], Any],
105105
**kwargs: Any,
106106
) -> None:
107107
if isinstance(optimizer, LBFGS):
108108
raise MisconfigurationException(
109109
f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
110110
)
111-
closure_result = lambda_closure()
111+
closure_result = closure()
112112
if isinstance(model, pl.LightningModule):
113113
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
114114
skipped_backward = closure_result is None

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,14 @@ def optimizer_step(
5555
model: Union["pl.LightningModule", Module],
5656
optimizer: Optimizer,
5757
optimizer_idx: int,
58-
lambda_closure: Callable[[], Any],
58+
closure: Callable[[], Any],
5959
**kwargs: Any,
6060
) -> None:
6161
if isinstance(optimizer, LBFGS):
6262
raise MisconfigurationException(
6363
f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
6464
)
65-
closure_result = lambda_closure()
65+
closure_result = closure()
6666
if isinstance(model, pl.LightningModule):
6767
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
6868
skipped_backward = closure_result is None

pytorch_lightning/plugins/precision/ipu_precision.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,15 @@ def optimizer_step(
4343
model: Union["pl.LightningModule", Module],
4444
optimizer: Optimizer,
4545
optimizer_idx: int,
46-
lambda_closure: Callable[[], Any],
46+
closure: Callable[[], Any],
4747
**kwargs: Any,
4848
) -> None:
4949
"""IPUs handle the optimizer step internally."""
5050
if isinstance(optimizer, LBFGS):
5151
raise MisconfigurationException(
5252
f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
5353
)
54-
closure_result = lambda_closure()
54+
closure_result = closure()
5555
if isinstance(model, pl.LightningModule):
5656
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
5757
skipped_backward = closure_result is None

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,17 @@ def optimizer_step(
7272
model: Union["pl.LightningModule", Module],
7373
optimizer: Optimizer,
7474
optimizer_idx: int,
75-
lambda_closure: Callable[[], Any],
75+
closure: Callable[[], Any],
7676
**kwargs: Any,
7777
) -> None:
7878
if self.scaler is None:
7979
# skip scaler logic, as bfloat16 does not require scaler
80-
return super().optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
80+
return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)
8181
if isinstance(optimizer, LBFGS):
8282
raise MisconfigurationException(
8383
f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
8484
)
85-
closure_result = lambda_closure()
85+
closure_result = closure()
8686
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
8787
self.scaler.unscale_(optimizer)
8888
if isinstance(model, pl.LightningModule):

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import contextlib
15+
from functools import partial
1516
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
1617

1718
import torch
@@ -110,21 +111,38 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k
110111
"""
111112
tensor.backward(*args, **kwargs)
112113

114+
def _wrap_closure(
115+
self,
116+
model: "pl.LightningModule",
117+
optimizer: Optimizer,
118+
optimizer_idx: int,
119+
closure: Callable[[], Any],
120+
) -> Any:
121+
"""This double-closure allows makes sure the ``closure`` is executed before the
122+
``on_before_optimizer_step`` hook is called.
123+
124+
The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is
125+
consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
126+
"""
127+
closure_result = closure()
128+
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
129+
return closure_result
130+
113131
def optimizer_step(
114132
self,
115133
model: Union["pl.LightningModule", Module],
116134
optimizer: Optimizer,
117135
optimizer_idx: int,
118-
lambda_closure: Callable[[], Any],
136+
closure: Callable[[], Any],
119137
**kwargs: Any,
120138
) -> None:
121139
"""Hook to run the optimizer step."""
122140
if isinstance(model, pl.LightningModule):
123-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
124-
optimizer.step(closure=lambda_closure, **kwargs)
141+
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
142+
optimizer.step(closure=closure, **kwargs)
125143

126144
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
127-
if float(trainer.track_grad_norm) == -1:
145+
if trainer.track_grad_norm == -1:
128146
return
129147
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
130148
if grad_norm_dict:

pytorch_lightning/plugins/precision/tpu.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from functools import partial
1415
from typing import Any, Callable, Union
1516

1617
from torch.nn import Module
@@ -31,12 +32,12 @@ def optimizer_step(
3132
model: Union["pl.LightningModule", Module],
3233
optimizer: Optimizer,
3334
optimizer_idx: int,
34-
lambda_closure: Callable[[], Any],
35+
closure: Callable[[], Any],
3536
**kwargs: Any
3637
) -> None:
3738
if isinstance(model, pl.LightningModule):
38-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
39-
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": lambda_closure, **kwargs})
39+
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
40+
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
4041
skipped_backward = closure_result is None
4142
# in manual optimization, the closure does not return a value
4243
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:

tests/core/test_lightning_module.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -349,23 +349,20 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va
349349

350350
for pg in optimizer.param_groups:
351351
for p in pg["params"]:
352-
p.grad[p.grad > self.custom_gradient_clip_val] = self.custom_gradient_clip_val
353-
p.grad[p.grad <= 0] = 0
354-
355-
def on_before_optimizer_step(self, optimizer, optimizer_idx):
356-
for pg in optimizer.param_groups:
357-
for p in pg["params"]:
358-
if p.grad is not None and p.grad.abs().sum() > 0:
359-
self.has_validated_gradients = True
360-
assert p.grad.min() >= 0
361-
assert p.grad.max() <= self.custom_gradient_clip_val
352+
p.grad.clamp_(min=0, max=self.custom_gradient_clip_val)
362353

363354
model = TestModel()
364355
trainer = Trainer(
365-
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4
356+
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=0, gradient_clip_val=1e-4
366357
)
367358
trainer.fit(model)
368-
assert model.has_validated_gradients
359+
360+
optimizer = model.optimizers()
361+
for pg in optimizer.param_groups:
362+
for p in pg["params"]:
363+
if p.grad is not None:
364+
assert p.grad.min() >= 0
365+
assert p.grad.max() <= model.custom_gradient_clip_val
369366

370367

371368
def test_lightning_module_configure_gradient_clipping_different_argument_values(tmpdir):

tests/models/test_hooks.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,7 @@ def _train_batch(self, *args, **kwargs):
275275
def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs):
276276
using_native_amp = kwargs.get("amp_backend") == "native"
277277
using_deepspeed = kwargs.get("strategy") == "deepspeed"
278-
using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy")
279278
out = []
280-
on_before_optimizer_step = [
281-
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
282-
dict(name="on_before_optimizer_step", args=(ANY, 0)),
283-
]
284279
for i in range(batches):
285280
out.extend(
286281
[
@@ -291,8 +286,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
291286
dict(name="Callback.on_batch_start", args=(trainer, model)),
292287
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)),
293288
dict(name="on_train_batch_start", args=(ANY, i)),
294-
# without a precision plugin, we execute the closure inside the `optimizer.step`
295-
*([] if using_plugin else on_before_optimizer_step),
296289
dict(name="forward", args=(ANY,)),
297290
dict(name="training_step", args=(ANY, i)),
298291
dict(name="training_step_end", args=(dict(loss=ANY),)),
@@ -305,7 +298,9 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
305298
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
306299
dict(name="Callback.on_after_backward", args=(trainer, model)),
307300
dict(name="on_after_backward"),
308-
*(on_before_optimizer_step if using_plugin else []),
301+
# note: unscaling happens here in the case of AMP
302+
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
303+
dict(name="on_before_optimizer_step", args=(ANY, 0)),
309304
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
310305
dict(
311306
name="clip_gradients",
@@ -334,7 +329,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
334329
@staticmethod
335330
def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **kwargs):
336331
using_deepspeed = kwargs.get("strategy") == "deepspeed"
337-
using_plugin = kwargs.get("amp_backend") or kwargs.get("strategy")
338332
out = []
339333
for i in range(batches):
340334
out.extend(
@@ -355,11 +349,9 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
355349
dict(name="on_after_backward"),
356350
# `manual_backward` calls the previous 3
357351
dict(name="manual_backward", args=(ANY,)),
358-
*([dict(name="closure")] if using_plugin else []),
352+
dict(name="closure"),
359353
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
360354
dict(name="on_before_optimizer_step", args=(ANY, 0)),
361-
# without a precision plugin, we execute the closure inside the `optimizer.step`
362-
*([] if using_plugin else [dict(name="closure")]),
363355
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
364356
dict(name="training_step", args=(ANY, i)),
365357
dict(name="training_step_end", args=(dict(loss=ANY),)),

0 commit comments

Comments
 (0)