From fb7d8b84f9a04a5c0f6eabac6fe7e06f2fdda205 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Tue, 3 Mar 2020 21:44:49 -0500 Subject: [PATCH 1/4] allow where to receive a callable --- xarray/core/common.py | 3 +++ xarray/tests/test_dataarray.py | 6 ++++++ xarray/tests/test_dataset.py | 9 +++++++++ 3 files changed, 18 insertions(+) diff --git a/xarray/core/common.py b/xarray/core/common.py index e3739d6d039..6e352f9bc4c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1152,6 +1152,9 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): from .dataarray import DataArray from .dataset import Dataset + if callable(cond): + return self.where(cond(self), other=other, drop=drop) + if drop: if other is not dtypes.NA: raise ValueError("cannot set `other` if drop=True") diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 0a622d279ba..b8a9c5edaf9 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2215,6 +2215,12 @@ def test_where(self): actual = arr.where(arr.x < 2, drop=True) assert_identical(actual, expected) + def test_where_lambda(self): + arr = DataArray(np.arange(4), dims="y") + expected = arr.sel(y=slice(2)) + actual = arr.where(lambda x: x.y < 2, drop=True) + assert_identical(actual, expected) + def test_where_string(self): array = DataArray(["a", "b"]) expected = DataArray(np.array(["a", np.nan], dtype=object)) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7bcf9379ae8..44ffafb23b1 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4349,6 +4349,12 @@ def test_where(self): assert actual.a.name == "a" assert actual.a.attrs == ds.a.attrs + # lambda + ds = Dataset({"a": ("x", range(5))}) + expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])}) + actual = ds.where(lambda x: x > 1) + assert_identical(expected, actual) + def test_where_other(self): ds = Dataset({"a": ("x", range(5))}, {"x": range(5)}) expected = Dataset({"a": ("x", [-1, -1, 2, 3, 4])}, {"x": range(5)}) @@ -4356,6 +4362,9 @@ def test_where_other(self): assert_equal(expected, actual) assert actual.a.dtype == int + actual = ds.where(lambda x: x > 1, -1) + assert_equal(expected, actual) + with raises_regex(ValueError, "cannot set"): ds.where(ds > 1, other=0, drop=True) From 6f83daa7385c803748473104d5ba8fa7a1bfed83 Mon Sep 17 00:00:00 2001 From: Maximilian Roos <5635139+max-sixty@users.noreply.github.com> Date: Wed, 4 Mar 2020 16:32:24 -0500 Subject: [PATCH 2/4] Update xarray/core/common.py Co-Authored-By: keewis --- xarray/core/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 6e352f9bc4c..0a77bea6aec 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1153,7 +1153,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): from .dataset import Dataset if callable(cond): - return self.where(cond(self), other=other, drop=drop) + cond = cond(self) if drop: if other is not dtypes.NA: From 849142302de5e471ac3667bff357bac0dd8a83c9 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Thu, 5 Mar 2020 21:54:38 -0500 Subject: [PATCH 3/4] docstring --- xarray/core/common.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/xarray/core/common.py b/xarray/core/common.py index 0a77bea6aec..c80cb24c5b5 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1119,6 +1119,15 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): >>> import numpy as np >>> a = xr.DataArray(np.arange(25).reshape(5, 5), dims=('x', 'y')) + >>> a + + array([[ 0, 1, 2, 3, 4], + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) + Dimensions without coordinates: x, y + >>> a.where(a.x + a.y < 4) array([[ 0., 1., 2., 3., nan], @@ -1127,6 +1136,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): [ 15., nan, nan, nan, nan], [ nan, nan, nan, nan, nan]]) Dimensions without coordinates: x, y + >>> a.where(a.x + a.y < 5, -1) array([[ 0, 1, 2, 3, 4], @@ -1135,6 +1145,7 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): [15, 16, -1, -1, -1], [20, -1, -1, -1, -1]]) Dimensions without coordinates: x, y + >>> a.where(a.x + a.y < 4, drop=True) array([[ 0., 1., 2., 3.], @@ -1143,6 +1154,14 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): [ 15., nan, nan, nan]]) Dimensions without coordinates: x, y + >>> a.where(lambda x: x.x + x.y < 4, drop=True) + + array([[ 0., 1., 2., 3.], + [ 5., 6., 7., nan], + [ 10., 11., nan, nan], + [ 15., nan, nan, nan]]) + Dimensions without coordinates: x, y + See also -------- numpy.where : corresponding numpy function From f3c9cfdbd304dc6b7aa3cf27ff533e2a9f6301b3 Mon Sep 17 00:00:00 2001 From: Maximilian Roos Date: Thu, 5 Mar 2020 21:58:28 -0500 Subject: [PATCH 4/4] whatsnew --- doc/whats-new.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4a6083522ba..bb7f3f1d473 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -43,7 +43,9 @@ New Features in 0.14.1) is now on by default. To disable, use ``xarray.set_options(display_style="text")``. By `Julia Signell `_. - +- :py:meth:`Dataset.where` and :py:meth:`DataArray.where` accept a lambda as a + first argument, which is then called on the input; replicating pandas' behavior. + By `Maximilian Roos `_ Bug fixes ~~~~~~~~~