Skip to content

Fix gradient norm tracking and gradient clipping #9287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 89 commits into from
Oct 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
a890c4d
WIP
carmocca Sep 2, 2021
cb4533b
Progress
carmocca Sep 2, 2021
5335611
Merge branch 'master' into bugfix/track-grad-norm
carmocca Sep 3, 2021
1b9a4cb
Undo test change
carmocca Sep 3, 2021
14a9a93
Fix plugin closure execution order
carmocca Sep 2, 2021
7fe78fd
Update CHANGELOG
carmocca Sep 2, 2021
73b03d4
Fix manual optimization on AMP and skipping backward
carmocca Sep 3, 2021
d8a57e7
Fix for deepspeed
carmocca Sep 6, 2021
b945c1d
Typo
carmocca Sep 6, 2021
5696fb1
Hook test for manual closure
carmocca Sep 6, 2021
35a7bbc
Add skipping test with AMP
carmocca Sep 6, 2021
9b7df18
You are hideous, apex
carmocca Sep 6, 2021
1ba6432
Add deepspeed test
carmocca Sep 6, 2021
33ecf13
Update CHANGELOG
carmocca Sep 6, 2021
d09e753
Fix for broken master
carmocca Sep 6, 2021
df99e8d
Add RunIf
carmocca Sep 6, 2021
45947a7
Merge branch 'bugfix/plugin-closure-execution' into bugfix/track-grad…
carmocca Sep 6, 2021
904d6d3
FIXMEs
carmocca Sep 6, 2021
fd93617
Merge branch 'master' into bugfix/track-grad-norm
carmocca Sep 7, 2021
fd902be
Rename
carmocca Sep 17, 2021
b4eb544
Fix grad norm
carmocca Sep 17, 2021
6b78ff5
add a simple test
awaelchli Sep 20, 2021
452bbb5
update test
awaelchli Sep 20, 2021
e7370a0
update test
awaelchli Sep 20, 2021
efac53f
update test
awaelchli Sep 20, 2021
ff435ac
Merge branch 'master' into bugfix/track-grad-norm
awaelchli Sep 20, 2021
2a8e286
fix merge conflicts
awaelchli Sep 20, 2021
d79e0c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 20, 2021
abad9dc
Merge branch 'master' into bugfix/track-grad-norm
carmocca Sep 29, 2021
5ba5711
Sea of changes
carmocca Sep 29, 2021
1ad85fe
Undo change
carmocca Sep 29, 2021
2951698
Introduce TPUPrecisionPlugin
carmocca Sep 29, 2021
d587f4d
Merge gradient clipping customization changes
carmocca Oct 15, 2021
7817eb6
Undo changes
carmocca Oct 15, 2021
92c6429
Undo changes
carmocca Oct 15, 2021
f650272
Resolve FIXME
carmocca Oct 15, 2021
f44d9af
Undo change
carmocca Oct 15, 2021
5ac6376
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 18, 2021
5d90455
Undo change
carmocca Oct 18, 2021
0594879
Undo change
carmocca Oct 18, 2021
3a77d04
Fix FIXMEs
carmocca Oct 18, 2021
b544e9c
Fix FIXME
carmocca Oct 18, 2021
da5a2e8
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 18, 2021
a13a32d
Correct value
carmocca Oct 18, 2021
fc75821
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 19, 2021
bb2c504
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 19, 2021
9cb9873
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 21, 2021
0f42fa5
Bad merge
carmocca Oct 21, 2021
b32d5ac
Fix circular imports
carmocca Oct 21, 2021
6c40167
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 25, 2021
569e103
WIP
carmocca Oct 25, 2021
1f98d7b
Fixing clipping
carmocca Oct 25, 2021
aeed0ff
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 25, 2021
2d5040c
Fixes
carmocca Oct 25, 2021
18886db
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 25, 2021
f73539f
Bad merge
carmocca Oct 25, 2021
fa8901a
Move optimizer step and clipping into the `PrecisionPlugin`
carmocca Oct 25, 2021
8f8b601
Fix AMP
carmocca Oct 25, 2021
7eb639e
Update CHANGELOG
carmocca Oct 25, 2021
1a3e226
Fix tests
carmocca Oct 25, 2021
e52cef6
Underscore
carmocca Oct 25, 2021
df00f15
Merge branch 'refactotr/move-opt-step' into bugfix/track-grad-norm
carmocca Oct 25, 2021
b9cdc55
Progress
carmocca Oct 26, 2021
a5caefe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2021
e62ed4f
Remove pre_optimizer_step
carmocca Oct 26, 2021
e7060a4
Missed one
carmocca Oct 26, 2021
b24a533
Merge branch 'refactotr/move-opt-step' into bugfix/track-grad-norm
carmocca Oct 26, 2021
f82900d
Progress
carmocca Oct 26, 2021
f1f1709
Progress
carmocca Oct 26, 2021
818ca75
Fix test
carmocca Oct 26, 2021
a9b2cd0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2021
981ef1d
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 26, 2021
6e84b2f
Update FIXMEs
carmocca Oct 26, 2021
a0cc715
Fix test
carmocca Oct 26, 2021
d934e72
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 26, 2021
e6d1f97
Fix test
carmocca Oct 26, 2021
f3d126e
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 26, 2021
f068c04
DeepSpeed warning. mypy
carmocca Oct 26, 2021
3ce645f
Rename
carmocca Oct 26, 2021
9aae00d
Finish tests
carmocca Oct 27, 2021
0b488f8
Merge branch 'bugfix/track-grad-norm' of https://github.com/PyTorchLi…
carmocca Oct 27, 2021
a272f1d
Update CHANGELOG
carmocca Oct 27, 2021
737f3c5
Dumb fixes
carmocca Oct 27, 2021
6a4d1c8
accelerator=auto
carmocca Oct 27, 2021
8a30bc5
Apply suggestions from code review
carmocca Oct 27, 2021
c41d46b
Update on comments
carmocca Oct 27, 2021
a36d396
Use ClassifModule
carmocca Oct 27, 2021
e83bbb9
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 28, 2021
3067fb4
Merge branch 'master' into bugfix/track-grad-norm
carmocca Oct 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))


- Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287))


- Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118))


Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,18 @@ def optimizer_step(
"""
model = model or self.lightning_module
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs)
trainer = model.trainer
assert isinstance(trainer, pl.Trainer)
# TODO: this is done for the entire model but should be changed to per-optimizer
if opt_idx == 0:
self.precision_plugin._track_grad_norm(trainer)
self.precision_plugin._clip_gradients(
model,
optimizer,
opt_idx,
trainer.gradient_clip_val,
gradient_clip_algorithm=trainer.gradient_clip_algorithm,
)

def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
"""Zeros all model parameter's gradients."""
Expand Down
45 changes: 2 additions & 43 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm
from pytorch_lightning.utilities import AMPType, DeviceType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
Expand Down Expand Up @@ -228,25 +228,6 @@ def on_run_end(self) -> _OUTPUTS_TYPE:
outputs, self._outputs = self._outputs, {} # free memory
return outputs

def _backward(
self, loss: Tensor, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any
) -> None:
"""Performs the backward step.

Args:
loss: The loss value to back-propagate on
optimizer: Current optimizer being used
opt_idx: Index of the current optimizer being used
"""
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)

if not self.trainer.fit_loop._should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer, opt_idx=opt_idx)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)

def _run_optimization(
self, split_batch: Any, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
) -> ClosureResult:
Expand Down Expand Up @@ -343,7 +324,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
return None

def backward_fn(loss: Tensor) -> None:
self._backward(loss, optimizer, opt_idx)
self.trainer.accelerator.backward(loss, optimizer, opt_idx)

# check if model weights are nan
if self.trainer._terminate_on_nan:
Expand Down Expand Up @@ -473,25 +454,3 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
self.trainer._results.cpu()

return result

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -> Dict[str, float]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

Args:
optimizer: the current optimizer
"""
# track gradient norms
grad_norm_dict = {}
if self.trainer.track_grad_norm != -1:
grad_norm_dict = grad_norm(
self.trainer.lightning_module, self.trainer.track_grad_norm, self.trainer.logger.group_separator
)

# clip gradients
self.trainer.lightning_module.configure_gradient_clipping(
optimizer,
opt_idx,
gradient_clip_val=self.trainer.gradient_clip_val,
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm,
)
return grad_norm_dict
9 changes: 9 additions & 0 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,12 @@ def clip_gradients(
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""DeepSpeed handles gradient clipping internally."""

def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
# the gradients are not available in the model due to gradient partitioning in zero stage >= 2
warning_cache.warn(
f"You set `Trainer(track_grad_norm={trainer.track_grad_norm!r})' but this is not supported for DeepSpeed."
" The setting will be ignored."
)
30 changes: 29 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType, rank_zero_deprecation
from pytorch_lightning.utilities.types import _PARAMETERS


Expand Down Expand Up @@ -123,6 +123,34 @@ def optimizer_step(
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
optimizer.step(closure=lambda_closure, **kwargs)

def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if float(trainer.track_grad_norm) == -1:
return
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
if grad_norm_dict:
prev_fx = trainer.lightning_module._current_fx_name
trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
trainer.lightning_module.log_grad_norm(grad_norm_dict)
trainer.lightning_module._current_fx_name = prev_fx

def _clip_gradients(
self,
model: Union["pl.LightningModule", Module],
optimizer: Optimizer,
optimizer_idx: int,
clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None,
) -> None:
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization:
# the configuration validator disallows clipping on manual
return
model.configure_gradient_clipping(
optimizer,
optimizer_idx,
gradient_clip_val=clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
)

def clip_gradients(
self,
optimizer: Optimizer,
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,10 @@ def init_deepspeed(self):
# deepspeed handles gradient clipping internally
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
rank_zero_warn(
"Since deepspeed handles gradient clipping internally, `LightningModule.configure_gradient_clipping`"
" will be ignored. Consider setting `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
"Since DeepSpeed handles gradient clipping internally, the default"
" `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients."
" The hook will still be called. Consider setting"
" `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
" which will use the internal mechanism."
)

Expand Down
10 changes: 8 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __init__(self, called):
super().__init__()
pl_module_hooks = get_members(LightningModule)
# remove non-hooks
pl_module_hooks.difference_update({"optimizers"})
pl_module_hooks.difference_update({"optimizers", "log", "log_dict"})
# remove most `nn.Module` hooks
module_hooks = get_members(torch.nn.Module)
module_hooks.difference_update({"forward", "zero_grad", "train"})
Expand Down Expand Up @@ -305,6 +305,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
dict(name="Callback.on_after_backward", args=(trainer, model)),
dict(name="on_after_backward"),
*(on_before_optimizer_step if using_plugin else []),
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
dict(
name="clip_gradients",
args=(ANY,),
Expand All @@ -315,7 +317,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
args=(ANY, 0),
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
),
*(on_before_optimizer_step if using_plugin else []),
# this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
# the actual call to `PrecisionPlugin.optimizer_step`
dict(
name="optimizer_step",
args=(current_epoch, i, ANY, 0, ANY),
Expand Down Expand Up @@ -357,6 +360,7 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
dict(name="on_before_optimizer_step", args=(ANY, 0)),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*([] if using_plugin else [dict(name="closure")]),
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
Expand Down Expand Up @@ -484,6 +488,7 @@ def training_step(self, batch, batch_idx):
enable_progress_bar=False,
enable_model_summary=False,
callbacks=[callback],
track_grad_norm=1,
**kwargs,
)
assert called == [
Expand Down Expand Up @@ -608,6 +613,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
enable_progress_bar=False,
enable_model_summary=False,
callbacks=[callback],
track_grad_norm=1,
)
assert called == [
dict(name="Callback.on_init_start", args=(trainer,)),
Expand Down
53 changes: 47 additions & 6 deletions tests/plugins/test_amp_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,49 @@ def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, p


class GradientUnscaleBoringModel(BoringModel):
def on_before_optimizer_step(self, *_):
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 15.0
# sister test: tests/trainer/optimization/test_manual_optimization.py::test_multiple_optimizers_step
def on_after_backward(self) -> None:
# check grads are scaled
scale = self.trainer.precision_plugin.scaler.get_scale()
assert scale != 1.0 # the return value if not enabled
grads = [p.grad for p in self.parameters()]
inv_scale = 1 / scale
self.original_grads = [p * inv_scale for p in grads]

def check_grads_unscaled(self, optimizer=None):
if optimizer is not None:
scaler = self.trainer.precision_plugin.scaler
state = scaler._per_optimizer_states[id(optimizer)]
assert state["stage"].name == "UNSCALED"

grads = [p.grad for p in self.parameters()]
assert len(grads) == len(self.original_grads)
for actual, expected in zip(grads, self.original_grads):
torch.testing.assert_allclose(actual, expected)

def on_before_optimizer_step(self, optimizer, *_):
self.check_grads_unscaled(optimizer)
# manually clip
self.clipped_parameters = []
for p in self.parameters():
copy = p.detach().clone()
copy.grad = p.grad.clone()
self.clipped_parameters.append(copy)
clip_val = self.trainer.gradient_clip_val
torch.nn.utils.clip_grad_value_(self.clipped_parameters, clip_val)

def configure_gradient_clipping(self, *args, **kwargs):
# let lightning clip
super().configure_gradient_clipping(*args, **kwargs)
# check clipping worked as expected
parameters = list(self.parameters())
assert len(parameters) == len(self.clipped_parameters)
for actual, expected in zip(parameters, self.clipped_parameters):
torch.testing.assert_allclose(actual.grad, expected.grad)

def log_grad_norm(self, grad_norm_dict):
self.check_grads_unscaled()
assert len(grad_norm_dict)


@RunIf(min_gpus=2)
Expand All @@ -87,13 +126,15 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
max_epochs=2,
default_root_dir=tmpdir,
limit_train_batches=2,
limit_test_batches=2,
limit_val_batches=2,
limit_val_batches=0,
amp_backend="native",
strategy="ddp_spawn",
gpus=2,
precision=16,
track_grad_norm=2,
# use a tiny value to make sure it works
gradient_clip_val=1e-3,
gradient_clip_algorithm="value",
log_every_n_steps=1,
accumulate_grad_batches=accum,
)
Expand Down
11 changes: 7 additions & 4 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,20 @@ def test_deepspeed_defaults(tmpdir):


@RunIf(min_gpus=1, deepspeed=True, special=True)
def test_warn_deepspeed_override_backward(tmpdir):
"""Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning."""

def test_warn_deepspeed_ignored(tmpdir):
class TestModel(BoringModel):
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
return loss.backward()

model = TestModel()
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, strategy=DeepSpeedPlugin(), gpus=1, precision=16)
trainer = Trainer(
fast_dev_run=True, default_root_dir=tmpdir, strategy=DeepSpeedPlugin(), gpus=1, precision=16, track_grad_norm=2
)
from pytorch_lightning.plugins.precision.deepspeed_precision import warning_cache

with pytest.warns(UserWarning, match="will be ignored since DeepSpeed handles the backward"):
trainer.fit(model)
assert any("track_grad_norm=2.0)' but this is not supported" in w for w in warning_cache)


@RunIf(min_gpus=1, deepspeed=True)
Expand Down
39 changes: 28 additions & 11 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,7 @@ def test_multiple_optimizers_step(tmpdir):
"""Tests that `step` works with several optimizers."""

class TestModel(ManualOptModel):

called = False

def on_before_optimizer_step(self, *args):
self.called = True
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
if not (torch.isinf(norm) or torch.isnan(norm)):
assert norm.item() < 100, norm.item()

def training_step(self, batch, batch_idx):
# manual
opt_a, opt_b = self.optimizers()
x = batch[0]

Expand Down Expand Up @@ -436,6 +426,33 @@ def training_epoch_end(self, outputs) -> None:
# outputs should be an array with an entry per optimizer
assert len(outputs) == 2

# sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale
def on_after_backward(self) -> None:
# check grads are scaled
scale = self.trainer.precision_plugin.scaler.get_scale()
assert scale != 1.0 # the return value if not enabled
grads = [p.grad for p in self.parameters()]
inv_scale = 1 / scale
self.original_grads = [p * inv_scale for p in grads]

def check_grads_unscaled(self, optimizer=None):
if optimizer is not None:
scaler = self.trainer.precision_plugin.scaler
state = scaler._per_optimizer_states[id(optimizer)]
assert state["stage"].name == "UNSCALED"

grads = [p.grad for p in self.parameters()]
assert len(grads) == len(self.original_grads)
for actual, expected in zip(grads, self.original_grads):
torch.testing.assert_allclose(actual, expected)

def on_before_optimizer_step(self, optimizer, *_):
self.check_grads_unscaled(optimizer)

def log_grad_norm(self, grad_norm_dict):
self.check_grads_unscaled()
assert len(grad_norm_dict)

model = TestModel()
model.val_dataloader = None

Expand All @@ -450,12 +467,12 @@ def training_epoch_end(self, outputs) -> None:
precision=16,
amp_backend="native",
gpus=1,
track_grad_norm=2,
)

with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock:
trainer.fit(model)
assert bwd_mock.call_count == limit_train_batches * 3
assert model.called


def test_step_with_optimizer_closure(tmpdir):
Expand Down
Loading