Skip to content

Commit 1af0273

Browse files
authored
Merge xarray-contrib/datatree#13 from TomNicholas/expose_dataset_methods
Expose the methods that are explicitly defined in the Dataset class definition
2 parents 1723261 + 2de7f2e commit 1af0273

File tree

2 files changed

+116
-40
lines changed

2 files changed

+116
-40
lines changed

datatree/datatree.py

+82-35
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
import functools
33
import textwrap
4+
import inspect
45

56
from typing import Mapping, Hashable, Union, List, Any, Callable, Iterable, Dict
67

@@ -11,6 +12,7 @@
1112
from xarray.core.variable import Variable
1213
from xarray.core.combine import merge
1314
from xarray.core import dtypes, utils
15+
from xarray.core._typed_ops import DatasetOpsMixin
1416

1517
from .treenode import TreeNode, PathType, _init_single_treenode
1618

@@ -31,7 +33,7 @@
3133
| | Variable("far_infrared")
3234
|-- DataNode("topography")
3335
| |-- DataNode("elevation")
34-
| | |-- Variable("height_above_sea_level")
36+
| | Variable("height_above_sea_level")
3537
|-- DataNode("population")
3638
"""
3739

@@ -75,7 +77,6 @@ def _map_over_subtree(tree, *args, **kwargs):
7577
"""Internal function which maps func over every node in tree, returning a tree of the results."""
7678

7779
# Recreate and act on root node
78-
# TODO make this of class DataTree
7980
out_tree = DataNode(name=tree.name, data=tree.ds)
8081
if out_tree.has_data:
8182
out_tree.ds = func(out_tree.ds, *args, **kwargs)
@@ -132,14 +133,82 @@ def attrs(self):
132133
else:
133134
raise AttributeError("property is not defined for a node with no data")
134135

136+
# TODO .loc
137+
135138
dims.__doc__ = Dataset.dims.__doc__
136139
variables.__doc__ = Dataset.variables.__doc__
137140
encoding.__doc__ = Dataset.encoding.__doc__
138141
sizes.__doc__ = Dataset.sizes.__doc__
139142
attrs.__doc__ = Dataset.attrs.__doc__
140143

141144

142-
class DataTree(TreeNode, DatasetPropertiesMixin):
145+
_MAPPED_DOCSTRING_ADDENDUM = textwrap.fill("This method was copied from xarray.Dataset, but has been altered to "
146+
"call the method on the Datasets stored in every node of the subtree. "
147+
"See the `map_over_subtree` decorator for more details.", width=117)
148+
149+
150+
def _expose_methods_wrapped_to_map_over_subtree(obj, method_name, method):
151+
"""
152+
Expose given method on node object, but wrapped to map over whole subtree, not just that node object.
153+
154+
Result is like having written this in obj's class definition:
155+
156+
```
157+
@map_over_subtree
158+
def method_name(self, *args, **kwargs):
159+
return self.method(*args, **kwargs)
160+
```
161+
"""
162+
163+
# Expose Dataset method, but wrapped to map over whole subtree when called
164+
# TODO should we be using functools.partialmethod here instead?
165+
mapped_over_tree = functools.partial(map_over_subtree(method), obj)
166+
setattr(obj, method_name, mapped_over_tree)
167+
168+
# TODO do we really need this for ops like __add__?
169+
# Add a line to the method's docstring explaining how it's been mapped
170+
method_docstring = method.__doc__
171+
if method_docstring is not None:
172+
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
173+
obj_method = getattr(obj, method_name)
174+
setattr(obj_method, '__doc__', updated_method_docstring)
175+
176+
177+
# TODO equals, broadcast_equals etc.
178+
# TODO do dask-related private methods need to be exposed?
179+
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
180+
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', '__contains__', '__len__',
181+
'__bool__', '__iter__', '__array__', 'set_coords', 'reset_coords', 'info',
182+
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
183+
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
184+
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
185+
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
186+
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
187+
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
188+
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
189+
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
190+
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
191+
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE
192+
193+
# TODO methods which should not or cannot act over the whole tree, such as .to_array
194+
195+
196+
class DatasetMethodsMixin:
197+
"""Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree."""
198+
199+
# TODO is there a way to put this code in the class definition so we don't have to specifically call this method?
200+
def _add_dataset_methods(self):
201+
methods_to_expose = [(method_name, getattr(Dataset, method_name))
202+
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]
203+
204+
for method_name, method in methods_to_expose:
205+
_expose_methods_wrapped_to_map_over_subtree(self, method_name, method)
206+
207+
208+
# TODO implement ArrayReduce type methods
209+
210+
211+
class DataTree(TreeNode, DatasetPropertiesMixin, DatasetMethodsMixin):
143212
"""
144213
A tree-like hierarchical collection of xarray objects.
145214
@@ -178,14 +247,6 @@ class DataTree(TreeNode, DatasetPropertiesMixin):
178247
# TODO do we need a watch out for if methods intended only for root nodes are called on non-root nodes?
179248

180249
# TODO add any other properties (maybe dask ones?)
181-
_DS_PROPERTIES = ['variables', 'attrs', 'encoding', 'dims', 'sizes']
182-
183-
# TODO add all the other methods to dispatch
184-
_DS_METHODS_TO_MAP_OVER_SUBTREES = ['isel', 'sel', 'min', 'max', 'mean', '__array_ufunc__']
185-
_MAPPED_DOCSTRING_ADDENDUM = textwrap.fill("This method was copied from xarray.Dataset, but has been altered to "
186-
"call the method on the Datasets stored in every node of the subtree. "
187-
"See the datatree.map_over_subtree decorator for more details.",
188-
width=117)
189250

190251
# TODO currently allows self.ds = None, should we instead always store at least an empty Dataset?
191252

@@ -218,24 +279,14 @@ def __init__(
218279
new_node = self.get_node(path)
219280
new_node[path] = data
220281

221-
self._add_method_api()
222-
223-
def _add_method_api(self):
224-
# Add methods defined in Dataset's class definition to this classes API, but wrapped to map over descendants too
225-
for method_name in self._DS_METHODS_TO_MAP_OVER_SUBTREES:
226-
# Expose Dataset method, but wrapped to map over whole subtree
227-
ds_method = getattr(Dataset, method_name)
228-
setattr(self, method_name, map_over_subtree(ds_method))
229-
230-
# Add a line to the method's docstring explaining how it's been mapped
231-
ds_method_docstring = getattr(Dataset, f'{method_name}').__doc__
232-
if ds_method_docstring is not None:
233-
updated_method_docstring = ds_method_docstring.replace('\n', self._MAPPED_DOCSTRING_ADDENDUM, 1)
234-
setattr(self, f'{method_name}.__doc__', updated_method_docstring)
282+
# TODO this has to be
283+
self._add_all_dataset_api()
235284

236-
# TODO wrap methods for ops too, such as those in DatasetOpsMixin
285+
def _add_all_dataset_api(self):
286+
# Add methods like .mean(), but wrapped to map over subtrees
287+
self._add_dataset_methods()
237288

238-
# TODO map applied ufuncs over all leaves
289+
# TODO add dataset ops here
239290

240291
@property
241292
def ds(self) -> Dataset:
@@ -257,7 +308,7 @@ def has_data(self):
257308
def _init_single_datatree_node(
258309
cls,
259310
name: Hashable,
260-
data: Dataset = None,
311+
data: Union[Dataset, DataArray] = None,
261312
parent: TreeNode = None,
262313
children: List[TreeNode] = None,
263314
):
@@ -285,6 +336,9 @@ def _init_single_datatree_node(
285336
obj = object.__new__(cls)
286337
obj = _init_single_treenode(obj, name=name, parent=parent, children=children)
287338
obj.ds = data
339+
340+
obj._add_all_dataset_api()
341+
288342
return obj
289343

290344
def __str__(self):
@@ -559,13 +613,6 @@ def get_any(self, *tags: Hashable) -> DataTree:
559613
if any(tag in c.tags for tag in tags)}
560614
return DataTree(data_objects=matching_children)
561615

562-
@property
563-
def chunks(self):
564-
raise NotImplementedError
565-
566-
def chunk(self):
567-
raise NotImplementedError
568-
569616
def merge(self, datatree: DataTree) -> DataTree:
570617
"""Merge all the leaves of a second DataTree into this one."""
571618
raise NotImplementedError

datatree/tests/test_dataset_api.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import pytest
22

3+
import numpy as np
4+
35
import xarray as xr
46
from xarray.testing import assert_equal
57

@@ -93,12 +95,39 @@ def test_no_data_no_properties(self):
9395

9496

9597
class TestDSMethodInheritance:
98+
def test_root(self):
99+
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
100+
dt = DataNode('root', data=da)
101+
expected_ds = da.to_dataset().isel(x=1)
102+
result_ds = dt.isel(x=1).ds
103+
assert_equal(result_ds, expected_ds)
104+
105+
def test_descendants(self):
106+
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
107+
dt = DataNode('root')
108+
DataNode('results', parent=dt, data=da)
109+
expected_ds = da.to_dataset().isel(x=1)
110+
result_ds = dt.isel(x=1)['results'].ds
111+
assert_equal(result_ds, expected_ds)
112+
113+
114+
class TestOps:
96115
...
97116

98117

99-
class TestBinaryOps:
100-
...
101-
102-
118+
@pytest.mark.xfail
103119
class TestUFuncs:
104-
...
120+
def test_root(self):
121+
da = xr.DataArray(name='a', data=[1, 2, 3])
122+
dt = DataNode('root', data=da)
123+
expected_ds = np.sin(da.to_dataset())
124+
result_ds = np.sin(dt).ds
125+
assert_equal(result_ds, expected_ds)
126+
127+
def test_descendants(self):
128+
da = xr.DataArray(name='a', data=[1, 2, 3])
129+
dt = DataNode('root')
130+
DataNode('results', parent=dt, data=da)
131+
expected_ds = np.sin(da.to_dataset())
132+
result_ds = np.sin(dt)['results'].ds
133+
assert_equal(result_ds, expected_ds)

0 commit comments

Comments
 (0)