Skip to content

Commit e613d74

Browse files
committed
Enable hook factories to take converters
1 parent 436d651 commit e613d74

File tree

8 files changed

+169
-51
lines changed

8 files changed

+169
-51
lines changed

src/cattrs/converters.py

+80-12
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import Field
66
from enum import Enum
77
from functools import partial
8+
from inspect import Signature
89
from pathlib import Path
910
from typing import Any, Callable, Iterable, Optional, Tuple, TypeVar, overload
1011

@@ -55,6 +56,7 @@
5556
MultiStrategyDispatch,
5657
StructuredValue,
5758
StructureHook,
59+
TargetType,
5860
UnstructuredValue,
5961
UnstructureHook,
6062
)
@@ -85,11 +87,25 @@
8587

8688
T = TypeVar("T")
8789
V = TypeVar("V")
90+
8891
UnstructureHookFactory = TypeVar(
8992
"UnstructureHookFactory", bound=HookFactory[UnstructureHook]
9093
)
94+
95+
# The Extended factory also takes a converter.
96+
ExtendedUnstructureHookFactory = TypeVar(
97+
"ExtendedUnstructureHookFactory",
98+
bound=Callable[[TargetType, "BaseConverter"], UnstructureHook],
99+
)
100+
91101
StructureHookFactory = TypeVar("StructureHookFactory", bound=HookFactory[StructureHook])
92102

103+
# The Extended factory also takes a converter.
104+
ExtendedStructureHookFactory = TypeVar(
105+
"ExtendedStructureHookFactory",
106+
bound=Callable[[TargetType, "BaseConverter"], StructureHook],
107+
)
108+
93109

94110
class UnstructureStrategy(Enum):
95111
"""`attrs` classes unstructuring strategies."""
@@ -151,7 +167,9 @@ def __init__(
151167
self._unstructure_attrs = self.unstructure_attrs_astuple
152168
self._structure_attrs = self.structure_attrs_fromtuple
153169

154-
self._unstructure_func = MultiStrategyDispatch(unstructure_fallback_factory)
170+
self._unstructure_func = MultiStrategyDispatch(
171+
unstructure_fallback_factory, self
172+
)
155173
self._unstructure_func.register_cls_list(
156174
[(bytes, identity), (str, identity), (Path, str)]
157175
)
@@ -163,12 +181,12 @@ def __init__(
163181
),
164182
(
165183
lambda t: get_final_base(t) is not None,
166-
lambda t: self._unstructure_func.dispatch(get_final_base(t)),
184+
lambda t: self.get_unstructure_hook(get_final_base(t)),
167185
True,
168186
),
169187
(
170188
is_type_alias,
171-
lambda t: self._unstructure_func.dispatch(get_type_alias_base(t)),
189+
lambda t: self.get_unstructure_hook(get_type_alias_base(t)),
172190
True,
173191
),
174192
(is_mapping, self._unstructure_mapping),
@@ -185,7 +203,7 @@ def __init__(
185203
# Per-instance register of to-attrs converters.
186204
# Singledispatch dispatches based on the first argument, so we
187205
# store the function and switch the arguments in self.loads.
188-
self._structure_func = MultiStrategyDispatch(structure_fallback_factory)
206+
self._structure_func = MultiStrategyDispatch(structure_fallback_factory, self)
189207
self._structure_func.register_func_list(
190208
[
191209
(
@@ -308,6 +326,12 @@ def register_unstructure_hook_factory(
308326
) -> Callable[[UnstructureHookFactory], UnstructureHookFactory]:
309327
...
310328

329+
@overload
330+
def register_unstructure_hook_factory(
331+
self, predicate: Callable[[Any], bool]
332+
) -> Callable[[ExtendedUnstructureHookFactory], ExtendedUnstructureHookFactory]:
333+
...
334+
311335
@overload
312336
def register_unstructure_hook_factory(
313337
self, predicate: Callable[[Any], bool], factory: UnstructureHookFactory
@@ -325,7 +349,10 @@ def register_unstructure_hook_factory(
325349
"""
326350
Register a hook factory for a given predicate.
327351
328-
May also be used as a decorator.
352+
May also be used as a decorator. When used as a decorator, the hook
353+
factory may expose an additional required parameter. In this case,
354+
the current converter will be provided to the hook factory as that
355+
parameter.
329356
330357
:param predicate: A function that, given a type, returns whether the factory
331358
can produce a hook for that type.
@@ -336,7 +363,23 @@ def register_unstructure_hook_factory(
336363
This method may now be used as a decorator.
337364
"""
338365
if factory is None:
339-
return partial(self.register_unstructure_hook_factory, predicate)
366+
367+
def decorator(factory):
368+
# Is this an extended factory (takes a converter too)?
369+
sig = signature(factory)
370+
if (
371+
len(sig.parameters) >= 2
372+
and (list(sig.parameters.values())[1]).default is Signature.empty
373+
):
374+
self._unstructure_func.register_func_list(
375+
[(predicate, factory, "extended")]
376+
)
377+
else:
378+
self._unstructure_func.register_func_list(
379+
[(predicate, factory, True)]
380+
)
381+
382+
return decorator
340383
self._unstructure_func.register_func_list([(predicate, factory, True)])
341384
return factory
342385

@@ -420,6 +463,12 @@ def register_structure_hook_factory(
420463
) -> Callable[[StructureHookFactory, StructureHookFactory]]:
421464
...
422465

466+
@overload
467+
def register_structure_hook_factory(
468+
self, predicate: Callable[[Any, bool]]
469+
) -> Callable[[ExtendedStructureHookFactory, ExtendedStructureHookFactory]]:
470+
...
471+
423472
@overload
424473
def register_structure_hook_factory(
425474
self, predicate: Callable[[Any], bool], factory: StructureHookFactory
@@ -434,7 +483,10 @@ def register_structure_hook_factory(
434483
"""
435484
Register a hook factory for a given predicate.
436485
437-
May also be used as a decorator.
486+
May also be used as a decorator. When used as a decorator, the hook
487+
factory may expose an additional required parameter. In this case,
488+
the current converter will be provided to the hook factory as that
489+
parameter.
438490
439491
:param predicate: A function that, given a type, returns whether the factory
440492
can produce a hook for that type.
@@ -445,7 +497,23 @@ def register_structure_hook_factory(
445497
This method may now be used as a decorator.
446498
"""
447499
if factory is None:
448-
return partial(self.register_structure_hook_factory, predicate)
500+
# Decorator use.
501+
def decorator(factory):
502+
# Is this an extended factory (takes a converter too)?
503+
sig = signature(factory)
504+
if (
505+
len(sig.parameters) >= 2
506+
and (list(sig.parameters.values())[1]).default is Signature.empty
507+
):
508+
self._structure_func.register_func_list(
509+
[(predicate, factory, "extended")]
510+
)
511+
else:
512+
self._structure_func.register_func_list(
513+
[(predicate, factory, True)]
514+
)
515+
516+
return decorator
449517
self._structure_func.register_func_list([(predicate, factory, True)])
450518
return factory
451519

@@ -684,7 +752,7 @@ def _structure_list(self, obj: Iterable[T], cl: Any) -> list[T]:
684752
def _structure_deque(self, obj: Iterable[T], cl: Any) -> deque[T]:
685753
"""Convert an iterable to a potentially generic deque."""
686754
if is_bare(cl) or cl.__args__[0] in ANIES:
687-
res = deque(e for e in obj)
755+
res = deque(obj)
688756
else:
689757
elem_type = cl.__args__[0]
690758
handler = self._structure_func.dispatch(elem_type)
@@ -1048,7 +1116,7 @@ def __init__(
10481116
)
10491117
self.register_unstructure_hook_factory(
10501118
lambda t: get_newtype_base(t) is not None,
1051-
lambda t: self._unstructure_func.dispatch(get_newtype_base(t)),
1119+
lambda t: self.get_unstructure_hook(get_newtype_base(t)),
10521120
)
10531121

10541122
self.register_structure_hook_factory(is_annotated, self.gen_structure_annotated)
@@ -1070,7 +1138,7 @@ def get_structure_newtype(self, type: type[T]) -> Callable[[Any, Any], T]:
10701138

10711139
def gen_unstructure_annotated(self, type):
10721140
origin = type.__origin__
1073-
return self._unstructure_func.dispatch(origin)
1141+
return self.get_unstructure_hook(origin)
10741142

10751143
def gen_structure_annotated(self, type) -> Callable:
10761144
"""A hook factory for annotated types."""
@@ -1111,7 +1179,7 @@ def gen_unstructure_optional(self, cl: type[T]) -> Callable[[T], Any]:
11111179
if isinstance(other, TypeVar):
11121180
handler = self.unstructure
11131181
else:
1114-
handler = self._unstructure_func.dispatch(other)
1182+
handler = self.get_unstructure_hook(other)
11151183

11161184
def unstructure_optional(val, _handler=handler):
11171185
return None if val is None else _handler(val)

src/cattrs/dispatch.py

+49-30
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
from functools import lru_cache, partial, singledispatch
2-
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
1+
from __future__ import annotations
32

4-
from attrs import Factory, define, field
3+
from functools import lru_cache, singledispatch
4+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, TypeVar
5+
6+
from attrs import Factory, define
57

68
from cattrs._compat import TypeAlias
79

10+
if TYPE_CHECKING:
11+
from .converters import BaseConverter
12+
813
T = TypeVar("T")
914

1015
TargetType: TypeAlias = Any
@@ -33,23 +38,25 @@ class FunctionDispatch:
3338
objects that help determine dispatch should be instantiated objects.
3439
"""
3540

36-
_handler_pairs: List[
37-
Tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool]
41+
_converter: BaseConverter
42+
_handler_pairs: list[
43+
tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool, bool]
3844
] = Factory(list)
3945

4046
def register(
4147
self,
42-
can_handle: Callable[[Any], bool],
48+
predicate: Callable[[Any], bool],
4349
func: Callable[..., Any],
4450
is_generator=False,
51+
takes_converter=False,
4552
) -> None:
46-
self._handler_pairs.insert(0, (can_handle, func, is_generator))
53+
self._handler_pairs.insert(0, (predicate, func, is_generator, takes_converter))
4754

48-
def dispatch(self, typ: Any) -> Optional[Callable[..., Any]]:
55+
def dispatch(self, typ: Any) -> Callable[..., Any] | None:
4956
"""
5057
Return the appropriate handler for the object passed.
5158
"""
52-
for can_handle, handler, is_generator in self._handler_pairs:
59+
for can_handle, handler, is_generator, takes_converter in self._handler_pairs:
5360
# can handle could raise an exception here
5461
# such as issubclass being called on an instance.
5562
# it's easier to just ignore that case.
@@ -59,6 +66,8 @@ def dispatch(self, typ: Any) -> Optional[Callable[..., Any]]:
5966
continue
6067
if ch:
6168
if is_generator:
69+
if takes_converter:
70+
return handler(typ, self._converter)
6271
return handler(typ)
6372

6473
return handler
@@ -67,11 +76,11 @@ def dispatch(self, typ: Any) -> Optional[Callable[..., Any]]:
6776
def get_num_fns(self) -> int:
6877
return len(self._handler_pairs)
6978

70-
def copy_to(self, other: "FunctionDispatch", skip: int = 0) -> None:
79+
def copy_to(self, other: FunctionDispatch, skip: int = 0) -> None:
7180
other._handler_pairs = self._handler_pairs[:-skip] + other._handler_pairs
7281

7382

74-
@define
83+
@define(init=False)
7584
class MultiStrategyDispatch(Generic[Hook]):
7685
"""
7786
MultiStrategyDispatch uses a combination of exact-match dispatch,
@@ -85,18 +94,20 @@ class MultiStrategyDispatch(Generic[Hook]):
8594
"""
8695

8796
_fallback_factory: HookFactory[Hook]
88-
_direct_dispatch: Dict[TargetType, Hook] = field(init=False, factory=dict)
89-
_function_dispatch: FunctionDispatch = field(init=False, factory=FunctionDispatch)
90-
_single_dispatch: Any = field(
91-
init=False, factory=partial(singledispatch, _DispatchNotFound)
92-
)
93-
dispatch: Callable[[TargetType], Hook] = field(
94-
init=False,
95-
default=Factory(
96-
lambda self: lru_cache(maxsize=None)(self.dispatch_without_caching),
97-
takes_self=True,
98-
),
99-
)
97+
_converter: BaseConverter
98+
_direct_dispatch: dict[TargetType, Hook]
99+
_function_dispatch: FunctionDispatch
100+
_single_dispatch: Any
101+
dispatch: Callable[[TargetType, BaseConverter], Hook]
102+
103+
def __init__(
104+
self, fallback_factory: HookFactory[Hook], converter: BaseConverter
105+
) -> None:
106+
self._fallback_factory = fallback_factory
107+
self._direct_dispatch = {}
108+
self._function_dispatch = FunctionDispatch(converter)
109+
self._single_dispatch = singledispatch(_DispatchNotFound)
110+
self.dispatch = lru_cache(maxsize=None)(self.dispatch_without_caching)
100111

101112
def dispatch_without_caching(self, typ: TargetType) -> Hook:
102113
"""Dispatch on the type but without caching the result."""
@@ -126,15 +137,18 @@ def register_cls_list(self, cls_and_handler, direct: bool = False) -> None:
126137

127138
def register_func_list(
128139
self,
129-
pred_and_handler: List[
130-
Union[
131-
Tuple[Callable[[Any], bool], Any],
132-
Tuple[Callable[[Any], bool], Any, bool],
140+
pred_and_handler: list[
141+
tuple[Callable[[Any], bool], Any]
142+
| tuple[Callable[[Any], bool], Any, bool]
143+
| tuple[
144+
Callable[[Any], bool],
145+
Callable[[Any, BaseConverter], Any],
146+
Literal["extended"],
133147
]
134148
],
135149
):
136150
"""
137-
Register a predicate function to determine if the handle
151+
Register a predicate function to determine if the handler
138152
should be used for the type.
139153
"""
140154
for tup in pred_and_handler:
@@ -143,7 +157,12 @@ def register_func_list(
143157
self._function_dispatch.register(func, handler)
144158
else:
145159
func, handler, is_gen = tup
146-
self._function_dispatch.register(func, handler, is_generator=is_gen)
160+
if is_gen == "extended":
161+
self._function_dispatch.register(
162+
func, handler, is_generator=is_gen, takes_converter=True
163+
)
164+
else:
165+
self._function_dispatch.register(func, handler, is_generator=is_gen)
147166
self.clear_direct()
148167
self.dispatch.cache_clear()
149168

@@ -159,7 +178,7 @@ def clear_cache(self) -> None:
159178
def get_num_fns(self) -> int:
160179
return self._function_dispatch.get_num_fns()
161180

162-
def copy_to(self, other: "MultiStrategyDispatch", skip: int = 0) -> None:
181+
def copy_to(self, other: MultiStrategyDispatch, skip: int = 0) -> None:
163182
self._function_dispatch.copy_to(other._function_dispatch, skip=skip)
164183
for cls, fn in self._single_dispatch.registry.items():
165184
other._single_dispatch.register(cls, fn)

src/cattrs/gen/typeddicts.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def make_dict_unstructure_fn(
139139
if nrb is not NOTHING:
140140
t = nrb
141141
try:
142-
handler = converter._unstructure_func.dispatch(t)
142+
handler = converter.get_unstructure_hook(t)
143143
except RecursionError:
144144
# There's a circular reference somewhere down the line
145145
handler = converter.unstructure
@@ -185,7 +185,7 @@ def make_dict_unstructure_fn(
185185
if nrb is not NOTHING:
186186
t = nrb
187187
try:
188-
handler = converter._unstructure_func.dispatch(t)
188+
handler = converter.get_unstructure_hook(t)
189189
except RecursionError:
190190
# There's a circular reference somewhere down the line
191191
handler = converter.unstructure

src/cattrs/preconf/orjson.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def key_handler(v):
5656
# (For example base85 encoding for bytes.)
5757
# In that case, we want to use the override.
5858

59-
kh = converter._unstructure_func.dispatch(args[0])
59+
kh = converter.get_unstructure_hook(args[0])
6060
if kh != identity:
6161
key_handler = kh
6262

0 commit comments

Comments
 (0)