Skip to content

Commit 9eccb31

Browse files
authored
Loop and test restructuring (#9383)
1 parent d773407 commit 9eccb31

19 files changed

+18
-19
lines changed

pyproject.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ ignore_errors = "True"
6262
[[tool.mypy.overrides]]
6363
module = [
6464
"pytorch_lightning.callbacks.pruning",
65-
"pytorch_lightning.loops.closure",
66-
"pytorch_lightning.loops.batch.manual",
67-
"pytorch_lightning.loops.optimizer",
68-
"pytorch_lightning.trainer.evaluation_loop",
65+
"pytorch_lightning.loops.optimization.*",
66+
"pytorch_lightning.loops.evaluation_loop",
6967
"pytorch_lightning.trainer.connectors.logger_connector.*",
7068
"pytorch_lightning.trainer.progress",
7169
"pytorch_lightning.tuner.auto_gpu_select",

pytorch_lightning/loops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
1919
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
2020
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
21-
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
21+
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401

pytorch_lightning/loops/batch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pytorch_lightning.loops.batch.manual import ManualOptimization # noqa: F401
1615
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
16+
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization # noqa: F401

pytorch_lightning/loops/batch/training_batch_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from torch.optim import Optimizer
2020

2121
from pytorch_lightning.loops.base import Loop
22-
from pytorch_lightning.loops.batch.manual import ManualOptimization
23-
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop
22+
from pytorch_lightning.loops.optimization.manual_loop import ManualOptimization
23+
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
2424
from pytorch_lightning.trainer.supporters import TensorRunningAccum
2525
from pytorch_lightning.utilities import AttributeDict
2626
from pytorch_lightning.utilities.types import STEP_OUTPUT

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pytorch_lightning import loops # import as loops to avoid circular imports
1919
from pytorch_lightning.loops.batch import TrainingBatchLoop
20-
from pytorch_lightning.loops.closure import ClosureResult
20+
from pytorch_lightning.loops.optimization.closure import ClosureResult
2121
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
2323
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress

pytorch_lightning/loops/optimizer/__init__.py renamed to pytorch_lightning/loops/optimization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pytorch_lightning.loops.optimizer.optimizer_loop import OptimizerLoop # noqa: F401
15+
from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop # noqa: F401

pytorch_lightning/loops/batch/manual.py renamed to pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, Optional
1515

1616
from pytorch_lightning.loops import Loop
17-
from pytorch_lightning.loops.closure import ClosureResult
17+
from pytorch_lightning.loops.optimization.closure import ClosureResult
1818
from pytorch_lightning.loops.utilities import (
1919
_build_training_step_kwargs,
2020
_check_training_step_output,

pytorch_lightning/loops/optimizer/optimizer_loop.py renamed to pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pytorch_lightning.core.optimizer import LightningOptimizer
2222
from pytorch_lightning.loops import Loop
23-
from pytorch_lightning.loops.closure import Closure, ClosureResult
23+
from pytorch_lightning.loops.optimization.closure import Closure, ClosureResult
2424
from pytorch_lightning.loops.utilities import (
2525
_block_parallel_sync_behavior,
2626
_build_training_step_kwargs,
@@ -43,7 +43,7 @@ class OptimizerLoop(Loop):
4343
This loop implements what is known in Lightning as Automatic Optimization.
4444
"""
4545

46-
def __init__(self):
46+
def __init__(self) -> None:
4747
super().__init__()
4848
# TODO: use default dict here to simplify logic in loop
4949
self.outputs: _OUTPUTS_TYPE = []
@@ -71,7 +71,7 @@ def on_run_start(self, batch: Any, optimizers: List[Optimizer], batch_idx: int)
7171
self._batch_idx = batch_idx
7272
self._optimizers = optimizers
7373

74-
def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override]
74+
def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignore[override]
7575
result = self._run_optimization(
7676
batch,
7777
self._batch_idx,
@@ -183,7 +183,7 @@ def _make_zero_grad_fn(self, batch_idx: int, opt_idx: int, optimizer: Optimizer)
183183
if not is_first_batch_to_accumulate:
184184
return None
185185

186-
def zero_grad_fn():
186+
def zero_grad_fn() -> None:
187187
self._on_before_zero_grad(optimizer)
188188
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)
189189

@@ -198,7 +198,7 @@ def _make_backward_fn(self, optimizer: Optimizer, opt_idx: int) -> Optional[Call
198198
if self._skip_backward:
199199
return None
200200

201-
def backward_fn(loss: Tensor):
201+
def backward_fn(loss: Tensor) -> Tensor:
202202
self.backward(loss, optimizer, opt_idx)
203203

204204
# check if model weights are nan
@@ -332,6 +332,7 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
332332

333333
if self.trainer.move_metrics_to_cpu:
334334
# hiddens and the training step output are not moved as they are not considered "metrics"
335+
assert self.trainer._results is not None
335336
self.trainer._results.cpu()
336337

337338
return result

tests/core/test_lightning_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pytorch_lightning import Trainer
2222
from pytorch_lightning.core.optimizer import LightningOptimizer
23-
from pytorch_lightning.loops.closure import Closure
23+
from pytorch_lightning.loops.optimization.closure import Closure
2424
from tests.helpers.boring_model import BoringModel
2525

2626

tests/loops/test_closure.py renamed to tests/loops/optimization/test_closure.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616

1717
from pytorch_lightning import Trainer
18-
from pytorch_lightning.loops.closure import ClosureResult
18+
from pytorch_lightning.loops.optimization.closure import ClosureResult
1919
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2020
from tests.helpers import BoringModel
2121

File renamed without changes.

tests/trainer/loops/test_training_loop_flow_scalar.py renamed to tests/loops/test_training_loop_flow_scalar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pytorch_lightning import Trainer
2020
from pytorch_lightning.core.lightning import LightningModule
21-
from pytorch_lightning.loops.closure import Closure
21+
from pytorch_lightning.loops.optimization.closure import Closure
2222
from pytorch_lightning.trainer.states import RunningStage
2323
from tests.helpers.boring_model import BoringModel, RandomDataset
2424
from tests.helpers.deterministic_model import DeterministicModel

0 commit comments

Comments
 (0)