diff --git a/CHANGELOG.md b/CHANGELOG.md index 29a5ec5b002f4..a14abdbf6178e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -662,6 +662,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed undesired side effects being caused by `Trainer` patching dataloader methods on the `LightningModule` ([#9764](https://github.com/PyTorchLightning/pytorch-lightning/pull/9764)) +- Fixed gradients not being unscaled when clipping or logging the gradient norm ([#9287](https://github.com/PyTorchLightning/pytorch-lightning/pull/9287)) + + - Fixed monitor value in `ModelCheckpoint` getting moved to the wrong device in a special case where it becomes NaN ([#10118](https://github.com/PyTorchLightning/pytorch-lightning/pull/10118)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ffe05172cc66c..a8acec23c6ed3 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -333,6 +333,18 @@ def optimizer_step( """ model = model or self.lightning_module self.precision_plugin.optimizer_step(model, optimizer, opt_idx, lambda_closure, **kwargs) + trainer = model.trainer + assert isinstance(trainer, pl.Trainer) + # TODO: this is done for the entire model but should be changed to per-optimizer + if opt_idx == 0: + self.precision_plugin._track_grad_norm(trainer) + self.precision_plugin._clip_gradients( + model, + optimizer, + opt_idx, + trainer.gradient_clip_val, + gradient_clip_algorithm=trainer.gradient_clip_algorithm, + ) def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None: """Zeros all model parameter's gradients.""" diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index b55578b2f0c74..ecd52d76bd0da 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -30,7 +30,7 @@ ) from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler from pytorch_lightning.trainer.progress import OptimizationProgress -from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm +from pytorch_lightning.utilities import AMPType, DeviceType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.finite_checks import detect_nan_parameters from pytorch_lightning.utilities.imports import _TPU_AVAILABLE @@ -228,25 +228,6 @@ def on_run_end(self) -> _OUTPUTS_TYPE: outputs, self._outputs = self._outputs, {} # free memory return outputs - def _backward( - self, loss: Tensor, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any - ) -> None: - """Performs the backward step. - - Args: - loss: The loss value to back-propagate on - optimizer: Current optimizer being used - opt_idx: Index of the current optimizer being used - """ - self.trainer.accelerator.backward(loss, optimizer, opt_idx, *args, **kwargs) - - if not self.trainer.fit_loop._should_accumulate(): - # track gradients - grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer, opt_idx=opt_idx) - if grad_norm_dict: - self.trainer.lightning_module._current_fx_name = "on_after_backward" - self.trainer.lightning_module.log_grad_norm(grad_norm_dict) - def _run_optimization( self, split_batch: Any, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int ) -> ClosureResult: @@ -343,7 +324,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call return None def backward_fn(loss: Tensor) -> None: - self._backward(loss, optimizer, opt_idx) + self.trainer.accelerator.backward(loss, optimizer, opt_idx) # check if model weights are nan if self.trainer._terminate_on_nan: @@ -473,25 +454,3 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos self.trainer._results.cpu() return result - - def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -> Dict[str, float]: - """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. - - Args: - optimizer: the current optimizer - """ - # track gradient norms - grad_norm_dict = {} - if self.trainer.track_grad_norm != -1: - grad_norm_dict = grad_norm( - self.trainer.lightning_module, self.trainer.track_grad_norm, self.trainer.logger.group_separator - ) - - # clip gradients - self.trainer.lightning_module.configure_gradient_clipping( - optimizer, - opt_idx, - gradient_clip_val=self.trainer.gradient_clip_val, - gradient_clip_algorithm=self.trainer.gradient_clip_algorithm, - ) - return grad_norm_dict diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 478e8e529bbd9..704e8fe3f5c69 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -82,3 +82,12 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """DeepSpeed handles gradient clipping internally.""" + + def _track_grad_norm(self, trainer: "pl.Trainer") -> None: + if trainer.track_grad_norm == -1: + return + # the gradients are not available in the model due to gradient partitioning in zero stage >= 2 + warning_cache.warn( + f"You set `Trainer(track_grad_norm={trainer.track_grad_norm!r})' but this is not supported for DeepSpeed." + " The setting will be ignored." + ) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 62e1e33232480..dcbaa76b48559 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -21,7 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks -from pytorch_lightning.utilities import GradClipAlgorithmType, rank_zero_deprecation +from pytorch_lightning.utilities import grad_norm, GradClipAlgorithmType, rank_zero_deprecation from pytorch_lightning.utilities.types import _PARAMETERS @@ -123,6 +123,34 @@ def optimizer_step( model.trainer.call_hook("on_before_optimizer_step", optimizer, optimizer_idx) optimizer.step(closure=lambda_closure, **kwargs) + def _track_grad_norm(self, trainer: "pl.Trainer") -> None: + if float(trainer.track_grad_norm) == -1: + return + grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, trainer.logger.group_separator) + if grad_norm_dict: + prev_fx = trainer.lightning_module._current_fx_name + trainer.lightning_module._current_fx_name = "on_before_optimizer_step" + trainer.lightning_module.log_grad_norm(grad_norm_dict) + trainer.lightning_module._current_fx_name = prev_fx + + def _clip_gradients( + self, + model: Union["pl.LightningModule", Module], + optimizer: Optimizer, + optimizer_idx: int, + clip_val: Optional[Union[int, float]] = None, + gradient_clip_algorithm: Optional[GradClipAlgorithmType] = None, + ) -> None: + if not isinstance(model, pl.LightningModule) or not model.automatic_optimization: + # the configuration validator disallows clipping on manual + return + model.configure_gradient_clipping( + optimizer, + optimizer_idx, + gradient_clip_val=clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, + ) + def clip_gradients( self, optimizer: Optimizer, diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index f6b5481dd5ef9..955f8973d3c18 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -441,8 +441,10 @@ def init_deepspeed(self): # deepspeed handles gradient clipping internally if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule): rank_zero_warn( - "Since deepspeed handles gradient clipping internally, `LightningModule.configure_gradient_clipping`" - " will be ignored. Consider setting `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" + "Since DeepSpeed handles gradient clipping internally, the default" + " `LightningModule.configure_gradient_clipping` implementation will not actually clip gradients." + " The hook will still be called. Consider setting" + " `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`" " which will use the internal mechanism." ) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index cec9032e226e4..ccd31af4dbd3e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -234,7 +234,7 @@ def __init__(self, called): super().__init__() pl_module_hooks = get_members(LightningModule) # remove non-hooks - pl_module_hooks.difference_update({"optimizers"}) + pl_module_hooks.difference_update({"optimizers", "log", "log_dict"}) # remove most `nn.Module` hooks module_hooks = get_members(torch.nn.Module) module_hooks.difference_update({"forward", "zero_grad", "train"}) @@ -305,6 +305,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre *([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []), dict(name="Callback.on_after_backward", args=(trainer, model)), dict(name="on_after_backward"), + *(on_before_optimizer_step if using_plugin else []), + *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict( name="clip_gradients", args=(ANY,), @@ -315,7 +317,8 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre args=(ANY, 0), kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None), ), - *(on_before_optimizer_step if using_plugin else []), + # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates + # the actual call to `PrecisionPlugin.optimizer_step` dict( name="optimizer_step", args=(current_epoch, i, ANY, 0, ANY), @@ -357,6 +360,7 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k dict(name="on_before_optimizer_step", args=(ANY, 0)), # without a precision plugin, we execute the closure inside the `optimizer.step` *([] if using_plugin else [dict(name="closure")]), + *([dict(name="log_grad_norm", args=ANY)] if not using_deepspeed else []), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)), @@ -484,6 +488,7 @@ def training_step(self, batch, batch_idx): enable_progress_bar=False, enable_model_summary=False, callbacks=[callback], + track_grad_norm=1, **kwargs, ) assert called == [ @@ -608,6 +613,7 @@ def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): enable_progress_bar=False, enable_model_summary=False, callbacks=[callback], + track_grad_norm=1, ) assert called == [ dict(name="Callback.on_init_start", args=(trainer,)), diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index fa9926a6f5a18..a14f3af4d9c17 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -72,10 +72,49 @@ def test_amp_apex_ddp(mocked_device_count, strategy, gpus, amp, custom_plugin, p class GradientUnscaleBoringModel(BoringModel): - def on_before_optimizer_step(self, *_): - norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) - if not (torch.isinf(norm) or torch.isnan(norm)): - assert norm.item() < 15.0 + # sister test: tests/trainer/optimization/test_manual_optimization.py::test_multiple_optimizers_step + def on_after_backward(self) -> None: + # check grads are scaled + scale = self.trainer.precision_plugin.scaler.get_scale() + assert scale != 1.0 # the return value if not enabled + grads = [p.grad for p in self.parameters()] + inv_scale = 1 / scale + self.original_grads = [p * inv_scale for p in grads] + + def check_grads_unscaled(self, optimizer=None): + if optimizer is not None: + scaler = self.trainer.precision_plugin.scaler + state = scaler._per_optimizer_states[id(optimizer)] + assert state["stage"].name == "UNSCALED" + + grads = [p.grad for p in self.parameters()] + assert len(grads) == len(self.original_grads) + for actual, expected in zip(grads, self.original_grads): + torch.testing.assert_allclose(actual, expected) + + def on_before_optimizer_step(self, optimizer, *_): + self.check_grads_unscaled(optimizer) + # manually clip + self.clipped_parameters = [] + for p in self.parameters(): + copy = p.detach().clone() + copy.grad = p.grad.clone() + self.clipped_parameters.append(copy) + clip_val = self.trainer.gradient_clip_val + torch.nn.utils.clip_grad_value_(self.clipped_parameters, clip_val) + + def configure_gradient_clipping(self, *args, **kwargs): + # let lightning clip + super().configure_gradient_clipping(*args, **kwargs) + # check clipping worked as expected + parameters = list(self.parameters()) + assert len(parameters) == len(self.clipped_parameters) + for actual, expected in zip(parameters, self.clipped_parameters): + torch.testing.assert_allclose(actual.grad, expected.grad) + + def log_grad_norm(self, grad_norm_dict): + self.check_grads_unscaled() + assert len(grad_norm_dict) @RunIf(min_gpus=2) @@ -87,13 +126,15 @@ def test_amp_gradient_unscale(tmpdir, accum: int): max_epochs=2, default_root_dir=tmpdir, limit_train_batches=2, - limit_test_batches=2, - limit_val_batches=2, + limit_val_batches=0, amp_backend="native", strategy="ddp_spawn", gpus=2, precision=16, track_grad_norm=2, + # use a tiny value to make sure it works + gradient_clip_val=1e-3, + gradient_clip_algorithm="value", log_every_n_steps=1, accumulate_grad_batches=accum, ) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index e9fc6b1eab9ae..c599b96c6ffdb 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -203,17 +203,20 @@ def test_deepspeed_defaults(tmpdir): @RunIf(min_gpus=1, deepspeed=True, special=True) -def test_warn_deepspeed_override_backward(tmpdir): - """Test to ensure that if the backward hook in the LightningModule is overridden, we throw a warning.""" - +def test_warn_deepspeed_ignored(tmpdir): class TestModel(BoringModel): def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: return loss.backward() model = TestModel() - trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, strategy=DeepSpeedPlugin(), gpus=1, precision=16) + trainer = Trainer( + fast_dev_run=True, default_root_dir=tmpdir, strategy=DeepSpeedPlugin(), gpus=1, precision=16, track_grad_norm=2 + ) + from pytorch_lightning.plugins.precision.deepspeed_precision import warning_cache + with pytest.warns(UserWarning, match="will be ignored since DeepSpeed handles the backward"): trainer.fit(model) + assert any("track_grad_norm=2.0)' but this is not supported" in w for w in warning_cache) @RunIf(min_gpus=1, deepspeed=True) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 2dfe7a68338d3..5c86fd6343002 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -393,17 +393,7 @@ def test_multiple_optimizers_step(tmpdir): """Tests that `step` works with several optimizers.""" class TestModel(ManualOptModel): - - called = False - - def on_before_optimizer_step(self, *args): - self.called = True - norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) - if not (torch.isinf(norm) or torch.isnan(norm)): - assert norm.item() < 100, norm.item() - def training_step(self, batch, batch_idx): - # manual opt_a, opt_b = self.optimizers() x = batch[0] @@ -436,6 +426,33 @@ def training_epoch_end(self, outputs) -> None: # outputs should be an array with an entry per optimizer assert len(outputs) == 2 + # sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale + def on_after_backward(self) -> None: + # check grads are scaled + scale = self.trainer.precision_plugin.scaler.get_scale() + assert scale != 1.0 # the return value if not enabled + grads = [p.grad for p in self.parameters()] + inv_scale = 1 / scale + self.original_grads = [p * inv_scale for p in grads] + + def check_grads_unscaled(self, optimizer=None): + if optimizer is not None: + scaler = self.trainer.precision_plugin.scaler + state = scaler._per_optimizer_states[id(optimizer)] + assert state["stage"].name == "UNSCALED" + + grads = [p.grad for p in self.parameters()] + assert len(grads) == len(self.original_grads) + for actual, expected in zip(grads, self.original_grads): + torch.testing.assert_allclose(actual, expected) + + def on_before_optimizer_step(self, optimizer, *_): + self.check_grads_unscaled(optimizer) + + def log_grad_norm(self, grad_norm_dict): + self.check_grads_unscaled() + assert len(grad_norm_dict) + model = TestModel() model.val_dataloader = None @@ -450,12 +467,12 @@ def training_epoch_end(self, outputs) -> None: precision=16, amp_backend="native", gpus=1, + track_grad_norm=2, ) with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 - assert model.called def test_step_with_optimizer_closure(tmpdir): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a45bf105722cf..95b2840b19cb9 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -47,7 +47,9 @@ from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset from tests.helpers.boring_model import RandomIterableDataset, RandomIterableDatasetWithLen +from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf +from tests.helpers.simple_models import ClassificationModel @pytest.mark.parametrize("url_ckpt", [True, False]) @@ -991,75 +993,65 @@ def on_keyboard_interrupt(self, trainer, pl_module): assert isinstance(handle_interrupt_callback.exception, MisconfigurationException) -@pytest.mark.parametrize( - "precision", - [32, pytest.param(16, marks=RunIf(min_gpus=1))], -) +@pytest.mark.parametrize("precision", [32, pytest.param(16, marks=RunIf(min_gpus=1))]) def test_gradient_clipping_by_norm(tmpdir, precision): """Test gradient clipping by norm.""" tutils.reset_seed() - model = EvalModelTemplate() # TODO: when precision=16, BoringModel produces NaN, but EvalModelTemplate not trainer = Trainer( default_root_dir=tmpdir, max_steps=1, max_epochs=1, - gpus=int(torch.cuda.is_available()), + accelerator="auto", + devices=1, precision=precision, gradient_clip_algorithm="norm", - gradient_clip_val=1.0, + gradient_clip_val=0.05, ) - old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward + class TestModel(ClassificationModel): + def configure_gradient_clipping(self, *args, **kwargs): + super().configure_gradient_clipping(*args, **kwargs) + # test that gradient is clipped correctly + parameters = self.parameters() + grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + torch.testing.assert_allclose(grad_norm, torch.tensor(0.05)) + self.assertion_called = True - def backward(*args, **kwargs): - # test that gradient is clipped correctly - ret_val = old_backward(*args, **kwargs) - parameters = model.parameters() - grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) - assert (grad_norm - 1.0).abs() < 0.01, f"Gradient norm != 1.0: {grad_norm}" - return ret_val - - trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward = backward - trainer.fit(model) + model = TestModel() + trainer.fit(model, ClassifDataModule()) + assert model.assertion_called -@pytest.mark.parametrize( - "precision", - [32, pytest.param(16, marks=RunIf(min_gpus=1))], -) +@pytest.mark.parametrize("precision", [32, pytest.param(16, marks=RunIf(min_gpus=1))]) def test_gradient_clipping_by_value(tmpdir, precision): """Test gradient clipping by value.""" tutils.reset_seed() - model = BoringModel() - - grad_clip_val = 1e-10 trainer = Trainer( + default_root_dir=tmpdir, max_steps=1, max_epochs=1, + accelerator="auto", + devices=1, precision=precision, - gpus=int(torch.cuda.is_available()), - gradient_clip_val=grad_clip_val, gradient_clip_algorithm="value", - default_root_dir=tmpdir, + gradient_clip_val=1e-10, ) - old_backward = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward - - def backward(*args, **kwargs): - # test that gradient is clipped correctly - ret_val = old_backward(*args, **kwargs) - parameters = model.parameters() - grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] - grad_max = torch.max(torch.stack(grad_max_list)) - assert ( - abs(grad_max.item() - grad_clip_val) < 1e-11 - ), f"Gradient max value {grad_max} != grad_clip_val {grad_clip_val} ." - return ret_val + class TestModel(BoringModel): + def configure_gradient_clipping(self, *args, **kwargs): + super().configure_gradient_clipping(*args, **kwargs) + # test that gradient is clipped correctly + parameters = self.parameters() + grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] + grad_max = torch.max(torch.stack(grad_max_list)) + torch.testing.assert_allclose(grad_max.abs(), torch.tensor(1e-10)) + self.assertion_called = True - trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._backward = backward + model = TestModel() trainer.fit(model) + assert model.assertion_called def test_invalid_gradient_clip_value(tmpdir):