|
| 1 | +import math |
| 2 | + |
| 3 | +import pytest |
| 4 | +import numpy as np |
| 5 | +import array |
| 6 | +from numpy.testing import assert_allclose |
| 7 | + |
1 | 8 | from array_api_compat import ( # noqa: F401
|
2 | 9 | is_numpy_array, is_cupy_array, is_torch_array,
|
3 | 10 | is_dask_array, is_jax_array, is_pydata_sparse_array,
|
4 | 11 | is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
|
5 | 12 | is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
|
6 | 13 | )
|
7 | 14 |
|
8 |
| -from array_api_compat import device, is_array_api_obj, is_writeable_array, to_device |
9 |
| - |
| 15 | +from array_api_compat import ( |
| 16 | + device, is_array_api_obj, is_lazy_array, is_writeable_array, to_device, size |
| 17 | +) |
10 | 18 | from ._helpers import import_, wrapped_libraries, all_libraries
|
11 | 19 |
|
12 |
| -import pytest |
13 |
| -import numpy as np |
14 |
| -import array |
15 |
| -from numpy.testing import assert_allclose |
16 |
| - |
17 | 20 | is_array_functions = {
|
18 | 21 | 'numpy': 'is_numpy_array',
|
19 | 22 | 'cupy': 'is_cupy_array',
|
@@ -92,6 +95,62 @@ def test_is_writeable_array_numpy():
|
92 | 95 | assert not is_writeable_array(x)
|
93 | 96 |
|
94 | 97 |
|
| 98 | +@pytest.mark.parametrize("library", all_libraries) |
| 99 | +def test_is_lazy_array(library): |
| 100 | + lib = import_(library) |
| 101 | + x = lib.asarray([1, 2, 3]) |
| 102 | + assert isinstance(is_lazy_array(x), bool) |
| 103 | + |
| 104 | + |
| 105 | +@pytest.mark.parametrize("shape", [(math.nan,), (1, math.nan)]) |
| 106 | +def test_is_lazy_array_nan_size(shape, monkeypatch): |
| 107 | + """Test is_lazy_array() on an unknown Array API compliant object |
| 108 | + with NaN(s) in its shape |
| 109 | + """ |
| 110 | + xp = import_("array_api_strict") |
| 111 | + x = xp.asarray(1) |
| 112 | + assert not is_lazy_array(x) |
| 113 | + monkeypatch.setattr(type(x), "shape", shape) |
| 114 | + assert math.isnan(size(x)) |
| 115 | + assert is_lazy_array(x) |
| 116 | + |
| 117 | + |
| 118 | +@pytest.mark.parametrize("exc", [TypeError, AssertionError]) |
| 119 | +def test_is_lazy_array_bool_raises(exc, monkeypatch): |
| 120 | + """Test is_lazy_array() on an unknown Array API compliant object |
| 121 | + where calling bool() raises: |
| 122 | + - TypeError: e.g. like jitted JAX. This is the proper exception which |
| 123 | + lazy arrays should raise as per the Array API specification |
| 124 | + - something else: e.g. like Dask, where bool() triggers compute() |
| 125 | + which can cause in any kind of exception to be raised |
| 126 | + """ |
| 127 | + xp = import_("array_api_strict") |
| 128 | + x = xp.asarray(1) |
| 129 | + assert not is_lazy_array(x) |
| 130 | + |
| 131 | + def __bool__(self): |
| 132 | + raise exc("Hello world") |
| 133 | + |
| 134 | + monkeypatch.setattr(type(x), "__bool__", __bool__) |
| 135 | + with pytest.raises(exc, match="Hello world"): |
| 136 | + bool(x) |
| 137 | + assert is_lazy_array(x) |
| 138 | + |
| 139 | + |
| 140 | +@pytest.mark.parametrize("library", all_libraries) |
| 141 | +def test_size(library): |
| 142 | + xp = import_(library) |
| 143 | + x = xp.arange(10) |
| 144 | + assert size(x) == 10 |
| 145 | + |
| 146 | + |
| 147 | +def test_nan_size(): |
| 148 | + xp = import_("dask.array") |
| 149 | + x = xp.arange(10) |
| 150 | + x = x[x < 5] |
| 151 | + assert math.isnan(size(x)) |
| 152 | + |
| 153 | + |
95 | 154 | @pytest.mark.parametrize("library", all_libraries)
|
96 | 155 | def test_device(library):
|
97 | 156 | xp = import_(library, wrapper=True)
|
@@ -149,6 +208,7 @@ def test_asarray_cross_library(source_library, target_library, request):
|
149 | 208 |
|
150 | 209 | assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
|
151 | 210 |
|
| 211 | + |
152 | 212 | @pytest.mark.parametrize("library", wrapped_libraries)
|
153 | 213 | def test_asarray_copy(library):
|
154 | 214 | # Note, we have this test here because the test suite currently doesn't
|
|
0 commit comments