Skip to content

Commit af6b692

Browse files
committed
add exception
1 parent 1f74eb6 commit af6b692

File tree

6 files changed

+93
-20
lines changed

6 files changed

+93
-20
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,7 +1478,12 @@ def untoggle_optimizer(self, optimizer_idx: int):
14781478
# save memory
14791479
self._param_requires_grad_state = {}
14801480

1481-
def clip_gradients(self, optimizer: Optimizer, gradient_clip_val: Union[int, float], gradient_clip_algorithm: str):
1481+
def clip_gradients(
1482+
self,
1483+
optimizer: Optimizer,
1484+
gradient_clip_val: Optional[Union[int, float]] = None,
1485+
gradient_clip_algorithm: Optional[Union[str, GradClipAlgorithmType]] = None,
1486+
):
14821487
"""Handles gradient clipping internally.
14831488
14841489
Note:
@@ -1491,24 +1496,45 @@ def clip_gradients(self, optimizer: Optimizer, gradient_clip_val: Union[int, flo
14911496
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
14921497
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
14931498
"""
1494-
# gradient clipping
1499+
if gradient_clip_val is None:
1500+
gradient_clip_val = self.trainer.gradient_clip_val or 0.0
1501+
elif self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val != gradient_clip_val:
1502+
raise MisconfigurationException(
1503+
"You have set `Trainer(gradient_clip_val)` and have passed"
1504+
" `gradient_clip_val` inside `clip_gradients`. Please use only one of them."
1505+
)
1506+
1507+
if gradient_clip_algorithm is None:
1508+
gradient_clip_algorithm = self.trainer.gradient_clip_algorithm or "norm"
1509+
else:
1510+
gradient_clip_algorithm = gradient_clip_algorithm.lower()
1511+
if (
1512+
self.trainer.gradient_clip_algorithm is not None
1513+
and self.trainer.gradient_clip_algorithm != gradient_clip_algorithm
1514+
):
1515+
raise MisconfigurationException(
1516+
"You have set `Trainer(gradient_clip_algorithm)` and have passed"
1517+
" `gradient_clip_algorithm` inside `clip_gradients`. Please use only one of them."
1518+
)
1519+
14951520
if not isinstance(gradient_clip_val, (int, float)):
14961521
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
14971522

14981523
if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
14991524
raise MisconfigurationException(
1500-
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
1501-
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
1525+
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid."
1526+
f" Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
15021527
)
15031528

1529+
gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm)
15041530
self.trainer.accelerator.clip_gradients(optimizer, gradient_clip_val, gradient_clip_algorithm)
15051531

15061532
def configure_gradient_clipping(
15071533
self,
15081534
optimizer: Optimizer,
15091535
optimizer_idx: int,
1510-
gradient_clip_val: Union[int, float],
1511-
gradient_clip_algorithm: str,
1536+
gradient_clip_val: Optional[Union[int, float]] = None,
1537+
gradient_clip_algorithm: Optional[str] = None,
15121538
):
15131539
"""Perform gradient clipping for the optimizer parameters. Called before :meth:`optimizer_step`.
15141540

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def __verify_dp_batch_transfer_support(self, model: "pl.LightningModule") -> Non
184184
def __verify_manual_optimization_support(self, model: "pl.LightningModule") -> None:
185185
if model.automatic_optimization:
186186
return
187-
if self.trainer.gradient_clip_val > 0:
187+
if self.trainer.gradient_clip_val is not None and self.trainer.gradient_clip_val > 0:
188188
raise MisconfigurationException(
189189
"Automatic gradient clipping is not supported for manual optimization."
190190
f" Remove `Trainer(gradient_clip_val={self.trainer.gradient_clip_val})`"

pytorch_lightning/trainer/connectors/training_trick_connector.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Union
14+
from typing import Optional, Union
1515

1616
from pytorch_lightning.utilities import GradClipAlgorithmType
1717
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -23,19 +23,21 @@ def __init__(self, trainer):
2323

2424
def on_trainer_init(
2525
self,
26-
gradient_clip_val: Union[int, float],
27-
gradient_clip_algorithm: str,
26+
gradient_clip_val: Optional[Union[int, float]],
27+
gradient_clip_algorithm: Optional[str],
2828
track_grad_norm: Union[int, float, str],
2929
terminate_on_nan: bool,
3030
):
3131
if not isinstance(terminate_on_nan, bool):
3232
raise TypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")
3333

3434
# gradient clipping
35-
if not isinstance(gradient_clip_val, (int, float)):
35+
if gradient_clip_val is not None and not isinstance(gradient_clip_val, (int, float)):
3636
raise TypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")
3737

38-
if not GradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):
38+
if gradient_clip_algorithm is not None and not GradClipAlgorithmType.supported_type(
39+
gradient_clip_algorithm.lower()
40+
):
3941
raise MisconfigurationException(
4042
f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "
4143
f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}."
@@ -49,5 +51,9 @@ def on_trainer_init(
4951

5052
self.trainer.terminate_on_nan = terminate_on_nan
5153
self.trainer.gradient_clip_val = gradient_clip_val
52-
self.trainer.gradient_clip_algorithm = GradClipAlgorithmType(gradient_clip_algorithm.lower())
54+
self.trainer.gradient_clip_algorithm = (
55+
GradClipAlgorithmType(gradient_clip_algorithm.lower())
56+
if gradient_clip_algorithm is not None
57+
else gradient_clip_algorithm
58+
)
5359
self.trainer.track_grad_norm = float(track_grad_norm)

pytorch_lightning/trainer/trainer.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def __init__(
123123
checkpoint_callback: bool = True,
124124
callbacks: Optional[Union[List[Callback], Callback]] = None,
125125
default_root_dir: Optional[str] = None,
126-
gradient_clip_val: Union[int, float] = 0.0,
127-
gradient_clip_algorithm: str = "norm",
126+
gradient_clip_val: Optional[Union[int, float]] = None,
127+
gradient_clip_algorithm: Optional[str] = None,
128128
process_position: int = 0,
129129
num_nodes: int = 1,
130130
num_processes: int = 1,
@@ -240,11 +240,12 @@ def __init__(
240240
241241
gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node
242242
243-
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=0`` disables gradient
244-
clipping.
243+
gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
244+
gradient clipping.
245245
246246
gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
247-
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm.
247+
to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
248+
be set to ``"norm"``.
248249
249250
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).
250251

tests/core/test_lightning_module.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning import Trainer
2323
from pytorch_lightning.loggers import TensorBoardLogger
2424
from pytorch_lightning.utilities import _TORCH_SHARDED_TENSOR_AVAILABLE
25+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2526
from tests.helpers import BoringModel
2627
from tests.helpers.runif import RunIf
2728

@@ -338,6 +339,8 @@ def test_sharded_tensor_state_dict(tmpdir, single_process_pg):
338339

339340

340341
def test_lightning_module_configure_gradient_clipping(tmpdir):
342+
"""Test custom gradient clipping inside `configure_gradient_clipping` hook."""
343+
341344
class TestModel(BoringModel):
342345

343346
has_validated_gradients = False
@@ -366,3 +369,40 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx):
366369
)
367370
trainer.fit(model)
368371
assert model.has_validated_gradients
372+
373+
374+
def test_lightning_module_configure_gradient_clipping_different_argument_values(tmpdir):
375+
"""Test that setting gradient clipping arguments in `Trainer` and cusotmizing gradient clipping inside
376+
`configure_gradient_clipping` with different values raises an exception."""
377+
378+
class TestModel(BoringModel):
379+
custom_gradient_clip_val = 1e-2
380+
381+
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
382+
self.clip_gradients(optimizer, gradient_clip_val=self.custom_gradient_clip_val)
383+
384+
model = TestModel()
385+
trainer = Trainer(
386+
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, gradient_clip_val=1e-4
387+
)
388+
with pytest.raises(MisconfigurationException, match=r".*have set `Trainer\(gradient_clip_val\)` and have passed.*"):
389+
trainer.fit(model)
390+
391+
class TestModel(BoringModel):
392+
custom_gradient_clip_algorithm = "value"
393+
394+
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
395+
self.clip_gradients(optimizer, gradient_clip_algorithm=self.custom_gradient_clip_algorithm)
396+
397+
model = TestModel()
398+
trainer = Trainer(
399+
default_root_dir=tmpdir,
400+
max_epochs=1,
401+
limit_train_batches=2,
402+
limit_val_batches=0,
403+
gradient_clip_algorithm="norm",
404+
)
405+
with pytest.raises(
406+
MisconfigurationException, match=r".*have set `Trainer\(gradient_clip_algorithm\)` and have passed.*"
407+
):
408+
trainer.fit(model)

tests/models/test_hooks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,12 @@ def _auto_train_batch(trainer, model, batches, device=torch.device("cpu"), curre
309309
dict(
310310
name="clip_gradients",
311311
args=(ANY,),
312-
kwargs=dict(gradient_clip_val=0.0, gradient_clip_algorithm="norm"),
312+
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
313313
),
314314
dict(
315315
name="configure_gradient_clipping",
316316
args=(ANY, 0),
317-
kwargs=dict(gradient_clip_val=0.0, gradient_clip_algorithm="norm"),
317+
kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None),
318318
),
319319
dict(
320320
name="optimizer_step",

0 commit comments

Comments
 (0)