Skip to content

Increase coverage and other niceties for the operator/elementwise tests #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f661a23
Updates to op/elwise tests
honno Jan 25, 2022
a590f8d
`sh.fmt_idx()` helper
honno Jan 25, 2022
d7e5e63
Better values testing in `test_not_equal`
honno Jan 26, 2022
1a54bd4
Better values testing for bitwise op/elwise tests
honno Jan 26, 2022
2f8492b
Context objects for unary/binary params
honno Jan 26, 2022
4623214
Apply `iter_indices()` logic to binary op/elwise tests
honno Jan 26, 2022
4b2c41e
Update `test_remainder`
honno Jan 27, 2022
1927c10
Move `broadcast_shapes()` to `shape_helpers.py`
honno Jan 27, 2022
bb836b7
Skip `sh.iter_indices()` generation for 0-sided shapes
honno Jan 27, 2022
f11a6d0
Values testing for `test_sign`
honno Jan 27, 2022
47424e8
Values testing for `test_add` and `test_subtract`
honno Jan 27, 2022
2077986
Rudimentary values testing refactor, updates to logical elwise tests
honno Jan 28, 2022
66a1fd4
Favour lists compared to tuples for `ph.assert_dtypes()`
honno Jan 28, 2022
b6d05da
Favour lists for `ph.assert_result_shape()`
honno Jan 28, 2022
af6d150
Remove `lru_cache` use in `sh.fmt_idx()`
honno Jan 28, 2022
799b4e6
Refactor parametrized unary tests
honno Jan 28, 2022
e2b69df
Op/elwise fixes and improvements
honno Jan 31, 2022
3dfd665
`binary_param_assert_against_refimpl()` to refactor elwise+op tests
honno Jan 31, 2022
a4a7e04
Refactor remaining parametrized elwise+op tests
honno Feb 1, 2022
4d849f1
Finish elwise TODOs
honno Feb 1, 2022
5a82a33
Fix typing issues with refimpl utils
honno Feb 1, 2022
7386615
Remove redundant `in_stype` arg in refimpl utils
honno Feb 1, 2022
80d2909
Skip when refimpl overflows
honno Feb 1, 2022
9521f6b
Values testing for remaining tests for elwise funcs starting with a
honno Feb 1, 2022
e50fc1a
Defaults for `expr_template` in refimpl utils
honno Feb 1, 2022
4a364a5
Refactor majority of elwise tests with refimpl utils
honno Feb 1, 2022
56aa06d
`strict_check` kwarg for refiml utils for testing integrals
honno Feb 1, 2022
dfda4f5
Pass but filter out-of-range values for trig function tests
honno Feb 1, 2022
9d1f4da
Extend note on refimpl utils
honno Feb 1, 2022
e72184e
Refactor remaining elwise/op tests
honno Feb 2, 2022
9edcfcc
Favour use of `operator` for `refimpl`
honno Feb 2, 2022
6e8cda6
Filter undefined dtypes in `hh.two_mutual_arrays()`
honno Feb 2, 2022
493f669
Generic type hint for `refimpl` args
honno Feb 2, 2022
d924ce4
Introduce `right_scalar_assert_against_refimpl()`
honno Feb 2, 2022
3c85cae
Note why you'd want to not strictly check int outputs
honno Feb 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions array_api_tests/array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,3 @@ def same_sign(x, y):
def assert_same_sign(x, y):
assert all(same_sign(x, y)), "The input arrays do not have the same sign"

def int_to_dtype(x, n, signed):
"""
Convert the Python integer x into an n bit signed or unsigned number.
"""
mask = (1 << n) - 1
x &= mask
if signed:
highest_bit = 1 << (n-1)
if x & highest_bit:
x = -((~x & mask) + 1)
return x
3 changes: 1 addition & 2 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ._array_module import _UndefinedStub
from ._array_module import bool as bool_dtype
from ._array_module import broadcast_to, eye, float32, float64, full
from .algos import broadcast_shapes
from .function_stubs import elementwise_functions
from .pytest_helpers import nargs
from .typing import Array, DataType, Shape
Expand Down Expand Up @@ -243,7 +242,7 @@ def two_broadcastable_shapes(draw):
broadcast to shape1.
"""
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
assume(broadcast_shapes(shape1, shape2) == shape1)
assume(sh.broadcast_shapes(shape1, shape2) == shape1)
return (shape1, shape2)

sizes = integers(0, MAX_ARRAY_SIZE)
Expand Down
16 changes: 1 addition & 15 deletions array_api_tests/meta/test_array_helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from hypothesis import given, assume
from hypothesis.strategies import integers

from ..array_helpers import exactly_equal, notequal, int_to_dtype
from ..hypothesis_helpers import integer_dtypes
from ..dtype_helpers import dtype_nbits, dtype_signed
from .. import _array_module as xp
from ..array_helpers import exactly_equal, notequal

# TODO: These meta-tests currently only work with NumPy

Expand All @@ -22,12 +17,3 @@ def test_notequal():
res = xp.asarray([False, True, False, False, False, True, False, True])
assert xp.all(xp.equal(notequal(a, b), res))

@given(integers(), integer_dtypes)
def test_int_to_dtype(x, dtype):
n = dtype_nbits[dtype]
signed = dtype_signed[dtype]
try:
d = xp.asarray(x, dtype=dtype)
except OverflowError:
assume(False)
assert int_to_dtype(x, n, signed) == d
8 changes: 4 additions & 4 deletions array_api_tests/meta/test_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from ..algos import BroadcastError, _broadcast_shapes
from .. import shape_helpers as sh


@pytest.mark.parametrize(
Expand All @@ -19,7 +19,7 @@
],
)
def test_broadcast_shapes(shape1, shape2, expected):
assert _broadcast_shapes(shape1, shape2) == expected
assert sh._broadcast_shapes(shape1, shape2) == expected


@pytest.mark.parametrize(
Expand All @@ -31,5 +31,5 @@ def test_broadcast_shapes(shape1, shape2, expected):
],
)
def test_broadcast_shapes_fails_on_bad_shapes(shape1, shape2):
with pytest.raises(BroadcastError):
_broadcast_shapes(shape1, shape2)
with pytest.raises(sh.BroadcastError):
sh._broadcast_shapes(shape1, shape2)
4 changes: 2 additions & 2 deletions array_api_tests/meta/test_hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from .. import array_helpers as ah
from .. import dtype_helpers as dh
from .. import hypothesis_helpers as hh
from .. import shape_helpers as sh
from .. import xps
from .._array_module import _UndefinedStub
from ..algos import broadcast_shapes

UNDEFINED_DTYPES = any(isinstance(d, _UndefinedStub) for d in dh.all_dtypes)
pytestmark = [pytest.mark.skipif(UNDEFINED_DTYPES, reason="undefined dtypes")]
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_two_mutually_broadcastable_shapes(pair):
def test_two_broadcastable_shapes(pair):
for shape in pair:
assert valid_shape(shape)
assert broadcast_shapes(pair[0], pair[1]) == pair[0]
assert sh.broadcast_shapes(pair[0], pair[1]) == pair[0]


@given(*hh.two_mutual_arrays())
Expand Down
10 changes: 5 additions & 5 deletions array_api_tests/meta/test_pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


def test_assert_dtype():
ph.assert_dtype("promoted_func", (xp.uint8, xp.int8), xp.int16)
ph.assert_dtype("promoted_func", [xp.uint8, xp.int8], xp.int16)
with raises(AssertionError):
ph.assert_dtype("bad_func", (xp.uint8, xp.int8), xp.float32)
ph.assert_dtype("bool_func", (xp.uint8, xp.int8), xp.bool, xp.bool)
ph.assert_dtype("single_promoted_func", (xp.uint8,), xp.uint8)
ph.assert_dtype("single_bool_func", (xp.uint8,), xp.bool, xp.bool)
ph.assert_dtype("bad_func", [xp.uint8, xp.int8], xp.float32)
ph.assert_dtype("bool_func", [xp.uint8, xp.int8], xp.bool, xp.bool)
ph.assert_dtype("single_promoted_func", [xp.uint8], xp.uint8)
ph.assert_dtype("single_bool_func", [xp.uint8], xp.bool, xp.bool)
33 changes: 33 additions & 0 deletions array_api_tests/meta/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import pytest
from hypothesis import given, reject
from hypothesis import strategies as st

from .. import _array_module as xp
from .. import xps
from .. import shape_helpers as sh
from ..test_creation_functions import frange
from ..test_manipulation_functions import roll_ndindex
from ..test_operators_and_elementwise_functions import mock_int_dtype
from ..test_signatures import extension_module


Expand Down Expand Up @@ -82,3 +87,31 @@ def test_axes_ndindex(shape, axes, expected):
)
def test_roll_ndindex(shape, shifts, axes, expected):
assert list(roll_ndindex(shape, shifts, axes)) == expected


@pytest.mark.parametrize(
"idx, expected",
[
((), "x"),
(42, "x[42]"),
((42,), "x[42]"),
(slice(None, 2), "x[:2]"),
(slice(2, None), "x[2:]"),
(slice(0, 2), "x[0:2]"),
(slice(0, 2, -1), "x[0:2:-1]"),
(slice(None, None, -1), "x[::-1]"),
(slice(None, None), "x[:]"),
(..., "x[...]"),
],
)
def test_fmt_idx(idx, expected):
assert sh.fmt_idx("x", idx) == expected


@given(x=st.integers(), dtype=xps.unsigned_integer_dtypes() | xps.integer_dtypes())
def test_int_to_dtype(x, dtype):
try:
d = xp.asarray(x, dtype=dtype)
except OverflowError:
reject()
assert mock_int_dtype(x, dtype) == d
15 changes: 7 additions & 8 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import math
from inspect import getfullargspec
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union

from . import _array_module as xp
from . import array_helpers as ah
from . import dtype_helpers as dh
from . import function_stubs
from .algos import broadcast_shapes
from . import shape_helpers as sh
from .typing import Array, DataType, Scalar, ScalarType, Shape

__all__ = [
Expand Down Expand Up @@ -71,15 +71,14 @@ def fmt_kw(kw: Dict[str, Any]) -> str:

def assert_dtype(
func_name: str,
in_dtypes: Union[DataType, Tuple[DataType, ...]],
in_dtype: Union[DataType, Sequence[DataType]],
out_dtype: DataType,
expected: Optional[DataType] = None,
*,
repr_name: str = "out.dtype",
):
if not isinstance(in_dtypes, tuple):
in_dtypes = (in_dtypes,)
f_in_dtypes = dh.fmt_types(in_dtypes)
in_dtypes = in_dtype if isinstance(in_dtype, Sequence) else [in_dtype]
f_in_dtypes = dh.fmt_types(tuple(in_dtypes))
f_out_dtype = dh.dtype_to_name[out_dtype]
if expected is None:
expected = dh.result_type(*in_dtypes)
Expand Down Expand Up @@ -150,7 +149,7 @@ def assert_shape(

def assert_result_shape(
func_name: str,
in_shapes: Tuple[Shape],
in_shapes: Sequence[Shape],
out_shape: Shape,
/,
expected: Optional[Shape] = None,
Expand All @@ -159,7 +158,7 @@ def assert_result_shape(
**kw,
):
if expected is None:
expected = broadcast_shapes(*in_shapes)
expected = sh.broadcast_shapes(*in_shapes)
f_in_shapes = " . ".join(str(s) for s in in_shapes)
f_sig = f" {f_in_shapes} "
if kw:
Expand Down
114 changes: 105 additions & 9 deletions array_api_tests/shape_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,67 @@
from itertools import product
from typing import Iterator, List, Optional, Tuple, Union

from .typing import Scalar, Shape
from ndindex import iter_indices as _iter_indices

__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]
from .typing import AtomicIndex, Index, Scalar, Shape

__all__ = [
"broadcast_shapes",
"normalise_axis",
"ndindex",
"axis_ndindex",
"axes_ndindex",
"reshape",
"fmt_idx",
]


class BroadcastError(ValueError):
"""Shapes do not broadcast with eachother"""


def _broadcast_shapes(shape1: Shape, shape2: Shape) -> Shape:
"""Broadcasts `shape1` and `shape2`"""
N1 = len(shape1)
N2 = len(shape2)
N = max(N1, N2)
shape = [None for _ in range(N)]
i = N - 1
while i >= 0:
n1 = N1 - N + i
if N1 - N + i >= 0:
d1 = shape1[n1]
else:
d1 = 1
n2 = N2 - N + i
if N2 - N + i >= 0:
d2 = shape2[n2]
else:
d2 = 1

if d1 == 1:
shape[i] = d2
elif d2 == 1:
shape[i] = d1
elif d1 == d2:
shape[i] = d1
else:
raise BroadcastError()

i = i - 1

return tuple(shape)


def broadcast_shapes(*shapes: Shape):
if len(shapes) == 0:
raise ValueError("shapes=[] must be non-empty")
elif len(shapes) == 1:
return shapes[0]
result = _broadcast_shapes(shapes[0], shapes[1])
for i in range(2, len(shapes)):
result = _broadcast_shapes(result, shapes[i])
return result


def normalise_axis(
Expand All @@ -17,13 +75,21 @@ def normalise_axis(
return axes


def ndindex(shape):
"""Iterator of n-D indices to an array
def ndindex(shape: Shape) -> Iterator[Index]:
"""Yield every index of a shape"""
return (indices[0] for indices in iter_indices(shape))


Yields tuples of integers to index every element of an array of shape
`shape`. Same as np.ndindex().
"""
return product(*[range(i) for i in shape])
def iter_indices(
*shapes: Shape, skip_axes: Tuple[int, ...] = ()
) -> Iterator[Tuple[Index, ...]]:
"""Wrapper for ndindex.iter_indices()"""
# Prevent iterations if any shape has 0-sides
for shape in shapes:
if 0 in shape:
return
for indices in _iter_indices(*shapes, skip_axes=skip_axes):
yield tuple(i.raw for i in indices) # type: ignore


def axis_ndindex(
Expand Down Expand Up @@ -60,7 +126,7 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
yield list(indices)


def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List]:
"""Reshape a flat sequence"""
if any(s == 0 for s in shape):
raise ValueError(
Expand All @@ -75,3 +141,33 @@ def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]
size = len(flat_seq)
n = math.prod(shape[1:])
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]


def fmt_i(i: AtomicIndex) -> str:
if isinstance(i, int):
return str(i)
elif isinstance(i, slice):
res = ""
if i.start is not None:
res += str(i.start)
res += ":"
if i.stop is not None:
res += str(i.stop)
if i.step is not None:
res += f":{i.step}"
return res
else:
return "..."


def fmt_idx(sym: str, idx: Index) -> str:
if idx == ():
return sym
res = f"{sym}["
_idx = idx if isinstance(idx, tuple) else (idx,)
if len(_idx) == 1:
res += fmt_i(_idx[0])
else:
res += ", ".join(fmt_i(i) for i in _idx)
res += "]"
return res
Loading