Skip to content

Commit 03f01fb

Browse files
carmoccaawaelchlipre-commit-ci[bot]rohitgr7
authored
Fix gradient norm tracking and gradient clipping (#9287)
* WIP * Progress * Undo test change * Fix plugin closure execution order * Update CHANGELOG * Fix manual optimization on AMP and skipping backward * Fix for deepspeed * Typo * Hook test for manual closure * Add skipping test with AMP * You are hideous, apex * Add deepspeed test * Update CHANGELOG * Fix for broken master * Add RunIf * FIXMEs * Rename * Fix grad norm * add a simple test * update test * update test * update test * fix merge conflicts * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Sea of changes * Undo change * Introduce TPUPrecisionPlugin * Undo changes * Undo changes * Resolve FIXME * Undo change * Undo change * Undo change * Fix FIXMEs * Fix FIXME * Correct value * Bad merge * Fix circular imports * WIP * Fixing clipping * Fixes * Bad merge * Move optimizer step and clipping into the `PrecisionPlugin` * Fix AMP * Update CHANGELOG * Fix tests * Underscore * Progress * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove pre_optimizer_step * Missed one * Progress * Progress * Fix test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update FIXMEs * Fix test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix test * DeepSpeed warning. mypy * Rename * Finish tests * Update CHANGELOG * Dumb fixes * accelerator=auto * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * Update on comments * Use ClassifModule Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 5262b63 commit 03f01fb

File tree

11 files changed

+182
-110
lines changed

11 files changed

+182
-110
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
662662
- Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764))
663663

664664

665+
- Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287))
666+
667+
665668
- Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118))
666669

667670

pytorch_lightning/accelerators/accelerator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,18 @@ def optimizer_step(
333333
"""
334334
model = model or self.lightning_module
335335
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs)
336+
trainer = model.trainer
337+
assert isinstance(trainer, pl.Trainer)
338+
# TODO: this is done for the entire model but should be changed to per-optimizer
339+
if opt_idx == 0:
340+
self.precision_plugin._track_grad_norm(trainer)
341+
self.precision_plugin._clip_gradients(
342+
model,
343+
optimizer,
344+
opt_idx,
345+
trainer.gradient_clip_val,
346+
gradient_clip_algorithm=trainer.gradient_clip_algorithm,
347+
)
336348

337349
def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
338350
"""Zeros all model parameter's gradients."""

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 2 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
3232
from pytorch_lightning.trainer.progress import OptimizationProgress
33-
from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm
33+
from pytorch_lightning.utilities import AMPType, DeviceType
3434
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3535
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
3636
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
@@ -228,25 +228,6 @@ def on_run_end(self) -> _OUTPUTS_TYPE:
228228
outputs, self._outputs = self._outputs, {} # free memory
229229
return outputs
230230

231-
def _backward(
232-
self, loss: Tensor, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any
233-
) -> None:
234-
"""Performs the backward step.
235-
236-
Args:
237-
loss: The loss value to back-propagate on
238-
optimizer: Current optimizer being used
239-
opt_idx: Index of the current optimizer being used
240-
"""
241-
self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs)
242-
243-
if not self.trainer.fit_loop._should_accumulate():
244-
# track gradients
245-
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer, opt_idx=opt_idx)
246-
if grad_norm_dict:
247-
self.trainer.lightning_module._current_fx_name = "on_after_backward"
248-
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
249-
250231
def _run_optimization(
251232
self, split_batch: Any, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int
252233
) -> ClosureResult:
@@ -343,7 +324,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
343324
return None
344325

345326
def backward_fn(loss: Tensor) -> None:
346-
self._backward(loss, optimizer, opt_idx)
327+
self.trainer.accelerator.backward(loss, optimizer, opt_idx)
347328

348329
# check if model weights are nan
349330
if self.trainer._terminate_on_nan:
@@ -473,25 +454,3 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
473454
self.trainer._results.cpu()
474455

475456
return result
476-
477-
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -> Dict[str, float]:
478-
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
479-
480-
Args:
481-
optimizer: the current optimizer
482-
"""
483-
# track gradient norms
484-
grad_norm_dict = {}
485-
if self.trainer.track_grad_norm != -1:
486-
grad_norm_dict = grad_norm(
487-
self.trainer.lightning_module, self.trainer.track_grad_norm, self.trainer.logger.group_separator
488-
)
489-
490-
# clip gradients
491-
self.trainer.lightning_module.configure_gradient_clipping(
492-
optimizer,
493-
opt_idx,
494-
gradient_clip_val=self.trainer.gradient_clip_val,
495-
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm,
496-
)
497-
return grad_norm_dict

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,12 @@ def clip_gradients(
8282
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
8383
) -> None:
8484
"""DeepSpeed handles gradient clipping internally."""
85+
86+
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
87+
if trainer.track_grad_norm == -1:
88+
return
89+
# the gradients are not available in the model due to gradient partitioning in zero stage >= 2
90+
warning_cache.warn(
91+
f"You set `Trainer(track_grad_norm={trainer.track_grad_norm!r})' but this is not supported for DeepSpeed."
92+
" The setting will be ignored."
93+
)

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import pytorch_lightning as pl
2323
from pytorch_lightning.core.hooks import CheckpointHooks
24-
from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation
24+
from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType, rank_zero_deprecation
2525
from pytorch_lightning.utilities.types import _PARAMETERS
2626

2727

@@ -123,6 +123,34 @@ def optimizer_step(
123123
model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx)
124124
optimizer.step(closure=lambda_closure, **kwargs)
125125

126+
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
127+
if float(trainer.track_grad_norm) == -1:
128+
return
129+
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator)
130+
if grad_norm_dict:
131+
prev_fx = trainer.lightning_module._current_fx_name
132+
trainer.lightning_module._current_fx_name = "on_before_optimizer_step"
133+
trainer.lightning_module.log_grad_norm(grad_norm_dict)
134+
trainer.lightning_module._current_fx_name = prev_fx
135+
136+
def _clip_gradients(
137+
self,
138+
model: Union["pl.LightningModule", Module],
139+
optimizer: Optimizer,
140+
optimizer_idx: int,
141+
clip_val: Optional[Union[int, float]] = None,
142+
gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None,
143+
) -> None:
144+
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization:
145+
# the configuration validator disallows clipping on manual
146+
return
147+
model.configure_gradient_clipping(
148+
optimizer,
149+
optimizer_idx,
150+
gradient_clip_val=clip_val,
151+
gradient_clip_algorithm=gradient_clip_algorithm,
152+
)
153+
126154
def clip_gradients(
127155
self,
128156
optimizer: Optimizer,

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,10 @@ def init_deepspeed(self):
441441
# deepspeed handles gradient clipping internally
442442
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
443443
rank_zero_warn(
444-
"Since deepspeed handles gradient clipping internally, `LightningModule.configure_gradient_clipping`"
445-
" will be ignored. Consider setting `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
444+
"Since DeepSpeed handles gradient clipping internally, the default"
445+
" `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients."
446+
" The hook will still be called. Consider setting"
447+
" `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
446448
" which will use the internal mechanism."
447449
)
448450

tests/models/test_hooks.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def __init__(self, called):
234234
super().__init__()
235235
pl_module_hooks = get_members(LightningModule)
236236
# remove non-hooks
237-
pl_module_hooks.difference_update({"optimizers"})
237+
pl_module_hooks.difference_update({"optimizers", "log", "log_dict"})
238238
# remove most `nn.Module` hooks
239239
module_hooks = get_members(torch.nn.Module)
240240
module_hooks.difference_update({"forward", "zero_grad", "train"})
@@ -305,6 +305,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
305305
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
306306
dict(name="Callback.on_after_backward", args=(trainer, model)),
307307
dict(name="on_after_backward"),
308+
*(on_before_optimizer_step if using_plugin else []),
309+
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
308310
dict(
309311
name="clip_gradients",
310312
args=(ANY,),
@@ -315,7 +317,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
315317
args=(ANY, 0),
316318
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
317319
),
318-
*(on_before_optimizer_step if using_plugin else []),
320+
# this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
321+
# the actual call to `PrecisionPlugin.optimizer_step`
319322
dict(
320323
name="optimizer_step",
321324
args=(current_epoch, i, ANY, 0, ANY),
@@ -357,6 +360,7 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k
357360
dict(name="on_before_optimizer_step", args=(ANY, 0)),
358361
# without a precision plugin, we execute the closure inside the `optimizer.step`
359362
*([] if using_plugin else [dict(name="closure")]),
363+
*([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []),
360364
dict(name="training_step", args=(ANY, i)),
361365
dict(name="training_step_end", args=(dict(loss=ANY),)),
362366
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
@@ -484,6 +488,7 @@ def training_step(self, batch, batch_idx):
484488
enable_progress_bar=False,
485489
enable_model_summary=False,
486490
callbacks=[callback],
491+
track_grad_norm=1,
487492
**kwargs,
488493
)
489494
assert called == [
@@ -608,6 +613,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
608613
enable_progress_bar=False,
609614
enable_model_summary=False,
610615
callbacks=[callback],
616+
track_grad_norm=1,
611617
)
612618
assert called == [
613619
dict(name="Callback.on_init_start", args=(trainer,)),

tests/plugins/test_amp_plugins.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,49 @@ def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, p
7272

7373

7474
class GradientUnscaleBoringModel(BoringModel):
75-
def on_before_optimizer_step(self, *_):
76-
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
77-
if not (torch.isinf(norm) or torch.isnan(norm)):
78-
assert norm.item() < 15.0
75+
# sister test: tests/trainer/optimization/test_manual_optimization.py::test_multiple_optimizers_step
76+
def on_after_backward(self) -> None:
77+
# check grads are scaled
78+
scale = self.trainer.precision_plugin.scaler.get_scale()
79+
assert scale != 1.0 # the return value if not enabled
80+
grads = [p.grad for p in self.parameters()]
81+
inv_scale = 1 / scale
82+
self.original_grads = [p * inv_scale for p in grads]
83+
84+
def check_grads_unscaled(self, optimizer=None):
85+
if optimizer is not None:
86+
scaler = self.trainer.precision_plugin.scaler
87+
state = scaler._per_optimizer_states[id(optimizer)]
88+
assert state["stage"].name == "UNSCALED"
89+
90+
grads = [p.grad for p in self.parameters()]
91+
assert len(grads) == len(self.original_grads)
92+
for actual, expected in zip(grads, self.original_grads):
93+
torch.testing.assert_allclose(actual, expected)
94+
95+
def on_before_optimizer_step(self, optimizer, *_):
96+
self.check_grads_unscaled(optimizer)
97+
# manually clip
98+
self.clipped_parameters = []
99+
for p in self.parameters():
100+
copy = p.detach().clone()
101+
copy.grad = p.grad.clone()
102+
self.clipped_parameters.append(copy)
103+
clip_val = self.trainer.gradient_clip_val
104+
torch.nn.utils.clip_grad_value_(self.clipped_parameters, clip_val)
105+
106+
def configure_gradient_clipping(self, *args, **kwargs):
107+
# let lightning clip
108+
super().configure_gradient_clipping(*args, **kwargs)
109+
# check clipping worked as expected
110+
parameters = list(self.parameters())
111+
assert len(parameters) == len(self.clipped_parameters)
112+
for actual, expected in zip(parameters, self.clipped_parameters):
113+
torch.testing.assert_allclose(actual.grad, expected.grad)
114+
115+
def log_grad_norm(self, grad_norm_dict):
116+
self.check_grads_unscaled()
117+
assert len(grad_norm_dict)
79118

80119

81120
@RunIf(min_gpus=2)
@@ -87,13 +126,15 @@ def test_amp_gradient_unscale(tmpdir, accum: int):
87126
max_epochs=2,
88127
default_root_dir=tmpdir,
89128
limit_train_batches=2,
90-
limit_test_batches=2,
91-
limit_val_batches=2,
129+
limit_val_batches=0,
92130
amp_backend="native",
93131
strategy="ddp_spawn",
94132
gpus=2,
95133
precision=16,
96134
track_grad_norm=2,
135+
# use a tiny value to make sure it works
136+
gradient_clip_val=1e-3,
137+
gradient_clip_algorithm="value",
97138
log_every_n_steps=1,
98139
accumulate_grad_batches=accum,
99140
)

tests/plugins/test_deepspeed_plugin.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,17 +203,20 @@ def test_deepspeed_defaults(tmpdir):
203203

204204

205205
@RunIf(min_gpus=1, deepspeed=True, special=True)
206-
def test_warn_deepspeed_override_backward(tmpdir):
207-
"""Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning."""
208-
206+
def test_warn_deepspeed_ignored(tmpdir):
209207
class TestModel(BoringModel):
210208
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
211209
return loss.backward()
212210

213211
model = TestModel()
214-
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, strategy=DeepSpeedPlugin(), gpus=1, precision=16)
212+
trainer = Trainer(
213+
fast_dev_run=True, default_root_dir=tmpdir, strategy=DeepSpeedPlugin(), gpus=1, precision=16, track_grad_norm=2
214+
)
215+
from pytorch_lightning.plugins.precision.deepspeed_precision import warning_cache
216+
215217
with pytest.warns(UserWarning, match="will be ignored since DeepSpeed handles the backward"):
216218
trainer.fit(model)
219+
assert any("track_grad_norm=2.0)' but this is not supported" in w for w in warning_cache)
217220

218221

219222
@RunIf(min_gpus=1, deepspeed=True)

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -393,17 +393,7 @@ def test_multiple_optimizers_step(tmpdir):
393393
"""Tests that `step` works with several optimizers."""
394394

395395
class TestModel(ManualOptModel):
396-
397-
called = False
398-
399-
def on_before_optimizer_step(self, *args):
400-
self.called = True
401-
norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2)
402-
if not (torch.isinf(norm) or torch.isnan(norm)):
403-
assert norm.item() < 100, norm.item()
404-
405396
def training_step(self, batch, batch_idx):
406-
# manual
407397
opt_a, opt_b = self.optimizers()
408398
x = batch[0]
409399

@@ -436,6 +426,33 @@ def training_epoch_end(self, outputs) -> None:
436426
# outputs should be an array with an entry per optimizer
437427
assert len(outputs) == 2
438428

429+
# sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale
430+
def on_after_backward(self) -> None:
431+
# check grads are scaled
432+
scale = self.trainer.precision_plugin.scaler.get_scale()
433+
assert scale != 1.0 # the return value if not enabled
434+
grads = [p.grad for p in self.parameters()]
435+
inv_scale = 1 / scale
436+
self.original_grads = [p * inv_scale for p in grads]
437+
438+
def check_grads_unscaled(self, optimizer=None):
439+
if optimizer is not None:
440+
scaler = self.trainer.precision_plugin.scaler
441+
state = scaler._per_optimizer_states[id(optimizer)]
442+
assert state["stage"].name == "UNSCALED"
443+
444+
grads = [p.grad for p in self.parameters()]
445+
assert len(grads) == len(self.original_grads)
446+
for actual, expected in zip(grads, self.original_grads):
447+
torch.testing.assert_allclose(actual, expected)
448+
449+
def on_before_optimizer_step(self, optimizer, *_):
450+
self.check_grads_unscaled(optimizer)
451+
452+
def log_grad_norm(self, grad_norm_dict):
453+
self.check_grads_unscaled()
454+
assert len(grad_norm_dict)
455+
439456
model = TestModel()
440457
model.val_dataloader = None
441458

@@ -450,12 +467,12 @@ def training_epoch_end(self, outputs) -> None:
450467
precision=16,
451468
amp_backend="native",
452469
gpus=1,
470+
track_grad_norm=2,
453471
)
454472

455473
with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock:
456474
trainer.fit(model)
457475
assert bwd_mock.call_count == limit_train_batches * 3
458-
assert model.called
459476

460477

461478
def test_step_with_optimizer_closure(tmpdir):

0 commit comments

Comments
 (0)