Skip to content

Commit 24db14e

Browse files
carmoccaawaelchlirohitgr7
authored andcommitted
Minor fixes related to clipping (Lightning-AI#10130)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 1dc5957 commit 24db14e

15 files changed

+81
-101
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def optimizer_step(
329329
opt_idx: index of the current optimizer
330330
lambda_closure: closure calculating the loss value
331331
model: reference to the model, optionally defining optimizer step related hooks
332+
**kwargs: Any extra arguments to ``optimizer.step``
332333
"""
333334
model = model or self.lightning_module
334335
make_optimizer_step = self.precision_plugin.pre_optimizer_step(
@@ -349,9 +350,7 @@ def clip_gradients(
349350
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
350351
) -> None:
351352
"""clips all the optimizer parameters to the given value."""
352-
self.precision_plugin.clip_gradients(
353-
optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm, model=self.model
354-
)
353+
self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
355354

356355
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
357356
"""Creates optimizers and schedulers.

pytorch_lightning/core/lightning.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,10 +1411,7 @@ def training_step(...):
14111411
*args: Additional positional arguments to be forwarded to :meth:`~torch.Tensor.backward`
14121412
**kwargs: Additional keyword arguments to be forwarded to :meth:`~torch.Tensor.backward`
14131413
"""
1414-
# make sure we're using manual opt
14151414
self._verify_is_manual_optimization("manual_backward")
1416-
1417-
# backward
14181415
self.trainer.accelerator.backward(loss, None, None, *args, **kwargs)
14191416

14201417
def backward(
@@ -1487,7 +1484,7 @@ def clip_gradients(
14871484
self,
14881485
optimizer: Optimizer,
14891486
gradient_clip_val: Optional[Union[int, float]] = None,
1490-
gradient_clip_algorithm: Optional[Union[str, GradClipAlgorithmType]] = None,
1487+
gradient_clip_algorithm: Optional[str] = None,
14911488
):
14921489
"""Handles gradient clipping internally.
14931490
@@ -1505,8 +1502,9 @@ def clip_gradients(
15051502
gradient_clip_val = self.trainer.gradient_clip_val or 0.0
15061503
elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
15071504
raise MisconfigurationException(
1508-
"You have set `Trainer(gradient_clip_val)` and have passed"
1509-
" `gradient_clip_val` inside `clip_gradients`. Please use only one of them."
1505+
f"You have set `Trainer(gradient_clip_val={self.trainer.gradient_clip_val!r})`"
1506+
f" and have passed `clip_gradients(gradient_clip_val={gradient_clip_val!r})`."
1507+
" Please use only one of them."
15101508
)
15111509

15121510
if gradient_clip_algorithm is None:
@@ -1518,8 +1516,9 @@ def clip_gradients(
15181516
and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
15191517
):
15201518
raise MisconfigurationException(
1521-
"You have set `Trainer(gradient_clip_algorithm)` and have passed"
1522-
" `gradient_clip_algorithm` inside `clip_gradients`. Please use only one of them."
1519+
f"You have set `Trainer(gradient_clip_algorithm={self.trainer.gradient_clip_algorithm.value!r})`"
1520+
f" and have passed `clip_gradients(gradient_clip_algorithm={gradient_clip_algorithm!r})"
1521+
" Please use only one of them."
15231522
)
15241523

15251524
if not isinstance(gradient_clip_val, (int, float)):
@@ -1543,10 +1542,6 @@ def configure_gradient_clipping(
15431542
):
15441543
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
15451544
1546-
Note:
1547-
This hook won't be called when using deepspeed since it handles gradient clipping internally.
1548-
Consider setting ``gradient_clip_val`` and ``gradient_clip_algorithm`` inside ``Trainer``."
1549-
15501545
Args:
15511546
optimizer: Current optimizer being used.
15521547
optimizer_idx: Index of the current optimizer being used.

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -488,11 +488,10 @@ def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer, opt_idx: int) -
488488
)
489489

490490
# clip gradients
491-
if not self.trainer.accelerator_connector.use_deepspeed:
492-
self.trainer.lightning_module.configure_gradient_clipping(
493-
optimizer,
494-
opt_idx,
495-
gradient_clip_val=self.trainer.gradient_clip_val,
496-
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm,
497-
)
491+
self.trainer.lightning_module.configure_gradient_clipping(
492+
optimizer,
493+
opt_idx,
494+
gradient_clip_val=self.trainer.gradient_clip_val,
495+
gradient_clip_algorithm=self.trainer.gradient_clip_algorithm,
496+
)
498497
return grad_norm_dict

pytorch_lightning/plugins/precision/apex_amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def pre_optimizer_step(
101101
model: Union["pl.LightningModule", Module],
102102
optimizer: Optimizer,
103103
optimizer_idx: int,
104-
lambda_closure: Callable,
104+
lambda_closure: Callable[[], Any],
105105
**kwargs: Any,
106106
) -> bool:
107107
"""Hook to do something before each optimizer step."""

pytorch_lightning/plugins/precision/deepspeed_precision.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Optional, Union
14+
from typing import Any, Callable, Union
1515

1616
from torch import Tensor
1717
from torch.nn import Module
@@ -51,7 +51,7 @@ def pre_optimizer_step(
5151
model: Union["pl.LightningModule", Module],
5252
optimizer: Optimizer,
5353
optimizer_idx: int,
54-
lambda_closure: Callable,
54+
lambda_closure: Callable[[], Any],
5555
**kwargs: Any,
5656
) -> bool:
5757
"""Hook to do something before each optimizer step."""
@@ -69,14 +69,13 @@ def pre_optimizer_step(
6969
)
7070
# DeepSpeed handles the optimizer step internally
7171
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
72-
deepspeed_engine.step()
72+
deepspeed_engine.step(**kwargs)
7373
return False
7474

7575
def clip_gradients(
7676
self,
7777
optimizer: Optimizer,
78-
clip_val: Union[int, float],
78+
clip_val: Union[int, float] = 0.0,
7979
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
80-
model: Optional[Module] = None,
8180
) -> None:
8281
"""DeepSpeed handles gradient clipping internally."""

pytorch_lightning/plugins/precision/fully_sharded_native_amp.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,36 +11,23 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Union
15-
16-
from torch.nn import Module
17-
from torch.optim import Optimizer
14+
from typing import Any
1815

1916
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
20-
from pytorch_lightning.utilities import GradClipAlgorithmType
17+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2118

2219

2320
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
2421
"""Mixed Precision for Full Sharded Training."""
2522

2623
precision = "mixed"
2724

28-
def clip_gradients(
29-
self,
30-
optimizer: Optimizer,
31-
clip_val: Union[int, float],
32-
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.VALUE,
33-
model: Optional[Module] = None,
34-
) -> None:
35-
clip_val = float(clip_val)
36-
if clip_val <= 0:
37-
return
25+
def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
3826
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html
3927
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
4028
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val)
4129
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to
4230
# trace back the root FSDP. Now we only support clip by value.
43-
assert (
44-
gradient_clip_algorithm == GradClipAlgorithmType.VALUE
45-
), "`gradient_clip_algorithm`: `norm` is currently not supported for `FullyShardedNativeMixedPrecisionPlugin`"
46-
self.clip_grad_by_value(optimizer, clip_val)
31+
raise MisconfigurationException(
32+
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
33+
)

pytorch_lightning/plugins/precision/ipu_precision.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Callable, Optional, Union
14+
from typing import Any, Callable, Union
1515

1616
from torch.nn import Module
1717
from torch.optim import LBFGS, Optimizer
@@ -67,12 +67,9 @@ def pre_optimizer_step(
6767
def clip_gradients(
6868
self,
6969
optimizer: Optimizer,
70-
clip_val: Union[int, float],
70+
clip_val: Union[int, float] = 0.0,
7171
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
72-
model: Optional[Module] = None,
7372
) -> None:
74-
"""Clips the gradients."""
75-
if clip_val is None or float(clip_val) <= 0:
73+
if clip_val <= 0:
7674
return
77-
7875
raise MisconfigurationException("IPUs currently do not support clipping gradients.")

pytorch_lightning/plugins/precision/native_amp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def pre_optimizer_step(
6969
model: Union["pl.LightningModule", Module],
7070
optimizer: Optimizer,
7171
optimizer_idx: int,
72-
lambda_closure: Callable,
72+
lambda_closure: Callable[[], Any],
7373
**kwargs: Any,
7474
) -> bool:
7575
if self.is_bfloat16:
@@ -86,7 +86,7 @@ def pre_optimizer_step(
8686
# in manual optimization, the closure does not return a value
8787
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
8888
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
89-
self.scaler.step(optimizer)
89+
self.scaler.step(optimizer, **kwargs)
9090
self.scaler.update()
9191
return False
9292

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,22 +113,15 @@ def pre_optimizer_step(
113113
def clip_gradients(
114114
self,
115115
optimizer: Optimizer,
116-
clip_val: Union[int, float],
116+
clip_val: Union[int, float] = 0.0,
117117
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
118-
model: Optional[Module] = None,
119118
) -> None:
120119
"""Clips the gradients."""
121-
if clip_val is None:
122-
return
123-
124-
clip_val = float(clip_val)
125120
if clip_val <= 0:
126121
return
127-
128122
if gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
129123
self.clip_grad_by_value(optimizer, clip_val)
130124
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
131-
# TODO: there should be a mechanism to set `norm_type`
132125
self.clip_grad_by_norm(optimizer, clip_val)
133126

134127
def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:

pytorch_lightning/plugins/precision/sharded_native_amp.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,5 @@ def __init__(self, precision: Union[int, str] = 16, use_cpu: bool = False) -> No
2929
if not self.use_cpu:
3030
self.scaler = ShardedGradScaler()
3131

32-
def clip_grad_by_norm(
33-
self, optimizer: "OSS", clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6
34-
) -> None:
35-
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
32+
def clip_grad_by_norm(self, optimizer: "OSS", clip_val: Union[int, float]) -> None:
33+
optimizer.clip_grad_norm(clip_val)

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@
3333
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
3434
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
3535
from pytorch_lightning.trainer.states import TrainerFn
36-
from pytorch_lightning.utilities import AMPType
36+
from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType
3737
from pytorch_lightning.utilities.apply_func import apply_to_collection
3838
from pytorch_lightning.utilities.distributed import log, rank_zero_info, rank_zero_only
39-
from pytorch_lightning.utilities.enums import GradClipAlgorithmType
4039
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4140
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
4241
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -422,17 +421,16 @@ def _setup_model_and_optimizer(
422421
return deepspeed_engine, deepspeed_optimizer
423422

424423
def init_deepspeed(self):
425-
# check that `configure_gradient_clipping` hook isn't overriden since deepspeed handles
426-
# gradient clipping internally
424+
# deepspeed handles gradient clipping internally
427425
if is_overridden("configure_gradient_clipping", self.lightning_module, pl.LightningModule):
428426
rank_zero_warn(
429-
"Since deepspeed handles gradient clipping internally, this hook will"
430-
" be ignored. Consider setting `gradient_clip_val` and `gradient_clip_algorithm`"
431-
" inside `Trainer`."
427+
"Since deepspeed handles gradient clipping internally, `LightningModule.configure_gradient_clipping`"
428+
" will be ignored. Consider setting `Trainer(gradient_clip_val=..., gradient_clip_algorithm='norm')`"
429+
" which will use the internal mechanism."
432430
)
433431

434432
if self.lightning_module.trainer.gradient_clip_algorithm == GradClipAlgorithmType.VALUE:
435-
raise MisconfigurationException("Deepspeed does not support clipping gradients by value.")
433+
raise MisconfigurationException("DeepSpeed does not support clipping gradients by value.")
436434

437435
accumulation_scheduler = self.lightning_module.trainer.accumulation_scheduler
438436

@@ -480,7 +478,7 @@ def _initialize_deepspeed_train(self, model):
480478
else:
481479
rank_zero_info(
482480
"You have not specified an optimizer or scheduler within the DeepSpeed config."
483-
"Using `configure_optimizers` to define optimizer and scheduler."
481+
" Using `configure_optimizers` to define optimizer and scheduler."
484482
)
485483
optimizer, lr_scheduler, _ = self._init_optimizers()
486484

@@ -534,7 +532,7 @@ def _initialize_deepspeed_inference(self, model):
534532
if "optimizer" not in self.config:
535533
rank_zero_info(
536534
"You have not specified an optimizer or scheduler within the DeepSpeed config."
537-
"Using `configure_optimizers` to define optimizer and scheduler."
535+
" Using `configure_optimizers` to define optimizer and scheduler."
538536
)
539537
optimizer, lr_scheduler, _ = self._init_optimizers()
540538
scheduler = lr_scheduler["scheduler"]

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def __init__(
257257
gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node
258258
259259
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
260-
gradient clipping.
260+
gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
261261
262262
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
263263
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
@@ -380,7 +380,8 @@ def __init__(
380380
381381
ipus: How many IPUs to train on.
382382
383-
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
383+
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. If using
384+
Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them.
384385
385386
val_check_interval: How often to check the validation set. Use float to check within a training epoch,
386387
use int to check every n steps (batches).

tests/core/test_lightning_module.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,11 +386,14 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va
386386
trainer = Trainer(
387387
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4
388388
)
389-
with pytest.raises(MisconfigurationException, match=r".*have set `Trainer\(gradient_clip_val\)` and have passed.*"):
389+
with pytest.raises(
390+
MisconfigurationException,
391+
match=r"gradient_clip_val=0.0001\)` and have passed `clip_gradients\(gradient_clip_val=0.01",
392+
):
390393
trainer.fit(model)
391394

392395
class TestModel(BoringModel):
393-
custom_gradient_clip_algorithm = "value"
396+
custom_gradient_clip_algorithm = "foo"
394397

395398
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
396399
self.clip_gradients(optimizer, gradient_clip_algorithm=self.custom_gradient_clip_algorithm)
@@ -404,6 +407,7 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va
404407
gradient_clip_algorithm="norm",
405408
)
406409
with pytest.raises(
407-
MisconfigurationException, match=r".*have set `Trainer\(gradient_clip_algorithm\)` and have passed.*"
410+
MisconfigurationException,
411+
match=r"gradient_clip_algorithm='norm'\)` and have passed `clip_gradients\(gradient_clip_algorithm='foo'",
408412
):
409413
trainer.fit(model)

tests/models/test_hooks.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -281,24 +281,6 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
281281
dict(name="Callback.on_before_optimizer_step", args=(trainer, model, ANY, 0)),
282282
dict(name="on_before_optimizer_step", args=(ANY, 0)),
283283
]
284-
285-
# deepspeed handles gradient clipping internally
286-
configure_gradient_clipping = (
287-
[]
288-
if using_deepspeed
289-
else [
290-
dict(
291-
name="clip_gradients",
292-
args=(ANY,),
293-
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
294-
),
295-
dict(
296-
name="configure_gradient_clipping",
297-
args=(ANY, 0),
298-
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
299-
),
300-
]
301-
)
302284
for i in range(batches):
303285
out.extend(
304286
[
@@ -323,7 +305,16 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
323305
*([dict(name="backward", args=(ANY, ANY, 0))] if not using_deepspeed else []),
324306
dict(name="Callback.on_after_backward", args=(trainer, model)),
325307
dict(name="on_after_backward"),
326-
*configure_gradient_clipping,
308+
dict(
309+
name="clip_gradients",
310+
args=(ANY,),
311+
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
312+
),
313+
dict(
314+
name="configure_gradient_clipping",
315+
args=(ANY, 0),
316+
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
317+
),
327318
*(on_before_optimizer_step if using_plugin else []),
328319
dict(
329320
name="optimizer_step",

0 commit comments

Comments
 (0)