From 210b401fece33f273484ddf0d233d7cca13eb2f6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 29 Sep 2021 06:42:23 +0200 Subject: [PATCH 1/6] Add typing for `LightningOptimizer` --- pyproject.toml | 1 + pytorch_lightning/core/optimizer.py | 93 ++++++++----------- .../training_type/training_type_plugin.py | 1 + 3 files changed, 42 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4161af7849572..5acd5ea13ffa9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ module = [ "pytorch_lightning.callbacks.model_summary", "pytorch_lightning.callbacks.pruning", "pytorch_lightning.callbacks.rich_model_summary", + "pytorch_lightning.core.optimizer", "pytorch_lightning.loops.optimization.*", "pytorch_lightning.loops.evaluation_loop", "pytorch_lightning.trainer.connectors.checkpoint_connector", diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ba81644b9bd9a..0ad7f8a2040c3 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Callable, Optional +from typing import Any, Callable, Generator, List, Optional from weakref import proxy from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException -def do_nothing_closure(): +def do_nothing_closure() -> None: return @@ -44,39 +45,38 @@ def __init__(self, optimizer: Optimizer): self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {}) self._optimizer = optimizer - self._trainer = None - self._optimizer_idx = None - self._total_optimizer_step_calls = 0 + self._trainer: Optional["pl.Trainer"] = None + self._optimizer_idx = 0 @property - def optimizer(self): + def optimizer(self) -> Optimizer: return self._optimizer @property - def defaults(self): + def defaults(self) -> dict: return self._optimizer.defaults @defaults.setter - def defaults(self, defaults): + def defaults(self, defaults: dict) -> None: self._optimizer.defaults = defaults @property - def state(self): + def state(self) -> dict: return self._optimizer.state @state.setter - def state(self, state): + def state(self, state: dict) -> None: self._optimizer.state = state @property - def param_groups(self): + def param_groups(self) -> List[dict]: return self._optimizer.param_groups @param_groups.setter - def param_groups(self, param_groups): + def param_groups(self, param_groups: List[dict]) -> None: self._optimizer.param_groups = param_groups - def _on_trainer_init(self, trainer): + def _on_trainer_init(self, trainer: "pl.Trainer") -> None: self._trainer = proxy(trainer) for opt_idx, opt in enumerate(trainer.optimizers): if opt == self._optimizer: @@ -84,25 +84,28 @@ def _on_trainer_init(self, trainer): break @classmethod - def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx): + def _to_lightning_optimizer(cls, optimizer: Optimizer, trainer: "pl.Trainer", opt_idx: int) -> "LightningOptimizer": # apex overrides .step function and need to be wrapped on each step - if trainer.amp_backend == AMPType.APEX: - optimizer = cls(optimizer) - optimizer._on_trainer_init(trainer) + if trainer.amp_backend is not None and trainer.amp_backend == AMPType.APEX: + lightning_optimizer = cls(optimizer) + lightning_optimizer._on_trainer_init(trainer) else: - optimizer = trainer.lightning_optimizers[opt_idx] - return optimizer + lightning_optimizer = trainer.lightning_optimizers[opt_idx] + return lightning_optimizer - def _toggle_model(self): + def _toggle_model(self) -> None: + assert self._trainer is not None model_ref = self._trainer.lightning_module model_ref.toggle_optimizer(self, self._optimizer_idx) - def _untoggle_model(self): + def _untoggle_model(self) -> None: + assert self._trainer is not None model_ref = self._trainer.lightning_module + # FIXME: what? model_ref.untoggle_optimizer(self) @contextmanager - def toggle_model(self, sync_grad: bool = True): + def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: """This function is just a helper for advanced users. Considering the current optimizer as A and all other optimizers as B. @@ -116,34 +119,22 @@ def toggle_model(self, sync_grad: bool = True): # local import here to avoid circular import from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior + assert self._trainer is not None with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)): self._toggle_model() yield self._untoggle_model() - def __optimizer_step(self, closure: Callable, profiler_name: str = None, **kwargs): - trainer = self._trainer - - with trainer.profiler.profile(profiler_name): - trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) - - def step(self, closure: Optional[Callable] = None, **kwargs): - """Call this directly from your training_step when doing optimizations manually. By using this we can - ensure that all the proper scaling when using 16-bit, accelerator etc is been done properly for you. - - .. note:: In Manual Optimization, the user is expected to know when to call zero_grad, - perform accumulated_grad_batches, etc ... Lightning will only take care of precision and accelerators + def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> None: + """Performs a single optimization step (parameter update). Args: - - closure: One could provide its own optimizer_closure. Set to None by default. - - kwargs: Any parameters provided to wrapped optimizer.step() + closure: An optional optimizer_closure. + kwargs: Any additional arguments to the ``optimizer.step()`` call. Example:: - # Scenario for a GAN. - + # Scenario for a GAN using manual optimization def training_step(...): opt_gen, opt_dis = self.optimizers() @@ -165,8 +156,7 @@ def training_step(...): opt_dis.step() - # Scenario for a GAN advanced - + # A more advanced example def training_step(self, batch, batch_idx, ...): opt_gen, opt_dis = self.optimizers() @@ -192,18 +182,15 @@ def closure_dis(): with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) """ - if closure is None: - profiler_name = f"closure_{self._optimizer_idx}" - closure = do_nothing_closure - else: - if not callable(closure): - raise MisconfigurationException("When closure is provided, it should be a function") - profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" - - self.__optimizer_step(closure=closure, profiler_name=profiler_name, **kwargs) - self._total_optimizer_step_calls += 1 + closure = closure or do_nothing_closure + if not callable(closure): + raise MisconfigurationException("When closure is provided, it should be a function") + trainer = self._trainer + assert trainer is not None + with trainer.profiler.profile(f"optimizer_step_and_closure_{self._optimizer_idx}"): + trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs) - def __repr__(self): + def __repr__(self) -> str: groups = [ {k: round(v, 12) if isinstance(v, float) else v for k, v in sorted(group.items()) if k != "params"} for group in self.param_groups diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6caaea8632354..28153d41f1cb0 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -156,6 +156,7 @@ def model(self, new_model: Optional[Module]) -> None: @property def lightning_module(self) -> "pl.LightningModule": """Returns the pure LightningModule without potential wrappers.""" + assert self._model is not None return unwrap_lightning_module(self._model) @property From ad88177770aafec6c4568279e2935d4b93ec07a9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 29 Sep 2021 19:25:48 +0200 Subject: [PATCH 2/6] Resolve FIXME --- pytorch_lightning/core/lightning.py | 27 +++++++++---------- pytorch_lightning/core/optimizer.py | 18 +++---------- .../training_type/training_type_plugin.py | 1 - 3 files changed, 16 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 19f39403e4270..39f753dafd78b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1416,17 +1416,16 @@ def backward(self, loss, optimizer, optimizer_idx): """ loss.backward(*args, **kwargs) - def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): + def toggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer], optimizer_idx: int) -> None: """Makes sure only the gradients of the current optimizer's parameters are calculated in the training step - to prevent dangling gradients in multiple-optimizer setup. It works with :meth:`untoggle_optimizer` to make - sure ``param_requires_grad_state`` is properly reset. Override for your own behavior. + to prevent dangling gradients in multiple-optimizer setup. - Args: - optimizer: Current optimizer used in the training loop - optimizer_idx: Current optimizer idx in the training loop + This is only called automatically when automatic optimization is enabled and multiple optimizers are used. + It works with :meth:`untoggle_optimizer` to make sure ``param_requires_grad_state`` is properly reset. - Note: - Only called when using multiple optimizers + Args: + optimizer: The optimizer to toggle. + optimizer_idx: The index of the optimizer to toggle. """ # Iterate over all optimizer parameters to preserve their `requires_grad` information # in case these are pre-defined during `configure_optimizers` @@ -1447,15 +1446,13 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): param.requires_grad = param_requires_grad_state[param] self._param_requires_grad_state = param_requires_grad_state - def untoggle_optimizer(self, optimizer_idx: int): - """Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. Override for - your own behavior. + def untoggle_optimizer(self, optimizer_idx: int) -> None: + """Resets the state of required gradients that were toggled with :meth:`toggle_optimizer`. - Args: - optimizer_idx: Current optimizer idx in the training loop + This is only called automatically when automatic optimization is enabled and multiple optimizers are used. - Note: - Only called when using multiple optimizers + Args: + optimizer_idx: The index of the optimizer to untoggle. """ for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): if optimizer_idx != opt_idx: diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 0ad7f8a2040c3..bb01d42b7de18 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -93,17 +93,6 @@ def _to_lightning_optimizer(cls, optimizer: Optimizer, trainer: "pl.Trainer", op lightning_optimizer = trainer.lightning_optimizers[opt_idx] return lightning_optimizer - def _toggle_model(self) -> None: - assert self._trainer is not None - model_ref = self._trainer.lightning_module - model_ref.toggle_optimizer(self, self._optimizer_idx) - - def _untoggle_model(self) -> None: - assert self._trainer is not None - model_ref = self._trainer.lightning_module - # FIXME: what? - model_ref.untoggle_optimizer(self) - @contextmanager def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: """This function is just a helper for advanced users. @@ -111,7 +100,6 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False. - When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. @@ -120,10 +108,12 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: from pytorch_lightning.loops.utilities import _block_parallel_sync_behavior assert self._trainer is not None + lightning_module = self._trainer.lightning_module + with _block_parallel_sync_behavior(self._trainer, block=(not sync_grad)): - self._toggle_model() + lightning_module.toggle_optimizer(self, self._optimizer_idx) yield - self._untoggle_model() + lightning_module.untoggle_optimizer(self._optimizer_idx) def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> None: """Performs a single optimization step (parameter update). diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 28153d41f1cb0..6caaea8632354 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -156,7 +156,6 @@ def model(self, new_model: Optional[Module]) -> None: @property def lightning_module(self) -> "pl.LightningModule": """Returns the pure LightningModule without potential wrappers.""" - assert self._model is not None return unwrap_lightning_module(self._model) @property From 663775abaf387e652e2c77d50c2125487f7234cb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 15 Oct 2021 23:02:01 +0200 Subject: [PATCH 3/6] Remove manual tracking of optimizer steps --- pytorch_lightning/core/optimizer.py | 2 -- tests/accelerators/test_tpu.py | 15 +++++----- tests/core/test_lightning_optimizer.py | 1 - .../optimization/test_manual_optimization.py | 28 +++++++++++-------- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ba81644b9bd9a..65759c55f1a09 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -46,7 +46,6 @@ def __init__(self, optimizer: Optimizer): self._optimizer = optimizer self._trainer = None self._optimizer_idx = None - self._total_optimizer_step_calls = 0 @property def optimizer(self): @@ -201,7 +200,6 @@ def closure_dis(): profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}" self.__optimizer_step(closure=closure, profiler_name=profiler_name, **kwargs) - self._total_optimizer_step_calls += 1 def __repr__(self): groups = [ diff --git a/tests/accelerators/test_tpu.py b/tests/accelerators/test_tpu.py index df5444ac776a6..6059322e0fa59 100644 --- a/tests/accelerators/test_tpu.py +++ b/tests/accelerators/test_tpu.py @@ -13,6 +13,7 @@ # limitations under the License import collections from copy import deepcopy +from unittest.mock import patch import pytest import torch @@ -21,7 +22,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.accelerators.cpu import CPUAccelerator from pytorch_lightning.accelerators.tpu import TPUAccelerator -from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import TPUSpawnPlugin from pytorch_lightning.utilities import find_shared_parameters from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -189,16 +189,18 @@ def on_train_batch_end(self, outputs, batch, batch_idx): assert torch.all(self.layer.weight.grad == 0) self.count += 1 + def on_train_start(self): + opt = self.optimizers() + self.opt_step_patch = patch.object(opt, "step", wraps=opt.step) + self.opt_step_mock = self.opt_step_patch.start() + def on_train_end(self): assert self.called["training_step"] == 5 assert self.called["on_train_batch_start"] == 5 assert self.called["on_train_batch_end"] == 5 - class TestManualOptimizationCallack(Callback): - def on_train_end(self, trainer, pl_module): - - opt = pl_module.optimizers() - assert opt._total_optimizer_step_calls == 3 + self.opt_step_patch.stop() + assert self.opt_step_mock.call_count == 3 model = ManualOptimizationModel() model_copy = deepcopy(model) @@ -212,7 +214,6 @@ def on_train_end(self, trainer, pl_module): limit_test_batches=0, limit_val_batches=0, tpu_cores=8, - callbacks=[TestManualOptimizationCallack()], ) trainer.fit(model) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index f4f1287c122d8..05de6f44b9e44 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -161,7 +161,6 @@ def test_state(tmpdir): "zero_grad", "__setstate__", "add_param_group", - "_total_optimizer_step_calls", ] for k, v in lightning_optimizer.__dict__.items(): diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index c250caee092f7..27c6c50690335 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -23,7 +23,6 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.accelerators import Accelerator -from pytorch_lightning.callbacks import Callback from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf @@ -706,14 +705,6 @@ def configure_optimizers(self): mock_adam_step.assert_has_calls(expected_calls) -class TestManualOptimizationDDPCallack(Callback): - def on_train_end(self, trainer, pl_module): - - opt_a, opt_b = pl_module.optimizers() - assert opt_a._total_optimizer_step_calls == 4 - assert opt_b._total_optimizer_step_calls == 2 - - class TesManualOptimizationDDPModel(BoringModel): def __init__(self): super().__init__() @@ -787,8 +778,22 @@ def configure_optimizers(self): optimizer_dis = torch.optim.Adam(self.layer.parameters(), lr=0.001) return [optimizer_gen, optimizer_dis] + def on_train_start(self): + # this is done here instead of in the calling function due to `spawn` + sgd, adam = self.optimizers() + self.sgd_step_patch = patch.object(sgd, "step", wraps=sgd.step) + self.sgd_step_mock = self.sgd_step_patch.start() + self.adam_step_patch = patch.object(adam, "step", wraps=adam.step) + self.adam_step_mock = self.adam_step_patch.start() + + def on_train_end(self): + self.sgd_step_patch.stop() + assert self.sgd_step_mock.call_count == 4 + self.adam_step_patch.stop() + assert self.adam_step_mock.call_count == 2 + -def train_manual_optimization(tmpdir, accelerator, model_cls=TesManualOptimizationDDPModel): +def train_manual_optimization(tmpdir, strategy, model_cls=TesManualOptimizationDDPModel): seed_everything(42) @@ -805,8 +810,7 @@ def train_manual_optimization(tmpdir, accelerator, model_cls=TesManualOptimizati max_epochs=1, log_every_n_steps=1, gpus=2, - accelerator=accelerator, - callbacks=[TestManualOptimizationDDPCallack()], + strategy=strategy, ) trainer.fit(model) From f927fb942742f7227183a19020b1328ad8eb830b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 18 Oct 2021 03:48:23 +0200 Subject: [PATCH 4/6] Fix signature --- pytorch_lightning/core/lightning.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 103d74767b02f..549a0ccacfc16 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1565,14 +1565,14 @@ def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_va def optimizer_step( self, - epoch: int = None, - batch_idx: int = None, - optimizer: Optimizer = None, - optimizer_idx: int = None, - optimizer_closure: Optional[Callable] = None, - on_tpu: bool = None, - using_native_amp: bool = None, - using_lbfgs: bool = None, + epoch: int, + batch_idx: int, + optimizer: Union[Optimizer, LightningOptimizer], + optimizer_idx: int = 0, + optimizer_closure: Optional[Callable[[], Any]] = None, + on_tpu: bool = False, + using_native_amp: bool = False, + using_lbfgs: bool = False, ) -> None: r""" Override this method to adjust the default way the @@ -1581,10 +1581,6 @@ def optimizer_step( once per optimizer. This method (and ``zero_grad()``) won't be called during the accumulation phase when ``Trainer(accumulate_grad_batches != 1)``. - Warning: - If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter - to ``optimizer.step()`` function as shown in the examples. - Args: epoch: Current epoch batch_idx: Index of current batch From d8508b2576923db82cbaf317f864238e82df52f4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 18 Oct 2021 15:06:04 +0200 Subject: [PATCH 5/6] overload --- pytorch_lightning/core/lightning.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 549a0ccacfc16..999999d1f8e54 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -21,7 +21,7 @@ import tempfile from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, overload, Tuple, Union import torch from torch import ScriptModule, Tensor @@ -120,6 +120,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # deprecated, will be removed in 1.6 self._loaded_optimizer_states_dict = {} + @overload + def optimizers(self, use_pl_optimizer: Literal[True] = True) -> Union[LightningOptimizer, List[LightningOptimizer]]: + ... + + @overload + def optimizers(self, use_pl_optimizer: Literal[False]) -> Union[Optimizer, List[Optimizer]]: + ... + def optimizers( self, use_pl_optimizer: bool = True ) -> Union[Optimizer, LightningOptimizer, List[Optimizer], List[LightningOptimizer]]: From 10f9f26167ea612015038cdaaa3b5b77bd09f06a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 18 Oct 2021 15:29:34 +0200 Subject: [PATCH 6/6] fix import --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 999999d1f8e54..ca4b2af7eee17 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -21,13 +21,14 @@ import tempfile from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, overload, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union import torch from torch import ScriptModule, Tensor from torch.nn import Module from torch.optim.optimizer import Optimizer from torchmetrics import Metric +from typing_extensions import Literal import pytorch_lightning as pl from pytorch_lightning.callbacks.progress import base as progress_base