Skip to content

Commit 631aabf

Browse files
authored
Merge xarray-contrib/datatree#10 from TomNicholas/expose_dataset_ops
Expose dataset reduce operations
2 parents 4cd27bd + fecc2e0 commit 631aabf

File tree

2 files changed

+129
-50
lines changed

2 files changed

+129
-50
lines changed

datatree/datatree.py

+74-43
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22
import functools
33
import textwrap
4-
import inspect
54

65
from typing import Mapping, Hashable, Union, List, Any, Callable, Iterable, Dict
76

@@ -12,6 +11,9 @@
1211
from xarray.core.variable import Variable
1312
from xarray.core.combine import merge
1413
from xarray.core import dtypes, utils
14+
from xarray.core.common import DataWithCoords
15+
from xarray.core.arithmetic import DatasetArithmetic
16+
from xarray.core.ops import REDUCE_METHODS, NAN_REDUCE_METHODS, NAN_CUM_METHODS
1517

1618
from .treenode import TreeNode, PathType, _init_single_treenode
1719

@@ -46,7 +48,8 @@ def map_over_subtree(func):
4648
The function will be applied to any dataset stored in this node, as well as any dataset stored in any of the
4749
descendant nodes. The returned tree will have the same structure as the original subtree.
4850
49-
func needs to return a Dataset in order to rebuild the subtree.
51+
func needs to return a Dataset, DataArray, or None in order to be able to rebuild the subtree after mapping, as each
52+
result will be assigned to its respective node of new tree via `DataTree.__setitem__`.
5053
5154
Parameters
5255
----------
@@ -132,7 +135,6 @@ def attrs(self):
132135
else:
133136
raise AttributeError("property is not defined for a node with no data")
134137

135-
136138
@property
137139
def nbytes(self) -> int:
138140
return sum(node.ds.nbytes for node in self.subtree_nodes)
@@ -203,10 +205,33 @@ def imag(self):
203205

204206
_MAPPED_DOCSTRING_ADDENDUM = textwrap.fill("This method was copied from xarray.Dataset, but has been altered to "
205207
"call the method on the Datasets stored in every node of the subtree. "
206-
"See the `map_over_subtree` decorator for more details.", width=117)
207-
208-
209-
def _wrap_then_attach_to_cls(cls_dict, methods_to_expose, wrap_func=None):
208+
"See the `map_over_subtree` function for more details.", width=117)
209+
210+
# TODO equals, broadcast_equals etc.
211+
# TODO do dask-related private methods need to be exposed?
212+
_DATASET_DASK_METHODS_TO_MAP = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
213+
_DATASET_METHODS_TO_MAP = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info',
214+
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
215+
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
216+
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
217+
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
218+
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
219+
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
220+
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
221+
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
222+
# TODO unsure if these are called by external functions or not?
223+
_DATASET_OPS_TO_MAP = ['_unary_op', '_binary_op', '_inplace_binary_op']
224+
_ALL_DATASET_METHODS_TO_MAP = _DATASET_DASK_METHODS_TO_MAP + _DATASET_METHODS_TO_MAP + _DATASET_OPS_TO_MAP
225+
226+
_DATA_WITH_COORDS_METHODS_TO_MAP = ['squeeze', 'clip', 'assign_coords', 'where', 'close', 'isnull', 'notnull',
227+
'isin', 'astype']
228+
229+
# TODO NUM_BINARY_OPS apparently aren't defined on DatasetArithmetic, and don't appear to be injected anywhere...
230+
#['__array_ufunc__'] \
231+
_ARITHMETIC_METHODS_TO_WRAP = REDUCE_METHODS + NAN_REDUCE_METHODS + NAN_CUM_METHODS
232+
233+
234+
def _wrap_then_attach_to_cls(target_cls_dict, source_cls, methods_to_set, wrap_func=None):
210235
"""
211236
Attach given methods on a class, and optionally wrap each method first. (i.e. with map_over_subtree)
212237
@@ -219,23 +244,32 @@ def method_name(self, *args, **kwargs):
219244
220245
Parameters
221246
----------
222-
cls_dict
223-
The __dict__ attribute of a class, which can also be accessed by calling vars() from within that classes'
224-
definition.
225-
methods_to_expose : Iterable[Tuple[str, callable]]
226-
The method names and definitions supplied as a list of (method_name_string, method) pairs.\
247+
target_cls_dict : MappingProxy
248+
The __dict__ attribute of the class which we want the methods to be added to. (The __dict__ attribute can also
249+
be accessed by calling vars() from within that classes' definition.) This will be updated by this function.
250+
source_cls : class
251+
Class object from which we want to copy methods (and optionally wrap them). Should be the actual class object
252+
(or instance), not just the __dict__.
253+
methods_to_set : Iterable[Tuple[str, callable]]
254+
The method names and definitions supplied as a list of (method_name_string, method) pairs.
227255
This format matches the output of inspect.getmembers().
256+
wrap_func : callable, optional
257+
Function to decorate each method with. Must have the same return type as the method.
228258
"""
229-
for method_name, method in methods_to_expose:
230-
wrapped_method = wrap_func(method) if wrap_func is not None else method
231-
cls_dict[method_name] = wrapped_method
232-
233-
# TODO do we really need this for ops like __add__?
234-
# Add a line to the method's docstring explaining how it's been mapped
235-
method_docstring = method.__doc__
236-
if method_docstring is not None:
237-
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
238-
setattr(cls_dict[method_name], '__doc__', updated_method_docstring)
259+
for method_name in methods_to_set:
260+
orig_method = getattr(source_cls, method_name)
261+
wrapped_method = wrap_func(orig_method) if wrap_func is not None else orig_method
262+
target_cls_dict[method_name] = wrapped_method
263+
264+
if wrap_func is map_over_subtree:
265+
# Add a paragraph to the method's docstring explaining how it's been mapped
266+
orig_method_docstring = orig_method.__doc__
267+
if orig_method_docstring is not None:
268+
if '\n' in orig_method_docstring:
269+
new_method_docstring = orig_method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
270+
else:
271+
new_method_docstring = orig_method_docstring + f"\n\n{_MAPPED_DOCSTRING_ADDENDUM}"
272+
setattr(target_cls_dict[method_name], '__doc__', new_method_docstring)
239273

240274

241275
class MappedDatasetMethodsMixin:
@@ -244,33 +278,28 @@ class MappedDatasetMethodsMixin:
244278
245279
Every method wrapped here needs to have a return value of Dataset or DataArray in order to construct a new tree.
246280
"""
247-
# TODO equals, broadcast_equals etc.
248-
# TODO do dask-related private methods need to be exposed?
249-
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
250-
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', 'set_coords', 'reset_coords', 'info',
251-
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
252-
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
253-
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
254-
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
255-
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
256-
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
257-
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
258-
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
259-
# TODO unsure if these are called by external functions or not?
260-
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
261-
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
281+
__slots__ = ()
282+
_wrap_then_attach_to_cls(vars(), Dataset, _ALL_DATASET_METHODS_TO_MAP, wrap_func=map_over_subtree)
262283

263-
# TODO methods which should not or cannot act over the whole tree, such as .to_array
264284

265-
methods_to_expose = [(method_name, getattr(Dataset, method_name))
266-
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]
267-
_wrap_then_attach_to_cls(vars(), methods_to_expose, wrap_func=map_over_subtree)
285+
class MappedDataWithCoords(DataWithCoords):
286+
# TODO add mapped versions of groupby, weighted, rolling, rolling_exp, coarsen, resample
287+
# TODO re-implement AttrsAccessMixin stuff so that it includes access to child nodes
288+
_wrap_then_attach_to_cls(vars(), DataWithCoords, _DATA_WITH_COORDS_METHODS_TO_MAP, wrap_func=map_over_subtree)
268289

269290

270-
# TODO implement ArrayReduce type methods
291+
class DataTreeArithmetic(DatasetArithmetic):
292+
"""
293+
Mixin to add Dataset methods like __add__ and .mean()
294+
295+
Some of these methods must be wrapped to map over all nodes in the subtree. Others are fine unaltered (normally
296+
because they (a) only call dataset properties and (b) don't return a dataset that should be nested into a new
297+
tree) and some will get overridden by the class definition of DataTree.
298+
"""
299+
_wrap_then_attach_to_cls(vars(), DatasetArithmetic, _ARITHMETIC_METHODS_TO_WRAP, wrap_func=map_over_subtree)
271300

272301

273-
class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin):
302+
class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin, MappedDataWithCoords, DataTreeArithmetic):
274303
"""
275304
A tree-like hierarchical collection of xarray objects.
276305
@@ -312,6 +341,8 @@ class DataTree(TreeNode, DatasetPropertiesMixin, MappedDatasetMethodsMixin):
312341

313342
# TODO currently allows self.ds = None, should we instead always store at least an empty Dataset?
314343

344+
# TODO dataset methods which should not or cannot act over the whole tree, such as .to_array
345+
315346
def __init__(
316347
self,
317348
data_objects: Dict[PathType, Union[Dataset, DataArray]] = None,

datatree/tests/test_dataset_api.py

+55-7
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def test_properties(self):
8080
assert dt.sizes == dt.ds.sizes
8181
assert dt.variables == dt.ds.variables
8282

83-
8483
def test_no_data_no_properties(self):
8584
dt = DataNode('root', data=None)
8685
with pytest.raises(AttributeError):
@@ -96,24 +95,73 @@ def test_no_data_no_properties(self):
9695

9796

9897
class TestDSMethodInheritance:
99-
def test_root(self):
98+
def test_dataset_method(self):
99+
# test root
100100
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
101101
dt = DataNode('root', data=da)
102102
expected_ds = da.to_dataset().isel(x=1)
103103
result_ds = dt.isel(x=1).ds
104104
assert_equal(result_ds, expected_ds)
105105

106-
def test_descendants(self):
107-
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
108-
dt = DataNode('root')
106+
# test descendant
109107
DataNode('results', parent=dt, data=da)
110-
expected_ds = da.to_dataset().isel(x=1)
111108
result_ds = dt.isel(x=1)['results'].ds
112109
assert_equal(result_ds, expected_ds)
113110

111+
def test_reduce_method(self):
112+
# test root
113+
da = xr.DataArray(name='a', data=[False, True, False], dims='x')
114+
dt = DataNode('root', data=da)
115+
expected_ds = da.to_dataset().any()
116+
result_ds = dt.any().ds
117+
assert_equal(result_ds, expected_ds)
118+
119+
# test descendant
120+
DataNode('results', parent=dt, data=da)
121+
result_ds = dt.any()['results'].ds
122+
assert_equal(result_ds, expected_ds)
123+
124+
def test_nan_reduce_method(self):
125+
# test root
126+
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
127+
dt = DataNode('root', data=da)
128+
expected_ds = da.to_dataset().mean()
129+
result_ds = dt.mean().ds
130+
assert_equal(result_ds, expected_ds)
131+
132+
# test descendant
133+
DataNode('results', parent=dt, data=da)
134+
result_ds = dt.mean()['results'].ds
135+
assert_equal(result_ds, expected_ds)
136+
137+
def test_cum_method(self):
138+
# test root
139+
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
140+
dt = DataNode('root', data=da)
141+
expected_ds = da.to_dataset().cumsum()
142+
result_ds = dt.cumsum().ds
143+
assert_equal(result_ds, expected_ds)
144+
145+
# test descendant
146+
DataNode('results', parent=dt, data=da)
147+
result_ds = dt.cumsum()['results'].ds
148+
assert_equal(result_ds, expected_ds)
149+
114150

115151
class TestOps:
116-
...
152+
@pytest.mark.xfail
153+
def test_binary_op(self):
154+
ds1 = xr.Dataset({'a': [5], 'b': [3]})
155+
ds2 = xr.Dataset({'x': [0.1, 0.2], 'y': [10, 20]})
156+
dt = DataNode('root', data=ds1)
157+
DataNode('subnode', data=ds2, parent=dt)
158+
159+
expected_root = DataNode('root', data=ds1*ds1)
160+
expected_descendant = DataNode('subnode', data=ds2*ds2, parent=expected_root)
161+
result = dt * dt
162+
163+
assert_equal(result.ds, expected_root.ds)
164+
assert_equal(result['subnode'].ds, expected_descendant.ds)
117165

118166

119167
@pytest.mark.xfail

0 commit comments

Comments
 (0)