Skip to content

Commit bbd25c6

Browse files
nvictusshoyer
authored andcommitted
Support for __array_function__ implementers (sparse arrays) [WIP] (#3117)
* Support for __array_function__ implementers * Pep8 * Consistent naming * Check for NEP18 enabled and nep18 non-numpy arrays * Replace .values with .data * Add initial test for nep18 * Fix linting issues * Add parameterized tests * Internal clean-up of isnull() to avoid relying on pandas This version should be much more compatible out of the box with duck typing. * Add sparse to ci requirements * Moar tests * Two more patches for __array_function__ duck-arrays * Don't use coords attribute from duck-arrays that aren't derived from DataWithCoords * Improve checking for coords, and autopep8 * Skip tests if NEP-18 envvar is not set * flake8 * Update xarray/core/dataarray.py Co-Authored-By: Stephan Hoyer <[email protected]> * Fix coords parsing * More tests * Add align tests * Replace nep18 tests with more extensive tests on pydata/sparse * Add xfails for missing np.result_type (fixed by pydata/sparse/pull/261) * Fix xpasses * Revert isnull/notnull * Fix as_like_arrays by coercing dense arrays to COO if any sparse * Make Variable.load a no-op for non-dask duck arrays * Add additional method tests * Fix utils.as_scalar to handle duck arrays with ndim>0
1 parent 50f8970 commit bbd25c6

File tree

10 files changed

+775
-19
lines changed

10 files changed

+775
-19
lines changed

ci/requirements/py37.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- pip
2222
- scipy
2323
- seaborn
24+
- sparse
2425
- toolz
2526
- rasterio
2627
- boto3

xarray/core/dataarray.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -260,15 +260,18 @@ def __init__(
260260
else:
261261
# try to fill in arguments from data if they weren't supplied
262262
if coords is None:
263-
coords = getattr(data, 'coords', None)
264-
if isinstance(data, pd.Series):
263+
264+
if isinstance(data, DataArray):
265+
coords = data.coords
266+
elif isinstance(data, pd.Series):
265267
coords = [data.index]
266268
elif isinstance(data, pd.DataFrame):
267269
coords = [data.index, data.columns]
268270
elif isinstance(data, (pd.Index, IndexVariable)):
269271
coords = [data]
270272
elif isinstance(data, pdcompat.Panel):
271273
coords = [data.items, data.major_axis, data.minor_axis]
274+
272275
if dims is None:
273276
dims = getattr(data, 'dims', getattr(coords, 'dims', None))
274277
if name is None:

xarray/core/duck_array_ops.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from . import dask_array_ops, dtypes, npcompat, nputils
1515
from .nputils import nanfirst, nanlast
16-
from .pycompat import dask_array_type
16+
from .pycompat import dask_array_type, sparse_array_type
1717

1818
try:
1919
import dask.array as dask_array
@@ -64,6 +64,7 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
6464
around = _dask_or_eager_func('around')
6565
isclose = _dask_or_eager_func('isclose')
6666

67+
6768
if hasattr(np, 'isnat') and (
6869
dask_array is None or hasattr(dask_array_type, '__array_ufunc__')):
6970
# np.isnat is available since NumPy 1.13, so __array_ufunc__ is always
@@ -153,7 +154,11 @@ def trapz(y, x, axis):
153154

154155

155156
def asarray(data):
156-
return data if isinstance(data, dask_array_type) else np.asarray(data)
157+
return (
158+
data if (isinstance(data, dask_array_type)
159+
or hasattr(data, '__array_function__'))
160+
else np.asarray(data)
161+
)
157162

158163

159164
def as_shared_dtype(scalars_or_arrays):
@@ -170,6 +175,9 @@ def as_shared_dtype(scalars_or_arrays):
170175
def as_like_arrays(*data):
171176
if all(isinstance(d, dask_array_type) for d in data):
172177
return data
178+
elif any(isinstance(d, sparse_array_type) for d in data):
179+
from sparse import COO
180+
return tuple(COO(d) for d in data)
173181
else:
174182
return tuple(np.asarray(d) for d in data)
175183

xarray/core/formatting.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,10 @@ def set_numpy_options(*args, **kwargs):
357357

358358

359359
def short_array_repr(array):
360-
array = np.asarray(array)
360+
361+
if not hasattr(array, '__array_function__'):
362+
array = np.asarray(array)
363+
361364
# default to lower precision so a full (abbreviated) line can fit on
362365
# one line with the default display_width
363366
options = {
@@ -394,7 +397,7 @@ def short_data_repr(array):
394397
if isinstance(getattr(array, 'variable', array)._data, dask_array_type):
395398
return short_dask_repr(array)
396399
elif array._in_memory or array.size < 1e5:
397-
return short_array_repr(array.values)
400+
return short_array_repr(array.data)
398401
else:
399402
return u'[{} values with dtype={}]'.format(array.size, array.dtype)
400403

xarray/core/indexing.py

+13
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ def as_indexable(array):
657657
return PandasIndexAdapter(array)
658658
if isinstance(array, dask_array_type):
659659
return DaskIndexingAdapter(array)
660+
if hasattr(array, '__array_function__'):
661+
return NdArrayLikeIndexingAdapter(array)
662+
660663
raise TypeError('Invalid array type: {}'.format(type(array)))
661664

662665

@@ -1189,6 +1192,16 @@ def __setitem__(self, key, value):
11891192
raise
11901193

11911194

1195+
class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter):
1196+
def __init__(self, array):
1197+
if not hasattr(array, '__array_function__'):
1198+
raise TypeError(
1199+
'NdArrayLikeIndexingAdapter must wrap an object that '
1200+
'implements the __array_function__ protocol'
1201+
)
1202+
self.array = array
1203+
1204+
11921205
class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
11931206
"""Wrap a dask array to support explicit indexing."""
11941207

xarray/core/npcompat.py

+15
Original file line numberDiff line numberDiff line change
@@ -357,3 +357,18 @@ def moveaxis(a, source, destination):
357357
# https://github.com/numpy/numpy/issues/7370
358358
# https://github.com/numpy/numpy-stubs/
359359
DTypeLike = Union[np.dtype, str]
360+
361+
362+
# from dask/array/utils.py
363+
def _is_nep18_active():
364+
class A:
365+
def __array_function__(self, *args, **kwargs):
366+
return True
367+
368+
try:
369+
return np.concatenate([A()])
370+
except ValueError:
371+
return False
372+
373+
374+
IS_NEP18_ACTIVE = _is_nep18_active()

xarray/core/pycompat.py

+7
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@
88
dask_array_type = (dask.array.Array,)
99
except ImportError: # pragma: no cover
1010
dask_array_type = ()
11+
12+
try:
13+
# solely for isinstance checks
14+
import sparse
15+
sparse_array_type = (sparse.SparseArray,)
16+
except ImportError: # pragma: no cover
17+
sparse_array_type = ()

xarray/core/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def is_scalar(value: Any) -> bool:
243243
return (
244244
getattr(value, 'ndim', None) == 0 or
245245
isinstance(value, (str, bytes)) or not
246-
isinstance(value, (Iterable, ) + dask_array_type))
246+
(isinstance(value, (Iterable, ) + dask_array_type) or
247+
hasattr(value, '__array_function__'))
248+
)
247249

248250

249251
def is_valid_numpy_dtype(dtype: Any) -> bool:

xarray/core/variable.py

+27-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
as_indexable)
1818
from .options import _get_keep_attrs
1919
from .pycompat import dask_array_type, integer_types
20+
from .npcompat import IS_NEP18_ACTIVE
2021
from .utils import (
2122
OrderedSet, decode_numpy_dict_values, either_dict_or_kwargs,
2223
ensure_us_time_resolution)
@@ -179,6 +180,18 @@ def as_compatible_data(data, fastpath=False):
179180
else:
180181
data = np.asarray(data)
181182

183+
if not isinstance(data, np.ndarray):
184+
if hasattr(data, '__array_function__'):
185+
if IS_NEP18_ACTIVE:
186+
return data
187+
else:
188+
raise TypeError(
189+
'Got an NumPy-like array type providing the '
190+
'__array_function__ protocol but NEP18 is not enabled. '
191+
'Check that numpy >= v1.16 and that the environment '
192+
'variable "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION" is set to '
193+
'"1"')
194+
182195
# validate whether the data is valid data types
183196
data = np.asarray(data)
184197

@@ -288,7 +301,7 @@ def _in_memory(self):
288301

289302
@property
290303
def data(self):
291-
if isinstance(self._data, dask_array_type):
304+
if hasattr(self._data, '__array_function__'):
292305
return self._data
293306
else:
294307
return self.values
@@ -320,7 +333,7 @@ def load(self, **kwargs):
320333
"""
321334
if isinstance(self._data, dask_array_type):
322335
self._data = as_compatible_data(self._data.compute(**kwargs))
323-
elif not isinstance(self._data, np.ndarray):
336+
elif not hasattr(self._data, '__array_function__'):
324337
self._data = np.asarray(self._data)
325338
return self
326339

@@ -705,8 +718,8 @@ def __setitem__(self, key, value):
705718

706719
if new_order:
707720
value = duck_array_ops.asarray(value)
708-
value = value[(len(dims) - value.ndim) * (np.newaxis,) +
709-
(Ellipsis,)]
721+
value = value[(len(dims) - value.ndim) * (np.newaxis,)
722+
+ (Ellipsis,)]
710723
value = duck_array_ops.moveaxis(
711724
value, new_order, range(len(new_order)))
712725

@@ -805,7 +818,8 @@ def copy(self, deep=True, data=None):
805818
data = indexing.MemoryCachedArray(data.array)
806819

807820
if deep:
808-
if isinstance(data, dask_array_type):
821+
if (hasattr(data, '__array_function__')
822+
or isinstance(data, dask_array_type)):
809823
data = data.copy()
810824
elif not isinstance(data, PandasIndexAdapter):
811825
# pandas.Index is immutable
@@ -1494,9 +1508,10 @@ def equals(self, other, equiv=duck_array_ops.array_equiv):
14941508
"""
14951509
other = getattr(other, 'variable', other)
14961510
try:
1497-
return (self.dims == other.dims and
1498-
(self._data is other._data or
1499-
equiv(self.data, other.data)))
1511+
return (
1512+
self.dims == other.dims and
1513+
(self._data is other._data or equiv(self.data, other.data))
1514+
)
15001515
except (TypeError, AttributeError):
15011516
return False
15021517

@@ -1517,8 +1532,8 @@ def identical(self, other):
15171532
"""Like equals, but also checks attributes.
15181533
"""
15191534
try:
1520-
return (utils.dict_equiv(self.attrs, other.attrs) and
1521-
self.equals(other))
1535+
return (utils.dict_equiv(self.attrs, other.attrs)
1536+
and self.equals(other))
15221537
except (TypeError, AttributeError):
15231538
return False
15241539

@@ -1959,8 +1974,8 @@ def equals(self, other, equiv=None):
19591974
# otherwise use the native index equals, rather than looking at _data
19601975
other = getattr(other, 'variable', other)
19611976
try:
1962-
return (self.dims == other.dims and
1963-
self._data_equals(other))
1977+
return (self.dims == other.dims
1978+
and self._data_equals(other))
19641979
except (TypeError, AttributeError):
19651980
return False
19661981

0 commit comments

Comments
 (0)