Skip to content

Commit afffee9

Browse files
authored
Fix for configmixin with explicit classes (#10129)
Fix for configmixin with explicit classes
1 parent 55b3f86 commit afffee9

File tree

3 files changed

+44
-9
lines changed

3 files changed

+44
-9
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5151

5252
### Fixed
5353

54-
- Fixed `_recursive_config()` for `torch.nn.ModuleList`([#10124](https://github.com/pyg-team/pytorch_geometric/pull/10124))
54+
- Fixed `_recursive_config()` for `torch.nn.ModuleList` and `torch.nn.ModuleDict` ([#10124](https://github.com/pyg-team/pytorch_geometric/pull/10124), [#10129](https://github.com/pyg-team/pytorch_geometric/pull/10129))
5555
- Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756))
5656
- Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766))
5757
- Fixed `WebQSDataset.process` raising exceptions ([#9665](https://github.com/pyg-team/pytorch_geometric/pull/9665))

test/test_config_mixin.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,18 @@ def __init__(self, p: float):
4141

4242
@register(with_target=True)
4343
class CompoundModule(torch.nn.Module, ConfigMixin):
44-
def __init__(self, z: int, module: Module, submodules: list[SubModule]):
44+
def __init__(
45+
self,
46+
z: str,
47+
module: Module,
48+
submodules: list[SubModule],
49+
key_modules: dict[str, torch.nn.Module],
50+
):
4551
super().__init__()
4652
self.z = z
4753
self.module = module
4854
self.submodules = torch.nn.ModuleList(submodules)
55+
self.key_modules = torch.nn.ModuleDict(key_modules)
4956

5057

5158
def test_config_mixin() -> None:
@@ -107,28 +114,45 @@ def test_config_mixin() -> None:
107114

108115

109116
def test_config_mixin_compound() -> None:
110-
submodules = [SubModule(1.41), SubModule(3.14)]
111117
module = Module(x=0, data=Dataclass(x=1, y=2))
112-
model = CompoundModule(z=3, module=module, submodules=submodules)
118+
submodules = [SubModule(1.41), SubModule(3.14)]
119+
key_modules = {
120+
"key1": Module(x=10, data=Dataclass(x=11, y=12)),
121+
"key2": SubModule(2.71),
122+
}
123+
model = CompoundModule(z="foo", module=module, submodules=submodules,
124+
key_modules=key_modules)
125+
113126
cfg = model.config()
114127
assert is_dataclass(cfg)
115128
assert cfg._target_ == 'test_config_mixin.CompoundModule'
116-
assert cfg.z == 3
129+
assert cfg.z == "foo"
117130
assert cfg.module._target_ == 'test_config_mixin.Module'
118131
assert cfg.module.x == 0
119132
assert isinstance(cfg.module.data, Dataclass)
120133
assert cfg.module.data.x == 1
121134
assert cfg.module.data.y == 2
135+
122136
assert len(cfg.submodules) == 2
123137
assert isinstance(cfg.submodules, Sequence)
124138
assert cfg.submodules[0]._target_ == 'test_config_mixin.SubModule'
125139
assert cfg.submodules[0].p == 1.41
126140
assert cfg.submodules[1]._target_ == 'test_config_mixin.SubModule'
127141
assert cfg.submodules[1].p == 3.14
128142

143+
assert len(cfg.key_modules) == 2
144+
assert cfg.key_modules["key1"]._target_ == 'test_config_mixin.Module'
145+
assert cfg.key_modules["key1"].x == 10
146+
assert isinstance(cfg.key_modules["key1"].data, Dataclass)
147+
assert cfg.key_modules["key1"].data.x == 11
148+
assert cfg.key_modules["key1"].data.y == 12
149+
150+
assert cfg.key_modules["key2"]._target_ == 'test_config_mixin.SubModule'
151+
assert cfg.key_modules["key2"].p == 2.71
152+
129153
model = CompoundModule.from_config(cfg)
130154
assert isinstance(model, CompoundModule)
131-
assert model.z == 3
155+
assert model.z == "foo"
132156
assert isinstance(model.module, Module)
133157
assert model.module.x == 0
134158
assert isinstance(model.module.data, Dataclass)
@@ -140,3 +164,12 @@ def test_config_mixin_compound() -> None:
140164
assert model.submodules[0].p == 1.41
141165
assert isinstance(model.submodules[1], SubModule)
142166
assert model.submodules[1].p == 3.14
167+
assert isinstance(model.key_modules, torch.nn.ModuleDict)
168+
assert len(model.key_modules) == 2
169+
assert isinstance(model.key_modules["key1"], Module)
170+
assert model.key_modules["key1"].x == 10
171+
assert isinstance(model.key_modules["key1"].data, Dataclass)
172+
assert model.key_modules["key1"].data.x == 11
173+
assert model.key_modules["key1"].data.y == 12
174+
assert isinstance(model.key_modules["key2"], SubModule)
175+
assert model.key_modules["key2"].p == 2.71

torch_geometric/config_mixin.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import inspect
22
from dataclasses import fields, is_dataclass
33
from importlib import import_module
4-
from typing import Any, Dict, Iterable
4+
from typing import Any, Dict
5+
6+
from torch.nn import ModuleDict, ModuleList
57

68
from torch_geometric.config_store import (
79
class_from_dataclass,
@@ -71,9 +73,9 @@ def _recursive_config(value: Any) -> Any:
7173
return value.config()
7274
if is_torch_instance(value, ConfigMixin):
7375
return value.config()
74-
if isinstance(value, (tuple, list, Iterable)):
76+
if isinstance(value, (tuple, list, ModuleList)):
7577
return [_recursive_config(v) for v in value]
76-
if isinstance(value, dict):
78+
if isinstance(value, (dict, ModuleDict)):
7779
return {k: _recursive_config(v) for k, v in value.items()}
7880
return value
7981

0 commit comments

Comments
 (0)