Skip to content

Commit 552f12c

Browse files
committed
Instantiator receives values applied by instantiation links to set in hparams (Lightning-AI#20311).
1 parent 6da480d commit 552f12c

File tree

4 files changed

+112
-9
lines changed

4 files changed

+112
-9
lines changed

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
matplotlib>3.1, <3.9.0
66
omegaconf >=2.2.3, <2.4.0
77
hydra-core >=1.2.0, <1.4.0
8-
jsonargparse[signatures] >=4.27.7, <=4.35.0
8+
jsonargparse[signatures] >=4.39.0, <4.40.0
99
rich >=12.3.0, <13.6.0
1010
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
1111
bitsandbytes >=0.45.2,<0.45.3; platform_system != "Darwin"

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424

2525
### Fixed
2626

27-
-
27+
- Fixed `save_hyperparameters` not working correctly with `LightningCLI` when there are parsing links applied on instantiation ([#???](https://github.com/Lightning-AI/pytorch-lightning/pull/???))
2828

2929

3030
---

src/lightning/pytorch/cli.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def __init__(
320320
args: ArgsType = None,
321321
run: bool = True,
322322
auto_configure_optimizers: bool = True,
323+
load_from_checkpoint_support: bool = True,
323324
) -> None:
324325
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
325326
called / instantiated using a parsed configuration file and / or command line args.
@@ -360,6 +361,11 @@ def __init__(
360361
``dict`` or ``jsonargparse.Namespace``.
361362
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
362363
method. If set to ``False``, the trainer and model classes will be instantiated only.
364+
auto_configure_optimizers: Whether to automatically add default optimizer and lr_scheduler arguments.
365+
load_from_checkpoint_support: Whether ``save_hyperparameters`` should save the original parsed
366+
hyperparameters (instead of what ``__init__`` receives), such that it is possible for
367+
``load_from_checkpoint`` to correctly instantiate classes even when using complex nesting and
368+
dependency injection.
363369
364370
"""
365371
self.save_config_callback = save_config_callback
@@ -389,7 +395,8 @@ def __init__(
389395

390396
self._set_seed()
391397

392-
self._add_instantiators()
398+
if load_from_checkpoint_support:
399+
self._add_instantiators()
393400
self.before_instantiate_classes()
394401
self.instantiate_classes()
395402
self.after_instantiate_classes()
@@ -537,11 +544,14 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
537544
else:
538545
self.config = parser.parse_args(args)
539546

540-
def _add_instantiators(self) -> None:
547+
def _dump_config(self) -> None:
548+
if hasattr(self, "config_dump"):
549+
return
541550
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
542551
if "subcommand" in self.config:
543552
self.config_dump = self.config_dump[self.config.subcommand]
544553

554+
def _add_instantiators(self) -> None:
545555
self.parser.add_instantiator(
546556
_InstantiatorFn(cli=self, key="model"),
547557
_get_module_type(self._model_class),
@@ -792,12 +802,27 @@ def _get_module_type(value: Union[Callable, type]) -> type:
792802
return value
793803

794804

805+
def _set_dict_nested(data: dict, key: str, value: Any) -> None:
806+
keys = key.split(".")
807+
for k in keys[:-1]:
808+
assert k in data, f"Expected key {key} to be in data"
809+
data = data[k]
810+
data[keys[-1]] = value
811+
812+
795813
class _InstantiatorFn:
796814
def __init__(self, cli: LightningCLI, key: str) -> None:
797815
self.cli = cli
798816
self.key = key
799817

800-
def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
818+
def __call__(
819+
self,
820+
class_type: type[ModuleType],
821+
*args: Any,
822+
applied_instantiation_links: dict,
823+
**kwargs: Any,
824+
) -> ModuleType:
825+
self.cli._dump_config()
801826
hparams = self.cli.config_dump.get(self.key, {})
802827
if "class_path" in hparams:
803828
# To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the
@@ -808,6 +833,15 @@ def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> M
808833
**hparams.get("init_args", {}),
809834
**hparams.get("dict_kwargs", {}),
810835
}
836+
# get instantiation link target values from kwargs
837+
for key, value in applied_instantiation_links.items():
838+
if not key.startswith(f"{self.key}."):
839+
continue
840+
key = key[len(f"{self.key}.") :]
841+
if key.startswith("init_args."):
842+
key = key[len("init_args.") :]
843+
_set_dict_nested(hparams, key, value)
844+
811845
with _given_hyperparameters_context(
812846
hparams=hparams,
813847
instantiator="lightning.pytorch.cli.instantiate_module",

tests/tests_pytorch/test_cli.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[
560560
class BoringModelRequiredClasses(BoringModel):
561561
def __init__(self, num_classes: int, batch_size: int = 8):
562562
super().__init__()
563+
self.save_hyperparameters()
563564
self.num_classes = num_classes
564565
self.batch_size = batch_size
565566

@@ -577,29 +578,97 @@ def add_arguments_to_parser(self, parser):
577578
parser.link_arguments("data.batch_size", "model.batch_size")
578579
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")
579580

580-
cli_args = ["--data.batch_size=12"]
581+
cli_args = ["--data.batch_size=12", "--trainer.max_epochs=1"]
581582

582583
with mock.patch("sys.argv", ["any.py"] + cli_args):
583584
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)
584585

585586
assert cli.model.batch_size == 12
586587
assert cli.model.num_classes == 5
587588

588-
class MyLightningCLI(LightningCLI):
589+
cli.trainer.fit(cli.model)
590+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
591+
assert hparams_path.is_file()
592+
hparams = yaml.safe_load(hparams_path.read_text())
593+
594+
hparams.pop("_instantiator")
595+
assert hparams == {"batch_size": 12, "num_classes": 5}
596+
597+
class MyLightningCLI2(LightningCLI):
589598
def add_arguments_to_parser(self, parser):
590599
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
591600
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
592601

593-
cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
602+
cli_args[0] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
594603

595604
with mock.patch("sys.argv", ["any.py"] + cli_args):
596-
cli = MyLightningCLI(
605+
cli = MyLightningCLI2(
597606
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
598607
)
599608

600609
assert cli.model.batch_size == 8
601610
assert cli.model.num_classes == 5
602611

612+
cli.trainer.fit(cli.model)
613+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
614+
assert hparams_path.is_file()
615+
hparams = yaml.safe_load(hparams_path.read_text())
616+
617+
hparams.pop("_instantiator")
618+
assert hparams == {"batch_size": 8, "num_classes": 5}
619+
620+
621+
class CustomAdam(torch.optim.Adam):
622+
def __init__(self, params, num_classes: Optional[int] = None, **kwargs):
623+
super().__init__(params, **kwargs)
624+
625+
626+
class DeepLinkTargetModel(BoringModel):
627+
def __init__(
628+
self,
629+
optimizer: OptimizerCallable = torch.optim.Adam,
630+
):
631+
super().__init__()
632+
self.save_hyperparameters()
633+
self.optimizer = optimizer
634+
635+
def configure_optimizers(self):
636+
optimizer = self.optimizer(self.parameters())
637+
return {"optimizer": optimizer}
638+
639+
640+
def test_lightning_cli_link_arguments_subcommands_nested_target(cleandir):
641+
class MyLightningCLI(LightningCLI):
642+
def add_arguments_to_parser(self, parser):
643+
parser.link_arguments(
644+
"data.num_classes",
645+
"model.init_args.optimizer.init_args.num_classes",
646+
apply_on="instantiate",
647+
)
648+
649+
cli_args = [
650+
"fit",
651+
"--data.batch_size=12",
652+
"--trainer.max_epochs=1",
653+
"--model=tests_pytorch.test_cli.DeepLinkTargetModel",
654+
"--model.optimizer=tests_pytorch.test_cli.CustomAdam",
655+
]
656+
657+
with mock.patch("sys.argv", ["any.py"] + cli_args):
658+
cli = MyLightningCLI(
659+
DeepLinkTargetModel,
660+
BoringDataModuleBatchSizeAndClasses,
661+
subclass_mode_model=True,
662+
auto_configure_optimizers=False,
663+
)
664+
665+
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
666+
assert hparams_path.is_file()
667+
hparams = yaml.safe_load(hparams_path.read_text())
668+
669+
assert hparams["optimizer"]["class_path"] == "tests_pytorch.test_cli.CustomAdam"
670+
assert hparams["optimizer"]["init_args"]["num_classes"] == 5
671+
603672

604673
class EarlyExitTestModel(BoringModel):
605674
def on_fit_start(self):

0 commit comments

Comments
 (0)