Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Expose dataset methods #13

Merged
merged 5 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 82 additions & 35 deletions datatree/datatree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import functools
import textwrap
import inspect

from typing import Mapping, Hashable, Union, List, Any, Callable, Iterable, Dict

Expand All @@ -11,6 +12,7 @@
from xarray.core.variable import Variable
from xarray.core.combine import merge
from xarray.core import dtypes, utils
from xarray.core._typed_ops import DatasetOpsMixin

from .treenode import TreeNode, PathType, _init_single_treenode

Expand All @@ -31,7 +33,7 @@
| | Variable("far_infrared")
|-- DataNode("topography")
| |-- DataNode("elevation")
| | |-- Variable("height_above_sea_level")
| | Variable("height_above_sea_level")
|-- DataNode("population")
"""

Expand Down Expand Up @@ -75,7 +77,6 @@ def _map_over_subtree(tree, *args, **kwargs):
"""Internal function which maps func over every node in tree, returning a tree of the results."""

# Recreate and act on root node
# TODO make this of class DataTree
out_tree = DataNode(name=tree.name, data=tree.ds)
if out_tree.has_data:
out_tree.ds = func(out_tree.ds, *args, **kwargs)
Expand Down Expand Up @@ -132,14 +133,82 @@ def attrs(self):
else:
raise AttributeError("property is not defined for a node with no data")

# TODO .loc

dims.__doc__ = Dataset.dims.__doc__
variables.__doc__ = Dataset.variables.__doc__
encoding.__doc__ = Dataset.encoding.__doc__
sizes.__doc__ = Dataset.sizes.__doc__
attrs.__doc__ = Dataset.attrs.__doc__


class DataTree(TreeNode, DatasetPropertiesMixin):
_MAPPED_DOCSTRING_ADDENDUM = textwrap.fill("This method was copied from xarray.Dataset, but has been altered to "
"call the method on the Datasets stored in every node of the subtree. "
"See the `map_over_subtree` decorator for more details.", width=117)


def _expose_methods_wrapped_to_map_over_subtree(obj, method_name, method):
"""
Expose given method on node object, but wrapped to map over whole subtree, not just that node object.

Result is like having written this in obj's class definition:

```
@map_over_subtree
def method_name(self, *args, **kwargs):
return self.method(*args, **kwargs)
```
"""

# Expose Dataset method, but wrapped to map over whole subtree when called
# TODO should we be using functools.partialmethod here instead?
mapped_over_tree = functools.partial(map_over_subtree(method), obj)
setattr(obj, method_name, mapped_over_tree)

# TODO do we really need this for ops like __add__?
# Add a line to the method's docstring explaining how it's been mapped
method_docstring = method.__doc__
if method_docstring is not None:
updated_method_docstring = method_docstring.replace('\n', _MAPPED_DOCSTRING_ADDENDUM, 1)
obj_method = getattr(obj, method_name)
setattr(obj_method, '__doc__', updated_method_docstring)


# TODO equals, broadcast_equals etc.
# TODO do dask-related private methods need to be exposed?
_DATASET_DASK_METHODS_TO_EXPOSE = ['load', 'compute', 'persist', 'unify_chunks', 'chunk', 'map_blocks']
_DATASET_METHODS_TO_EXPOSE = ['copy', 'as_numpy', '__copy__', '__deepcopy__', '__contains__', '__len__',
'__bool__', '__iter__', '__array__', 'set_coords', 'reset_coords', 'info',
'isel', 'sel', 'head', 'tail', 'thin', 'broadcast_like', 'reindex_like',
'reindex', 'interp', 'interp_like', 'rename', 'rename_dims', 'rename_vars',
'swap_dims', 'expand_dims', 'set_index', 'reset_index', 'reorder_levels', 'stack',
'unstack', 'update', 'merge', 'drop_vars', 'drop_sel', 'drop_isel', 'drop_dims',
'transpose', 'dropna', 'fillna', 'interpolate_na', 'ffill', 'bfill', 'combine_first',
'reduce', 'map', 'assign', 'diff', 'shift', 'roll', 'sortby', 'quantile', 'rank',
'differentiate', 'integrate', 'cumulative_integrate', 'filter_by_attrs', 'polyfit',
'pad', 'idxmin', 'idxmax', 'argmin', 'argmax', 'query', 'curvefit']
_DATASET_OPS_TO_EXPOSE = ['_unary_op', '_binary_op', '_inplace_binary_op']
_ALL_DATASET_METHODS_TO_EXPOSE = _DATASET_DASK_METHODS_TO_EXPOSE + _DATASET_METHODS_TO_EXPOSE + _DATASET_OPS_TO_EXPOSE

# TODO methods which should not or cannot act over the whole tree, such as .to_array


class DatasetMethodsMixin:
"""Mixin to add Dataset methods like .mean(), but wrapped to map over all nodes in the subtree."""

# TODO is there a way to put this code in the class definition so we don't have to specifically call this method?
def _add_dataset_methods(self):
methods_to_expose = [(method_name, getattr(Dataset, method_name))
for method_name in _ALL_DATASET_METHODS_TO_EXPOSE]

for method_name, method in methods_to_expose:
_expose_methods_wrapped_to_map_over_subtree(self, method_name, method)


# TODO implement ArrayReduce type methods


class DataTree(TreeNode, DatasetPropertiesMixin, DatasetMethodsMixin):
"""
A tree-like hierarchical collection of xarray objects.

Expand Down Expand Up @@ -178,14 +247,6 @@ class DataTree(TreeNode, DatasetPropertiesMixin):
# TODO do we need a watch out for if methods intended only for root nodes are called on non-root nodes?

# TODO add any other properties (maybe dask ones?)
_DS_PROPERTIES = ['variables', 'attrs', 'encoding', 'dims', 'sizes']

# TODO add all the other methods to dispatch
_DS_METHODS_TO_MAP_OVER_SUBTREES = ['isel', 'sel', 'min', 'max', 'mean', '__array_ufunc__']
_MAPPED_DOCSTRING_ADDENDUM = textwrap.fill("This method was copied from xarray.Dataset, but has been altered to "
"call the method on the Datasets stored in every node of the subtree. "
"See the datatree.map_over_subtree decorator for more details.",
width=117)

# TODO currently allows self.ds = None, should we instead always store at least an empty Dataset?

Expand Down Expand Up @@ -218,24 +279,14 @@ def __init__(
new_node = self.get_node(path)
new_node[path] = data

self._add_method_api()

def _add_method_api(self):
# Add methods defined in Dataset's class definition to this classes API, but wrapped to map over descendants too
for method_name in self._DS_METHODS_TO_MAP_OVER_SUBTREES:
# Expose Dataset method, but wrapped to map over whole subtree
ds_method = getattr(Dataset, method_name)
setattr(self, method_name, map_over_subtree(ds_method))

# Add a line to the method's docstring explaining how it's been mapped
ds_method_docstring = getattr(Dataset, f'{method_name}').__doc__
if ds_method_docstring is not None:
updated_method_docstring = ds_method_docstring.replace('\n', self._MAPPED_DOCSTRING_ADDENDUM, 1)
setattr(self, f'{method_name}.__doc__', updated_method_docstring)
# TODO this has to be
self._add_all_dataset_api()

# TODO wrap methods for ops too, such as those in DatasetOpsMixin
def _add_all_dataset_api(self):
# Add methods like .mean(), but wrapped to map over subtrees
self._add_dataset_methods()

# TODO map applied ufuncs over all leaves
# TODO add dataset ops here

@property
def ds(self) -> Dataset:
Expand All @@ -257,7 +308,7 @@ def has_data(self):
def _init_single_datatree_node(
cls,
name: Hashable,
data: Dataset = None,
data: Union[Dataset, DataArray] = None,
parent: TreeNode = None,
children: List[TreeNode] = None,
):
Expand Down Expand Up @@ -285,6 +336,9 @@ def _init_single_datatree_node(
obj = object.__new__(cls)
obj = _init_single_treenode(obj, name=name, parent=parent, children=children)
obj.ds = data

obj._add_all_dataset_api()

return obj

def __str__(self):
Expand Down Expand Up @@ -559,13 +613,6 @@ def get_any(self, *tags: Hashable) -> DataTree:
if any(tag in c.tags for tag in tags)}
return DataTree(data_objects=matching_children)

@property
def chunks(self):
raise NotImplementedError

def chunk(self):
raise NotImplementedError

def merge(self, datatree: DataTree) -> DataTree:
"""Merge all the leaves of a second DataTree into this one."""
raise NotImplementedError
Expand Down
39 changes: 34 additions & 5 deletions datatree/tests/test_dataset_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest

import numpy as np

import xarray as xr
from xarray.testing import assert_equal

Expand Down Expand Up @@ -93,12 +95,39 @@ def test_no_data_no_properties(self):


class TestDSMethodInheritance:
def test_root(self):
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
dt = DataNode('root', data=da)
expected_ds = da.to_dataset().isel(x=1)
result_ds = dt.isel(x=1).ds
assert_equal(result_ds, expected_ds)

def test_descendants(self):
da = xr.DataArray(name='a', data=[1, 2, 3], dims='x')
dt = DataNode('root')
DataNode('results', parent=dt, data=da)
expected_ds = da.to_dataset().isel(x=1)
result_ds = dt.isel(x=1)['results'].ds
assert_equal(result_ds, expected_ds)


class TestOps:
...


class TestBinaryOps:
...


@pytest.mark.xfail
class TestUFuncs:
...
def test_root(self):
da = xr.DataArray(name='a', data=[1, 2, 3])
dt = DataNode('root', data=da)
expected_ds = np.sin(da.to_dataset())
result_ds = np.sin(dt).ds
assert_equal(result_ds, expected_ds)

def test_descendants(self):
da = xr.DataArray(name='a', data=[1, 2, 3])
dt = DataNode('root')
DataNode('results', parent=dt, data=da)
expected_ds = np.sin(da.to_dataset())
result_ds = np.sin(dt)['results'].ds
assert_equal(result_ds, expected_ds)