Skip to content

Fix gradient norm tracking and gradient clipping #9287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 89 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
a890c4d
WIP
carmocca Sep 2, 2021
cb4533b
Progress
carmocca Sep 2, 2021
5335611
Merge branch 'master' into bugfix/track-grad-norm
carmocca Sep 3, 2021
1b9a4cb
Undo test change
carmocca Sep 3, 2021
14a9a93
Fix plugin closure execution order
carmocca Sep 2, 2021
7fe78fd
Update CHANGELOG
carmocca Sep 2, 2021
73b03d4
Fix manual optimization on AMP and skipping backward
carmocca Sep 3, 2021
d8a57e7
Fix for deepspeed
carmocca Sep 6, 2021
b945c1d
Typo
carmocca Sep 6, 2021
5696fb1
Hook test for manual closure
carmocca Sep 6, 2021
35a7bbc
Add skipping test with AMP
carmocca Sep 6, 2021
9b7df18
You are hideous, apex
carmocca Sep 6, 2021
1ba6432
Add deepspeed test
carmocca Sep 6, 2021
33ecf13
Update CHANGELOG
carmocca Sep 6, 2021
d09e753
Fix for broken master
carmocca Sep 6, 2021
df99e8d
Add RunIf
carmocca Sep 6, 2021
45947a7
Merge branch 'bugfix/plugin-closure-execution' into bugfix/track-grad…
carmocca Sep 6, 2021
904d6d3
FIXMEs
carmocca Sep 6, 2021
fd93617
Merge branch 'master' into bugfix/track-grad-norm
carmocca Sep 7, 2021
fd902be
Rename
carmocca Sep 17, 2021
b4eb544
Fix grad norm
carmocca Sep 17, 2021
6b78ff5
add a simple test
awaelchli Sep 20, 2021
452bbb5
update test
awaelchli Sep 20, 2021
e7370a0
update test
awaelchli Sep 20, 2021
efac53f
update test
awaelchli Sep 20, 2021
ff435ac
Merge branch 'master' into bugfix/track-grad-norm
awaelchli Sep 20, 2021
2a8e286
fix merge conflicts
awaelchli Sep 20, 2021
d79e0c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2021
abad9dc
Merge branch 'master' into bugfix/track-grad-norm
carmocca Sep 29, 2021
5ba5711
Sea of changes
carmocca Sep 29, 2021
1ad85fe
Undo change
carmocca Sep 29, 2021
2951698
Introduce TPUPrecisionPlugin
carmocca Sep 29, 2021
d587f4d
Merge gradient clipping customization changes
carmocca Oct 15, 2021
7817eb6
Undo changes
carmocca Oct 15, 2021
92c6429
Undo changes
carmocca Oct 15, 2021
f650272
Resolve FIXME
carmocca Oct 15, 2021
f44d9af
Undo change
carmocca Oct 15, 2021
5ac6376
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 18, 2021
5d90455
Undo change
carmocca Oct 18, 2021
0594879
Undo change
carmocca Oct 18, 2021
3a77d04
Fix FIXMEs
carmocca Oct 18, 2021
b544e9c
Fix FIXME
carmocca Oct 18, 2021
da5a2e8
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 18, 2021
a13a32d
Correct value
carmocca Oct 18, 2021
fc75821
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 19, 2021
bb2c504
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 19, 2021
9cb9873
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 21, 2021
0f42fa5
Bad merge
carmocca Oct 21, 2021
b32d5ac
Fix circular imports
carmocca Oct 21, 2021
6c40167
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 25, 2021
569e103
WIP
carmocca Oct 25, 2021
1f98d7b
Fixing clipping
carmocca Oct 25, 2021
aeed0ff
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 25, 2021
2d5040c
Fixes
carmocca Oct 25, 2021
18886db
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 25, 2021
f73539f
Bad merge
carmocca Oct 25, 2021
fa8901a
Move optimizer step and clipping into the `PrecisionPlugin`
carmocca Oct 25, 2021
8f8b601
Fix AMP
carmocca Oct 25, 2021
7eb639e
Update CHANGELOG
carmocca Oct 25, 2021
1a3e226
Fix tests
carmocca Oct 25, 2021
e52cef6
Underscore
carmocca Oct 25, 2021
df00f15
Merge branch 'refactotr/move-opt-step' into bugfix/track-grad-norm
carmocca Oct 25, 2021
b9cdc55
Progress
carmocca Oct 26, 2021
a5caefe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2021
e62ed4f
Remove pre_optimizer_step
carmocca Oct 26, 2021
e7060a4
Missed one
carmocca Oct 26, 2021
b24a533
Merge branch 'refactotr/move-opt-step' into bugfix/track-grad-norm
carmocca Oct 26, 2021
f82900d
Progress
carmocca Oct 26, 2021
f1f1709
Progress
carmocca Oct 26, 2021
818ca75
Fix test
carmocca Oct 26, 2021
a9b2cd0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2021
981ef1d
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 26, 2021
6e84b2f
Update FIXMEs
carmocca Oct 26, 2021
a0cc715
Fix test
carmocca Oct 26, 2021
d934e72
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2021
e6d1f97
Fix test
carmocca Oct 26, 2021
f3d126e
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 26, 2021
f068c04
DeepSpeed warning. mypy
carmocca Oct 26, 2021
3ce645f
Rename
carmocca Oct 26, 2021
9aae00d
Finish tests
carmocca Oct 27, 2021
0b488f8
Merge branch 'bugfix/track-grad-norm' of https://github.com/PyTorchLi…
carmocca Oct 27, 2021
a272f1d
Update CHANGELOG
carmocca Oct 27, 2021
737f3c5
Dumb fixes
carmocca Oct 27, 2021
6a4d1c8
accelerator=auto
carmocca Oct 27, 2021
8a30bc5
Apply suggestions from code review
carmocca Oct 27, 2021
c41d46b
Update on comments
carmocca Oct 27, 2021
a36d396
Use ClassifModule
carmocca Oct 27, 2021
e83bbb9
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 28, 2021
3067fb4
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _NATIVE_AMP_AVAILABLE:
Expand Down Expand Up @@ -304,17 +304,6 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""clips all the optimizer parameters to the given value"""
self.precision_plugin.clip_gradients(
optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=self.model
)

def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""
Creates optimizers and schedulers
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,11 +1414,8 @@ def training_step(...):
*args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward`
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
"""
# make sure we're using manual opt
self._verify_is_manual_optimization("manual_backward")

# backward
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, None, None, *args, **kwargs)
self.trainer.accelerator.backward(loss, None, None, *args, **kwargs)

def backward(
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
Expand Down
20 changes: 0 additions & 20 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import Any, Callable, List, Optional, Tuple

import numpy as np
import torch
from deprecate import void
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -246,25 +245,6 @@ def _tbptt_split_batch(self, batch: Any) -> List[Any]:
splits = model_ref.tbptt_split_batch(batch, tbptt_steps)
return splits

# TODO: remove this method and update tests
def backward(
self,
loss: Tensor,
optimizer: Optional[torch.optim.Optimizer],
opt_idx: Optional[int] = None,
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Performs the backward step.

Args:
loss: The loss value to back-propagate on
optimizer: Current optimizer being used. ``None`` if using manual optimization.
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
"""
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
return loss

def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value"""
if self.trainer.lightning_module.automatic_optimization:
Expand Down
50 changes: 3 additions & 47 deletions pytorch_lightning/loops/optimizer/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple

import torch
from torch import Tensor
Expand All @@ -32,7 +32,7 @@
)
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE

Expand Down Expand Up @@ -91,31 +91,6 @@ def on_run_end(self) -> Tuple[_OUTPUTS_TYPE, Optional[Any]]:
self._hiddens = None
return outputs, hiddens

def backward(
self,
loss: Tensor,
optimizer: Optional[torch.optim.Optimizer],
opt_idx: Optional[int] = None,
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Performs the backward step.

Args:
loss: The loss value to back-propagate on
optimizer: Current optimizer being used. ``None`` if using manual optimization.
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
"""
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)

if not self.trainer.fit_loop.should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
return loss

def _run_optimization(
self,
batch_idx: int,
Expand Down Expand Up @@ -217,7 +192,7 @@ def _make_backward_fn(
"""

def backward_fn(loss: Tensor):
self.backward(loss, optimizer, opt_idx)
self.trainer.accelerator.backward(loss, optimizer, opt_idx)

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
Expand Down Expand Up @@ -355,22 +330,3 @@ def _training_step(
# the loss will get scaled for amp. avoid any modifications to it
loss = closure_loss.detach().clone()
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

Args:
optimizer: the current optimizer
"""
# track gradient norms
grad_norm_dict = {}
can_log = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
should_track = float(self.trainer.track_grad_norm) > 0
if should_track and can_log:
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)

# clip gradients
self.trainer.accelerator.clip_gradients(
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
)
return grad_norm_dict
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
lambda_closure() # APEX amp does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # APEX amp does not support closures
optimizer.step(**kwargs)
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
lambda_closure() # DeepSpeed does not support closures
super().pre_optimizer_step(model, optimizer, optimizer_idx, lambda_closure, **kwargs)
# the following should be in a `optimizer_step` hook but we don't have one in the precision plugin.
lambda_closure() # DeepSpeed does not support closures
deepspeed_engine = model.trainer.model
deepspeed_engine.step()
return False
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def pre_optimizer_step(
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)
result = True
# FIXME: is this correct for manual?
if model.automatic_optimization:
result = lambda_closure()
self.scaler.unscale_(optimizer)
Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType
from pytorch_lightning.utilities.types import _PARAMETERS


Expand Down Expand Up @@ -100,9 +100,25 @@ def pre_optimizer_step(
**kwargs: Any,
) -> bool:
"""Hook to do something before each optimizer step."""
trainer = model.trainer
self._track_and_norm_grad(trainer)
self.clip_gradients(
optimizer, trainer.gradient_clip_val, gradient_clip_algorithm=trainer.gradient_clip_algorithm, model=model
)
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
return True

def _track_and_norm_grad(self, trainer: "pl.Trainer") -> None:
"""Tracks the model's gradient norms."""
if float(trainer.track_grad_norm) < 0:
return
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
trainer.lightning_module.log_grad_norm(grad_norm_dict)
trainer.lightning_module._current_fx_name = prev_fx

def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

Expand Down
18 changes: 11 additions & 7 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __init__(self, called):
super().__init__()
pl_module_hooks = get_members(LightningModule)
# remove non-hooks
pl_module_hooks.difference_update({"optimizers"})
pl_module_hooks.difference_update({"optimizers", "log", "log_dict"})
# remove most `nn.Module` hooks
module_hooks = get_members(torch.nn.Module)
module_hooks.difference_update({"forward", "zero_grad", "train"})
Expand Down Expand Up @@ -277,8 +277,10 @@ def _train_batch(self, *args, **kwargs):
def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), current_epoch=0, **kwargs):
using_native_amp = kwargs.get("amp_backend") == "native"
using_deepspeed = kwargs.get("plugins") == "deepspeed"
using_plugin = kwargs.get("amp_backend") or kwargs.get("plugins")
out = []
on_before_optimizer_step = [
dict(name="log_grad_norm", args=ANY),
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
dict(name="on_before_optimizer_step", args=(ANY, 0)),
]
Expand All @@ -292,10 +294,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
dict(name="Callback.on_batch_start", args=(trainer, model)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)),
dict(name="on_train_batch_start", args=(ANY, i, 0)),
# these are before the training step because
# they are not part of the `training_step_and_backward` closure, however,
# with native amp, the closure is run first and then the optimizer step.
*(on_before_optimizer_step if not using_native_amp else []),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*(on_before_optimizer_step if not using_plugin else []),
dict(name="forward", args=(ANY,)),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
Expand All @@ -308,7 +308,7 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name="Callback.on_after_backward", args=(trainer, model)),
dict(name="on_after_backward"),
*(on_before_optimizer_step if using_native_amp else []),
*(on_before_optimizer_step if using_plugin else []),
dict(
name="optimizer_step",
args=(current_epoch, i, ANY, 0, ANY),
Expand Down Expand Up @@ -344,6 +344,7 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
dict(name="on_after_backward"),
# `manual_backward` calls the previous 3
dict(name="manual_backward", args=(ANY,)),
dict(name="log_grad_norm", args=ANY),
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
dict(name="on_before_optimizer_step", args=(ANY, 0)),
dict(name="training_step", args=(ANY, i)),
Expand Down Expand Up @@ -456,6 +457,7 @@ def training_step(self, batch, batch_idx):
progress_bar_refresh_rate=0,
weights_summary=None,
callbacks=[callback],
track_grad_norm=1,
**kwargs,
)

Expand Down Expand Up @@ -554,7 +556,9 @@ def training_step(self, batch, batch_idx):
dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")),
dict(name="teardown", kwargs=dict(stage="fit")),
]
assert called == expected
assert [c["name"] for c in called] == [e["name"] for e in expected]
for c, e in zip(called, expected):
assert c["name"] == e["name"]


def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
Expand Down