1
1
"""Strategies for customizing subclass behaviors."""
2
2
from gc import collect
3
- from typing import Dict , Optional , Tuple , Type , Union , List , Callable , Any , get_args
3
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Type , Union , get_args
4
4
5
- from ..converters import Converter , BaseConverter
6
- from ..gen import AttributeOverride , make_dict_structure_fn , make_dict_unstructure_fn
5
+ from ..converters import BaseConverter , Converter
6
+ from ..gen import (
7
+ AttributeOverride ,
8
+ _already_generating ,
9
+ make_dict_structure_fn ,
10
+ make_dict_unstructure_fn ,
11
+ )
7
12
8
13
9
14
def _make_subclasses_tree (cl : Type ) -> List [Type ]:
@@ -12,7 +17,8 @@ def _make_subclasses_tree(cl: Type) -> List[Type]:
12
17
]
13
18
14
19
15
- def _has_subclasses (cl : Type , given_subclasses : Tuple [Type ]):
20
+ def _has_subclasses (cl : Type , given_subclasses : Tuple [Type , ...]) -> bool :
21
+ """Whether the given class has subclasses from `given_subclasses`."""
16
22
actual = set (cl .__subclasses__ ())
17
23
given = set (given_subclasses )
18
24
return bool (actual & given )
@@ -31,7 +37,7 @@ def _get_union_type(cl: Type, given_subclasses_tree: Tuple[Type]) -> Optional[Ty
31
37
def include_subclasses (
32
38
cl : Type ,
33
39
converter : Converter ,
34
- subclasses : Optional [Tuple [Type ]] = None ,
40
+ subclasses : Optional [Tuple [Type , ... ]] = None ,
35
41
union_strategy : Optional [Callable [[Any , BaseConverter ], Any ]] = None ,
36
42
overrides : Optional [Dict [str , AttributeOverride ]] = None ,
37
43
) -> None :
@@ -139,40 +145,76 @@ def unstruct_hook(
139
145
140
146
def _include_subclasses_with_union_strategy (
141
147
converter : Converter ,
142
- union_classes : Tuple [Type ],
148
+ union_classes : Tuple [Type , ... ],
143
149
union_strategy : Callable [[Any , BaseConverter ], Any ],
144
150
overrides : Dict [str , AttributeOverride ],
145
151
):
152
+ """
153
+ This function is tricky because we're dealing with what is essentially a circular reference.
154
+
155
+ We need to generate a structure hook for a class that is both:
156
+ * specific for that particular class and its own fields
157
+ * but should handle specific functions for all its descendants too
158
+
159
+ Hence the dance with registering below.
160
+ """
161
+
146
162
parent_classes = [cl for cl in union_classes if _has_subclasses (cl , union_classes )]
147
163
if not parent_classes :
148
164
return
149
165
166
+ original_unstruct_hooks = {}
167
+ original_struct_hooks = {}
150
168
for cl in union_classes :
169
+ # In the first pass, every class gets its own unstructure function according to
170
+ # the overrides.
171
+ # We just generate the hooks, and do not register them. This allows us to manipulate
172
+ # the _already_generating set to force runtime dispatch.
173
+ _already_generating .working_set = set (union_classes ) - {cl }
174
+ try :
175
+ unstruct_hook = make_dict_unstructure_fn (cl , converter , ** overrides )
176
+ struct_hook = make_dict_structure_fn (cl , converter , ** overrides )
177
+ finally :
178
+ _already_generating .working_set = set ()
179
+ original_unstruct_hooks [cl ] = unstruct_hook
180
+ original_struct_hooks [cl ] = struct_hook
181
+
182
+ # Now that's done, we can register all the hooks and generate the
183
+ # union handler. The union handler needs them.
184
+ final_union = Union [union_classes ] # type: ignore
185
+
186
+ for cl , hook in original_unstruct_hooks .items ():
151
187
152
188
def cls_is_cl (cls , _cl = cl ):
153
189
return cls is _cl
154
190
155
- converter .register_structure_hook_func (
156
- cls_is_cl , make_dict_structure_fn (cl , converter , ** overrides )
157
- )
158
- converter .register_unstructure_hook_func (
159
- cls_is_cl , make_dict_unstructure_fn (cl , converter , ** overrides )
160
- )
191
+ converter .register_unstructure_hook_func (cls_is_cl , hook )
161
192
162
- for cl in parent_classes :
163
- subclass_union = _get_union_type (cl , union_classes )
164
- sub_union_classes = get_args (subclass_union )
165
- union_strategy (subclass_union , converter )
166
- struct_hook = converter ._union_struct_registry [subclass_union ]
167
- unstruct_hook = converter ._unstructure_func .dispatch (subclass_union )
193
+ for cl , hook in original_struct_hooks .items ():
168
194
169
195
def cls_is_cl (cls , _cl = cl ):
170
196
return cls is _cl
171
197
172
- def cls_is_in_union (cls , _union_classes = sub_union_classes ):
173
- return cls in _union_classes
198
+ converter .register_structure_hook_func (cls_is_cl , hook )
174
199
175
- # This needs to use function dispatch, using singledispatch will again
176
- # match A and all subclasses, which is not what we want.
177
- converter .register_structure_hook_func (cls_is_cl , struct_hook )
178
- converter .register_unstructure_hook_func (cls_is_in_union , unstruct_hook )
200
+ union_strategy (final_union , converter )
201
+ unstruct_hook = converter ._unstructure_func .dispatch (final_union )
202
+ struct_hook = converter ._structure_func .dispatch (final_union )
203
+
204
+ for cl in union_classes :
205
+ # In the second pass, we overwrite the hooks with the union hook.
206
+
207
+ def cls_is_cl (cls , _cl = cl ):
208
+ return cls is _cl
209
+
210
+ converter .register_unstructure_hook_func (cls_is_cl , unstruct_hook )
211
+ subclasses = tuple ([c for c in union_classes if issubclass (c , cl )])
212
+ if len (subclasses ) > 1 :
213
+ u = Union [subclasses ] # type: ignore
214
+ union_strategy (u , converter )
215
+ struct_hook = converter ._structure_func .dispatch (u )
216
+
217
+ def sh (payload : dict , _ , _u = u , _s = struct_hook ) -> cl :
218
+ return _s (payload , _u )
219
+
220
+ converter .register_structure_hook_func (cls_is_cl , sh )
0 commit comments