Skip to content

Commit 87834dd

Browse files
authored
Merge pull request #124 from honno/indexing-improvements
Index testing improvements
2 parents 167b4f7 + 8a5103b commit 87834dd

9 files changed

+152
-87
lines changed

Diff for: array_api_tests/dtype_helpers.py

-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,6 @@ def is_float_dtype(dtype):
121121
# See https://github.com/numpy/numpy/issues/18434
122122
if dtype is None:
123123
return False
124-
# TODO: Return True for float dtypes that aren't part of the spec e.g. np.float16
125124
return dtype in float_dtypes
126125

127126

Diff for: array_api_tests/meta/test_pytest_helpers.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def test_assert_dtype():
1313
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
1414

1515

16-
def test_assert_array():
17-
ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0))
18-
ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
16+
def test_assert_array_elements():
17+
ph.assert_array_elements("int zeros", xp.asarray(0), xp.asarray(0))
18+
ph.assert_array_elements("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
1919
with raises(AssertionError):
20-
ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
20+
ph.assert_array_elements("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
2121
with raises(AssertionError):
22-
ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
22+
ph.assert_array_elements("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))

Diff for: array_api_tests/meta/test_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,15 @@ def test_roll_ndindex(shape, shifts, axes, expected):
9191
((), "x"),
9292
(42, "x[42]"),
9393
((42,), "x[42]"),
94+
((42, 7), "x[42, 7]"),
9495
(slice(None, 2), "x[:2]"),
9596
(slice(2, None), "x[2:]"),
9697
(slice(0, 2), "x[0:2]"),
9798
(slice(0, 2, -1), "x[0:2:-1]"),
9899
(slice(None, None, -1), "x[::-1]"),
99100
(slice(None, None), "x[:]"),
100101
(..., "x[...]"),
102+
((None, 42), "x[None, 42]"),
101103
],
102104
)
103105
def test_fmt_idx(idx, expected):

Diff for: array_api_tests/pytest_helpers.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
"assert_keepdimable_shape",
2626
"assert_0d_equals",
2727
"assert_fill",
28-
"assert_array",
28+
"assert_array_elements",
2929
]
3030

3131

@@ -301,7 +301,7 @@ def assert_0d_equals(
301301
>>> x = xp.asarray([0, 1, 2])
302302
>>> res = xp.asarray(x, copy=True)
303303
>>> res[0] = 42
304-
>>> assert_0d_equals('__setitem__', 'x[0]', x[0], 'x[0]', res[0])
304+
>>> assert_0d_equals('asarray', 'x[0]', x[0], 'x[0]', res[0])
305305
306306
is equivalent to
307307
@@ -374,28 +374,30 @@ def assert_fill(
374374
assert xp.all(xp.equal(out, xp.asarray(fill_value, dtype=dtype))), msg
375375

376376

377-
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
377+
def assert_array_elements(
378+
func_name: str, out: Array, expected: Array, /, *, out_repr: str = "out", **kw
379+
):
378380
"""
379-
Assert array is (strictly) as expected, e.g.
381+
Assert array elements are (strictly) as expected, e.g.
380382
381383
>>> x = xp.arange(5)
382384
>>> out = xp.asarray(x)
383-
>>> assert_array('asarray', out, x)
385+
>>> assert_array_elements('asarray', out, x)
384386
385387
is equivalent to
386388
387389
>>> assert xp.all(out == x)
388390
389391
"""
390-
assert_dtype(func_name, out.dtype, expected.dtype)
391-
assert_shape(func_name, out.shape, expected.shape, **kw)
392+
dh.result_type(out.dtype, expected.dtype) # sanity check
393+
assert_shape(func_name, out.shape, expected.shape, **kw) # sanity check
392394
f_func = f"[{func_name}({fmt_kw(kw)})]"
393395
if dh.is_float_dtype(out.dtype):
394396
for idx in sh.ndindex(out.shape):
395397
at_out = out[idx]
396398
at_expected = expected[idx]
397399
msg = (
398-
f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} "
400+
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
399401
f"{f_func}"
400402
)
401403
if xp.isnan(at_expected):
@@ -411,6 +413,6 @@ def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
411413
else:
412414
assert at_out == at_expected, msg
413415
else:
414-
assert xp.all(out == expected), (
415-
f"out not as expected {f_func}\n" f"{out=}\n{expected=}"
416-
)
416+
assert xp.all(
417+
out == expected
418+
), f"{out_repr} not as expected {f_func}\n{out_repr}={out!r}\n{expected=}"

Diff for: array_api_tests/shape_helpers.py

+2
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ def fmt_i(i: AtomicIndex) -> str:
156156
if i.step is not None:
157157
res += f":{i.step}"
158158
return res
159+
elif i is None:
160+
return "None"
159161
else:
160162
return "..."
161163

Diff for: array_api_tests/test_array_object.py

+87-50
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import List, get_args
3+
from typing import List, Sequence, Tuple, Union, get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -12,30 +12,29 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15-
from .typing import DataType, Param, Scalar, ScalarType, Shape
15+
from .test_operators_and_elementwise_functions import oneway_promotable_dtypes
16+
from .typing import DataType, Index, Param, Scalar, ScalarType, Shape
1617

1718
pytestmark = pytest.mark.ci
1819

1920

20-
def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]:
21+
def scalar_objects(
22+
dtype: DataType, shape: Shape
23+
) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
2124
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
2225
size = math.prod(shape)
2326
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
2427
lambda l: sh.reshape(l, shape)
2528
)
2629

2730

28-
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
29-
def test_getitem(shape, data):
30-
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
31-
obj = data.draw(scalar_objects(dtype, shape), label="obj")
32-
x = xp.asarray(obj, dtype=dtype)
33-
note(f"{x=}")
34-
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
35-
36-
out = x[key]
31+
def normalise_key(key: Index, shape: Shape) -> Tuple[Union[int, slice], ...]:
32+
"""
33+
Normalise an indexing key.
3734
38-
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
35+
* If a non-tuple index, wrap as a tuple.
36+
* Represent ellipsis as equivalent slices.
37+
"""
3938
_key = tuple(key) if isinstance(key, tuple) else (key,)
4039
if Ellipsis in _key:
4140
nonexpanding_key = tuple(i for i in _key if i is not None)
@@ -44,71 +43,109 @@ def test_getitem(shape, data):
4443
slices = tuple(slice(None) for _ in range(start_a, stop_a))
4544
start_pos = _key.index(Ellipsis)
4645
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
46+
return _key
47+
48+
49+
def get_indexed_axes_and_out_shape(
50+
key: Tuple[Union[int, slice, None], ...], shape: Shape
51+
) -> Tuple[Tuple[Sequence[int], ...], Shape]:
52+
"""
53+
From the (normalised) key and input shape, calculates:
54+
55+
* indexed_axes: For each dimension, the axes which the key indexes.
56+
* out_shape: The resulting shape of indexing an array (of the input shape)
57+
with the key.
58+
"""
4759
axes_indices = []
4860
out_shape = []
4961
a = 0
50-
for i in _key:
62+
for i in key:
5163
if i is None:
5264
out_shape.append(1)
5365
else:
66+
side = shape[a]
5467
if isinstance(i, int):
55-
axes_indices.append([i])
68+
if i < 0:
69+
i += side
70+
axes_indices.append((i,))
5671
else:
57-
assert isinstance(i, slice) # sanity check
58-
side = shape[a]
5972
indices = range(side)[i]
6073
axes_indices.append(indices)
6174
out_shape.append(len(indices))
6275
a += 1
63-
out_shape = tuple(out_shape)
76+
return tuple(axes_indices), tuple(out_shape)
77+
78+
79+
@given(shape=hh.shapes(), dtype=xps.scalar_dtypes(), data=st.data())
80+
def test_getitem(shape, dtype, data):
81+
zero_sided = any(side == 0 for side in shape)
82+
if zero_sided:
83+
x = xp.zeros(shape, dtype=dtype)
84+
else:
85+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
86+
x = xp.asarray(obj, dtype=dtype)
87+
note(f"{x=}")
88+
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
89+
90+
out = x[key]
91+
92+
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
93+
_key = normalise_key(key, shape)
94+
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
6495
ph.assert_shape("__getitem__", out.shape, out_shape)
65-
assume(all(len(indices) > 0 for indices in axes_indices))
66-
out_obj = []
67-
for idx in product(*axes_indices):
68-
val = obj
69-
for i in idx:
70-
val = val[i]
71-
out_obj.append(val)
72-
out_obj = sh.reshape(out_obj, out_shape)
73-
expected = xp.asarray(out_obj, dtype=dtype)
74-
ph.assert_array("__getitem__", out, expected)
75-
76-
77-
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
78-
def test_setitem(shape, data):
79-
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
80-
obj = data.draw(scalar_objects(dtype, shape), label="obj")
81-
x = xp.asarray(obj, dtype=dtype)
96+
out_zero_sided = any(side == 0 for side in out_shape)
97+
if not zero_sided and not out_zero_sided:
98+
out_obj = []
99+
for idx in product(*axes_indices):
100+
val = obj
101+
for i in idx:
102+
val = val[i]
103+
out_obj.append(val)
104+
out_obj = sh.reshape(out_obj, out_shape)
105+
expected = xp.asarray(out_obj, dtype=dtype)
106+
ph.assert_array_elements("__getitem__", out, expected)
107+
108+
109+
@given(
110+
shape=hh.shapes(),
111+
dtypes=oneway_promotable_dtypes(dh.all_dtypes),
112+
data=st.data(),
113+
)
114+
def test_setitem(shape, dtypes, data):
115+
zero_sided = any(side == 0 for side in shape)
116+
if zero_sided:
117+
x = xp.zeros(shape, dtype=dtypes.result_dtype)
118+
else:
119+
obj = data.draw(scalar_objects(dtypes.result_dtype, shape), label="obj")
120+
x = xp.asarray(obj, dtype=dtypes.result_dtype)
82121
note(f"{x=}")
83-
# TODO: test setting non-0d arrays
84-
key = data.draw(xps.indices(shape=shape, max_dims=0), label="key")
85-
value = data.draw(
86-
xps.from_dtype(dtype) | xps.arrays(dtype=dtype, shape=()), label="value"
87-
)
122+
key = data.draw(xps.indices(shape=shape), label="key")
123+
_key = normalise_key(key, shape)
124+
axes_indices, out_shape = get_indexed_axes_and_out_shape(_key, shape)
125+
value_strat = xps.arrays(dtype=dtypes.result_dtype, shape=out_shape)
126+
if out_shape == ():
127+
# We can pass scalars if we're only indexing one element
128+
value_strat |= xps.from_dtype(dtypes.result_dtype)
129+
value = data.draw(value_strat, label="value")
88130

89131
res = xp.asarray(x, copy=True)
90132
res[key] = value
91133

92134
ph.assert_dtype("__setitem__", x.dtype, res.dtype, repr_name="x.dtype")
93135
ph.assert_shape("__setitem__", res.shape, x.shape, repr_name="x.shape")
136+
f_res = sh.fmt_idx("x", key)
94137
if isinstance(value, get_args(Scalar)):
95-
msg = f"x[{key}]={res[key]!r}, but should be {value=} [__setitem__()]"
138+
msg = f"{f_res}={res[key]!r}, but should be {value=} [__setitem__()]"
96139
if math.isnan(value):
97140
assert xp.isnan(res[key]), msg
98141
else:
99142
assert res[key] == value, msg
100143
else:
101-
ph.assert_0d_equals(
102-
"__setitem__", "value", value, f"modified x[{key}]", res[key]
103-
)
104-
_key = key if isinstance(key, tuple) else (key,)
105-
assume(all(isinstance(i, int) for i in _key)) # TODO: normalise slices and ellipsis
106-
_key = tuple(i if i >= 0 else s + i for i, s in zip(_key, x.shape))
107-
unaffected_indices = list(sh.ndindex(res.shape))
108-
unaffected_indices.remove(_key)
144+
ph.assert_array_elements("__setitem__", res[key], value, out_repr=f_res)
145+
unaffected_indices = set(sh.ndindex(res.shape)) - set(product(*axes_indices))
109146
for idx in unaffected_indices:
110147
ph.assert_0d_equals(
111-
"__setitem__", f"old x[{idx}]", x[idx], f"modified x[{idx}]", res[idx]
148+
"__setitem__", f"old {f_res}", x[idx], f"modified {f_res}", res[idx]
112149
)
113150

114151

0 commit comments

Comments
 (0)