Skip to content

Commit 4c53eae

Browse files
authored
Self-review of the recent Trainer changes (#14916)
1 parent 4eb7766 commit 4c53eae

File tree

5 files changed

+18
-21
lines changed

5 files changed

+18
-21
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ warn_no_return = "False"
5858
module = [
5959
"pytorch_lightning.callbacks.progress.rich_progress",
6060
"pytorch_lightning.trainer.trainer",
61-
"pytorch_lightning.trainer.connectors.checkpoint_connector",
6261
"lightning_app.api.http_methods",
6362
"lightning_app.api.request_types",
6463
"lightning_app.cli.app-template.app",

src/pytorch_lightning/trainer/teardown.py renamed to src/pytorch_lightning/trainer/call.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
import traceback
1515
from typing import Any, Callable
1616

17+
import pytorch_lightning as pl
1718
from lightning_lite.utilities.distributed import distributed_available
1819
from pytorch_lightning.trainer.states import TrainerStatus
1920
from pytorch_lightning.utilities.exceptions import _TunerExitException
2021
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
2122

2223

23-
def call_and_handle_interrupt(trainer: Any, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
24+
def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
2425
r"""
2526
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
2627
as all errors should funnel through them

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import operator
1716
import os
1817
import re
1918
from copy import deepcopy
20-
from functools import partial
2119
from typing import Any, Dict, Optional
2220

2321
import torch
@@ -172,7 +170,7 @@ def _set_ckpt_path(
172170
" or last checkpoint available. No checkpoint will be loaded."
173171
)
174172
return None
175-
ckpt_path = max(candidates_ts.keys(), key=partial(operator.getitem, candidates_ts))
173+
ckpt_path = max(candidates_ts, key=candidates_ts.get) # type: ignore[arg-type]
176174

177175
elif ckpt_path == "hpc":
178176
if not self._hpc_resume_path:

src/pytorch_lightning/trainer/setup.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
"""Houses the methods used to set up the Trainer."""
1615

17-
from typing import Any, Optional, Union
16+
from typing import Optional, Union
1817

18+
import pytorch_lightning as pl
1919
from lightning_lite.utilities.warnings import PossibleUserWarning
2020
from pytorch_lightning.accelerators import (
2121
CUDAAccelerator,
@@ -38,8 +38,8 @@
3838
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
3939

4040

41-
def init_debugging_flags(
42-
trainer: Any,
41+
def _init_debugging_flags(
42+
trainer: "pl.Trainer",
4343
limit_train_batches: Optional[Union[int, float]],
4444
limit_val_batches: Optional[Union[int, float]],
4545
limit_test_batches: Optional[Union[int, float]],
@@ -128,7 +128,7 @@ def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) ->
128128
)
129129

130130

131-
def init_profiler(trainer: Any, profiler: Optional[Union[Profiler, str]]) -> None:
131+
def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str]]) -> None:
132132
if isinstance(profiler, str):
133133
PROFILERS = {
134134
"simple": SimpleProfiler,
@@ -147,8 +147,7 @@ def init_profiler(trainer: Any, profiler: Optional[Union[Profiler, str]]) -> Non
147147
trainer.profiler = profiler or PassThroughProfiler()
148148

149149

150-
def log_device_info(trainer: Any) -> None:
151-
150+
def _log_device_info(trainer: "pl.Trainer") -> None:
152151
if CUDAAccelerator.is_available():
153152
gpu_available = True
154153
gpu_type = " (cuda)"

src/pytorch_lightning/trainer/trainer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
)
6767
from pytorch_lightning.profilers import Profiler
6868
from pytorch_lightning.strategies import ParallelStrategy, Strategy
69-
from pytorch_lightning.trainer import setup, teardown
69+
from pytorch_lightning.trainer import call, setup
7070
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
7171
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
7272
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
@@ -393,7 +393,6 @@ def __init__(
393393
Trainer._log_api_event("init")
394394
log.detail(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")
395395
self.state = TrainerState()
396-
self.num_sanity_val_steps: int
397396

398397
# init connectors
399398
self._data_connector = DataConnector(self, multiple_trainloader_mode)
@@ -498,15 +497,16 @@ def __init__(
498497
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)
499498

500499
# configure profiler
501-
setup.init_profiler(self, profiler)
500+
setup._init_profiler(self, profiler)
502501

503502
# init logger flags
504503
self._loggers: List[Logger]
505504
self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)
506505

507506
# init debugging flags
508507
self.val_check_interval: Union[int, float]
509-
setup.init_debugging_flags(
508+
self.num_sanity_val_steps: Union[float, int]
509+
setup._init_debugging_flags(
510510
self,
511511
limit_train_batches,
512512
limit_val_batches,
@@ -522,7 +522,7 @@ def __init__(
522522
self._call_callback_hooks("on_init_end")
523523

524524
def _setup_on_init(self) -> None:
525-
setup.log_device_info(self)
525+
setup._log_device_info(self)
526526

527527
self.should_stop = False
528528
self.state = TrainerState()
@@ -568,7 +568,7 @@ def fit(
568568
if not isinstance(model, pl.LightningModule):
569569
raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
570570
self.strategy._lightning_module = model
571-
teardown.call_and_handle_interrupt(
571+
call._call_and_handle_interrupt(
572572
self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
573573
)
574574

@@ -648,7 +648,7 @@ def validate(
648648
if model is not None and not isinstance(model, pl.LightningModule):
649649
raise TypeError(f"`Trainer.validate()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
650650
self.strategy._lightning_module = model or self.lightning_module
651-
return teardown.call_and_handle_interrupt(
651+
return call._call_and_handle_interrupt(
652652
self, self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule
653653
)
654654

@@ -740,7 +740,7 @@ def test(
740740
if model is not None and not isinstance(model, pl.LightningModule):
741741
raise TypeError(f"`Trainer.test()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
742742
self.strategy._lightning_module = model or self.lightning_module
743-
return teardown.call_and_handle_interrupt(
743+
return call._call_and_handle_interrupt(
744744
self, self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule
745745
)
746746

@@ -831,7 +831,7 @@ def predict(
831831
if model is not None and not isinstance(model, pl.LightningModule):
832832
raise TypeError(f"`Trainer.predict()` requires a `LightningModule`, got: {model.__class__.__qualname__}")
833833
self.strategy._lightning_module = model or self.lightning_module
834-
return teardown.call_and_handle_interrupt(
834+
return call._call_and_handle_interrupt(
835835
self, self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
836836
)
837837

0 commit comments

Comments
 (0)