1
1
from __future__ import annotations
2
2
import functools
3
3
import textwrap
4
- import inspect
5
4
6
5
from typing import Mapping , Hashable , Union , List , Any , Callable , Iterable , Dict
7
6
12
11
from xarray .core .variable import Variable
13
12
from xarray .core .combine import merge
14
13
from xarray .core import dtypes , utils
14
+ from xarray .core .common import DataWithCoords
15
+ from xarray .core .arithmetic import DatasetArithmetic
16
+ from xarray .core .ops import REDUCE_METHODS , NAN_REDUCE_METHODS , NAN_CUM_METHODS
15
17
16
18
from .treenode import TreeNode , PathType , _init_single_treenode
17
19
@@ -46,7 +48,8 @@ def map_over_subtree(func):
46
48
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
47
49
descendant nodes. The returned tree will have the same structure as the original subtree.
48
50
49
- func needs to return a Dataset in order to rebuild the subtree.
51
+ func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
52
+ result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
50
53
51
54
Parameters
52
55
----------
@@ -132,7 +135,6 @@ def attrs(self):
132
135
else :
133
136
raise AttributeError ("property is not defined for a node with no data" )
134
137
135
-
136
138
@property
137
139
def nbytes (self ) -> int :
138
140
return sum (node .ds .nbytes for node in self .subtree_nodes )
@@ -203,10 +205,33 @@ def imag(self):
203
205
204
206
_MAPPED_DOCSTRING_ADDENDUM = textwrap .fill ("This method was copied from xarray.Dataset, but has been altered to "
205
207
"call the method on the Datasets stored in every node of the subtree. "
206
- "See the `map_over_subtree` decorator for more details." , width = 117 )
207
-
208
-
209
- def _wrap_then_attach_to_cls (cls_dict , methods_to_expose , wrap_func = None ):
208
+ "See the `map_over_subtree` function for more details." , width = 117 )
209
+
210
+ # TODO equals, broadcast_equals etc.
211
+ # TODO do dask-related private methods need to be exposed?
212
+ _DATASET_DASK_METHODS_TO_MAP = ['load' , 'compute' , 'persist' , 'unify_chunks' , 'chunk' , 'map_blocks' ]
213
+ _DATASET_METHODS_TO_MAP = ['copy' , 'as_numpy' , '__copy__' , '__deepcopy__' , 'set_coords' , 'reset_coords' , 'info' ,
214
+ 'isel' , 'sel' , 'head' , 'tail' , 'thin' , 'broadcast_like' , 'reindex_like' ,
215
+ 'reindex' , 'interp' , 'interp_like' , 'rename' , 'rename_dims' , 'rename_vars' ,
216
+ 'swap_dims' , 'expand_dims' , 'set_index' , 'reset_index' , 'reorder_levels' , 'stack' ,
217
+ 'unstack' , 'update' , 'merge' , 'drop_vars' , 'drop_sel' , 'drop_isel' , 'drop_dims' ,
218
+ 'transpose' , 'dropna' , 'fillna' , 'interpolate_na' , 'ffill' , 'bfill' , 'combine_first' ,
219
+ 'reduce' , 'map' , 'assign' , 'diff' , 'shift' , 'roll' , 'sortby' , 'quantile' , 'rank' ,
220
+ 'differentiate' , 'integrate' , 'cumulative_integrate' , 'filter_by_attrs' , 'polyfit' ,
221
+ 'pad' , 'idxmin' , 'idxmax' , 'argmin' , 'argmax' , 'query' , 'curvefit' ]
222
+ # TODO unsure if these are called by external functions or not?
223
+ _DATASET_OPS_TO_MAP = ['_unary_op' , '_binary_op' , '_inplace_binary_op' ]
224
+ _ALL_DATASET_METHODS_TO_MAP = _DATASET_DASK_METHODS_TO_MAP + _DATASET_METHODS_TO_MAP + _DATASET_OPS_TO_MAP
225
+
226
+ _DATA_WITH_COORDS_METHODS_TO_MAP = ['squeeze' , 'clip' , 'assign_coords' , 'where' , 'close' , 'isnull' , 'notnull' ,
227
+ 'isin' , 'astype' ]
228
+
229
+ # TODO NUM_BINARY_OPS apparently aren't defined on DatasetArithmetic, and don't appear to be injected anywhere...
230
+ #['__array_ufunc__'] \
231
+ _ARITHMETIC_METHODS_TO_WRAP = REDUCE_METHODS + NAN_REDUCE_METHODS + NAN_CUM_METHODS
232
+
233
+
234
+ def _wrap_then_attach_to_cls (target_cls_dict , source_cls , methods_to_set , wrap_func = None ):
210
235
"""
211
236
Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree)
212
237
@@ -219,23 +244,32 @@ def method_name(self, *args, **kwargs):
219
244
220
245
Parameters
221
246
----------
222
- cls_dict
223
- The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes'
224
- definition.
225
- methods_to_expose : Iterable[Tuple[str, callable]]
226
- The method names and definitions supplied as a list of (method_name_string, method) pairs.\
247
+ target_cls_dict : MappingProxy
248
+ The __dict__ attribute of the class which we want the methods to be added to. (The __dict__ attribute can also
249
+ be accessed by calling vars() from within that classes' definition.) This will be updated by this function.
250
+ source_cls : class
251
+ Class object from which we want to copy methods (and optionally wrap them). Should be the actual class object
252
+ (or instance), not just the __dict__.
253
+ methods_to_set : Iterable[Tuple[str, callable]]
254
+ The method names and definitions supplied as a list of (method_name_string, method) pairs.
227
255
This format matches the output of inspect.getmembers().
256
+ wrap_func : callable, optional
257
+ Function to decorate each method with. Must have the same return type as the method.
228
258
"""
229
- for method_name , method in methods_to_expose :
230
- wrapped_method = wrap_func (method ) if wrap_func is not None else method
231
- cls_dict [method_name ] = wrapped_method
232
-
233
- # TODO do we really need this for ops like __add__?
234
- # Add a line to the method's docstring explaining how it's been mapped
235
- method_docstring = method .__doc__
236
- if method_docstring is not None :
237
- updated_method_docstring = method_docstring .replace ('\n ' , _MAPPED_DOCSTRING_ADDENDUM , 1 )
238
- setattr (cls_dict [method_name ], '__doc__' , updated_method_docstring )
259
+ for method_name in methods_to_set :
260
+ orig_method = getattr (source_cls , method_name )
261
+ wrapped_method = wrap_func (orig_method ) if wrap_func is not None else orig_method
262
+ target_cls_dict [method_name ] = wrapped_method
263
+
264
+ if wrap_func is map_over_subtree :
265
+ # Add a paragraph to the method's docstring explaining how it's been mapped
266
+ orig_method_docstring = orig_method .__doc__
267
+ if orig_method_docstring is not None :
268
+ if '\n ' in orig_method_docstring :
269
+ new_method_docstring = orig_method_docstring .replace ('\n ' , _MAPPED_DOCSTRING_ADDENDUM , 1 )
270
+ else :
271
+ new_method_docstring = orig_method_docstring + f"\n \n { _MAPPED_DOCSTRING_ADDENDUM } "
272
+ setattr (target_cls_dict [method_name ], '__doc__' , new_method_docstring )
239
273
240
274
241
275
class MappedDatasetMethodsMixin :
@@ -244,33 +278,28 @@ class MappedDatasetMethodsMixin:
244
278
245
279
Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree.
246
280
"""
247
- # TODO equals, broadcast_equals etc.
248
- # TODO do dask-related private methods need to be exposed?
249
- _DATASET_DASK_METHODS_TO_EXPOSE = ['load' , 'compute' , 'persist' , 'unify_chunks' , 'chunk' , 'map_blocks' ]
250
- _DATASET_METHODS_TO_EXPOSE = ['copy' , 'as_numpy' , '__copy__' , '__deepcopy__' , 'set_coords' , 'reset_coords' , 'info' ,
251
- 'isel' , 'sel' , 'head' , 'tail' , 'thin' , 'broadcast_like' , 'reindex_like' ,
252
- 'reindex' , 'interp' , 'interp_like' , 'rename' , 'rename_dims' , 'rename_vars' ,
253
- 'swap_dims' , 'expand_dims' , 'set_index' , 'reset_index' , 'reorder_levels' , 'stack' ,
254
- 'unstack' , 'update' , 'merge' , 'drop_vars' , 'drop_sel' , 'drop_isel' , 'drop_dims' ,
255
- 'transpose' , 'dropna' , 'fillna' , 'interpolate_na' , 'ffill' , 'bfill' , 'combine_first' ,
256
- 'reduce' , 'map' , 'assign' , 'diff' , 'shift' , 'roll' , 'sortby' , 'quantile' , 'rank' ,
257
- 'differentiate' , 'integrate' , 'cumulative_integrate' , 'filter_by_attrs' , 'polyfit' ,
258
- 'pad' , 'idxmin' , 'idxmax' , 'argmin' , 'argmax' , 'query' , 'curvefit' ]
259
- # TODO unsure if these are called by external functions or not?
260
- _DATASET_OPS_TO_EXPOSE = ['_unary_op' , '_binary_op' , '_inplace_binary_op' ]
261
- _ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
281
+ __slots__ = ()
282
+ _wrap_then_attach_to_cls (vars (), Dataset , _ALL_DATASET_METHODS_TO_MAP , wrap_func = map_over_subtree )
262
283
263
- # TODO methods which should not or cannot act over the whole tree, such as .to_array
264
284
265
- methods_to_expose = [(method_name , getattr (Dataset , method_name ))
266
- for method_name in _ALL_DATASET_METHODS_TO_EXPOSE ]
267
- _wrap_then_attach_to_cls (vars (), methods_to_expose , wrap_func = map_over_subtree )
285
+ class MappedDataWithCoords (DataWithCoords ):
286
+ # TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample
287
+ # TODO re-implement AttrsAccessMixin stuff so that it includes access to child nodes
288
+ _wrap_then_attach_to_cls (vars (), DataWithCoords , _DATA_WITH_COORDS_METHODS_TO_MAP , wrap_func = map_over_subtree )
268
289
269
290
270
- # TODO implement ArrayReduce type methods
291
+ class DataTreeArithmetic (DatasetArithmetic ):
292
+ """
293
+ Mixin to add Dataset methods like __add__ and .mean()
294
+
295
+ Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine unaltered (normally
296
+ because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new
297
+ tree) and some will get overridden by the class definition of DataTree.
298
+ """
299
+ _wrap_then_attach_to_cls (vars (), DatasetArithmetic , _ARITHMETIC_METHODS_TO_WRAP , wrap_func = map_over_subtree )
271
300
272
301
273
- class DataTree (TreeNode , DatasetPropertiesMixin , MappedDatasetMethodsMixin ):
302
+ class DataTree (TreeNode , DatasetPropertiesMixin , MappedDatasetMethodsMixin , MappedDataWithCoords , DataTreeArithmetic ):
274
303
"""
275
304
A tree-like hierarchical collection of xarray objects.
276
305
@@ -312,6 +341,8 @@ class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin):
312
341
313
342
# TODO currently allows self.ds = None, should we instead always store at least an empty Dataset?
314
343
344
+ # TODO dataset methods which should not or cannot act over the whole tree, such as .to_array
345
+
315
346
def __init__ (
316
347
self ,
317
348
data_objects : Dict [PathType , Union [Dataset , DataArray ]] = None ,
0 commit comments