@@ -661,7 +661,7 @@ def _check_batch_size_seq_length(self, attribute, value):
661
661
)
662
662
663
663
@staticmethod
664
- def dict_to_defaultdict (d : Dict , t : type ) -> DefaultDict :
664
+ def dict_to_trainerdict (d : Dict , t : type ) -> "TrainerSettings.DefaultTrainerDict" :
665
665
return TrainerSettings .DefaultTrainerDict (
666
666
cattr .structure (d , Dict [str , TrainerSettings ])
667
667
)
@@ -718,12 +718,26 @@ def __init__(self, *args):
718
718
super ().__init__ (* args )
719
719
else :
720
720
super ().__init__ (TrainerSettings , * args )
721
+ self ._config_specified = True
722
+
723
+ def set_config_specified (self , require_config_specified : bool ) -> None :
724
+ self ._config_specified = require_config_specified
721
725
722
726
def __missing__ (self , key : Any ) -> "TrainerSettings" :
723
727
if TrainerSettings .default_override is not None :
724
- return copy .deepcopy (TrainerSettings .default_override )
728
+ self [key ] = copy .deepcopy (TrainerSettings .default_override )
729
+ elif self ._config_specified :
730
+ raise TrainerConfigError (
731
+ f"The behavior name { key } has not been specified in the trainer configuration. "
732
+ f"Please add an entry in the configuration file for { key } , or set default_settings."
733
+ )
725
734
else :
726
- return TrainerSettings ()
735
+ logger .warn (
736
+ f"Behavior name { key } does not match any behaviors specified "
737
+ f"in the trainer configuration file. A default configuration will be used."
738
+ )
739
+ self [key ] = TrainerSettings ()
740
+ return self [key ]
727
741
728
742
729
743
# COMMAND LINE #########################################################################
@@ -788,7 +802,7 @@ class TorchSettings:
788
802
@attr .s (auto_attribs = True )
789
803
class RunOptions (ExportableSettings ):
790
804
default_settings : Optional [TrainerSettings ] = None
791
- behaviors : DefaultDict [ str , TrainerSettings ] = attr .ib (
805
+ behaviors : TrainerSettings . DefaultTrainerDict = attr .ib (
792
806
factory = TrainerSettings .DefaultTrainerDict
793
807
)
794
808
env_settings : EnvironmentSettings = attr .ib (factory = EnvironmentSettings )
@@ -800,7 +814,8 @@ class RunOptions(ExportableSettings):
800
814
# These are options that are relevant to the run itself, and not the engine or environment.
801
815
# They will be left here.
802
816
debug : bool = parser .get_default ("debug" )
803
- # Strict conversion
817
+
818
+ # Convert to settings while making sure all fields are valid
804
819
cattr .register_structure_hook (EnvironmentSettings , strict_to_cls )
805
820
cattr .register_structure_hook (EngineSettings , strict_to_cls )
806
821
cattr .register_structure_hook (CheckpointSettings , strict_to_cls )
@@ -816,7 +831,7 @@ class RunOptions(ExportableSettings):
816
831
)
817
832
cattr .register_structure_hook (TrainerSettings , TrainerSettings .structure )
818
833
cattr .register_structure_hook (
819
- DefaultDict [ str , TrainerSettings ] , TrainerSettings .dict_to_defaultdict
834
+ TrainerSettings . DefaultTrainerDict , TrainerSettings .dict_to_trainerdict
820
835
)
821
836
cattr .register_unstructure_hook (collections .defaultdict , defaultdict_to_dict )
822
837
@@ -839,8 +854,12 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
839
854
"engine_settings" : {},
840
855
"torch_settings" : {},
841
856
}
857
+ _require_all_behaviors = True
842
858
if config_path is not None :
843
859
configured_dict .update (load_config (config_path ))
860
+ else :
861
+ # If we're not loading from a file, we don't require all behavior names to be specified.
862
+ _require_all_behaviors = False
844
863
845
864
# Use the YAML file values for all values not specified in the CLI.
846
865
for key in configured_dict .keys ():
@@ -868,6 +887,10 @@ def from_argparse(args: argparse.Namespace) -> "RunOptions":
868
887
configured_dict [key ] = val
869
888
870
889
final_runoptions = RunOptions .from_dict (configured_dict )
890
+ # Need check to bypass type checking but keep structure on dict working
891
+ if isinstance (final_runoptions .behaviors , TrainerSettings .DefaultTrainerDict ):
892
+ # configure whether or not we should require all behavior names to be found in the config YAML
893
+ final_runoptions .behaviors .set_config_specified (_require_all_behaviors )
871
894
return final_runoptions
872
895
873
896
@staticmethod
0 commit comments