From ef91cf0b2afcba0795c77b3f5022d9ead6b82da9 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 2 Feb 2024 16:50:48 -0700 Subject: [PATCH 01/15] groupby: Dispatch median, quantile to flox. --- doc/whats-new.rst | 2 + xarray/core/_aggregations.py | 226 +++++++++++---------------- xarray/core/groupby.py | 34 ++-- xarray/util/generate_aggregations.py | 30 ++-- 4 files changed, 137 insertions(+), 155 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index dc4fb7ae722..cbc74cdb8c4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,6 +38,8 @@ New Features for the calculation of nanquantiles (i.e., `skipna=True`) if it is installed. This is currently limited to the linear interpolation method (`method='linear'`). (:issue:`7377`, :pull:`8684`) By `Marco Wolsza `_. +- Grouped and resampled median and quantile calculations now use ``flox>=0.9.1`` if present. + By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index bee6afd5a19..12e95f281af 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -2392,8 +2392,6 @@ def count( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2490,8 +2488,6 @@ def all( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2588,8 +2584,6 @@ def any( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2692,8 +2686,6 @@ def max( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2808,8 +2800,6 @@ def min( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -2924,8 +2914,6 @@ def mean( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3049,8 +3037,6 @@ def prod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3186,8 +3172,6 @@ def sum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3320,8 +3304,6 @@ def std( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3454,8 +3436,6 @@ def var( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3584,8 +3564,6 @@ def median( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3628,14 +3606,30 @@ def median( Data variables: da (labels) float64 24B nan 2.0 1.5 """ - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if ( + flox_available + and OPTIONS["use_flox"] + and module_available("flox", minversion="0.9.1") + and contains_only_chunked_or_numpy(self._obj) + ): + return self._flox_reduce( + func="median", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, @@ -3687,8 +3681,6 @@ def cumsum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3788,8 +3780,6 @@ def cumprod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -3919,8 +3909,6 @@ def count( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4017,8 +4005,6 @@ def all( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4115,8 +4101,6 @@ def any( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4219,8 +4203,6 @@ def max( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4335,8 +4317,6 @@ def min( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -4451,8 +4431,6 @@ def mean( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4576,8 +4554,6 @@ def prod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4713,8 +4689,6 @@ def sum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4847,8 +4821,6 @@ def std( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -4981,8 +4953,6 @@ def var( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5111,8 +5081,6 @@ def median( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5155,14 +5123,30 @@ def median( Data variables: da (time) float64 24B 1.0 2.0 nan """ - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if ( + flox_available + and OPTIONS["use_flox"] + and module_available("flox", minversion="0.9.1") + and contains_only_chunked_or_numpy(self._obj) + ): + return self._flox_reduce( + func="median", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, @@ -5214,8 +5198,6 @@ def cumsum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5315,8 +5297,6 @@ def cumprod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -5446,8 +5426,6 @@ def count( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5537,8 +5515,6 @@ def all( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5628,8 +5604,6 @@ def any( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5725,8 +5699,6 @@ def max( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5832,8 +5804,6 @@ def min( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Examples @@ -5939,8 +5909,6 @@ def mean( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6055,8 +6023,6 @@ def prod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6181,8 +6147,6 @@ def sum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6304,8 +6268,6 @@ def std( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6427,8 +6389,6 @@ def var( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6546,8 +6506,6 @@ def median( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6583,13 +6541,28 @@ def median( Coordinates: * labels (labels) object 24B 'a' 'b' 'c' """ - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if ( + flox_available + and OPTIONS["use_flox"] + and module_available("flox", minversion="0.9.1") + and contains_only_chunked_or_numpy(self._obj) + ): + return self._flox_reduce( + func="median", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, @@ -6641,8 +6614,6 @@ def cumsum( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6738,8 +6709,6 @@ def cumprod( Use the ``flox`` package to significantly speed up groupby computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - other methods might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -6865,8 +6834,6 @@ def count( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -6956,8 +6923,6 @@ def all( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7047,8 +7012,6 @@ def any( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7144,8 +7107,6 @@ def max( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7251,8 +7212,6 @@ def min( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Examples @@ -7358,8 +7317,6 @@ def mean( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7474,8 +7431,6 @@ def prod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7600,8 +7555,6 @@ def sum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7723,8 +7676,6 @@ def std( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7846,8 +7797,6 @@ def var( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -7965,8 +7914,6 @@ def median( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -8002,13 +7949,28 @@ def median( Coordinates: * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if ( + flox_available + and OPTIONS["use_flox"] + and module_available("flox", minversion="0.9.1") + and contains_only_chunked_or_numpy(self._obj) + ): + return self._flox_reduce( + func="median", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, @@ -8060,8 +8022,6 @@ def cumsum( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. @@ -8157,8 +8117,6 @@ def cumprod( Use the ``flox`` package to significantly speed up resampling computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. - The default choice is ``method="cohorts"`` which generalizes the best, - ``method="blockwise"`` might work better for your problem. See the `flox documentation `_ for more. Non-numeric variables will be removed prior to reducing. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3aabf618a20..e99308ebddc 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -26,15 +26,17 @@ filter_indexes_from_coords, safe_cast_to_index, ) -from xarray.core.options import _get_keep_attrs +from xarray.core.options import OPTIONS, _get_keep_attrs from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray from xarray.core.utils import ( FrozenMappingWarningOnValuesAccess, + contains_only_chunked_or_numpy, either_dict_or_kwargs, emit_user_level_warning, hashable, is_scalar, maybe_wrap_array, + module_available, peek_at, ) from xarray.core.variable import IndexVariable, Variable @@ -1267,16 +1269,26 @@ def quantile( (grouper,) = self.groupers dim = grouper.group1d.dims - return self.map( - self._obj.__class__.quantile, - shortcut=False, - q=q, - dim=dim, - method=method, - keep_attrs=keep_attrs, - skipna=skipna, - interpolation=interpolation, - ) + if ( + method == "linear" + and OPTIONS["use_flox"] + and contains_only_chunked_or_numpy(self._obj) + and module_available("flox", minversion="0.9.1") + ): + return self._flox_reduce( + func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna + ) + else: + return self.map( + self._obj.__class__.quantile, + shortcut=False, + q=q, + dim=dim, + method=method, + keep_attrs=keep_attrs, + skipna=skipna, + interpolation=interpolation, + ) def where(self, cond, other=dtypes.NA) -> T_Xarray: """Return elements from `self` or `other` depending on `cond`. diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index 3462af28663..d2ae1655ad2 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -19,6 +19,7 @@ MODULE_PREAMBLE = '''\ """Mixin classes with reduction operations.""" + # This file was generated using xarray.util.generate_aggregations. Do not edit manually. from __future__ import annotations @@ -245,13 +246,9 @@ def {method}( _FLOX_NOTES_TEMPLATE = """Use the ``flox`` package to significantly speed up {kind} computations, especially with dask arrays. Xarray will use flox by default if installed. Pass flox-specific keyword arguments in ``**kwargs``. -The default choice is ``method="cohorts"`` which generalizes the best, -{recco} might work better for your problem. See the `flox documentation `_ for more.""" -_FLOX_GROUPBY_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="groupby", recco="other methods") -_FLOX_RESAMPLE_NOTES = _FLOX_NOTES_TEMPLATE.format( - kind="resampling", recco='``method="blockwise"``' -) +_FLOX_GROUPBY_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="groupby") +_FLOX_RESAMPLE_NOTES = _FLOX_NOTES_TEMPLATE.format(kind="resampling") ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example") skipna = ExtraKwarg( @@ -300,11 +297,13 @@ def __init__( extra_kwargs=tuple(), numeric_only=False, see_also_modules=("numpy", "dask.array"), + min_flox_version=None, ): self.name = name self.extra_kwargs = extra_kwargs self.numeric_only = numeric_only self.see_also_modules = see_also_modules + self.min_flox_version = min_flox_version if bool_reduce: self.array_method = f"array_{name}" self.np_example_array = """ @@ -445,7 +444,7 @@ def generate_code(self, method, has_keep_attrs): # numpy_groupies & flox do not support median # https://github.com/ml31415/numpy-groupies/issues/43 - method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod") + method_is_not_flox_supported = method.name in ("cumsum", "cumprod") if method_is_not_flox_supported: indent = 12 else: @@ -466,10 +465,18 @@ def generate_code(self, method, has_keep_attrs): )""" else: - return f"""\ + return ( + """\ if ( flox_available - and OPTIONS["use_flox"] + and OPTIONS["use_flox"]""" + + ( + f""" + and module_available("flox", minversion="{method.min_flox_version}")""" + if method.min_flox_version is not None + else "" + ) + + f""" and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( @@ -486,6 +493,7 @@ def generate_code(self, method, has_keep_attrs): keep_attrs=keep_attrs, **kwargs, )""" + ) class GenericAggregationGenerator(AggregationGenerator): @@ -522,7 +530,9 @@ def generate_code(self, method, has_keep_attrs): Method("sum", extra_kwargs=(skipna, min_count), numeric_only=True), Method("std", extra_kwargs=(skipna, ddof), numeric_only=True), Method("var", extra_kwargs=(skipna, ddof), numeric_only=True), - Method("median", extra_kwargs=(skipna,), numeric_only=True), + Method( + "median", extra_kwargs=(skipna,), numeric_only=True, min_flox_version="0.9.1" + ), # Cumulatives: Method("cumsum", extra_kwargs=(skipna,), numeric_only=True), Method("cumprod", extra_kwargs=(skipna,), numeric_only=True), From 314820919593490efa95f9d2b39c08eb0e8573c0 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 8 Feb 2024 13:00:56 -0700 Subject: [PATCH 02/15] bump min flox version --- xarray/core/groupby.py | 2 +- xarray/util/generate_aggregations.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e99308ebddc..3d0b25149a6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1273,7 +1273,7 @@ def quantile( method == "linear" and OPTIONS["use_flox"] and contains_only_chunked_or_numpy(self._obj) - and module_available("flox", minversion="0.9.1") + and module_available("flox", minversion="0.9.2") ): return self._flox_reduce( func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index d2ae1655ad2..a5117751625 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -531,7 +531,7 @@ def generate_code(self, method, has_keep_attrs): Method("std", extra_kwargs=(skipna, ddof), numeric_only=True), Method("var", extra_kwargs=(skipna, ddof), numeric_only=True), Method( - "median", extra_kwargs=(skipna,), numeric_only=True, min_flox_version="0.9.1" + "median", extra_kwargs=(skipna,), numeric_only=True, min_flox_version="0.9.2" ), # Cumulatives: Method("cumsum", extra_kwargs=(skipna,), numeric_only=True), From a1e176bcc9d43a102e7efd3c564947b96e5282af Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 12 Feb 2024 20:24:00 -0700 Subject: [PATCH 03/15] Add test for chunked dataarrays --- xarray/tests/test_groupby.py | 43 ++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b65c01fe76d..7b0acafedf4 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -245,6 +245,49 @@ def test_da_groupby_empty() -> None: empty_array.groupby("dim") +def test_dask_da_groupby_quantile() -> None: + # Only works when the grouped reduction can run blockwise + # Scalar quantile + expected = xr.DataArray( + data=[2, 5], coords={"x": [1, 2], "quantile": 0.5}, dims="x" + ) + array = xr.DataArray( + data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + with pytest.raises(ValueError): + array.chunk(x=1).groupby("x").quantile(0.5) + + # will work blockwise with flox + actual = array.chunk(x=3).groupby("x").quantile(0.5) + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=-1).groupby("x").quantile(0.5) + assert_identical(expected, actual) + + +def test_dask_da_groupby_median() -> None: + expected = xr.DataArray(data=[2, 5], coords={"x": [1, 2]}, dims="x") + array = xr.DataArray( + data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" + ) + with xr.set_options(use_flox=False): + actual = array.chunk(x=1).groupby("x").median() + assert_identical(expected, actual) + + with xr.set_options(use_flox=True): + actual = array.chunk(x=1).groupby("x").median() + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=3).groupby("x").median() + assert_identical(expected, actual) + + # will work blockwise with flox + actual = array.chunk(x=-1).groupby("x").median() + assert_identical(expected, actual) + + def test_da_groupby_quantile() -> None: array = xr.DataArray( data=[1, 2, 3, 4, 5, 6], coords={"x": [1, 1, 1, 2, 2, 2]}, dims="x" From cc94a94509153e7886bf7c658ea963efe836c05a Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Mar 2024 20:03:37 -0600 Subject: [PATCH 04/15] Cleanup --- xarray/core/_aggregations.py | 8 ++++---- xarray/util/generate_aggregations.py | 19 ++++++++----------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 12e95f281af..751b08507fb 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -3609,7 +3609,7 @@ def median( if ( flox_available and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.1") + and module_available("flox", minversion="0.9.2") and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( @@ -5126,7 +5126,7 @@ def median( if ( flox_available and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.1") + and module_available("flox", minversion="0.9.2") and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( @@ -6544,7 +6544,7 @@ def median( if ( flox_available and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.1") + and module_available("flox", minversion="0.9.2") and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( @@ -7952,7 +7952,7 @@ def median( if ( flox_available and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.1") + and module_available("flox", minversion="0.9.2") and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index a5117751625..3b195905787 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -464,19 +464,16 @@ def generate_code(self, method, has_keep_attrs): **kwargs, )""" - else: - return ( - """\ + min_version_check = f""" + and module_available("flox", minversion="{method.min_flox_version}")""" + + return ( + """\ if ( flox_available and OPTIONS["use_flox"]""" - + ( - f""" - and module_available("flox", minversion="{method.min_flox_version}")""" - if method.min_flox_version is not None - else "" - ) - + f""" + + (min_version_check if method.min_flox_version is not None else "") + + f""" and contains_only_chunked_or_numpy(self._obj) ): return self._flox_reduce( @@ -493,7 +490,7 @@ def generate_code(self, method, has_keep_attrs): keep_attrs=keep_attrs, **kwargs, )""" - ) + ) class GenericAggregationGenerator(AggregationGenerator): From 7a9933aaab178ef0394bd8b702a74fef45d33f61 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Mar 2024 20:06:14 -0600 Subject: [PATCH 05/15] Disable median for now. --- xarray/core/_aggregations.py | 122 +++++++-------------------- xarray/util/generate_aggregations.py | 6 +- 2 files changed, 33 insertions(+), 95 deletions(-) diff --git a/xarray/core/_aggregations.py b/xarray/core/_aggregations.py index 751b08507fb..96f860b3209 100644 --- a/xarray/core/_aggregations.py +++ b/xarray/core/_aggregations.py @@ -3606,30 +3606,14 @@ def median( Data variables: da (labels) float64 24B nan 2.0 1.5 """ - if ( - flox_available - and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.2") - and contains_only_chunked_or_numpy(self._obj) - ): - return self._flox_reduce( - func="median", - dim=dim, - skipna=skipna, - numeric_only=True, - # fill_value=fill_value, - keep_attrs=keep_attrs, - **kwargs, - ) - else: - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, @@ -5123,30 +5107,14 @@ def median( Data variables: da (time) float64 24B 1.0 2.0 nan """ - if ( - flox_available - and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.2") - and contains_only_chunked_or_numpy(self._obj) - ): - return self._flox_reduce( - func="median", - dim=dim, - skipna=skipna, - numeric_only=True, - # fill_value=fill_value, - keep_attrs=keep_attrs, - **kwargs, - ) - else: - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, @@ -6541,28 +6509,13 @@ def median( Coordinates: * labels (labels) object 24B 'a' 'b' 'c' """ - if ( - flox_available - and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.2") - and contains_only_chunked_or_numpy(self._obj) - ): - return self._flox_reduce( - func="median", - dim=dim, - skipna=skipna, - # fill_value=fill_value, - keep_attrs=keep_attrs, - **kwargs, - ) - else: - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, @@ -7949,28 +7902,13 @@ def median( Coordinates: * time (time) datetime64[ns] 24B 2001-01-31 2001-04-30 2001-07-31 """ - if ( - flox_available - and OPTIONS["use_flox"] - and module_available("flox", minversion="0.9.2") - and contains_only_chunked_or_numpy(self._obj) - ): - return self._flox_reduce( - func="median", - dim=dim, - skipna=skipna, - # fill_value=fill_value, - keep_attrs=keep_attrs, - **kwargs, - ) - else: - return self._reduce_without_squeeze_warn( - duck_array_ops.median, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + return self._reduce_without_squeeze_warn( + duck_array_ops.median, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def cumsum( self, diff --git a/xarray/util/generate_aggregations.py b/xarray/util/generate_aggregations.py index 3b195905787..b59dc36c108 100644 --- a/xarray/util/generate_aggregations.py +++ b/xarray/util/generate_aggregations.py @@ -442,9 +442,9 @@ def generate_code(self, method, has_keep_attrs): if self.datastructure.numeric_only: extra_kwargs.append(f"numeric_only={method.numeric_only},") - # numpy_groupies & flox do not support median - # https://github.com/ml31415/numpy-groupies/issues/43 - method_is_not_flox_supported = method.name in ("cumsum", "cumprod") + # median isn't enabled yet, because it would break if a single group was present in multiple + # chunks. The non-flox code path will just rechunk every group to a single chunk and execute the median + method_is_not_flox_supported = method.name in ("median", "cumsum", "cumprod") if method_is_not_flox_supported: indent = 12 else: From 6fa25a72ca32d7cf7e707711fc1b0271c19963f2 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Mar 2024 20:06:56 -0600 Subject: [PATCH 06/15] update whats-new --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index cbc74cdb8c4..d4c928e1533 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -38,7 +38,7 @@ New Features for the calculation of nanquantiles (i.e., `skipna=True`) if it is installed. This is currently limited to the linear interpolation method (`method='linear'`). (:issue:`7377`, :pull:`8684`) By `Marco Wolsza `_. -- Grouped and resampled median and quantile calculations now use ``flox>=0.9.1`` if present. +- Grouped and resampling quantile calculations now use ``flox>=0.9.2`` if present. By `Deepak Cherian `_. Breaking changes From 7f7eb3cb5af52eab05a87ba2c68a68ce47de0e71 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Mar 2024 20:08:52 -0600 Subject: [PATCH 07/15] update whats-new --- doc/whats-new.rst | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index dc913de9f49..0089af2abcb 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,22 +23,7 @@ v2024.03.0 (unreleased) New Features ~~~~~~~~~~~~ -- Added a simple `nbytes` representation in DataArrays and Dataset `repr`. - (:issue:`8690`, :pull:`8702`). - By `Etienne Schalk `_. -- Allow negative frequency strings (e.g. ``"-1YE"``). These strings are for example used - in :py:func:`date_range`, and :py:func:`cftime_range` (:pull:`8651`). - By `Mathias Hauser `_. -- Add :py:meth:`NamedArray.expand_dims`, :py:meth:`NamedArray.permute_dims` and :py:meth:`NamedArray.broadcast_to` - (:pull:`8380`) By `Anderson Banihirwe `_. -- Xarray now defers to flox's `heuristics `_ - to set default `method` for groupby problems. This only applies to ``flox>=0.9``. - By `Deepak Cherian `_. -- All `quantile` methods (e.g. :py:meth:`DataArray.quantile`) now use `numbagg` - for the calculation of nanquantiles (i.e., `skipna=True`) if it is installed. - This is currently limited to the linear interpolation method (`method='linear'`). - (:issue:`7377`, :pull:`8684`) By `Marco Wolsza `_. -- Grouped and resampling quantile calculations now use ``flox>=0.9.2`` if present. +- Grouped and resampling quantile calculations now use the vectorized algorithm in ``flox>=0.9.2`` if present. By `Deepak Cherian `_. - Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False`` (:issue:`6806`, :pull:`8784`). From c30494ff94e5db2ee254855e607faeb40b43704e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Mar 2024 22:25:00 -0600 Subject: [PATCH 08/15] bump min flox version --- xarray/core/groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0bc45f92bf1..0bb49db8b81 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1382,7 +1382,7 @@ def quantile( method == "linear" and OPTIONS["use_flox"] and contains_only_chunked_or_numpy(self._obj) - and module_available("flox", minversion="0.9.2") + and module_available("flox", minversion="0.9.4") ): return self._flox_reduce( func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna From 07d75cdc46c3fe18cf384fdc6cae1cb908453be7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Fri, 15 Mar 2024 22:29:48 -0600 Subject: [PATCH 09/15] add requires_dask --- xarray/tests/test_groupby.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 0ab37a75db9..045e1223b7d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -245,6 +245,7 @@ def test_da_groupby_empty() -> None: empty_array.groupby("dim") +@requires_dask def test_dask_da_groupby_quantile() -> None: # Only works when the grouped reduction can run blockwise # Scalar quantile @@ -266,6 +267,7 @@ def test_dask_da_groupby_quantile() -> None: assert_identical(expected, actual) +@requires_dask def test_dask_da_groupby_median() -> None: expected = xr.DataArray(data=[2, 5], coords={"x": [1, 2]}, dims="x") array = xr.DataArray( From 09935f48a40146859f6ced5fc41b969320221d35 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Mar 2024 22:22:11 -0600 Subject: [PATCH 10/15] restore dim order --- xarray/core/groupby.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 0bb49db8b81..34832814aad 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1374,6 +1374,8 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ + from xarray.core.dataarray import DataArray + if dim is None: (grouper,) = self.groupers dim = grouper.group1d.dims @@ -1384,9 +1386,13 @@ def quantile( and contains_only_chunked_or_numpy(self._obj) and module_available("flox", minversion="0.9.4") ): - return self._flox_reduce( + + result = self._flox_reduce( func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna ) + if isinstance(result, DataArray) and not is_scalar(q): + result = self._restore_dim_order(result) + return result else: return self.map( self._obj.__class__.quantile, From 734404178418bb2ee1e634efc32d05e40c3c1d45 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Mon, 18 Mar 2024 22:23:31 -0600 Subject: [PATCH 11/15] fix whats-new --- doc/whats-new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0089af2abcb..781dac558e4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,7 +23,7 @@ v2024.03.0 (unreleased) New Features ~~~~~~~~~~~~ -- Grouped and resampling quantile calculations now use the vectorized algorithm in ``flox>=0.9.2`` if present. +- Grouped and resampling quantile calculations now use the vectorized algorithm in ``flox>=0.9.4`` if present. By `Deepak Cherian `_. - Do not broadcast in arithmetic operations when global option ``arithmetic_broadcast=False`` (:issue:`6806`, :pull:`8784`). From 4806f5a5b9eb855e2981c1361073b8450fc84657 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Mar 2024 07:46:31 -0600 Subject: [PATCH 12/15] Fix doctest --- xarray/core/groupby.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 34832814aad..3a7a325b49d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1380,13 +1380,15 @@ def quantile( (grouper,) = self.groupers dim = grouper.group1d.dims + # Dataset.quantile does this, do it for flox to ensure same output. + q = np.asarray(q, dtype=np.float64) + if ( method == "linear" and OPTIONS["use_flox"] and contains_only_chunked_or_numpy(self._obj) and module_available("flox", minversion="0.9.4") ): - result = self._flox_reduce( func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna ) From 1047263ee1f20b56e687c1057ace9d6c238a586e Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Mar 2024 07:49:23 -0600 Subject: [PATCH 13/15] cleanup --- xarray/core/groupby.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3a7a325b49d..ef88e47feb5 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1211,13 +1211,9 @@ def _flox_reduce( (result.sizes[grouper.name],) + var.shape, ) - if isbin: - # Fix dimension order when binning a dimension coordinate - # Needed as long as we do a separate code path for pint; - # For some reason Datasets and DataArrays behave differently! - (group_dim,) = grouper.dims - if isinstance(self._obj, Dataset) and group_dim in self._obj.dims: - result = result.transpose(grouper.name, ...) + if not isinstance(result, Dataset): + # only restore dimension order for arrays + result = self._restore_dim_order(result) return result @@ -1374,7 +1370,6 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - from xarray.core.dataarray import DataArray if dim is None: (grouper,) = self.groupers @@ -1392,8 +1387,6 @@ def quantile( result = self._flox_reduce( func="quantile", q=q, dim=dim, keep_attrs=keep_attrs, skipna=skipna ) - if isinstance(result, DataArray) and not is_scalar(q): - result = self._restore_dim_order(result) return result else: return self.map( From b6c95976174a8fb9e8d0ab6c3c9d2bf9e575cfce Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Mar 2024 07:50:02 -0600 Subject: [PATCH 14/15] Update xarray/core/groupby.py --- xarray/core/groupby.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ef88e47feb5..2e15adae208 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1370,7 +1370,6 @@ def quantile( "Sample quantiles in statistical packages," The American Statistician, 50(4), pp. 361-365, 1996 """ - if dim is None: (grouper,) = self.groupers dim = grouper.group1d.dims From fe676d2e4b161d5f8d24e3e04842cf89d5ab3d94 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Mar 2024 08:30:44 -0600 Subject: [PATCH 15/15] Fix mypy --- xarray/core/groupby.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2e15adae208..5966c32df92 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -1077,6 +1077,9 @@ def _binary_op(self, other, f, reflexive=False): result[var] = result[var].transpose(d, ...) return result + def _restore_dim_order(self, stacked): + raise NotImplementedError + def _maybe_restore_empty_groups(self, combined): """Our index contained empty groups (e.g., from a resampling or binning). If we reduced on that dimension, we want to restore the full index.