Skip to content

Commit 63ebadb

Browse files
authored
Merge pull request #112 from honno/use-dtype-eq
Implement `EqualityMapping` and use for relevant dtype helpers
2 parents 9816011 + ed23bfa commit 63ebadb

File tree

3 files changed

+180
-75
lines changed

3 files changed

+180
-75
lines changed

Diff for: array_api_tests/__init__.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from functools import wraps
22

33
from hypothesis import strategies as st
4-
from hypothesis.extra.array_api import make_strategies_namespace
4+
from hypothesis.extra import array_api
55

66
from ._array_module import mod as _xp
77

88
__all__ = ["xps"]
99

10-
xps = make_strategies_namespace(_xp)
11-
1210

1311
# We monkey patch floats() to always disable subnormals as they are out-of-scope
1412

@@ -23,5 +21,29 @@ def floats(*a, **kw):
2321

2422
st.floats = floats
2523

24+
25+
# We do the same with xps.from_dtype() - this is not strictly necessary, as
26+
# the underlying floats() will never generate subnormals. We only do this
27+
# because internal logic in xps.from_dtype() assumes xp.finfo() has its
28+
# attributes as scalar floats, which is expected behaviour but disrupts many
29+
# unrelated tests.
30+
try:
31+
__from_dtype = array_api._from_dtype
32+
33+
@wraps(__from_dtype)
34+
def _from_dtype(*a, **kw):
35+
kw["allow_subnormal"] = False
36+
return __from_dtype(*a, **kw)
37+
38+
array_api._from_dtype = _from_dtype
39+
except AttributeError:
40+
# Ignore monkey patching if Hypothesis changes the private API
41+
pass
42+
43+
44+
xps = array_api.make_strategies_namespace(_xp)
45+
46+
2647
from . import _version
27-
__version__ = _version.get_versions()['version']
48+
49+
__version__ = _version.get_versions()["version"]

Diff for: array_api_tests/dtype_helpers.py

+117-71
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from collections.abc import Mapping
12
from functools import lru_cache
2-
from typing import NamedTuple, Tuple, Union
3+
from typing import Any, NamedTuple, Sequence, Tuple, Union
34
from warnings import warn
45

56
from . import _array_module as xp
@@ -36,6 +37,49 @@
3637
]
3738

3839

40+
class EqualityMapping(Mapping):
41+
"""
42+
Mapping that uses equality for indexing
43+
44+
Typical mappings (e.g. the built-in dict) use hashing for indexing. This
45+
isn't ideal for the Array API, as no __hash__() method is specified for
46+
dtype objects - but __eq__() is!
47+
48+
See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
49+
"""
50+
51+
def __init__(self, key_value_pairs: Sequence[Tuple[Any, Any]]):
52+
keys = [k for k, _ in key_value_pairs]
53+
for i, key in enumerate(keys):
54+
if not (key == key): # specifically checking __eq__, not __neq__
55+
raise ValueError("Key {key!r} does not have equality with itself")
56+
other_keys = keys[:]
57+
other_keys.pop(i)
58+
for other_key in other_keys:
59+
if key == other_key:
60+
raise ValueError("Key {key!r} has equality with key {other_key!r}")
61+
self._key_value_pairs = key_value_pairs
62+
63+
def __getitem__(self, key):
64+
for k, v in self._key_value_pairs:
65+
if key == k:
66+
return v
67+
else:
68+
raise KeyError(f"{key!r} not found")
69+
70+
def __iter__(self):
71+
return (k for k, _ in self._key_value_pairs)
72+
73+
def __len__(self):
74+
return len(self._key_value_pairs)
75+
76+
def __str__(self):
77+
return "{" + ", ".join(f"{k!r}: {v!r}" for k, v in self._key_value_pairs) + "}"
78+
79+
def __repr__(self):
80+
return f"EqualityMapping({self})"
81+
82+
3983
_uint_names = ("uint8", "uint16", "uint32", "uint64")
4084
_int_names = ("int8", "int16", "int32", "int64")
4185
_float_names = ("float32", "float64")
@@ -51,14 +95,16 @@
5195
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes
5296

5397

54-
dtype_to_name = {getattr(xp, name): name for name in _dtype_names}
98+
dtype_to_name = EqualityMapping([(getattr(xp, name), name) for name in _dtype_names])
5599

56100

57-
dtype_to_scalars = {
58-
xp.bool: [bool],
59-
**{d: [int] for d in all_int_dtypes},
60-
**{d: [int, float] for d in float_dtypes},
61-
}
101+
dtype_to_scalars = EqualityMapping(
102+
[
103+
(xp.bool, [bool]),
104+
*[(d, [int]) for d in all_int_dtypes],
105+
*[(d, [int, float]) for d in float_dtypes],
106+
]
107+
)
62108

63109

64110
def is_int_dtype(dtype):
@@ -90,31 +136,32 @@ class MinMax(NamedTuple):
90136
max: Union[int, float]
91137

92138

93-
dtype_ranges = {
94-
xp.int8: MinMax(-128, +127),
95-
xp.int16: MinMax(-32_768, +32_767),
96-
xp.int32: MinMax(-2_147_483_648, +2_147_483_647),
97-
xp.int64: MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807),
98-
xp.uint8: MinMax(0, +255),
99-
xp.uint16: MinMax(0, +65_535),
100-
xp.uint32: MinMax(0, +4_294_967_295),
101-
xp.uint64: MinMax(0, +18_446_744_073_709_551_615),
102-
xp.float32: MinMax(-3.4028234663852886e+38, 3.4028234663852886e+38),
103-
xp.float64: MinMax(-1.7976931348623157e+308, 1.7976931348623157e+308),
104-
}
139+
dtype_ranges = EqualityMapping(
140+
[
141+
(xp.int8, MinMax(-128, +127)),
142+
(xp.int16, MinMax(-32_768, +32_767)),
143+
(xp.int32, MinMax(-2_147_483_648, +2_147_483_647)),
144+
(xp.int64, MinMax(-9_223_372_036_854_775_808, +9_223_372_036_854_775_807)),
145+
(xp.uint8, MinMax(0, +255)),
146+
(xp.uint16, MinMax(0, +65_535)),
147+
(xp.uint32, MinMax(0, +4_294_967_295)),
148+
(xp.uint64, MinMax(0, +18_446_744_073_709_551_615)),
149+
(xp.float32, MinMax(-3.4028234663852886e38, 3.4028234663852886e38)),
150+
(xp.float64, MinMax(-1.7976931348623157e308, 1.7976931348623157e308)),
151+
]
152+
)
105153

106-
dtype_nbits = {
107-
**{d: 8 for d in [xp.int8, xp.uint8]},
108-
**{d: 16 for d in [xp.int16, xp.uint16]},
109-
**{d: 32 for d in [xp.int32, xp.uint32, xp.float32]},
110-
**{d: 64 for d in [xp.int64, xp.uint64, xp.float64]},
111-
}
154+
dtype_nbits = EqualityMapping(
155+
[(d, 8) for d in [xp.int8, xp.uint8]]
156+
+ [(d, 16) for d in [xp.int16, xp.uint16]]
157+
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
158+
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]]
159+
)
112160

113161

114-
dtype_signed = {
115-
**{d: True for d in int_dtypes},
116-
**{d: False for d in uint_dtypes},
117-
}
162+
dtype_signed = EqualityMapping(
163+
[(d, True) for d in int_dtypes] + [(d, False) for d in uint_dtypes]
164+
)
118165

119166

120167
if isinstance(xp.asarray, _UndefinedStub):
@@ -137,52 +184,51 @@ class MinMax(NamedTuple):
137184
default_uint = xp.uint64
138185

139186

140-
_numeric_promotions = {
187+
_numeric_promotions = [
141188
# ints
142-
(xp.int8, xp.int8): xp.int8,
143-
(xp.int8, xp.int16): xp.int16,
144-
(xp.int8, xp.int32): xp.int32,
145-
(xp.int8, xp.int64): xp.int64,
146-
(xp.int16, xp.int16): xp.int16,
147-
(xp.int16, xp.int32): xp.int32,
148-
(xp.int16, xp.int64): xp.int64,
149-
(xp.int32, xp.int32): xp.int32,
150-
(xp.int32, xp.int64): xp.int64,
151-
(xp.int64, xp.int64): xp.int64,
189+
((xp.int8, xp.int8), xp.int8),
190+
((xp.int8, xp.int16), xp.int16),
191+
((xp.int8, xp.int32), xp.int32),
192+
((xp.int8, xp.int64), xp.int64),
193+
((xp.int16, xp.int16), xp.int16),
194+
((xp.int16, xp.int32), xp.int32),
195+
((xp.int16, xp.int64), xp.int64),
196+
((xp.int32, xp.int32), xp.int32),
197+
((xp.int32, xp.int64), xp.int64),
198+
((xp.int64, xp.int64), xp.int64),
152199
# uints
153-
(xp.uint8, xp.uint8): xp.uint8,
154-
(xp.uint8, xp.uint16): xp.uint16,
155-
(xp.uint8, xp.uint32): xp.uint32,
156-
(xp.uint8, xp.uint64): xp.uint64,
157-
(xp.uint16, xp.uint16): xp.uint16,
158-
(xp.uint16, xp.uint32): xp.uint32,
159-
(xp.uint16, xp.uint64): xp.uint64,
160-
(xp.uint32, xp.uint32): xp.uint32,
161-
(xp.uint32, xp.uint64): xp.uint64,
162-
(xp.uint64, xp.uint64): xp.uint64,
200+
((xp.uint8, xp.uint8), xp.uint8),
201+
((xp.uint8, xp.uint16), xp.uint16),
202+
((xp.uint8, xp.uint32), xp.uint32),
203+
((xp.uint8, xp.uint64), xp.uint64),
204+
((xp.uint16, xp.uint16), xp.uint16),
205+
((xp.uint16, xp.uint32), xp.uint32),
206+
((xp.uint16, xp.uint64), xp.uint64),
207+
((xp.uint32, xp.uint32), xp.uint32),
208+
((xp.uint32, xp.uint64), xp.uint64),
209+
((xp.uint64, xp.uint64), xp.uint64),
163210
# ints and uints (mixed sign)
164-
(xp.int8, xp.uint8): xp.int16,
165-
(xp.int8, xp.uint16): xp.int32,
166-
(xp.int8, xp.uint32): xp.int64,
167-
(xp.int16, xp.uint8): xp.int16,
168-
(xp.int16, xp.uint16): xp.int32,
169-
(xp.int16, xp.uint32): xp.int64,
170-
(xp.int32, xp.uint8): xp.int32,
171-
(xp.int32, xp.uint16): xp.int32,
172-
(xp.int32, xp.uint32): xp.int64,
173-
(xp.int64, xp.uint8): xp.int64,
174-
(xp.int64, xp.uint16): xp.int64,
175-
(xp.int64, xp.uint32): xp.int64,
211+
((xp.int8, xp.uint8), xp.int16),
212+
((xp.int8, xp.uint16), xp.int32),
213+
((xp.int8, xp.uint32), xp.int64),
214+
((xp.int16, xp.uint8), xp.int16),
215+
((xp.int16, xp.uint16), xp.int32),
216+
((xp.int16, xp.uint32), xp.int64),
217+
((xp.int32, xp.uint8), xp.int32),
218+
((xp.int32, xp.uint16), xp.int32),
219+
((xp.int32, xp.uint32), xp.int64),
220+
((xp.int64, xp.uint8), xp.int64),
221+
((xp.int64, xp.uint16), xp.int64),
222+
((xp.int64, xp.uint32), xp.int64),
176223
# floats
177-
(xp.float32, xp.float32): xp.float32,
178-
(xp.float32, xp.float64): xp.float64,
179-
(xp.float64, xp.float64): xp.float64,
180-
}
181-
promotion_table = {
182-
(xp.bool, xp.bool): xp.bool,
183-
**_numeric_promotions,
184-
**{(d2, d1): res for (d1, d2), res in _numeric_promotions.items()},
185-
}
224+
((xp.float32, xp.float32), xp.float32),
225+
((xp.float32, xp.float64), xp.float64),
226+
((xp.float64, xp.float64), xp.float64),
227+
]
228+
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
229+
_promotion_table = list(set(_numeric_promotions))
230+
_promotion_table.insert(0, ((xp.bool, xp.bool), xp.bool))
231+
promotion_table = EqualityMapping(_promotion_table)
186232

187233

188234
def result_type(*dtypes: DataType):

Diff for: array_api_tests/meta/test_equality_mapping.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
3+
from ..dtype_helpers import EqualityMapping
4+
5+
6+
def test_raises_on_distinct_eq_key():
7+
with pytest.raises(ValueError):
8+
EqualityMapping([(float("nan"), "value")])
9+
10+
11+
def test_raises_on_indistinct_eq_keys():
12+
class AlwaysEq:
13+
def __init__(self, hash):
14+
self._hash = hash
15+
16+
def __eq__(self, other):
17+
return True
18+
19+
def __hash__(self):
20+
return self._hash
21+
22+
with pytest.raises(ValueError):
23+
EqualityMapping([(AlwaysEq(0), "value1"), (AlwaysEq(1), "value2")])
24+
25+
26+
def test_key_error():
27+
mapping = EqualityMapping([("key", "value")])
28+
with pytest.raises(KeyError):
29+
mapping["nonexistent key"]
30+
31+
32+
def test_iter():
33+
mapping = EqualityMapping([("key", "value")])
34+
it = iter(mapping)
35+
assert next(it) == "key"
36+
with pytest.raises(StopIteration):
37+
next(it)

0 commit comments

Comments
 (0)