Skip to content

Commit 5a60354

Browse files
jbuseckedcherianmax-sixtymathause
authored
Dask friendly check in .weighted() (#4559)
* Use map_blocks for weighted init checks * added dask test * Update xarray/core/weighted.py Co-authored-by: Deepak Cherian <[email protected]> * Update xarray/core/weighted.py Co-authored-by: Maximilian Roos <[email protected]> * implement requires_dask * Implement raise_if_dask_computes * Added logic to check for dask arrays * applied isort * Update xarray/core/weighted.py Co-authored-by: Maximilian Roos <[email protected]> * Refactor dask mapping * Try duck_array_ops.isnull * black formatting * Add whatsnew * Remove numpy * apply isort * Update xarray/core/weighted.py * Update xarray/core/weighted.py Co-authored-by: Mathias Hauser <[email protected]> * black formatting * Update xarray/core/weighted.py * Update xarray/core/weighted.py Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Maximilian Roos <[email protected]> Co-authored-by: Mathias Hauser <[email protected]> Co-authored-by: Mathias Hauser <[email protected]>
1 parent 1735892 commit 5a60354

File tree

3 files changed

+40
-4
lines changed

3 files changed

+40
-4
lines changed

doc/whats-new.rst

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ New Features
4343
By `Stephan Hoyer <https://github.com/shoyer>`_.
4444
- Added typehints in :py:func:`align` to reflect that the same type received in ``objects`` arg will be returned (:pull:`4522`).
4545
By `Michal Baumgartner <https://github.com/m1so>`_.
46+
- :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`).
47+
By `Julius Busecke <https://github.com/jbusecke>`_.
4648

4749
Bug fixes
4850
~~~~~~~~~

xarray/core/weighted.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
22

3+
from . import duck_array_ops
34
from .computation import dot
45
from .options import _get_keep_attrs
6+
from .pycompat import is_duck_dask_array
57

68
if TYPE_CHECKING:
79
from .dataarray import DataArray, Dataset
@@ -100,12 +102,24 @@ def __init__(self, obj, weights):
100102
if not isinstance(weights, DataArray):
101103
raise ValueError("`weights` must be a DataArray")
102104

103-
if weights.isnull().any():
104-
raise ValueError(
105-
"`weights` cannot contain missing values. "
106-
"Missing values can be replaced by `weights.fillna(0)`."
105+
def _weight_check(w):
106+
# Ref https://github.com/pydata/xarray/pull/4559/files#r515968670
107+
if duck_array_ops.isnull(w).any():
108+
raise ValueError(
109+
"`weights` cannot contain missing values. "
110+
"Missing values can be replaced by `weights.fillna(0)`."
111+
)
112+
return w
113+
114+
if is_duck_dask_array(weights.data):
115+
# assign to copy - else the check is not triggered
116+
weights = weights.copy(
117+
data=weights.data.map_blocks(_weight_check, dtype=weights.dtype)
107118
)
108119

120+
else:
121+
_weight_check(weights.data)
122+
109123
self.obj = obj
110124
self.weights = weights
111125

xarray/tests/test_weighted.py

+20
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from xarray import DataArray
66
from xarray.tests import assert_allclose, assert_equal, raises_regex
77

8+
from . import raise_if_dask_computes, requires_dask
9+
810

911
@pytest.mark.parametrize("as_dataset", (True, False))
1012
def test_weighted_non_DataArray_weights(as_dataset):
@@ -29,6 +31,24 @@ def test_weighted_weights_nan_raises(as_dataset, weights):
2931
data.weighted(DataArray(weights))
3032

3133

34+
@requires_dask
35+
@pytest.mark.parametrize("as_dataset", (True, False))
36+
@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
37+
def test_weighted_weights_nan_raises_dask(as_dataset, weights):
38+
39+
data = DataArray([1, 2]).chunk({"dim_0": -1})
40+
if as_dataset:
41+
data = data.to_dataset(name="data")
42+
43+
weights = DataArray(weights).chunk({"dim_0": -1})
44+
45+
with raise_if_dask_computes():
46+
weighted = data.weighted(weights)
47+
48+
with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
49+
weighted.sum().load()
50+
51+
3252
@pytest.mark.parametrize(
3353
("weights", "expected"),
3454
(([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)),

0 commit comments

Comments
 (0)