Skip to content

Clip before step #10248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,23 +334,6 @@ def optimizer_step(
model = model or self.lightning_module
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

if not isinstance(model, pl.LightningModule):
# gradient clipping and norm tracking only available with a LightingModule/Trainer
return

trainer = model.trainer
assert isinstance(trainer, pl.Trainer)
# TODO: this is done for the entire model but should be changed to per-optimizer
if opt_idx == 0:
self.precision_plugin._track_grad_norm(trainer)
self.precision_plugin._clip_gradients(
model,
optimizer,
opt_idx,
trainer.gradient_clip_val,
gradient_clip_algorithm=trainer.gradient_clip_algorithm,
)

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients."""
model_ref = self.lightning_module
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,13 @@ def on_before_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) ->

The hook is only called if gradients do not need to be accumulated.
See: :paramref:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches`.

If using native AMP, the loss will be unscaled before calling this hook.
See these `docs <https://pytorch.org/docs/stable/notes/amp_examples.html#working-with-unscaled-gradients>`__
for more information on the scaling of gradients.

If clipping gradients, the gradients will not have been clipped yet.

Args:
optimizer: Current optimizer being used.
optimizer_idx: Index of the current optimizer being used.
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,8 @@ def __to_tensor(self, value: numbers.Number) -> torch.Tensor:
def log_grad_norm(self, grad_norm_dict: Dict[str, float]) -> None:
"""Override this method to change the default behaviour of ``log_grad_norm``.

If clipping gradients, the gradients will not have been clipped yet.

Args:
grad_norm_dict: Dictionary containing current grad norm metrics

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def optimizer_step(
f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = closure()
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def optimizer_step(
f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = closure()
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def optimizer_step(
f"IPUs and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
)
closure_result = closure()
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@ def optimizer_step(
closure_result = closure()
# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.
self.scaler.unscale_(optimizer)
if isinstance(model, pl.LightningModule):
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
Expand Down
23 changes: 22 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,27 @@ def _run_backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **k
"""
tensor.backward(*args, **kwargs)

def _after_closure(
self, model: Union["pl.LightningModule", Module], optimizer: Optimizer, optimizer_idx: int
) -> None:
"""Utility to share some code after the closure has been run."""
if not isinstance(model, pl.LightningModule):
# none of this applies to Lite
return
trainer = model.trainer
assert trainer is not None
trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
# TODO: this is done for the entire model but should be changed to per-optimizer
if optimizer_idx == 0:
self._track_grad_norm(trainer)
self._clip_gradients(
model,
optimizer,
optimizer_idx,
trainer.gradient_clip_val,
gradient_clip_algorithm=trainer.gradient_clip_algorithm,
)

def _wrap_closure(
self,
model: "pl.LightningModule",
Expand All @@ -125,7 +146,7 @@ def _wrap_closure(
consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly.
"""
closure_result = closure()
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
self._after_closure(model, optimizer, optimizer_idx)
return closure_result

def optimizer_step(
Expand Down
36 changes: 26 additions & 10 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,13 @@ def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, p
assert isinstance(trainer.precision_plugin, plugin_cls)


class GradientUnscaleBoringModel(BoringModel):
class TestClippingOptimizer(torch.optim.SGD):
def step(self, *args, pl_module=None):
pl_module.check_grads_clipped()
return super().step(*args)


class TestPrecisionModel(BoringModel):
# sister test: tests/trainer/optimization/test_manual_optimization.py::test_multiple_optimizers_step
def on_after_backward(self) -> None:
# check grads are scaled
Expand All @@ -92,6 +98,12 @@ def check_grads_unscaled(self, optimizer=None):
for actual, expected in zip(grads, self.original_grads):
torch.testing.assert_allclose(actual, expected)

def check_grads_clipped(self):
parameters = list(self.parameters())
assert len(parameters) == len(self.clipped_parameters)
for actual, expected in zip(parameters, self.clipped_parameters):
torch.testing.assert_allclose(actual.grad, expected.grad)

def on_before_optimizer_step(self, optimizer, *_):
self.check_grads_unscaled(optimizer)
# manually clip
Expand All @@ -103,24 +115,28 @@ def on_before_optimizer_step(self, optimizer, *_):
clip_val = self.trainer.gradient_clip_val
torch.nn.utils.clip_grad_value_(self.clipped_parameters, clip_val)

def log_grad_norm(self, grad_norm_dict):
self.check_grads_unscaled()
assert len(grad_norm_dict)

def configure_gradient_clipping(self, *args, **kwargs):
# let lightning clip
super().configure_gradient_clipping(*args, **kwargs)
# check clipping worked as expected
parameters = list(self.parameters())
assert len(parameters) == len(self.clipped_parameters)
for actual, expected in zip(parameters, self.clipped_parameters):
torch.testing.assert_allclose(actual.grad, expected.grad)
self.check_grads_clipped()

def log_grad_norm(self, grad_norm_dict):
self.check_grads_unscaled()
assert len(grad_norm_dict)
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, closure, **_):
# pass self as a kwarg
optimizer.step(closure, pl_module=self)

def configure_optimizers(self):
return TestClippingOptimizer(self.layer.parameters(), lr=0.1)


@RunIf(min_gpus=2)
@pytest.mark.parametrize("accum", [1, 2])
def test_amp_gradient_unscale(tmpdir, accum: int):
model = GradientUnscaleBoringModel()
model = TestPrecisionModel()

trainer = Trainer(
max_epochs=2,
Expand All @@ -137,6 +153,7 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
gradient_clip_algorithm="value",
log_every_n_steps=1,
accumulate_grad_batches=accum,
enable_progress_bar=False,
)
trainer.fit(model)

Expand Down Expand Up @@ -200,7 +217,6 @@ def training_step(self, batch, batch_idx):
@RunIf(min_gpus=2, amp_apex=True)
@pytest.mark.parametrize("amp_level", ["O2"])
def test_amp_apex_ddp_spawn_fit(amp_level, tmpdir):

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
Expand Down