Skip to content

Commit a7549b0

Browse files
gh-112281: Allow Union with unhashable Annotated metadata (#112283)
Co-authored-by: Alex Waygood <[email protected]>
1 parent 2713c2a commit a7549b0

File tree

4 files changed

+156
-18
lines changed

4 files changed

+156
-18
lines changed

Lib/test/test_types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,26 @@ def test_hash(self):
713713
self.assertEqual(hash(int | str), hash(str | int))
714714
self.assertEqual(hash(int | str), hash(typing.Union[int, str]))
715715

716+
def test_union_of_unhashable(self):
717+
class UnhashableMeta(type):
718+
__hash__ = None
719+
720+
class A(metaclass=UnhashableMeta): ...
721+
class B(metaclass=UnhashableMeta): ...
722+
723+
self.assertEqual((A | B).__args__, (A, B))
724+
union1 = A | B
725+
with self.assertRaises(TypeError):
726+
hash(union1)
727+
728+
union2 = int | B
729+
with self.assertRaises(TypeError):
730+
hash(union2)
731+
732+
union3 = A | int
733+
with self.assertRaises(TypeError):
734+
hash(union3)
735+
716736
def test_instancecheck_and_subclasscheck(self):
717737
for x in (int | str, typing.Union[int, str]):
718738
with self.subTest(x=x):

Lib/test/test_typing.py

Lines changed: 103 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import collections
33
import collections.abc
44
from collections import defaultdict
5-
from functools import lru_cache, wraps
5+
from functools import lru_cache, wraps, reduce
66
import gc
77
import inspect
88
import itertools
9+
import operator
910
import pickle
1011
import re
1112
import sys
@@ -1769,6 +1770,26 @@ def test_union_union(self):
17691770
v = Union[u, Employee]
17701771
self.assertEqual(v, Union[int, float, Employee])
17711772

1773+
def test_union_of_unhashable(self):
1774+
class UnhashableMeta(type):
1775+
__hash__ = None
1776+
1777+
class A(metaclass=UnhashableMeta): ...
1778+
class B(metaclass=UnhashableMeta): ...
1779+
1780+
self.assertEqual(Union[A, B].__args__, (A, B))
1781+
union1 = Union[A, B]
1782+
with self.assertRaises(TypeError):
1783+
hash(union1)
1784+
1785+
union2 = Union[int, B]
1786+
with self.assertRaises(TypeError):
1787+
hash(union2)
1788+
1789+
union3 = Union[A, int]
1790+
with self.assertRaises(TypeError):
1791+
hash(union3)
1792+
17721793
def test_repr(self):
17731794
self.assertEqual(repr(Union), 'typing.Union')
17741795
u = Union[Employee, int]
@@ -5506,10 +5527,8 @@ def some(self):
55065527
self.assertFalse(hasattr(WithOverride.some, "__override__"))
55075528

55085529
def test_multiple_decorators(self):
5509-
import functools
5510-
55115530
def with_wraps(f): # similar to `lru_cache` definition
5512-
@functools.wraps(f)
5531+
@wraps(f)
55135532
def wrapper(*args, **kwargs):
55145533
return f(*args, **kwargs)
55155534
return wrapper
@@ -8524,6 +8543,76 @@ def test_flatten(self):
85248543
self.assertEqual(A.__metadata__, (4, 5))
85258544
self.assertEqual(A.__origin__, int)
85268545

8546+
def test_deduplicate_from_union(self):
8547+
# Regular:
8548+
self.assertEqual(get_args(Annotated[int, 1] | int),
8549+
(Annotated[int, 1], int))
8550+
self.assertEqual(get_args(Union[Annotated[int, 1], int]),
8551+
(Annotated[int, 1], int))
8552+
self.assertEqual(get_args(Annotated[int, 1] | Annotated[int, 2] | int),
8553+
(Annotated[int, 1], Annotated[int, 2], int))
8554+
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[int, 2], int]),
8555+
(Annotated[int, 1], Annotated[int, 2], int))
8556+
self.assertEqual(get_args(Annotated[int, 1] | Annotated[str, 1] | int),
8557+
(Annotated[int, 1], Annotated[str, 1], int))
8558+
self.assertEqual(get_args(Union[Annotated[int, 1], Annotated[str, 1], int]),
8559+
(Annotated[int, 1], Annotated[str, 1], int))
8560+
8561+
# Duplicates:
8562+
self.assertEqual(Annotated[int, 1] | Annotated[int, 1] | int,
8563+
Annotated[int, 1] | int)
8564+
self.assertEqual(Union[Annotated[int, 1], Annotated[int, 1], int],
8565+
Union[Annotated[int, 1], int])
8566+
8567+
# Unhashable metadata:
8568+
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[int, set()] | int),
8569+
(str, Annotated[int, {}], Annotated[int, set()], int))
8570+
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[int, set()], int]),
8571+
(str, Annotated[int, {}], Annotated[int, set()], int))
8572+
self.assertEqual(get_args(str | Annotated[int, {}] | Annotated[str, {}] | int),
8573+
(str, Annotated[int, {}], Annotated[str, {}], int))
8574+
self.assertEqual(get_args(Union[str, Annotated[int, {}], Annotated[str, {}], int]),
8575+
(str, Annotated[int, {}], Annotated[str, {}], int))
8576+
8577+
self.assertEqual(get_args(Annotated[int, 1] | str | Annotated[str, {}] | int),
8578+
(Annotated[int, 1], str, Annotated[str, {}], int))
8579+
self.assertEqual(get_args(Union[Annotated[int, 1], str, Annotated[str, {}], int]),
8580+
(Annotated[int, 1], str, Annotated[str, {}], int))
8581+
8582+
import dataclasses
8583+
@dataclasses.dataclass
8584+
class ValueRange:
8585+
lo: int
8586+
hi: int
8587+
v = ValueRange(1, 2)
8588+
self.assertEqual(get_args(Annotated[int, v] | None),
8589+
(Annotated[int, v], types.NoneType))
8590+
self.assertEqual(get_args(Union[Annotated[int, v], None]),
8591+
(Annotated[int, v], types.NoneType))
8592+
self.assertEqual(get_args(Optional[Annotated[int, v]]),
8593+
(Annotated[int, v], types.NoneType))
8594+
8595+
# Unhashable metadata duplicated:
8596+
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
8597+
Annotated[int, {}] | int)
8598+
self.assertEqual(Annotated[int, {}] | Annotated[int, {}] | int,
8599+
int | Annotated[int, {}])
8600+
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
8601+
Union[Annotated[int, {}], int])
8602+
self.assertEqual(Union[Annotated[int, {}], Annotated[int, {}], int],
8603+
Union[int, Annotated[int, {}]])
8604+
8605+
def test_order_in_union(self):
8606+
expr1 = Annotated[int, 1] | str | Annotated[str, {}] | int
8607+
for args in itertools.permutations(get_args(expr1)):
8608+
with self.subTest(args=args):
8609+
self.assertEqual(expr1, reduce(operator.or_, args))
8610+
8611+
expr2 = Union[Annotated[int, 1], str, Annotated[str, {}], int]
8612+
for args in itertools.permutations(get_args(expr2)):
8613+
with self.subTest(args=args):
8614+
self.assertEqual(expr2, Union[args])
8615+
85278616
def test_specialize(self):
85288617
L = Annotated[List[T], "my decoration"]
85298618
LI = Annotated[List[int], "my decoration"]
@@ -8544,6 +8633,16 @@ def test_hash_eq(self):
85448633
{Annotated[int, 4, 5], Annotated[int, 4, 5], Annotated[T, 4, 5]},
85458634
{Annotated[int, 4, 5], Annotated[T, 4, 5]}
85468635
)
8636+
# Unhashable `metadata` raises `TypeError`:
8637+
a1 = Annotated[int, []]
8638+
with self.assertRaises(TypeError):
8639+
hash(a1)
8640+
8641+
class A:
8642+
__hash__ = None
8643+
a2 = Annotated[int, A()]
8644+
with self.assertRaises(TypeError):
8645+
hash(a2)
85478646

85488647
def test_instantiate(self):
85498648
class C:

Lib/typing.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -308,19 +308,33 @@ def _unpack_args(args):
308308
newargs.append(arg)
309309
return newargs
310310

311-
def _deduplicate(params):
311+
def _deduplicate(params, *, unhashable_fallback=False):
312312
# Weed out strict duplicates, preserving the first of each occurrence.
313-
all_params = set(params)
314-
if len(all_params) < len(params):
315-
new_params = []
316-
for t in params:
317-
if t in all_params:
318-
new_params.append(t)
319-
all_params.remove(t)
320-
params = new_params
321-
assert not all_params, all_params
322-
return params
323-
313+
try:
314+
return dict.fromkeys(params)
315+
except TypeError:
316+
if not unhashable_fallback:
317+
raise
318+
# Happens for cases like `Annotated[dict, {'x': IntValidator()}]`
319+
return _deduplicate_unhashable(params)
320+
321+
def _deduplicate_unhashable(unhashable_params):
322+
new_unhashable = []
323+
for t in unhashable_params:
324+
if t not in new_unhashable:
325+
new_unhashable.append(t)
326+
return new_unhashable
327+
328+
def _compare_args_orderless(first_args, second_args):
329+
first_unhashable = _deduplicate_unhashable(first_args)
330+
second_unhashable = _deduplicate_unhashable(second_args)
331+
t = list(second_unhashable)
332+
try:
333+
for elem in first_unhashable:
334+
t.remove(elem)
335+
except ValueError:
336+
return False
337+
return not t
324338

325339
def _remove_dups_flatten(parameters):
326340
"""Internal helper for Union creation and substitution.
@@ -335,7 +349,7 @@ def _remove_dups_flatten(parameters):
335349
else:
336350
params.append(p)
337351

338-
return tuple(_deduplicate(params))
352+
return tuple(_deduplicate(params, unhashable_fallback=True))
339353

340354

341355
def _flatten_literal_params(parameters):
@@ -1555,7 +1569,10 @@ def copy_with(self, params):
15551569
def __eq__(self, other):
15561570
if not isinstance(other, (_UnionGenericAlias, types.UnionType)):
15571571
return NotImplemented
1558-
return set(self.__args__) == set(other.__args__)
1572+
try: # fast path
1573+
return set(self.__args__) == set(other.__args__)
1574+
except TypeError: # not hashable, slow path
1575+
return _compare_args_orderless(self.__args__, other.__args__)
15591576

15601577
def __hash__(self):
15611578
return hash(frozenset(self.__args__))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Allow creating :ref:`union of types<types-union>` for
2+
:class:`typing.Annotated` with unhashable metadata.

0 commit comments

Comments
 (0)