Skip to content

Commit 21e8484

Browse files
Gijommathausedcherian
authored
Make xr.corr and xr.map_blocks work without dask (#5731)
Co-authored-by: Mathias Hauser <[email protected]> Co-authored-by: dcherian <[email protected]> Co-authored-by: Deepak Cherian <[email protected]>
1 parent 5db4046 commit 21e8484

File tree

5 files changed

+36
-7
lines changed

5 files changed

+36
-7
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ Deprecations
3434

3535
Bug fixes
3636
~~~~~~~~~
37+
- :py:func:`xr.map_blocks` and :py:func:`xr.corr` now work when dask is not installed (:issue:`3391`, :issue:`5715`, :pull:`5731`).
38+
By `Gijom <https://github.com/Gijom>`_.
3739
- Fix plot.line crash for data of shape ``(1, N)`` in _title_for_slice on format_item (:pull:`5948`).
3840
By `Sebastian Weigand <https://github.com/s-weigand>`_.
3941
- Fix a regression in the removal of duplicate backend entrypoints (:issue:`5944`, :pull:`5959`)

xarray/core/computation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,8 @@ def _get_valid_values(da, other):
13591359
da = da.where(~missing_vals)
13601360
return da
13611361
else:
1362-
return da
1362+
# ensure consistent return dtype
1363+
return da.astype(float)
13631364

13641365
da_a = da_a.map_blocks(_get_valid_values, args=[da_b])
13651366
da_b = da_b.map_blocks(_get_valid_values, args=[da_a])

xarray/core/parallel.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .alignment import align
2424
from .dataarray import DataArray
2525
from .dataset import Dataset
26+
from .pycompat import is_dask_collection
2627

2728
try:
2829
import dask
@@ -328,13 +329,13 @@ def _wrapper(
328329
raise TypeError("kwargs must be a mapping (for example, a dict)")
329330

330331
for value in kwargs.values():
331-
if dask.is_dask_collection(value):
332+
if is_dask_collection(value):
332333
raise TypeError(
333334
"Cannot pass dask collections in kwargs yet. Please compute or "
334335
"load values before passing to map_blocks."
335336
)
336337

337-
if not dask.is_dask_collection(obj):
338+
if not is_dask_collection(obj):
338339
return func(obj, *args, **kwargs)
339340

340341
all_args = [obj] + list(args)

xarray/core/pycompat.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,19 @@ def __init__(self, mod):
4343
self.available = duck_array_module is not None
4444

4545

46-
def is_duck_dask_array(x):
46+
def is_dask_collection(x):
4747
if DuckArrayModule("dask").available:
4848
from dask.base import is_dask_collection
4949

50-
return is_duck_array(x) and is_dask_collection(x)
50+
return is_dask_collection(x)
5151
else:
5252
return False
5353

5454

55+
def is_duck_dask_array(x):
56+
return is_duck_array(x) and is_dask_collection(x)
57+
58+
5559
dsk = DuckArrayModule("dask")
5660
dask_version = dsk.version
5761
dask_array_type = dsk.type

xarray/tests/test_computation.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424

2525
from . import has_dask, raise_if_dask_computes, requires_dask
2626

27-
dask = pytest.importorskip("dask")
28-
2927

3028
def assert_identical(a, b):
3129
"""A version of this function which accepts numpy arrays"""
@@ -1420,6 +1418,7 @@ def arrays_w_tuples():
14201418
],
14211419
)
14221420
@pytest.mark.parametrize("dim", [None, "x", "time"])
1421+
@requires_dask
14231422
def test_lazy_corrcov(da_a, da_b, dim, ddof) -> None:
14241423
# GH 5284
14251424
from dask import is_dask_collection
@@ -1554,6 +1553,28 @@ def test_covcorr_consistency(da_a, da_b, dim) -> None:
15541553
assert_allclose(actual, expected)
15551554

15561555

1556+
@requires_dask
1557+
@pytest.mark.parametrize("da_a, da_b", arrays_w_tuples()[1])
1558+
@pytest.mark.parametrize("dim", [None, "time", "x"])
1559+
def test_corr_lazycorr_consistency(da_a, da_b, dim) -> None:
1560+
da_al = da_a.chunk()
1561+
da_bl = da_b.chunk()
1562+
c_abl = xr.corr(da_al, da_bl, dim=dim)
1563+
c_ab = xr.corr(da_a, da_b, dim=dim)
1564+
c_ab_mixed = xr.corr(da_a, da_bl, dim=dim)
1565+
assert_allclose(c_ab, c_abl)
1566+
assert_allclose(c_ab, c_ab_mixed)
1567+
1568+
1569+
@requires_dask
1570+
def test_corr_dtype_error():
1571+
da_a = xr.DataArray([[1, 2], [2, 1]], dims=["x", "time"])
1572+
da_b = xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"])
1573+
1574+
xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a.chunk(), da_b.chunk()))
1575+
xr.testing.assert_equal(xr.corr(da_a, da_b), xr.corr(da_a, da_b.chunk()))
1576+
1577+
15571578
@pytest.mark.parametrize(
15581579
"da_a",
15591580
arrays_w_tuples()[0],

0 commit comments

Comments
 (0)