Skip to content

Commit 7a0ab8d

Browse files
TomNicholasdcherianIllviljan
authored andcommitted
Rely on NEP-18 to dispatch to dask in duck_array_ops (pydata#5571)
* basic test for the mean * minimum to get mean working * don't even need to call dask specifically * remove reference to dask when dispatching to modules * fixed special case of pandas vs dask isnull * removed _dask_or_eager_func completely * noqa * pre-commit * what's new * linting * properly import dask for test * fix iris conversion error by rolling back treatment of np.ma.masked_invalid * linting * Update xarray/core/duck_array_ops.py Co-authored-by: Illviljan <[email protected]> * Update xarray/core/duck_array_ops.py Co-authored-by: Illviljan <[email protected]> * Update xarray/core/duck_array_ops.py Co-authored-by: Illviljan <[email protected]> Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Illviljan <[email protected]>
1 parent 46c67ec commit 7a0ab8d

File tree

5 files changed

+91
-87
lines changed

5 files changed

+91
-87
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@ Internal Changes
206206
pandas-specific implementation into ``PandasIndex.query()`` and
207207
``PandasMultiIndex.query()`` (:pull:`5322`).
208208
By `Benoit Bovy <https://github.com/benbovy>`_.
209+
- Refactor `xarray.core.duck_array_ops` to no longer special-case dispatching to
210+
dask versions of functions when acting on dask arrays, instead relying numpy
211+
and dask's adherence to NEP-18 to dispatch automatically. (:pull:`5571`)
212+
By `Tom Nicholas <https://github.com/TomNicholas>`_.
209213

210214
.. _whats-new.0.18.2:
211215

xarray/core/duck_array_ops.py

+56-59
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@
1111

1212
import numpy as np
1313
import pandas as pd
14+
from numpy import all as array_all # noqa
15+
from numpy import any as array_any # noqa
16+
from numpy import zeros_like # noqa
17+
from numpy import around, broadcast_to # noqa
18+
from numpy import concatenate as _concatenate
19+
from numpy import einsum, isclose, isin, isnan, isnat, pad # noqa
20+
from numpy import stack as _stack
21+
from numpy import take, tensordot, transpose, unravel_index # noqa
22+
from numpy import where as _where
1423

1524
from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils
1625
from .nputils import nanfirst, nanlast
@@ -34,31 +43,15 @@ def _dask_or_eager_func(
3443
name,
3544
eager_module=np,
3645
dask_module=dask_array,
37-
list_of_args=False,
38-
array_args=slice(1),
39-
requires_dask=None,
4046
):
4147
"""Create a function that dispatches to dask for dask array inputs."""
42-
if dask_module is not None:
43-
44-
def f(*args, **kwargs):
45-
if list_of_args:
46-
dispatch_args = args[0]
47-
else:
48-
dispatch_args = args[array_args]
49-
if any(is_duck_dask_array(a) for a in dispatch_args):
50-
try:
51-
wrapped = getattr(dask_module, name)
52-
except AttributeError as e:
53-
raise AttributeError(f"{e}: requires dask >={requires_dask}")
54-
else:
55-
wrapped = getattr(eager_module, name)
56-
return wrapped(*args, **kwargs)
5748

58-
else:
59-
60-
def f(*args, **kwargs):
61-
return getattr(eager_module, name)(*args, **kwargs)
49+
def f(*args, **kwargs):
50+
if any(is_duck_dask_array(a) for a in args):
51+
wrapped = getattr(dask_module, name)
52+
else:
53+
wrapped = getattr(eager_module, name)
54+
return wrapped(*args, **kwargs)
6255

6356
return f
6457

@@ -72,16 +65,40 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
7265
raise NotImplementedError(msg % func_name)
7366

7467

75-
around = _dask_or_eager_func("around")
76-
isclose = _dask_or_eager_func("isclose")
77-
68+
# Requires special-casing because pandas won't automatically dispatch to dask.isnull via NEP-18
69+
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd, dask_module=dask_array)
7870

79-
isnat = np.isnat
80-
isnan = _dask_or_eager_func("isnan")
81-
zeros_like = _dask_or_eager_func("zeros_like")
82-
83-
84-
pandas_isnull = _dask_or_eager_func("isnull", eager_module=pd)
71+
# np.around has failing doctests, overwrite it so they pass:
72+
# https://github.com/numpy/numpy/issues/19759
73+
around.__doc__ = str.replace(
74+
around.__doc__ or "",
75+
"array([0., 2.])",
76+
"array([0., 2.])",
77+
)
78+
around.__doc__ = str.replace(
79+
around.__doc__ or "",
80+
"array([0., 2.])",
81+
"array([0., 2.])",
82+
)
83+
around.__doc__ = str.replace(
84+
around.__doc__ or "",
85+
"array([0.4, 1.6])",
86+
"array([0.4, 1.6])",
87+
)
88+
around.__doc__ = str.replace(
89+
around.__doc__ or "",
90+
"array([0., 2., 2., 4., 4.])",
91+
"array([0., 2., 2., 4., 4.])",
92+
)
93+
around.__doc__ = str.replace(
94+
around.__doc__ or "",
95+
(
96+
' .. [2] "How Futile are Mindless Assessments of\n'
97+
' Roundoff in Floating-Point Computation?", William Kahan,\n'
98+
" https://people.eecs.berkeley.edu/~wkahan/Mindless.pdf\n"
99+
),
100+
"",
101+
)
85102

86103

87104
def isnull(data):
@@ -114,21 +131,10 @@ def notnull(data):
114131
return ~isnull(data)
115132

116133

117-
transpose = _dask_or_eager_func("transpose")
118-
_where = _dask_or_eager_func("where", array_args=slice(3))
119-
isin = _dask_or_eager_func("isin", array_args=slice(2))
120-
take = _dask_or_eager_func("take")
121-
broadcast_to = _dask_or_eager_func("broadcast_to")
122-
pad = _dask_or_eager_func("pad", dask_module=dask_array_compat)
123-
124-
_concatenate = _dask_or_eager_func("concatenate", list_of_args=True)
125-
_stack = _dask_or_eager_func("stack", list_of_args=True)
126-
127-
array_all = _dask_or_eager_func("all")
128-
array_any = _dask_or_eager_func("any")
129-
130-
tensordot = _dask_or_eager_func("tensordot", array_args=slice(2))
131-
einsum = _dask_or_eager_func("einsum", array_args=slice(1, None))
134+
# TODO replace with simply np.ma.masked_invalid once numpy/numpy#16022 is fixed
135+
masked_invalid = _dask_or_eager_func(
136+
"masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None)
137+
)
132138

133139

134140
def gradient(x, coord, axis, edge_order):
@@ -166,11 +172,6 @@ def cumulative_trapezoid(y, x, axis):
166172
return cumsum(integrand, axis=axis, skipna=False)
167173

168174

169-
masked_invalid = _dask_or_eager_func(
170-
"masked_invalid", eager_module=np.ma, dask_module=getattr(dask_array, "ma", None)
171-
)
172-
173-
174175
def astype(data, dtype, **kwargs):
175176
if (
176177
isinstance(data, sparse_array_type)
@@ -317,9 +318,7 @@ def _ignore_warnings_if(condition):
317318
yield
318319

319320

320-
def _create_nan_agg_method(
321-
name, dask_module=dask_array, coerce_strings=False, invariant_0d=False
322-
):
321+
def _create_nan_agg_method(name, coerce_strings=False, invariant_0d=False):
323322
from . import nanops
324323

325324
def f(values, axis=None, skipna=None, **kwargs):
@@ -344,7 +343,8 @@ def f(values, axis=None, skipna=None, **kwargs):
344343
else:
345344
if name in ["sum", "prod"]:
346345
kwargs.pop("min_count", None)
347-
func = _dask_or_eager_func(name, dask_module=dask_module)
346+
347+
func = getattr(np, name)
348348

349349
try:
350350
with warnings.catch_warnings():
@@ -378,9 +378,7 @@ def f(values, axis=None, skipna=None, **kwargs):
378378
std.numeric_only = True
379379
var = _create_nan_agg_method("var")
380380
var.numeric_only = True
381-
median = _create_nan_agg_method(
382-
"median", dask_module=dask_array_compat, invariant_0d=True
383-
)
381+
median = _create_nan_agg_method("median", invariant_0d=True)
384382
median.numeric_only = True
385383
prod = _create_nan_agg_method("prod", invariant_0d=True)
386384
prod.numeric_only = True
@@ -389,7 +387,6 @@ def f(values, axis=None, skipna=None, **kwargs):
389387
cumprod_1d.numeric_only = True
390388
cumsum_1d = _create_nan_agg_method("cumsum", invariant_0d=True)
391389
cumsum_1d.numeric_only = True
392-
unravel_index = _dask_or_eager_func("unravel_index")
393390

394391

395392
_mean = _create_nan_agg_method("mean", invariant_0d=True)

xarray/core/nanops.py

+11-26
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,7 @@
33
import numpy as np
44

55
from . import dtypes, nputils, utils
6-
from .duck_array_ops import (
7-
_dask_or_eager_func,
8-
count,
9-
fillna,
10-
isnull,
11-
where,
12-
where_method,
13-
)
6+
from .duck_array_ops import count, fillna, isnull, where, where_method
147
from .pycompat import dask_array_type
158

169
try:
@@ -53,7 +46,7 @@ def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs):
5346
"""
5447
valid_count = count(value, axis=axis)
5548
value = fillna(value, fill_value)
56-
data = _dask_or_eager_func(func)(value, axis=axis, **kwargs)
49+
data = getattr(np, func)(value, axis=axis, **kwargs)
5750

5851
# TODO This will evaluate dask arrays and might be costly.
5952
if (valid_count == 0).any():
@@ -111,7 +104,7 @@ def nanargmax(a, axis=None):
111104

112105
def nansum(a, axis=None, dtype=None, out=None, min_count=None):
113106
a, mask = _replace_nan(a, 0)
114-
result = _dask_or_eager_func("sum")(a, axis=axis, dtype=dtype)
107+
result = np.sum(a, axis=axis, dtype=dtype)
115108
if min_count is not None:
116109
return _maybe_null_out(result, axis, mask, min_count)
117110
else:
@@ -120,7 +113,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None):
120113

121114
def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
122115
"""In house nanmean. ddof argument will be used in _nanvar method"""
123-
from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method
116+
from .duck_array_ops import count, fillna, where_method
124117

125118
valid_count = count(value, axis=axis)
126119
value = fillna(value, 0)
@@ -129,7 +122,7 @@ def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs):
129122
if dtype is None and value.dtype.kind == "O":
130123
dtype = value.dtype if value.dtype.kind in ["cf"] else float
131124

132-
data = _dask_or_eager_func("sum")(value, axis=axis, dtype=dtype, **kwargs)
125+
data = np.sum(value, axis=axis, dtype=dtype, **kwargs)
133126
data = data / (valid_count - ddof)
134127
return where_method(data, valid_count != 0)
135128

@@ -155,7 +148,7 @@ def nanmedian(a, axis=None, out=None):
155148
# possibly blow memory
156149
if axis is not None and len(np.atleast_1d(axis)) == a.ndim:
157150
axis = None
158-
return _dask_or_eager_func("nanmedian", eager_module=nputils)(a, axis=axis)
151+
return nputils.nanmedian(a, axis=axis)
159152

160153

161154
def _nanvar_object(value, axis=None, ddof=0, keepdims=False, **kwargs):
@@ -170,33 +163,25 @@ def nanvar(a, axis=None, dtype=None, out=None, ddof=0):
170163
if a.dtype.kind == "O":
171164
return _nanvar_object(a, axis=axis, dtype=dtype, ddof=ddof)
172165

173-
return _dask_or_eager_func("nanvar", eager_module=nputils)(
174-
a, axis=axis, dtype=dtype, ddof=ddof
175-
)
166+
return nputils.nanvar(a, axis=axis, dtype=dtype, ddof=ddof)
176167

177168

178169
def nanstd(a, axis=None, dtype=None, out=None, ddof=0):
179-
return _dask_or_eager_func("nanstd", eager_module=nputils)(
180-
a, axis=axis, dtype=dtype, ddof=ddof
181-
)
170+
return nputils.nanstd(a, axis=axis, dtype=dtype, ddof=ddof)
182171

183172

184173
def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
185174
a, mask = _replace_nan(a, 1)
186-
result = _dask_or_eager_func("nanprod")(a, axis=axis, dtype=dtype, out=out)
175+
result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out)
187176
if min_count is not None:
188177
return _maybe_null_out(result, axis, mask, min_count)
189178
else:
190179
return result
191180

192181

193182
def nancumsum(a, axis=None, dtype=None, out=None):
194-
return _dask_or_eager_func("nancumsum", eager_module=nputils)(
195-
a, axis=axis, dtype=dtype
196-
)
183+
return nputils.nancumsum(a, axis=axis, dtype=dtype)
197184

198185

199186
def nancumprod(a, axis=None, dtype=None, out=None):
200-
return _dask_or_eager_func("nancumprod", eager_module=nputils)(
201-
a, axis=axis, dtype=dtype
202-
)
187+
return nputils.nancumprod(a, axis=axis, dtype=dtype)

xarray/tests/test_units.py

+19
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
assert_duckarray_allclose,
1414
assert_equal,
1515
assert_identical,
16+
requires_dask,
1617
requires_matplotlib,
1718
)
1819
from .test_plot import PlotTestCase
@@ -5579,6 +5580,24 @@ def test_merge(self, variant, unit, error, dtype):
55795580
assert_equal(expected, actual)
55805581

55815582

5583+
@requires_dask
5584+
class TestPintWrappingDask:
5585+
def test_duck_array_ops(self):
5586+
import dask.array
5587+
5588+
d = dask.array.array([1, 2, 3])
5589+
q = pint.Quantity(d, units="m")
5590+
da = xr.DataArray(q, dims="x")
5591+
5592+
actual = da.mean().compute()
5593+
actual.name = None
5594+
expected = xr.DataArray(pint.Quantity(np.array(2.0), units="m"))
5595+
5596+
assert_units_equal(expected, actual)
5597+
# Don't use isinstance b/c we don't want to allow subclasses through
5598+
assert type(expected.data) == type(actual.data) # noqa
5599+
5600+
55825601
@requires_matplotlib
55835602
class TestPlots(PlotTestCase):
55845603
def test_units_in_line_plot_labels(self):

xarray/ufuncs.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from .core.dataarray import DataArray as _DataArray
2222
from .core.dataset import Dataset as _Dataset
23-
from .core.duck_array_ops import _dask_or_eager_func
2423
from .core.groupby import GroupBy as _GroupBy
2524
from .core.pycompat import dask_array_type as _dask_array_type
2625
from .core.variable import Variable as _Variable
@@ -71,7 +70,7 @@ def __call__(self, *args, **kwargs):
7170
new_args = tuple(reversed(args))
7271

7372
if res is _UNDEFINED:
74-
f = _dask_or_eager_func(self._name, array_args=slice(len(args)))
73+
f = getattr(_np, self._name)
7574
res = f(*new_args, **kwargs)
7675
if res is NotImplemented:
7776
raise TypeError(

0 commit comments

Comments
 (0)