Skip to content

Commit 23e8b59

Browse files
authored
Add configure_gradient_clipping hook in LightningModule (#9584)
* init hook * docs * dep train args * update tests * doc * doc * .gitignore * not dep * add trainer args * add & update tests * fix tests * pre-commit * docs * add docs * add exception * code review * deepspeed * update tests * not * try fix * Apply suggestions from code review * update deepspeed * disable some tests * disable some tests * enable all tests
1 parent 05b15e6 commit 23e8b59

File tree

12 files changed

+317
-20
lines changed

12 files changed

+317
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,4 @@ cifar-10-batches-py
156156
*.pt
157157
# ctags
158158
tags
159+
.tags

docs/source/common/lightning_module.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1195,6 +1195,7 @@ for more information.
11951195
on_after_backward()
11961196
11971197
on_before_optimizer_step()
1198+
configure_gradient_clipping()
11981199
optimizer_step()
11991200
12001201
on_train_batch_end()
@@ -1452,6 +1453,12 @@ on_before_optimizer_step
14521453
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_optimizer_step
14531454
:noindex:
14541455

1456+
configure_gradient_clipping
1457+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
1458+
1459+
.. automethod:: pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping
1460+
:noindex:
1461+
14551462
optimizer_step
14561463
~~~~~~~~~~~~~~
14571464

docs/source/common/optimizers.rst

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ Here is a minimal example of manual optimization.
6969
Gradient accumulation
7070
---------------------
7171
You can accumulate gradients over batches similarly to
72-
:attr:`~pytorch_lightning.trainer.Trainer.accumulate_grad_batches` of automatic optimization.
72+
:attr:`~pytorch_lightning.trainer.trainer.Trainer.accumulate_grad_batches` of automatic optimization.
7373
To perform gradient accumulation with one optimizer, you can do as such.
7474

7575
.. testcode:: python
@@ -516,3 +516,47 @@ to perform a step, Lightning won't be able to support accelerators and precision
516516
):
517517
optimizer = optimizer.optimizer
518518
optimizer.step(closure=optimizer_closure)
519+
520+
-----
521+
522+
Configure gradient clipping
523+
---------------------------
524+
To configure custom gradient clipping, consider overriding
525+
the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method.
526+
Attributes :attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_val` and
527+
:attr:`~pytorch_lightning.trainer.trainer.Trainer.gradient_clip_algorithm` will be passed in the respective
528+
arguments here and Lightning will handle gradient clipping for you. In case you want to set
529+
different values for your arguments of your choice and let Lightning handle the gradient clipping, you can
530+
use the inbuilt :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients` method and pass
531+
the arguments along with your optimizer.
532+
533+
.. note::
534+
Make sure to not override :meth:`~pytorch_lightning.core.lightning.LightningModule.clip_gradients`
535+
method. If you want to customize gradient clipping, consider using
536+
:meth:`~pytorch_lightning.core.lightning.LightningModule.configure_gradient_clipping` method.
537+
538+
For example, here we will apply gradient clipping only to the gradients associated with optimizer A.
539+
540+
.. testcode:: python
541+
542+
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
543+
if optimizer_idx == 0:
544+
# Lightning will handle the gradient clipping
545+
self.clip_gradients(
546+
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
547+
)
548+
549+
Here we configure gradient clipping differently for optimizer B.
550+
551+
.. testcode:: python
552+
553+
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
554+
if optimizer_idx == 0:
555+
# Lightning will handle the gradient clipping
556+
self.clip_gradients(
557+
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
558+
)
559+
elif optimizer_idx == 1:
560+
self.clip_gradients(
561+
optimizer, gradient_clip_val=gradient_clip_val * 2, gradient_clip_algorithm=gradient_clip_algorithm
562+
)

pytorch_lightning/core/lightning.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,12 @@
3636
from pytorch_lightning.core.optimizer import LightningOptimizer
3737
from pytorch_lightning.core.saving import ModelIO
3838
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
39-
from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE, rank_zero_deprecation, rank_zero_warn
39+
from pytorch_lightning.utilities import (
40+
_TORCH_SHARDED_TENSOR_AVAILABLE,
41+
GradClipAlgorithmType,
42+
rank_zero_deprecation,
43+
rank_zero_warn,
44+
)
4045
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
4146
from pytorch_lightning.utilities.cloud_io import get_filesystem
4247
from pytorch_lightning.utilities.distributed import distributed_available, sync_ddp
@@ -1460,7 +1465,7 @@ def untoggle_optimizer(self, optimizer_idx: int):
14601465
optimizer_idx: Current optimizer idx in the training loop
14611466
14621467
Note:
1463-
Only called when using multiple optimizers
1468+
Only called when using multiple_optimizers
14641469
"""
14651470
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
14661471
if optimizer_idx != opt_idx:
@@ -1471,6 +1476,96 @@ def untoggle_optimizer(self, optimizer_idx: int):
14711476
# save memory
14721477
self._param_requires_grad_state = {}
14731478

1479+
def clip_gradients(
1480+
self,
1481+
optimizer: Optimizer,
1482+
gradient_clip_val: Optional[Union[int, float]] = None,
1483+
gradient_clip_algorithm: Optional[Union[str, GradClipAlgorithmType]] = None,
1484+
):
1485+
"""Handles gradient clipping internally.
1486+
1487+
Note:
1488+
Do not override this method. If you want to customize gradient clipping, consider
1489+
using :meth:`configure_gradient_clipping` method.
1490+
1491+
Args:
1492+
optimizer: Current optimizer being used.
1493+
gradient_clip_val: The value at which to clip gradients.
1494+
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
1495+
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
1496+
"""
1497+
if gradient_clip_val is None:
1498+
gradient_clip_val = self.trainer.gradient_clip_val or 0.0
1499+
elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
1500+
raise MisconfigurationException(
1501+
"You have set `Trainer(gradient_clip_val)` and have passed"
1502+
" `gradient_clip_val` inside `clip_gradients`. Please use only one of them."
1503+
)
1504+
1505+
if gradient_clip_algorithm is None:
1506+
gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm"
1507+
else:
1508+
gradient_clip_algorithm = gradient_clip_algorithm.lower()
1509+
if (
1510+
self.trainer.gradient_clip_algorithm is not None
1511+
and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
1512+
):
1513+
raise MisconfigurationException(
1514+
"You have set `Trainer(gradient_clip_algorithm)` and have passed"
1515+
" `gradient_clip_algorithm` inside `clip_gradients`. Please use only one of them."
1516+
)
1517+
1518+
if not isinstance(gradient_clip_val, (int, float)):
1519+
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
1520+
1521+
if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
1522+
raise MisconfigurationException(
1523+
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid."
1524+
f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
1525+
)
1526+
1527+
gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
1528+
self.trainer.accelerator.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
1529+
1530+
def configure_gradient_clipping(
1531+
self,
1532+
optimizer: Optimizer,
1533+
optimizer_idx: int,
1534+
gradient_clip_val: Optional[Union[int, float]] = None,
1535+
gradient_clip_algorithm: Optional[str] = None,
1536+
):
1537+
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
1538+
1539+
Note:
1540+
This hook won't be called when using deepspeed since it handles gradient clipping internally.
1541+
Consider setting ``gradient_clip_val`` and ``gradient_clip_algorithm`` inside ``Trainer``."
1542+
1543+
Args:
1544+
optimizer: Current optimizer being used.
1545+
optimizer_idx: Index of the current optimizer being used.
1546+
gradient_clip_val: The value at which to clip gradients. By default value passed in Trainer
1547+
will be available here.
1548+
gradient_clip_algorithm: The gradient clipping algorithm to use. By default value
1549+
passed in Trainer will be available here.
1550+
1551+
Example::
1552+
1553+
# Perform gradient clipping on gradients associated with discriminator (optimizer_idx=1) in GAN
1554+
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
1555+
if optimizer_idx == 1:
1556+
# Lightning will handle the gradient clipping
1557+
self.clip_gradients(
1558+
optimizer,
1559+
gradient_clip_val=gradient_clip_val,
1560+
gradient_clip_algorithm=gradient_clip_algorithm
1561+
)
1562+
else:
1563+
# implement your own custom logic to clip gradients for generator (optimizer_idx=0)
1564+
"""
1565+
self.clip_gradients(
1566+
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
1567+
)
1568+
14741569
def optimizer_step(
14751570
self,
14761571
epoch: int = None,

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def _backward(
240240

241241
if not self.trainer.fit_loop._should_accumulate():
242242
# track gradients
243-
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
243+
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer, opt_idx=opt_idx)
244244
if grad_norm_dict:
245245
self.trainer.lightning_module._current_fx_name = "on_after_backward"
246246
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)
@@ -470,7 +470,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
470470

471471
return result
472472

473-
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
473+
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -> Dict[str, float]:
474474
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
475475
476476
Args:
@@ -484,7 +484,11 @@ def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, fl
484484
grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm)
485485

486486
# clip gradients
487-
self.trainer.accelerator.clip_gradients(
488-
optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm
489-
)
487+
if not self.trainer.accelerator_connector.use_deepspeed:
488+
self.trainer.lightning_module.configure_gradient_clipping(
489+
optimizer,
490+
opt_idx,
491+
gradient_clip_val=self.trainer.gradient_clip_val,
492+
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm,
493+
)
490494
return grad_norm_dict

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434
from pytorch_lightning.utilities import AMPType
3535
from pytorch_lightning.utilities.apply_func import apply_to_collection
3636
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
37+
from pytorch_lightning.utilities.enums import GradClipAlgorithmType
3738
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3839
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
40+
from pytorch_lightning.utilities.model_helpers import is_overridden
3941
from pytorch_lightning.utilities.seed import reset_seed
4042
from pytorch_lightning.utilities.types import _PATH, LRSchedulerTypeTuple
4143
from pytorch_lightning.utilities.warnings import rank_zero_warn, WarningCache
@@ -376,6 +378,18 @@ def pre_dispatch(self):
376378
self.barrier()
377379

378380
def init_deepspeed(self):
381+
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
382+
# gradient clipping internally
383+
if is_overridden("configure_gradient_clipping", self.lightning_module):
384+
rank_zero_warn(
385+
"Since deepspeed handles gradient clipping internally, this hook will"
386+
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"
387+
" inside `Trainer`."
388+
)
389+
390+
if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
391+
raise MisconfigurationException("Deepspeed does not support clipping gradients by value.")
392+
379393
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler
380394

381395
if accumulation_scheduler.epochs != [0]:
@@ -569,7 +583,7 @@ def _format_batch_size_and_grad_accum_config(self):
569583
batch_size = self._auto_select_batch_size()
570584
self.config["train_micro_batch_size_per_gpu"] = batch_size
571585
if "gradient_clipping" not in self.config:
572-
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
586+
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val or 0.0
573587

574588
def _auto_select_batch_size(self):
575589
# train_micro_batch_size_per_gpu is used for throughput logging purposes

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> Non
201201
def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> None:
202202
if model.automatic_optimization:
203203
return
204-
if self.trainer.gradient_clip_val > 0:
204+
if self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val > 0:
205205
raise MisconfigurationException(
206206
"Automatic gradient clipping is not supported for manual optimization."
207207
f" Remove `Trainer(gradient_clip_val={self.trainer.gradient_clip_val})`"

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def __init__(self, trainer):
2323

2424
def on_trainer_init(
2525
self,
26-
gradient_clip_val: Union[int, float],
27-
gradient_clip_algorithm: str,
26+
gradient_clip_val: Optional[Union[int, float]],
27+
gradient_clip_algorithm: Optional[str],
2828
track_grad_norm: Union[int, float, str],
2929
terminate_on_nan: Optional[bool],
3030
):
@@ -37,10 +37,12 @@ def on_trainer_init(
3737
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
3838

3939
# gradient clipping
40-
if not isinstance(gradient_clip_val, (int, float)):
40+
if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
4141
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
4242

43-
if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
43+
if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type(
44+
gradient_clip_algorithm.lower()
45+
):
4446
raise MisconfigurationException(
4547
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
4648
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
@@ -54,5 +56,9 @@ def on_trainer_init(
5456

5557
self.trainer._terminate_on_nan = terminate_on_nan
5658
self.trainer.gradient_clip_val = gradient_clip_val
57-
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
59+
self.trainer.gradient_clip_algorithm = (
60+
GradClipAlgorithmType(gradient_clip_algorithm.lower())
61+
if gradient_clip_algorithm is not None
62+
else gradient_clip_algorithm
63+
)
5864
self.trainer.track_grad_norm = float(track_grad_norm)

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def __init__(
124124
enable_checkpointing: bool = True,
125125
callbacks: Optional[Union[List[Callback], Callback]] = None,
126126
default_root_dir: Optional[str] = None,
127-
gradient_clip_val: Union[int, float] = 0.0,
128-
gradient_clip_algorithm: str = "norm",
127+
gradient_clip_val: Optional[Union[int, float]] = None,
128+
gradient_clip_algorithm: Optional[str] = None,
129129
process_position: int = 0,
130130
num_nodes: int = 1,
131131
num_processes: int = 1,
@@ -254,11 +254,12 @@ def __init__(
254254
255255
gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node
256256
257-
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=0`` disables gradient
258-
clipping.
257+
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
258+
gradient clipping.
259259
260260
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
261-
for clip_by_value, and ``gradient_clip_algorithm="norm"`` for clip_by_norm.
261+
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
262+
be set to ``"norm"``.
262263
263264
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).
264265

0 commit comments

Comments
 (0)