Skip to content

Commit 38ed26e

Browse files
authored
Do not require omegaconf to run tests (#10832)
1 parent a81accb commit 38ed26e

File tree

6 files changed

+64
-58
lines changed

6 files changed

+64
-58
lines changed

tests/checkpointing/test_model_checkpoint.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import pytest
3030
import torch
3131
import yaml
32-
from omegaconf import Container, OmegaConf
3332
from torch import optim
3433

3534
import pytorch_lightning as pl
@@ -39,9 +38,13 @@
3938
from pytorch_lightning.loggers import TensorBoardLogger
4039
from pytorch_lightning.utilities.cloud_io import load as pl_load
4140
from pytorch_lightning.utilities.exceptions import MisconfigurationException
41+
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
4242
from tests.helpers import BoringModel
4343
from tests.helpers.runif import RunIf
4444

45+
if _OMEGACONF_AVAILABLE:
46+
from omegaconf import Container, OmegaConf
47+
4548

4649
def test_model_checkpoint_state_key():
4750
early_stopping = ModelCheckpoint(monitor="val_loss")
@@ -1094,8 +1097,8 @@ def training_step(self, *args):
10941097
assert model_checkpoint.current_score == expected
10951098

10961099

1097-
@pytest.mark.parametrize("hparams_type", [dict, Container])
1098-
def test_hparams_type(tmpdir, hparams_type):
1100+
@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
1101+
def test_hparams_type(tmpdir, use_omegaconf):
10991102
class TestModel(BoringModel):
11001103
def __init__(self, hparams):
11011104
super().__init__()
@@ -1113,15 +1116,15 @@ def __init__(self, hparams):
11131116
enable_model_summary=False,
11141117
)
11151118
hp = {"test_hp_0": 1, "test_hp_1": 2}
1116-
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
1119+
hp = OmegaConf.create(hp) if use_omegaconf else Namespace(**hp)
11171120
model = TestModel(hp)
11181121
trainer.fit(model)
11191122
ckpt = trainer.checkpoint_connector.dump_checkpoint()
1120-
if hparams_type == Container:
1121-
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
1123+
if use_omegaconf:
1124+
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], Container)
11221125
else:
11231126
# make sure it's not AttributeDict
1124-
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is hparams_type
1127+
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is dict
11251128

11261129

11271130
def test_ckpt_version_after_rerun_new_trainer(tmpdir):

tests/core/test_datamodules.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@
2020

2121
import pytest
2222
import torch
23-
from omegaconf import OmegaConf
2423

2524
from pytorch_lightning import LightningDataModule, Trainer
2625
from pytorch_lightning.callbacks import ModelCheckpoint
2726
from pytorch_lightning.trainer.states import TrainerFn
28-
from pytorch_lightning.utilities import AttributeDict
27+
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
2928
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3029
from pytorch_lightning.utilities.model_helpers import is_overridden
3130
from tests.helpers import BoringDataModule, BoringModel
@@ -34,6 +33,9 @@
3433
from tests.helpers.simple_models import ClassificationModel
3534
from tests.helpers.utils import reset_seed
3635

36+
if _OMEGACONF_AVAILABLE:
37+
from omegaconf import OmegaConf
38+
3739

3840
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
3941
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
@@ -440,8 +442,9 @@ def test_hyperparameters_saving():
440442
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
441443
assert data.hparams == AttributeDict({"hello": "world"})
442444

443-
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
444-
assert data.hparams == OmegaConf.create({"hello": "world"})
445+
if _OMEGACONF_AVAILABLE:
446+
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
447+
assert data.hparams == OmegaConf.create({"hello": "world"})
445448

446449

447450
def test_define_as_dataclass():

tests/helpers/runif.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
2828
_HOROVOD_AVAILABLE,
2929
_IPU_AVAILABLE,
30+
_OMEGACONF_AVAILABLE,
3031
_RICH_AVAILABLE,
3132
_TORCH_QUANTIZE_AVAILABLE,
3233
_TPU_AVAILABLE,
@@ -70,6 +71,7 @@ def __new__(
7071
deepspeed: bool = False,
7172
rich: bool = False,
7273
skip_49370: bool = False,
74+
omegaconf: bool = False,
7375
**kwargs,
7476
):
7577
"""
@@ -89,9 +91,10 @@ def __new__(
8991
standalone: Mark the test as standalone, our CI will run it in a separate process.
9092
fairscale: Require that facebookresearch/fairscale is installed.
9193
fairscale_fully_sharded: Require that `fairscale` fully sharded support is available.
92-
deepspeed: Require that Microsoft/DeepSpeed is installed.
94+
deepspeed: Require that microsoft/DeepSpeed is installed.
9395
rich: Require that willmcgugan/rich is installed.
9496
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
97+
omegaconf: Require that omry/omegaconf is installed.
9598
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
9699
"""
97100
conditions = []
@@ -177,6 +180,10 @@ def __new__(
177180
conditions.append(ge_3_9 and old_torch)
178181
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")
179182

183+
if omegaconf:
184+
conditions.append(not _OMEGACONF_AVAILABLE)
185+
reasons.append("omegaconf")
186+
180187
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
181188
return pytest.mark.skipif(
182189
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs

tests/loggers/test_tensorboard.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,16 @@
2121
import pytest
2222
import torch
2323
import yaml
24-
from omegaconf import OmegaConf
2524

2625
from pytorch_lightning import Trainer
2726
from pytorch_lightning.loggers import TensorBoardLogger
2827
from pytorch_lightning.loggers.base import LoggerCollection
29-
from pytorch_lightning.utilities.imports import _compare_version
28+
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
3029
from tests.helpers import BoringModel
30+
from tests.helpers.runif import RunIf
31+
32+
if _OMEGACONF_AVAILABLE:
33+
from omegaconf import OmegaConf
3134

3235

3336
@pytest.mark.skipif(
@@ -205,6 +208,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
205208
logger.log_hyperparams(hparams, metrics)
206209

207210

211+
@RunIf(omegaconf=True)
208212
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
209213
logger = TensorBoardLogger(tmpdir, default_hp_metric=False)
210214
hparams = {
@@ -214,8 +218,6 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
214218
"bool": True,
215219
"dict": {"a": {"b": "c"}},
216220
"list": [1, 2, 3],
217-
# "namespace": Namespace(foo=Namespace(bar="buzz")),
218-
# "layer": torch.nn.BatchNorm1d,
219221
}
220222
hparams = OmegaConf.create(hparams)
221223

tests/models/test_hparams.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,24 @@
2424
import pytest
2525
import torch
2626
from fsspec.implementations.local import LocalFileSystem
27-
from omegaconf import Container, OmegaConf
28-
from omegaconf.dictconfig import DictConfig
2927
from torch.utils.data import DataLoader
3028

3129
from pytorch_lightning import LightningModule, Trainer
3230
from pytorch_lightning.callbacks import ModelCheckpoint
3331
from pytorch_lightning.core.datamodule import LightningDataModule
3432
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
35-
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable
33+
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
3634
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3735
from tests.helpers import BoringModel, RandomDataset
36+
from tests.helpers.runif import RunIf
3837

3938
if _HYDRA_EXPERIMENTAL_AVAILABLE:
4039
from hydra.experimental import compose, initialize
4140

41+
if _OMEGACONF_AVAILABLE:
42+
from omegaconf import Container, OmegaConf
43+
from omegaconf.dictconfig import DictConfig
44+
4245

4346
class SaveHparamsModel(BoringModel):
4447
"""Tests that a model can take an object."""
@@ -117,6 +120,7 @@ def test_dict_hparams(tmpdir, cls):
117120
_run_standard_hparams_test(tmpdir, model, cls)
118121

119122

123+
@RunIf(omegaconf=True)
120124
@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel])
121125
def test_omega_conf_hparams(tmpdir, cls):
122126
# init model
@@ -275,10 +279,18 @@ def __init__(obj, *more_args, other_arg=300, **more_kwargs):
275279
obj.save_hyperparameters()
276280

277281

278-
class DictConfSubClassBoringModel(SubClassBoringModel):
279-
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
280-
super().__init__(*args, **kwargs)
281-
self.save_hyperparameters()
282+
if _OMEGACONF_AVAILABLE:
283+
284+
class DictConfSubClassBoringModel(SubClassBoringModel):
285+
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
286+
super().__init__(*args, **kwargs)
287+
self.save_hyperparameters()
288+
289+
290+
else:
291+
292+
class DictConfSubClassBoringModel:
293+
...
282294

283295

284296
@pytest.mark.parametrize(
@@ -290,7 +302,7 @@ def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something"))
290302
SubSubClassBoringModel,
291303
AggSubClassBoringModel,
292304
UnconventionalArgsBoringModel,
293-
DictConfSubClassBoringModel,
305+
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
294306
],
295307
)
296308
def test_collect_init_arguments(tmpdir, cls):
@@ -383,31 +395,6 @@ def test_collect_init_arguments_with_local_vars(cls):
383395
assert model.hparams["arg2"] == 2
384396

385397

386-
# @pytest.mark.parametrize("cls,config", [
387-
# (SaveHparamsModel, Namespace(my_arg=42)),
388-
# (SaveHparamsModel, dict(my_arg=42)),
389-
# (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))),
390-
# (AssignHparamsModel, Namespace(my_arg=42)),
391-
# (AssignHparamsModel, dict(my_arg=42)),
392-
# (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))),
393-
# ])
394-
# def test_single_config_models(tmpdir, cls, config):
395-
# """ Test that the model automatically saves the arguments passed into the constructor """
396-
# model = cls(config)
397-
#
398-
# # no matter how you do it, it should be assigned
399-
# assert model.hparams.my_arg == 42
400-
#
401-
# # verify that the checkpoint saved the correct values
402-
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
403-
# trainer.fit(model)
404-
#
405-
# # verify that model loads correctly
406-
# raw_checkpoint_path = _raw_checkpoint_path(trainer)
407-
# model = cls.load_from_checkpoint(raw_checkpoint_path)
408-
# assert model.hparams.my_arg == 42
409-
410-
411398
class AnotherArgModel(BoringModel):
412399
def __init__(self, arg1):
413400
super().__init__()
@@ -511,8 +498,9 @@ def _compare_params(loaded_params, default_params: dict):
511498
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
512499
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
513500

514-
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
515-
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
501+
if _OMEGACONF_AVAILABLE:
502+
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
503+
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
516504

517505

518506
class NoArgsSubClassBoringModel(CustomBoringModel):

tests/trainer/test_trainer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import cloudpickle
2727
import pytest
2828
import torch
29-
from omegaconf import OmegaConf
3029
from torch.nn.parallel.distributed import DistributedDataParallel
3130
from torch.optim import SGD
3231
from torch.utils.data import DataLoader, IterableDataset
@@ -51,6 +50,7 @@
5150
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
5251
from pytorch_lightning.utilities.cloud_io import load as pl_load
5352
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
53+
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
5454
from pytorch_lightning.utilities.seed import seed_everything
5555
from tests.base import EvalModelTemplate
5656
from tests.helpers import BoringModel, RandomDataset
@@ -59,6 +59,9 @@
5959
from tests.helpers.runif import RunIf
6060
from tests.helpers.simple_models import ClassificationModel
6161

62+
if _OMEGACONF_AVAILABLE:
63+
from omegaconf import OmegaConf
64+
6265

6366
@pytest.mark.parametrize("url_ckpt", [True, False])
6467
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
@@ -1271,12 +1274,12 @@ def __init__(self, **kwargs):
12711274
TrainerSubclass(abcdefg="unknown_arg")
12721275

12731276

1274-
@pytest.mark.parametrize(
1275-
"trainer_params", [OmegaConf.create(dict(max_epochs=1, gpus=1)), OmegaConf.create(dict(max_epochs=1, gpus=[0]))]
1276-
)
1277-
@RunIf(min_gpus=1)
1278-
def test_trainer_omegaconf(trainer_params):
1279-
Trainer(**trainer_params)
1277+
@RunIf(omegaconf=True)
1278+
@pytest.mark.parametrize("trainer_params", [{"max_epochs": 1, "gpus": 1}, {"max_epochs": 1, "gpus": [0]}])
1279+
@mock.patch("torch.cuda.device_count", return_value=1)
1280+
def test_trainer_omegaconf(_, trainer_params):
1281+
config = OmegaConf.create(trainer_params)
1282+
Trainer(**config)
12801283

12811284

12821285
def test_trainer_pickle(tmpdir):

0 commit comments

Comments
 (0)