-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Improved duck array wrapping #9798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 21 commits
fd6b339
893408c
f7866ce
90037fe
5ba1a2f
6225ae3
e2911c2
2ac37f9
1cc344b
69080a5
372439c
0eef2cb
6739504
9e6d6f8
e721011
1fe4131
205c199
7752088
c8d4e5e
e67a819
f306768
18ebdcd
f51e3fb
121af9e
472ae7e
5aa4a39
390df6f
f6074d2
561f21b
bfd6aeb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,21 +18,16 @@ | |
import pandas as pd | ||
from numpy import all as array_all # noqa: F401 | ||
from numpy import any as array_any # noqa: F401 | ||
from numpy import concatenate as _concatenate | ||
from numpy import ( # noqa: F401 | ||
full_like, | ||
gradient, | ||
isclose, | ||
isin, | ||
isnat, | ||
take, | ||
tensordot, | ||
transpose, | ||
unravel_index, | ||
) | ||
from pandas.api.types import is_extension_array_dtype | ||
|
||
from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils | ||
from xarray.core.array_api_compat import get_array_namespace | ||
from xarray.core.options import OPTIONS | ||
from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available | ||
from xarray.namedarray.parallelcompat import get_chunked_array_type | ||
|
@@ -52,28 +47,6 @@ | |
dask_available = module_available("dask") | ||
|
||
|
||
def get_array_namespace(*values): | ||
def _get_array_namespace(x): | ||
if hasattr(x, "__array_namespace__"): | ||
return x.__array_namespace__() | ||
else: | ||
return np | ||
|
||
namespaces = {_get_array_namespace(t) for t in values} | ||
non_numpy = namespaces - {np} | ||
|
||
if len(non_numpy) > 1: | ||
raise TypeError( | ||
"cannot deal with more than one type supporting the array API at the same time" | ||
) | ||
elif non_numpy: | ||
[xp] = non_numpy | ||
else: | ||
xp = np | ||
|
||
return xp | ||
|
||
|
||
def einsum(*args, **kwargs): | ||
from xarray.core.options import OPTIONS | ||
|
||
|
@@ -82,7 +55,23 @@ def einsum(*args, **kwargs): | |
|
||
return opt_einsum.contract(*args, **kwargs) | ||
else: | ||
return np.einsum(*args, **kwargs) | ||
xp = get_array_namespace(*args) | ||
return xp.einsum(*args, **kwargs) | ||
|
||
|
||
def tensordot(*args, **kwargs): | ||
xp = get_array_namespace(*args) | ||
return xp.tensordot(*args, **kwargs) | ||
|
||
|
||
def cross(*args, **kwargs): | ||
xp = get_array_namespace(*args) | ||
return xp.cross(*args, **kwargs) | ||
|
||
|
||
def gradient(f, *varargs, axis=None, edge_order=1): | ||
xp = get_array_namespace(f) | ||
return xp.gradient(f, *varargs, axis=axis, edge_order=edge_order) | ||
|
||
|
||
def _dask_or_eager_func( | ||
|
@@ -131,15 +120,20 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): | |
"masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" | ||
) | ||
|
||
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), | ||
# so we need to hand-code this. | ||
sliding_window_view = _dask_or_eager_func( | ||
"sliding_window_view", | ||
eager_module=np.lib.stride_tricks, | ||
dask_module=dask_array_compat, | ||
dask_only_kwargs=("automatic_rechunk",), | ||
numpy_only_kwargs=("subok", "writeable"), | ||
) | ||
|
||
def sliding_window_view(array, window_shape, axis=None, **kwargs): | ||
# TODO: some libraries (e.g. jax) don't have this, implement an alternative? | ||
xp = get_array_namespace(array) | ||
# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), | ||
# so we need to hand-code this. | ||
func = _dask_or_eager_func( | ||
"sliding_window_view", | ||
eager_module=xp.lib.stride_tricks, | ||
dask_module=dask_array_compat, | ||
dask_only_kwargs=("automatic_rechunk",), | ||
numpy_only_kwargs=("subok", "writeable"), | ||
) | ||
return func(array, window_shape, axis=axis, **kwargs) | ||
|
||
|
||
def round(array): | ||
|
@@ -172,7 +166,8 @@ def isnull(data): | |
) | ||
): | ||
# these types cannot represent missing values | ||
return full_like(data, dtype=bool, fill_value=False) | ||
dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will do, but also |
||
return full_like(data, dtype=dtype, fill_value=False) | ||
else: | ||
# at this point, array should have dtype=object | ||
if isinstance(data, np.ndarray) or is_extension_array_dtype(data): | ||
|
@@ -213,11 +208,23 @@ def cumulative_trapezoid(y, x, axis): | |
|
||
# Pad so that 'axis' has same length in result as it did in y | ||
pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] | ||
integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) | ||
|
||
xp = get_array_namespace(y, x) | ||
integrand = xp.pad(integrand, pads, mode="constant", constant_values=0.0) | ||
|
||
return cumsum(integrand, axis=axis, skipna=False) | ||
|
||
|
||
def full_like(a, fill_value, **kwargs): | ||
xp = get_array_namespace(a) | ||
return xp.full_like(a, fill_value, **kwargs) | ||
|
||
|
||
def empty_like(a, **kwargs): | ||
xp = get_array_namespace(a) | ||
return xp.empty_like(a, **kwargs) | ||
|
||
|
||
def astype(data, dtype, **kwargs): | ||
if hasattr(data, "__array_namespace__"): | ||
xp = get_array_namespace(data) | ||
|
@@ -348,7 +355,8 @@ def array_notnull_equiv(arr1, arr2): | |
|
||
def count(data, axis=None): | ||
"""Count the number of non-NA in this array along the given axis or axes""" | ||
return np.sum(np.logical_not(isnull(data)), axis=axis) | ||
xp = get_array_namespace(data) | ||
return xp.sum(xp.logical_not(isnull(data)), axis=axis) | ||
|
||
|
||
def sum_where(data, axis=None, dtype=None, where=None): | ||
|
@@ -363,7 +371,7 @@ def sum_where(data, axis=None, dtype=None, where=None): | |
|
||
def where(condition, x, y): | ||
"""Three argument where() with better dtype promotion rules.""" | ||
xp = get_array_namespace(condition) | ||
xp = get_array_namespace(condition, x, y) | ||
return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) | ||
|
||
|
||
|
@@ -380,15 +388,25 @@ def fillna(data, other): | |
return where(notnull(data), data, other) | ||
|
||
|
||
def logical_not(data): | ||
xp = get_array_namespace(data) | ||
return xp.logical_not(data) | ||
|
||
|
||
def clip(data, min=None, max=None): | ||
xp = get_array_namespace(data) | ||
return xp.clip(data, min, max) | ||
|
||
|
||
def concatenate(arrays, axis=0): | ||
"""concatenate() with better dtype promotion rules.""" | ||
# TODO: remove the additional check once `numpy` adds `concat` to its array namespace | ||
if hasattr(arrays[0], "__array_namespace__") and not isinstance( | ||
arrays[0], np.ndarray | ||
): | ||
xp = get_array_namespace(arrays[0]) | ||
# TODO: `concat` is the xp compliant name, but fallback to concatenate for | ||
# older numpy and for cupy | ||
xp = get_array_namespace(*arrays) | ||
if hasattr(xp, "concat"): | ||
return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) | ||
return _concatenate(as_shared_dtype(arrays), axis=axis) | ||
else: | ||
return xp.concatenate(as_shared_dtype(arrays, xp=xp), axis=axis) | ||
|
||
|
||
def stack(arrays, axis=0): | ||
|
@@ -406,6 +424,26 @@ def ravel(array): | |
return reshape(array, (-1,)) | ||
|
||
|
||
def transpose(array, axes=None): | ||
xp = get_array_namespace(array) | ||
return xp.transpose(array, axes) | ||
|
||
|
||
def moveaxis(array, source, destination): | ||
xp = get_array_namespace(array) | ||
return xp.moveaxis(array, source, destination) | ||
|
||
|
||
def pad(array, pad_width, **kwargs): | ||
xp = get_array_namespace(array) | ||
return xp.pad(array, pad_width, **kwargs) | ||
|
||
|
||
def quantile(array, q, axis=None, **kwargs): | ||
xp = get_array_namespace(array) | ||
return xp.quantile(array, q, axis=axis, **kwargs) | ||
|
||
|
||
@contextlib.contextmanager | ||
def _ignore_warnings_if(condition): | ||
if condition: | ||
|
@@ -747,6 +785,11 @@ def last(values, axis, skipna=None): | |
return take(values, -1, axis=axis) | ||
|
||
|
||
def isin(element, test_elements, **kwargs): | ||
xp = get_array_namespace(element, test_elements) | ||
return xp.isin(element, test_elements, **kwargs) | ||
|
||
|
||
def least_squares(lhs, rhs, rcond=None, skipna=False): | ||
"""Return the coefficients and residuals of a least-squares fit.""" | ||
if is_duck_dask_array(rhs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one of the biggest outstanding bummers of wrapping jax arrays. There is apparently openness to adding this as an API (even though it would not offer any performance benefit in XLA). But given this is way outside the API standard, whether it makes sense to implement a general version within xarray that doesn't rely on stride tricks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could implement a version using "summed area tables" (basically run a single accumulator and then compute differences between the window edges); or convolutions I guess.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have something that works pretty well with this style of gather operation. But only in a
jit
context where XLA can work its magic. So I guess this is better left to the specific library to implement, or the user.