66
66
)
67
67
from pytorch_lightning .profilers import Profiler
68
68
from pytorch_lightning .strategies import ParallelStrategy , Strategy
69
- from pytorch_lightning .trainer import setup , teardown
69
+ from pytorch_lightning .trainer import call , setup
70
70
from pytorch_lightning .trainer .configuration_validator import verify_loop_configurations
71
71
from pytorch_lightning .trainer .connectors .accelerator_connector import _LITERAL_WARN , AcceleratorConnector
72
72
from pytorch_lightning .trainer .connectors .callback_connector import CallbackConnector
@@ -393,7 +393,6 @@ def __init__(
393
393
Trainer ._log_api_event ("init" )
394
394
log .detail (f"{ self .__class__ .__name__ } : Initializing trainer with parameters: { locals ()} " )
395
395
self .state = TrainerState ()
396
- self .num_sanity_val_steps : int
397
396
398
397
# init connectors
399
398
self ._data_connector = DataConnector (self , multiple_trainloader_mode )
@@ -498,15 +497,16 @@ def __init__(
498
497
self .tuner .on_trainer_init (auto_lr_find , auto_scale_batch_size )
499
498
500
499
# configure profiler
501
- setup .init_profiler (self , profiler )
500
+ setup ._init_profiler (self , profiler )
502
501
503
502
# init logger flags
504
503
self ._loggers : List [Logger ]
505
504
self ._logger_connector .on_trainer_init (logger , log_every_n_steps , move_metrics_to_cpu )
506
505
507
506
# init debugging flags
508
507
self .val_check_interval : Union [int , float ]
509
- setup .init_debugging_flags (
508
+ self .num_sanity_val_steps : Union [float , int ]
509
+ setup ._init_debugging_flags (
510
510
self ,
511
511
limit_train_batches ,
512
512
limit_val_batches ,
@@ -522,7 +522,7 @@ def __init__(
522
522
self ._call_callback_hooks ("on_init_end" )
523
523
524
524
def _setup_on_init (self ) -> None :
525
- setup .log_device_info (self )
525
+ setup ._log_device_info (self )
526
526
527
527
self .should_stop = False
528
528
self .state = TrainerState ()
@@ -568,7 +568,7 @@ def fit(
568
568
if not isinstance (model , pl .LightningModule ):
569
569
raise TypeError (f"`Trainer.fit()` requires a `LightningModule`, got: { model .__class__ .__qualname__ } " )
570
570
self .strategy ._lightning_module = model
571
- teardown . call_and_handle_interrupt (
571
+ call . _call_and_handle_interrupt (
572
572
self , self ._fit_impl , model , train_dataloaders , val_dataloaders , datamodule , ckpt_path
573
573
)
574
574
@@ -648,7 +648,7 @@ def validate(
648
648
if model is not None and not isinstance (model , pl .LightningModule ):
649
649
raise TypeError (f"`Trainer.validate()` requires a `LightningModule`, got: { model .__class__ .__qualname__ } " )
650
650
self .strategy ._lightning_module = model or self .lightning_module
651
- return teardown . call_and_handle_interrupt (
651
+ return call . _call_and_handle_interrupt (
652
652
self , self ._validate_impl , model , dataloaders , ckpt_path , verbose , datamodule
653
653
)
654
654
@@ -740,7 +740,7 @@ def test(
740
740
if model is not None and not isinstance (model , pl .LightningModule ):
741
741
raise TypeError (f"`Trainer.test()` requires a `LightningModule`, got: { model .__class__ .__qualname__ } " )
742
742
self .strategy ._lightning_module = model or self .lightning_module
743
- return teardown . call_and_handle_interrupt (
743
+ return call . _call_and_handle_interrupt (
744
744
self , self ._test_impl , model , dataloaders , ckpt_path , verbose , datamodule
745
745
)
746
746
@@ -831,7 +831,7 @@ def predict(
831
831
if model is not None and not isinstance (model , pl .LightningModule ):
832
832
raise TypeError (f"`Trainer.predict()` requires a `LightningModule`, got: { model .__class__ .__qualname__ } " )
833
833
self .strategy ._lightning_module = model or self .lightning_module
834
- return teardown . call_and_handle_interrupt (
834
+ return call . _call_and_handle_interrupt (
835
835
self , self ._predict_impl , model , dataloaders , datamodule , return_predictions , ckpt_path
836
836
)
837
837
0 commit comments