diff --git a/doc/whats-new.rst b/doc/whats-new.rst index b66c99d0bcb..eb61cd154cf 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,8 @@ Deprecations Bug fixes ~~~~~~~~~ +- :py:func:`xr.map_blocks` and :py:func:`xr.corr` now work when dask is not installed (:issue:`3391`, :issue:`5715`, :pull:`5731`). + By `Gijom `_. - Fix plot.line crash for data of shape ``(1, N)`` in _title_for_slice on format_item (:pull:`5948`). By `Sebastian Weigand `_. - Fix a regression in the removal of duplicate backend entrypoints (:issue:`5944`, :pull:`5959`) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0c21ca07744..04fda5a7cb3 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1359,7 +1359,8 @@ def _get_valid_values(da, other): da = da.where(~missing_vals) return da else: - return da + # ensure consistent return dtype + return da.astype(float) da_a = da_a.map_blocks(_get_valid_values, args=[da_b]) da_b = da_b.map_blocks(_get_valid_values, args=[da_a]) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 4917714a9c2..f20256346da 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -23,6 +23,7 @@ from .alignment import align from .dataarray import DataArray from .dataset import Dataset +from .pycompat import is_dask_collection try: import dask @@ -328,13 +329,13 @@ def _wrapper( raise TypeError("kwargs must be a mapping (for example, a dict)") for value in kwargs.values(): - if dask.is_dask_collection(value): + if is_dask_collection(value): raise TypeError( "Cannot pass dask collections in kwargs yet. Please compute or " "load values before passing to map_blocks." ) - if not dask.is_dask_collection(obj): + if not is_dask_collection(obj): return func(obj, *args, **kwargs) all_args = [obj] + list(args) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index d1649235006..d95dced9ddf 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -43,15 +43,19 @@ def __init__(self, mod): self.available = duck_array_module is not None -def is_duck_dask_array(x): +def is_dask_collection(x): if DuckArrayModule("dask").available: from dask.base import is_dask_collection - return is_duck_array(x) and is_dask_collection(x) + return is_dask_collection(x) else: return False +def is_duck_dask_array(x): + return is_duck_array(x) and is_dask_collection(x) + + dsk = DuckArrayModule("dask") dask_version = dsk.version dask_array_type = dsk.type diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 22a3efce999..8af7604cae5 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -24,8 +24,6 @@ from . import has_dask, raise_if_dask_computes, requires_dask -dask = pytest.importorskip("dask") - def assert_identical(a, b): """A version of this function which accepts numpy arrays""" @@ -1420,6 +1418,7 @@ def arrays_w_tuples(): ], ) @pytest.mark.parametrize("dim", [None, "x", "time"]) +@requires_dask def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None: # GH 5284 from dask import is_dask_collection @@ -1554,6 +1553,28 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None: assert_allclose(actual, expected) +@requires_dask +@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1]) +@pytest.mark.parametrize("dim", [None, "time", "x"]) +def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None: + da_al = da_a.chunk() + da_bl = da_b.chunk() + c_abl = xr.corr(da_al, da_bl, dim=dim) + c_ab = xr.corr(da_a, da_b, dim=dim) + c_ab_mixed = xr.corr(da_a, da_bl, dim=dim) + assert_allclose(c_ab, c_abl) + assert_allclose(c_ab, c_ab_mixed) + + +@requires_dask +def test_corr_dtype_error(): + da_a = xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"]) + da_b = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]) + + xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a.chunk(), da_b.chunk())) + xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk())) + + @pytest.mark.parametrize( "da_a", arrays_w_tuples()[0],