Skip to content

Commit 60de814

Browse files
committed
Use non-deprecated options in tests
1 parent 6429de8 commit 60de814

File tree

5 files changed

+16
-15
lines changed

5 files changed

+16
-15
lines changed

tests/callbacks/test_device_stats_monitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
4646
gpus=1,
4747
callbacks=[device_stats],
4848
logger=DebugLogger(tmpdir),
49-
checkpoint_callback=False,
49+
enable_checkpointing=False,
5050
enable_progress_bar=False,
5151
)
5252

@@ -75,7 +75,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
7575
gpus=1,
7676
callbacks=[device_stats],
7777
logger=DebugLogger(tmpdir),
78-
checkpoint_callback=False,
78+
enable_checkpointing=False,
7979
enable_progress_bar=False,
8080
)
8181

@@ -104,7 +104,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
104104
log_every_n_steps=1,
105105
callbacks=[device_stats],
106106
logger=DebugLogger(tmpdir),
107-
checkpoint_callback=False,
107+
enable_checkpointing=False,
108108
enable_progress_bar=False,
109109
)
110110

@@ -122,7 +122,7 @@ def test_device_stats_monitor_no_logger(tmpdir):
122122
callbacks=[device_stats],
123123
max_epochs=1,
124124
logger=False,
125-
checkpoint_callback=False,
125+
enable_checkpointing=False,
126126
enable_progress_bar=False,
127127
)
128128

tests/callbacks/test_model_summary.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def test_model_summary_callback_present_trainer():
3232

3333

3434
def test_model_summary_callback_with_weights_summary_none():
35-
36-
trainer = Trainer(weights_summary=None)
35+
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
36+
trainer = Trainer(weights_summary=None)
3737
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
3838

3939
trainer = Trainer(enable_model_summary=False)
@@ -42,7 +42,8 @@ def test_model_summary_callback_with_weights_summary_none():
4242
trainer = Trainer(enable_model_summary=False, weights_summary="full")
4343
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
4444

45-
trainer = Trainer(enable_model_summary=True, weights_summary=None)
45+
with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"):
46+
trainer = Trainer(enable_model_summary=True, weights_summary=None)
4647
assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks)
4748

4849

tests/models/test_restore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def configure_optimizers(self):
153153
state_dict = torch.load(resume_ckpt)
154154

155155
trainer_args.update(
156-
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "checkpoint_callback": False, "callbacks": []}
156+
{"max_epochs": 3, "resume_from_checkpoint": resume_ckpt, "enable_checkpointing": False, "callbacks": []}
157157
)
158158

159159
class CustomClassifModel(CustomClassifModel):

tests/trainer/test_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):
468468
"max_epochs": 3,
469469
"max_steps": num_train_samples + 10,
470470
"logger": False,
471-
"weights_summary": None,
471+
"enable_model_summary": False,
472472
"enable_progress_bar": False,
473473
}
474474
trainer = Trainer(**trainer_kwargs)
@@ -555,7 +555,7 @@ def test_trainer_min_steps_and_epochs(tmpdir):
555555
# define less min steps than 1 epoch
556556
"min_steps": num_train_samples // 2,
557557
"logger": False,
558-
"weights_summary": None,
558+
"enable_model_summary": False,
559559
"enable_progress_bar": False,
560560
}
561561
trainer = Trainer(**trainer_kwargs)
@@ -723,9 +723,9 @@ def predict_step(self, batch, *_):
723723
assert getattr(trainer, path_attr) == ckpt_path
724724

725725

726-
@pytest.mark.parametrize("checkpoint_callback", (False, True))
726+
@pytest.mark.parametrize("enable_model_summary", (False, True))
727727
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
728-
def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
728+
def test_tested_checkpoint_path_best(tmpdir, enable_model_summary, fn):
729729
class TestModel(BoringModel):
730730
def validation_step(self, batch, batch_idx):
731731
self.log("foo", -batch_idx)
@@ -746,15 +746,15 @@ def predict_step(self, batch, *_):
746746
limit_predict_batches=1,
747747
enable_progress_bar=False,
748748
default_root_dir=tmpdir,
749-
checkpoint_callback=checkpoint_callback,
749+
enable_model_summary=enable_model_summary,
750750
)
751751
trainer.fit(model)
752752

753753
trainer_fn = getattr(trainer, fn)
754754
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
755755
assert getattr(trainer, path_attr) is None
756756

757-
if checkpoint_callback:
757+
if enable_model_summary:
758758
trainer_fn(ckpt_path="best")
759759
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
760760

tests/utilities/test_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def test_lightning_cli_config_and_subclass_mode(tmpdir):
384384
"fit": {
385385
"model": {"class_path": "tests.helpers.BoringModel"},
386386
"data": {"class_path": "tests.helpers.BoringDataModule", "init_args": {"data_dir": str(tmpdir)}},
387-
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "weights_summary": None},
387+
"trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False},
388388
}
389389
}
390390
config_path = tmpdir / "config.yaml"

0 commit comments

Comments
 (0)