Skip to content

Commit e38480d

Browse files
crusaderkylucascolley
authored andcommitted
ENH: Array API 2024.12 binary ops vs. Python scalars (#131)
Co-authored-by: Lucas Colley <[email protected]>
1 parent 84e6430 commit e38480d

File tree

6 files changed

+246
-19
lines changed

6 files changed

+246
-19
lines changed

Diff for: docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
56+
"array-api": ("https://data-apis.org/array-api/draft", None),
5657
"jax": ("https://jax.readthedocs.io/en/latest", None),
5758
}
5859

Diff for: src/array_api_extra/_delegation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def isclose(
5252
5353
Parameters
5454
----------
55-
a, b : Array
56-
Input arrays to compare.
55+
a, b : Array | int | float | complex | bool
56+
Input objects to compare. At least one must be an array.
5757
rtol : array_like, optional
5858
The relative tolerance parameter (see Notes).
5959
atol : array_like, optional

Diff for: src/array_api_extra/_lib/_funcs.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ._at import at
1313
from ._utils import _compat, _helpers
1414
from ._utils._compat import array_namespace, is_jax_array
15+
from ._utils._helpers import asarrays
1516
from ._utils._typing import Array
1617

1718
__all__ = [
@@ -315,6 +316,7 @@ def isclose(
315316
xp: ModuleType,
316317
) -> Array: # numpydoc ignore=PR01,RT01
317318
"""See docstring in array_api_extra._delegation."""
319+
a, b = asarrays(a, b, xp=xp)
318320

319321
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
320322
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
@@ -356,8 +358,8 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
356358
357359
Parameters
358360
----------
359-
a, b : array
360-
Input arrays.
361+
a, b : Array | int | float | complex
362+
Input arrays or scalars. At least one must be an array.
361363
xp : array_namespace, optional
362364
The standard-compatible namespace for `a` and `b`. Default: infer.
363365
@@ -420,10 +422,10 @@ def kron(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
420422
"""
421423
if xp is None:
422424
xp = array_namespace(a, b)
425+
a, b = asarrays(a, b, xp=xp)
423426

424-
b = xp.asarray(b)
425427
singletons = (1,) * (b.ndim - a.ndim)
426-
a = xp.broadcast_to(xp.asarray(a), singletons + a.shape)
428+
a = xp.broadcast_to(a, singletons + a.shape)
427429

428430
nd_b, nd_a = b.ndim, a.ndim
429431
nd_max = max(nd_b, nd_a)
@@ -583,6 +585,7 @@ def setdiff1d(
583585
"""
584586
if xp is None:
585587
xp = array_namespace(x1, x2)
588+
x1, x2 = asarrays(x1, x2, xp=xp)
586589

587590
if assume_unique:
588591
x1 = xp.reshape(x1, (-1,))

Diff for: src/array_api_extra/_lib/_utils/_helpers.py

+84
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from __future__ import annotations
55

66
from types import ModuleType
7+
from typing import cast
78

89
from . import _compat
10+
from ._compat import is_array_api_obj, is_numpy_array
911
from ._typing import Array
1012

1113
__all__ = ["in1d", "mean"]
@@ -91,3 +93,85 @@ def mean(
9193
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
9294
return mean_real + (mean_imag * xp.asarray(1j))
9395
return xp.mean(x, axis=axis, keepdims=keepdims)
96+
97+
98+
def is_python_scalar(x: object) -> bool: # numpydoc ignore=PR01,RT01
99+
"""Return True if `x` is a Python scalar, False otherwise."""
100+
# isinstance(x, float) returns True for np.float64
101+
# isinstance(x, complex) returns True for np.complex128
102+
return isinstance(x, int | float | complex | bool) and not is_numpy_array(x)
103+
104+
105+
def asarrays(
106+
a: Array | int | float | complex | bool,
107+
b: Array | int | float | complex | bool,
108+
xp: ModuleType,
109+
) -> tuple[Array, Array]:
110+
"""
111+
Ensure both `a` and `b` are arrays.
112+
113+
If `b` is a python scalar, it is converted to the same dtype as `a`, and vice versa.
114+
115+
Behavior is not specified when mixing a Python ``float`` and an array with an
116+
integer data type; this may give ``float32``, ``float64``, or raise an exception.
117+
Behavior is implementation-specific.
118+
119+
Similarly, behavior is not specified when mixing a Python ``complex`` and an array
120+
with a real-valued data type; this may give ``complex64``, ``complex128``, or raise
121+
an exception. Behavior is implementation-specific.
122+
123+
Parameters
124+
----------
125+
a, b : Array | int | float | complex | bool
126+
Input arrays or scalars. At least one must be an array.
127+
xp : ModuleType
128+
The standard-compatible namespace for the returned arrays.
129+
130+
Returns
131+
-------
132+
Array, Array
133+
The input arrays, possibly converted to arrays if they were scalars.
134+
135+
See Also
136+
--------
137+
mixing-arrays-with-python-scalars : Array API specification for the behavior.
138+
"""
139+
a_scalar = is_python_scalar(a)
140+
b_scalar = is_python_scalar(b)
141+
if not a_scalar and not b_scalar:
142+
return a, b # This includes misc. malformed input e.g. str
143+
144+
swap = False
145+
if a_scalar:
146+
swap = True
147+
b, a = a, b
148+
149+
if is_array_api_obj(a):
150+
# a is an Array API object
151+
# b is a int | float | complex | bool
152+
153+
# pyright doesn't like it if you reuse the same variable name
154+
xa = cast(Array, a)
155+
156+
# https://data-apis.org/array-api/draft/API_specification/type_promotion.html#mixing-arrays-with-python-scalars
157+
same_dtype = {
158+
bool: "bool",
159+
int: ("integral", "real floating", "complex floating"),
160+
float: ("real floating", "complex floating"),
161+
complex: "complex floating",
162+
}
163+
kind = same_dtype[type(b)] # type: ignore[index]
164+
if xp.isdtype(xa.dtype, kind):
165+
xb = xp.asarray(b, dtype=xa.dtype)
166+
else:
167+
# Undefined behaviour. Let the function deal with it, if it can.
168+
xb = xp.asarray(b)
169+
170+
else:
171+
# Neither a nor b are Array API objects.
172+
# Note: we can only reach this point when one explicitly passes
173+
# xp=xp to the calling function; otherwise we fail earlier on
174+
# array_namespace(a, b).
175+
xa, xb = xp.asarray(a), xp.asarray(b)
176+
177+
return (xb, xa) if swap else (xa, xb)

Diff for: tests/test_funcs.py

+56-12
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,24 @@ def test_none_shape_bool(self, xp: ModuleType):
401401
a = a[a]
402402
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
403403

404+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
405+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="Array API 2024.12 support")
406+
def test_python_scalar(self, xp: ModuleType):
407+
a = xp.asarray([0.0, 0.1], dtype=xp.float32)
408+
xp_assert_equal(isclose(a, 0.0), xp.asarray([True, False]))
409+
xp_assert_equal(isclose(0.0, a), xp.asarray([True, False]))
410+
411+
a = xp.asarray([0, 1], dtype=xp.int16)
412+
xp_assert_equal(isclose(a, 0), xp.asarray([True, False]))
413+
xp_assert_equal(isclose(0, a), xp.asarray([True, False]))
414+
415+
xp_assert_equal(isclose(0, 0, xp=xp), xp.asarray(True))
416+
xp_assert_equal(isclose(0, 1, xp=xp), xp.asarray(False))
417+
418+
def test_all_python_scalars(self):
419+
with pytest.raises(TypeError, match="Unrecognized"):
420+
isclose(0, 0)
421+
404422
def test_xp(self, xp: ModuleType):
405423
a = xp.asarray([0.0, 0.0])
406424
b = xp.asarray([1e-9, 1e-4])
@@ -413,30 +431,22 @@ def test_basic(self, xp: ModuleType):
413431
# Using 0-dimensional array
414432
a = xp.asarray(1)
415433
b = xp.asarray([[1, 2], [3, 4]])
416-
k = xp.asarray([[1, 2], [3, 4]])
417-
xp_assert_equal(kron(a, b), k)
418-
a = xp.asarray([[1, 2], [3, 4]])
419-
b = xp.asarray(1)
420-
xp_assert_equal(kron(a, b), k)
434+
xp_assert_equal(kron(a, b), b)
435+
xp_assert_equal(kron(b, a), b)
421436

422437
# Using 1-dimensional array
423438
a = xp.asarray([3])
424439
b = xp.asarray([[1, 2], [3, 4]])
425440
k = xp.asarray([[3, 6], [9, 12]])
426441
xp_assert_equal(kron(a, b), k)
427-
a = xp.asarray([[1, 2], [3, 4]])
428-
b = xp.asarray([3])
429-
xp_assert_equal(kron(a, b), k)
442+
xp_assert_equal(kron(b, a), k)
430443

431444
# Using 3-dimensional array
432445
a = xp.asarray([[[1]], [[2]]])
433446
b = xp.asarray([[1, 2], [3, 4]])
434447
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
435448
xp_assert_equal(kron(a, b), k)
436-
a = xp.asarray([[1, 2], [3, 4]])
437-
b = xp.asarray([[[1]], [[2]]])
438-
k = xp.asarray([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])
439-
xp_assert_equal(kron(a, b), k)
449+
xp_assert_equal(kron(b, a), k)
440450

441451
def test_kron_smoke(self, xp: ModuleType):
442452
a = xp.ones((3, 3))
@@ -474,6 +484,18 @@ def test_kron_shape(
474484
k = kron(a, b)
475485
assert k.shape == expected_shape
476486

487+
def test_python_scalar(self, xp: ModuleType):
488+
a = 1
489+
# Test no dtype promotion to xp.asarray(a); use b.dtype
490+
b = xp.asarray([[1, 2], [3, 4]], dtype=xp.int16)
491+
xp_assert_equal(kron(a, b), b)
492+
xp_assert_equal(kron(b, a), b)
493+
xp_assert_equal(kron(1, 1, xp=xp), xp.asarray(1))
494+
495+
def test_all_python_scalars(self):
496+
with pytest.raises(TypeError, match="Unrecognized"):
497+
kron(1, 1)
498+
477499
def test_device(self, xp: ModuleType, device: Device):
478500
x1 = xp.asarray([1, 2, 3], device=device)
479501
x2 = xp.asarray([4, 5], device=device)
@@ -601,6 +623,28 @@ def test_shapes(
601623
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
602624
xp_assert_equal(actual, xp.empty((0,)))
603625

626+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
627+
@pytest.mark.parametrize("assume_unique", [True, False])
628+
def test_python_scalar(self, xp: ModuleType, assume_unique: bool):
629+
# Test no dtype promotion to xp.asarray(x2); use x1.dtype
630+
x1 = xp.asarray([3, 1, 2], dtype=xp.int16)
631+
x2 = 3
632+
actual = setdiff1d(x1, x2, assume_unique=assume_unique)
633+
xp_assert_equal(actual, xp.asarray([1, 2], dtype=xp.int16))
634+
635+
actual = setdiff1d(x2, x1, assume_unique=assume_unique)
636+
xp_assert_equal(actual, xp.asarray([], dtype=xp.int16))
637+
638+
xp_assert_equal(
639+
setdiff1d(0, 0, assume_unique=assume_unique, xp=xp),
640+
xp.asarray([0])[:0], # Default int dtype for backend
641+
)
642+
643+
@pytest.mark.parametrize("assume_unique", [True, False])
644+
def test_all_python_scalars(self, assume_unique: bool):
645+
with pytest.raises(TypeError, match="Unrecognized"):
646+
setdiff1d(0, 0, assume_unique=assume_unique)
647+
604648
def test_device(self, xp: ModuleType, device: Device):
605649
x1 = xp.asarray([3, 8, 20], device=device)
606650
x2 = xp.asarray([2, 3, 4], device=device)

Diff for: tests/test_utils.py

+96-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from types import ModuleType
22

3+
import numpy as np
34
import pytest
45

56
from array_api_extra._lib import Backend
67
from array_api_extra._lib._testing import xp_assert_equal
78
from array_api_extra._lib._utils._compat import device as get_device
8-
from array_api_extra._lib._utils._helpers import in1d
9+
from array_api_extra._lib._utils._helpers import asarrays, in1d
910
from array_api_extra._lib._utils._typing import Device
1011
from array_api_extra.testing import lazy_xp_function
1112

@@ -45,3 +46,97 @@ def test_xp(self, xp: ModuleType):
4546
expected = xp.asarray([True, False])
4647
actual = in1d(x1, x2, xp=xp)
4748
xp_assert_equal(actual, expected)
49+
50+
51+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no isdtype")
52+
@pytest.mark.parametrize(
53+
("dtype", "b", "defined"),
54+
[
55+
# Well-defined cases of dtype promotion from Python scalar to Array
56+
# bool vs. bool
57+
("bool", True, True),
58+
# int vs. xp.*int*, xp.float*, xp.complex*
59+
("int16", 1, True),
60+
("uint8", 1, True),
61+
("float32", 1, True),
62+
("float64", 1, True),
63+
("complex64", 1, True),
64+
("complex128", 1, True),
65+
# float vs. xp.float, xp.complex
66+
("float32", 1.0, True),
67+
("float64", 1.0, True),
68+
("complex64", 1.0, True),
69+
("complex128", 1.0, True),
70+
# complex vs. xp.complex
71+
("complex64", 1.0j, True),
72+
("complex128", 1.0j, True),
73+
# Undefined cases
74+
("bool", 1, False),
75+
("int64", 1.0, False),
76+
("float64", 1.0j, False),
77+
],
78+
)
79+
def test_asarrays_array_vs_scalar(
80+
dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
81+
):
82+
a = xp.asarray(1, dtype=getattr(xp, dtype))
83+
84+
xa, xb = asarrays(a, b, xp)
85+
assert xa.dtype == a.dtype
86+
if defined:
87+
assert xb.dtype == a.dtype
88+
else:
89+
assert xb.dtype == xp.asarray(b).dtype
90+
91+
xbr, xar = asarrays(b, a, xp)
92+
assert xar.dtype == xa.dtype
93+
assert xbr.dtype == xb.dtype
94+
95+
96+
def test_asarrays_scalar_vs_scalar(xp: ModuleType):
97+
a, b = asarrays(1, 2.2, xp=xp)
98+
assert a.dtype == xp.asarray(1).dtype # Default dtype
99+
assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted
100+
101+
102+
ALL_TYPES = (
103+
"int8",
104+
"int16",
105+
"int32",
106+
"int64",
107+
"uint8",
108+
"uint16",
109+
"uint32",
110+
"uint64",
111+
"float32",
112+
"float64",
113+
"complex64",
114+
"complex128",
115+
"bool",
116+
)
117+
118+
119+
@pytest.mark.parametrize("a_type", ALL_TYPES)
120+
@pytest.mark.parametrize("b_type", ALL_TYPES)
121+
def test_asarrays_array_vs_array(a_type: str, b_type: str, xp: ModuleType):
122+
"""
123+
Test that when both inputs of asarray are already Array API objects,
124+
they are returned unchanged.
125+
"""
126+
a = xp.asarray(1, dtype=getattr(xp, a_type))
127+
b = xp.asarray(1, dtype=getattr(xp, b_type))
128+
xa, xb = asarrays(a, b, xp)
129+
assert xa.dtype == a.dtype
130+
assert xb.dtype == b.dtype
131+
132+
133+
@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
134+
def test_asarrays_numpy_generics(dtype: type):
135+
"""
136+
Test special case of np.float64 and np.complex128,
137+
which are subclasses of float and complex.
138+
"""
139+
a = dtype(0)
140+
xa, xb = asarrays(a, 0, xp=np)
141+
assert xa.dtype == dtype
142+
assert xb.dtype == dtype

0 commit comments

Comments
 (0)