Skip to content

Commit 00a443f

Browse files
committed
Add sum_where to duck_array_ops
1 parent beb333b commit 00a443f

File tree

3 files changed

+23
-18
lines changed

3 files changed

+23
-18
lines changed

xarray/core/duck_array_ops.py

+10
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,16 @@ def count(data, axis=None):
281281
return np.sum(np.logical_not(isnull(data)), axis=axis)
282282

283283

284+
def sum_where(data, axis=None, dtype=None, where=None):
285+
xp = get_array_namespace(data)
286+
if where is not None:
287+
a = where_method(xp.zeros_like(data), where, data)
288+
else:
289+
a = data
290+
result = xp.sum(a, axis=axis, dtype=dtype)
291+
return result
292+
293+
284294
def where(condition, x, y):
285295
"""Three argument where() with better dtype promotion rules."""
286296
xp = get_array_namespace(condition)

xarray/core/nanops.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,7 @@
55
import numpy as np
66

77
from . import dtypes, nputils, utils
8-
from .duck_array_ops import (
9-
count,
10-
fillna,
11-
get_array_namespace,
12-
isnull,
13-
where,
14-
where_method,
15-
)
8+
from .duck_array_ops import count, fillna, isnull, sum_where, where, where_method
169

1710

1811
def _replace_nan(a, val):
@@ -24,7 +17,6 @@ def _replace_nan(a, val):
2417
return where_method(val, mask, a), mask
2518

2619

27-
2820
def _maybe_null_out(result, axis, mask, min_count=1):
2921
"""
3022
xarray version of pandas.core.nanops._maybe_null_out
@@ -100,13 +92,8 @@ def nanargmax(a, axis=None):
10092

10193

10294
def nansum(a, axis=None, dtype=None, out=None, min_count=None):
103-
if hasattr(a, "__array_namespace__"):
104-
a, mask = _replace_nan(a, 0)
105-
xp = get_array_namespace(a)
106-
result = xp.sum(a, axis=axis, dtype=dtype)
107-
else:
108-
mask = isnull(a)
109-
result = np.nansum(a, axis=axis, dtype=dtype)
95+
mask = isnull(a)
96+
result = sum_where(a, axis=axis, dtype=dtype, where=mask)
11097
if min_count is not None:
11198
return _maybe_null_out(result, axis, mask, min_count)
11299
else:

xarray/tests/test_array_api.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,16 @@
1717

1818
@pytest.fixture
1919
def arrays() -> tuple[xr.DataArray, xr.DataArray]:
20-
np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
21-
xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]})
20+
np_arr = xr.DataArray(
21+
np.array([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]),
22+
dims=("x", "y"),
23+
coords={"x": [10, 20]},
24+
)
25+
xp_arr = xr.DataArray(
26+
xp.asarray([[1.0, 2.0, 3.0], [4.0, 5.0, np.nan]]),
27+
dims=("x", "y"),
28+
coords={"x": [10, 20]},
29+
)
2230
assert isinstance(xp_arr.data, Array)
2331
return np_arr, xp_arr
2432

0 commit comments

Comments
 (0)