Skip to content

Commit f2b8f3d

Browse files
Tinchematmel
authored andcommitted
Tweak the include_subclasses strategy
1 parent d2654da commit f2b8f3d

File tree

2 files changed

+75
-62
lines changed

2 files changed

+75
-62
lines changed

src/cattrs/strategies/_subclasses.py

+66-24
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""Strategies for customizing subclass behaviors."""
22
from gc import collect
3-
from typing import Dict, Optional, Tuple, Type, Union, List, Callable, Any, get_args
3+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, get_args
44

5-
from ..converters import Converter, BaseConverter
6-
from ..gen import AttributeOverride, make_dict_structure_fn, make_dict_unstructure_fn
5+
from ..converters import BaseConverter, Converter
6+
from ..gen import (
7+
AttributeOverride,
8+
_already_generating,
9+
make_dict_structure_fn,
10+
make_dict_unstructure_fn,
11+
)
712

813

914
def _make_subclasses_tree(cl: Type) -> List[Type]:
@@ -12,7 +17,8 @@ def _make_subclasses_tree(cl: Type) -> List[Type]:
1217
]
1318

1419

15-
def _has_subclasses(cl: Type, given_subclasses: Tuple[Type]):
20+
def _has_subclasses(cl: Type, given_subclasses: Tuple[Type, ...]) -> bool:
21+
"""Whether the given class has subclasses from `given_subclasses`."""
1622
actual = set(cl.__subclasses__())
1723
given = set(given_subclasses)
1824
return bool(actual & given)
@@ -31,7 +37,7 @@ def _get_union_type(cl: Type, given_subclasses_tree: Tuple[Type]) -> Optional[Ty
3137
def include_subclasses(
3238
cl: Type,
3339
converter: Converter,
34-
subclasses: Optional[Tuple[Type]] = None,
40+
subclasses: Optional[Tuple[Type, ...]] = None,
3541
union_strategy: Optional[Callable[[Any, BaseConverter], Any]] = None,
3642
overrides: Optional[Dict[str, AttributeOverride]] = None,
3743
) -> None:
@@ -139,40 +145,76 @@ def unstruct_hook(
139145

140146
def _include_subclasses_with_union_strategy(
141147
converter: Converter,
142-
union_classes: Tuple[Type],
148+
union_classes: Tuple[Type, ...],
143149
union_strategy: Callable[[Any, BaseConverter], Any],
144150
overrides: Dict[str, AttributeOverride],
145151
):
152+
"""
153+
This function is tricky because we're dealing with what is essentially a circular reference.
154+
155+
We need to generate a structure hook for a class that is both:
156+
* specific for that particular class and its own fields
157+
* but should handle specific functions for all its descendants too
158+
159+
Hence the dance with registering below.
160+
"""
161+
146162
parent_classes = [cl for cl in union_classes if _has_subclasses(cl, union_classes)]
147163
if not parent_classes:
148164
return
149165

166+
original_unstruct_hooks = {}
167+
original_struct_hooks = {}
150168
for cl in union_classes:
169+
# In the first pass, every class gets its own unstructure function according to
170+
# the overrides.
171+
# We just generate the hooks, and do not register them. This allows us to manipulate
172+
# the _already_generating set to force runtime dispatch.
173+
_already_generating.working_set = set(union_classes) - {cl}
174+
try:
175+
unstruct_hook = make_dict_unstructure_fn(cl, converter, **overrides)
176+
struct_hook = make_dict_structure_fn(cl, converter, **overrides)
177+
finally:
178+
_already_generating.working_set = set()
179+
original_unstruct_hooks[cl] = unstruct_hook
180+
original_struct_hooks[cl] = struct_hook
181+
182+
# Now that's done, we can register all the hooks and generate the
183+
# union handler. The union handler needs them.
184+
final_union = Union[union_classes] # type: ignore
185+
186+
for cl, hook in original_unstruct_hooks.items():
151187

152188
def cls_is_cl(cls, _cl=cl):
153189
return cls is _cl
154190

155-
converter.register_structure_hook_func(
156-
cls_is_cl, make_dict_structure_fn(cl, converter, **overrides)
157-
)
158-
converter.register_unstructure_hook_func(
159-
cls_is_cl, make_dict_unstructure_fn(cl, converter, **overrides)
160-
)
191+
converter.register_unstructure_hook_func(cls_is_cl, hook)
161192

162-
for cl in parent_classes:
163-
subclass_union = _get_union_type(cl, union_classes)
164-
sub_union_classes = get_args(subclass_union)
165-
union_strategy(subclass_union, converter)
166-
struct_hook = converter._union_struct_registry[subclass_union]
167-
unstruct_hook = converter._unstructure_func.dispatch(subclass_union)
193+
for cl, hook in original_struct_hooks.items():
168194

169195
def cls_is_cl(cls, _cl=cl):
170196
return cls is _cl
171197

172-
def cls_is_in_union(cls, _union_classes=sub_union_classes):
173-
return cls in _union_classes
198+
converter.register_structure_hook_func(cls_is_cl, hook)
174199

175-
# This needs to use function dispatch, using singledispatch will again
176-
# match A and all subclasses, which is not what we want.
177-
converter.register_structure_hook_func(cls_is_cl, struct_hook)
178-
converter.register_unstructure_hook_func(cls_is_in_union, unstruct_hook)
200+
union_strategy(final_union, converter)
201+
unstruct_hook = converter._unstructure_func.dispatch(final_union)
202+
struct_hook = converter._structure_func.dispatch(final_union)
203+
204+
for cl in union_classes:
205+
# In the second pass, we overwrite the hooks with the union hook.
206+
207+
def cls_is_cl(cls, _cl=cl):
208+
return cls is _cl
209+
210+
converter.register_unstructure_hook_func(cls_is_cl, unstruct_hook)
211+
subclasses = tuple([c for c in union_classes if issubclass(c, cl)])
212+
if len(subclasses) > 1:
213+
u = Union[subclasses] # type: ignore
214+
union_strategy(u, converter)
215+
struct_hook = converter._structure_func.dispatch(u)
216+
217+
def sh(payload: dict, _, _u=u, _s=struct_hook) -> cl:
218+
return _s(payload, _u)
219+
220+
converter.register_structure_hook_func(cls_is_cl, sh)

tests/strategies/test_include_subclasses.py

+9-38
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import typing
2-
from functools import partial
32
from copy import deepcopy
3+
from functools import partial
4+
from typing import Tuple
45

56
import attr
67
import pytest
78

89
from cattrs import Converter, override
910
from cattrs.errors import ClassValidationError
10-
from cattrs.strategies._subclasses import _make_subclasses_tree
1111
from cattrs.strategies import configure_tagged_union, include_subclasses
1212

1313

@@ -148,8 +148,8 @@ def conv_w_subclasses(request):
148148
"struct_unstruct", IDS_TO_STRUCT_UNSTRUCT.values(), ids=IDS_TO_STRUCT_UNSTRUCT
149149
)
150150
def test_structuring_with_inheritance(
151-
conv_w_subclasses: typing.Tuple[Converter, bool], struct_unstruct
152-
):
151+
conv_w_subclasses: Tuple[Converter, bool], struct_unstruct
152+
) -> None:
153153
structured, unstructured = struct_unstruct
154154

155155
converter, included_subclasses_param = conv_w_subclasses
@@ -176,7 +176,7 @@ def test_structuring_with_inheritance(
176176

177177
if structured.__class__ in {Parent, Child1, Child2}:
178178
with pytest.raises(ClassValidationError):
179-
converter.structure(unstructured, GrandChild)
179+
_ = converter.structure(unstructured, GrandChild)
180180

181181

182182
def test_structure_as_union():
@@ -201,14 +201,11 @@ def test_circular_reference(conv_w_subclasses):
201201
if included_subclasses_param != "with-subclasses-and-tagged-union":
202202
unstruct = _remove_type_name(unstruct)
203203

204-
if included_subclasses_param == "wo-subclasses":
205-
# We already now that it will fail
206-
return
204+
if "wo-subclasses" in included_subclasses_param:
205+
pytest.xfail("Cannot succeed if include_subclasses strategy is not used")
207206

208207
res = c.unstructure(struct)
209-
if "wo-subclasses" or "tagged-union" in included_subclasses_param:
210-
# TODO: tagged-union should work here, but it does not yet.
211-
pytest.xfail("Cannot succeed if include_subclasses strategy is not used")
208+
212209
assert res == unstruct
213210

214211
res = c.unstructure(struct, CircularA)
@@ -222,7 +219,7 @@ def test_circular_reference(conv_w_subclasses):
222219
"struct_unstruct", IDS_TO_STRUCT_UNSTRUCT.values(), ids=IDS_TO_STRUCT_UNSTRUCT
223220
)
224221
def test_unstructuring_with_inheritance(
225-
conv_w_subclasses: typing.Tuple[Converter, bool], struct_unstruct
222+
conv_w_subclasses: Tuple[Converter, bool], struct_unstruct
226223
):
227224
structured, unstructured = struct_unstruct
228225
converter, included_subclasses_param = conv_w_subclasses
@@ -326,29 +323,3 @@ def test_overrides(with_union_strategy: bool, struct_unstruct: str):
326323
assert c.unstructure(structured) == unstructured
327324
assert c.structure(unstructured, Parent) == structured
328325
assert c.structure(unstructured, structured.__class__) == structured
329-
330-
331-
def test_class_tree_generator():
332-
class P:
333-
pass
334-
335-
class C1(P):
336-
pass
337-
338-
class C2(P):
339-
pass
340-
341-
class GC1(C2):
342-
pass
343-
344-
class GC2(C2):
345-
pass
346-
347-
tree_c1 = _make_subclasses_tree(C1)
348-
assert tree_c1 == [C1]
349-
350-
tree_c2 = _make_subclasses_tree(C2)
351-
assert tree_c2 == [C2, GC1, GC2]
352-
353-
tree_p = _make_subclasses_tree(P)
354-
assert tree_p == [P, C1, C2, GC1, GC2]

0 commit comments

Comments
 (0)