Skip to content

Commit 9237106

Browse files
authored
Clip before step (#10248)
1 parent d33429c commit 9237106

File tree

9 files changed

+57
-36
lines changed

9 files changed

+57
-36
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -335,23 +335,6 @@ def optimizer_step(
335335
model = model or self.lightning_module
336336
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
337337

338-
if not isinstance(model, pl.LightningModule):
339-
# gradient clipping and norm tracking only available with a LightingModule/Trainer
340-
return
341-
342-
trainer = model.trainer
343-
assert isinstance(trainer, pl.Trainer)
344-
# TODO: this is done for the entire model but should be changed to per-optimizer
345-
if opt_idx == 0:
346-
self.precision_plugin._track_grad_norm(trainer)
347-
self.precision_plugin._clip_gradients(
348-
model,
349-
optimizer,
350-
opt_idx,
351-
trainer.gradient_clip_val,
352-
gradient_clip_algorithm=trainer.gradient_clip_algorithm,
353-
)
354-
355338
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
356339
"""Zeros all model parameter's gradients."""
357340
model_ref = self.lightning_module

pytorch_lightning/core/hooks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,13 @@ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) ->
258258
259259
The hook is only called if gradients do not need to be accumulated.
260260
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.
261+
261262
If using native AMP, the loss will be unscaled before calling this hook.
262263
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
263264
for more information on the scaling of gradients.
264265
266+
If clipping gradients, the gradients will not have been clipped yet.
267+
265268
Args:
266269
optimizer: Current optimizer being used.
267270
optimizer_idx: Index of the current optimizer being used.

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,8 @@ def __to_tensor(self, value: numbers.Number) -> torch.Tensor:
578578
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
579579
"""Override this method to change the default behaviour of ``log_grad_norm``.
580580
581+
If clipping gradients, the gradients will not have been clipped yet.
582+
581583
Args:
582584
grad_norm_dict: Dictionary containing current grad norm metrics
583585

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def optimizer_step(
109109
f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
110110
)
111111
closure_result = closure()
112-
if isinstance(model, pl.LightningModule):
113-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
112+
self._after_closure(model, optimizer, optimizer_idx)
114113
skipped_backward = closure_result is None
115114
# in manual optimization, the closure does not return a value
116115
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ def optimizer_step(
6363
f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
6464
)
6565
closure_result = closure()
66-
if isinstance(model, pl.LightningModule):
67-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
66+
self._after_closure(model, optimizer, optimizer_idx)
6867
skipped_backward = closure_result is None
6968
# in manual optimization, the closure does not return a value
7069
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:

pytorch_lightning/plugins/precision/ipu_precision.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ def optimizer_step(
5252
f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
5353
)
5454
closure_result = closure()
55-
if isinstance(model, pl.LightningModule):
56-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
55+
self._after_closure(model, optimizer, optimizer_idx)
5756
skipped_backward = closure_result is None
5857
# in manual optimization, the closure does not return a value
5958
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ def optimizer_step(
8585
closure_result = closure()
8686
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
8787
self.scaler.unscale_(optimizer)
88-
if isinstance(model, pl.LightningModule):
89-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
88+
self._after_closure(model, optimizer, optimizer_idx)
9089
skipped_backward = closure_result is None
9190
# in manual optimization, the closure does not return a value
9291
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,27 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k
111111
"""
112112
tensor.backward(*args, **kwargs)
113113

114+
def _after_closure(
115+
self, model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int
116+
) -> None:
117+
"""Utility to share some code after the closure has been run."""
118+
if not isinstance(model, pl.LightningModule):
119+
# none of this applies to Lite
120+
return
121+
trainer = model.trainer
122+
assert trainer is not None
123+
trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
124+
# TODO: this is done for the entire model but should be changed to per-optimizer
125+
if optimizer_idx == 0:
126+
self._track_grad_norm(trainer)
127+
self._clip_gradients(
128+
model,
129+
optimizer,
130+
optimizer_idx,
131+
trainer.gradient_clip_val,
132+
gradient_clip_algorithm=trainer.gradient_clip_algorithm,
133+
)
134+
114135
def _wrap_closure(
115136
self,
116137
model: "pl.LightningModule",
@@ -125,7 +146,7 @@ def _wrap_closure(
125146
consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
126147
"""
127148
closure_result = closure()
128-
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
149+
self._after_closure(model, optimizer, optimizer_idx)
129150
return closure_result
130151

131152
def optimizer_step(

tests/plugins/test_amp_plugins.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, p
7171
assert isinstance(trainer.precision_plugin, plugin_cls)
7272

7373

74-
class GradientUnscaleBoringModel(BoringModel):
74+
class TestClippingOptimizer(torch.optim.SGD):
75+
def step(self, *args, pl_module=None):
76+
pl_module.check_grads_clipped()
77+
return super().step(*args)
78+
79+
80+
class TestPrecisionModel(BoringModel):
7581
# sister test: tests/trainer/optimization/test_manual_optimization.py::test_multiple_optimizers_step
7682
def on_after_backward(self) -> None:
7783
# check grads are scaled
@@ -92,6 +98,12 @@ def check_grads_unscaled(self, optimizer=None):
9298
for actual, expected in zip(grads, self.original_grads):
9399
torch.testing.assert_allclose(actual, expected)
94100

101+
def check_grads_clipped(self):
102+
parameters = list(self.parameters())
103+
assert len(parameters) == len(self.clipped_parameters)
104+
for actual, expected in zip(parameters, self.clipped_parameters):
105+
torch.testing.assert_allclose(actual.grad, expected.grad)
106+
95107
def on_before_optimizer_step(self, optimizer, *_):
96108
self.check_grads_unscaled(optimizer)
97109
# manually clip
@@ -103,24 +115,28 @@ def on_before_optimizer_step(self, optimizer, *_):
103115
clip_val = self.trainer.gradient_clip_val
104116
torch.nn.utils.clip_grad_value_(self.clipped_parameters, clip_val)
105117

118+
def log_grad_norm(self, grad_norm_dict):
119+
self.check_grads_unscaled()
120+
assert len(grad_norm_dict)
121+
106122
def configure_gradient_clipping(self, *args, **kwargs):
107123
# let lightning clip
108124
super().configure_gradient_clipping(*args, **kwargs)
109125
# check clipping worked as expected
110-
parameters = list(self.parameters())
111-
assert len(parameters) == len(self.clipped_parameters)
112-
for actual, expected in zip(parameters, self.clipped_parameters):
113-
torch.testing.assert_allclose(actual.grad, expected.grad)
126+
self.check_grads_clipped()
114127

115-
def log_grad_norm(self, grad_norm_dict):
116-
self.check_grads_unscaled()
117-
assert len(grad_norm_dict)
128+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, closure, **_):
129+
# pass self as a kwarg
130+
optimizer.step(closure, pl_module=self)
131+
132+
def configure_optimizers(self):
133+
return TestClippingOptimizer(self.layer.parameters(), lr=0.1)
118134

119135

120136
@RunIf(min_gpus=2)
121137
@pytest.mark.parametrize("accum", [1, 2])
122138
def test_amp_gradient_unscale(tmpdir, accum: int):
123-
model = GradientUnscaleBoringModel()
139+
model = TestPrecisionModel()
124140

125141
trainer = Trainer(
126142
max_epochs=2,
@@ -137,6 +153,7 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
137153
gradient_clip_algorithm="value",
138154
log_every_n_steps=1,
139155
accumulate_grad_batches=accum,
156+
enable_progress_bar=False,
140157
)
141158
trainer.fit(model)
142159

@@ -200,7 +217,6 @@ def training_step(self, batch, batch_idx):
200217
@RunIf(min_gpus=2, amp_apex=True)
201218
@pytest.mark.parametrize("amp_level", ["O2"])
202219
def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):
203-
204220
trainer = Trainer(
205221
default_root_dir=tmpdir,
206222
fast_dev_run=True,

0 commit comments

Comments
 (0)