@@ -560,6 +560,7 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[
560
560
class BoringModelRequiredClasses (BoringModel ):
561
561
def __init__ (self , num_classes : int , batch_size : int = 8 ):
562
562
super ().__init__ ()
563
+ self .save_hyperparameters ()
563
564
self .num_classes = num_classes
564
565
self .batch_size = batch_size
565
566
@@ -577,29 +578,97 @@ def add_arguments_to_parser(self, parser):
577
578
parser .link_arguments ("data.batch_size" , "model.batch_size" )
578
579
parser .link_arguments ("data.num_classes" , "model.num_classes" , apply_on = "instantiate" )
579
580
580
- cli_args = ["--data.batch_size=12" ]
581
+ cli_args = ["--data.batch_size=12" , "--trainer.max_epochs=1" ]
581
582
582
583
with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
583
584
cli = MyLightningCLI (BoringModelRequiredClasses , BoringDataModuleBatchSizeAndClasses , run = False )
584
585
585
586
assert cli .model .batch_size == 12
586
587
assert cli .model .num_classes == 5
587
588
588
- class MyLightningCLI (LightningCLI ):
589
+ cli .trainer .fit (cli .model )
590
+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
591
+ assert hparams_path .is_file ()
592
+ hparams = yaml .safe_load (hparams_path .read_text ())
593
+
594
+ hparams .pop ("_instantiator" )
595
+ assert hparams == {"batch_size" : 12 , "num_classes" : 5 }
596
+
597
+ class MyLightningCLI2 (LightningCLI ):
589
598
def add_arguments_to_parser (self , parser ):
590
599
parser .link_arguments ("data.batch_size" , "model.init_args.batch_size" )
591
600
parser .link_arguments ("data.num_classes" , "model.init_args.num_classes" , apply_on = "instantiate" )
592
601
593
- cli_args [- 1 ] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
602
+ cli_args [0 ] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
594
603
595
604
with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
596
- cli = MyLightningCLI (
605
+ cli = MyLightningCLI2 (
597
606
BoringModelRequiredClasses , BoringDataModuleBatchSizeAndClasses , subclass_mode_model = True , run = False
598
607
)
599
608
600
609
assert cli .model .batch_size == 8
601
610
assert cli .model .num_classes == 5
602
611
612
+ cli .trainer .fit (cli .model )
613
+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
614
+ assert hparams_path .is_file ()
615
+ hparams = yaml .safe_load (hparams_path .read_text ())
616
+
617
+ hparams .pop ("_instantiator" )
618
+ assert hparams == {"batch_size" : 8 , "num_classes" : 5 }
619
+
620
+
621
+ class CustomAdam (torch .optim .Adam ):
622
+ def __init__ (self , params , num_classes : Optional [int ] = None , ** kwargs ):
623
+ super ().__init__ (params , ** kwargs )
624
+
625
+
626
+ class DeepLinkTargetModel (BoringModel ):
627
+ def __init__ (
628
+ self ,
629
+ optimizer : OptimizerCallable = torch .optim .Adam ,
630
+ ):
631
+ super ().__init__ ()
632
+ self .save_hyperparameters ()
633
+ self .optimizer = optimizer
634
+
635
+ def configure_optimizers (self ):
636
+ optimizer = self .optimizer (self .parameters ())
637
+ return {"optimizer" : optimizer }
638
+
639
+
640
+ def test_lightning_cli_link_arguments_subcommands_nested_target (cleandir ):
641
+ class MyLightningCLI (LightningCLI ):
642
+ def add_arguments_to_parser (self , parser ):
643
+ parser .link_arguments (
644
+ "data.num_classes" ,
645
+ "model.init_args.optimizer.init_args.num_classes" ,
646
+ apply_on = "instantiate" ,
647
+ )
648
+
649
+ cli_args = [
650
+ "fit" ,
651
+ "--data.batch_size=12" ,
652
+ "--trainer.max_epochs=1" ,
653
+ "--model=tests_pytorch.test_cli.DeepLinkTargetModel" ,
654
+ "--model.optimizer=tests_pytorch.test_cli.CustomAdam" ,
655
+ ]
656
+
657
+ with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
658
+ cli = MyLightningCLI (
659
+ DeepLinkTargetModel ,
660
+ BoringDataModuleBatchSizeAndClasses ,
661
+ subclass_mode_model = True ,
662
+ auto_configure_optimizers = False ,
663
+ )
664
+
665
+ hparams_path = Path (cli .trainer .log_dir ) / "hparams.yaml"
666
+ assert hparams_path .is_file ()
667
+ hparams = yaml .safe_load (hparams_path .read_text ())
668
+
669
+ assert hparams ["optimizer" ]["class_path" ] == "tests_pytorch.test_cli.CustomAdam"
670
+ assert hparams ["optimizer" ]["init_args" ]["num_classes" ] == 5
671
+
603
672
604
673
class EarlyExitTestModel (BoringModel ):
605
674
def on_fit_start (self ):
0 commit comments