Skip to content

Commit dc2dd89

Browse files
Russell Manserdcheriankeewis
authored
Change isinstance checks to duck Dask Array checks #4208 (#4221)
* Change isinstance checks to duck Dask Array checks #4208 * Use is_dask_collection in is_duck_dask_array * Use is_dask_collection in is_duck_dask_array * Revert to isinstance checks according to review discussion * Move is_duck_dask_array to pycompat.py and use tokenize for comparisons * isort * Implement `is_duck_array` to replace `is_array_like` * Rename `is_array_like` to `is_duck_array` * `is_duck_array` checks for `__array_function__` and `__array_ufunc__` in addition to previous checks * Replace checks for `is_duck_dask_array` and `__array_function__` with `is_duck_array` * Skip numpy duck array tests when NEP18 is not active * Use utils.is_duck_array in xarray/core/formatting.py * Replace locally defined `is_duck_array` in _diff_mapping_repr * Replace `"__array_function__"` and `is_duck_dask_array` check in `short_data_repr` * Revert back to isinstance check for iris cube * Add is_duck_array_or_ndarray function to utils * Use is_duck_array_or_ndarray for duck array checks without NEP18 * Remove is_duck_dask_array_or_ndarray, replace checks with is_duck_array * Add explicit check for NumPy array to is_duck_array * Replace is_duck_array_or_ndarray checks with is_duck_array * Remove is_duck_array check for deep copy Co-authored-by: keewis <[email protected]> * Use is_duck_array check in load * Move duck dask array tokenize tests from test_units.py to test_dask.py * Use _importorskip to require pint >=0.15 instead of pytest.mark.skipif Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: keewis <[email protected]>
1 parent 9ee0f01 commit dc2dd89

23 files changed

+156
-103
lines changed

xarray/backends/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ..conventions import cf_encoder
88
from ..core import indexing
9-
from ..core.pycompat import dask_array_type
9+
from ..core.pycompat import is_duck_dask_array
1010
from ..core.utils import FrozenDict, NdimSizeLenMixin
1111

1212
# Create a logger object, but don't add any handlers. Leave that to user code.
@@ -134,7 +134,7 @@ def __init__(self, lock=None):
134134
self.lock = lock
135135

136136
def add(self, source, target, region=None):
137-
if isinstance(source, dask_array_type):
137+
if is_duck_dask_array(source):
138138
self.sources.append(source)
139139
self.targets.append(target)
140140
self.regions.append(region)

xarray/coding/strings.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from ..core import indexing
7-
from ..core.pycompat import dask_array_type
7+
from ..core.pycompat import is_duck_dask_array
88
from ..core.variable import Variable
99
from .variables import (
1010
VariableCoder,
@@ -130,7 +130,7 @@ def bytes_to_char(arr):
130130
if arr.dtype.kind != "S":
131131
raise ValueError("argument must have a fixed-width bytes dtype")
132132

133-
if isinstance(arr, dask_array_type):
133+
if is_duck_dask_array(arr):
134134
import dask.array as da
135135

136136
return da.map_blocks(
@@ -166,7 +166,7 @@ def char_to_bytes(arr):
166166
# can't make an S0 dtype
167167
return np.zeros(arr.shape[:-1], dtype=np.string_)
168168

169-
if isinstance(arr, dask_array_type):
169+
if is_duck_dask_array(arr):
170170
import dask.array as da
171171

172172
if len(arr.chunks[-1]) > 1:

xarray/coding/variables.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88

99
from ..core import dtypes, duck_array_ops, indexing
10-
from ..core.pycompat import dask_array_type
10+
from ..core.pycompat import is_duck_dask_array
1111
from ..core.variable import Variable
1212

1313

@@ -54,7 +54,7 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin):
5454
"""
5555

5656
def __init__(self, array, func, dtype):
57-
assert not isinstance(array, dask_array_type)
57+
assert not is_duck_dask_array(array)
5858
self.array = indexing.as_indexable(array)
5959
self.func = func
6060
self._dtype = dtype
@@ -91,7 +91,7 @@ def lazy_elemwise_func(array, func, dtype):
9191
-------
9292
Either a dask.array.Array or _ElementwiseFunctionArray.
9393
"""
94-
if isinstance(array, dask_array_type):
94+
if is_duck_dask_array(array):
9595
return array.map_blocks(func, dtype=dtype)
9696
else:
9797
return _ElementwiseFunctionArray(array, func, dtype)

xarray/conventions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .coding.variables import SerializationWarning, pop_to
99
from .core import duck_array_ops, indexing
1010
from .core.common import contains_cftime_datetimes
11-
from .core.pycompat import dask_array_type
11+
from .core.pycompat import is_duck_dask_array
1212
from .core.variable import IndexVariable, Variable, as_variable
1313

1414

@@ -178,7 +178,7 @@ def ensure_dtype_not_object(var, name=None):
178178
if var.dtype.kind == "O":
179179
dims, data, attrs, encoding = _var_as_tuple(var)
180180

181-
if isinstance(data, dask_array_type):
181+
if is_duck_dask_array(data):
182182
warnings.warn(
183183
"variable {} has data in the form of a dask array with "
184184
"dtype=object, which means it is being loaded into memory "
@@ -351,7 +351,7 @@ def decode_cf_variable(
351351
del attributes["dtype"]
352352
data = BoolTypeArray(data)
353353

354-
if not isinstance(data, dask_array_type):
354+
if not is_duck_dask_array(data):
355355
data = indexing.LazilyOuterIndexedArray(data)
356356

357357
return Variable(dimensions, data, attributes, encoding=encoding)

xarray/convert.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .core import duck_array_ops
1111
from .core.dataarray import DataArray
1212
from .core.dtypes import get_fill_value
13+
from .core.pycompat import dask_array_type
1314

1415
cdms2_ignored_attrs = {"name", "tileIndex"}
1516
iris_forbidden_keys = {
@@ -246,8 +247,6 @@ def from_iris(cube):
246247
"""Convert a Iris cube into an DataArray"""
247248
import iris.exceptions
248249

249-
from xarray.core.pycompat import dask_array_type
250-
251250
name = _name(cube)
252251
if name == "unknown":
253252
name = None

xarray/core/accessor_dt.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
is_np_datetime_like,
77
is_np_timedelta_like,
88
)
9-
from .pycompat import dask_array_type
9+
from .pycompat import is_duck_dask_array
1010

1111

1212
def _season_from_months(months):
@@ -69,7 +69,7 @@ def _get_date_field(values, name, dtype):
6969
else:
7070
access_method = _access_through_cftimeindex
7171

72-
if isinstance(values, dask_array_type):
72+
if is_duck_dask_array(values):
7373
from dask.array import map_blocks
7474

7575
return map_blocks(access_method, values, name, dtype=dtype)
@@ -114,7 +114,7 @@ def _round_field(values, name, freq):
114114
Array-like of datetime fields accessed for each element in values
115115
116116
"""
117-
if isinstance(values, dask_array_type):
117+
if is_duck_dask_array(values):
118118
from dask.array import map_blocks
119119

120120
dtype = np.datetime64 if is_np_datetime_like(values.dtype) else np.dtype("O")
@@ -151,7 +151,7 @@ def _strftime(values, date_format):
151151
access_method = _strftime_through_series
152152
else:
153153
access_method = _strftime_through_cftimeindex
154-
if isinstance(values, dask_array_type):
154+
if is_duck_dask_array(values):
155155
from dask.array import map_blocks
156156

157157
return map_blocks(access_method, values, date_format)

xarray/core/common.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .arithmetic import SupportsArithmetic
2424
from .npcompat import DTypeLike
2525
from .options import OPTIONS, _get_keep_attrs
26-
from .pycompat import dask_array_type
26+
from .pycompat import is_duck_dask_array
2727
from .rolling_exp import RollingExp
2828
from .utils import Frozen, either_dict_or_kwargs, is_scalar
2929

@@ -1507,7 +1507,7 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
15071507
if fill_value is dtypes.NA:
15081508
fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype)
15091509

1510-
if isinstance(other.data, dask_array_type):
1510+
if is_duck_dask_array(other.data):
15111511
import dask.array
15121512

15131513
if dtype is None:
@@ -1652,7 +1652,7 @@ def _contains_cftime_datetimes(array) -> bool:
16521652
else:
16531653
if array.dtype == np.dtype("O") and array.size > 0:
16541654
sample = array.ravel()[0]
1655-
if isinstance(sample, dask_array_type):
1655+
if is_duck_dask_array(sample):
16561656
sample = sample.compute()
16571657
if isinstance(sample, np.ndarray):
16581658
sample = sample.item()

xarray/core/computation.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from .alignment import align, deep_align
3030
from .merge import merge_coordinates_without_align
3131
from .options import OPTIONS
32-
from .pycompat import dask_array_type
32+
from .pycompat import is_duck_dask_array
3333
from .utils import is_dict_like
3434
from .variable import Variable
3535

@@ -610,7 +610,7 @@ def apply_variable_ufunc(
610610
for arg, core_dims in zip(args, signature.input_core_dims)
611611
]
612612

613-
if any(isinstance(array, dask_array_type) for array in input_data):
613+
if any(is_duck_dask_array(array) for array in input_data):
614614
if dask == "forbidden":
615615
raise ValueError(
616616
"apply_ufunc encountered a dask array on an "
@@ -726,7 +726,7 @@ def func(*arrays):
726726

727727
def apply_array_ufunc(func, *args, dask="forbidden"):
728728
"""Apply a ndarray level function over ndarray objects."""
729-
if any(isinstance(arg, dask_array_type) for arg in args):
729+
if any(is_duck_dask_array(arg) for arg in args):
730730
if dask == "forbidden":
731731
raise ValueError(
732732
"apply_ufunc encountered a dask array on an "
@@ -1604,7 +1604,7 @@ def _calc_idxminmax(
16041604
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)
16051605

16061606
# Handle dask arrays.
1607-
if isinstance(array.data, dask_array_type):
1607+
if is_duck_dask_array(array.data):
16081608
import dask.array
16091609

16101610
chunks = dict(zip(array.dims, array.chunks))

xarray/core/dask_array_compat.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from .pycompat import dask_array_type
7+
from .pycompat import is_duck_dask_array
88

99
try:
1010
import dask.array as da
@@ -39,7 +39,7 @@ def meta_from_array(x, ndim=None, dtype=None):
3939
"""
4040
# If using x._meta, x must be a Dask Array, some libraries (e.g. zarr)
4141
# implement a _meta attribute that are incompatible with Dask Array._meta
42-
if hasattr(x, "_meta") and isinstance(x, dask_array_type):
42+
if hasattr(x, "_meta") and is_duck_dask_array(x):
4343
x = x._meta
4444

4545
if dtype is None and x is None:

xarray/core/dataset.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
)
8181
from .missing import get_clean_interp_index
8282
from .options import OPTIONS, _get_keep_attrs
83-
from .pycompat import dask_array_type
83+
from .pycompat import is_duck_dask_array
8484
from .utils import (
8585
Default,
8686
Frozen,
@@ -645,9 +645,7 @@ def load(self, **kwargs) -> "Dataset":
645645
"""
646646
# access .data to coerce everything to numpy or dask arrays
647647
lazy_data = {
648-
k: v._data
649-
for k, v in self.variables.items()
650-
if isinstance(v._data, dask_array_type)
648+
k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data)
651649
}
652650
if lazy_data:
653651
import dask.array as da
@@ -815,9 +813,7 @@ def _persist_inplace(self, **kwargs) -> "Dataset":
815813
"""Persist all Dask arrays in memory"""
816814
# access .data to coerce everything to numpy or dask arrays
817815
lazy_data = {
818-
k: v._data
819-
for k, v in self.variables.items()
820-
if isinstance(v._data, dask_array_type)
816+
k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data)
821817
}
822818
if lazy_data:
823819
import dask
@@ -6043,7 +6039,7 @@ def polyfit(
60436039
if dim not in da.dims:
60446040
continue
60456041

6046-
if isinstance(da.data, dask_array_type) and (
6042+
if is_duck_dask_array(da.data) and (
60476043
rank != order or full or skipna is None
60486044
):
60496045
# Current algorithm with dask and skipna=False neither supports

xarray/core/duck_array_ops.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@
1515

1616
from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils
1717
from .nputils import nanfirst, nanlast
18-
from .pycompat import cupy_array_type, dask_array_type, sparse_array_type
18+
from .pycompat import (
19+
cupy_array_type,
20+
dask_array_type,
21+
is_duck_dask_array,
22+
sparse_array_type,
23+
)
24+
from .utils import is_duck_array
1925

2026
try:
2127
import dask.array as dask_array
28+
from dask.base import tokenize
2229
except ImportError:
2330
dask_array = None # type: ignore
2431

@@ -39,7 +46,7 @@ def f(*args, **kwargs):
3946
dispatch_args = args[0]
4047
else:
4148
dispatch_args = args[array_args]
42-
if any(isinstance(a, dask_array_type) for a in dispatch_args):
49+
if any(is_duck_dask_array(a) for a in dispatch_args):
4350
try:
4451
wrapped = getattr(dask_module, name)
4552
except AttributeError as e:
@@ -57,7 +64,7 @@ def f(*args, **kwargs):
5764

5865

5966
def fail_on_dask_array_input(values, msg=None, func_name=None):
60-
if isinstance(values, dask_array_type):
67+
if is_duck_dask_array(values):
6168
if msg is None:
6269
msg = "%r is not yet a valid method on dask arrays"
6370
if func_name is None:
@@ -129,7 +136,7 @@ def notnull(data):
129136

130137

131138
def gradient(x, coord, axis, edge_order):
132-
if isinstance(x, dask_array_type):
139+
if is_duck_dask_array(x):
133140
return dask_array.gradient(x, coord, axis=axis, edge_order=edge_order)
134141
return np.gradient(x, coord, axis=axis, edge_order=edge_order)
135142

@@ -174,11 +181,7 @@ def astype(data, **kwargs):
174181

175182

176183
def asarray(data, xp=np):
177-
return (
178-
data
179-
if (isinstance(data, dask_array_type) or hasattr(data, "__array_function__"))
180-
else xp.asarray(data)
181-
)
184+
return data if is_duck_array(data) else xp.asarray(data)
182185

183186

184187
def as_shared_dtype(scalars_or_arrays):
@@ -200,24 +203,20 @@ def as_shared_dtype(scalars_or_arrays):
200203

201204
def lazy_array_equiv(arr1, arr2):
202205
"""Like array_equal, but doesn't actually compare values.
203-
Returns True when arr1, arr2 identical or their dask names are equal.
206+
Returns True when arr1, arr2 identical or their dask tokens are equal.
204207
Returns False when shapes are not equal.
205208
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
206-
or their dask names are not equal
209+
or their dask tokens are not equal
207210
"""
208211
if arr1 is arr2:
209212
return True
210213
arr1 = asarray(arr1)
211214
arr2 = asarray(arr2)
212215
if arr1.shape != arr2.shape:
213216
return False
214-
if (
215-
dask_array
216-
and isinstance(arr1, dask_array_type)
217-
and isinstance(arr2, dask_array_type)
218-
):
219-
# GH3068
220-
if arr1.name == arr2.name:
217+
if dask_array and is_duck_dask_array(arr1) and is_duck_dask_array(arr2):
218+
# GH3068, GH4221
219+
if tokenize(arr1) == tokenize(arr2):
221220
return True
222221
else:
223222
return None
@@ -331,7 +330,7 @@ def f(values, axis=None, skipna=None, **kwargs):
331330
try:
332331
return func(values, axis=axis, **kwargs)
333332
except AttributeError:
334-
if not isinstance(values, dask_array_type):
333+
if not is_duck_dask_array(values):
335334
raise
336335
try: # dask/dask#3133 dask sometimes needs dtype argument
337336
# if func does not accept dtype, then raises TypeError
@@ -545,7 +544,7 @@ def mean(array, axis=None, skipna=None, **kwargs):
545544
+ offset
546545
)
547546
elif _contains_cftime_datetimes(array):
548-
if isinstance(array, dask_array_type):
547+
if is_duck_dask_array(array):
549548
raise NotImplementedError(
550549
"Computing the mean of an array containing "
551550
"cftime.datetime objects is not yet implemented on "
@@ -614,15 +613,15 @@ def rolling_window(array, axis, window, center, fill_value):
614613
Make an ndarray with a rolling window of axis-th dimension.
615614
The rolling dimension will be placed at the last dimension.
616615
"""
617-
if isinstance(array, dask_array_type):
616+
if is_duck_dask_array(array):
618617
return dask_array_ops.rolling_window(array, axis, window, center, fill_value)
619618
else: # np.ndarray
620619
return nputils.rolling_window(array, axis, window, center, fill_value)
621620

622621

623622
def least_squares(lhs, rhs, rcond=None, skipna=False):
624623
"""Return the coefficients and residuals of a least-squares fit."""
625-
if isinstance(rhs, dask_array_type):
624+
if is_duck_dask_array(rhs):
626625
return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
627626
else:
628627
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)

0 commit comments

Comments
 (0)