From 8a810da1716be217a5c97f93ca1c2778c115244e Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 5 Aug 2022 11:56:05 +0100 Subject: [PATCH 1/7] More Array API changes, including aggregation with nans, astype, where, stack. --- xarray/conventions.py | 2 +- xarray/core/duck_array_ops.py | 54 ++++++++++++++++++++++++---------- xarray/core/nanops.py | 28 ++++++++++++++++-- xarray/core/variable.py | 3 +- xarray/tests/test_array_api.py | 34 +++++++++++++++++++++ 5 files changed, 101 insertions(+), 20 deletions(-) diff --git a/xarray/conventions.py b/xarray/conventions.py index 8bd316d199f..695bed3b365 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -141,7 +141,7 @@ def maybe_encode_bools(var): ): dims, data, attrs, encoding = _var_as_tuple(var) attrs["dtype"] = "bool" - data = data.astype(dtype="i1", copy=True) + data = duck_array_ops.astype(data, dtype="i1", copy=True) var = Variable(dims, data, attrs, encoding) return var diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 8c8f2443967..b61ffb700a8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,12 +18,19 @@ from numpy import zeros_like # noqa from numpy import around, broadcast_to # noqa from numpy import concatenate as _concatenate -from numpy import einsum, gradient, isclose, isin, isnan, isnat # noqa -from numpy import stack as _stack -from numpy import take, tensordot, transpose, unravel_index # noqa -from numpy import where as _where +from numpy import ( # noqa + einsum, + isclose, + isin, + isnat, + take, + tensordot, + transpose, + unravel_index, +) from numpy.lib.stride_tricks import sliding_window_view # noqa + from . import dask_array_ops, dtypes, nputils from .nputils import nanfirst, nanlast from .pycompat import cupy_array_type, is_duck_dask_array @@ -36,6 +43,13 @@ dask_array = None # type: ignore +def get_array_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + else: + return np + + def _dask_or_eager_func( name, eager_module=np, @@ -108,7 +122,8 @@ def isnull(data): return isnat(data) elif issubclass(scalar_type, np.inexact): # float types use NaN for null - return isnan(data) + xp = get_array_namespace(data) + return xp.isnan(data) elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)): # these types cannot represent missing values return zeros_like(data, dtype=bool) @@ -164,6 +179,9 @@ def cumulative_trapezoid(y, x, axis): def astype(data, dtype, **kwargs): + if hasattr(data, "__array_namespace__"): + xp = get_array_namespace(data) + return xp.astype(data, dtype, **kwargs) return data.astype(dtype, **kwargs) @@ -171,21 +189,28 @@ def asarray(data, xp=np): return data if is_duck_array(data) else xp.asarray(data) -def as_shared_dtype(scalars_or_arrays): +def as_shared_dtype(scalars_or_arrays, xp=np): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" if any(isinstance(x, cupy_array_type) for x in scalars_or_arrays): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] + elif any(hasattr(x, "__array_namespace__") for x in scalars_or_arrays): + xp = [x for x in scalars_or_arrays if hasattr(x, "__array_namespace__")][ + 0 + ].__array_namespace__() + arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] + out_type = dtypes.result_type(*arrays) + return [xp.astype(x, out_type, copy=False) for x in arrays] else: - arrays = [asarray(x) for x in scalars_or_arrays] + arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars # get handled properly. # Note that result_type() safely gets the dtype from dask arrays without # evaluating them. out_type = dtypes.result_type(*arrays) - return [x.astype(out_type, copy=False) for x in arrays] + return [astype(x, out_type, copy=False) for x in arrays] def lazy_array_equiv(arr1, arr2): @@ -261,7 +286,8 @@ def count(data, axis=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" - return _where(condition, *as_shared_dtype([x, y])) + xp = get_array_namespace(condition) + return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) def where_method(data, cond, other=dtypes.NA): @@ -284,7 +310,8 @@ def concatenate(arrays, axis=0): def stack(arrays, axis=0): """stack() with better dtype promotion rules.""" - return _stack(as_shared_dtype(arrays), axis=axis) + xp = get_array_namespace(arrays[0]) + return xp.stack(as_shared_dtype(arrays), axis=axis) @contextlib.contextmanager @@ -323,11 +350,8 @@ def f(values, axis=None, skipna=None, **kwargs): if name in ["sum", "prod"]: kwargs.pop("min_count", None) - if hasattr(values, "__array_namespace__"): - xp = values.__array_namespace__() - func = getattr(xp, name) - else: - func = getattr(np, name) + xp = get_array_namespace(values) + func = getattr(xp, name) try: with warnings.catch_warnings(): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 4c71658b577..fe8b2e784a7 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -5,7 +5,24 @@ import numpy as np from . import dtypes, nputils, utils -from .duck_array_ops import count, fillna, isnull, where, where_method +from .duck_array_ops import ( + count, + fillna, + get_array_namespace, + isnull, + where, + where_method, +) + + +def _replace_nan(a, val): + """ + replace nan in a by val, and returns the replaced array and the nan + position + """ + mask = isnull(a) + return where_method(val, mask, a), mask + def _maybe_null_out(result, axis, mask, min_count=1): @@ -83,8 +100,13 @@ def nanargmax(a, axis=None): def nansum(a, axis=None, dtype=None, out=None, min_count=None): - mask = isnull(a) - result = np.nansum(a, axis=axis, dtype=dtype) + if hasattr(a, "__array_namespace__"): + a, mask = _replace_nan(a, 0) + xp = get_array_namespace(a) + result = xp.sum(a, axis=axis, dtype=dtype) + else: + mask = isnull(a) + result = np.nansum(a, axis=axis, dtype=dtype) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 796c178f2a0..e8fd0d592f2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1637,7 +1637,8 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable): reordered = self.transpose(*dim_order) new_shape = reordered.shape[: len(other_dims)] + (-1,) - new_data = reordered.data.reshape(new_shape) + xp = duck_array_ops.get_array_namespace(reordered.data) + new_data = xp.reshape(reordered.data, new_shape) new_dims = reordered.dims[: len(other_dims)] + (new_dim,) return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index a15492028dd..f7c87c88f84 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -32,6 +32,14 @@ def test_arithmetic(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.sum() + actual = xp_arr.sum() + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_aggregation_skipna(arrays) -> None: np_arr, xp_arr = arrays expected = np_arr.sum(skipna=False) actual = xp_arr.sum(skipna=False) @@ -39,6 +47,15 @@ def test_aggregation(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: assert_equal(actual, expected) +def test_astype(arrays) -> None: + np_arr, xp_arr = arrays + expected = np_arr.astype(np.int64) + actual = xp_arr.astype(np.int64) + assert actual.dtype == np.int64 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + def test_indexing(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: np_arr, xp_arr = arrays expected = np_arr[:, 0] @@ -59,3 +76,20 @@ def test_reorganizing_operation(arrays: tuple[xr.DataArray, xr.DataArray]) -> No actual = xp_arr.transpose() assert isinstance(actual.data, Array) assert_equal(actual, expected) + + +def test_stack(arrays: tuple[xr.DataArray, xr.DataArray]) -> None: + np_arr, xp_arr = arrays + expected = np_arr.stack(z=("x", "y")) + actual = xp_arr.stack(z=("x", "y")) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_where() -> None: + np_arr = xr.DataArray(np.array([1, 0]), dims="x") + xp_arr = xr.DataArray(xp.asarray([1, 0]), dims="x") + expected = xr.where(np_arr, 1, 0) + actual = xr.where(xp_arr, 1, 0) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) From 0662c1ae7e2fa4b4f2daa79ed330b57254a28563 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 23 Sep 2022 14:24:42 +0100 Subject: [PATCH 2/7] Add `reshape` to `duck_array_ops` --- xarray/core/duck_array_ops.py | 5 +++++ xarray/core/variable.py | 3 +-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index b61ffb700a8..c05a25edd2f 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -314,6 +314,11 @@ def stack(arrays, axis=0): return xp.stack(as_shared_dtype(arrays), axis=axis) +def reshape(array, shape): + xp = get_array_namespace(array) + return xp.reshape(array, shape) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index e8fd0d592f2..c70cd45b502 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1637,8 +1637,7 @@ def _stack_once(self, dims: list[Hashable], new_dim: Hashable): reordered = self.transpose(*dim_order) new_shape = reordered.shape[: len(other_dims)] + (-1,) - xp = duck_array_ops.get_array_namespace(reordered.data) - new_data = xp.reshape(reordered.data, new_shape) + new_data = duck_array_ops.reshape(reordered.data, new_shape) new_dims = reordered.dims[: len(other_dims)] + (new_dim,) return Variable(new_dims, new_data, self._attrs, self._encoding, fastpath=True) From bb502d6f27eb5b939cf6911f0bd81b9b74e1cce9 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 23 Sep 2022 15:17:08 +0100 Subject: [PATCH 3/7] Simplify `as_shared_dtype` --- xarray/core/duck_array_ops.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index c05a25edd2f..cc76b37ab8e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -196,13 +196,6 @@ def as_shared_dtype(scalars_or_arrays, xp=np): import cupy as cp arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] - elif any(hasattr(x, "__array_namespace__") for x in scalars_or_arrays): - xp = [x for x in scalars_or_arrays if hasattr(x, "__array_namespace__")][ - 0 - ].__array_namespace__() - arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] - out_type = dtypes.result_type(*arrays) - return [xp.astype(x, out_type, copy=False) for x in arrays] else: arrays = [asarray(x, xp=xp) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars @@ -311,7 +304,7 @@ def concatenate(arrays, axis=0): def stack(arrays, axis=0): """stack() with better dtype promotion rules.""" xp = get_array_namespace(arrays[0]) - return xp.stack(as_shared_dtype(arrays), axis=axis) + return xp.stack(as_shared_dtype(arrays, xp=xp), axis=axis) def reshape(array, shape): From 0f22fc033c748d6c32cb14f71cf31320d698ed43 Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 23 Sep 2022 15:31:12 +0100 Subject: [PATCH 4/7] Add `sum_where` to `duck_array_ops` --- xarray/core/duck_array_ops.py | 10 ++++++++++ xarray/core/nanops.py | 19 +++---------------- xarray/tests/test_array_api.py | 12 ++++++++++-- 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cc76b37ab8e..d9c2a91d1c9 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -277,6 +277,16 @@ def count(data, axis=None): return np.sum(np.logical_not(isnull(data)), axis=axis) +def sum_where(data, axis=None, dtype=None, where=None): + xp = get_array_namespace(data) + if where is not None: + a = where_method(xp.zeros_like(data), where, data) + else: + a = data + result = xp.sum(a, axis=axis, dtype=dtype) + return result + + def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" xp = get_array_namespace(condition) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index fe8b2e784a7..4bea74051ff 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -5,14 +5,7 @@ import numpy as np from . import dtypes, nputils, utils -from .duck_array_ops import ( - count, - fillna, - get_array_namespace, - isnull, - where, - where_method, -) +from .duck_array_ops import count, fillna, isnull, sum_where, where, where_method def _replace_nan(a, val): @@ -24,7 +17,6 @@ def _replace_nan(a, val): return where_method(val, mask, a), mask - def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out @@ -100,13 +92,8 @@ def nanargmax(a, axis=None): def nansum(a, axis=None, dtype=None, out=None, min_count=None): - if hasattr(a, "__array_namespace__"): - a, mask = _replace_nan(a, 0) - xp = get_array_namespace(a) - result = xp.sum(a, axis=axis, dtype=dtype) - else: - mask = isnull(a) - result = np.nansum(a, axis=axis, dtype=dtype) + mask = isnull(a) + result = sum_where(a, axis=axis, dtype=dtype, where=mask) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index f7c87c88f84..7940c979249 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -17,8 +17,16 @@ @pytest.fixture def arrays() -> tuple[xr.DataArray, xr.DataArray]: - np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) - xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) + np_arr = xr.DataArray( + np.array([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]), + dims=("x", "y"), + coords={"x": [10, 20]}, + ) + xp_arr = xr.DataArray( + xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]), + dims=("x", "y"), + coords={"x": [10, 20]}, + ) assert isinstance(xp_arr.data, Array) return np_arr, xp_arr From 3c4518090a4cd95197e19c9d65aab3443480d16d Mon Sep 17 00:00:00 2001 From: Tom White Date: Fri, 30 Sep 2022 10:12:16 +0100 Subject: [PATCH 5/7] Remove unused `_replace_nan` function --- xarray/core/nanops.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 4bea74051ff..651fd9aca17 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -8,15 +8,6 @@ from .duck_array_ops import count, fillna, isnull, sum_where, where, where_method -def _replace_nan(a, val): - """ - replace nan in a by val, and returns the replaced array and the nan - position - """ - mask = isnull(a) - return where_method(val, mask, a), mask - - def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out From 163b1096c7d8b7f4911a942cc78ebab87cc0d356 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Oct 2022 15:57:00 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xarray/core/duck_array_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index d9c2a91d1c9..4ee5bf85246 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -30,7 +30,6 @@ ) from numpy.lib.stride_tricks import sliding_window_view # noqa - from . import dask_array_ops, dtypes, nputils from .nputils import nanfirst, nanlast from .pycompat import cupy_array_type, is_duck_dask_array From 177012c8565b3dbec5df99bbf71704c817dc3746 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 17 Oct 2022 09:59:54 -0600 Subject: [PATCH 7/7] Update xarray/core/duck_array_ops.py --- xarray/core/duck_array_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 4ee5bf85246..8c92bc4ee6e 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -20,6 +20,7 @@ from numpy import concatenate as _concatenate from numpy import ( # noqa einsum, + gradient, isclose, isin, isnat,