Skip to content

Loop and test restructuring #9383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,8 @@ ignore_errors = "True"
[[tool.mypy.overrides]]
module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.loops.closure",
"pytorch_lightning.loops.batch.manual",
"pytorch_lightning.loops.optimizer",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.loops.optimization.*",
"pytorch_lightning.loops.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector.*",
"pytorch_lightning.trainer.progress",
"pytorch_lightning.tuner.auto_gpu_select",
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch.optim import Optimizer

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.batch.manual import ManualOptimization
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.types import STEP_OUTPUT
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.optimization.closure import ClosureResult
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, Optional

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.optimization.closure import ClosureResult
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.closure import Closure, ClosureResult
from pytorch_lightning.loops.optimization.closure import Closure, ClosureResult
from pytorch_lightning.loops.utilities import (
_block_parallel_sync_behavior,
_build_training_step_kwargs,
Expand All @@ -43,7 +43,7 @@ class OptimizerLoop(Loop):
This loop implements what is known in Lightning as Automatic Optimization.
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
# TODO: use default dict here to simplify logic in loop
self.outputs: _OUTPUTS_TYPE = []
Expand Down Expand Up @@ -71,7 +71,7 @@ def on_run_start(self, batch: Any, optimizers: List[Optimizer], batch_idx: int)
self._batch_idx = batch_idx
self._optimizers = optimizers

def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override]
def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
result = self._run_optimization(
batch,
self._batch_idx,
Expand Down Expand Up @@ -183,7 +183,7 @@ def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer)
if not is_first_batch_to_accumulate:
return None

def zero_grad_fn():
def zero_grad_fn() -> None:
self._on_before_zero_grad(optimizer)
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)

Expand All @@ -198,7 +198,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
if self._skip_backward:
return None

def backward_fn(loss: Tensor):
def backward_fn(loss: Tensor) -> Tensor:
self.backward(loss, optimizer, opt_idx)

# check if model weights are nan
Expand Down Expand Up @@ -332,6 +332,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
assert self.trainer._results is not None
self.trainer._results.cpu()

return result
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_lightning_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.closure import Closure
from pytorch_lightning.loops.optimization.closure import Closure
from tests.helpers.boring_model import BoringModel


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.optimization.closure import ClosureResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loops.closure import Closure
from pytorch_lightning.loops.optimization.closure import Closure
from pytorch_lightning.trainer.states import RunningStage
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.deterministic_model import DeterministicModel
Expand Down