1
- from typing import TYPE_CHECKING , Hashable , Iterable , Optional , Union , overload
1
+ from typing import TYPE_CHECKING , Generic , Hashable , Iterable , Optional , TypeVar , Union
2
2
3
3
from . import duck_array_ops
4
4
from .computation import dot
5
- from .options import _get_keep_attrs
6
5
from .pycompat import is_duck_dask_array
7
6
8
7
if TYPE_CHECKING :
8
+ from .common import DataWithCoords # noqa: F401
9
9
from .dataarray import DataArray , Dataset
10
10
11
+ T_DataWithCoords = TypeVar ("T_DataWithCoords" , bound = "DataWithCoords" )
12
+
13
+
11
14
_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
12
15
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
13
16
56
59
"""
57
60
58
61
59
- class Weighted :
62
+ class Weighted ( Generic [ T_DataWithCoords ]) :
60
63
"""An object that implements weighted operations.
61
64
62
65
You should create a Weighted object by using the ``DataArray.weighted`` or
@@ -70,15 +73,7 @@ class Weighted:
70
73
71
74
__slots__ = ("obj" , "weights" )
72
75
73
- @overload
74
- def __init__ (self , obj : "DataArray" , weights : "DataArray" ) -> None :
75
- ...
76
-
77
- @overload
78
- def __init__ (self , obj : "Dataset" , weights : "DataArray" ) -> None :
79
- ...
80
-
81
- def __init__ (self , obj , weights ):
76
+ def __init__ (self , obj : T_DataWithCoords , weights : "DataArray" ):
82
77
"""
83
78
Create a Weighted object
84
79
@@ -121,8 +116,8 @@ def _weight_check(w):
121
116
else :
122
117
_weight_check (weights .data )
123
118
124
- self .obj = obj
125
- self .weights = weights
119
+ self .obj : T_DataWithCoords = obj
120
+ self .weights : "DataArray" = weights
126
121
127
122
@staticmethod
128
123
def _reduce (
@@ -146,7 +141,6 @@ def _reduce(
146
141
147
142
# `dot` does not broadcast arrays, so this avoids creating a large
148
143
# DataArray (if `weights` has additional dimensions)
149
- # maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
150
144
return dot (da , weights , dims = dim )
151
145
152
146
def _sum_of_weights (
@@ -203,7 +197,7 @@ def sum_of_weights(
203
197
self ,
204
198
dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
205
199
keep_attrs : Optional [bool ] = None ,
206
- ) -> Union [ "DataArray" , "Dataset" ] :
200
+ ) -> T_DataWithCoords :
207
201
208
202
return self ._implementation (
209
203
self ._sum_of_weights , dim = dim , keep_attrs = keep_attrs
@@ -214,7 +208,7 @@ def sum(
214
208
dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
215
209
skipna : Optional [bool ] = None ,
216
210
keep_attrs : Optional [bool ] = None ,
217
- ) -> Union [ "DataArray" , "Dataset" ] :
211
+ ) -> T_DataWithCoords :
218
212
219
213
return self ._implementation (
220
214
self ._weighted_sum , dim = dim , skipna = skipna , keep_attrs = keep_attrs
@@ -225,7 +219,7 @@ def mean(
225
219
dim : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
226
220
skipna : Optional [bool ] = None ,
227
221
keep_attrs : Optional [bool ] = None ,
228
- ) -> Union [ "DataArray" , "Dataset" ] :
222
+ ) -> T_DataWithCoords :
229
223
230
224
return self ._implementation (
231
225
self ._weighted_mean , dim = dim , skipna = skipna , keep_attrs = keep_attrs
@@ -239,22 +233,15 @@ def __repr__(self):
239
233
return f"{ klass } with weights along dimensions: { weight_dims } "
240
234
241
235
242
- class DataArrayWeighted (Weighted ):
243
- def _implementation (self , func , dim , ** kwargs ):
244
-
245
- keep_attrs = kwargs .pop ("keep_attrs" )
246
- if keep_attrs is None :
247
- keep_attrs = _get_keep_attrs (default = False )
248
-
249
- weighted = func (self .obj , dim = dim , ** kwargs )
250
-
251
- if keep_attrs :
252
- weighted .attrs = self .obj .attrs
236
+ class DataArrayWeighted (Weighted ["DataArray" ]):
237
+ def _implementation (self , func , dim , ** kwargs ) -> "DataArray" :
253
238
254
- return weighted
239
+ dataset = self .obj ._to_temp_dataset ()
240
+ dataset = dataset .map (func , dim = dim , ** kwargs )
241
+ return self .obj ._from_temp_dataset (dataset )
255
242
256
243
257
- class DatasetWeighted (Weighted ):
244
+ class DatasetWeighted (Weighted [ "Dataset" ] ):
258
245
def _implementation (self , func , dim , ** kwargs ) -> "Dataset" :
259
246
260
247
return self .obj .map (func , dim = dim , ** kwargs )
0 commit comments