Skip to content

Commit 9232b03

Browse files
committed
allow where to receive a callable
1 parent a333a5c commit 9232b03

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

xarray/core/common.py

+3
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,9 @@ def where(self, cond, other=dtypes.NA, drop: bool = False):
11521152
from .dataarray import DataArray
11531153
from .dataset import Dataset
11541154

1155+
if callable(cond):
1156+
return self.where(cond(self), other=other, drop=drop)
1157+
11551158
if drop:
11561159
if other is not dtypes.NA:
11571160
raise ValueError("cannot set `other` if drop=True")

xarray/tests/test_dataarray.py

+6
Original file line numberDiff line numberDiff line change
@@ -2215,6 +2215,12 @@ def test_where(self):
22152215
actual = arr.where(arr.x < 2, drop=True)
22162216
assert_identical(actual, expected)
22172217

2218+
def test_where_lambda(self):
2219+
arr = DataArray(np.arange(4), dims="y")
2220+
expected = arr.sel(y=slice(2))
2221+
actual = arr.where(lambda x: x.y < 2, drop=True)
2222+
assert_identical(actual, expected)
2223+
22182224
def test_where_string(self):
22192225
array = DataArray(["a", "b"])
22202226
expected = DataArray(np.array(["a", np.nan], dtype=object))

xarray/tests/test_dataset.py

+8
Original file line numberDiff line numberDiff line change
@@ -4349,13 +4349,21 @@ def test_where(self):
43494349
assert actual.a.name == "a"
43504350
assert actual.a.attrs == ds.a.attrs
43514351

4352+
ds = Dataset({"a": ("x", range(5))})
4353+
expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])})
4354+
actual = ds.where(lambda x: x > 1)
4355+
assert_identical(expected, actual)
4356+
43524357
def test_where_other(self):
43534358
ds = Dataset({"a": ("x", range(5))}, {"x": range(5)})
43544359
expected = Dataset({"a": ("x", [-1, -1, 2, 3, 4])}, {"x": range(5)})
43554360
actual = ds.where(ds > 1, -1)
43564361
assert_equal(expected, actual)
43574362
assert actual.a.dtype == int
43584363

4364+
actual = ds.where(lambda x: x > 1, -1)
4365+
assert_equal(expected, actual)
4366+
43594367
with raises_regex(ValueError, "cannot set"):
43604368
ds.where(ds > 1, other=0, drop=True)
43614369

0 commit comments

Comments
 (0)