diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e73a1a7fa62..c7f2d1aca26 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,9 +23,12 @@ New Features ~~~~~~~~~~~~ - :py:meth:`DataArray.where` & :py:meth:`Dataset.where` accept a callable for - the ``other`` parameter, passing the object as the first argument. Previously, + the ``other`` parameter, passing the object as the only argument. Previously, this was only valid for the ``cond`` parameter. (:issue:`8255`) By `Maximilian Roos `_. +- :py:meth:`DataArray.sortby` & :py:meth:`Dataset.sortby` accept a callable for + the ``variables`` parameter, passing the object as the only argument. + By `Maximilian Roos `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/common.py b/xarray/core/common.py index 2a4c4c200d4..f571576850c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1073,7 +1073,8 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self: ---------- cond : DataArray, Dataset, or callable Locations at which to preserve this object's values. dtype must be `bool`. - If a callable, it must expect this object as its only parameter. + If a callable, the callable is passed this object, and the result is used as + the value for cond. other : scalar, DataArray, Dataset, or callable, optional Value to use for locations in this object where ``cond`` is False. By default, these locations are filled with NA. If a callable, it must diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ef4389f3c6c..8b3e999b78c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -4921,7 +4921,10 @@ def dot( def sortby( self, - variables: Hashable | DataArray | Sequence[Hashable | DataArray], + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | Sequence[Hashable | DataArray]], ascending: bool = True, ) -> Self: """Sort object by labels or values (along an axis). @@ -4942,9 +4945,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or sequence of Hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords whose values are used to sort this array. + variables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -4964,22 +4968,33 @@ def sortby( Examples -------- >>> da = xr.DataArray( - ... np.random.rand(5), + ... np.arange(5, 0, -1), ... coords=[pd.date_range("1/1/2000", periods=5)], ... dims="time", ... ) >>> da - array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ]) + array([5, 4, 3, 2, 1]) Coordinates: * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 >>> da.sortby(da) - array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937]) + array([1, 2, 3, 4, 5]) Coordinates: - * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02 + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 + + >>> da.sortby(lambda x: x) + + array([1, 2, 3, 4, 5]) + Coordinates: + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-01 """ + # We need to convert the callable here rather than pass it through to the + # dataset method, since otherwise the dataset method would try to call the + # callable with the dataset as the object + if callable(variables): + variables = variables(self) ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 459e2f3fce7..533eed2c848 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7824,7 +7824,10 @@ def roll( def sortby( self, - variables: Hashable | DataArray | list[Hashable | DataArray], + variables: Hashable + | DataArray + | Sequence[Hashable | DataArray] + | Callable[[Self], Hashable | DataArray | list[Hashable | DataArray]], ascending: bool = True, ) -> Self: """ @@ -7846,9 +7849,10 @@ def sortby( Parameters ---------- - variables : Hashable, DataArray, or list of hashable or DataArray - 1D DataArray objects or name(s) of 1D variable(s) in - coords/data_vars whose values are used to sort the dataset. + kariables : Hashable, DataArray, sequence of Hashable or DataArray, or Callable + 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are + used to sort this array. If a callable, the callable is passed this object, + and the result is used as the value for cond. ascending : bool, default: True Whether to sort by ascending or descending order. @@ -7874,8 +7878,7 @@ def sortby( ... }, ... coords={"x": ["b", "a"], "y": [1, 0]}, ... ) - >>> ds = ds.sortby("x") - >>> ds + >>> ds.sortby("x") Dimensions: (x: 2, y: 2) Coordinates: @@ -7884,9 +7887,20 @@ def sortby( Data variables: A (x, y) int64 3 4 1 2 B (x, y) int64 7 8 5 6 + >>> ds.sortby(lambda x: -x["y"]) + + Dimensions: (x: 2, y: 2) + Coordinates: + * x (x)