1
1
from __future__ import annotations
2
2
import functools
3
3
import textwrap
4
+ import inspect
4
5
5
6
from typing import Mapping , Hashable , Union , List , Any , Callable , Iterable , Dict
6
7
11
12
from xarray .core .variable import Variable
12
13
from xarray .core .combine import merge
13
14
from xarray .core import dtypes , utils
15
+ from xarray .core ._typed_ops import DatasetOpsMixin
14
16
15
17
from .treenode import TreeNode , PathType , _init_single_treenode
16
18
31
33
| | Variable("far_infrared")
32
34
|-- DataNode("topography")
33
35
| |-- DataNode("elevation")
34
- | | |-- Variable("height_above_sea_level")
36
+ | | Variable("height_above_sea_level")
35
37
|-- DataNode("population")
36
38
"""
37
39
@@ -75,7 +77,6 @@ def _map_over_subtree(tree, *args, **kwargs):
75
77
"""Internal function which maps func over every node in tree, returning a tree of the results."""
76
78
77
79
# Recreate and act on root node
78
- # TODO make this of class DataTree
79
80
out_tree = DataNode (name = tree .name , data = tree .ds )
80
81
if out_tree .has_data :
81
82
out_tree .ds = func (out_tree .ds , * args , ** kwargs )
@@ -132,14 +133,82 @@ def attrs(self):
132
133
else :
133
134
raise AttributeError ("property is not defined for a node with no data" )
134
135
136
+ # TODO .loc
137
+
135
138
dims .__doc__ = Dataset .dims .__doc__
136
139
variables .__doc__ = Dataset .variables .__doc__
137
140
encoding .__doc__ = Dataset .encoding .__doc__
138
141
sizes .__doc__ = Dataset .sizes .__doc__
139
142
attrs .__doc__ = Dataset .attrs .__doc__
140
143
141
144
142
- class DataTree (TreeNode , DatasetPropertiesMixin ):
145
+ _MAPPED_DOCSTRING_ADDENDUM = textwrap .fill ("This method was copied from xarray.Dataset, but has been altered to "
146
+ "call the method on the Datasets stored in every node of the subtree. "
147
+ "See the `map_over_subtree` decorator for more details." , width = 117 )
148
+
149
+
150
+ def _expose_methods_wrapped_to_map_over_subtree (obj , method_name , method ):
151
+ """
152
+ Expose given method on node object, but wrapped to map over whole subtree, not just that node object.
153
+
154
+ Result is like having written this in obj's class definition:
155
+
156
+ ```
157
+ @map_over_subtree
158
+ def method_name(self, *args, **kwargs):
159
+ return self.method(*args, **kwargs)
160
+ ```
161
+ """
162
+
163
+ # Expose Dataset method, but wrapped to map over whole subtree when called
164
+ # TODO should we be using functools.partialmethod here instead?
165
+ mapped_over_tree = functools .partial (map_over_subtree (method ), obj )
166
+ setattr (obj , method_name , mapped_over_tree )
167
+
168
+ # TODO do we really need this for ops like __add__?
169
+ # Add a line to the method's docstring explaining how it's been mapped
170
+ method_docstring = method .__doc__
171
+ if method_docstring is not None :
172
+ updated_method_docstring = method_docstring .replace ('\n ' , _MAPPED_DOCSTRING_ADDENDUM , 1 )
173
+ obj_method = getattr (obj , method_name )
174
+ setattr (obj_method , '__doc__' , updated_method_docstring )
175
+
176
+
177
+ # TODO equals, broadcast_equals etc.
178
+ # TODO do dask-related private methods need to be exposed?
179
+ _DATASET_DASK_METHODS_TO_EXPOSE = ['load' , 'compute' , 'persist' , 'unify_chunks' , 'chunk' , 'map_blocks' ]
180
+ _DATASET_METHODS_TO_EXPOSE = ['copy' , 'as_numpy' , '__copy__' , '__deepcopy__' , '__contains__' , '__len__' ,
181
+ '__bool__' , '__iter__' , '__array__' , 'set_coords' , 'reset_coords' , 'info' ,
182
+ 'isel' , 'sel' , 'head' , 'tail' , 'thin' , 'broadcast_like' , 'reindex_like' ,
183
+ 'reindex' , 'interp' , 'interp_like' , 'rename' , 'rename_dims' , 'rename_vars' ,
184
+ 'swap_dims' , 'expand_dims' , 'set_index' , 'reset_index' , 'reorder_levels' , 'stack' ,
185
+ 'unstack' , 'update' , 'merge' , 'drop_vars' , 'drop_sel' , 'drop_isel' , 'drop_dims' ,
186
+ 'transpose' , 'dropna' , 'fillna' , 'interpolate_na' , 'ffill' , 'bfill' , 'combine_first' ,
187
+ 'reduce' , 'map' , 'assign' , 'diff' , 'shift' , 'roll' , 'sortby' , 'quantile' , 'rank' ,
188
+ 'differentiate' , 'integrate' , 'cumulative_integrate' , 'filter_by_attrs' , 'polyfit' ,
189
+ 'pad' , 'idxmin' , 'idxmax' , 'argmin' , 'argmax' , 'query' , 'curvefit' ]
190
+ _DATASET_OPS_TO_EXPOSE = ['_unary_op' , '_binary_op' , '_inplace_binary_op' ]
191
+ _ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
192
+
193
+ # TODO methods which should not or cannot act over the whole tree, such as .to_array
194
+
195
+
196
+ class DatasetMethodsMixin :
197
+ """Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree."""
198
+
199
+ # TODO is there a way to put this code in the class definition so we don't have to specifically call this method?
200
+ def _add_dataset_methods (self ):
201
+ methods_to_expose = [(method_name , getattr (Dataset , method_name ))
202
+ for method_name in _ALL_DATASET_METHODS_TO_EXPOSE ]
203
+
204
+ for method_name , method in methods_to_expose :
205
+ _expose_methods_wrapped_to_map_over_subtree (self , method_name , method )
206
+
207
+
208
+ # TODO implement ArrayReduce type methods
209
+
210
+
211
+ class DataTree (TreeNode , DatasetPropertiesMixin , DatasetMethodsMixin ):
143
212
"""
144
213
A tree-like hierarchical collection of xarray objects.
145
214
@@ -178,14 +247,6 @@ class DataTree(TreeNode, DatasetPropertiesMixin):
178
247
# TODO do we need a watch out for if methods intended only for root nodes are called on non-root nodes?
179
248
180
249
# TODO add any other properties (maybe dask ones?)
181
- _DS_PROPERTIES = ['variables' , 'attrs' , 'encoding' , 'dims' , 'sizes' ]
182
-
183
- # TODO add all the other methods to dispatch
184
- _DS_METHODS_TO_MAP_OVER_SUBTREES = ['isel' , 'sel' , 'min' , 'max' , 'mean' , '__array_ufunc__' ]
185
- _MAPPED_DOCSTRING_ADDENDUM = textwrap .fill ("This method was copied from xarray.Dataset, but has been altered to "
186
- "call the method on the Datasets stored in every node of the subtree. "
187
- "See the datatree.map_over_subtree decorator for more details." ,
188
- width = 117 )
189
250
190
251
# TODO currently allows self.ds = None, should we instead always store at least an empty Dataset?
191
252
@@ -218,24 +279,14 @@ def __init__(
218
279
new_node = self .get_node (path )
219
280
new_node [path ] = data
220
281
221
- self ._add_method_api ()
222
-
223
- def _add_method_api (self ):
224
- # Add methods defined in Dataset's class definition to this classes API, but wrapped to map over descendants too
225
- for method_name in self ._DS_METHODS_TO_MAP_OVER_SUBTREES :
226
- # Expose Dataset method, but wrapped to map over whole subtree
227
- ds_method = getattr (Dataset , method_name )
228
- setattr (self , method_name , map_over_subtree (ds_method ))
229
-
230
- # Add a line to the method's docstring explaining how it's been mapped
231
- ds_method_docstring = getattr (Dataset , f'{ method_name } ' ).__doc__
232
- if ds_method_docstring is not None :
233
- updated_method_docstring = ds_method_docstring .replace ('\n ' , self ._MAPPED_DOCSTRING_ADDENDUM , 1 )
234
- setattr (self , f'{ method_name } .__doc__' , updated_method_docstring )
282
+ # TODO this has to be
283
+ self ._add_all_dataset_api ()
235
284
236
- # TODO wrap methods for ops too, such as those in DatasetOpsMixin
285
+ def _add_all_dataset_api (self ):
286
+ # Add methods like .mean(), but wrapped to map over subtrees
287
+ self ._add_dataset_methods ()
237
288
238
- # TODO map applied ufuncs over all leaves
289
+ # TODO add dataset ops here
239
290
240
291
@property
241
292
def ds (self ) -> Dataset :
@@ -257,7 +308,7 @@ def has_data(self):
257
308
def _init_single_datatree_node (
258
309
cls ,
259
310
name : Hashable ,
260
- data : Dataset = None ,
311
+ data : Union [ Dataset , DataArray ] = None ,
261
312
parent : TreeNode = None ,
262
313
children : List [TreeNode ] = None ,
263
314
):
@@ -285,6 +336,9 @@ def _init_single_datatree_node(
285
336
obj = object .__new__ (cls )
286
337
obj = _init_single_treenode (obj , name = name , parent = parent , children = children )
287
338
obj .ds = data
339
+
340
+ obj ._add_all_dataset_api ()
341
+
288
342
return obj
289
343
290
344
def __str__ (self ):
@@ -559,13 +613,6 @@ def get_any(self, *tags: Hashable) -> DataTree:
559
613
if any (tag in c .tags for tag in tags )}
560
614
return DataTree (data_objects = matching_children )
561
615
562
- @property
563
- def chunks (self ):
564
- raise NotImplementedError
565
-
566
- def chunk (self ):
567
- raise NotImplementedError
568
-
569
616
def merge (self , datatree : DataTree ) -> DataTree :
570
617
"""Merge all the leaves of a second DataTree into this one."""
571
618
raise NotImplementedError
0 commit comments