Skip to content

Commit 7f1f911

Browse files
tomwhitepre-commit-ci[bot]dcherian
authored
More Array API changes (#7067)
* More Array API changes, including aggregation with nans, astype, where, stack. * Add `reshape` to `duck_array_ops` * Simplify `as_shared_dtype` * Add `sum_where` to `duck_array_ops` * Remove unused `_replace_nan` function * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update xarray/core/duck_array_ops.py Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 2687536 commit 7f1f911

File tree

5 files changed

+95
-21
lines changed

5 files changed

+95
-21
lines changed

xarray/conventions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def maybe_encode_bools(var):
141141
):
142142
dims, data, attrs, encoding = _var_as_tuple(var)
143143
attrs["dtype"] = "bool"
144-
data = data.astype(dtype="i1", copy=True)
144+
data = duck_array_ops.astype(data, dtype="i1", copy=True)
145145
var = Variable(dims, data, attrs, encoding)
146146
return var
147147

xarray/core/duck_array_ops.py

+47-15
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,17 @@
1818
from numpy import zeros_like # noqa
1919
from numpy import around, broadcast_to # noqa
2020
from numpy import concatenate as _concatenate
21-
from numpy import einsum, gradient, isclose, isin, isnan, isnat # noqa
22-
from numpy import stack as _stack
23-
from numpy import take, tensordot, transpose, unravel_index # noqa
24-
from numpy import where as _where
21+
from numpy import ( # noqa
22+
einsum,
23+
gradient,
24+
isclose,
25+
isin,
26+
isnat,
27+
take,
28+
tensordot,
29+
transpose,
30+
unravel_index,
31+
)
2532
from numpy.lib.stride_tricks import sliding_window_view # noqa
2633

2734
from . import dask_array_ops, dtypes, nputils
@@ -36,6 +43,13 @@
3643
dask_array = None # type: ignore
3744

3845

46+
def get_array_namespace(x):
47+
if hasattr(x, "__array_namespace__"):
48+
return x.__array_namespace__()
49+
else:
50+
return np
51+
52+
3953
def _dask_or_eager_func(
4054
name,
4155
eager_module=np,
@@ -108,7 +122,8 @@ def isnull(data):
108122
return isnat(data)
109123
elif issubclass(scalar_type, np.inexact):
110124
# float types use NaN for null
111-
return isnan(data)
125+
xp = get_array_namespace(data)
126+
return xp.isnan(data)
112127
elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
113128
# these types cannot represent missing values
114129
return zeros_like(data, dtype=bool)
@@ -164,28 +179,31 @@ def cumulative_trapezoid(y, x, axis):
164179

165180

166181
def astype(data, dtype, **kwargs):
182+
if hasattr(data, "__array_namespace__"):
183+
xp = get_array_namespace(data)
184+
return xp.astype(data, dtype, **kwargs)
167185
return data.astype(dtype, **kwargs)
168186

169187

170188
def asarray(data, xp=np):
171189
return data if is_duck_array(data) else xp.asarray(data)
172190

173191

174-
def as_shared_dtype(scalars_or_arrays):
192+
def as_shared_dtype(scalars_or_arrays, xp=np):
175193
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
176194

177195
if any(isinstance(x, cupy_array_type) for x in scalars_or_arrays):
178196
import cupy as cp
179197

180198
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]
181199
else:
182-
arrays = [asarray(x) for x in scalars_or_arrays]
200+
arrays = [asarray(x, xp=xp) for x in scalars_or_arrays]
183201
# Pass arrays directly instead of dtypes to result_type so scalars
184202
# get handled properly.
185203
# Note that result_type() safely gets the dtype from dask arrays without
186204
# evaluating them.
187205
out_type = dtypes.result_type(*arrays)
188-
return [x.astype(out_type, copy=False) for x in arrays]
206+
return [astype(x, out_type, copy=False) for x in arrays]
189207

190208

191209
def lazy_array_equiv(arr1, arr2):
@@ -259,9 +277,20 @@ def count(data, axis=None):
259277
return np.sum(np.logical_not(isnull(data)), axis=axis)
260278

261279

280+
def sum_where(data, axis=None, dtype=None, where=None):
281+
xp = get_array_namespace(data)
282+
if where is not None:
283+
a = where_method(xp.zeros_like(data), where, data)
284+
else:
285+
a = data
286+
result = xp.sum(a, axis=axis, dtype=dtype)
287+
return result
288+
289+
262290
def where(condition, x, y):
263291
"""Three argument where() with better dtype promotion rules."""
264-
return _where(condition, *as_shared_dtype([x, y]))
292+
xp = get_array_namespace(condition)
293+
return xp.where(condition, *as_shared_dtype([x, y], xp=xp))
265294

266295

267296
def where_method(data, cond, other=dtypes.NA):
@@ -284,7 +313,13 @@ def concatenate(arrays, axis=0):
284313

285314
def stack(arrays, axis=0):
286315
"""stack() with better dtype promotion rules."""
287-
return _stack(as_shared_dtype(arrays), axis=axis)
316+
xp = get_array_namespace(arrays[0])
317+
return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis)
318+
319+
320+
def reshape(array, shape):
321+
xp = get_array_namespace(array)
322+
return xp.reshape(array, shape)
288323

289324

290325
@contextlib.contextmanager
@@ -323,11 +358,8 @@ def f(values, axis=None, skipna=None, **kwargs):
323358
if name in ["sum", "prod"]:
324359
kwargs.pop("min_count", None)
325360

326-
if hasattr(values, "__array_namespace__"):
327-
xp = values.__array_namespace__()
328-
func = getattr(xp, name)
329-
else:
330-
func = getattr(np, name)
361+
xp = get_array_namespace(values)
362+
func = getattr(xp, name)
331363

332364
try:
333365
with warnings.catch_warnings():

xarray/core/nanops.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66

77
from . import dtypes, nputils, utils
8-
from .duck_array_ops import count, fillna, isnull, where, where_method
8+
from .duck_array_ops import count, fillna, isnull, sum_where, where, where_method
99

1010

1111
def _maybe_null_out(result, axis, mask, min_count=1):
@@ -84,7 +84,7 @@ def nanargmax(a, axis=None):
8484

8585
def nansum(a, axis=None, dtype=None, out=None, min_count=None):
8686
mask = isnull(a)
87-
result = np.nansum(a, axis=axis, dtype=dtype)
87+
result = sum_where(a, axis=axis, dtype=dtype, where=mask)
8888
if min_count is not None:
8989
return _maybe_null_out(result, axis, mask, min_count)
9090
else:

xarray/core/variable.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1637,7 +1637,7 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable):
16371637
reordered = self.transpose(*dim_order)
16381638

16391639
new_shape = reordered.shape[: len(other_dims)] + (-1,)
1640-
new_data = reordered.data.reshape(new_shape)
1640+
new_data = duck_array_ops.reshape(reordered.data, new_shape)
16411641
new_dims = reordered.dims[: len(other_dims)] + (new_dim,)
16421642

16431643
return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True)

xarray/tests/test_array_api.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717

1818
@pytest.fixture
1919
def arrays() -> tuple[xr.DataArray, xr.DataArray]:
20-
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
21-
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
20+
np_arr = xr.DataArray(
21+
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]),
22+
dims=("x", "y"),
23+
coords={"x": [10, 20]},
24+
)
25+
xp_arr = xr.DataArray(
26+
xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]),
27+
dims=("x", "y"),
28+
coords={"x": [10, 20]},
29+
)
2230
assert isinstance(xp_arr.data, Array)
2331
return np_arr, xp_arr
2432

@@ -32,13 +40,30 @@ def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
3240

3341

3442
def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
43+
np_arr, xp_arr = arrays
44+
expected = np_arr.sum()
45+
actual = xp_arr.sum()
46+
assert isinstance(actual.data, Array)
47+
assert_equal(actual, expected)
48+
49+
50+
def test_aggregation_skipna(arrays) -> None:
3551
np_arr, xp_arr = arrays
3652
expected = np_arr.sum(skipna=False)
3753
actual = xp_arr.sum(skipna=False)
3854
assert isinstance(actual.data, Array)
3955
assert_equal(actual, expected)
4056

4157

58+
def test_astype(arrays) -> None:
59+
np_arr, xp_arr = arrays
60+
expected = np_arr.astype(np.int64)
61+
actual = xp_arr.astype(np.int64)
62+
assert actual.dtype == np.int64
63+
assert isinstance(actual.data, Array)
64+
assert_equal(actual, expected)
65+
66+
4267
def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
4368
np_arr, xp_arr = arrays
4469
expected = np_arr[:, 0]
@@ -59,3 +84,20 @@ def test_reorganizing_operation(arrays: tuple[xr.DataArray, xr.DataArray]) -> No
5984
actual = xp_arr.transpose()
6085
assert isinstance(actual.data, Array)
6186
assert_equal(actual, expected)
87+
88+
89+
def test_stack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None:
90+
np_arr, xp_arr = arrays
91+
expected = np_arr.stack(z=("x", "y"))
92+
actual = xp_arr.stack(z=("x", "y"))
93+
assert isinstance(actual.data, Array)
94+
assert_equal(actual, expected)
95+
96+
97+
def test_where() -> None:
98+
np_arr = xr.DataArray(np.array([1, 0]), dims="x")
99+
xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x")
100+
expected = xr.where(np_arr, 1, 0)
101+
actual = xr.where(xp_arr, 1, 0)
102+
assert isinstance(actual.data, Array)
103+
assert_equal(actual, expected)

0 commit comments

Comments
 (0)