Skip to content

Make xr.corr and xr.map_blocks work without dask #5731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 24, 2021
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Deprecations

Bug fixes
~~~~~~~~~
- :py:func:`xr.map_blocks` and :py:func:`xr.corr` now work when dask is not installed (:issue:`3391`, :issue:`5715`, :pull:`5731`).
By `Gijom <https://github.com/Gijom>`_.
- Fix plot.line crash for data of shape ``(1, N)`` in _title_for_slice on format_item (:pull:`5948`).
By `Sebastian Weigand <https://github.com/s-weigand>`_.
- Fix a regression in the removal of duplicate backend entrypoints (:issue:`5944`, :pull:`5959`)
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,8 @@ def _get_valid_values(da, other):
da = da.where(~missing_vals)
return da
else:
return da
# ensure consistent return dtype
return da.astype(float)

da_a = da_a.map_blocks(_get_valid_values, args=[da_b])
da_b = da_b.map_blocks(_get_valid_values, args=[da_a])
Expand Down
5 changes: 3 additions & 2 deletions xarray/core/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .alignment import align
from .dataarray import DataArray
from .dataset import Dataset
from .pycompat import is_dask_collection

try:
import dask
Expand Down Expand Up @@ -328,13 +329,13 @@ def _wrapper(
raise TypeError("kwargs must be a mapping (for example, a dict)")

for value in kwargs.values():
if dask.is_dask_collection(value):
if is_dask_collection(value):
raise TypeError(
"Cannot pass dask collections in kwargs yet. Please compute or "
"load values before passing to map_blocks."
)

if not dask.is_dask_collection(obj):
if not is_dask_collection(obj):
return func(obj, *args, **kwargs)

all_args = [obj] + list(args)
Expand Down
8 changes: 6 additions & 2 deletions xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@ def __init__(self, mod):
self.available = duck_array_module is not None


def is_duck_dask_array(x):
def is_dask_collection(x):
if DuckArrayModule("dask").available:
from dask.base import is_dask_collection

return is_duck_array(x) and is_dask_collection(x)
return is_dask_collection(x)
else:
return False


def is_duck_dask_array(x):
return is_duck_array(x) and is_dask_collection(x)


dsk = DuckArrayModule("dask")
dask_version = dsk.version
dask_array_type = dsk.type
Expand Down
25 changes: 23 additions & 2 deletions xarray/tests/test_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

from . import has_dask, raise_if_dask_computes, requires_dask

dask = pytest.importorskip("dask")


def assert_identical(a, b):
"""A version of this function which accepts numpy arrays"""
Expand Down Expand Up @@ -1420,6 +1418,7 @@ def arrays_w_tuples():
],
)
@pytest.mark.parametrize("dim", [None, "x", "time"])
@requires_dask
def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
# GH 5284
from dask import is_dask_collection
Expand Down Expand Up @@ -1554,6 +1553,28 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None:
assert_allclose(actual, expected)


@requires_dask
@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1])
@pytest.mark.parametrize("dim", [None, "time", "x"])
def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None:
da_al = da_a.chunk()
da_bl = da_b.chunk()
c_abl = xr.corr(da_al, da_bl, dim=dim)
c_ab = xr.corr(da_a, da_b, dim=dim)
c_ab_mixed = xr.corr(da_a, da_bl, dim=dim)
assert_allclose(c_ab, c_abl)
assert_allclose(c_ab, c_ab_mixed)


@requires_dask
def test_corr_dtype_error():
da_a = xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"])
da_b = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])

xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a.chunk(), da_b.chunk()))
xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk()))


@pytest.mark.parametrize(
"da_a",
arrays_w_tuples()[0],
Expand Down