Skip to content

Commit 44c80d8

Browse files
authored
Merge pull request #98 from honno/dask-fixes
Fixes related to Dask
2 parents cb2e7d0 + 92def25 commit 44c80d8

8 files changed

+90
-42
lines changed

README.md

+6-7
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ Use the `--ci` flag to run only the primary and special cases tests. You can
157157
ignore the other test cases as they are redundant for the purposes of checking
158158
compliance.
159159

160+
#### Data-dependent shapes
161+
162+
Use the `--disable-data-dependent-shapes` flag to skip testing functions which have
163+
[data-dependent shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html).
164+
160165
#### Extensions
161166

162167
By default, tests for the optional Array API extensions such as
@@ -200,16 +205,10 @@ instead of having a seperate `skips.txt` file, e.g.:
200205
# Skip test cases with known issues
201206
cat << EOF >> skips.txt
202207
203-
# Skip specific test case, e.g. when argsort() does not respect relative order
204-
# https://github.com/numpy/numpy/issues/20778
208+
# Comments can still work here
205209
array_api_tests/test_sorting_functions.py::test_argsort
206-
207-
# Skip specific test case parameter, e.g. you forgot to implement in-place adds
208210
array_api_tests/test_add[__iadd__(x1, x2)]
209211
array_api_tests/test_add[__iadd__(x, s)]
210-
211-
# Skip module, e.g. when your set functions treat NaNs as non-distinct
212-
# https://github.com/numpy/numpy/issues/20326
213212
array_api_tests/test_set_functions.py
214213
215214
EOF

array_api_tests/meta/test_pytest_helpers.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pytest import raises
22

3-
from .. import pytest_helpers as ph
43
from .. import _array_module as xp
4+
from .. import pytest_helpers as ph
55

66

77
def test_assert_dtype():
@@ -11,3 +11,12 @@ def test_assert_dtype():
1111
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
1212
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
1313
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
14+
15+
16+
def test_assert_array():
17+
ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0))
18+
ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
19+
with raises(AssertionError):
20+
ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
21+
with raises(AssertionError):
22+
ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))

array_api_tests/pytest_helpers.py

+34-10
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
"doesnt_raise",
1515
"nargs",
1616
"fmt_kw",
17+
"is_pos_zero",
18+
"is_neg_zero",
1719
"assert_dtype",
1820
"assert_kw_dtype",
1921
"assert_default_float",
@@ -22,6 +24,7 @@
2224
"assert_shape",
2325
"assert_result_shape",
2426
"assert_keepdimable_shape",
27+
"assert_0d_equals",
2528
"assert_fill",
2629
"assert_array",
2730
]
@@ -69,6 +72,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str:
6972
return ", ".join(f"{k}={v}" for k, v in kw.items())
7073

7174

75+
def is_pos_zero(n: float) -> bool:
76+
return n == 0 and math.copysign(1, n) == 1
77+
78+
79+
def is_neg_zero(n: float) -> bool:
80+
return n == 0 and math.copysign(1, n) == -1
81+
82+
7283
def assert_dtype(
7384
func_name: str,
7485
in_dtype: Union[DataType, Sequence[DataType]],
@@ -232,15 +243,28 @@ def assert_fill(
232243
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
233244
assert_dtype(func_name, out.dtype, expected.dtype)
234245
assert_shape(func_name, out.shape, expected.shape, **kw)
235-
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
246+
f_func = f"[{func_name}({fmt_kw(kw)})]"
236247
if dh.is_float_dtype(out.dtype):
237-
neg_zeros = expected == -0.0
238-
assert xp.all((out == -0.0) == neg_zeros), msg
239-
pos_zeros = expected == +0.0
240-
assert xp.all((out == +0.0) == pos_zeros), msg
241-
nans = xp.isnan(expected)
242-
assert xp.all(xp.isnan(out) == nans), msg
243-
mask = ~(neg_zeros | pos_zeros | nans)
244-
assert xp.all(out[mask] == expected[mask]), msg
248+
for idx in sh.ndindex(out.shape):
249+
at_out = out[idx]
250+
at_expected = expected[idx]
251+
msg = (
252+
f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} "
253+
f"{f_func}"
254+
)
255+
if xp.isnan(at_expected):
256+
assert xp.isnan(at_out), msg
257+
elif at_expected == 0.0 or at_expected == -0.0:
258+
scalar_at_expected = float(at_expected)
259+
scalar_at_out = float(at_out)
260+
if is_pos_zero(scalar_at_expected):
261+
assert is_pos_zero(scalar_at_out), msg
262+
else:
263+
assert is_neg_zero(scalar_at_expected) # sanity check
264+
assert is_neg_zero(scalar_at_out), msg
265+
else:
266+
assert at_out == at_expected, msg
245267
else:
246-
assert xp.all(out == expected), msg
268+
assert xp.all(out == expected), (
269+
f"out not as expected {f_func}\n" f"{out=}\n{expected=}"
270+
)

array_api_tests/test_creation_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def test_asarray_arrays(x, data):
280280
if copy:
281281
assert not xp.all(
282282
out == x
283-
), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
283+
), f"xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
284284
elif copy is False:
285285
pass # TODO
286286

array_api_tests/test_operators_and_elementwise_functions.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def default_filter(s: Scalar) -> bool:
123123
124124
Used by default as these values are typically special-cased.
125125
"""
126-
return math.isfinite(s) and s is not -0.0 and s is not +0.0
126+
if isinstance(s, int): # note bools are ints
127+
return True
128+
else:
129+
return math.isfinite(s) and s != 0
127130

128131

129132
T = TypeVar("T")
@@ -538,7 +541,7 @@ def test_abs(ctx, data):
538541
abs, # type: ignore
539542
expr_template="abs({})={}",
540543
filter_=lambda s: (
541-
s == float("infinity") or (math.isfinite(s) and s is not -0.0)
544+
s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s))
542545
),
543546
)
544547

array_api_tests/test_searching_functions.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,10 @@
2121
data=st.data(),
2222
)
2323
def test_argmax(x, data):
24-
kw = data.draw(
25-
hh.kwargs(
26-
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
27-
keepdims=st.booleans(),
28-
),
29-
label="kw",
30-
)
24+
axis_strat = st.none()
25+
if x.ndim > 0:
26+
axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0))
27+
kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw")
3128

3229
out = xp.argmax(x, **kw)
3330

@@ -56,13 +53,10 @@ def test_argmax(x, data):
5653
data=st.data(),
5754
)
5855
def test_argmin(x, data):
59-
kw = data.draw(
60-
hh.kwargs(
61-
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
62-
keepdims=st.booleans(),
63-
),
64-
label="kw",
65-
)
56+
axis_strat = st.none()
57+
if x.ndim > 0:
58+
axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0))
59+
kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw")
6660

6761
out = xp.argmin(x, **kw)
6862

@@ -82,7 +76,7 @@ def test_argmin(x, data):
8276
ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected)
8377

8478

85-
# TODO: skip if opted out
79+
@pytest.mark.data_dependent_shapes
8680
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
8781
def test_nonzero(x):
8882
out = xp.nonzero(x)

array_api_tests/test_set_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import shape_helpers as sh
1313
from . import xps
1414

15-
pytestmark = pytest.mark.ci
15+
pytestmark = [pytest.mark.ci, pytest.mark.data_dependent_shapes]
1616

1717

1818
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))

conftest.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def pytest_addoption(parser):
3535
default=[],
3636
help="disable testing for Array API extension(s)",
3737
)
38+
# data-dependent shape
39+
parser.addoption(
40+
"--disable-data-dependent-shapes",
41+
"--disable-dds",
42+
action="store_true",
43+
help="disable testing functions with output shapes dependent on input",
44+
)
3845
# CI
3946
parser.addoption(
4047
"--ci",
@@ -47,6 +54,9 @@ def pytest_configure(config):
4754
config.addinivalue_line(
4855
"markers", "xp_extension(ext): tests an Array API extension"
4956
)
57+
config.addinivalue_line(
58+
"markers", "data_dependent_shapes: output shapes are dependent on inputs"
59+
)
5060
config.addinivalue_line("markers", "ci: primary test")
5161
# Hypothesis
5262
hypothesis_max_examples = config.getoption("--hypothesis-max-examples")
@@ -83,9 +93,15 @@ def xp_has_ext(ext: str) -> bool:
8393

8494
def pytest_collection_modifyitems(config, items):
8595
disabled_exts = config.getoption("--disable-extension")
96+
disabled_dds = config.getoption("--disable-data-dependent-shapes")
8697
ci = config.getoption("--ci")
8798
for item in items:
8899
markers = list(item.iter_markers())
100+
# skip if specified in skips.txt
101+
for id_ in skip_ids:
102+
if item.nodeid.startswith(id_):
103+
item.add_marker(mark.skip(reason="skips.txt"))
104+
break
89105
# skip if disabled or non-existent extension
90106
ext_mark = next((m for m in markers if m.name == "xp_extension"), None)
91107
if ext_mark is not None:
@@ -96,11 +112,14 @@ def pytest_collection_modifyitems(config, items):
96112
)
97113
elif not xp_has_ext(ext):
98114
item.add_marker(mark.skip(reason=f"{ext} not found in array module"))
99-
# skip if specified in skips.txt
100-
for id_ in skip_ids:
101-
if item.nodeid.startswith(id_):
102-
item.add_marker(mark.skip(reason="skips.txt"))
103-
break
115+
# skip if disabled by dds flag
116+
if disabled_dds:
117+
for m in markers:
118+
if m.name == "data_dependent_shapes":
119+
item.add_marker(
120+
mark.skip(reason="disabled via --disable-data-dependent-shapes")
121+
)
122+
break
104123
# skip if test not appropiate for CI
105124
if ci:
106125
ci_mark = next((m for m in markers if m.name == "ci"), None)

0 commit comments

Comments
 (0)