@@ -723,9 +723,9 @@ def predict_step(self, batch, *_):
723
723
assert getattr (trainer , path_attr ) == ckpt_path
724
724
725
725
726
- @pytest .mark .parametrize ("enable_model_summary " , (False , True ))
726
+ @pytest .mark .parametrize ("enable_checkpointing " , (False , True ))
727
727
@pytest .mark .parametrize ("fn" , ("validate" , "test" , "predict" ))
728
- def test_tested_checkpoint_path_best (tmpdir , enable_model_summary , fn ):
728
+ def test_tested_checkpoint_path_best (tmpdir , enable_checkpointing , fn ):
729
729
class TestModel (BoringModel ):
730
730
def validation_step (self , batch , batch_idx ):
731
731
self .log ("foo" , - batch_idx )
@@ -746,15 +746,15 @@ def predict_step(self, batch, *_):
746
746
limit_predict_batches = 1 ,
747
747
enable_progress_bar = False ,
748
748
default_root_dir = tmpdir ,
749
- enable_model_summary = enable_model_summary ,
749
+ enable_checkpointing = enable_checkpointing ,
750
750
)
751
751
trainer .fit (model )
752
752
753
753
trainer_fn = getattr (trainer , fn )
754
754
path_attr = f"{ fn } { 'd' if fn == 'validate' else 'ed' } _ckpt_path"
755
755
assert getattr (trainer , path_attr ) is None
756
756
757
- if enable_model_summary :
757
+ if enable_checkpointing :
758
758
trainer_fn (ckpt_path = "best" )
759
759
assert getattr (trainer , path_attr ) == trainer .checkpoint_callback .best_model_path
760
760
0 commit comments