Skip to content

Commit ef6ea4e

Browse files
committed
Test newaxis in test_getitem, bumps Hypothesis min pin
1 parent fdde7ec commit ef6ea4e

File tree

3 files changed

+24
-15
lines changed

3 files changed

+24
-15
lines changed

.github/workflows/numpy.yml

+3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ jobs:
4242
# https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448
4343
array_api_tests/test_set_functions.py
4444
45+
# https://github.com/numpy/numpy/issues/21373
46+
array_api_tests/test_array_object.py::test_getitem
47+
4548
# missing copy arg
4649
array_api_tests/test_signatures.py::test_func_signature[reshape]
4750

array_api_tests/test_array_object.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,35 @@ def test_getitem(shape, data):
3131
obj = data.draw(scalar_objects(dtype, shape), label="obj")
3232
x = xp.asarray(obj, dtype=dtype)
3333
note(f"{x=}")
34-
key = data.draw(xps.indices(shape=shape), label="key")
34+
key = data.draw(xps.indices(shape=shape, allow_newaxis=True), label="key")
3535

3636
out = x[key]
3737

3838
ph.assert_dtype("__getitem__", x.dtype, out.dtype)
3939
_key = tuple(key) if isinstance(key, tuple) else (key,)
4040
if Ellipsis in _key:
41-
start_a = _key.index(Ellipsis)
42-
stop_a = start_a + (len(shape) - (len(_key) - 1))
43-
slices = tuple(slice(None, None) for _ in range(start_a, stop_a))
44-
_key = _key[:start_a] + slices + _key[start_a + 1 :]
41+
nonexpanding_key = tuple(i for i in _key if i is not None)
42+
start_a = nonexpanding_key.index(Ellipsis)
43+
stop_a = start_a + (len(shape) - (len(nonexpanding_key) - 1))
44+
slices = tuple(slice(None) for _ in range(start_a, stop_a))
45+
start_pos = _key.index(Ellipsis)
46+
_key = _key[:start_pos] + slices + _key[start_pos + 1 :]
4547
axes_indices = []
4648
out_shape = []
47-
for a, i in enumerate(_key):
48-
if isinstance(i, int):
49-
axes_indices.append([i])
49+
a = 0
50+
for i in _key:
51+
if i is None:
52+
out_shape.append(1)
5053
else:
51-
side = shape[a]
52-
indices = range(side)[i]
53-
axes_indices.append(indices)
54-
out_shape.append(len(indices))
54+
if isinstance(i, int):
55+
axes_indices.append([i])
56+
else:
57+
assert isinstance(i, slice) # sanity check
58+
side = shape[a]
59+
indices = range(side)[i]
60+
axes_indices.append(indices)
61+
out_shape.append(len(indices))
62+
a += 1
5563
out_shape = tuple(out_shape)
5664
ph.assert_shape("__getitem__", out.shape, out_shape)
5765
assume(all(len(indices) > 0 for indices in axes_indices))
@@ -104,8 +112,6 @@ def test_setitem(shape, data):
104112
)
105113

106114

107-
# TODO: make mask tests optional
108-
109115
@pytest.mark.data_dependent_shapes
110116
@given(hh.shapes(), st.data())
111117
def test_getitem_masking(shape, data):

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
pytest
2-
hypothesis>=6.31.1
2+
hypothesis>=6.45.0
33
ndindex>=1.6

0 commit comments

Comments
 (0)