Skip to content

Commit 9fea799

Browse files
authored
weighted: small improvements (#4818)
* weighted: small improvements * use T_DataWithCoords
1 parent a4bb7e1 commit 9fea799

File tree

2 files changed

+28
-32
lines changed

2 files changed

+28
-32
lines changed

xarray/core/common.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from html import escape
44
from textwrap import dedent
55
from typing import (
6+
TYPE_CHECKING,
67
Any,
78
Callable,
89
Dict,
@@ -32,6 +33,12 @@
3233
ALL_DIMS = ...
3334

3435

36+
if TYPE_CHECKING:
37+
from .dataarray import DataArray
38+
from .weighted import Weighted
39+
40+
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
41+
3542
C = TypeVar("C")
3643
T = TypeVar("T")
3744

@@ -772,7 +779,9 @@ def groupby_bins(
772779
},
773780
)
774781

775-
def weighted(self, weights):
782+
def weighted(
783+
self: T_DataWithCoords, weights: "DataArray"
784+
) -> "Weighted[T_DataWithCoords]":
776785
"""
777786
Weighted operations.
778787

xarray/core/weighted.py

+18-31
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
1+
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union
22

33
from . import duck_array_ops
44
from .computation import dot
5-
from .options import _get_keep_attrs
65
from .pycompat import is_duck_dask_array
76

87
if TYPE_CHECKING:
8+
from .common import DataWithCoords # noqa: F401
99
from .dataarray import DataArray, Dataset
1010

11+
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
12+
13+
1114
_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
1215
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
1316
@@ -56,7 +59,7 @@
5659
"""
5760

5861

59-
class Weighted:
62+
class Weighted(Generic[T_DataWithCoords]):
6063
"""An object that implements weighted operations.
6164
6265
You should create a Weighted object by using the ``DataArray.weighted`` or
@@ -70,15 +73,7 @@ class Weighted:
7073

7174
__slots__ = ("obj", "weights")
7275

73-
@overload
74-
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
75-
...
76-
77-
@overload
78-
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
79-
...
80-
81-
def __init__(self, obj, weights):
76+
def __init__(self, obj: T_DataWithCoords, weights: "DataArray"):
8277
"""
8378
Create a Weighted object
8479
@@ -121,8 +116,8 @@ def _weight_check(w):
121116
else:
122117
_weight_check(weights.data)
123118

124-
self.obj = obj
125-
self.weights = weights
119+
self.obj: T_DataWithCoords = obj
120+
self.weights: "DataArray" = weights
126121

127122
@staticmethod
128123
def _reduce(
@@ -146,7 +141,6 @@ def _reduce(
146141

147142
# `dot` does not broadcast arrays, so this avoids creating a large
148143
# DataArray (if `weights` has additional dimensions)
149-
# maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
150144
return dot(da, weights, dims=dim)
151145

152146
def _sum_of_weights(
@@ -203,7 +197,7 @@ def sum_of_weights(
203197
self,
204198
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
205199
keep_attrs: Optional[bool] = None,
206-
) -> Union["DataArray", "Dataset"]:
200+
) -> T_DataWithCoords:
207201

208202
return self._implementation(
209203
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
@@ -214,7 +208,7 @@ def sum(
214208
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
215209
skipna: Optional[bool] = None,
216210
keep_attrs: Optional[bool] = None,
217-
) -> Union["DataArray", "Dataset"]:
211+
) -> T_DataWithCoords:
218212

219213
return self._implementation(
220214
self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
@@ -225,7 +219,7 @@ def mean(
225219
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
226220
skipna: Optional[bool] = None,
227221
keep_attrs: Optional[bool] = None,
228-
) -> Union["DataArray", "Dataset"]:
222+
) -> T_DataWithCoords:
229223

230224
return self._implementation(
231225
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
@@ -239,22 +233,15 @@ def __repr__(self):
239233
return f"{klass} with weights along dimensions: {weight_dims}"
240234

241235

242-
class DataArrayWeighted(Weighted):
243-
def _implementation(self, func, dim, **kwargs):
244-
245-
keep_attrs = kwargs.pop("keep_attrs")
246-
if keep_attrs is None:
247-
keep_attrs = _get_keep_attrs(default=False)
248-
249-
weighted = func(self.obj, dim=dim, **kwargs)
250-
251-
if keep_attrs:
252-
weighted.attrs = self.obj.attrs
236+
class DataArrayWeighted(Weighted["DataArray"]):
237+
def _implementation(self, func, dim, **kwargs) -> "DataArray":
253238

254-
return weighted
239+
dataset = self.obj._to_temp_dataset()
240+
dataset = dataset.map(func, dim=dim, **kwargs)
241+
return self.obj._from_temp_dataset(dataset)
255242

256243

257-
class DatasetWeighted(Weighted):
244+
class DatasetWeighted(Weighted["Dataset"]):
258245
def _implementation(self, func, dim, **kwargs) -> "Dataset":
259246

260247
return self.obj.map(func, dim=dim, **kwargs)

0 commit comments

Comments
 (0)