Skip to content

ENH: Array API 2024.12 binary ops vs. Python scalars #131

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"array-api": ("https://data-apis.org/array-api/draft", None),
"jax": ("https://jax.readthedocs.io/en/latest", None),
}

Expand Down
4 changes: 2 additions & 2 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def isclose(

Parameters
----------
a, b : Array
Input arrays to compare.
a, b : Array | int | float | complex | bool
Input objects to compare. At least one must be an array.
rtol : array_like, optional
The relative tolerance parameter (see Notes).
atol : array_like, optional
Expand Down
11 changes: 7 additions & 4 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ._at import at
from ._utils import _compat, _helpers
from ._utils._compat import array_namespace, is_jax_array
from ._utils._helpers import asarrays
from ._utils._typing import Array

__all__ = [
Expand Down Expand Up @@ -315,6 +316,7 @@ def isclose(
xp: ModuleType,
) -> Array: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
a, b = asarrays(a, b, xp=xp)

a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
Expand Down Expand Up @@ -356,8 +358,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:

Parameters
----------
a, b : array
Input arrays.
a, b : Array | int | float | complex
Input arrays or scalars. At least one must be an array.
xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Expand Down Expand Up @@ -420,10 +422,10 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
if xp is None:
xp = array_namespace(a, b)
a, b = asarrays(a, b, xp=xp)

b = xp.asarray(b)
singletons = (1,) * (b.ndim - a.ndim)
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)
a = xp.broadcast_to(a, singletons + a.shape)

nd_b, nd_a = b.ndim, a.ndim
nd_max = max(nd_b, nd_a)
Expand Down Expand Up @@ -583,6 +585,7 @@ def setdiff1d(
"""
if xp is None:
xp = array_namespace(x1, x2)
x1, x2 = asarrays(x1, x2, xp=xp)

if assume_unique:
x1 = xp.reshape(x1, (-1,))
Expand Down
84 changes: 84 additions & 0 deletions src/array_api_extra/_lib/_utils/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from __future__ import annotations

from types import ModuleType
from typing import cast

from . import _compat
from ._compat import is_array_api_obj, is_numpy_array
from ._typing import Array

__all__ = ["in1d", "mean"]
Expand Down Expand Up @@ -91,3 +93,85 @@ def mean(
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
return mean_real + (mean_imag * xp.asarray(1j))
return xp.mean(x, axis=axis, keepdims=keepdims)


def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
"""Return True if `x` is a Python scalar, False otherwise."""
# isinstance(x, float) returns True for np.float64
# isinstance(x, complex) returns True for np.complex128
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)


def asarrays(
a: Array | int | float | complex | bool,
b: Array | int | float | complex | bool,
xp: ModuleType,
) -> tuple[Array, Array]:
"""
Ensure both `a` and `b` are arrays.

If `b` is a python scalar, it is converted to the same dtype as `a`, and vice versa.

Behavior is not specified when mixing a Python ``float`` and an array with an
integer data type; this may give ``float32``, ``float64``, or raise an exception.
Behavior is implementation-specific.

Similarly, behavior is not specified when mixing a Python ``complex`` and an array
with a real-valued data type; this may give ``complex64``, ``complex128``, or raise
an exception. Behavior is implementation-specific.

Parameters
----------
a, b : Array | int | float | complex | bool
Input arrays or scalars. At least one must be an array.
xp : ModuleType
The standard-compatible namespace for the returned arrays.

Returns
-------
Array, Array
The input arrays, possibly converted to arrays if they were scalars.

See Also
--------
mixing-arrays-with-python-scalars : Array API specification for the behavior.
"""
a_scalar = is_python_scalar(a)
b_scalar = is_python_scalar(b)
if not a_scalar and not b_scalar:
return a, b # This includes misc. malformed input e.g. str

swap = False
if a_scalar:
swap = True
b, a = a, b

if is_array_api_obj(a):
# a is an Array API object
# b is a int | float | complex | bool

# pyright doesn't like it if you reuse the same variable name
xa = cast(Array, a)

# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
same_dtype = {
bool: "bool",
int: ("integral", "real floating", "complex floating"),
float: ("real floating", "complex floating"),
complex: "complex floating",
}
kind = same_dtype[type(b)] # type: ignore[index]
if xp.isdtype(xa.dtype, kind):
xb = xp.asarray(b, dtype=xa.dtype)
else:
# Undefined behaviour. Let the function deal with it, if it can.
xb = xp.asarray(b)

else:
# Neither a nor b are Array API objects.
# Note: we can only reach this point when one explicitly passes
# xp=xp to the calling function; otherwise we fail earlier on
# array_namespace(a, b).
xa, xb = xp.asarray(a), xp.asarray(b)

return (xb, xa) if swap else (xa, xb)
68 changes: 56 additions & 12 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,24 @@ def test_none_shape_bool(self, xp: ModuleType):
a = a[a]
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
def test_python_scalar(self, xp: ModuleType):
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))
xp_assert_equal(isclose(0.0, a), xp.asarray([True, False]))

a = xp.asarray([0, 1], dtype=xp.int16)
xp_assert_equal(isclose(a, 0), xp.asarray([True, False]))
xp_assert_equal(isclose(0, a), xp.asarray([True, False]))

xp_assert_equal(isclose(0, 0, xp=xp), xp.asarray(True))
xp_assert_equal(isclose(0, 1, xp=xp), xp.asarray(False))

def test_all_python_scalars(self):
with pytest.raises(TypeError, match="Unrecognized"):
isclose(0, 0)

def test_xp(self, xp: ModuleType):
a = xp.asarray([0.0, 0.0])
b = xp.asarray([1e-9, 1e-4])
Expand All @@ -413,30 +431,22 @@ def test_basic(self, xp: ModuleType):
# Using 0-dimensional array
a = xp.asarray(1)
b = xp.asarray([[1, 2], [3, 4]])
k = xp.asarray([[1, 2], [3, 4]])
xp_assert_equal(kron(a, b), k)
a = xp.asarray([[1, 2], [3, 4]])
b = xp.asarray(1)
xp_assert_equal(kron(a, b), k)
xp_assert_equal(kron(a, b), b)
xp_assert_equal(kron(b, a), b)

# Using 1-dimensional array
a = xp.asarray([3])
b = xp.asarray([[1, 2], [3, 4]])
k = xp.asarray([[3, 6], [9, 12]])
xp_assert_equal(kron(a, b), k)
a = xp.asarray([[1, 2], [3, 4]])
b = xp.asarray([3])
xp_assert_equal(kron(a, b), k)
xp_assert_equal(kron(b, a), k)

# Using 3-dimensional array
a = xp.asarray([[[1]], [[2]]])
b = xp.asarray([[1, 2], [3, 4]])
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
xp_assert_equal(kron(a, b), k)
a = xp.asarray([[1, 2], [3, 4]])
b = xp.asarray([[[1]], [[2]]])
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
xp_assert_equal(kron(a, b), k)
xp_assert_equal(kron(b, a), k)

def test_kron_smoke(self, xp: ModuleType):
a = xp.ones((3, 3))
Expand Down Expand Up @@ -474,6 +484,18 @@ def test_kron_shape(
k = kron(a, b)
assert k.shape == expected_shape

def test_python_scalar(self, xp: ModuleType):
a = 1
# Test no dtype promotion to xp.asarray(a); use b.dtype
b = xp.asarray([[1, 2], [3, 4]], dtype=xp.int16)
xp_assert_equal(kron(a, b), b)
xp_assert_equal(kron(b, a), b)
xp_assert_equal(kron(1, 1, xp=xp), xp.asarray(1))

def test_all_python_scalars(self):
with pytest.raises(TypeError, match="Unrecognized"):
kron(1, 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does NOT fail if xp=jax.numpy, because xp_lazy_function converts everything to jax.


def test_device(self, xp: ModuleType, device: Device):
x1 = xp.asarray([1, 2, 3], device=device)
x2 = xp.asarray([4, 5], device=device)
Expand Down Expand Up @@ -601,6 +623,28 @@ def test_shapes(
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.empty((0,)))

@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
@pytest.mark.parametrize("assume_unique", [True, False])
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
x2 = 3
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))

actual = setdiff1d(x2, x1, assume_unique=assume_unique)
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))

xp_assert_equal(
setdiff1d(0, 0, assume_unique=assume_unique, xp=xp),
xp.asarray([0])[:0], # Default int dtype for backend
)

@pytest.mark.parametrize("assume_unique", [True, False])
def test_all_python_scalars(self, assume_unique: bool):
with pytest.raises(TypeError, match="Unrecognized"):
setdiff1d(0, 0, assume_unique=assume_unique)

def test_device(self, xp: ModuleType, device: Device):
x1 = xp.asarray([3, 8, 20], device=device)
x2 = xp.asarray([2, 3, 4], device=device)
Expand Down
97 changes: 96 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from types import ModuleType

import numpy as np
import pytest

from array_api_extra._lib import Backend
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import device as get_device
from array_api_extra._lib._utils._helpers import in1d
from array_api_extra._lib._utils._helpers import asarrays, in1d
from array_api_extra._lib._utils._typing import Device
from array_api_extra.testing import lazy_xp_function

Expand Down Expand Up @@ -45,3 +46,97 @@ def test_xp(self, xp: ModuleType):
expected = xp.asarray([True, False])
actual = in1d(x1, x2, xp=xp)
xp_assert_equal(actual, expected)


@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
@pytest.mark.parametrize(
("dtype", "b", "defined"),
[
# Well-defined cases of dtype promotion from Python scalar to Array
# bool vs. bool
("bool", True, True),
# int vs. xp.*int*, xp.float*, xp.complex*
("int16", 1, True),
("uint8", 1, True),
("float32", 1, True),
("float64", 1, True),
("complex64", 1, True),
("complex128", 1, True),
# float vs. xp.float, xp.complex
("float32", 1.0, True),
("float64", 1.0, True),
("complex64", 1.0, True),
("complex128", 1.0, True),
# complex vs. xp.complex
("complex64", 1.0j, True),
("complex128", 1.0j, True),
# Undefined cases
("bool", 1, False),
("int64", 1.0, False),
("float64", 1.0j, False),
],
)
def test_asarrays_array_vs_scalar(
dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
):
a = xp.asarray(1, dtype=getattr(xp, dtype))

xa, xb = asarrays(a, b, xp)
assert xa.dtype == a.dtype
if defined:
assert xb.dtype == a.dtype
else:
assert xb.dtype == xp.asarray(b).dtype

xbr, xar = asarrays(b, a, xp)
assert xar.dtype == xa.dtype
assert xbr.dtype == xb.dtype


def test_asarrays_scalar_vs_scalar(xp: ModuleType):
a, b = asarrays(1, 2.2, xp=xp)
assert a.dtype == xp.asarray(1).dtype # Default dtype
assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted


ALL_TYPES = (
"int8",
"int16",
"int32",
"int64",
"uint8",
"uint16",
"uint32",
"uint64",
"float32",
"float64",
"complex64",
"complex128",
"bool",
)


@pytest.mark.parametrize("a_type", ALL_TYPES)
@pytest.mark.parametrize("b_type", ALL_TYPES)
def test_asarrays_array_vs_array(a_type: str, b_type: str, xp: ModuleType):
"""
Test that when both inputs of asarray are already Array API objects,
they are returned unchanged.
"""
a = xp.asarray(1, dtype=getattr(xp, a_type))
b = xp.asarray(1, dtype=getattr(xp, b_type))
xa, xb = asarrays(a, b, xp)
assert xa.dtype == a.dtype
assert xb.dtype == b.dtype


@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
def test_asarrays_numpy_generics(dtype: type):
"""
Test special case of np.float64 and np.complex128,
which are subclasses of float and complex.
"""
a = dtype(0)
xa, xb = asarrays(a, 0, xp=np)
assert xa.dtype == dtype
assert xb.dtype == dtype