|
13 | 13 | # limitations under the License.
|
14 | 14 | import pickle
|
15 | 15 | from argparse import Namespace
|
16 |
| -from typing import Optional |
| 16 | +from copy import deepcopy |
| 17 | +from typing import Any, Dict, Optional |
17 | 18 | from unittest.mock import MagicMock, patch
|
18 | 19 |
|
19 | 20 | import numpy as np
|
20 | 21 | import pytest
|
| 22 | +import torch |
21 | 23 |
|
22 | 24 | from pytorch_lightning import Trainer
|
23 | 25 | from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
|
24 | 26 | from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
|
25 | 27 | from pytorch_lightning.utilities import rank_zero_only
|
26 |
| -from tests.helpers import BoringModel |
| 28 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
| 29 | +from tests.helpers.boring_model import BoringDataModule, BoringModel |
27 | 30 |
|
28 | 31 |
|
29 | 32 | def test_logger_collection():
|
@@ -288,3 +291,77 @@ def __init__(self, param_one, param_two):
|
288 | 291 | log_hyperparams_mock.assert_called()
|
289 | 292 | else:
|
290 | 293 | log_hyperparams_mock.assert_not_called()
|
| 294 | + |
| 295 | + |
| 296 | +@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams") |
| 297 | +def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir): |
| 298 | + class TestModel(BoringModel): |
| 299 | + def __init__(self, hparams: Dict[str, Any]) -> None: |
| 300 | + super().__init__() |
| 301 | + self.save_hyperparameters(hparams) |
| 302 | + |
| 303 | + class TestDataModule(BoringDataModule): |
| 304 | + def __init__(self, hparams: Dict[str, Any]) -> None: |
| 305 | + super().__init__() |
| 306 | + self.save_hyperparameters(hparams) |
| 307 | + |
| 308 | + class _Test: |
| 309 | + ... |
| 310 | + |
| 311 | + same_params = {1: 1, "2": 2, "three": 3.0, "test": _Test(), "4": torch.tensor(4)} |
| 312 | + model = TestModel(same_params) |
| 313 | + dm = TestDataModule(same_params) |
| 314 | + |
| 315 | + trainer = Trainer( |
| 316 | + default_root_dir=tmpdir, |
| 317 | + max_epochs=1, |
| 318 | + limit_train_batches=0.1, |
| 319 | + limit_val_batches=0.1, |
| 320 | + num_sanity_val_steps=0, |
| 321 | + checkpoint_callback=False, |
| 322 | + progress_bar_refresh_rate=0, |
| 323 | + weights_summary=None, |
| 324 | + ) |
| 325 | + # there should be no exceptions raised for the same key/value pair in the hparams of both |
| 326 | + # the lightning module and data module |
| 327 | + trainer.fit(model) |
| 328 | + |
| 329 | + obj_params = deepcopy(same_params) |
| 330 | + obj_params["test"] = _Test() |
| 331 | + model = TestModel(same_params) |
| 332 | + dm = TestDataModule(obj_params) |
| 333 | + trainer.fit(model) |
| 334 | + |
| 335 | + diff_params = deepcopy(same_params) |
| 336 | + diff_params.update({1: 0, "test": _Test()}) |
| 337 | + model = TestModel(same_params) |
| 338 | + dm = TestDataModule(diff_params) |
| 339 | + trainer = Trainer( |
| 340 | + default_root_dir=tmpdir, |
| 341 | + max_epochs=1, |
| 342 | + limit_train_batches=0.1, |
| 343 | + limit_val_batches=0.1, |
| 344 | + num_sanity_val_steps=0, |
| 345 | + checkpoint_callback=False, |
| 346 | + progress_bar_refresh_rate=0, |
| 347 | + weights_summary=None, |
| 348 | + ) |
| 349 | + with pytest.raises(MisconfigurationException, match="Error while merging hparams"): |
| 350 | + trainer.fit(model, dm) |
| 351 | + |
| 352 | + tensor_params = deepcopy(same_params) |
| 353 | + tensor_params.update({"4": torch.tensor(3)}) |
| 354 | + model = TestModel(same_params) |
| 355 | + dm = TestDataModule(tensor_params) |
| 356 | + trainer = Trainer( |
| 357 | + default_root_dir=tmpdir, |
| 358 | + max_epochs=1, |
| 359 | + limit_train_batches=0.1, |
| 360 | + limit_val_batches=0.1, |
| 361 | + num_sanity_val_steps=0, |
| 362 | + checkpoint_callback=False, |
| 363 | + progress_bar_refresh_rate=0, |
| 364 | + weights_summary=None, |
| 365 | + ) |
| 366 | + with pytest.raises(MisconfigurationException, match="Error while merging hparams"): |
| 367 | + trainer.fit(model, dm) |
0 commit comments