Skip to content

Commit 835a9ca

Browse files
authored
Merge pull request #348 from ev-br/test_py_scalars
ENH: test binary functions with python scalars
2 parents 9a2f4cf + b63b89b commit 835a9ca

File tree

3 files changed

+204
-36
lines changed

3 files changed

+204
-36
lines changed

Diff for: array_api_tests/dtype_helpers.py

+1
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def get_scalar_type(dtype: DataType) -> ScalarType:
198198
def is_scalar(x):
199199
return isinstance(x, (int, float, complex, bool))
200200

201+
201202
def _make_dtype_mapping_from_names(mapping: Dict[str, Any]) -> EqualityMapping:
202203
dtype_value_pairs = []
203204
for name, value in mapping.items():

Diff for: array_api_tests/hypothesis_helpers.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from hypothesis import assume, reject
1111
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
12-
integers, just, lists, none, one_of,
12+
integers, complex_numbers, just, lists, none, one_of,
1313
sampled_from, shared, builds, nothing)
1414

1515
from . import _array_module as xp, api_version
@@ -19,7 +19,7 @@
1919
from . import xps
2020
from ._array_module import _UndefinedStub
2121
from ._array_module import bool as bool_dtype
22-
from ._array_module import broadcast_to, eye, float32, float64, full
22+
from ._array_module import broadcast_to, eye, float32, float64, full, complex64, complex128
2323
from .stubs import category_to_funcs
2424
from .pytest_helpers import nargs
2525
from .typing import Array, DataType, Scalar, Shape
@@ -462,6 +462,14 @@ def scalars(draw, dtypes, finite=False):
462462
if finite:
463463
return draw(floats(width=32, allow_nan=False, allow_infinity=False))
464464
return draw(floats(width=32))
465+
elif dtype == complex64:
466+
if finite:
467+
return draw(complex_numbers(width=32, allow_nan=False, allow_infinity=False))
468+
return draw(complex_numbers(width=32))
469+
elif dtype == complex128:
470+
if finite:
471+
return draw(complex_numbers(allow_nan=False, allow_infinity=False))
472+
return draw(complex_numbers())
465473
else:
466474
raise ValueError(f"Unrecognized dtype {dtype}")
467475

@@ -571,6 +579,20 @@ def two_mutual_arrays(
571579
)
572580
return arrays1, arrays2
573581

582+
583+
@composite
584+
def array_and_py_scalar(draw, dtypes):
585+
"""Draw a pair: (array, scalar) or (scalar, array)."""
586+
dtype = draw(sampled_from(dtypes))
587+
scalar_var = draw(scalars(just(dtype), finite=True))
588+
array_var = draw(arrays(dtype, shape=shapes(min_dims=1)))
589+
590+
if draw(booleans()):
591+
return scalar_var, array_var
592+
else:
593+
return array_var, scalar_var
594+
595+
574596
@composite
575597
def kwargs(draw, **kw):
576598
"""

Diff for: array_api_tests/test_operators_and_elementwise_functions.py

+179-34
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,40 @@ def binary_param_assert_against_refimpl(
690690
)
691691

692692

693+
def _convert_scalars_helper(x1, x2):
694+
"""Convert python scalar to arrays, record the shapes/dtypes of arrays.
695+
696+
For inputs being scalars or arrays, return the dtypes and shapes of array arguments,
697+
and all arguments converted to arrays.
698+
699+
dtypes are separate to help distinguishing between
700+
`py_scalar + f32_array -> f32_array` and `f64_array + f32_array -> f64_array`
701+
"""
702+
if dh.is_scalar(x1):
703+
in_dtypes = [x2.dtype]
704+
in_shapes = [x2.shape]
705+
x1a, x2a = xp.asarray(x1), x2
706+
elif dh.is_scalar(x2):
707+
in_dtypes = [x1.dtype]
708+
in_shapes = [x1.shape]
709+
x1a, x2a = x1, xp.asarray(x2)
710+
else:
711+
in_dtypes = [x1.dtype, x2.dtype]
712+
in_shapes = [x1.shape, x2.shape]
713+
x1a, x2a = x1, x2
714+
715+
return in_dtypes, in_shapes, (x1a, x2a)
716+
717+
718+
def _assert_correctness_binary(
719+
name, func, in_dtypes, in_shapes, in_arrs, out, expected_dtype=None, **kwargs
720+
):
721+
x1a, x2a = in_arrs
722+
ph.assert_dtype(name, in_dtype=in_dtypes, out_dtype=out.dtype, expected=expected_dtype)
723+
ph.assert_result_shape(name, in_shapes=in_shapes, out_shape=out.shape)
724+
binary_assert_against_refimpl(name, x1a, x2a, out, func, **kwargs)
725+
726+
693727
@pytest.mark.parametrize("ctx", make_unary_params("abs", dh.numeric_dtypes))
694728
@given(data=st.data())
695729
def test_abs(ctx, data):
@@ -789,10 +823,14 @@ def test_atan(x):
789823
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
790824
def test_atan2(x1, x2):
791825
out = xp.atan2(x1, x2)
792-
ph.assert_dtype("atan2", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
793-
ph.assert_result_shape("atan2", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
794-
refimpl = cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2
795-
binary_assert_against_refimpl("atan2", x1, x2, out, refimpl)
826+
_assert_correctness_binary(
827+
"atan",
828+
cmath.atan2 if x1.dtype in dh.complex_dtypes else math.atan2,
829+
in_dtypes=[x1.dtype, x2.dtype],
830+
in_shapes=[x1.shape, x2.shape],
831+
in_arrs=[x1, x2],
832+
out=out,
833+
)
796834

797835

798836
@given(hh.arrays(dtype=hh.all_floating_dtypes(), shape=hh.shapes()))
@@ -1258,10 +1296,14 @@ def test_greater_equal(ctx, data):
12581296
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
12591297
def test_hypot(x1, x2):
12601298
out = xp.hypot(x1, x2)
1261-
ph.assert_dtype("hypot", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1262-
ph.assert_result_shape("hypot", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1263-
binary_assert_against_refimpl("hypot", x1, x2, out, math.hypot)
1264-
1299+
_assert_correctness_binary(
1300+
"hypot",
1301+
math.hypot,
1302+
in_dtypes=[x1.dtype, x2.dtype],
1303+
in_shapes=[x1.shape, x2.shape],
1304+
in_arrs=[x1, x2],
1305+
out=out
1306+
)
12651307

12661308

12671309
@pytest.mark.min_version("2022.12")
@@ -1411,21 +1453,17 @@ def logaddexp_refimpl(l: float, r: float) -> float:
14111453
raise OverflowError
14121454

14131455

1456+
@pytest.mark.min_version("2023.12")
14141457
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14151458
def test_logaddexp(x1, x2):
14161459
out = xp.logaddexp(x1, x2)
1417-
ph.assert_dtype("logaddexp", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1418-
ph.assert_result_shape("logaddexp", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1419-
binary_assert_against_refimpl("logaddexp", x1, x2, out, logaddexp_refimpl)
1420-
1421-
1422-
@given(*hh.two_mutual_arrays([xp.bool]))
1423-
def test_logical_and(x1, x2):
1424-
out = xp.logical_and(x1, x2)
1425-
ph.assert_dtype("logical_and", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1426-
ph.assert_result_shape("logical_and", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1427-
binary_assert_against_refimpl(
1428-
"logical_and", x1, x2, out, operator.and_, expr_template="({} and {})={}"
1460+
_assert_correctness_binary(
1461+
"logaddexp",
1462+
logaddexp_refimpl,
1463+
in_dtypes=[x1.dtype, x2.dtype],
1464+
in_shapes=[x1.shape, x2.shape],
1465+
in_arrs=[x1, x2],
1466+
out=out
14291467
)
14301468

14311469

@@ -1439,42 +1477,64 @@ def test_logical_not(x):
14391477
)
14401478

14411479

1480+
@given(*hh.two_mutual_arrays([xp.bool]))
1481+
def test_logical_and(x1, x2):
1482+
out = xp.logical_and(x1, x2)
1483+
_assert_correctness_binary(
1484+
"logical_and",
1485+
operator.and_,
1486+
in_dtypes=[x1.dtype, x2.dtype],
1487+
in_shapes=[x1.shape, x2.shape],
1488+
in_arrs=[x1, x2],
1489+
out=out,
1490+
expr_template="({} and {})={}"
1491+
)
1492+
1493+
14421494
@given(*hh.two_mutual_arrays([xp.bool]))
14431495
def test_logical_or(x1, x2):
14441496
out = xp.logical_or(x1, x2)
1445-
ph.assert_dtype("logical_or", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1446-
ph.assert_result_shape("logical_or", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1447-
binary_assert_against_refimpl(
1448-
"logical_or", x1, x2, out, operator.or_, expr_template="({} or {})={}"
1497+
_assert_correctness_binary(
1498+
"logical_or",
1499+
operator.or_,
1500+
in_dtypes=[x1.dtype, x2.dtype],
1501+
in_shapes=[x1.shape, x2.shape],
1502+
in_arrs=[x1, x2],
1503+
out=out,
1504+
expr_template="({} or {})={}"
14491505
)
14501506

14511507

14521508
@given(*hh.two_mutual_arrays([xp.bool]))
14531509
def test_logical_xor(x1, x2):
14541510
out = xp.logical_xor(x1, x2)
1455-
ph.assert_dtype("logical_xor", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1456-
ph.assert_result_shape("logical_xor", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1457-
binary_assert_against_refimpl(
1458-
"logical_xor", x1, x2, out, operator.xor, expr_template="({} ^ {})={}"
1511+
_assert_correctness_binary(
1512+
"logical_xor",
1513+
operator.xor,
1514+
in_dtypes=[x1.dtype, x2.dtype],
1515+
in_shapes=[x1.shape, x2.shape],
1516+
in_arrs=[x1, x2],
1517+
out=out,
1518+
expr_template="({} ^ {})={}"
14591519
)
14601520

14611521

14621522
@pytest.mark.min_version("2023.12")
14631523
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14641524
def test_maximum(x1, x2):
14651525
out = xp.maximum(x1, x2)
1466-
ph.assert_dtype("maximum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1467-
ph.assert_result_shape("maximum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1468-
binary_assert_against_refimpl("maximum", x1, x2, out, max, strict_check=True)
1526+
_assert_correctness_binary(
1527+
"maximum", max, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
1528+
)
14691529

14701530

14711531
@pytest.mark.min_version("2023.12")
14721532
@given(*hh.two_mutual_arrays(dh.real_float_dtypes))
14731533
def test_minimum(x1, x2):
14741534
out = xp.minimum(x1, x2)
1475-
ph.assert_dtype("minimum", in_dtype=[x1.dtype, x2.dtype], out_dtype=out.dtype)
1476-
ph.assert_result_shape("minimum", in_shapes=[x1.shape, x2.shape], out_shape=out.shape)
1477-
binary_assert_against_refimpl("minimum", x1, x2, out, min, strict_check=True)
1535+
_assert_correctness_binary(
1536+
"minimum", min, [x1.dtype, x2.dtype], [x1.shape, x2.shape], (x1, x2), out, strict_check=True
1537+
)
14781538

14791539

14801540
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@@ -1719,3 +1779,88 @@ def test_trunc(x):
17191779
ph.assert_dtype("trunc", in_dtype=x.dtype, out_dtype=out.dtype)
17201780
ph.assert_shape("trunc", out_shape=out.shape, expected=x.shape)
17211781
unary_assert_against_refimpl("trunc", x, out, math.trunc, strict_check=True)
1782+
1783+
1784+
def _check_binary_with_scalars(func_data, x1x2):
1785+
x1, x2 = x1x2
1786+
func, name, refimpl, kwds, expected_dtype = func_data
1787+
out = func(x1, x2)
1788+
in_dtypes, in_shapes, (x1a, x2a) = _convert_scalars_helper(x1, x2)
1789+
_assert_correctness_binary(
1790+
name, refimpl, in_dtypes, in_shapes, (x1a, x2a), out, expected_dtype, **kwds
1791+
)
1792+
1793+
1794+
def _filter_zero(x):
1795+
return x != 0 if dh.is_scalar(x) else (not xp.any(x == 0))
1796+
1797+
1798+
@pytest.mark.min_version("2024.12")
1799+
@pytest.mark.parametrize('func_data',
1800+
# xp_func, name, refimpl, kwargs, expected_dtype
1801+
[
1802+
(xp.add, "add", operator.add, {}, None),
1803+
(xp.atan2, "atan2", math.atan2, {}, None),
1804+
(xp.copysign, "copysign", math.copysign, {}, None),
1805+
(xp.divide, "divide", operator.truediv, {"filter_": lambda s: s != 0}, None),
1806+
(xp.hypot, "hypot", math.hypot, {}, None),
1807+
(xp.logaddexp, "logaddexp", logaddexp_refimpl, {}, None),
1808+
(xp.maximum, "maximum", max, {'strict_check': True}, None),
1809+
(xp.minimum, "minimum", min, {'strict_check': True}, None),
1810+
(xp.multiply, "mul", operator.mul, {}, None),
1811+
(xp.subtract, "sub", operator.sub, {}, None),
1812+
1813+
(xp.equal, "equal", operator.eq, {}, xp.bool),
1814+
(xp.not_equal, "neq", operator.ne, {}, xp.bool),
1815+
(xp.less, "less", operator.lt, {}, xp.bool),
1816+
(xp.less_equal, "les_equal", operator.le, {}, xp.bool),
1817+
(xp.greater, "greater", operator.gt, {}, xp.bool),
1818+
(xp.greater_equal, "greater_equal", operator.ge, {}, xp.bool),
1819+
(xp.remainder, "remainder", operator.mod, {}, None),
1820+
(xp.floor_divide, "floor_divide", operator.floordiv, {}, None),
1821+
],
1822+
ids=lambda func_data: func_data[1] # use names for test IDs
1823+
)
1824+
@given(x1x2=hh.array_and_py_scalar(dh.real_float_dtypes))
1825+
def test_binary_with_scalars_real(func_data, x1x2):
1826+
1827+
if func_data[1] == "remainder":
1828+
assume(_filter_zero(x1x2[1]))
1829+
if func_data[1] == "floor_divide":
1830+
assume(_filter_zero(x1x2[0]) and _filter_zero(x1x2[1]))
1831+
1832+
_check_binary_with_scalars(func_data, x1x2)
1833+
1834+
1835+
@pytest.mark.min_version("2024.12")
1836+
@pytest.mark.parametrize('func_data',
1837+
# xp_func, name, refimpl, kwargs, expected_dtype
1838+
[
1839+
(xp.logical_and, "logical_and", operator.and_, {"expr_template": "({} or {})={}"}, None),
1840+
(xp.logical_or, "logical_or", operator.or_, {"expr_template": "({} or {})={}"}, None),
1841+
(xp.logical_xor, "logical_xor", operator.xor, {"expr_template": "({} or {})={}"}, None),
1842+
],
1843+
ids=lambda func_data: func_data[1] # use names for test IDs
1844+
)
1845+
@given(x1x2=hh.array_and_py_scalar([xp.bool]))
1846+
def test_binary_with_scalars_bool(func_data, x1x2):
1847+
_check_binary_with_scalars(func_data, x1x2)
1848+
1849+
1850+
@pytest.mark.min_version("2024.12")
1851+
@pytest.mark.parametrize('func_data',
1852+
# xp_func, name, refimpl, kwargs, expected_dtype
1853+
[
1854+
(xp.bitwise_and, "bitwise_and", operator.and_, {}, None),
1855+
(xp.bitwise_or, "bitwise_or", operator.or_, {}, None),
1856+
(xp.bitwise_xor, "bitwise_xor", operator.xor, {}, None),
1857+
],
1858+
ids=lambda func_data: func_data[1] # use names for test IDs
1859+
)
1860+
@given(x1x2=hh.array_and_py_scalar([xp.int32]))
1861+
def test_binary_with_scalars_bitwise(func_data, x1x2):
1862+
xp_func, name, refimpl, kwargs, expected = func_data
1863+
# repack the refimpl
1864+
refimpl_ = lambda l, r: mock_int_dtype(refimpl(l, r), xp.int32 )
1865+
_check_binary_with_scalars((xp_func, name, refimpl_, kwargs,expected), x1x2)
1866+

0 commit comments

Comments
 (0)