Skip to content

Commit a14d202

Browse files
authored
allow using __array_function__ as a fallback for missing Array API functions (#9530)
* don't pass along `out` for `nanprod` We don't do this anywhere elsewhere, so it doesn't make sense to do this only for `nanprod`. * add tests for `as_indexable` * allow using `__array_function__` as a fallback for missing array API funcs * also check dask * don't try to create a `dask` array if `dask` is not installed
1 parent cde720f commit a14d202

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

xarray/core/indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -878,10 +878,10 @@ def as_indexable(array):
878878
return PandasIndexingAdapter(array)
879879
if is_duck_dask_array(array):
880880
return DaskIndexingAdapter(array)
881-
if hasattr(array, "__array_function__"):
882-
return NdArrayLikeIndexingAdapter(array)
883881
if hasattr(array, "__array_namespace__"):
884882
return ArrayApiIndexingAdapter(array)
883+
if hasattr(array, "__array_function__"):
884+
return NdArrayLikeIndexingAdapter(array)
885885

886886
raise TypeError(f"Invalid array type: {type(array)}")
887887

xarray/core/nanops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0):
162162

163163
def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
164164
mask = isnull(a)
165-
result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out)
165+
result = nputils.nanprod(a, axis=axis, dtype=dtype)
166166
if min_count is not None:
167167
return _maybe_null_out(result, axis, mask, min_count)
168168
else:

xarray/tests/test_indexing.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,74 @@ def test_posify_mask_subindexer(indices, expected) -> None:
894894
np.testing.assert_array_equal(expected, actual)
895895

896896

897+
class ArrayWithNamespace:
898+
def __array_namespace__(self, version=None):
899+
pass
900+
901+
902+
class ArrayWithArrayFunction:
903+
def __array_function__(self, func, types, args, kwargs):
904+
pass
905+
906+
907+
class ArrayWithNamespaceAndArrayFunction:
908+
def __array_namespace__(self, version=None):
909+
pass
910+
911+
def __array_function__(self, func, types, args, kwargs):
912+
pass
913+
914+
915+
def as_dask_array(arr, chunks):
916+
try:
917+
import dask.array as da
918+
except ImportError:
919+
return None
920+
921+
return da.from_array(arr, chunks=chunks)
922+
923+
924+
@pytest.mark.parametrize(
925+
["array", "expected_type"],
926+
(
927+
pytest.param(
928+
indexing.CopyOnWriteArray(np.array([1, 2])),
929+
indexing.CopyOnWriteArray,
930+
id="ExplicitlyIndexed",
931+
),
932+
pytest.param(
933+
np.array([1, 2]), indexing.NumpyIndexingAdapter, id="numpy.ndarray"
934+
),
935+
pytest.param(
936+
pd.Index([1, 2]), indexing.PandasIndexingAdapter, id="pandas.Index"
937+
),
938+
pytest.param(
939+
as_dask_array(np.array([1, 2]), chunks=(1,)),
940+
indexing.DaskIndexingAdapter,
941+
id="dask.array",
942+
marks=requires_dask,
943+
),
944+
pytest.param(
945+
ArrayWithNamespace(), indexing.ArrayApiIndexingAdapter, id="array_api"
946+
),
947+
pytest.param(
948+
ArrayWithArrayFunction(),
949+
indexing.NdArrayLikeIndexingAdapter,
950+
id="array_like",
951+
),
952+
pytest.param(
953+
ArrayWithNamespaceAndArrayFunction(),
954+
indexing.ArrayApiIndexingAdapter,
955+
id="array_api_with_fallback",
956+
),
957+
),
958+
)
959+
def test_as_indexable(array, expected_type):
960+
actual = indexing.as_indexable(array)
961+
962+
assert isinstance(actual, expected_type)
963+
964+
897965
def test_indexing_1d_object_array() -> None:
898966
items = (np.arange(3), np.arange(6))
899967
arr = DataArray(np.array(items, dtype=object))

0 commit comments

Comments
 (0)