Skip to content

Commit 856ed10

Browse files
authored
Improve collision check on hparams between LightningModule and DataModule (#9496)
* fix hyperparameter logging between LightningModule and DataModule
1 parent 1bb5fcc commit 856ed10

File tree

3 files changed

+96
-8
lines changed

3 files changed

+96
-8
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
204204
- Executing the `optimizer_closure` is now required when overriding the `optimizer_step` hook ([#9360](https://github.com/PyTorchLightning/pytorch-lightning/pull/9360))
205205

206206

207+
- Changed logging of `LightningModule` and `LightningDataModule` hyperparameters to raise an exception only if there are colliding keys with different values ([#9496](https://github.com/PyTorchLightning/pytorch-lightning/pull/9496))
208+
209+
207210
### Deprecated
208211

209212
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`

pytorch_lightning/trainer/trainer.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def _pre_dispatch(self):
10361036
self.accelerator.pre_dispatch(self)
10371037
self._log_hyperparams()
10381038

1039-
def _log_hyperparams(self):
1039+
def _log_hyperparams(self) -> None:
10401040
# log hyper-parameters
10411041
hparams_initial = None
10421042

@@ -1047,12 +1047,20 @@ def _log_hyperparams(self):
10471047
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
10481048
datamodule_hparams = self.datamodule.hparams_initial
10491049
lightning_hparams = self.lightning_module.hparams_initial
1050-
1051-
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
1052-
if colliding_keys:
1050+
inconsistent_keys = []
1051+
for key in lightning_hparams.keys() & datamodule_hparams.keys():
1052+
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
1053+
if type(lm_val) != type(dm_val):
1054+
inconsistent_keys.append(key)
1055+
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
1056+
inconsistent_keys.append(key)
1057+
elif lm_val != dm_val:
1058+
inconsistent_keys.append(key)
1059+
if inconsistent_keys:
10531060
raise MisconfigurationException(
1054-
f"Error while merging hparams: the keys {colliding_keys} are present "
1055-
"in both the LightningModule's and LightningDataModule's hparams."
1061+
f"Error while merging hparams: the keys {inconsistent_keys} are present "
1062+
"in both the LightningModule's and LightningDataModule's hparams "
1063+
"but have different values."
10561064
)
10571065
hparams_initial = {**lightning_hparams, **datamodule_hparams}
10581066
elif self.lightning_module._log_hyperparams:

tests/loggers/test_base.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,20 @@
1313
# limitations under the License.
1414
import pickle
1515
from argparse import Namespace
16-
from typing import Optional
16+
from copy import deepcopy
17+
from typing import Any, Dict, Optional
1718
from unittest.mock import MagicMock, patch
1819

1920
import numpy as np
2021
import pytest
22+
import torch
2123

2224
from pytorch_lightning import Trainer
2325
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
2426
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
2527
from pytorch_lightning.utilities import rank_zero_only
26-
from tests.helpers import BoringModel
28+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
29+
from tests.helpers.boring_model import BoringDataModule, BoringModel
2730

2831

2932
def test_logger_collection():
@@ -288,3 +291,77 @@ def __init__(self, param_one, param_two):
288291
log_hyperparams_mock.assert_called()
289292
else:
290293
log_hyperparams_mock.assert_not_called()
294+
295+
296+
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams")
297+
def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
298+
class TestModel(BoringModel):
299+
def __init__(self, hparams: Dict[str, Any]) -> None:
300+
super().__init__()
301+
self.save_hyperparameters(hparams)
302+
303+
class TestDataModule(BoringDataModule):
304+
def __init__(self, hparams: Dict[str, Any]) -> None:
305+
super().__init__()
306+
self.save_hyperparameters(hparams)
307+
308+
class _Test:
309+
...
310+
311+
same_params = {1: 1, "2": 2, "three": 3.0, "test": _Test(), "4": torch.tensor(4)}
312+
model = TestModel(same_params)
313+
dm = TestDataModule(same_params)
314+
315+
trainer = Trainer(
316+
default_root_dir=tmpdir,
317+
max_epochs=1,
318+
limit_train_batches=0.1,
319+
limit_val_batches=0.1,
320+
num_sanity_val_steps=0,
321+
checkpoint_callback=False,
322+
progress_bar_refresh_rate=0,
323+
weights_summary=None,
324+
)
325+
# there should be no exceptions raised for the same key/value pair in the hparams of both
326+
# the lightning module and data module
327+
trainer.fit(model)
328+
329+
obj_params = deepcopy(same_params)
330+
obj_params["test"] = _Test()
331+
model = TestModel(same_params)
332+
dm = TestDataModule(obj_params)
333+
trainer.fit(model)
334+
335+
diff_params = deepcopy(same_params)
336+
diff_params.update({1: 0, "test": _Test()})
337+
model = TestModel(same_params)
338+
dm = TestDataModule(diff_params)
339+
trainer = Trainer(
340+
default_root_dir=tmpdir,
341+
max_epochs=1,
342+
limit_train_batches=0.1,
343+
limit_val_batches=0.1,
344+
num_sanity_val_steps=0,
345+
checkpoint_callback=False,
346+
progress_bar_refresh_rate=0,
347+
weights_summary=None,
348+
)
349+
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
350+
trainer.fit(model, dm)
351+
352+
tensor_params = deepcopy(same_params)
353+
tensor_params.update({"4": torch.tensor(3)})
354+
model = TestModel(same_params)
355+
dm = TestDataModule(tensor_params)
356+
trainer = Trainer(
357+
default_root_dir=tmpdir,
358+
max_epochs=1,
359+
limit_train_batches=0.1,
360+
limit_val_batches=0.1,
361+
num_sanity_val_steps=0,
362+
checkpoint_callback=False,
363+
progress_bar_refresh_rate=0,
364+
weights_summary=None,
365+
)
366+
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
367+
trainer.fit(model, dm)

0 commit comments

Comments
 (0)