Skip to content

Dask friendly check in .weighted() #4559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload

import numpy as np

from .computation import dot
from .options import _get_keep_attrs
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .dataarray import DataArray, Dataset
Expand Down Expand Up @@ -100,11 +103,22 @@ def __init__(self, obj, weights):
if not isinstance(weights, DataArray):
raise ValueError("`weights` must be a DataArray")

if weights.isnull().any():
raise ValueError(
"`weights` cannot contain missing values. "
"Missing values can be replaced by `weights.fillna(0)`."
def _weight_check(w):
if np.isnan(w).any():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.isnull() does a bit more than that: np.isnan won't detect NaT. @mathause, how likely is it to get datetime-like arrays here? They don't make much sense as weights, but as far as I can tell we don't check (I might be missing something, though)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no check for that. A TimeDelta may make some sense as weights. DateTime not so much. I think we can get away with using np.isnan. A Date* array as weights containing NaT should be super uncommon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could still operate on the dataarray instead of the dask/numpy array, but as @dcherian suggesred, that would be less efficient. I would be curious as to what penalties would actually occur when we use the weights.map_blocks compared to dask.array.map_blocks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Could use duck_array_ops.isnull to account for timedelta64? It is weird to have it as a weight though. Does that work?
  2. Re map_blocks: the xarray version adds tasks that create xarray objects wrapping every block in a dask array. That adds overhead which is totally unneccesary here.

raise ValueError(
"`weights` cannot contain missing values. "
"Missing values can be replaced by `weights.fillna(0)`."
)
return w.data

if is_duck_dask_array(weights.data):
import dask.array as dsa

weights.data = dsa.map_blocks(
Copy link
Contributor

@dcherian dcherian Nov 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
weights.data = dsa.map_blocks(
weights = weights.copy(data=dsa.map_blocks(

so we don't modify the original object. Could even do weights.data.map_blocks(...) to save some typing...

_weight_check, weights.data, dtype=weights.dtype
)
else:
weights.data = _weight_check(weights.data)

self.obj = obj
self.weights = weights
Expand Down
20 changes: 20 additions & 0 deletions xarray/tests/test_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from xarray import DataArray
from xarray.tests import assert_allclose, assert_equal, raises_regex

from . import raise_if_dask_computes, requires_dask


@pytest.mark.parametrize("as_dataset", (True, False))
def test_weighted_non_DataArray_weights(as_dataset):
Expand All @@ -29,6 +31,24 @@ def test_weighted_weights_nan_raises(as_dataset, weights):
data.weighted(DataArray(weights))


@requires_dask
@pytest.mark.parametrize("as_dataset", (True, False))
@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
def test_weighted_weights_nan_raises_dask(as_dataset, weights):

data = DataArray([1, 2]).chunk({"dim_0": -1})
if as_dataset:
data = data.to_dataset(name="data")

weights = DataArray(weights).chunk({"dim_0": -1})

with raise_if_dask_computes():
weighted = data.weighted(weights)

with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
weighted.sum().load()


@pytest.mark.parametrize(
("weights", "expected"),
(([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)),
Expand Down