Skip to content

Commit 4d9215b

Browse files
bottlerfacebook-github-bot
authored andcommitted
fix to get_default_args(instance)
Summary: Small config system fix. Allows get_default_args to work on an instance which has been created with a dict (instead of a DictConfig) as an args field. E.g. ``` gm = GenericModel( raysampler_AdaptiveRaySampler_args={"scene_extent": 4.0} ) OmegaConf.structured(gm1) ``` Reviewed By: shapovalov Differential Revision: D40341047 fbshipit-source-id: 587d0e8262e271df442a80858949a48e5d6db3df
1 parent 76cddd9 commit 4d9215b

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

pytorch3d/implicitron/tools/config.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,8 @@ def create_x(self):...
709709
710710
with
711711
712-
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
713-
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
712+
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
713+
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
714714
def create_x(self):
715715
args = self.getattr(f"x_{self.x_class_type}_args")
716716
self.create_x_impl(self.x_class_type, args)
@@ -733,8 +733,8 @@ def create_x(self):...
733733
734734
with
735735
736-
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
737-
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
736+
x_Y_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Y))
737+
x_Z_args: dict = dataclasses.field(default_factory=lambda: get_default_args(Z))
738738
def create_x(self):
739739
if self.x_class_type is None:
740740
args = None
@@ -764,7 +764,7 @@ def create_x(self):...
764764
765765
will be replaced with
766766
767-
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
767+
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
768768
def create_x(self):
769769
self.create_x_impl(True, self.x_args)
770770
@@ -786,7 +786,7 @@ def create_x(self):...
786786
787787
with
788788
789-
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
789+
x_args: dict = dataclasses.field(default_factory=lambda: get_default_args(X))
790790
x_enabled: bool = False
791791
def create_x(self):
792792
self.create_x_impl(self.x_enabled, self.x_args)
@@ -818,6 +818,11 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
818818
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
819819
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
820820
821+
Note that although the *_args members are intended to have type DictConfig, they
822+
are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig
823+
in place of a dict, but not vice-versa. Allowing dict lets a class user specify
824+
x_args as an explicit dict without getting an incomprehensible error.
825+
821826
Args:
822827
some_class: the class to be processed
823828
_do_not_process: Internal use for get_default_args: Because get_default_args calls
@@ -1040,7 +1045,7 @@ def _process_member(
10401045
raise ValueError(
10411046
f"Cannot generate {args_name} because it is already present."
10421047
)
1043-
some_class.__annotations__[args_name] = DictConfig
1048+
some_class.__annotations__[args_name] = dict
10441049
if hook is not None:
10451050
hook_closed = partial(hook, derived_type)
10461051
else:
@@ -1064,7 +1069,7 @@ def _process_member(
10641069
if issubclass(type_, some_class) or type_ in _do_not_process:
10651070
raise ValueError(f"Cannot process {type_} inside {some_class}")
10661071

1067-
some_class.__annotations__[args_name] = DictConfig
1072+
some_class.__annotations__[args_name] = dict
10681073
if hook is not None:
10691074
hook_closed = partial(hook, type_)
10701075
else:

tests/implicitron/test_config.py

+28
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,34 @@ class MainTestWrapper(Configurable):
687687
remove_unused_components(args)
688688
self.assertEqual(OmegaConf.to_yaml(args), "mt_enabled: false\n")
689689

690+
def test_get_instance_args(self):
691+
mt1, mt2 = [
692+
MainTest(
693+
n_ids=0,
694+
n_reps=909,
695+
the_fruit_class_type="Pear",
696+
the_second_fruit_class_type="Pear",
697+
the_fruit_Pear_args=DictConfig({}),
698+
the_second_fruit_Pear_args={},
699+
)
700+
for _ in range(2)
701+
]
702+
# Two equivalent ways to get the DictConfig back out of an instance.
703+
cfg1 = OmegaConf.structured(mt1)
704+
cfg2 = get_default_args(mt2)
705+
self.assertEqual(cfg1, cfg2)
706+
self.assertEqual(len(cfg1.the_second_fruit_Pear_args), 0)
707+
self.assertEqual(len(mt2.the_second_fruit_Pear_args), 0)
708+
709+
from_cfg = MainTest(**cfg2)
710+
self.assertEqual(len(from_cfg.the_second_fruit_Pear_args), 0)
711+
712+
# If you want the complete args, merge with the defaults.
713+
merged_args = OmegaConf.merge(get_default_args(MainTest), cfg2)
714+
from_merged = MainTest(**merged_args)
715+
self.assertEqual(len(from_merged.the_second_fruit_Pear_args), 1)
716+
self.assertEqual(from_merged.n_reps, 909)
717+
690718
def test_tweak_hook(self):
691719
class A(Configurable):
692720
n: int = 9

0 commit comments

Comments
 (0)