From b40e62ab120ccec90e56e1517990fee4ca2ae0a3 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 14:45:33 +0200 Subject: [PATCH 01/36] initial groupby support --- xarray/core/alignment.py | 2 +- xarray/core/common.py | 31 +++++--- xarray/core/dataarray.py | 166 ++++++++++++++++++++++++++++++++++++++- xarray/core/dataset.py | 150 +++++++++++++++++++++++++++++++++-- xarray/core/groupby.py | 139 +++++++++++++++++++------------- xarray/core/resample.py | 6 +- xarray/core/utils.py | 2 +- xarray/core/variable.py | 8 +- 8 files changed, 419 insertions(+), 85 deletions(-) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index df8b3c24a91..aed41e05777 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -215,7 +215,7 @@ def _normalize_indexes( normalized_index_vars = {} for idx, index_vars in Indexes(xr_indexes, xr_variables).group_by_index(): coord_names_and_dims = [] - all_dims = set() + all_dims: set[Hashable] = set() for name, var in index_vars.items(): dims = var.dims diff --git a/xarray/core/common.py b/xarray/core/common.py index 3c328f42e98..8fb6dc20875 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -39,9 +39,10 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset + from .groupby import GroupBy from .indexes import Index from .types import ScalarOrArray, T_DataWithCoords, T_Xarray - from .variable import Variable + from .variable import IndexVariable, Variable from .weighted import Weighted @@ -355,6 +356,7 @@ class DataWithCoords(AttrAccessMixin): _close: Callable[[], None] | None _indexes: dict[Hashable, Index] + _groupby_cls: type[GroupBy] __slots__ = ("_close",) @@ -749,26 +751,29 @@ def pipe( return func(self, *args, **kwargs) def groupby( - self, group: Any, squeeze: bool = True, restore_coord_dims: bool | None = None - ): + self, + group: Hashable | DataArray | IndexVariable, + squeeze: bool = True, + restore_coord_dims: bool = False, + ) -> GroupBy: """Returns a GroupBy object for performing grouped operations. Parameters ---------- - group : str, DataArray or IndexVariable + group : Hashable, DataArray or IndexVariable Array whose unique values should be used to group this array. If a string, must be the name of a variable contained in this dataset. squeeze : bool, default: True If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along that dimension or if the dimension is squeezed out. - restore_coord_dims : bool, optional + restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. Returns ------- - grouped + grouped : GroupBy A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. @@ -816,15 +821,15 @@ def groupby( def groupby_bins( self, - group, + group: Hashable | DataArray | IndexVariable, bins, right: bool = True, labels=None, precision: int = 3, include_lowest: bool = False, squeeze: bool = True, - restore_coord_dims: bool = None, - ): + restore_coord_dims: bool = False, + ) -> GroupBy: """Returns a GroupBy object for performing grouped operations. Rather than using all unique values of `group`, the values are discretized @@ -832,7 +837,7 @@ def groupby_bins( Parameters ---------- - group : str, DataArray or IndexVariable + group : Hashable, DataArray or IndexVariable Array whose binned values should be used to group this array. If a string, must be the name of a variable contained in this dataset. bins : int or array-like @@ -849,15 +854,15 @@ def groupby_bins( Used as labels for the resulting bins. Must be of the same length as the resulting bins. If False, string bin labels are assigned by `pandas.cut`. - precision : int + precision : int, default: 3 The precision at which to store and display the bins labels. - include_lowest : bool + include_lowest : bool, default: False Whether the first interval should be left-inclusive or not. squeeze : bool, default: True If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along that dimension or if the dimension is squeezed out. - restore_coord_dims : bool, optional + restore_coord_dims : bool, default: False If True, also restore the dimension order of multi-dimensional coordinates. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index fc4ff2d5783..8d9a7868d37 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -28,7 +28,6 @@ alignment, computation, dtypes, - groupby, indexing, ops, resample, @@ -83,6 +82,7 @@ iris_Cube = None from ..backends.api import T_NetcdfEngine, T_NetcdfTypes + from .groupby import DataArrayGroupBy from .types import ( DatetimeUnitOptions, ErrorOptions, @@ -367,7 +367,6 @@ class DataArray( "__weakref__", ) - _groupby_cls = groupby.DataArrayGroupBy _rolling_cls = rolling.DataArrayRolling _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample @@ -3517,7 +3516,9 @@ def _binary_op( f: Callable, reflexive: bool = False, ) -> T_DataArray: - if isinstance(other, (Dataset, groupby.GroupBy)): + from .groupby import GroupBy + + if isinstance(other, (Dataset, GroupBy)): return NotImplemented if isinstance(other, DataArray): align_type = OPTIONS["arithmetic_join"] @@ -3536,7 +3537,9 @@ def _binary_op( return self._replace(variable, coords, name, indexes=indexes) def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArray: - if isinstance(other, groupby.GroupBy): + from .groupby import GroupBy + + if isinstance(other, GroupBy): raise TypeError( "in-place operations between a DataArray and " "a grouped object are not permitted" @@ -5305,6 +5308,161 @@ def interp_calendar( """ return interp_calendar(self, target, dim=dim) + def groupby( + self, + group: Hashable | DataArray | IndexVariable, + squeeze: bool = True, + restore_coord_dims: bool = False, + ) -> DataArrayGroupBy: + """Returns a DataArrayGroupBy object for performing grouped operations. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DataArrayGroupBy + A `DataArrayGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + + Examples + -------- + Calculate daily anomalies for daily data: + + >>> da = xr.DataArray( + ... np.linspace(0, 1826, num=1827), + ... coords=[pd.date_range("1/1/2000", "31/12/2004", freq="D")], + ... dims="time", + ... ) + >>> da + + array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, + 1.826e+03]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 + >>> da.groupby("time.dayofyear") - da.groupby("time.dayofyear").mean("time") + + array([-730.8, -730.8, -730.8, ..., 730.2, 730.2, 730.5]) + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 + dayofyear (time) int64 1 2 3 4 5 6 7 8 ... 359 360 361 362 363 364 365 366 + + See Also + -------- + DataArray.groupby_bins + Dataset.groupby + core.groupby.DataArrayGroupBy + pandas.DataFrame.groupby + """ + from .groupby import DataArrayGroupBy + + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be True or False, but {squeeze} was supplied" + ) + + return DataArrayGroupBy( + self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) + + def groupby_bins( + self, + group: Hashable | DataArray | IndexVariable, + bins, + right: bool = True, + labels=None, + precision: int = 3, + include_lowest: bool = False, + squeeze: bool = True, + restore_coord_dims: bool = False, + ) -> DataArrayGroupBy: + """Returns a DataArrayGroupBy object for performing grouped operations. + + Rather than using all unique values of `group`, the values are discretized + first by applying `pandas.cut` [1]_ to `group`. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose binned values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + bins : int or array-like + If bins is an int, it defines the number of equal-width bins in the + range of x. However, in this case, the range of x is extended by .1% + on each side to include the min or max values of x. If bins is a + sequence it defines the bin edges allowing for non-uniform bin + width. No extension of the range of x is done in this case. + right : bool, default: True + Indicates whether the bins include the rightmost edge or not. If + right == True (the default), then the bins [1,2,3,4] indicate + (1,2], (2,3], (3,4]. + labels : array-like or bool, default: None + Used as labels for the resulting bins. Must be of the same length as + the resulting bins. If False, string bin labels are assigned by + `pandas.cut`. + precision : int, default: 3 + The precision at which to store and display the bins labels. + include_lowest : bool, default: False + Whether the first interval should be left-inclusive or not. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DataArrayGroupBy + A `DataArrayGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + The name of the group has the added suffix `_bins` in order to + distinguish it from the original variable. + + See Also + -------- + DataArray.groupby + Dataset.groupby_bins + core.groupby.DataArrayGroupBy + pandas.DataFrame.groupby + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html + """ + from .groupby import DataArrayGroupBy + + return DataArrayGroupBy( + self, + group, + squeeze=squeeze, + bins=bins, + restore_coord_dims=restore_coord_dims, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4e2caf8e1cb..15fef68302e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -41,7 +41,6 @@ duck_array_ops, formatting, formatting_html, - groupby, ops, resample, rolling, @@ -106,6 +105,7 @@ from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .coordinates import Coordinates from .dataarray import DataArray + from .groupby import DatasetGroupBy from .merge import CoercibleMapping from .types import ( CFCalendar, @@ -563,7 +563,6 @@ class Dataset( "__weakref__", ) - _groupby_cls = groupby.DatasetGroupBy _rolling_cls = rolling.DatasetRolling _coarsen_cls = rolling.DatasetCoarsen _resample_cls = resample.DatasetResample @@ -2078,7 +2077,7 @@ def info(self, buf: IO | None = None) -> None: lines.append(f"\t{name} = {size} ;") lines.append("\nvariables:") for name, da in self.variables.items(): - dims = ", ".join(da.dims) + dims = ", ".join(map(str, da.dims)) lines.append(f"\t{da.dtype} {name}({dims}) ;") for k, v in da.attrs.items(): lines.append(f"\t\t{name}:{k} = {v} ;") @@ -5542,7 +5541,7 @@ def reduce( if len(reduce_dims) == 1: # unpack dimensions for the benefit of functions # like np.argmin which can't handle tuple arguments - (reduce_dims,) = reduce_dims + (reduce_dims,) = reduce_dims # type: ignore[assignment] elif len(reduce_dims) == var.ndim: # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because @@ -6241,8 +6240,9 @@ def _unary_op(self: T_Dataset, f, *args, **kwargs) -> T_Dataset: def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: from .dataarray import DataArray + from .groupby import GroupBy - if isinstance(other, groupby.GroupBy): + if isinstance(other, GroupBy): return NotImplemented align_type = OPTIONS["arithmetic_join"] if join is None else join if isinstance(other, (DataArray, Dataset)): @@ -6253,8 +6253,9 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset: def _inplace_binary_op(self: T_Dataset, other, f) -> T_Dataset: from .dataarray import DataArray + from .groupby import GroupBy - if isinstance(other, groupby.GroupBy): + if isinstance(other, GroupBy): raise TypeError( "in-place operations between a Dataset and " "a grouped object are not permitted" @@ -6930,7 +6931,9 @@ def differentiate( dim = coord_var.dims[0] if _contains_datetime_like_objects(coord_var): if coord_var.dtype.kind in "mM" and datetime_unit is None: - datetime_unit, _ = np.datetime_data(coord_var.dtype) + datetime_unit = cast( + "DatetimeUnitOptions", np.datetime_data(coord_var.dtype)[0] + ) elif datetime_unit is None: datetime_unit = "s" # Default to seconds for cftime objects coord_var = coord_var._to_numeric(datetime_unit=datetime_unit) @@ -8517,3 +8520,136 @@ def interp_calendar( The source interpolated on the decimal years of target, """ return interp_calendar(self, target, dim=dim) + + def groupby( + self, + group: Hashable | DataArray | IndexVariable, + squeeze: bool = True, + restore_coord_dims: bool = False, + ) -> DatasetGroupBy: + """Returns a DatasetGroupBy object for performing grouped operations. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose unique values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DatasetGroupBy + A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + + See Also + -------- + Dataset.groupby_bins + DataArray.groupby + core.groupby.DatasetGroupBy + pandas.DataFrame.groupby + """ + from .groupby import DatasetGroupBy + + # While we don't generally check the type of every arg, passing + # multiple dimensions as multiple arguments is common enough, and the + # consequences hidden enough (strings evaluate as true) to warrant + # checking here. + # A future version could make squeeze kwarg only, but would face + # backward-compat issues. + if not isinstance(squeeze, bool): + raise TypeError( + f"`squeeze` must be True or False, but {squeeze} was supplied" + ) + + return DatasetGroupBy( + self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims + ) + + def groupby_bins( + self, + group: Hashable | DataArray | IndexVariable, + bins, + right: bool = True, + labels=None, + precision: int = 3, + include_lowest: bool = False, + squeeze: bool = True, + restore_coord_dims: bool = False, + ) -> DatasetGroupBy: + """Returns a DatasetGroupBy object for performing grouped operations. + + Rather than using all unique values of `group`, the values are discretized + first by applying `pandas.cut` [1]_ to `group`. + + Parameters + ---------- + group : Hashable, DataArray or IndexVariable + Array whose binned values should be used to group this array. If a + string, must be the name of a variable contained in this dataset. + bins : int or array-like + If bins is an int, it defines the number of equal-width bins in the + range of x. However, in this case, the range of x is extended by .1% + on each side to include the min or max values of x. If bins is a + sequence it defines the bin edges allowing for non-uniform bin + width. No extension of the range of x is done in this case. + right : bool, default: True + Indicates whether the bins include the rightmost edge or not. If + right == True (the default), then the bins [1,2,3,4] indicate + (1,2], (2,3], (3,4]. + labels : array-like or bool, default: None + Used as labels for the resulting bins. Must be of the same length as + the resulting bins. If False, string bin labels are assigned by + `pandas.cut`. + precision : int, default: 3 + The precision at which to store and display the bins labels. + include_lowest : bool, default: False + Whether the first interval should be left-inclusive or not. + squeeze : bool, default: True + If "group" is a dimension of any arrays in this dataset, `squeeze` + controls whether the subarrays have a dimension of length 1 along + that dimension or if the dimension is squeezed out. + restore_coord_dims : bool, default: False + If True, also restore the dimension order of multi-dimensional + coordinates. + + Returns + ------- + grouped : DatasetGroupBy + A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be + iterated over in the form of `(unique_value, grouped_array)` pairs. + The name of the group has the added suffix `_bins` in order to + distinguish it from the original variable. + + See Also + -------- + Dataset.groupby + DataArray.groupby_bins + core.groupby.DatasetGroupBy + pandas.DataFrame.groupby + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html + """ + from .groupby import DatasetGroupBy + + return DatasetGroupBy( + self, + group, + squeeze=squeeze, + bins=bins, + restore_coord_dims=restore_coord_dims, + cut_kwargs={ + "right": right, + "labels": labels, + "precision": precision, + "include_lowest": include_lowest, + }, + ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 54499c16e90..80724b2e4eb 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -2,7 +2,18 @@ import datetime import warnings -from typing import Any, Callable, Hashable, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Hashable, + Literal, + Sequence, + TypeVar, + Union, + cast, +) import numpy as np import pandas as pd @@ -15,6 +26,7 @@ from .indexes import create_default_index_implicit, filter_indexes_from_coords from .options import _get_keep_attrs from .pycompat import integer_types +from .types import T_Xarray from .utils import ( either_dict_or_kwargs, hashable, @@ -25,6 +37,11 @@ ) from .variable import IndexVariable, Variable +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + from .utils import Frozen + def check_reduce_dims(reduce_dims, dimensions): @@ -38,14 +55,16 @@ def check_reduce_dims(reduce_dims, dimensions): ) -def unique_value_groups(ar, sort=True): +def unique_value_groups( + ar, sort: bool = True +) -> tuple[np.ndarray | pd.Index, list[list[int]]]: """Group an array by its unique values. Parameters ---------- ar : array-like Input array. This will be flattened if it is not already 1-D. - sort : bool, optional + sort : bool, default: True Whether or not to sort unique values. Returns @@ -59,7 +78,7 @@ def unique_value_groups(ar, sort=True): inverse, values = pd.factorize(ar, sort=sort) if isinstance(values, pd.MultiIndex): values.names = ar.names - groups = [[] for _ in range(len(values))] + groups: list[list[int]] = [[] for _ in range(len(values))] for n, g in enumerate(inverse): if g >= 0: # pandas uses -1 to mark NaN, but doesn't include them in values @@ -68,10 +87,9 @@ def unique_value_groups(ar, sort=True): def _dummy_copy(xarray_obj): - from .dataarray import DataArray - from .dataset import Dataset + from . import dataarray, dataset - if isinstance(xarray_obj, Dataset): + if isinstance(xarray_obj, dataset.Dataset): res = Dataset( { k: dtypes.get_fill_value(v.dtype) @@ -84,7 +102,7 @@ def _dummy_copy(xarray_obj): }, xarray_obj.attrs, ) - elif isinstance(xarray_obj, DataArray): + elif isinstance(xarray_obj, dataarray.DataArray): res = DataArray( dtypes.get_fill_value(xarray_obj.dtype), { @@ -158,17 +176,17 @@ class _DummyGroup: __slots__ = ("name", "coords", "size") - def __init__(self, obj, name, coords): + def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None: self.name = name self.coords = coords self.size = obj.sizes[name] @property - def dims(self): + def dims(self) -> tuple[Hashable]: return (self.name,) @property - def ndim(self): + def ndim(self) -> Literal[1]: return 1 @property @@ -176,7 +194,7 @@ def values(self): return range(self.size) @property - def shape(self): + def shape(self) -> tuple[int]: return (self.size,) def __getitem__(self, key): @@ -185,21 +203,31 @@ def __getitem__(self, key): return self.values[key] -def _ensure_1d(group, obj): - if group.ndim != 1: +T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup]) + + +def _ensure_1d( + group: T_Group, obj: T_Xarray +) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]: + # 1D cases: do nothing + if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: + return group, obj, None, [] + + if isinstance(group, DataArray): # try to stack the dims of the group into a single dim orig_dims = group.dims - stacked_dim = "stacked_" + "_".join(orig_dims) + stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) # these dimensions get created by the stack operation inserted_dims = [dim for dim in group.dims if dim not in group.coords] # the copy is necessary here, otherwise read only array raises error # in pandas: https://github.com/pydata/pandas/issues/12813 - group = group.stack(**{stacked_dim: orig_dims}).copy() - obj = obj.stack(**{stacked_dim: orig_dims}) - else: - stacked_dim = None - inserted_dims = [] - return group, obj, stacked_dim, inserted_dims + newgroup = group.stack({stacked_dim: orig_dims}).copy() + newobj = obj.stack({stacked_dim: orig_dims}) + return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims + + raise TypeError( + f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}." + ) def _unique_and_monotonic(group): @@ -235,7 +263,7 @@ def _apply_loffset(grouper, result): grouper.loffset = None -class GroupBy: +class GroupBy(Generic[T_Xarray]): """A object that implements the split-apply-combine pattern. Modeled after `pandas.GroupBy`. The `GroupBy` object can be iterated over @@ -270,26 +298,27 @@ class GroupBy: "_unstacked_group", "_bins", ) + _obj: T_Xarray def __init__( self, - obj, - group, - squeeze=False, - grouper=None, + obj: T_Xarray, + group: Hashable | DataArray | IndexVariable, + squeeze: bool = False, + grouper: pd.Grouper | None = None, bins=None, - restore_coord_dims=True, + restore_coord_dims: bool = True, cut_kwargs=None, - ): + ) -> None: """Create a GroupBy object Parameters ---------- obj : Dataset or DataArray Object to group. - group : DataArray - Array with the group values. - squeeze : bool, optional + group : Hashable, DataArray or Index + Array with the group values or name of the variable. + squeeze : bool, default: False If "group" is a coordinate of object, `squeeze` controls whether the subarrays have a dimension of length 1 along that coordinate or if the dimension is squeezed out. @@ -330,7 +359,7 @@ def __init__( if getattr(group, "name", None) is None: group.name = "group" - self._original_obj = obj + self._original_obj: T_Xarray = obj self._unstacked_group = group self._bins = bins @@ -351,10 +380,12 @@ def __init__( if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") binned, bins = pd.cut(group.values, bins, **cut_kwargs, retbins=True) - new_dim_name = group.name + "_bins" - group = DataArray(binned, group.coords, name=new_dim_name) + new_dim_name = str(group.name) + "_bins" + group = DataArray(binned, getattr(group, "coords", None), name=new_dim_name) full_index = binned.categories + group_indices: list[slice] | list[list[int]] | np.ndarray + unique_coord: DataArray | IndexVariable | _DummyGroup if grouper is not None: index = safe_cast_to_index(group) if not index.is_monotonic_increasing: @@ -375,7 +406,7 @@ def __init__( group_indices = [slice(i, i + 1) for i in group_indices] unique_coord = group else: - if group.isnull().any(): + if isinstance(group, DataArray) and group.isnull().any(): # drop any NaN valued groups. # also drop obj values where group was NaN # Use where instead of reindex to account for duplicate coordinate labels. @@ -401,7 +432,7 @@ def __init__( ) # specification for the groupby operation - self._obj = obj + self._obj: T_Xarray = obj self._group = group self._group_dim = group_dim self._group_indices = group_indices @@ -414,20 +445,18 @@ def __init__( self._squeeze = squeeze # cached attributes - self._groups = None - self._dims = None + self._groups: dict[Any, slice | int | list[int]] | None = None + self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None @property - def dims(self): + def dims(self) -> tuple[Hashable, ...] | Frozen[Hashable, int]: if self._dims is None: - self._dims = self._obj.isel( - **{self._group_dim: self._group_indices[0]} - ).dims + self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims return self._dims @property - def groups(self): + def groups(self) -> dict[Any, slice | int | list[int]]: """ Mapping from group labels to indices. The indices can be used to index the underlying object. """ @@ -489,8 +518,7 @@ def _infer_concat_args(self, applied_example): return coord, dim, positions def _binary_op(self, other, f, reflexive=False): - from .dataarray import DataArray - from .dataset import Dataset + from . import dataarray, dataset g = f if not reflexive else lambda x, y: f(y, x) @@ -501,7 +529,7 @@ def _binary_op(self, other, f, reflexive=False): group = obj[dim] name = group.name - if not isinstance(other, (Dataset, DataArray)): + if not isinstance(other, (dataset.Dataset, dataarray.DataArray)): raise TypeError( "GroupBy objects only support binary ops " "when the other argument is a Dataset or " @@ -583,7 +611,7 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs): """Adaptor function that translates our groupby API to that of flox.""" from flox.xarray import xarray_reduce - from .dataset import Dataset + from . import dataset obj = self._original_obj @@ -693,7 +721,10 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs): # Fix dimension order when binning a dimension coordinate # Needed as long as we do a separate code path for pint; # For some reason Datasets and DataArrays behave differently! - if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: + if ( + isinstance(self._obj, dataset.Dataset) + and self._group_dim in self._obj.dims + ): result = result.transpose(self._group.name, ...) return result @@ -921,7 +952,7 @@ def _maybe_reorder(xarray_obj, dim, positions): return xarray_obj[{dim: order}] -class DataArrayGroupByBase(GroupBy, DataArrayGroupbyArithmetic): +class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic): """GroupBy object specialized to grouping DataArray objects""" __slots__ = () @@ -945,7 +976,7 @@ def _concat_shortcut(self, applied, dim, positions=None): reordered = _maybe_reorder(stacked, dim, positions) return self._obj._replace_maybe_drop_dims(reordered) - def _restore_dim_order(self, stacked): + def _restore_dim_order(self, stacked: DataArray) -> DataArray: def lookup_order(dimension): if dimension == self._group.name: (dimension,) = self._group.dims @@ -1098,11 +1129,12 @@ def reduce_array(ar): return self.map(reduce_array, shortcut=shortcut) -class DataArrayGroupBy(DataArrayGroupByBase, DataArrayGroupByReductions): +# https://github.com/python/mypy/issues/9031 +class DataArrayGroupBy(DataArrayGroupByBase, DataArrayGroupByReductions): # type: ignore[misc] __slots__ = () -class DatasetGroupByBase(GroupBy, DatasetGroupbyArithmetic): +class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): __slots__ = () @@ -1236,5 +1268,6 @@ def assign(self, **kwargs): return self.map(lambda ds: ds.assign(**kwargs)) -class DatasetGroupBy(DatasetGroupByBase, DatasetGroupByReductions): +# https://github.com/python/mypy/issues/9031 +class DatasetGroupBy(DatasetGroupByBase, DatasetGroupByReductions): # type: ignore[misc] __slots__ = () diff --git a/xarray/core/resample.py b/xarray/core/resample.py index e38deb3e440..15955038010 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -185,7 +185,8 @@ def _interpolate(self, kind="linear"): ) -class DataArrayResample(Resample, DataArrayGroupByBase, DataArrayResampleReductions): +# https://github.com/python/mypy/issues/9031 +class DataArrayResample(Resample, DataArrayGroupByBase, DataArrayResampleReductions): # type: ignore[misc] """DataArrayGroupBy object specialized to time resampling operations over a specified dimension """ @@ -276,7 +277,8 @@ def apply(self, func, args=(), shortcut=None, **kwargs): return self.map(func=func, shortcut=shortcut, args=args, **kwargs) -class DatasetResample(Resample, DatasetGroupByBase, DatasetResampleReductions): +# https://github.com/python/mypy/issues/9031 +class DatasetResample(Resample, DatasetGroupByBase, DatasetResampleReductions): # type: ignore[misc] """DatasetGroupBy object specialized to resampling a specified dimension""" def __init__(self, *args, dim=None, resample_dim=None, **kwargs): diff --git a/xarray/core/utils.py b/xarray/core/utils.py index b253f1661ae..a87beafaf19 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -828,7 +828,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( indexers: Mapping[Any, Any], - dims: list | Mapping[Any, int], + dims: Iterable[Hashable] | Mapping[Any, int], missing_dims: ErrorOptionsWithWarn, ) -> Mapping[Hashable, Any]: """Depending on the setting of missing_dims, drop any dimensions from indexers that diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 2d115ff0ed9..6c6d2aca90c 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -5,7 +5,7 @@ import numbers import warnings from datetime import timedelta -from typing import TYPE_CHECKING, Any, Hashable, Literal, Mapping, Sequence +from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, Mapping, Sequence import numpy as np import pandas as pd @@ -552,15 +552,15 @@ def to_dict(self, data: bool = True, encoding: bool = False) -> dict: return item @property - def dims(self): + def dims(self) -> tuple[Hashable, ...]: """Tuple of dimension names with which this variable is associated.""" return self._dims @dims.setter - def dims(self, value): + def dims(self, value: str | Iterable[Hashable]) -> None: self._dims = self._parse_dimensions(value) - def _parse_dimensions(self, dims): + def _parse_dimensions(self, dims: str | Iterable[Hashable]) -> tuple[Hashable, ...]: if isinstance(dims, str): dims = (dims,) dims = tuple(dims) From a4f2cd3ee61bddad975b4452334ba642f1276c6d Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 15:26:50 +0200 Subject: [PATCH 02/36] more typing of groupby funcs --- xarray/core/dataarray.py | 4 +- xarray/core/dataset.py | 2 +- xarray/core/groupby.py | 73 +++++++++++++++++++++++---------- xarray/tests/test_groupby.py | 78 ++++++++++++++++++------------------ 4 files changed, 93 insertions(+), 64 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8d9a7868d37..978d48c1705 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -3879,7 +3879,7 @@ def sortby( def quantile( self: T_DataArray, q: ArrayLike, - dim: str | Sequence[Hashable] | None = None, + dim: str | Iterable[Hashable] | None = None, method: QUANTILE_METHODS = "linear", keep_attrs: bool | None = None, skipna: bool | None = None, @@ -3893,7 +3893,7 @@ def quantile( ---------- q : float or array-like of float Quantile to compute, which must be between 0 and 1 inclusive. - dim : Hashable or sequence of Hashable, optional + dim : str or Iterable of Hashable, optional Dimension(s) over which to apply quantile. method : str, default: "linear" This optional parameter specifies the interpolation method to use when the diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 15fef68302e..bab238c6ec4 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -6670,7 +6670,7 @@ def quantile( ---------- q : float or array-like of float Quantile to compute, which must be between 0 and 1 inclusive. - dim : str or sequence of str, optional + dim : str or Iterable of Hashable, optional Dimension(s) over which to apply quantile. method : str, default: "linear" This optional parameter specifies the interpolation method to use when the diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 80724b2e4eb..10d2e398e96 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -8,6 +8,8 @@ Callable, Generic, Hashable, + Iterable, + Iterator, Literal, Sequence, TypeVar, @@ -24,6 +26,7 @@ from .concat import concat from .formatting import format_array_flat from .indexes import create_default_index_implicit, filter_indexes_from_coords +from .npcompat import QUANTILE_METHODS, ArrayLike from .options import _get_keep_attrs from .pycompat import integer_types from .types import T_Xarray @@ -190,7 +193,7 @@ def ndim(self) -> Literal[1]: return 1 @property - def values(self): + def values(self) -> range: return range(self.size) @property @@ -230,7 +233,7 @@ def _ensure_1d( ) -def _unique_and_monotonic(group): +def _unique_and_monotonic(group: T_Group) -> bool: if isinstance(group, _DummyGroup): return True index = safe_cast_to_index(group) @@ -263,6 +266,9 @@ def _apply_loffset(grouper, result): grouper.loffset = None +T_GroupBy = TypeVar("T_GroupBy", bound="GroupBy") + + class GroupBy(Generic[T_Xarray]): """A object that implements the split-apply-combine pattern. @@ -448,6 +454,15 @@ def __init__( self._groups: dict[Any, slice | int | list[int]] | None = None self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None + def map( + self, + func: Callable, + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> T_Xarray: + raise NotImplementedError() + @property def dims(self) -> tuple[Hashable, ...] | Frozen[Hashable, int]: if self._dims is None: @@ -465,19 +480,19 @@ def groups(self) -> dict[Any, slice | int | list[int]]: self._groups = dict(zip(self._unique_coord.values, self._group_indices)) return self._groups - def __getitem__(self, key): + def __getitem__(self, key: Any) -> T_Xarray: """ Get DataArray or Dataset corresponding to a particular group label. """ return self._obj.isel({self._group_dim: self.groups[key]}) - def __len__(self): + def __len__(self) -> int: return self._unique_coord.size - def __iter__(self): + def __iter__(self) -> Iterator[tuple[Any, T_Xarray]]: return zip(self._unique_coord.values, self._iter_grouped()) - def __repr__(self): + def __repr__(self) -> str: return "{}, grouped over {!r}\n{!r} groups with labels {}.".format( self.__class__.__name__, self._unique_coord.name, @@ -729,7 +744,7 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs): return result - def fillna(self, value): + def fillna(self, value: Any) -> T_Xarray: """Fill missing values in this object by group. This operation follows the normal broadcasting and alignment rules that @@ -757,13 +772,13 @@ def fillna(self, value): def quantile( self, - q, - dim=None, - method="linear", - keep_attrs=None, - skipna=None, - interpolation=None, - ): + q: ArrayLike, + dim: str | Iterable[Hashable] | None = None, + method: QUANTILE_METHODS = "linear", + keep_attrs: bool | None = None, + skipna: bool | None = None, + interpolation: QUANTILE_METHODS | None = None, + ) -> T_Xarray: """Compute the qth quantile over each array in the groups and concatenate them together into a new array. @@ -772,7 +787,7 @@ def quantile( q : float or sequence of float Quantile to compute, which must be between 0 and 1 inclusive. - dim : ..., str or sequence of str, optional + dim : str or Iterable of Hashable, optional Dimension(s) over which to apply quantile. Defaults to the grouped dimension. method : str, default: "linear" @@ -802,8 +817,11 @@ def quantile( an asterix require numpy version 1.22 or newer. The "method" argument was previously called "interpolation", renamed in accordance with numpy version 1.22.0. - - skipna : bool, optional + keep_attrs : bool or None, default: None + If True, the dataarray's attributes (`attrs`) will be copied from + the original object to the new one. If False, the new + object will be returned without attributes. + skipna : bool or None, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been @@ -879,9 +897,9 @@ def quantile( The American Statistician, 50(4), pp. 361-365, 1996 """ if dim is None: - dim = self._group_dim + dim = (self._group_dim,) - out = self.map( + return self.map( self._obj.__class__.quantile, shortcut=False, q=q, @@ -891,7 +909,6 @@ def quantile( skipna=skipna, interpolation=interpolation, ) - return out def where(self, cond, other=dtypes.NA): """Return elements from `self` or `other` depending on `cond`. @@ -989,7 +1006,13 @@ def lookup_order(dimension): new_order = sorted(stacked.dims, key=lookup_order) return stacked.transpose(*new_order, transpose_coords=self._restore_coord_dims) - def map(self, func, shortcut=False, args=(), **kwargs): + def map( + self, + func: Callable, + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> DataArray: """Apply a function to each array in the group and concatenate them together into a new array. @@ -1138,7 +1161,13 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): __slots__ = () - def map(self, func, args=(), shortcut=None, **kwargs): + def map( + self, + func: Callable, + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> Dataset: """Apply a function to each Dataset in the group and concatenate them together into a new Dataset. diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index f0b16bc42c7..7d1932283b2 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -235,13 +235,13 @@ def test_da_groupby_quantile() -> None: dims=("x", "y"), ) - actual_x = array.groupby("x").quantile(0, dim=...) + actual_x = array.groupby("x").quantile(0, dim=...) # type: ignore[arg-type] expected_x = xr.DataArray( data=[1, 4], coords={"x": [1, 2], "quantile": 0}, dims="x" ) assert_identical(expected_x, actual_x) - actual_y = array.groupby("y").quantile(0, dim=...) + actual_y = array.groupby("y").quantile(0, dim=...) # type: ignore[arg-type] expected_y = xr.DataArray( data=[1, 22], coords={"y": [0, 1], "quantile": 0}, dims="y" ) @@ -272,7 +272,7 @@ def test_da_groupby_quantile() -> None: ) g = foo.groupby(foo.time.dt.month) - actual = g.quantile(0, dim=...) + actual = g.quantile(0, dim=...) # type: ignore[arg-type] expected = xr.DataArray( data=[ 0.0, @@ -356,11 +356,11 @@ def test_ds_groupby_quantile() -> None: coords={"x": [1, 1, 1, 2, 2], "y": [0, 0, 1]}, ) - actual_x = ds.groupby("x").quantile(0, dim=...) + actual_x = ds.groupby("x").quantile(0, dim=...) # type: ignore[arg-type] expected_x = xr.Dataset({"a": ("x", [1, 4])}, coords={"x": [1, 2], "quantile": 0}) assert_identical(expected_x, actual_x) - actual_y = ds.groupby("y").quantile(0, dim=...) + actual_y = ds.groupby("y").quantile(0, dim=...) # type: ignore[arg-type] expected_y = xr.Dataset({"a": ("y", [1, 22])}, coords={"y": [0, 1], "quantile": 0}) assert_identical(expected_y, actual_y) @@ -386,7 +386,7 @@ def test_ds_groupby_quantile() -> None: ) g = foo.groupby(foo.time.dt.month) - actual = g.quantile(0, dim=...) + actual = g.quantile(0, dim=...) # type: ignore[arg-type] expected = xr.Dataset( { "a": ( @@ -522,17 +522,17 @@ def test_groupby_drops_nans() -> None: grouped = ds.groupby(ds.id) # non reduction operation - expected = ds.copy() - expected.variable.values[0, 0, :] = np.nan - expected.variable.values[-1, -1, :] = np.nan - expected.variable.values[3, 0, :] = np.nan - actual = grouped.map(lambda x: x).transpose(*ds.variable.dims) - assert_identical(actual, expected) + expected1 = ds.copy() + expected1.variable.values[0, 0, :] = np.nan + expected1.variable.values[-1, -1, :] = np.nan + expected1.variable.values[3, 0, :] = np.nan + actual1 = grouped.map(lambda x: x).transpose(*ds.variable.dims) + assert_identical(actual1, expected1) # reduction along grouped dimension - actual = grouped.mean() + actual2 = grouped.mean() stacked = ds.stack({"xy": ["lat", "lon"]}) - expected = ( + expected2 = ( stacked.variable.where(stacked.id.notnull()) .rename({"xy": "id"}) .to_dataset() @@ -540,21 +540,21 @@ def test_groupby_drops_nans() -> None: .drop_vars(["lon", "lat"]) .assign(id=stacked.id.values) .dropna("id") - .transpose(*actual.dims) + .transpose(*actual2.dims) ) - assert_identical(actual, expected) + assert_identical(actual2, expected2) # reduction operation along a different dimension - actual = grouped.mean("time") - expected = ds.mean("time").where(ds.id.notnull()) - assert_identical(actual, expected) + actual3 = grouped.mean("time") + expected3 = ds.mean("time").where(ds.id.notnull()) + assert_identical(actual3, expected3) # NaN in non-dimensional coordinate array = xr.DataArray([1, 2, 3], [("x", [1, 2, 3])]) array["x1"] = ("x", [1, 1, np.nan]) - expected_da = xr.DataArray(3, [("x1", [1])]) - actual = array.groupby("x1").sum() - assert_equal(expected_da, actual) + expected4 = xr.DataArray(3, [("x1", [1])]) + actual4 = array.groupby("x1").sum() + assert_equal(expected4, actual4) # NaT in non-dimensional coordinate array["t"] = ( @@ -565,15 +565,15 @@ def test_groupby_drops_nans() -> None: np.datetime64("NaT"), ], ) - expected_da = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])]) - actual = array.groupby("t").sum() - assert_equal(expected_da, actual) + expected5 = xr.DataArray(3, [("t", [np.datetime64("2001-01-01")])]) + actual5 = array.groupby("t").sum() + assert_equal(expected5, actual5) # test for repeated coordinate labels array = xr.DataArray([0, 1, 2, 4, 3, 4], [("x", [np.nan, 1, 1, np.nan, 2, np.nan])]) - expected_da = xr.DataArray([3, 3], [("x", [1, 2])]) - actual = array.groupby("x").sum() - assert_equal(expected_da, actual) + expected6 = xr.DataArray([3, 3], [("x", [1, 2])]) + actual6 = array.groupby("x").sum() + assert_equal(expected6, actual6) def test_groupby_grouping_errors() -> None: @@ -679,28 +679,28 @@ def test_groupby_dataset() -> None: ("b", data.isel(x=1)), ("c", data.isel(x=2)), ] - for actual, expected in zip(groupby, expected_items): - assert actual[0] == expected[0] - assert_equal(actual[1], expected[1]) + for actual1, expected1 in zip(groupby, expected_items): + assert actual1[0] == expected1[0] + assert_equal(actual1[1], expected1[1]) def identity(x): return x for k in ["x", "c", "y"]: - actual = data.groupby(k, squeeze=False).map(identity) - assert_equal(data, actual) + actual2 = data.groupby(k, squeeze=False).map(identity) + assert_equal(data, actual2) def test_groupby_dataset_returns_new_type() -> None: data = Dataset({"z": (["x", "y"], np.random.randn(3, 5))}) - actual = data.groupby("x").map(lambda ds: ds["z"]) - expected = data["z"] - assert_identical(expected, actual) + actual1 = data.groupby("x").map(lambda ds: ds["z"]) + expected1 = data["z"] + assert_identical(expected1, actual1) - actual = data["z"].groupby("x").map(lambda x: x.to_dataset()) - expected_ds = data - assert_identical(expected_ds, actual) + actual2 = data["z"].groupby("x").map(lambda x: x.to_dataset()) + expected2 = data + assert_identical(expected2, actual2) def test_groupby_dataset_iter() -> None: From 2d06e25b933d188f22b5027df58369e199efa35d Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 15:55:39 +0200 Subject: [PATCH 03/36] type all groupby methods --- xarray/core/dataarray.py | 4 ++-- xarray/core/groupby.py | 50 +++++++++++++++++++++++++--------------- xarray/core/resample.py | 7 +++--- xarray/core/variable.py | 35 +++++++++++++++++----------- 4 files changed, 60 insertions(+), 36 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 978d48c1705..1c157f9ba12 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -2877,7 +2877,7 @@ def combine_first(self: T_DataArray, other: T_DataArray) -> T_DataArray: def reduce( self: T_DataArray, func: Callable[..., Any], - dim: None | Hashable | Sequence[Hashable] = None, + dim: None | Hashable | Iterable[Hashable] = None, *, axis: None | int | Sequence[int] = None, keep_attrs: bool | None = None, @@ -2892,7 +2892,7 @@ def reduce( Function which can be called in the form `f(x, axis=axis, **kwargs)` to return the result of reducing an np.ndarray over an integer valued axis. - dim : Hashable or sequence of Hashable, optional + dim : Hashable or Iterable of Hashable, optional Dimension(s) over which to apply `func`. axis : int or sequence of int, optional Axis(es) over which to repeatedly apply `func`. Only one of the diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 10d2e398e96..6e6a7e20e34 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -463,6 +463,19 @@ def map( ) -> T_Xarray: raise NotImplementedError() + def reduce( + self, + func: Callable[..., Any], + dim: None | Hashable | Iterable[Hashable] = None, + *, + axis: None | int | Sequence[int] = None, + keep_attrs: bool | None = None, + keepdims: bool = False, + shortcut: bool = True, + **kwargs: Any, + ) -> T_Xarray: + raise NotImplementedError() + @property def dims(self) -> tuple[Hashable, ...] | Frozen[Hashable, int]: if self._dims is None: @@ -514,10 +527,10 @@ def _get_index_and_items(self, index, grouper): first_items = first_items.dropna() return full_index, first_items - def _iter_grouped(self): + def _iter_grouped(self) -> Iterator[T_Xarray]: """Iterate over each element in this group""" for indices in self._group_indices: - yield self._obj.isel(**{self._group_dim: indices}) + yield self._obj.isel({self._group_dim: indices}) def _infer_concat_args(self, applied_example): if self._group_dim in applied_example.dims: @@ -910,7 +923,7 @@ def quantile( interpolation=interpolation, ) - def where(self, cond, other=dtypes.NA): + def where(self, cond, other=dtypes.NA) -> T_Xarray: """Return elements from `self` or `other` depending on `cond`. Parameters @@ -940,11 +953,11 @@ def _first_or_last(self, op, skipna, keep_attrs): keep_attrs = _get_keep_attrs(default=True) return self.reduce(op, self._group_dim, skipna=skipna, keep_attrs=keep_attrs) - def first(self, skipna=None, keep_attrs=None): + def first(self, skipna: bool | None = None, keep_attrs: bool | None = None): """Return the first element of each group along the group dimension""" return self._first_or_last(duck_array_ops.first, skipna, keep_attrs) - def last(self, skipna=None, keep_attrs=None): + def last(self, skipna: bool | None = None, keep_attrs: bool | None = None): """Return the last element of each group along the group dimension""" return self._first_or_last(duck_array_ops.last, skipna, keep_attrs) @@ -1008,7 +1021,7 @@ def lookup_order(dimension): def map( self, - func: Callable, + func: Callable[..., DataArray], args: tuple[Any, ...] = (), shortcut: bool | None = None, **kwargs: Any, @@ -1051,7 +1064,7 @@ def map( Returns ------- - applied : DataArray or DataArray + applied : DataArray The result of splitting, applying and combining this array. """ grouped = self._iter_grouped_shortcut() if shortcut else self._iter_grouped() @@ -1098,14 +1111,14 @@ def _combine(self, applied, shortcut=False): def reduce( self, func: Callable[..., Any], - dim: None | Hashable | Sequence[Hashable] = None, + dim: None | Hashable | Iterable[Hashable] = None, *, axis: None | int | Sequence[int] = None, keep_attrs: bool = None, keepdims: bool = False, shortcut: bool = True, **kwargs: Any, - ): + ) -> DataArray: """Reduce the items in this group by applying `func` along some dimension(s). @@ -1137,7 +1150,7 @@ def reduce( if dim is None: dim = self._group_dim - def reduce_array(ar): + def reduce_array(ar: DataArray) -> DataArray: return ar.reduce( func=func, dim=dim, @@ -1163,7 +1176,7 @@ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): def map( self, - func: Callable, + func: Callable[..., Dataset], args: tuple[Any, ...] = (), shortcut: bool | None = None, **kwargs: Any, @@ -1194,7 +1207,7 @@ def map( Returns ------- - applied : Dataset or DataArray + applied : Dataset The result of splitting, applying and combining this dataset. """ # ignore shortcut if set (for now) @@ -1235,13 +1248,14 @@ def _combine(self, applied): def reduce( self, func: Callable[..., Any], - dim: None | Hashable | Sequence[Hashable] = None, + dim: None | Hashable | Iterable[Hashable] = None, *, axis: None | int | Sequence[int] = None, keep_attrs: bool = None, keepdims: bool = False, + shortcut: bool = True, **kwargs: Any, - ): + ) -> Dataset: """Reduce the items in this group by applying `func` along some dimension(s). @@ -1251,7 +1265,7 @@ def reduce( Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. - dim : ..., str or sequence of str, optional + dim : ..., str or Iterable of Hashable, optional Dimension(s) over which to apply `func`. axis : int or sequence of int, optional Axis(es) over which to apply `func`. Only one of the 'dimension' @@ -1266,14 +1280,14 @@ def reduce( Returns ------- - reduced : Array + reduced : Dataset Array with summarized data and the indicated dimension(s) removed. """ if dim is None: dim = self._group_dim - def reduce_dataset(ds): + def reduce_dataset(ds: Dataset) -> Dataset: return ds.reduce( func=func, dim=dim, @@ -1287,7 +1301,7 @@ def reduce_dataset(ds): return self.map(reduce_dataset) - def assign(self, **kwargs): + def assign(self, **kwargs: Any) -> Dataset: """Assign data variables by group. See Also diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 15955038010..fb6082dc7a9 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import Any, Callable, Hashable, Sequence +from typing import Any, Callable, Hashable, Iterable, Sequence import numpy as np @@ -349,11 +349,12 @@ def apply(self, func, args=(), shortcut=None, **kwargs): def reduce( self, func: Callable[..., Any], - dim: None | Hashable | Sequence[Hashable] = None, + dim: None | Hashable | Iterable[Hashable] = None, *, axis: None | int | Sequence[int] = None, keep_attrs: bool = None, keepdims: bool = False, + shortcut: bool = True, **kwargs: Any, ): """Reduce the items in this group by applying `func` along the @@ -365,7 +366,7 @@ def reduce( Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. - dim : str or sequence of str, optional + dim : Hashable or Iterable of Hashable, optional Dimension(s) over which to apply `func`. keep_attrs : bool, optional If True, the datasets's attributes (`attrs`) will be copied from diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 6c6d2aca90c..4e13e8bae73 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -5,7 +5,16 @@ import numbers import warnings from datetime import timedelta -from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Hashable, + Iterable, + Literal, + Mapping, + Sequence, +) import numpy as np import pandas as pd @@ -1780,13 +1789,13 @@ def clip(self, min=None, max=None): def reduce( self, - func, - dim=None, - axis=None, - keep_attrs=None, - keepdims=False, + func: Callable[..., Any], + dim: Hashable | Iterable[Hashable] | None = None, + axis: int | Sequence[int] | None = None, + keep_attrs: bool | None = None, + keepdims: bool = False, **kwargs, - ): + ) -> Variable: """Reduce this array by applying `func` along some dimension(s). Parameters @@ -1795,9 +1804,9 @@ def reduce( Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of reducing an np.ndarray over an integer valued axis. - dim : str or sequence of str, optional + dim : Hashable or Iterable of Hashable, optional Dimension(s) over which to apply `func`. - axis : int or sequence of int, optional + axis : int or Iterable of int, optional Axis(es) over which to apply `func`. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then the reduction is calculated over the flattened array (by calling @@ -1838,8 +1847,8 @@ def reduce( if getattr(data, "shape", ()) == self.shape: dims = self.dims else: - removed_axes = ( - range(self.ndim) if axis is None else np.atleast_1d(axis) % self.ndim + removed_axes: Sequence[int] = ( + range(self.ndim) if axis is None else np.atleast_1d(axis) % self.ndim # type: ignore[assignment] ) if keepdims: # Insert np.newaxis for removed dims @@ -1854,9 +1863,9 @@ def reduce( data = data[slices] dims = self.dims else: - dims = [ + dims = tuple( adim for n, adim in enumerate(self.dims) if n not in removed_axes - ] + ) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) From c8f07bdfde89547af5eed2673de17eba0b151a6e Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 15:59:50 +0200 Subject: [PATCH 04/36] remove old groupby methods --- xarray/core/common.py | 146 +----------------------------------------- 1 file changed, 1 insertion(+), 145 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 8fb6dc20875..6c1421f581a 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -39,10 +39,9 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .groupby import GroupBy from .indexes import Index from .types import ScalarOrArray, T_DataWithCoords, T_Xarray - from .variable import IndexVariable, Variable + from .variable import Variable from .weighted import Weighted @@ -356,7 +355,6 @@ class DataWithCoords(AttrAccessMixin): _close: Callable[[], None] | None _indexes: dict[Hashable, Index] - _groupby_cls: type[GroupBy] __slots__ = ("_close",) @@ -750,148 +748,6 @@ def pipe( else: return func(self, *args, **kwargs) - def groupby( - self, - group: Hashable | DataArray | IndexVariable, - squeeze: bool = True, - restore_coord_dims: bool = False, - ) -> GroupBy: - """Returns a GroupBy object for performing grouped operations. - - Parameters - ---------- - group : Hashable, DataArray or IndexVariable - Array whose unique values should be used to group this array. If a - string, must be the name of a variable contained in this dataset. - squeeze : bool, default: True - If "group" is a dimension of any arrays in this dataset, `squeeze` - controls whether the subarrays have a dimension of length 1 along - that dimension or if the dimension is squeezed out. - restore_coord_dims : bool, default: False - If True, also restore the dimension order of multi-dimensional - coordinates. - - Returns - ------- - grouped : GroupBy - A `GroupBy` object patterned after `pandas.GroupBy` that can be - iterated over in the form of `(unique_value, grouped_array)` pairs. - - Examples - -------- - Calculate daily anomalies for daily data: - - >>> da = xr.DataArray( - ... np.linspace(0, 1826, num=1827), - ... coords=[pd.date_range("1/1/2000", "31/12/2004", freq="D")], - ... dims="time", - ... ) - >>> da - - array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, - 1.826e+03]) - Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 - >>> da.groupby("time.dayofyear") - da.groupby("time.dayofyear").mean("time") - - array([-730.8, -730.8, -730.8, ..., 730.2, 730.2, 730.5]) - Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 - dayofyear (time) int64 1 2 3 4 5 6 7 8 ... 359 360 361 362 363 364 365 366 - - See Also - -------- - core.groupby.DataArrayGroupBy - core.groupby.DatasetGroupBy - """ - # While we don't generally check the type of every arg, passing - # multiple dimensions as multiple arguments is common enough, and the - # consequences hidden enough (strings evaluate as true) to warrant - # checking here. - # A future version could make squeeze kwarg only, but would face - # backward-compat issues. - if not isinstance(squeeze, bool): - raise TypeError( - f"`squeeze` must be True or False, but {squeeze} was supplied" - ) - - return self._groupby_cls( - self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims - ) - - def groupby_bins( - self, - group: Hashable | DataArray | IndexVariable, - bins, - right: bool = True, - labels=None, - precision: int = 3, - include_lowest: bool = False, - squeeze: bool = True, - restore_coord_dims: bool = False, - ) -> GroupBy: - """Returns a GroupBy object for performing grouped operations. - - Rather than using all unique values of `group`, the values are discretized - first by applying `pandas.cut` [1]_ to `group`. - - Parameters - ---------- - group : Hashable, DataArray or IndexVariable - Array whose binned values should be used to group this array. If a - string, must be the name of a variable contained in this dataset. - bins : int or array-like - If bins is an int, it defines the number of equal-width bins in the - range of x. However, in this case, the range of x is extended by .1% - on each side to include the min or max values of x. If bins is a - sequence it defines the bin edges allowing for non-uniform bin - width. No extension of the range of x is done in this case. - right : bool, default: True - Indicates whether the bins include the rightmost edge or not. If - right == True (the default), then the bins [1,2,3,4] indicate - (1,2], (2,3], (3,4]. - labels : array-like or bool, default: None - Used as labels for the resulting bins. Must be of the same length as - the resulting bins. If False, string bin labels are assigned by - `pandas.cut`. - precision : int, default: 3 - The precision at which to store and display the bins labels. - include_lowest : bool, default: False - Whether the first interval should be left-inclusive or not. - squeeze : bool, default: True - If "group" is a dimension of any arrays in this dataset, `squeeze` - controls whether the subarrays have a dimension of length 1 along - that dimension or if the dimension is squeezed out. - restore_coord_dims : bool, default: False - If True, also restore the dimension order of multi-dimensional - coordinates. - - Returns - ------- - grouped - A `GroupBy` object patterned after `pandas.GroupBy` that can be - iterated over in the form of `(unique_value, grouped_array)` pairs. - The name of the group has the added suffix `_bins` in order to - distinguish it from the original variable. - - References - ---------- - .. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html - """ - return self._groupby_cls( - self, - group, - squeeze=squeeze, - bins=bins, - restore_coord_dims=restore_coord_dims, - cut_kwargs={ - "right": right, - "labels": labels, - "precision": precision, - "include_lowest": include_lowest, - }, - ) - def weighted(self: T_DataWithCoords, weights: DataArray) -> Weighted[T_Xarray]: """ Weighted operations. From ddb51f33b9266782d4275e0f505568d319b9126a Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 16:01:03 +0200 Subject: [PATCH 05/36] remove unused typevar --- xarray/core/groupby.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 6e6a7e20e34..3e786d2150d 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -266,9 +266,6 @@ def _apply_loffset(grouper, result): grouper.loffset = None -T_GroupBy = TypeVar("T_GroupBy", bound="GroupBy") - - class GroupBy(Generic[T_Xarray]): """A object that implements the split-apply-combine pattern. From 3d867a242c0f2f6d79a9af7c369a786d006a027c Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 16:15:43 +0200 Subject: [PATCH 06/36] more consistent groupby.dims typing and new groupby.sizes --- xarray/core/common.py | 2 +- xarray/core/dataarray.py | 5 +++++ xarray/core/dataset.py | 5 +++++ xarray/core/groupby.py | 44 ++++++++++++++++++++++++++++++++-------- 4 files changed, 47 insertions(+), 9 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 6c1421f581a..543763e2a07 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -198,7 +198,7 @@ def _get_axis_num(self: Any, dim: Hashable) -> int: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") @property - def sizes(self: Any) -> Mapping[Hashable, int]: + def sizes(self: Any) -> Frozen[Hashable, int]: """Ordered mapping from dimension names to lengths. Immutable. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 1c157f9ba12..83221b10eea 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -742,6 +742,11 @@ def dims(self) -> tuple[Hashable, ...]: Note that the type of this property is inconsistent with `Dataset.dims`. See `Dataset.sizes` and `DataArray.sizes` for consistently named properties. + + See Also + -------- + DataArray.sizes + Dataset.dims """ return self.variable.dims diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index bab238c6ec4..3c6502bd275 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -658,6 +658,11 @@ def dims(self) -> Frozen[Hashable, int]: Note that type of this object differs from `DataArray.dims`. See `Dataset.sizes` and `DataArray.sizes` for consistently named properties. + + See Also + -------- + Dataset.sizes + DataArray.dims """ return Frozen(self._dims) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3e786d2150d..1dfb7d76638 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -449,7 +449,26 @@ def __init__( # cached attributes self._groups: dict[Any, slice | int | list[int]] | None = None - self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None + self._dims = None + self._sizes: Frozen[Hashable, int] | None = None + + @property + def sizes(self) -> Frozen[Hashable, int]: + """Ordered mapping from dimension names to lengths. + + Immutable. + + See Also + -------- + DataArray.sizes + Dataset.sizes + """ + if self._sizes is None: + self._sizes = self._obj.isel( + {self._group_dim: self._group_indices[0]} + ).sizes + + return self._sizes def map( self, @@ -473,13 +492,6 @@ def reduce( ) -> T_Xarray: raise NotImplementedError() - @property - def dims(self) -> tuple[Hashable, ...] | Frozen[Hashable, int]: - if self._dims is None: - self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims - - return self._dims - @property def groups(self) -> dict[Any, slice | int | list[int]]: """ @@ -983,6 +995,14 @@ class DataArrayGroupByBase(GroupBy["DataArray"], DataArrayGroupbyArithmetic): """GroupBy object specialized to grouping DataArray objects""" __slots__ = () + _dims: tuple[Hashable, ...] | None + + @property + def dims(self) -> tuple[Hashable, ...]: + if self._dims is None: + self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims + + return self._dims def _iter_grouped_shortcut(self): """Fast version of `_iter_grouped` that yields Variables without @@ -1170,6 +1190,14 @@ class DataArrayGroupBy(DataArrayGroupByBase, DataArrayGroupByReductions): # typ class DatasetGroupByBase(GroupBy["Dataset"], DatasetGroupbyArithmetic): __slots__ = () + _dims: Frozen[Hashable, int] | None + + @property + def dims(self) -> Frozen[Hashable, int]: + if self._dims is None: + self._dims = self._obj.isel({self._group_dim: self._group_indices[0]}).dims + + return self._dims def map( self, From 33e9d8d46abdf8e6d986f6562d66f4e3348e6b58 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 17:52:59 +0200 Subject: [PATCH 07/36] typing for coarsen objects --- xarray/core/common.py | 178 +------------------------------ xarray/core/dataarray.py | 197 ++++++++++++++++++++++++++++++++--- xarray/core/dataset.py | 120 ++++++++++++++++++++- xarray/core/groupby.py | 1 + xarray/core/rolling.py | 52 +++++---- xarray/core/rolling_exp.py | 12 +-- xarray/core/types.py | 3 + xarray/core/weighted.py | 6 +- xarray/tests/test_coarsen.py | 30 +++--- 9 files changed, 364 insertions(+), 235 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 543763e2a07..feb6724cbbb 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -24,7 +24,6 @@ from .npcompat import DTypeLike, DTypeLikeSave from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array -from .rolling_exp import RollingExp from .utils import Frozen, either_dict_or_kwargs, is_scalar try: @@ -40,9 +39,9 @@ from .dataarray import DataArray from .dataset import Dataset from .indexes import Index - from .types import ScalarOrArray, T_DataWithCoords, T_Xarray + from .rolling_exp import RollingExp + from .types import ScalarOrArray, T_DataWithCoords from .variable import Variable - from .weighted import Weighted C = TypeVar("C") @@ -748,105 +747,12 @@ def pipe( else: return func(self, *args, **kwargs) - def weighted(self: T_DataWithCoords, weights: DataArray) -> Weighted[T_Xarray]: - """ - Weighted operations. - - Parameters - ---------- - weights : DataArray - An array of weights associated with the values in this Dataset. - Each value in the data contributes to the reduction operation - according to its associated weight. - - Notes - ----- - ``weights`` must be a DataArray and cannot contain missing values. - Missing values can be replaced by ``weights.fillna(0)``. - """ - - return self._weighted_cls(self, weights) - - def rolling( - self, - dim: Mapping[Any, int] = None, - min_periods: int = None, - center: bool | Mapping[Any, bool] = False, - **window_kwargs: int, - ): - """ - Rolling window object. - - Parameters - ---------- - dim : dict, optional - Mapping from the dimension name to create the rolling iterator - along (e.g. `time`) to its moving window size. - min_periods : int, default: None - Minimum number of observations in window required to have a value - (otherwise result is NA). The default, None, is equivalent to - setting min_periods equal to the size of the window. - center : bool or mapping, default: False - Set the labels at the center of the window. - **window_kwargs : optional - The keyword arguments form of ``dim``. - One of dim or window_kwargs must be provided. - - Returns - ------- - core.rolling.DataArrayRolling or core.rolling.DatasetRolling - A rolling object (``DataArrayRolling`` for ``DataArray``, - ``DatasetRolling`` for ``Dataset``) - - Examples - -------- - Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON: - - >>> da = xr.DataArray( - ... np.linspace(0, 11, num=12), - ... coords=[ - ... pd.date_range( - ... "1999-12-15", - ... periods=12, - ... freq=pd.DateOffset(months=1), - ... ) - ... ], - ... dims="time", - ... ) - >>> da - - array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) - Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 - >>> da.rolling(time=3, center=True).mean() - - array([nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., nan]) - Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 - - Remove the NaNs using ``dropna()``: - - >>> da.rolling(time=3, center=True).mean().dropna("time") - - array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) - Coordinates: - * time (time) datetime64[ns] 2000-01-15 2000-02-15 ... 2000-10-15 - - See Also - -------- - core.rolling.DataArrayRolling - core.rolling.DatasetRolling - """ - - dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") - return self._rolling_cls(self, dim, min_periods=min_periods, center=center) - def rolling_exp( - self, + self: T_DataWithCoords, window: Mapping[Any, int] = None, window_type: str = "span", **window_kwargs, - ): + ) -> RollingExp[T_DataWithCoords]: """ Exponentially-weighted moving window. Similar to EWM in pandas @@ -882,82 +788,6 @@ def rolling_exp( return RollingExp(self, window, window_type) - def coarsen( - self, - dim: Mapping[Any, int] = None, - boundary: str = "exact", - side: str | Mapping[Any, str] = "left", - coord_func: str = "mean", - **window_kwargs: int, - ): - """ - Coarsen object. - - Parameters - ---------- - dim : mapping of hashable to int, optional - Mapping from the dimension name to the window size. - boundary : {"exact", "trim", "pad"}, default: "exact" - If 'exact', a ValueError will be raised if dimension size is not a - multiple of the window size. If 'trim', the excess entries are - dropped. If 'pad', NA will be padded. - side : {"left", "right"} or mapping of str to {"left", "right"} - coord_func : str or mapping of hashable to str, default: "mean" - function (name) that is applied to the coordinates, - or a mapping from coordinate name to function (name). - - Returns - ------- - core.rolling.DataArrayCoarsen or core.rolling.DatasetCoarsen - A coarsen object (``DataArrayCoarsen`` for ``DataArray``, - ``DatasetCoarsen`` for ``Dataset``) - - Examples - -------- - Coarsen the long time series by averaging over every four days. - - >>> da = xr.DataArray( - ... np.linspace(0, 364, num=364), - ... dims="time", - ... coords={"time": pd.date_range("1999-12-15", periods=364)}, - ... ) - >>> da # +doctest: ELLIPSIS - - array([ 0. , 1.00275482, 2.00550964, 3.00826446, - 4.01101928, 5.0137741 , 6.01652893, 7.01928375, - 8.02203857, 9.02479339, 10.02754821, 11.03030303, - ... - 356.98071625, 357.98347107, 358.9862259 , 359.98898072, - 360.99173554, 361.99449036, 362.99724518, 364. ]) - Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-12-12 - >>> da.coarsen(time=3, boundary="trim").mean() # +doctest: ELLIPSIS - - array([ 1.00275482, 4.01101928, 7.01928375, 10.02754821, - 13.03581267, 16.04407713, 19.0523416 , 22.06060606, - 25.06887052, 28.07713499, 31.08539945, 34.09366391, - ... - 349.96143251, 352.96969697, 355.97796143, 358.9862259 , - 361.99449036]) - Coordinates: - * time (time) datetime64[ns] 1999-12-16 1999-12-19 ... 2000-12-10 - >>> - - See Also - -------- - core.rolling.DataArrayCoarsen - core.rolling.DatasetCoarsen - """ - - dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") - return self._coarsen_cls( - self, - dim, - boundary=boundary, - side=side, - coord_func=coord_func, - ) - def resample( self, indexer: Mapping[Any, str] = None, diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 83221b10eea..0cb72cd2f54 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -24,17 +24,7 @@ from ..coding.cftimeindex import CFTimeIndex from ..plot.plot import _PlotMethods from ..plot.utils import _get_units_from_attrs -from . import ( - alignment, - computation, - dtypes, - indexing, - ops, - resample, - rolling, - utils, - weighted, -) +from . import alignment, computation, dtypes, indexing, ops, resample, utils from ._reductions import DataArrayReductions from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor @@ -83,7 +73,9 @@ from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .groupby import DataArrayGroupBy + from .rolling import DataArrayCoarsen, DataArrayRolling from .types import ( + CoarsenBoundaryOptions, DatetimeUnitOptions, ErrorOptions, ErrorOptionsWithWarn, @@ -93,9 +85,11 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + SideOptions, T_DataArray, T_Xarray, ) + from .weighted import DataArrayWeighted T_XarrayOther = TypeVar("T_XarrayOther", bound=Union["DataArray", Dataset]) @@ -367,10 +361,7 @@ class DataArray( "__weakref__", ) - _rolling_cls = rolling.DataArrayRolling - _coarsen_cls = rolling.DataArrayCoarsen _resample_cls = resample.DataArrayResample - _weighted_cls = weighted.DataArrayWeighted dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor["DataArray"]) @@ -5468,6 +5459,184 @@ def groupby_bins( }, ) + def weighted(self, weights: DataArray) -> DataArrayWeighted: + """ + Weighted DataArray operations. + + Parameters + ---------- + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in the data contributes to the reduction operation + according to its associated weight. + + Notes + ----- + ``weights`` must be a DataArray and cannot contain missing values. + Missing values can be replaced by ``weights.fillna(0)``. + + Returns + ------- + core.weighted.DataArrayWeighted + + See Also + -------- + Dataset.weighted + """ + from . import weighted + + return weighted.DataArrayWeighted(self, weights) + + def rolling( + self, + dim: Mapping[Any, int] | None = None, + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + **window_kwargs: int, + ) -> DataArrayRolling: + """ + Rolling window object for DataArrays. + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int or None, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or Mapping to int, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + core.rolling.DataArrayRolling + + Examples + -------- + Create rolling seasonal average of monthly data e.g. DJF, JFM, ..., SON: + + >>> da = xr.DataArray( + ... np.linspace(0, 11, num=12), + ... coords=[ + ... pd.date_range( + ... "1999-12-15", + ... periods=12, + ... freq=pd.DateOffset(months=1), + ... ) + ... ], + ... dims="time", + ... ) + >>> da + + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 + >>> da.rolling(time=3, center=True).mean() + + array([nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., nan]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 + + Remove the NaNs using ``dropna()``: + + >>> da.rolling(time=3, center=True).mean().dropna("time") + + array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) + Coordinates: + * time (time) datetime64[ns] 2000-01-15 2000-02-15 ... 2000-10-15 + + See Also + -------- + core.rolling.DataArrayRolling + Dataset.rolling + """ + from . import rolling + + dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") + return rolling.DataArrayRolling( + self, dim, min_periods=min_periods, center=center + ) + + def coarsen( + self, + dim: Mapping[Any, int] | None = None, + boundary: CoarsenBoundaryOptions = "exact", + side: SideOptions | Mapping[Any, SideOptions] = "left", + coord_func: str | Callable | Mapping[Any, str | Callable] = "mean", + **window_kwargs: int, + ) -> DataArrayCoarsen: + """ + Coarsen object for DataArrays. + + Parameters + ---------- + dim : mapping of hashable to int, optional + Mapping from the dimension name to the window size. + boundary : {"exact", "trim", "pad"}, default: "exact" + If 'exact', a ValueError will be raised if dimension size is not a + multiple of the window size. If 'trim', the excess entries are + dropped. If 'pad', NA will be padded. + side : {"left", "right"} or mapping of str to {"left", "right"}, default: "left" + coord_func : str or mapping of hashable to str, default: "mean" + function (name) that is applied to the coordinates, + or a mapping from coordinate name to function (name). + + Returns + ------- + core.rolling.DataArrayCoarsen + + Examples + -------- + Coarsen the long time series by averaging over every four days. + + >>> da = xr.DataArray( + ... np.linspace(0, 364, num=364), + ... dims="time", + ... coords={"time": pd.date_range("1999-12-15", periods=364)}, + ... ) + >>> da # +doctest: ELLIPSIS + + array([ 0. , 1.00275482, 2.00550964, 3.00826446, + 4.01101928, 5.0137741 , 6.01652893, 7.01928375, + 8.02203857, 9.02479339, 10.02754821, 11.03030303, + ... + 356.98071625, 357.98347107, 358.9862259 , 359.98898072, + 360.99173554, 361.99449036, 362.99724518, 364. ]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-12-12 + >>> da.coarsen(time=3, boundary="trim").mean() # +doctest: ELLIPSIS + + array([ 1.00275482, 4.01101928, 7.01928375, 10.02754821, + 13.03581267, 16.04407713, 19.0523416 , 22.06060606, + 25.06887052, 28.07713499, 31.08539945, 34.09366391, + ... + 349.96143251, 352.96969697, 355.97796143, 358.9862259 , + 361.99449036]) + Coordinates: + * time (time) datetime64[ns] 1999-12-16 1999-12-19 ... 2000-12-10 + >>> + + See Also + -------- + core.rolling.DataArrayCoarsen + Dataset.coarsen + """ + from . import rolling + + dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") + return rolling.DataArrayCoarsen( + self, + dim, + boundary=boundary, + side=side, + coord_func=coord_func, + ) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3c6502bd275..484f4e2a8a3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -43,9 +43,7 @@ formatting_html, ops, resample, - rolling, utils, - weighted, ) from ._reductions import DatasetReductions from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align @@ -107,8 +105,10 @@ from .dataarray import DataArray from .groupby import DatasetGroupBy from .merge import CoercibleMapping + from .rolling import DatasetCoarsen, DatasetRolling from .types import ( CFCalendar, + CoarsenBoundaryOptions, CombineAttrsOptions, CompatOptions, DatetimeUnitOptions, @@ -121,8 +121,10 @@ QueryEngineOptions, QueryParserOptions, ReindexMethodOptions, + SideOptions, T_Xarray, ) + from .weighted import DatasetWeighted try: from dask.delayed import Delayed @@ -563,10 +565,7 @@ class Dataset( "__weakref__", ) - _rolling_cls = rolling.DatasetRolling - _coarsen_cls = rolling.DatasetCoarsen _resample_cls = resample.DatasetResample - _weighted_cls = weighted.DatasetWeighted def __init__( self, @@ -8658,3 +8657,114 @@ def groupby_bins( "include_lowest": include_lowest, }, ) + + def weighted(self, weights: DataArray) -> DatasetWeighted: + """ + Weighted Dataset operations. + + Parameters + ---------- + weights : DataArray + An array of weights associated with the values in this Dataset. + Each value in the data contributes to the reduction operation + according to its associated weight. + + Notes + ----- + ``weights`` must be a DataArray and cannot contain missing values. + Missing values can be replaced by ``weights.fillna(0)``. + + Returns + ------- + core.weighted.DatasetWeighted + + See Also + -------- + DataArray.weighted + """ + from . import weighted + + return weighted.DatasetWeighted(self, weights) + + def rolling( + self, + dim: Mapping[Any, int] | None = None, + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + **window_kwargs: int, + ) -> DatasetRolling: + """ + Rolling window object for Datasets. + + Parameters + ---------- + dim : dict, optional + Mapping from the dimension name to create the rolling iterator + along (e.g. `time`) to its moving window size. + min_periods : int or None, default: None + Minimum number of observations in window required to have a value + (otherwise result is NA). The default, None, is equivalent to + setting min_periods equal to the size of the window. + center : bool or Mapping to int, default: False + Set the labels at the center of the window. + **window_kwargs : optional + The keyword arguments form of ``dim``. + One of dim or window_kwargs must be provided. + + Returns + ------- + core.rolling.DatasetRolling + + See Also + -------- + core.rolling.DatasetRolling + DataArray.rolling + """ + from . import rolling + + dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") + return rolling.DatasetRolling(self, dim, min_periods=min_periods, center=center) + + def coarsen( + self, + dim: Mapping[Any, int] | None = None, + boundary: CoarsenBoundaryOptions = "exact", + side: SideOptions | Mapping[Any, SideOptions] = "left", + coord_func: str | Callable | Mapping[Any, str | Callable] = "mean", + **window_kwargs: int, + ) -> DatasetCoarsen: + """ + Coarsen object for Datasets. + + Parameters + ---------- + dim : mapping of hashable to int, optional + Mapping from the dimension name to the window size. + boundary : {"exact", "trim", "pad"}, default: "exact" + If 'exact', a ValueError will be raised if dimension size is not a + multiple of the window size. If 'trim', the excess entries are + dropped. If 'pad', NA will be padded. + side : {"left", "right"} or mapping of str to {"left", "right"}, default: "left" + coord_func : str or mapping of hashable to str, default: "mean" + function (name) that is applied to the coordinates, + or a mapping from coordinate name to function (name). + + Returns + ------- + core.rolling.DatasetCoarsen + + See Also + -------- + core.rolling.DatasetCoarsen + DataArray.coarsen + """ + from . import rolling + + dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") + return rolling.DatasetCoarsen( + self, + dim, + boundary=boundary, + side=side, + coord_func=coord_func, + ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 1dfb7d76638..156109663d5 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -295,6 +295,7 @@ class GroupBy(Generic[T_Xarray]): "_stacked_dim", "_unique_coord", "_dims", + "_sizes", "_squeeze", # Save unstacked object for flox "_original_obj", diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index bfbf7c8f34e..e833cb23ace 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -3,7 +3,7 @@ import functools import itertools import warnings -from typing import Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, Mapping import numpy as np @@ -19,6 +19,10 @@ # use numpy methods instead bottleneck = None +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + from .types import CoarsenBoundaryOptions, SideOptions, T_Xarray _ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ Reduce this object's data windows by applying `{name}` along its dimension. @@ -737,7 +741,7 @@ def construct( ) -class Coarsen(CoarsenArithmetic): +class Coarsen(CoarsenArithmetic, Generic[T_Xarray]): """A object that implements the coarsen. See Also @@ -755,8 +759,16 @@ class Coarsen(CoarsenArithmetic): "trim_excess", ) _attributes = ("windows", "side", "trim_excess") + obj: T_Xarray - def __init__(self, obj, windows, boundary, side, coord_func): + def __init__( + self, + obj: T_Xarray, + windows: Mapping[Any, int], + boundary: CoarsenBoundaryOptions, + side: SideOptions | Mapping[Any, SideOptions], + coord_func: str | Callable | Mapping[Any, str | Callable], + ) -> None: """ Moving window object. @@ -767,12 +779,12 @@ def __init__(self, obj, windows, boundary, side, coord_func): windows : mapping of hashable to int A mapping from the name of the dimension to create the rolling exponential window along (e.g. `time`) to the size of the moving window. - boundary : 'exact' | 'trim' | 'pad' + boundary : {"exact", "trim", "pad"} If 'exact', a ValueError will be raised if dimension size is not a multiple of window size. If 'trim', the excess indexes are trimmed. If 'pad', NA will be padded. side : 'left' or 'right' or mapping from dimension to 'left' or 'right' - coord_func : mapping from coordinate name to func. + coord_func : function (name) or mapping from coordinate name to funcion (name). Returns ------- @@ -789,11 +801,11 @@ def __init__(self, obj, windows, boundary, side, coord_func): f"Dimensions {absent_dims!r} not found in {self.obj.__class__.__name__}." ) if not utils.is_dict_like(coord_func): - coord_func = {d: coord_func for d in self.obj.dims} + coord_func = {d: coord_func for d in self.obj.dims} # type: ignore[misc] for c in self.obj.coords: if c not in coord_func: - coord_func[c] = duck_array_ops.mean - self.coord_func = coord_func + coord_func[c] = duck_array_ops.mean # type: ignore[index] + self.coord_func: Mapping[Hashable, str | Callable] = coord_func def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: @@ -801,7 +813,7 @@ def _get_keep_attrs(self, keep_attrs): return keep_attrs - def __repr__(self): + def __repr__(self) -> str: """provide a nice str repr of our coarsen object""" attrs = [ @@ -818,7 +830,7 @@ def construct( window_dim=None, keep_attrs=None, **window_dim_kwargs, - ): + ) -> T_Xarray: """ Convert this Coarsen object to a DataArray or Dataset, where the coarsening dimension is split or reshaped to two @@ -917,7 +929,7 @@ def construct( return result -class DataArrayCoarsen(Coarsen): +class DataArrayCoarsen(Coarsen["DataArray"]): __slots__ = () _reduce_extra_args_docstring = """""" @@ -925,7 +937,7 @@ class DataArrayCoarsen(Coarsen): @classmethod def _reduce_method( cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False - ): + ) -> Callable[..., DataArray]: """ Return a wrapped function for injecting reduction methods. see ops.inject_reduce_methods @@ -934,7 +946,9 @@ def _reduce_method( if include_skipna: kwargs["skipna"] = None - def wrapped_func(self, keep_attrs: bool = None, **kwargs): + def wrapped_func( + self: DataArrayCoarsen, keep_attrs: bool = None, **kwargs + ) -> DataArray: from .dataarray import DataArray keep_attrs = self._get_keep_attrs(keep_attrs) @@ -964,7 +978,7 @@ def wrapped_func(self, keep_attrs: bool = None, **kwargs): return wrapped_func - def reduce(self, func: Callable, keep_attrs: bool = None, **kwargs): + def reduce(self, func: Callable, keep_attrs: bool = None, **kwargs) -> DataArray: """Reduce the items in this group by applying `func` along some dimension(s). @@ -1001,7 +1015,7 @@ def reduce(self, func: Callable, keep_attrs: bool = None, **kwargs): return wrapped_func(self, keep_attrs=keep_attrs, **kwargs) -class DatasetCoarsen(Coarsen): +class DatasetCoarsen(Coarsen["Dataset"]): __slots__ = () _reduce_extra_args_docstring = """""" @@ -1009,7 +1023,7 @@ class DatasetCoarsen(Coarsen): @classmethod def _reduce_method( cls, func: Callable, include_skipna: bool = False, numeric_only: bool = False - ): + ) -> Callable[..., Dataset]: """ Return a wrapped function for injecting reduction methods. see ops.inject_reduce_methods @@ -1018,7 +1032,9 @@ def _reduce_method( if include_skipna: kwargs["skipna"] = None - def wrapped_func(self, keep_attrs: bool = None, **kwargs): + def wrapped_func( + self: DatasetCoarsen, keep_attrs: bool = None, **kwargs + ) -> Dataset: from .dataset import Dataset keep_attrs = self._get_keep_attrs(keep_attrs) @@ -1056,7 +1072,7 @@ def wrapped_func(self, keep_attrs: bool = None, **kwargs): return wrapped_func - def reduce(self, func: Callable, keep_attrs=None, **kwargs): + def reduce(self, func: Callable, keep_attrs=None, **kwargs) -> Dataset: """Reduce the items in this group by applying `func` along some dimension(s). diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 9fd097cd4dc..6033b061335 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -8,7 +8,7 @@ from .options import _get_keep_attrs from .pdcompat import count_not_none from .pycompat import is_duck_dask_array -from .types import T_Xarray +from .types import T_DataWithCoords def _get_alpha(com=None, span=None, halflife=None, alpha=None): @@ -76,7 +76,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -class RollingExp(Generic[T_Xarray]): +class RollingExp(Generic[T_DataWithCoords]): """ Exponentially-weighted moving window object. Similar to EWM in pandas @@ -100,16 +100,16 @@ class RollingExp(Generic[T_Xarray]): def __init__( self, - obj: T_Xarray, + obj: T_DataWithCoords, windows: Mapping[Any, int | float], window_type: str = "span", ): - self.obj: T_Xarray = obj + self.obj: T_DataWithCoords = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self, keep_attrs: bool = None) -> T_Xarray: + def mean(self, keep_attrs: bool | None = None) -> T_DataWithCoords: """ Exponentially weighted moving average. @@ -136,7 +136,7 @@ def mean(self, keep_attrs: bool = None) -> T_Xarray: move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs ) - def sum(self, keep_attrs: bool = None) -> T_Xarray: + def sum(self, keep_attrs: bool | None = None) -> T_DataWithCoords: """ Exponentially weighted moving sum. diff --git a/xarray/core/types.py b/xarray/core/types.py index 30e3653556a..5604c5365dd 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -90,6 +90,9 @@ "366_day", ] +CoarsenBoundaryOptions = Literal["exact", "trim", "pad"] +SideOptions = Literal["left", "right"] + # TODO: Wait until mypy supports recursive objects in combination with typevars _T = TypeVar("_T") NestedSequence = Union[ diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 2e944eab1e0..0e1d37f9a12 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -143,7 +143,7 @@ class Weighted(Generic[T_Xarray]): __slots__ = ("obj", "weights") - def __init__(self, obj: T_Xarray, weights: DataArray): + def __init__(self, obj: T_Xarray, weights: DataArray) -> None: """ Create a Weighted object @@ -525,11 +525,11 @@ def quantile( self._weighted_quantile, q=q, dim=dim, skipna=skipna, keep_attrs=keep_attrs ) - def __repr__(self): + def __repr__(self) -> str: """provide a nice str repr of our Weighted object""" klass = self.__class__.__name__ - weight_dims = ", ".join(self.weights.dims) + weight_dims = ", ".join(map(str, self.weights.dims)) return f"{klass} with weights along dimensions: {weight_dims}" diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index e465b92ccfb..b6fc21ab7b7 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -19,42 +19,42 @@ from .test_dataset import ds -def test_coarsen_absent_dims_error(ds) -> None: +def test_coarsen_absent_dims_error(ds: Dataset) -> None: with pytest.raises(ValueError, match=r"not found in Dataset."): ds.coarsen(foo=2) @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize(("boundary", "side"), [("trim", "left"), ("pad", "right")]) -def test_coarsen_dataset(ds, dask, boundary, side) -> None: +def test_coarsen_dataset(ds: Dataset, dask: bool, boundary, side) -> None: if dask and has_dask: ds = ds.chunk({"x": 4}) - actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() + actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() # type: ignore[attr-defined] assert_equal( - actual["z1"], ds["z1"].coarsen(x=3, boundary=boundary, side=side).max() + actual["z1"], ds["z1"].coarsen(x=3, boundary=boundary, side=side).max() # type: ignore[attr-defined] ) # coordinate should be mean by default assert_equal( - actual["time"], ds["time"].coarsen(time=2, boundary=boundary, side=side).mean() + actual["time"], ds["time"].coarsen(time=2, boundary=boundary, side=side).mean() # type: ignore[attr-defined] ) @pytest.mark.parametrize("dask", [True, False]) -def test_coarsen_coords(ds, dask) -> None: +def test_coarsen_coords(ds: Dataset, dask: bool) -> None: if dask and has_dask: ds = ds.chunk({"x": 4}) # check if coord_func works - actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() - assert_equal(actual["z1"], ds["z1"].coarsen(x=3, boundary="trim").max()) - assert_equal(actual["time"], ds["time"].coarsen(time=2, boundary="trim").max()) + actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() # type: ignore[attr-defined] + assert_equal(actual["z1"], ds["z1"].coarsen(x=3, boundary="trim").max()) # type: ignore[attr-defined] + assert_equal(actual["time"], ds["time"].coarsen(time=2, boundary="trim").max()) # type: ignore[attr-defined] # raise if exact with pytest.raises(ValueError): - ds.coarsen(x=3).mean() + ds.coarsen(x=3).mean() # type: ignore[attr-defined] # should be no error - ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() + ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() # type: ignore[attr-defined] # working test with pd.time da = xr.DataArray( @@ -62,14 +62,14 @@ def test_coarsen_coords(ds, dask) -> None: dims="time", coords={"time": pd.date_range("1999-12-15", periods=364)}, ) - actual = da.coarsen(time=2).mean() + actual = da.coarsen(time=2).mean() # type: ignore[attr-defined] @requires_cftime def test_coarsen_coords_cftime() -> None: times = xr.cftime_range("2000", periods=6) da = xr.DataArray(range(6), [("time", times)]) - actual = da.coarsen(time=3).mean() + actual = da.coarsen(time=3).mean() # type: ignore[attr-defined] expected_times = xr.cftime_range("2000-01-02", freq="3D", periods=2) np.testing.assert_array_equal(actual.time, expected_times) @@ -159,7 +159,7 @@ def test_coarsen_keep_attrs(funcname, argument) -> None: @pytest.mark.parametrize("ds", (1, 2), indirect=True) @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "mean", "std", "var", "min", "max", "median")) -def test_coarsen_reduce(ds, window, name) -> None: +def test_coarsen_reduce(ds: Dataset, window, name) -> None: # Use boundary="trim" to accommodate all window sizes used in tests coarsen_obj = ds.coarsen(time=window, boundary="trim") @@ -253,7 +253,7 @@ def test_coarsen_da_reduce(da, window, name) -> None: @pytest.mark.parametrize("dask", [True, False]) -def test_coarsen_construct(dask) -> None: +def test_coarsen_construct(dask: bool) -> None: ds = Dataset( { From cd2f2b247cf6e1c68c7db6e92100898272fb0e1a Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 18:46:08 +0200 Subject: [PATCH 08/36] type resample --- xarray/core/common.py | 126 ++++++--------------------------------- xarray/core/dataarray.py | 126 ++++++++++++++++++++++++++++++++++++++- xarray/core/dataset.py | 88 +++++++++++++++++++++++---- xarray/core/groupby.py | 4 ++ xarray/core/resample.py | 117 ++++++++++++++++++++++-------------- 5 files changed, 293 insertions(+), 168 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index feb6724cbbb..815400f149b 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -36,14 +36,18 @@ if TYPE_CHECKING: + import datetime + from .dataarray import DataArray from .dataset import Dataset from .indexes import Index + from .resample import Resample from .rolling_exp import RollingExp - from .types import ScalarOrArray, T_DataWithCoords + from .types import ScalarOrArray, SideOptions, T_DataWithCoords from .variable import Variable +T_Resample = TypeVar("T_Resample", bound="Resample") C = TypeVar("C") T = TypeVar("T") @@ -788,111 +792,19 @@ def rolling_exp( return RollingExp(self, window, window_type) - def resample( + def _resample( self, - indexer: Mapping[Any, str] = None, - skipna=None, - closed: str = None, - label: str = None, - base: int = 0, - keep_attrs: bool = None, - loffset=None, - restore_coord_dims: bool = None, + resample_cls: type[T_Resample], + indexer: Mapping[Any, str] | None, + skipna: bool | None, + closed: SideOptions | None, + label: SideOptions | None, + base: int, + keep_attrs: bool | None, + loffset: datetime.timedelta | str | None, + restore_coord_dims: bool | None, **indexer_kwargs: str, - ): - """Returns a Resample object for performing resampling operations. - - Handles both downsampling and upsampling. The resampled - dimension must be a datetime-like coordinate. If any intervals - contain no values from the original object, they will be given - the value ``NaN``. - - Parameters - ---------- - indexer : {dim: freq}, optional - Mapping from the dimension name to resample frequency [1]_. The - dimension must be datetime-like. - skipna : bool, optional - Whether to skip missing values when aggregating in downsampling. - closed : {"left", "right"}, optional - Side of each interval to treat as closed. - label : {"left", "right"}, optional - Side of each interval to use for labeling. - base : int, optional - For frequencies that evenly subdivide 1 day, the "origin" of the - aggregated intervals. For example, for "24H" frequency, base could - range from 0 through 23. - loffset : timedelta or str, optional - Offset used to adjust the resampled time labels. Some pandas date - offset strings are supported. - restore_coord_dims : bool, optional - If True, also restore the dimension order of multi-dimensional - coordinates. - **indexer_kwargs : {dim: freq} - The keyword arguments form of ``indexer``. - One of indexer or indexer_kwargs must be provided. - - Returns - ------- - resampled : same type as caller - This object resampled. - - Examples - -------- - Downsample monthly time-series data to seasonal data: - - >>> da = xr.DataArray( - ... np.linspace(0, 11, num=12), - ... coords=[ - ... pd.date_range( - ... "1999-12-15", - ... periods=12, - ... freq=pd.DateOffset(months=1), - ... ) - ... ], - ... dims="time", - ... ) - >>> da - - array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) - Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 - >>> da.resample(time="QS-DEC").mean() - - array([ 1., 4., 7., 10.]) - Coordinates: - * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 - - Upsample monthly time-series data to daily data: - - >>> da.resample(time="1D").interpolate("linear") # +doctest: ELLIPSIS - - array([ 0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, - 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, - 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , - ... - 10.80645161, 10.83870968, 10.87096774, 10.90322581, 10.93548387, - 10.96774194, 11. ]) - Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 - - Limit scope of upsampling method - - >>> da.resample(time="1D").nearest(tolerance="1D") - - array([ 0., 0., nan, ..., nan, 11., 11.]) - Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 - - See Also - -------- - pandas.Series.resample - pandas.DataFrame.resample - - References - ---------- - .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases - """ + ) -> T_Resample: # TODO support non-string indexer after removing the old API. from ..coding.cftimeindex import CFTimeIndex @@ -923,7 +835,7 @@ def resample( raise ValueError("Resampling only supported along single dimensions.") dim, freq = next(iter(indexer.items())) - dim_name = dim + dim_name: Hashable = dim dim_coord = self[dim] # TODO: remove once pandas=1.1 is the minimum required version @@ -945,7 +857,7 @@ def resample( group = DataArray( dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM ) - resampler = self._resample_cls( + return resample_cls( self, group=group, dim=dim_name, @@ -954,8 +866,6 @@ def resample( restore_coord_dims=restore_coord_dims, ) - return resampler - def where( self: T_DataWithCoords, cond: Any, other: Any = dtypes.NA, drop: bool = False ) -> T_DataWithCoords: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 0cb72cd2f54..8d69e9dd8ab 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -24,7 +24,7 @@ from ..coding.cftimeindex import CFTimeIndex from ..plot.plot import _PlotMethods from ..plot.utils import _get_units_from_attrs -from . import alignment, computation, dtypes, indexing, ops, resample, utils +from . import alignment, computation, dtypes, indexing, ops, utils from ._reductions import DataArrayReductions from .accessor_dt import CombinedDatetimelikeAccessor from .accessor_str import StringAccessor @@ -73,6 +73,7 @@ from ..backends.api import T_NetcdfEngine, T_NetcdfTypes from .groupby import DataArrayGroupBy + from .resample import DataArrayResample from .rolling import DataArrayCoarsen, DataArrayRolling from .types import ( CoarsenBoundaryOptions, @@ -361,8 +362,6 @@ class DataArray( "__weakref__", ) - _resample_cls = resample.DataArrayResample - dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor["DataArray"]) def __init__( @@ -5637,6 +5636,127 @@ def coarsen( coord_func=coord_func, ) + def resample( + self, + indexer: Mapping[Any, str] | None = None, + skipna: bool | None = None, + closed: SideOptions | None = None, + label: SideOptions | None = None, + base: int = 0, + keep_attrs: bool | None = None, + loffset: datetime.timedelta | str | None = None, + restore_coord_dims: bool | None = None, + **indexer_kwargs: str, + ) -> DataArrayResample: + """Returns a Resample object for performing resampling operations. + + Handles both downsampling and upsampling. The resampled + dimension must be a datetime-like coordinate. If any intervals + contain no values from the original object, they will be given + the value ``NaN``. + + Parameters + ---------- + indexer : Mapping of Hashable to str, optional + Mapping from the dimension name to resample frequency [1]_. The + dimension must be datetime-like. + skipna : bool, optional + Whether to skip missing values when aggregating in downsampling. + closed : {"left", "right"}, optional + Side of each interval to treat as closed. + label : {"left", "right"}, optional + Side of each interval to use for labeling. + base : int, default = 0 + For frequencies that evenly subdivide 1 day, the "origin" of the + aggregated intervals. For example, for "24H" frequency, base could + range from 0 through 23. + loffset : timedelta or str, optional + Offset used to adjust the resampled time labels. Some pandas date + offset strings are supported. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + **indexer_kwargs : str + The keyword arguments form of ``indexer``. + One of indexer or indexer_kwargs must be provided. + + Returns + ------- + resampled : core.resample.DataArrayResample + This object resampled. + + Examples + -------- + Downsample monthly time-series data to seasonal data: + + >>> da = xr.DataArray( + ... np.linspace(0, 11, num=12), + ... coords=[ + ... pd.date_range( + ... "1999-12-15", + ... periods=12, + ... freq=pd.DateOffset(months=1), + ... ) + ... ], + ... dims="time", + ... ) + >>> da + + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 + >>> da.resample(time="QS-DEC").mean() + + array([ 1., 4., 7., 10.]) + Coordinates: + * time (time) datetime64[ns] 1999-12-01 2000-03-01 2000-06-01 2000-09-01 + + Upsample monthly time-series data to daily data: + + >>> da.resample(time="1D").interpolate("linear") # +doctest: ELLIPSIS + + array([ 0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, + 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, + 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , + ... + 10.80645161, 10.83870968, 10.87096774, 10.90322581, 10.93548387, + 10.96774194, 11. ]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 + + Limit scope of upsampling method + + >>> da.resample(time="1D").nearest(tolerance="1D") + + array([ 0., 0., nan, ..., nan, 11., 11.]) + Coordinates: + * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 + + See Also + -------- + Dataset.resample + pandas.Series.resample + pandas.DataFrame.resample + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases + """ + from . import resample + + return self._resample( + resample_cls=resample.DataArrayResample, + indexer=indexer, + skipna=skipna, + closed=closed, + label=label, + base=base, + keep_attrs=keep_attrs, + loffset=loffset, + restore_coord_dims=restore_coord_dims, + **indexer_kwargs, + ) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names str = utils.UncachedAccessor(StringAccessor["DataArray"]) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 484f4e2a8a3..475a908345e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -35,16 +35,7 @@ from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings from ..plot.dataset_plot import _Dataset_PlotMethods -from . import ( - alignment, - dtypes, - duck_array_ops, - formatting, - formatting_html, - ops, - resample, - utils, -) +from . import alignment, dtypes, duck_array_ops, formatting, formatting_html, ops, utils from ._reductions import DatasetReductions from .alignment import _broadcast_helper, _get_broadcast_dims_map_common_coords, align from .arithmetic import DatasetArithmetic @@ -105,6 +96,7 @@ from .dataarray import DataArray from .groupby import DatasetGroupBy from .merge import CoercibleMapping + from .resample import DatasetResample from .rolling import DatasetCoarsen, DatasetRolling from .types import ( CFCalendar, @@ -565,8 +557,6 @@ class Dataset( "__weakref__", ) - _resample_cls = resample.DatasetResample - def __init__( self, # could make a VariableArgs to use more generally, and refine these @@ -8768,3 +8758,77 @@ def coarsen( side=side, coord_func=coord_func, ) + + def resample( + self, + indexer: Mapping[Any, str] | None = None, + skipna: bool | None = None, + closed: SideOptions | None = None, + label: SideOptions | None = None, + base: int = 0, + keep_attrs: bool | None = None, + loffset: datetime.timedelta | str | None = None, + restore_coord_dims: bool | None = None, + **indexer_kwargs: str, + ) -> DatasetResample: + """Returns a Resample object for performing resampling operations. + + Handles both downsampling and upsampling. The resampled + dimension must be a datetime-like coordinate. If any intervals + contain no values from the original object, they will be given + the value ``NaN``. + + Parameters + ---------- + indexer : Mapping of Hashable to str, optional + Mapping from the dimension name to resample frequency [1]_. The + dimension must be datetime-like. + skipna : bool, optional + Whether to skip missing values when aggregating in downsampling. + closed : {"left", "right"}, optional + Side of each interval to treat as closed. + label : {"left", "right"}, optional + Side of each interval to use for labeling. + base : int, default = 0 + For frequencies that evenly subdivide 1 day, the "origin" of the + aggregated intervals. For example, for "24H" frequency, base could + range from 0 through 23. + loffset : timedelta or str, optional + Offset used to adjust the resampled time labels. Some pandas date + offset strings are supported. + restore_coord_dims : bool, optional + If True, also restore the dimension order of multi-dimensional + coordinates. + **indexer_kwargs : str + The keyword arguments form of ``indexer``. + One of indexer or indexer_kwargs must be provided. + + Returns + ------- + resampled : core.resample.DataArrayResample + This object resampled. + + See Also + -------- + DataArray.resample + pandas.Series.resample + pandas.DataFrame.resample + + References + ---------- + .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases + """ + from . import resample + + return self._resample( + resample_cls=resample.DatasetResample, + indexer=indexer, + skipna=skipna, + closed=closed, + label=label, + base=base, + keep_attrs=keep_attrs, + loffset=loffset, + restore_coord_dims=restore_coord_dims, + **indexer_kwargs, + ) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 156109663d5..f0c165bf427 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -196,6 +196,10 @@ def ndim(self) -> Literal[1]: def values(self) -> range: return range(self.size) + @property + def data(self) -> range: + return range(self.size) + @property def shape(self) -> tuple[int]: return (self.size,) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index fb6082dc7a9..f624da4b4ff 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,17 +1,22 @@ from __future__ import annotations import warnings -from typing import Any, Callable, Hashable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Sequence import numpy as np from ._reductions import DataArrayResampleReductions, DatasetResampleReductions from .groupby import DataArrayGroupByBase, DatasetGroupByBase, GroupBy +from .types import T_Xarray + +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset RESAMPLE_DIM = "__resample_dim__" -class Resample(GroupBy): +class Resample(GroupBy[T_Xarray]): """An object that extends the `GroupBy` object with additional logic for handling specialized re-sampling operations. @@ -25,7 +30,36 @@ class Resample(GroupBy): """ - def _flox_reduce(self, dim, **kwargs): + def __init__( + self, + *args, + dim: Hashable | None = None, + resample_dim: Hashable | None = None, + **kwargs, + ) -> None: + + if dim == resample_dim: + raise ValueError( + "Proxy resampling dimension ('{}') " + "cannot have the same name as actual dimension " + "('{}')! ".format(resample_dim, dim) + ) + self._dim = dim + self._resample_dim = resample_dim + + super().__init__(*args, **kwargs) + + def mean( + self, + dim: None | Hashable | Sequence[Hashable] = None, + *, + skipna: bool | None = None, + keep_attrs: bool | None = None, + **kwargs: Any, + ) -> T_Xarray: + raise NotImplementedError() + + def _flox_reduce(self, dim, keep_attrs: bool | None = None, **kwargs) -> T_Xarray: from .dataarray import DataArray @@ -34,6 +68,7 @@ def _flox_reduce(self, dim, **kwargs): # now create a label DataArray since resample doesn't do that somehow repeats = [] for slicer in self._group_indices: + assert isinstance(slicer, slice) stop = ( slicer.stop if slicer.stop is not None @@ -43,12 +78,14 @@ def _flox_reduce(self, dim, **kwargs): labels = np.repeat(self._unique_coord.data, repeats) group = DataArray(labels, dims=(self._group_dim,), name=self._unique_coord.name) - result = super()._flox_reduce(dim=dim, group=group, **kwargs) + result = super()._flox_reduce( + dim=dim, group=group, keep_attrs=keep_attrs, **kwargs + ) result = self._maybe_restore_empty_groups(result) result = result.rename({RESAMPLE_DIM: self._group_dim}) return result - def _upsample(self, method, *args, **kwargs): + def _upsample(self, method, *args, **kwargs) -> T_Xarray: """Dispatch function to call appropriate up-sampling methods on data. @@ -84,8 +121,12 @@ def _upsample(self, method, *args, **kwargs): elif method in ["pad", "ffill", "backfill", "bfill", "nearest"]: kwargs = kwargs.copy() - kwargs.update(**{self._dim: upsampled_index}) - return self._obj.reindex(method=method, *args, **kwargs) + if not isinstance(self._dim, str): + raise NotImplementedError( + "Only str dimensions supported by now. Please raise an issue on Github." + ) + kwargs[self._dim] = upsampled_index + return self._obj.reindex(method=method, *args, **kwargs) # type: ignore[misc] elif method == "interpolate": return self._interpolate(*args, **kwargs) @@ -96,13 +137,13 @@ def _upsample(self, method, *args, **kwargs): '"asfreq", "ffill", "bfill", or "interpolate"'.format(method) ) - def asfreq(self): + def asfreq(self) -> T_Xarray: """Return values of original object at the new up-sampling frequency; essentially a re-index with new times set to NaN. """ return self._upsample("asfreq") - def pad(self, tolerance=None): + def pad(self, tolerance=None) -> T_Xarray: """Forward fill new values at up-sampled frequency. Parameters @@ -119,7 +160,7 @@ def pad(self, tolerance=None): ffill = pad - def backfill(self, tolerance=None): + def backfill(self, tolerance=None) -> T_Xarray: """Backward fill new values at up-sampled frequency. Parameters @@ -136,7 +177,7 @@ def backfill(self, tolerance=None): bfill = backfill - def nearest(self, tolerance=None): + def nearest(self, tolerance=None) -> T_Xarray: """Take new values from nearest original coordinate to up-sampled frequency coordinates. @@ -152,7 +193,7 @@ def nearest(self, tolerance=None): """ return self._upsample("nearest", tolerance=tolerance) - def interpolate(self, kind="linear"): + def interpolate(self, kind="linear") -> T_Xarray: """Interpolate up-sampled data using the original data as knots. @@ -169,7 +210,7 @@ def interpolate(self, kind="linear"): """ return self._interpolate(kind=kind) - def _interpolate(self, kind="linear"): + def _interpolate(self, kind="linear") -> T_Xarray: """Apply scipy.interpolate.interp1d along resampling dimension.""" # drop any existing non-dimension coordinates along the resampling # dimension @@ -178,33 +219,26 @@ def _interpolate(self, kind="linear"): if k != self._dim and self._dim in v.dims: dummy = dummy.drop_vars(k) return dummy.interp( + coords={self._dim: self._full_index}, assume_sorted=True, method=kind, kwargs={"bounds_error": False}, - **{self._dim: self._full_index}, ) # https://github.com/python/mypy/issues/9031 -class DataArrayResample(Resample, DataArrayGroupByBase, DataArrayResampleReductions): # type: ignore[misc] +class DataArrayResample(Resample["DataArray"], DataArrayGroupByBase, DataArrayResampleReductions): # type: ignore[misc] """DataArrayGroupBy object specialized to time resampling operations over a specified dimension """ - def __init__(self, *args, dim=None, resample_dim=None, **kwargs): - - if dim == resample_dim: - raise ValueError( - "Proxy resampling dimension ('{}') " - "cannot have the same name as actual dimension " - "('{}')! ".format(resample_dim, dim) - ) - self._dim = dim - self._resample_dim = resample_dim - - super().__init__(*args, **kwargs) - - def map(self, func, shortcut=False, args=(), **kwargs): + def map( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = False, + **kwargs: Any, + ) -> DataArray: """Apply a function to each array in the group and concatenate them together into a new array. @@ -278,23 +312,16 @@ def apply(self, func, args=(), shortcut=None, **kwargs): # https://github.com/python/mypy/issues/9031 -class DatasetResample(Resample, DatasetGroupByBase, DatasetResampleReductions): # type: ignore[misc] +class DatasetResample(Resample["Dataset"], DatasetGroupByBase, DatasetResampleReductions): # type: ignore[misc] """DatasetGroupBy object specialized to resampling a specified dimension""" - def __init__(self, *args, dim=None, resample_dim=None, **kwargs): - - if dim == resample_dim: - raise ValueError( - "Proxy resampling dimension ('{}') " - "cannot have the same name as actual dimension " - "('{}')! ".format(resample_dim, dim) - ) - self._dim = dim - self._resample_dim = resample_dim - - super().__init__(*args, **kwargs) - - def map(self, func, args=(), shortcut=None, **kwargs): + def map( + self, + func: Callable[..., Any], + args: tuple[Any, ...] = (), + shortcut: bool | None = None, + **kwargs: Any, + ) -> Dataset: """Apply a function over each Dataset in the groups generated for resampling and concatenate them together into a new Dataset. @@ -356,7 +383,7 @@ def reduce( keepdims: bool = False, shortcut: bool = True, **kwargs: Any, - ): + ) -> Dataset: """Reduce the items in this group by applying `func` along the pre-defined resampling dimension. From 0f7c3da9b8167d99de6729c0228a754ac51f2d0b Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 18:55:28 +0200 Subject: [PATCH 09/36] fix import in TYPE_CHECKING --- xarray/core/rolling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index e833cb23ace..9b9f1221020 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -11,6 +11,7 @@ from .arithmetic import CoarsenArithmetic from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array +from .types import CoarsenBoundaryOptions, SideOptions, T_Xarray from .utils import either_dict_or_kwargs try: @@ -22,7 +23,6 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import CoarsenBoundaryOptions, SideOptions, T_Xarray _ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ Reduce this object's data windows by applying `{name}` along its dimension. From 5d432858b1d29e98ac8420b5c1c170bbfa243684 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 18:59:42 +0200 Subject: [PATCH 10/36] add new typing to whats-new --- doc/whats-new.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3c53d3bfb04..9acc6004e0d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,6 +22,11 @@ v2022.06.0 (unreleased) New Features ~~~~~~~~~~~~ +- Initial typing support for :py:meth:`groupby`, :py:meth:`rolling`, :py:meth:`rolling_exp`, + :py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`, + (:pull:`6702`) + By `Michael Niklas `_. + Deprecations ~~~~~~~~~~~~ From 9200673b78218cb4733fd6867ae50f5404518bb4 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 19:02:08 +0200 Subject: [PATCH 11/36] fix some import issues --- xarray/core/groupby.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index f0c165bf427..aa7c8a7de17 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -93,7 +93,7 @@ def _dummy_copy(xarray_obj): from . import dataarray, dataset if isinstance(xarray_obj, dataset.Dataset): - res = Dataset( + res = dataset.Dataset( { k: dtypes.get_fill_value(v.dtype) for k, v in xarray_obj.data_vars.items() @@ -106,7 +106,7 @@ def _dummy_copy(xarray_obj): xarray_obj.attrs, ) elif isinstance(xarray_obj, dataarray.DataArray): - res = DataArray( + res = dataarray.DataArray( dtypes.get_fill_value(xarray_obj.dtype), { k: dtypes.get_fill_value(v.dtype) From 64505af5938003a8f82e5182d431cd9aa03cc6c6 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 19:09:06 +0200 Subject: [PATCH 12/36] fix some import issues --- xarray/core/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 815400f149b..f83a9cac8d0 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -780,6 +780,7 @@ def rolling_exp( -------- core.rolling_exp.RollingExp """ + from . import rolling_exp if "keep_attrs" in window_kwargs: warnings.warn( @@ -790,7 +791,7 @@ def rolling_exp( window = either_dict_or_kwargs(window, window_kwargs, "rolling_exp") - return RollingExp(self, window, window_type) + return rolling_exp.RollingExp(self, window, window_type) def _resample( self, From 345a056faf5f0a3bea529abe1e5b91dffc1d8657 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 19:17:29 +0200 Subject: [PATCH 13/36] more fixes --- xarray/core/groupby.py | 6 ++++-- xarray/core/resample.py | 12 +----------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index aa7c8a7de17..3a69f7ad0a6 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -217,10 +217,12 @@ def _ensure_1d( group: T_Group, obj: T_Xarray ) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]: # 1D cases: do nothing + from . import dataarray + if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: return group, obj, None, [] - if isinstance(group, DataArray): + if isinstance(group, dataarray.DataArray): # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) @@ -623,7 +625,7 @@ def _binary_op(self, other, f, reflexive=False): if set(obj[var].dims) < set(group.dims): result[var] = obj[var].reset_coords(drop=True).broadcast_like(group) - if isinstance(result, Dataset) and isinstance(obj, Dataset): + if isinstance(result, dataset.Dataset) and isinstance(obj, dataset.Dataset): for var in set(result): if dim not in obj[var].dims: result[var] = result[var].transpose(dim, ...) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index f624da4b4ff..e954267988e 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -49,16 +49,6 @@ def __init__( super().__init__(*args, **kwargs) - def mean( - self, - dim: None | Hashable | Sequence[Hashable] = None, - *, - skipna: bool | None = None, - keep_attrs: bool | None = None, - **kwargs: Any, - ) -> T_Xarray: - raise NotImplementedError() - def _flox_reduce(self, dim, keep_attrs: bool | None = None, **kwargs) -> T_Xarray: from .dataarray import DataArray @@ -117,7 +107,7 @@ def _upsample(self, method, *args, **kwargs) -> T_Xarray: self._obj = self._obj.drop_vars(k) if method == "asfreq": - return self.mean(self._dim) + return self.mean(self._dim) # type: ignore[attr-defined] elif method in ["pad", "ffill", "backfill", "bfill", "nearest"]: kwargs = kwargs.copy() From 89f4b88434d8bc214a3b12f5b12a556cae789036 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Fri, 17 Jun 2022 19:22:53 +0200 Subject: [PATCH 14/36] add GroupKey synonym for Any --- xarray/core/groupby.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 3a69f7ad0a6..57f1b943f27 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -45,6 +45,8 @@ from .dataset import Dataset from .utils import Frozen + GroupKey = Any + def check_reduce_dims(reduce_dims, dimensions): @@ -455,7 +457,7 @@ def __init__( self._squeeze = squeeze # cached attributes - self._groups: dict[Any, slice | int | list[int]] | None = None + self._groups: dict[GroupKey, slice | int | list[int]] | None = None self._dims = None self._sizes: Frozen[Hashable, int] | None = None @@ -500,7 +502,7 @@ def reduce( raise NotImplementedError() @property - def groups(self) -> dict[Any, slice | int | list[int]]: + def groups(self) -> dict[GroupKey, slice | int | list[int]]: """ Mapping from group labels to indices. The indices can be used to index the underlying object. """ @@ -509,7 +511,7 @@ def groups(self) -> dict[Any, slice | int | list[int]]: self._groups = dict(zip(self._unique_coord.values, self._group_indices)) return self._groups - def __getitem__(self, key: Any) -> T_Xarray: + def __getitem__(self, key: GroupKey) -> T_Xarray: """ Get DataArray or Dataset corresponding to a particular group label. """ @@ -518,7 +520,7 @@ def __getitem__(self, key: Any) -> T_Xarray: def __len__(self) -> int: return self._unique_coord.size - def __iter__(self) -> Iterator[tuple[Any, T_Xarray]]: + def __iter__(self) -> Iterator[tuple[GroupKey, T_Xarray]]: return zip(self._unique_coord.values, self._iter_grouped()) def __repr__(self) -> str: From f8b6afa0400a8c675ae3f38452f2b13427b8829e Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 18 Jun 2022 11:50:02 +0200 Subject: [PATCH 15/36] no type checking for coarsen reductions --- xarray/tests/test_coarsen.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_coarsen.py b/xarray/tests/test_coarsen.py index b6fc21ab7b7..197d2db1f60 100644 --- a/xarray/tests/test_coarsen.py +++ b/xarray/tests/test_coarsen.py @@ -26,35 +26,35 @@ def test_coarsen_absent_dims_error(ds: Dataset) -> None: @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize(("boundary", "side"), [("trim", "left"), ("pad", "right")]) -def test_coarsen_dataset(ds: Dataset, dask: bool, boundary, side) -> None: +def test_coarsen_dataset(ds, dask, boundary, side): if dask and has_dask: ds = ds.chunk({"x": 4}) - actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() # type: ignore[attr-defined] + actual = ds.coarsen(time=2, x=3, boundary=boundary, side=side).max() assert_equal( - actual["z1"], ds["z1"].coarsen(x=3, boundary=boundary, side=side).max() # type: ignore[attr-defined] + actual["z1"], ds["z1"].coarsen(x=3, boundary=boundary, side=side).max() ) # coordinate should be mean by default assert_equal( - actual["time"], ds["time"].coarsen(time=2, boundary=boundary, side=side).mean() # type: ignore[attr-defined] + actual["time"], ds["time"].coarsen(time=2, boundary=boundary, side=side).mean() ) @pytest.mark.parametrize("dask", [True, False]) -def test_coarsen_coords(ds: Dataset, dask: bool) -> None: +def test_coarsen_coords(ds, dask): if dask and has_dask: ds = ds.chunk({"x": 4}) # check if coord_func works - actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() # type: ignore[attr-defined] - assert_equal(actual["z1"], ds["z1"].coarsen(x=3, boundary="trim").max()) # type: ignore[attr-defined] - assert_equal(actual["time"], ds["time"].coarsen(time=2, boundary="trim").max()) # type: ignore[attr-defined] + actual = ds.coarsen(time=2, x=3, boundary="trim", coord_func={"time": "max"}).max() + assert_equal(actual["z1"], ds["z1"].coarsen(x=3, boundary="trim").max()) + assert_equal(actual["time"], ds["time"].coarsen(time=2, boundary="trim").max()) # raise if exact with pytest.raises(ValueError): - ds.coarsen(x=3).mean() # type: ignore[attr-defined] + ds.coarsen(x=3).mean() # should be no error - ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() # type: ignore[attr-defined] + ds.isel(x=slice(0, 3 * (len(ds["x"]) // 3))).coarsen(x=3).mean() # working test with pd.time da = xr.DataArray( @@ -62,14 +62,14 @@ def test_coarsen_coords(ds: Dataset, dask: bool) -> None: dims="time", coords={"time": pd.date_range("1999-12-15", periods=364)}, ) - actual = da.coarsen(time=2).mean() # type: ignore[attr-defined] + actual = da.coarsen(time=2).mean() @requires_cftime -def test_coarsen_coords_cftime() -> None: +def test_coarsen_coords_cftime(): times = xr.cftime_range("2000", periods=6) da = xr.DataArray(range(6), [("time", times)]) - actual = da.coarsen(time=3).mean() # type: ignore[attr-defined] + actual = da.coarsen(time=3).mean() expected_times = xr.cftime_range("2000-01-02", freq="3D", periods=2) np.testing.assert_array_equal(actual.time, expected_times) From 2a6402269d4e646794b9254b40902cc7ea983d3b Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 18 Jun 2022 11:53:14 +0200 Subject: [PATCH 16/36] remove type: ignore --- xarray/core/resample.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index e954267988e..80be785e4f8 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -116,7 +116,8 @@ def _upsample(self, method, *args, **kwargs) -> T_Xarray: "Only str dimensions supported by now. Please raise an issue on Github." ) kwargs[self._dim] = upsampled_index - return self._obj.reindex(method=method, *args, **kwargs) # type: ignore[misc] + kwargs["method"] = method + return self._obj.reindex(*args, **kwargs) elif method == "interpolate": return self._interpolate(*args, **kwargs) From 7034563fdd31ebdc5136c392248d14bd4cbd4ee1 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 18 Jun 2022 11:55:48 +0200 Subject: [PATCH 17/36] remove some more type: ignore --- xarray/core/dataset.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 475a908345e..3cbf9ae7ee6 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5532,18 +5532,21 @@ def reduce( or np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_) ): + reduce_maybe_single: Hashable | None | list[Hashable] if len(reduce_dims) == 1: # unpack dimensions for the benefit of functions # like np.argmin which can't handle tuple arguments - (reduce_dims,) = reduce_dims # type: ignore[assignment] + (reduce_maybe_single,) = reduce_dims elif len(reduce_dims) == var.ndim: # prefer to aggregate over axis=None rather than # axis=(0, 1) if they will be equivalent, because # the former is often more efficient - reduce_dims = None # type: ignore[assignment] + reduce_maybe_single = None + else: + reduce_maybe_single = reduce_dims variables[name] = var.reduce( func, - dim=reduce_dims, + dim=reduce_maybe_single, keep_attrs=keep_attrs, keepdims=keepdims, **kwargs, From 2564d571ef88eb3757b6f70929ad7bd220b7af94 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 18 Jun 2022 12:03:44 +0200 Subject: [PATCH 18/36] selected imports --- xarray/core/dataarray.py | 18 ++++++++---------- xarray/core/dataset.py | 16 ++++++++-------- xarray/core/groupby.py | 29 ++++++++++++++--------------- 3 files changed, 30 insertions(+), 33 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 8d69e9dd8ab..19b1ddeb154 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -5482,9 +5482,9 @@ def weighted(self, weights: DataArray) -> DataArrayWeighted: -------- Dataset.weighted """ - from . import weighted + from .weighted import DataArrayWeighted - return weighted.DataArrayWeighted(self, weights) + return DataArrayWeighted(self, weights) def rolling( self, @@ -5554,12 +5554,10 @@ def rolling( core.rolling.DataArrayRolling Dataset.rolling """ - from . import rolling + from .rolling import DataArrayRolling dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") - return rolling.DataArrayRolling( - self, dim, min_periods=min_periods, center=center - ) + return DataArrayRolling(self, dim, min_periods=min_periods, center=center) def coarsen( self, @@ -5625,10 +5623,10 @@ def coarsen( core.rolling.DataArrayCoarsen Dataset.coarsen """ - from . import rolling + from .rolling import DataArrayCoarsen dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") - return rolling.DataArrayCoarsen( + return DataArrayCoarsen( self, dim, boundary=boundary, @@ -5742,10 +5740,10 @@ def resample( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases """ - from . import resample + from .resample import DataArrayResample return self._resample( - resample_cls=resample.DataArrayResample, + resample_cls=DataArrayResample, indexer=indexer, skipna=skipna, closed=closed, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3cbf9ae7ee6..4f714f43252 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8675,9 +8675,9 @@ def weighted(self, weights: DataArray) -> DatasetWeighted: -------- DataArray.weighted """ - from . import weighted + from .weighted import DatasetWeighted - return weighted.DatasetWeighted(self, weights) + return DatasetWeighted(self, weights) def rolling( self, @@ -8713,10 +8713,10 @@ def rolling( core.rolling.DatasetRolling DataArray.rolling """ - from . import rolling + from .rolling import DatasetRolling dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") - return rolling.DatasetRolling(self, dim, min_periods=min_periods, center=center) + return DatasetRolling(self, dim, min_periods=min_periods, center=center) def coarsen( self, @@ -8751,10 +8751,10 @@ def coarsen( core.rolling.DatasetCoarsen DataArray.coarsen """ - from . import rolling + from .rolling import DatasetCoarsen dim = either_dict_or_kwargs(dim, window_kwargs, "coarsen") - return rolling.DatasetCoarsen( + return DatasetCoarsen( self, dim, boundary=boundary, @@ -8821,10 +8821,10 @@ def resample( ---------- .. [1] http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases """ - from . import resample + from .resample import DatasetResample return self._resample( - resample_cls=resample.DatasetResample, + resample_cls=DatasetResample, indexer=indexer, skipna=skipna, closed=closed, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 57f1b943f27..05f8eb770b3 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -92,10 +92,11 @@ def unique_value_groups( def _dummy_copy(xarray_obj): - from . import dataarray, dataset + from .dataarray import DataArray + from .dataset import Dataset - if isinstance(xarray_obj, dataset.Dataset): - res = dataset.Dataset( + if isinstance(xarray_obj, Dataset): + res = Dataset( { k: dtypes.get_fill_value(v.dtype) for k, v in xarray_obj.data_vars.items() @@ -107,8 +108,8 @@ def _dummy_copy(xarray_obj): }, xarray_obj.attrs, ) - elif isinstance(xarray_obj, dataarray.DataArray): - res = dataarray.DataArray( + elif isinstance(xarray_obj, DataArray): + res = DataArray( dtypes.get_fill_value(xarray_obj.dtype), { k: dtypes.get_fill_value(v.dtype) @@ -219,12 +220,12 @@ def _ensure_1d( group: T_Group, obj: T_Xarray ) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]: # 1D cases: do nothing - from . import dataarray + from .dataarray import DataArray if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1: return group, obj, None, [] - if isinstance(group, dataarray.DataArray): + if isinstance(group, DataArray): # try to stack the dims of the group into a single dim orig_dims = group.dims stacked_dim = "stacked_" + "_".join(map(str, orig_dims)) @@ -564,7 +565,8 @@ def _infer_concat_args(self, applied_example): return coord, dim, positions def _binary_op(self, other, f, reflexive=False): - from . import dataarray, dataset + from .dataarray import DataArray + from .dataset import Dataset g = f if not reflexive else lambda x, y: f(y, x) @@ -575,7 +577,7 @@ def _binary_op(self, other, f, reflexive=False): group = obj[dim] name = group.name - if not isinstance(other, (dataset.Dataset, dataarray.DataArray)): + if not isinstance(other, (Dataset, DataArray)): raise TypeError( "GroupBy objects only support binary ops " "when the other argument is a Dataset or " @@ -627,7 +629,7 @@ def _binary_op(self, other, f, reflexive=False): if set(obj[var].dims) < set(group.dims): result[var] = obj[var].reset_coords(drop=True).broadcast_like(group) - if isinstance(result, dataset.Dataset) and isinstance(obj, dataset.Dataset): + if isinstance(result, Dataset) and isinstance(obj, Dataset): for var in set(result): if dim not in obj[var].dims: result[var] = result[var].transpose(dim, ...) @@ -657,7 +659,7 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs): """Adaptor function that translates our groupby API to that of flox.""" from flox.xarray import xarray_reduce - from . import dataset + from .dataset import Dataset obj = self._original_obj @@ -767,10 +769,7 @@ def _flox_reduce(self, dim, keep_attrs=None, **kwargs): # Fix dimension order when binning a dimension coordinate # Needed as long as we do a separate code path for pint; # For some reason Datasets and DataArrays behave differently! - if ( - isinstance(self._obj, dataset.Dataset) - and self._group_dim in self._obj.dims - ): + if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: result = result.transpose(self._group.name, ...) return result From 484d303318b8c381f1a3924f1a7612c5896ab248 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 18 Jun 2022 16:43:13 +0200 Subject: [PATCH 19/36] minor typing and docstring improvements --- xarray/core/dataarray.py | 10 +++++----- xarray/core/dataset.py | 4 ++-- xarray/core/groupby.py | 7 ++++--- xarray/core/weighted.py | 6 +++--- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 19b1ddeb154..43bd018aea7 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -5315,7 +5315,7 @@ def groupby( ---------- group : Hashable, DataArray or IndexVariable Array whose unique values should be used to group this array. If a - string, must be the name of a variable contained in this dataset. + Hashable, must be the name of a coordinate contained in this dataarray. squeeze : bool, default: True If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along @@ -5379,9 +5379,9 @@ def groupby( def groupby_bins( self, group: Hashable | DataArray | IndexVariable, - bins, + bins: ArrayLike, right: bool = True, - labels=None, + labels: ArrayLike | Literal[False] | None = None, precision: int = 3, include_lowest: bool = False, squeeze: bool = True, @@ -5396,7 +5396,7 @@ def groupby_bins( ---------- group : Hashable, DataArray or IndexVariable Array whose binned values should be used to group this array. If a - string, must be the name of a variable contained in this dataset. + Hashable, must be the name of a coordinate contained in this dataarray. bins : int or array-like If bins is an int, it defines the number of equal-width bins in the range of x. However, in this case, the range of x is extended by .1% @@ -5407,7 +5407,7 @@ def groupby_bins( Indicates whether the bins include the rightmost edge or not. If right == True (the default), then the bins [1,2,3,4] indicate (1,2], (2,3], (3,4]. - labels : array-like or bool, default: None + labels : array-like, False or None, default: None Used as labels for the resulting bins. Must be of the same length as the resulting bins. If False, string bin labels are assigned by `pandas.cut`. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4f714f43252..467b46bb429 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8572,9 +8572,9 @@ def groupby( def groupby_bins( self, group: Hashable | DataArray | IndexVariable, - bins, + bins: ArrayLike, right: bool = True, - labels=None, + labels: ArrayLike | None = None, precision: int = 3, include_lowest: bool = False, squeeze: bool = True, diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 05f8eb770b3..5fa78ae76de 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -11,6 +11,7 @@ Iterable, Iterator, Literal, + Mapping, Sequence, TypeVar, Union, @@ -319,9 +320,9 @@ def __init__( group: Hashable | DataArray | IndexVariable, squeeze: bool = False, grouper: pd.Grouper | None = None, - bins=None, + bins: ArrayLike | None = None, restore_coord_dims: bool = True, - cut_kwargs=None, + cut_kwargs: Mapping[Any, Any] | None = None, ) -> None: """Create a GroupBy object @@ -343,7 +344,7 @@ def __init__( restore_coord_dims : bool, default: True If True, also restore the dimension order of multi-dimensional coordinates. - cut_kwargs : dict, optional + cut_kwargs : dict-like, optional Extra keyword arguments to pass to `pandas.cut` """ diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 0e1d37f9a12..730cf9eac8f 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -26,14 +26,14 @@ Parameters ---------- - dim : str or sequence of str, optional + dim : Hashable or Iterable of Hashable, optional Dimension(s) over which to apply the weighted ``{fcn}``. - skipna : bool, optional + skipna : bool or None, optional If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - keep_attrs : bool, optional + keep_attrs : bool or None, optional If True, the attributes (``attrs``) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. From 14dd8dd5e2301f6877b59bd86d5e6277f320702d Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 18 Jun 2022 18:01:29 +0200 Subject: [PATCH 20/36] improve rolling typing --- xarray/core/rolling.py | 198 +++++++++++++++++++++++++---------------- 1 file changed, 121 insertions(+), 77 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 9b9f1221020..f982ebff98b 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -3,7 +3,16 @@ import functools import itertools import warnings -from typing import TYPE_CHECKING, Any, Callable, Generic, Hashable, Mapping +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Hashable, + Iterator, + Mapping, + TypeVar, +) import numpy as np @@ -24,6 +33,9 @@ from .dataarray import DataArray from .dataset import Dataset + RollingKey = Any + _T = TypeVar("_T") + _ROLLING_REDUCE_DOCSTRING_TEMPLATE = """\ Reduce this object's data windows by applying `{name}` along its dimension. @@ -43,7 +55,7 @@ """ -class Rolling: +class Rolling(Generic[T_Xarray]): """A object that implements the moving window pattern. See Also @@ -57,7 +69,13 @@ class Rolling: __slots__ = ("obj", "window", "min_periods", "center", "dim") _attributes = ("window", "min_periods", "center", "dim") - def __init__(self, obj, windows, min_periods=None, center=False): + def __init__( + self, + obj: T_Xarray, + windows: Mapping[Any, int], + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + ) -> None: """ Moving window object. @@ -68,18 +86,20 @@ def __init__(self, obj, windows, min_periods=None, center=False): windows : mapping of hashable to int A mapping from the name of the dimension to create the rolling window along (e.g. `time`) to the size of the moving window. - min_periods : int, default: None + min_periods : int or None, default: None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : bool, default: False - Set the labels at the center of the window. + center : bool or dict-like Hashable to bool, default: False + Set the labels at the center of the window. If dict-like, set this + property per rolling dimension. Returns ------- rolling : type of input argument """ - self.dim, self.window = [], [] + self.dim: list[Hashable] = [] + self.window: list[int] = [] for d, w in windows.items(): self.dim.append(d) if w <= 0: @@ -87,15 +107,17 @@ def __init__(self, obj, windows, min_periods=None, center=False): self.window.append(w) self.center = self._mapping_to_list(center, default=False) - self.obj = obj + self.obj: T_Xarray = obj # attributes if min_periods is not None and min_periods <= 0: raise ValueError("min_periods must be greater than zero or None") - self.min_periods = np.prod(self.window) if min_periods is None else min_periods + self.min_periods = ( + int(np.prod(self.window)) if min_periods is None else min_periods + ) - def __repr__(self): + def __repr__(self) -> str: """provide a nice str repr of our rolling object""" attrs = [ @@ -106,12 +128,16 @@ def __repr__(self): klass=self.__class__.__name__, attrs=",".join(attrs) ) - def __len__(self): - return self.obj.sizes[self.dim] + def __len__(self) -> int: + return int(np.prod([self.obj.sizes[d] for d in self.dim])) + + @property + def ndim(self) -> int: + return len(self.dim) def _reduce_method( # type: ignore[misc] - name: str, fillna, rolling_agg_func: Callable = None - ) -> Callable: + name: str, fillna: Any, rolling_agg_func: Callable | None = None + ) -> Callable[..., T_Xarray]: """Constructs reduction methods built on a numpy reduction function (e.g. sum), a bottleneck reduction function (e.g. move_sum), or a Rolling reduction (_mean).""" if rolling_agg_func: @@ -157,7 +183,10 @@ def _mean(self, keep_attrs, **kwargs): var = _reduce_method("var", None) median = _reduce_method("median", None) - def count(self, keep_attrs=None): + def _counts(self, keep_attrs: bool | None) -> T_Xarray: + raise NotImplementedError() + + def count(self, keep_attrs: bool | None = None) -> T_Xarray: keep_attrs = self._get_keep_attrs(keep_attrs) rolling_count = self._counts(keep_attrs=keep_attrs) enough_periods = rolling_count >= self.min_periods @@ -166,23 +195,24 @@ def count(self, keep_attrs=None): count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") def _mapping_to_list( - self, arg, default=None, allow_default=True, allow_allsame=True - ): + self, + arg: _T | Mapping[Any, _T], + default: _T | None = None, + allow_default: bool = True, + allow_allsame: bool = True, + ) -> list[_T]: if utils.is_dict_like(arg): if allow_default: return [arg.get(d, default) for d in self.dim] for d in self.dim: if d not in arg: - raise KeyError(f"argument has no key {d}.") + raise KeyError(f"Argument has no dimension key {d}.") return [arg[d] for d in self.dim] - elif allow_allsame: # for single argument - return [arg] * len(self.dim) - elif len(self.dim) == 1: - return [arg] - else: - raise ValueError( - f"Mapping argument is necessary for {len(self.dim)}d-rolling." - ) + if allow_allsame: # for single argument + return [arg] * self.ndim # type: ignore[list-item] # no check for negatives + if self.ndim == 1: + return [arg] # type: ignore[list-item] # no check for negatives + raise ValueError(f"Mapping argument is necessary for {self.ndim}d-rolling.") def _get_keep_attrs(self, keep_attrs): if keep_attrs is None: @@ -191,10 +221,16 @@ def _get_keep_attrs(self, keep_attrs): return keep_attrs -class DataArrayRolling(Rolling): +class DataArrayRolling(Rolling["DataArray"]): __slots__ = ("window_labels",) - def __init__(self, obj, windows, min_periods=None, center=False): + def __init__( + self, + obj: DataArray, + windows: Mapping[Any, int], + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + ) -> None: """ Moving window object for DataArray. You should use DataArray.rolling() method to construct this object @@ -230,14 +266,14 @@ def __init__(self, obj, windows, min_periods=None, center=False): # TODO legacy attribute self.window_labels = self.obj[self.dim[0]] - def __iter__(self): - if len(self.dim) > 1: + def __iter__(self) -> Iterator[tuple[RollingKey, DataArray]]: + if self.ndim > 1: raise ValueError("__iter__ is only supported for 1d-rolling") stops = np.arange(1, len(self.window_labels) + 1) starts = stops - int(self.window[0]) starts[: int(self.window[0])] = 0 for (label, start, stop) in zip(self.window_labels, starts, stops): - window = self.obj.isel(**{self.dim[0]: slice(start, stop)}) + window = self.obj.isel({self.dim[0]: slice(start, stop)}) counts = window.count(dim=self.dim[0]) window = window.where(counts >= self.min_periods) @@ -246,19 +282,19 @@ def __iter__(self): def construct( self, - window_dim=None, - stride=1, - fill_value=dtypes.NA, - keep_attrs=None, - **window_dim_kwargs, - ): + window_dim: Hashable | Mapping[Any, Hashable] | None = None, + stride: int | Mapping[Any, int] = 1, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + **window_dim_kwargs: Hashable, + ) -> DataArray: """ Convert this rolling object to xr.DataArray, where the window dimension is stacked as a new dimension Parameters ---------- - window_dim : str or mapping, optional + window_dim : Hashable or dict-like to Hashable, optional A mapping from dimension name to the new window dimension names. stride : int or mapping of int, default: 1 Size of stride for the rolling window. @@ -268,8 +304,8 @@ def construct( If True, the attributes (``attrs``) will be copied from the original object to the new one. If False, the new object will be returned without attributes. If None uses the global default. - **window_dim_kwargs : {dim: new_name, ...}, optional - The keyword arguments form of ``window_dim``. + **window_dim_kwargs : Hashable, optional + The keyword arguments form of ``window_dim`` {dim: new_name, ...}. Returns ------- @@ -321,13 +357,13 @@ def construct( def _construct( self, - obj, - window_dim=None, - stride=1, - fill_value=dtypes.NA, - keep_attrs=None, - **window_dim_kwargs, - ): + obj: DataArray, + window_dim: Hashable | Mapping[Any, Hashable] | None = None, + stride: int | Mapping[Any, int] = 1, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + **window_dim_kwargs: Hashable, + ) -> DataArray: from .dataarray import DataArray keep_attrs = self._get_keep_attrs(keep_attrs) @@ -337,31 +373,31 @@ def _construct( raise ValueError( "Either window_dim or window_dim_kwargs need to be specified." ) - window_dim = {d: window_dim_kwargs[d] for d in self.dim} + window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} - window_dim = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False + window_dims = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # ?? ) - stride = self._mapping_to_list(stride, default=1) + strides = self._mapping_to_list(stride, default=1) window = obj.variable.rolling_window( - self.dim, self.window, window_dim, self.center, fill_value=fill_value + self.dim, self.window, window_dims, self.center, fill_value=fill_value ) attrs = obj.attrs if keep_attrs else {} result = DataArray( window, - dims=obj.dims + tuple(window_dim), + dims=obj.dims + tuple(window_dims), coords=obj.coords, attrs=attrs, name=obj.name, ) - return result.isel( - **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} - ) + return result.isel({d: slice(None, None, s) for d, s in zip(self.dim, strides)}) - def reduce(self, func, keep_attrs=None, **kwargs): + def reduce( + self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any + ) -> DataArray: """Reduce the items in this group by applying `func` along some dimension(s). @@ -439,7 +475,7 @@ def reduce(self, func, keep_attrs=None, **kwargs): counts = self._counts(keep_attrs=False) return result.where(counts >= self.min_periods) - def _counts(self, keep_attrs): + def _counts(self, keep_attrs: bool | None) -> DataArray: """Number of non-nan entries in each rolling window.""" rolling_dim = { @@ -453,8 +489,8 @@ def _counts(self, keep_attrs): counts = ( self.obj.notnull(keep_attrs=keep_attrs) .rolling( + {d: w for d, w in zip(self.dim, self.window)}, center={d: self.center[i] for i, d in enumerate(self.dim)}, - **{d: w for d, w in zip(self.dim, self.window)}, ) .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) .sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs) @@ -526,7 +562,7 @@ def _numpy_or_bottleneck_reduce( OPTIONS["use_bottleneck"] and bottleneck_move_func is not None and not is_duck_dask_array(self.obj.data) - and len(self.dim) == 1 + and self.ndim == 1 ): # TODO: renable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are @@ -547,10 +583,16 @@ def _numpy_or_bottleneck_reduce( return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs) -class DatasetRolling(Rolling): +class DatasetRolling(Rolling["Dataset"]): __slots__ = ("rollings",) - def __init__(self, obj, windows, min_periods=None, center=False): + def __init__( + self, + obj: Dataset, + windows: Mapping[Any, int], + min_periods: int | None = None, + center: bool | Mapping[Any, bool] = False, + ) -> None: """ Moving window object for Dataset. You should use Dataset.rolling() method to construct this object @@ -616,7 +658,9 @@ def _dataset_implementation(self, func, keep_attrs, **kwargs): attrs = self.obj.attrs if keep_attrs else {} return Dataset(reduced, coords=self.obj.coords, attrs=attrs) - def reduce(self, func, keep_attrs=None, **kwargs): + def reduce( + self, func: Callable, keep_attrs: bool | None = None, **kwargs: Any + ) -> DataArray: """Reduce the items in this group by applying `func` along some dimension(s). @@ -644,7 +688,7 @@ def reduce(self, func, keep_attrs=None, **kwargs): **kwargs, ) - def _counts(self, keep_attrs): + def _counts(self, keep_attrs: bool | None) -> Dataset: return self._dataset_implementation( DataArrayRolling._counts, keep_attrs=keep_attrs ) @@ -670,12 +714,12 @@ def _numpy_or_bottleneck_reduce( def construct( self, - window_dim=None, - stride=1, - fill_value=dtypes.NA, - keep_attrs=None, - **window_dim_kwargs, - ): + window_dim: Hashable | Mapping[Any, Hashable] | None = None, + stride: int | Mapping[Any, int] = 1, + fill_value: Any = dtypes.NA, + keep_attrs: bool | None = None, + **window_dim_kwargs: Hashable, + ) -> Dataset: """ Convert this rolling object to xr.Dataset, where the window dimension is stacked as a new dimension @@ -706,20 +750,20 @@ def construct( raise ValueError( "Either window_dim or window_dim_kwargs need to be specified." ) - window_dim = {d: window_dim_kwargs[d] for d in self.dim} + window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} - window_dim = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False + window_dims = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # ?? ) - stride = self._mapping_to_list(stride, default=1) + strides = self._mapping_to_list(stride, default=1) dataset = {} for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on self.dim dims = [d for d in self.dim if d in da.dims] if dims: - wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims} - st = {d: stride[i] for i, d in enumerate(self.dim) if d in da.dims} + wi = {d: window_dims[i] for i, d in enumerate(self.dim) if d in da.dims} + st = {d: strides[i] for i, d in enumerate(self.dim) if d in da.dims} dataset[key] = self.rollings[key].construct( window_dim=wi, @@ -737,7 +781,7 @@ def construct( attrs = self.obj.attrs if keep_attrs else {} return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel( - **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} + {d: slice(None, None, s) for d, s in zip(self.dim, strides)} ) From 6b47c551f123ccadb9a54f041967819409114ae9 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:52:38 +0200 Subject: [PATCH 21/36] Update xarray/core/resample.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/resample.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 80be785e4f8..0e482d94087 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -40,9 +40,8 @@ def __init__( if dim == resample_dim: raise ValueError( - "Proxy resampling dimension ('{}') " - "cannot have the same name as actual dimension " - "('{}')! ".format(resample_dim, dim) + f"Proxy resampling dimension ('{resample_dim}') " + f"cannot have the same name as actual dimension ('{dim}')!" ) self._dim = dim self._resample_dim = resample_dim From 265870b294a488fb324d98fd8fd2a961a509e3f0 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:54:17 +0200 Subject: [PATCH 22/36] Update xarray/core/rolling.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/rolling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index f982ebff98b..f579590c2ed 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -376,7 +376,7 @@ def _construct( window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} window_dims = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # ?? + window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506 ) strides = self._mapping_to_list(stride, default=1) From 0d7e12ce587d0ad17efa8a9c25d5f0f7dc9005ba Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:54:22 +0200 Subject: [PATCH 23/36] Update xarray/tests/test_groupby.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 7d1932283b2..3bf8513505b 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -241,7 +241,7 @@ def test_da_groupby_quantile() -> None: ) assert_identical(expected_x, actual_x) - actual_y = array.groupby("y").quantile(0, dim=...) # type: ignore[arg-type] + actual_y = array.groupby("y").quantile(0, dim=...) # type: ignore[arg-type] # https://github.com/python/mypy/issues/7818 expected_y = xr.DataArray( data=[1, 22], coords={"y": [0, 1], "quantile": 0}, dims="y" ) From 0dc1532be8b53cd5a44d7df55b025feab20df5c4 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:54:55 +0200 Subject: [PATCH 24/36] Update xarray/core/rolling.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/rolling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index f579590c2ed..29c602b7b48 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -753,7 +753,7 @@ def construct( window_dim = {d: window_dim_kwargs[str(d)] for d in self.dim} window_dims = self._mapping_to_list( - window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # ?? + window_dim, allow_default=False, allow_allsame=False # type: ignore[arg-type] # https://github.com/python/mypy/issues/12506 ) strides = self._mapping_to_list(stride, default=1) From c37fa70df9086570b489669e9ab56401e9f3be72 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:55:01 +0200 Subject: [PATCH 25/36] Update xarray/tests/test_groupby.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 3bf8513505b..a0bc3d79f98 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -356,7 +356,7 @@ def test_ds_groupby_quantile() -> None: coords={"x": [1, 1, 1, 2, 2], "y": [0, 0, 1]}, ) - actual_x = ds.groupby("x").quantile(0, dim=...) # type: ignore[arg-type] + actual_x = ds.groupby("x").quantile(0, dim=...) # type: ignore[arg-type] # https://github.com/python/mypy/issues/7818 expected_x = xr.Dataset({"a": ("x", [1, 4])}, coords={"x": [1, 2], "quantile": 0}) assert_identical(expected_x, actual_x) From 6008daef2bb90c7d1ceb241da956b3fd841c28dc Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:55:07 +0200 Subject: [PATCH 26/36] Update xarray/tests/test_groupby.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index a0bc3d79f98..790b811dfb5 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -360,7 +360,7 @@ def test_ds_groupby_quantile() -> None: expected_x = xr.Dataset({"a": ("x", [1, 4])}, coords={"x": [1, 2], "quantile": 0}) assert_identical(expected_x, actual_x) - actual_y = ds.groupby("y").quantile(0, dim=...) # type: ignore[arg-type] + actual_y = ds.groupby("y").quantile(0, dim=...) # type: ignore[arg-type] # https://github.com/python/mypy/issues/7818 expected_y = xr.Dataset({"a": ("y", [1, 22])}, coords={"y": [0, 1], "quantile": 0}) assert_identical(expected_y, actual_y) From 305716d6df80b9caf0022183033d1e8c0d64e984 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:55:12 +0200 Subject: [PATCH 27/36] Update xarray/tests/test_groupby.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 790b811dfb5..ca6bb2ee721 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -272,7 +272,7 @@ def test_da_groupby_quantile() -> None: ) g = foo.groupby(foo.time.dt.month) - actual = g.quantile(0, dim=...) # type: ignore[arg-type] + actual = g.quantile(0, dim=...) # type: ignore[arg-type] # https://github.com/python/mypy/issues/7818 expected = xr.DataArray( data=[ 0.0, From 4d37705f8dae7bd8e2947624c49028f9b113f233 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:55:17 +0200 Subject: [PATCH 28/36] Update xarray/tests/test_groupby.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index ca6bb2ee721..70fbfbb05f3 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -386,7 +386,7 @@ def test_ds_groupby_quantile() -> None: ) g = foo.groupby(foo.time.dt.month) - actual = g.quantile(0, dim=...) # type: ignore[arg-type] + actual = g.quantile(0, dim=...) # type: ignore[arg-type] # https://github.com/python/mypy/issues/7818 expected = xr.Dataset( { "a": ( From 608002b561402f87b718e0e9798bfb1ef3d4d8b4 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:56:41 +0200 Subject: [PATCH 29/36] Update xarray/core/variable.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/variable.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4e13e8bae73..048942867db 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1847,9 +1847,11 @@ def reduce( if getattr(data, "shape", ()) == self.shape: dims = self.dims else: - removed_axes: Sequence[int] = ( - range(self.ndim) if axis is None else np.atleast_1d(axis) % self.ndim # type: ignore[assignment] - ) + removed_axes: Iterable[int] + if axis is None: + removed_axes = range(self.ndim) + else: + removed_axes = np.atleast_1d(axis) % self.ndim if keepdims: # Insert np.newaxis for removed dims slices = tuple( From 21c662b8fbfa9abe54323c9dc8b7c8e51beebc95 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:56:48 +0200 Subject: [PATCH 30/36] Update xarray/core/variable.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 048942867db..b19e42bf891 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1806,7 +1806,7 @@ def reduce( np.ndarray over an integer valued axis. dim : Hashable or Iterable of Hashable, optional Dimension(s) over which to apply `func`. - axis : int or Iterable of int, optional + axis : int or Sequence of int, optional Axis(es) over which to apply `func`. Only one of the 'dim' and 'axis' arguments can be supplied. If neither are supplied, then the reduction is calculated over the flattened array (by calling From 7e1b21e52abbd29a5e8618e09dd1d09484e1c713 Mon Sep 17 00:00:00 2001 From: Mick <43316012+headtr1ck@users.noreply.github.com> Date: Sun, 19 Jun 2022 15:59:07 +0200 Subject: [PATCH 31/36] Update xarray/tests/test_groupby.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_groupby.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 70fbfbb05f3..a006e54468a 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -235,7 +235,7 @@ def test_da_groupby_quantile() -> None: dims=("x", "y"), ) - actual_x = array.groupby("x").quantile(0, dim=...) # type: ignore[arg-type] + actual_x = array.groupby("x").quantile(0, dim=...) # type: ignore[arg-type] # https://github.com/python/mypy/issues/7818 expected_x = xr.DataArray( data=[1, 4], coords={"x": [1, 2], "quantile": 0}, dims="x" ) From b89d9dc1e69319db110dc33fdd96ca254921f47c Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 25 Jun 2022 14:02:29 +0200 Subject: [PATCH 32/36] reorder whats-new --- doc/whats-new.rst | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c8747e3ee7c..f1e645b1afc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -22,15 +22,14 @@ v2022.06.0 (unreleased) New Features ~~~~~~~~~~~~ -- Initial typing support for :py:meth:`groupby`, :py:meth:`rolling`, :py:meth:`rolling_exp`, - :py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`, - (:pull:`6702`) - By `Michael Niklas `_. - - Add :py:meth:`Dataset.dtypes`, :py:meth:`DatasetCoordinates.dtypes`, :py:meth:`DataArrayCoordinates.dtypes` properties: Mapping from variable names to dtypes. (:pull:`6706`) By `Michael Niklas `_. +- Initial typing support for :py:meth:`groupby`, :py:meth:`rolling`, :py:meth:`rolling_exp`, + :py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`, + (:pull:`6702`) + By `Michael Niklas `_. Deprecations ~~~~~~~~~~~~ From 2a6c48f7419b69fc958140ec0b76b6cd2b670520 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 25 Jun 2022 19:40:19 +0200 Subject: [PATCH 33/36] restructure Resample funcs --- xarray/core/coordinates.py | 2 +- xarray/core/dataarray.py | 4 +- xarray/core/resample.py | 152 +++++++++++++++++-------------------- 3 files changed, 74 insertions(+), 84 deletions(-) diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index cd80d8e2fb0..65949a24369 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -21,7 +21,7 @@ _THIS_ARRAY = ReprObject("") -class Coordinates(Mapping[Any, "DataArray"]): +class Coordinates(Mapping[Hashable, "DataArray"]): __slots__ = () def __getitem__(self, key: Hashable) -> DataArray: diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 225d02e5205..b4acdad9f1c 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1690,7 +1690,7 @@ def reindex( self: T_DataArray, indexers: Mapping[Any, Any] = None, method: ReindexMethodOptions = None, - tolerance: int | float | Iterable[int | float] | None = None, + tolerance: float | Iterable[float] | None = None, copy: bool = True, fill_value=dtypes.NA, **indexers_kwargs: Any, @@ -1720,7 +1720,7 @@ def reindex( - backfill / bfill: propagate next valid index value backward - nearest: use nearest valid index value - tolerance : optional + tolerance : float | Iterable[float] | None, default: None Maximum distance between original and new labels for inexact matches. The values of the index at the matching locations must satisfy the equation ``abs(index[indexer] - target) <= tolerance``. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 0e482d94087..5f5c234cfd5 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -7,7 +7,7 @@ from ._reductions import DataArrayResampleReductions, DatasetResampleReductions from .groupby import DataArrayGroupByBase, DatasetGroupByBase, GroupBy -from .types import T_Xarray +from .types import InterpOptions, T_Xarray if TYPE_CHECKING: from .dataarray import DataArray @@ -74,127 +74,121 @@ def _flox_reduce(self, dim, keep_attrs: bool | None = None, **kwargs) -> T_Xarra result = result.rename({RESAMPLE_DIM: self._group_dim}) return result - def _upsample(self, method, *args, **kwargs) -> T_Xarray: - """Dispatch function to call appropriate up-sampling methods on - data. - - This method should not be called directly; instead, use one of the - wrapper functions supplied by `Resample`. - - Parameters - ---------- - method : {"asfreq", "pad", "ffill", "backfill", "bfill", "nearest", \ - "interpolate"} - Method to use for up-sampling - - See Also - -------- - Resample.asfreq - Resample.pad - Resample.backfill - Resample.interpolate - - """ - - upsampled_index = self._full_index - - # Drop non-dimension coordinates along the resampled dimension - for k, v in self._obj.coords.items(): - if k == self._dim: - continue - if self._dim in v.dims: - self._obj = self._obj.drop_vars(k) - - if method == "asfreq": - return self.mean(self._dim) # type: ignore[attr-defined] - - elif method in ["pad", "ffill", "backfill", "bfill", "nearest"]: - kwargs = kwargs.copy() - if not isinstance(self._dim, str): - raise NotImplementedError( - "Only str dimensions supported by now. Please raise an issue on Github." - ) - kwargs[self._dim] = upsampled_index - kwargs["method"] = method - return self._obj.reindex(*args, **kwargs) - - elif method == "interpolate": - return self._interpolate(*args, **kwargs) - - else: - raise ValueError( - 'Specified method was "{}" but must be one of' - '"asfreq", "ffill", "bfill", or "interpolate"'.format(method) - ) + def _drop_coords(self) -> T_Xarray: + """Drop non-dimension coordinates along the resampled dimension.""" + obj = self._obj + for k, v in obj.coords.items(): + if k != self._dim and self._dim in v.dims: + obj = obj.drop_vars(k) + return obj def asfreq(self) -> T_Xarray: """Return values of original object at the new up-sampling frequency; essentially a re-index with new times set to NaN. + + Returns + ------- + resampled : DataArray or Dataset """ - return self._upsample("asfreq") + self._obj = self._drop_coords() + return self.mean(self._dim) # type: ignore[attr-defined] - def pad(self, tolerance=None) -> T_Xarray: + def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: """Forward fill new values at up-sampled frequency. Parameters ---------- - tolerance : optional + tolerance : float | Iterable[float] | None, default: None Maximum distance between original and new labels to limit the up-sampling method. Up-sampled data with indices that satisfy the equation ``abs(index[indexer] - target) <= tolerance`` are filled by new values. Data with indices that are outside the given - tolerance are filled with ``NaN`` s + tolerance are filled with ``NaN`` s. + + Returns + ------- + padded : DataArray or Dataset """ - return self._upsample("pad", tolerance=tolerance) + obj = self._drop_coords() + return obj.reindex( + {self._dim: self._full_index}, method="pad", tolerance=tolerance + ) ffill = pad - def backfill(self, tolerance=None) -> T_Xarray: + def backfill(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: """Backward fill new values at up-sampled frequency. Parameters ---------- - tolerance : optional + tolerance : float | Iterable[float] | None, default: None Maximum distance between original and new labels to limit the up-sampling method. Up-sampled data with indices that satisfy the equation ``abs(index[indexer] - target) <= tolerance`` are filled by new values. Data with indices that are outside the given - tolerance are filled with ``NaN`` s + tolerance are filled with ``NaN`` s. + + Returns + ------- + backfilled : DataArray or Dataset """ - return self._upsample("backfill", tolerance=tolerance) + obj = self._drop_coords() + return obj.reindex( + {self._dim: self._full_index}, method="backfill", tolerance=tolerance + ) bfill = backfill - def nearest(self, tolerance=None) -> T_Xarray: + def nearest(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: """Take new values from nearest original coordinate to up-sampled frequency coordinates. Parameters ---------- - tolerance : optional + tolerance : float | Iterable[float] | None, default: None Maximum distance between original and new labels to limit the up-sampling method. Up-sampled data with indices that satisfy the equation ``abs(index[indexer] - target) <= tolerance`` are filled by new values. Data with indices that are outside the given - tolerance are filled with ``NaN`` s + tolerance are filled with ``NaN`` s. + + Returns + ------- + upsampled : DataArray or Dataset """ - return self._upsample("nearest", tolerance=tolerance) + obj = self._drop_coords() + return obj.reindex( + {self._dim: self._full_index}, method="nearest", tolerance=tolerance + ) - def interpolate(self, kind="linear") -> T_Xarray: - """Interpolate up-sampled data using the original data - as knots. + def interpolate(self, kind: InterpOptions = "linear") -> T_Xarray: + """Interpolate up-sampled data using the original data as knots. Parameters ---------- kind : {"linear", "nearest", "zero", "slinear", \ - "quadratic", "cubic"}, default: "linear" - Interpolation scheme to use + "quadratic", "cubic", "polynomial"}, default: "linear" + The method used to interpolate. The method should be supported by + the scipy interpolator: + + - ``interp1d``: {"linear", "nearest", "zero", "slinear", + "quadratic", "cubic", "polynomial"} + - ``interpn``: {"linear", "nearest"} + + If ``"polynomial"`` is passed, the ``order`` keyword argument must + also be provided. + + Returns + ------- + interpolated : DataArray or Dataset See Also -------- + DataArray.interp + Dataset.interp scipy.interpolate.interp1d """ @@ -202,13 +196,8 @@ def interpolate(self, kind="linear") -> T_Xarray: def _interpolate(self, kind="linear") -> T_Xarray: """Apply scipy.interpolate.interp1d along resampling dimension.""" - # drop any existing non-dimension coordinates along the resampling - # dimension - dummy = self._obj.copy() - for k, v in self._obj.coords.items(): - if k != self._dim and self._dim in v.dims: - dummy = dummy.drop_vars(k) - return dummy.interp( + obj = self._drop_coords() + return obj.interp( coords={self._dim: self._full_index}, assume_sorted=True, method=kind, @@ -267,7 +256,7 @@ def map( Returns ------- - applied : DataArray or DataArray + applied : DataArray The result of splitting, applying and combining this array. """ # TODO: the argument order for Resample doesn't match that for its parent, @@ -338,7 +327,7 @@ def map( Returns ------- - applied : Dataset or DataArray + applied : Dataset The result of splitting, applying and combining this dataset. """ # ignore shortcut if set (for now) @@ -394,7 +383,7 @@ def reduce( Returns ------- - reduced : Array + reduced : Dataset Array with summarized data and the indicated dimension(s) removed. """ @@ -404,5 +393,6 @@ def reduce( axis=axis, keep_attrs=keep_attrs, keepdims=keepdims, + shortcut=shortcut, **kwargs, ) From 880eb77e8e90de3b33bfd325e0942062f1f6c943 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sat, 25 Jun 2022 19:48:00 +0200 Subject: [PATCH 34/36] change to math.prod --- xarray/core/rolling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 29c602b7b48..aef290f6d7f 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -2,6 +2,7 @@ import functools import itertools +import math import warnings from typing import ( TYPE_CHECKING, @@ -114,7 +115,7 @@ def __init__( raise ValueError("min_periods must be greater than zero or None") self.min_periods = ( - int(np.prod(self.window)) if min_periods is None else min_periods + math.prod(self.window) if min_periods is None else min_periods ) def __repr__(self) -> str: @@ -129,7 +130,7 @@ def __repr__(self) -> str: ) def __len__(self) -> int: - return int(np.prod([self.obj.sizes[d] for d in self.dim])) + return math.prod(self.obj.sizes[d] for d in self.dim) @property def ndim(self) -> int: From 87d36ed0a9d8122cc9300ca21a445d142e094896 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 26 Jun 2022 10:06:30 +0200 Subject: [PATCH 35/36] move asfreq Data{Array,set}Resample --- xarray/core/resample.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 5f5c234cfd5..5f340a1fbbe 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -90,8 +90,8 @@ def asfreq(self) -> T_Xarray: ------- resampled : DataArray or Dataset """ - self._obj = self._drop_coords() - return self.mean(self._dim) # type: ignore[attr-defined] + # requires mean, which is only available for Reduction Mixins + raise NotImplementedError def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: """Forward fill new values at up-sampled frequency. @@ -289,6 +289,17 @@ def apply(self, func, args=(), shortcut=None, **kwargs): ) return self.map(func=func, shortcut=shortcut, args=args, **kwargs) + def asfreq(self) -> DataArray: + """Return values of original object at the new up-sampling frequency; + essentially a re-index with new times set to NaN. + + Returns + ------- + resampled : DataArray + """ + self._obj = self._drop_coords() + return self.mean(self._dim) + # https://github.com/python/mypy/issues/9031 class DatasetResample(Resample["Dataset"], DatasetGroupByBase, DatasetResampleReductions): # type: ignore[misc] @@ -396,3 +407,14 @@ def reduce( shortcut=shortcut, **kwargs, ) + + def asfreq(self) -> Dataset: + """Return values of original object at the new up-sampling frequency; + essentially a re-index with new times set to NaN. + + Returns + ------- + resampled : Dataset + """ + self._obj = self._drop_coords() + return self.mean(self._dim) From 4c0aa8d30c041e3462e87493ce83fbe427e0f9f5 Mon Sep 17 00:00:00 2001 From: Michael Niklas Date: Sun, 26 Jun 2022 12:24:06 +0200 Subject: [PATCH 36/36] remove abstract asfreq method --- xarray/core/resample.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 5f340a1fbbe..bf9c9f7501a 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -82,17 +82,6 @@ def _drop_coords(self) -> T_Xarray: obj = obj.drop_vars(k) return obj - def asfreq(self) -> T_Xarray: - """Return values of original object at the new up-sampling frequency; - essentially a re-index with new times set to NaN. - - Returns - ------- - resampled : DataArray or Dataset - """ - # requires mean, which is only available for Reduction Mixins - raise NotImplementedError - def pad(self, tolerance: float | Iterable[float] | None = None) -> T_Xarray: """Forward fill new values at up-sampled frequency.