Skip to content

Improve collision check on hparams between LightningModule and DataModule #9492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ananthsub opened this issue Sep 13, 2021 · 2 comments · Fixed by #9496
Closed

Improve collision check on hparams between LightningModule and DataModule #9492

ananthsub opened this issue Sep 13, 2021 · 2 comments · Fixed by #9496
Labels
feature Is an improvement or enhancement good first issue Good for newcomers help wanted Open to be worked on let's do it! approved to implement

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Sep 13, 2021

🚀 Feature

Motivation

With the recent ability to log hyperparameters on the datamodule, this exception was introduced in case the keys had overlaps between the lightning module and datamodule
https://github.com/PyTorchLightning/pytorch-lightning/blob/ec828b826717cd3b5beabcb6d0cacf41b2320a98/pytorch_lightning/trainer/trainer.py#L1043-L1053

However, this check can be overly strict: if the same hparams are shared across the LightningModule and DataModule, this will result in an error.

Pitch

Only raise an exception if there are overlapping keys which have different values across lightning module and datamodule

datamodule_hparams = self.datamodule.hparams_initial
lightning_hparams = self.lightning_module.hparams_initial
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
if colliding_keys:
    inconsistent_keys = []
    for key in colliding_keys:
        if lightning_hparams[key] != datamodule_hparams[key]:
            inonsistent_keys.append(key)
    if len(inconsistent_keys) > 0:
        raise MisconfigurationException(
            f"Error while merging hparams: the keys {inconsistent_keys} are present "
            "in both the LightningModule's and LightningDataModule's hparams and have different values."
        )
hparams_initial = {**lightning_hparams, **datamodule_hparams}

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning

  • Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

  • Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

@ananthsub ananthsub added feature Is an improvement or enhancement help wanted Open to be worked on good first issue Good for newcomers labels Sep 13, 2021
@tchaton
Copy link
Contributor

tchaton commented Sep 13, 2021

Good catch ! Sounds like a good idea.

@ananthsub
Copy link
Contributor Author

ananthsub commented Sep 15, 2021

The workaround right now is, in one of the components, to call self.save_hyperparameters(..., logger=False) to avoid the collision

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement good first issue Good for newcomers help wanted Open to be worked on let's do it! approved to implement
Projects
None yet
2 participants