Skip to content

Commit b235ff4

Browse files
committed
ENH: is_lazy_array()
1 parent beac55b commit b235ff4

File tree

3 files changed

+126
-7
lines changed

3 files changed

+126
-7
lines changed

Diff for: array_api_compat/common/_helpers.py

+58
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,63 @@ def is_writeable_array(x) -> bool:
819819
return True
820820

821821

822+
def is_lazy_array(x) -> bool:
823+
"""Return True if x is potentially a future or it may be otherwise impossible or
824+
expensive to eagerly read its contents, regardless of their size, e.g. by
825+
calling ``bool(x)`` or ``float(x)``.
826+
827+
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
828+
cheap as long as the array has the right dtype.
829+
830+
Note
831+
----
832+
This function errs on the side of caution for array types that may or may not be
833+
lazy, e.g. JAX arrays, by always returning True for them.
834+
"""
835+
if (
836+
is_numpy_array(x)
837+
or is_cupy_array(x)
838+
or is_torch_array(x)
839+
or is_pydata_sparse_array(x)
840+
):
841+
return False
842+
843+
# **JAX note:** while it is possible to determine if you're inside or outside
844+
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
845+
# as we do below for unknown arrays, this is not recommended by JAX best practices.
846+
847+
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
848+
# This behaviour, while impossible to change without breaking backwards
849+
# compatibility, is highly detrimental to performance as the whole graph will end
850+
# up being computed multiple times.
851+
852+
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
853+
return True
854+
855+
# Unknown Array API compatible object. Note that this test may have dire consequences
856+
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
857+
# on __bool__ (dask is one such example, which however is special-cased above).
858+
859+
# Select a single point of the array
860+
s = size(x)
861+
if math.isnan(s):
862+
return True
863+
xp = array_namespace(x)
864+
if s > 1:
865+
x = xp.reshape(x, (-1,))[0]
866+
# Cast to dtype=bool and deal with size 0 arrays
867+
x = xp.any(x)
868+
869+
try:
870+
bool(x)
871+
return False
872+
# The Array API standard dictactes that __bool__ should raise TypeError if the
873+
# output cannot be defined.
874+
# Here we're more lenient and also allow for e.g. NotImplementedError.
875+
except Exception:
876+
return True
877+
878+
822879
__all__ = [
823880
"array_namespace",
824881
"device",
@@ -840,6 +897,7 @@ def is_writeable_array(x) -> bool:
840897
"is_pydata_sparse_array",
841898
"is_pydata_sparse_namespace",
842899
"is_writeable_array",
900+
"is_lazy_array",
843901
"size",
844902
"to_device",
845903
]

Diff for: docs/helper-functions.rst

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ yet.
5252
.. autofunction:: is_pydata_sparse_array
5353
.. autofunction:: is_ndonnx_array
5454
.. autofunction:: is_writeable_array
55+
.. autofunction:: is_lazy_array
5556
.. autofunction:: is_numpy_namespace
5657
.. autofunction:: is_cupy_namespace
5758
.. autofunction:: is_torch_namespace

Diff for: tests/test_common.py

+67-7
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
import math
2+
3+
import pytest
4+
import numpy as np
5+
import array
6+
from numpy.testing import assert_allclose
7+
18
from array_api_compat import ( # noqa: F401
29
is_numpy_array, is_cupy_array, is_torch_array,
310
is_dask_array, is_jax_array, is_pydata_sparse_array,
411
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
512
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
613
)
714

8-
from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device
9-
15+
from array_api_compat import (
16+
device, is_array_api_obj, is_lazy_array, is_writeable_array, to_device, size
17+
)
1018
from ._helpers import import_, wrapped_libraries, all_libraries
1119

12-
import pytest
13-
import numpy as np
14-
import array
15-
from numpy.testing import assert_allclose
16-
1720
is_array_functions = {
1821
'numpy': 'is_numpy_array',
1922
'cupy': 'is_cupy_array',
@@ -92,6 +95,62 @@ def test_is_writeable_array_numpy():
9295
assert not is_writeable_array(x)
9396

9497

98+
@pytest.mark.parametrize("library", all_libraries)
99+
def test_is_lazy_array(library):
100+
lib = import_(library)
101+
x = lib.asarray([1, 2, 3])
102+
assert isinstance(is_lazy_array(x), bool)
103+
104+
105+
@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan)])
106+
def test_is_lazy_array_nan_size(shape, monkeypatch):
107+
"""Test is_lazy_array() on an unknown Array API compliant object
108+
with NaN(s) in its shape
109+
"""
110+
xp = import_("array_api_strict")
111+
x = xp.asarray(1)
112+
assert not is_lazy_array(x)
113+
monkeypatch.setattr(type(x), "shape", shape)
114+
assert math.isnan(size(x))
115+
assert is_lazy_array(x)
116+
117+
118+
@pytest.mark.parametrize("exc", [TypeError, AssertionError])
119+
def test_is_lazy_array_bool_raises(exc, monkeypatch):
120+
"""Test is_lazy_array() on an unknown Array API compliant object
121+
where calling bool() raises:
122+
- TypeError: e.g. like jitted JAX. This is the proper exception which
123+
lazy arrays should raise as per the Array API specification
124+
- something else: e.g. like Dask, where bool() triggers compute()
125+
which can cause in any kind of exception to be raised
126+
"""
127+
xp = import_("array_api_strict")
128+
x = xp.asarray(1)
129+
assert not is_lazy_array(x)
130+
131+
def __bool__(self):
132+
raise exc("Hello world")
133+
134+
monkeypatch.setattr(type(x), "__bool__", __bool__)
135+
with pytest.raises(exc, match="Hello world"):
136+
bool(x)
137+
assert is_lazy_array(x)
138+
139+
140+
@pytest.mark.parametrize("library", all_libraries)
141+
def test_size(library):
142+
xp = import_(library)
143+
x = xp.arange(10)
144+
assert size(x) == 10
145+
146+
147+
def test_nan_size():
148+
xp = import_("dask.array")
149+
x = xp.arange(10)
150+
x = x[x < 5]
151+
assert math.isnan(size(x))
152+
153+
95154
@pytest.mark.parametrize("library", all_libraries)
96155
def test_device(library):
97156
xp = import_(library, wrapper=True)
@@ -149,6 +208,7 @@ def test_asarray_cross_library(source_library, target_library, request):
149208

150209
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
151210

211+
152212
@pytest.mark.parametrize("library", wrapped_libraries)
153213
def test_asarray_copy(library):
154214
# Note, we have this test here because the test suite currently doesn't

0 commit comments

Comments
 (0)