Skip to content

Fixes related to Dask #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ Use the `--ci` flag to run only the primary and special cases tests. You can
ignore the other test cases as they are redundant for the purposes of checking
compliance.

#### Data-dependent shapes

Use the `--disable-data-dependent-shapes` flag to skip testing functions which have
[data-dependent shapes](https://data-apis.org/array-api/latest/design_topics/data_dependent_output_shapes.html).

#### Extensions

By default, tests for the optional Array API extensions such as
Expand Down Expand Up @@ -200,16 +205,10 @@ instead of having a seperate `skips.txt` file, e.g.:
# Skip test cases with known issues
cat << EOF >> skips.txt

# Skip specific test case, e.g. when argsort() does not respect relative order
# https://github.com/numpy/numpy/issues/20778
# Comments can still work here
array_api_tests/test_sorting_functions.py::test_argsort

# Skip specific test case parameter, e.g. you forgot to implement in-place adds
array_api_tests/test_add[__iadd__(x1, x2)]
array_api_tests/test_add[__iadd__(x, s)]

# Skip module, e.g. when your set functions treat NaNs as non-distinct
# https://github.com/numpy/numpy/issues/20326
array_api_tests/test_set_functions.py

EOF
Expand Down
11 changes: 10 additions & 1 deletion array_api_tests/meta/test_pytest_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pytest import raises

from .. import pytest_helpers as ph
from .. import _array_module as xp
from .. import pytest_helpers as ph


def test_assert_dtype():
Expand All @@ -11,3 +11,12 @@ def test_assert_dtype():
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)


def test_assert_array():
ph.assert_array("int zeros", xp.asarray(0), xp.asarray(0))
ph.assert_array("pos zeros", xp.asarray(0.0), xp.asarray(0.0))
with raises(AssertionError):
ph.assert_array("mixed sign zeros", xp.asarray(0.0), xp.asarray(-0.0))
with raises(AssertionError):
ph.assert_array("mixed sign zeros", xp.asarray(-0.0), xp.asarray(0.0))
44 changes: 34 additions & 10 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"doesnt_raise",
"nargs",
"fmt_kw",
"is_pos_zero",
"is_neg_zero",
"assert_dtype",
"assert_kw_dtype",
"assert_default_float",
Expand All @@ -22,6 +24,7 @@
"assert_shape",
"assert_result_shape",
"assert_keepdimable_shape",
"assert_0d_equals",
"assert_fill",
"assert_array",
]
Expand Down Expand Up @@ -69,6 +72,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str:
return ", ".join(f"{k}={v}" for k, v in kw.items())


def is_pos_zero(n: float) -> bool:
return n == 0 and math.copysign(1, n) == 1


def is_neg_zero(n: float) -> bool:
return n == 0 and math.copysign(1, n) == -1


def assert_dtype(
func_name: str,
in_dtype: Union[DataType, Sequence[DataType]],
Expand Down Expand Up @@ -232,15 +243,28 @@ def assert_fill(
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
assert_dtype(func_name, out.dtype, expected.dtype)
assert_shape(func_name, out.shape, expected.shape, **kw)
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
f_func = f"[{func_name}({fmt_kw(kw)})]"
if dh.is_float_dtype(out.dtype):
neg_zeros = expected == -0.0
assert xp.all((out == -0.0) == neg_zeros), msg
pos_zeros = expected == +0.0
assert xp.all((out == +0.0) == pos_zeros), msg
nans = xp.isnan(expected)
assert xp.all(xp.isnan(out) == nans), msg
mask = ~(neg_zeros | pos_zeros | nans)
assert xp.all(out[mask] == expected[mask]), msg
for idx in sh.ndindex(out.shape):
at_out = out[idx]
at_expected = expected[idx]
msg = (
f"{sh.fmt_idx('out', idx)}={at_out}, should be {at_expected} "
f"{f_func}"
)
if xp.isnan(at_expected):
assert xp.isnan(at_out), msg
elif at_expected == 0.0 or at_expected == -0.0:
scalar_at_expected = float(at_expected)
scalar_at_out = float(at_out)
if is_pos_zero(scalar_at_expected):
assert is_pos_zero(scalar_at_out), msg
else:
assert is_neg_zero(scalar_at_expected) # sanity check
assert is_neg_zero(scalar_at_out), msg
else:
assert at_out == at_expected, msg
else:
assert xp.all(out == expected), msg
assert xp.all(out == expected), (
f"out not as expected {f_func}\n" f"{out=}\n{expected=}"
)
2 changes: 1 addition & 1 deletion array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def test_asarray_arrays(x, data):
if copy:
assert not xp.all(
out == x
), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
), f"xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
elif copy is False:
pass # TODO

Expand Down
7 changes: 5 additions & 2 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def default_filter(s: Scalar) -> bool:

Used by default as these values are typically special-cased.
"""
return math.isfinite(s) and s is not -0.0 and s is not +0.0
if isinstance(s, int): # note bools are ints
return True
else:
return math.isfinite(s) and s != 0


T = TypeVar("T")
Expand Down Expand Up @@ -538,7 +541,7 @@ def test_abs(ctx, data):
abs, # type: ignore
expr_template="abs({})={}",
filter_=lambda s: (
s == float("infinity") or (math.isfinite(s) and s is not -0.0)
s == float("infinity") or (math.isfinite(s) and not ph.is_neg_zero(s))
),
)

Expand Down
24 changes: 9 additions & 15 deletions array_api_tests/test_searching_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@
data=st.data(),
)
def test_argmax(x, data):
kw = data.draw(
hh.kwargs(
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
keepdims=st.booleans(),
),
label="kw",
)
axis_strat = st.none()
if x.ndim > 0:
axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0))
kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw")

out = xp.argmax(x, **kw)

Expand Down Expand Up @@ -56,13 +53,10 @@ def test_argmax(x, data):
data=st.data(),
)
def test_argmin(x, data):
kw = data.draw(
hh.kwargs(
axis=st.none() | st.integers(-x.ndim, max(x.ndim - 1, 0)),
keepdims=st.booleans(),
),
label="kw",
)
axis_strat = st.none()
if x.ndim > 0:
axis_strat |= st.integers(-x.ndim, max(x.ndim - 1, 0))
kw = data.draw(hh.kwargs(axis=axis_strat, keepdims=st.booleans()), label="kw")

out = xp.argmin(x, **kw)

Expand All @@ -82,7 +76,7 @@ def test_argmin(x, data):
ph.assert_scalar_equals("argmin", int, out_idx, min_i, expected)


# TODO: skip if opted out
@pytest.mark.data_dependent_shapes
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
def test_nonzero(x):
out = xp.nonzero(x)
Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from . import shape_helpers as sh
from . import xps

pytestmark = pytest.mark.ci
pytestmark = [pytest.mark.ci, pytest.mark.data_dependent_shapes]


@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1)))
Expand Down
29 changes: 24 additions & 5 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ def pytest_addoption(parser):
default=[],
help="disable testing for Array API extension(s)",
)
# data-dependent shape
parser.addoption(
"--disable-data-dependent-shapes",
"--disable-dds",
action="store_true",
help="disable testing functions with output shapes dependent on input",
)
# CI
parser.addoption(
"--ci",
Expand All @@ -47,6 +54,9 @@ def pytest_configure(config):
config.addinivalue_line(
"markers", "xp_extension(ext): tests an Array API extension"
)
config.addinivalue_line(
"markers", "data_dependent_shapes: output shapes are dependent on inputs"
)
config.addinivalue_line("markers", "ci: primary test")
# Hypothesis
hypothesis_max_examples = config.getoption("--hypothesis-max-examples")
Expand Down Expand Up @@ -83,9 +93,15 @@ def xp_has_ext(ext: str) -> bool:

def pytest_collection_modifyitems(config, items):
disabled_exts = config.getoption("--disable-extension")
disabled_dds = config.getoption("--disable-data-dependent-shapes")
ci = config.getoption("--ci")
for item in items:
markers = list(item.iter_markers())
# skip if specified in skips.txt
for id_ in skip_ids:
if item.nodeid.startswith(id_):
item.add_marker(mark.skip(reason="skips.txt"))
break
# skip if disabled or non-existent extension
ext_mark = next((m for m in markers if m.name == "xp_extension"), None)
if ext_mark is not None:
Expand All @@ -96,11 +112,14 @@ def pytest_collection_modifyitems(config, items):
)
elif not xp_has_ext(ext):
item.add_marker(mark.skip(reason=f"{ext} not found in array module"))
# skip if specified in skips.txt
for id_ in skip_ids:
if item.nodeid.startswith(id_):
item.add_marker(mark.skip(reason="skips.txt"))
break
# skip if disabled by dds flag
if disabled_dds:
for m in markers:
if m.name == "data_dependent_shapes":
item.add_marker(
mark.skip(reason="disabled via --disable-data-dependent-shapes")
)
break
# skip if test not appropiate for CI
if ci:
ci_mark = next((m for m in markers if m.name == "ci"), None)
Expand Down