Skip to content

Commit 14e2531

Browse files
committed
fresh baked indices hot from the oven
1 parent 89431af commit 14e2531

File tree

3 files changed

+107
-40
lines changed

3 files changed

+107
-40
lines changed

test/test_index.py

-25
This file was deleted.

test/test_indices.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
from numpy.typing import NDArray
3+
from xarray.core.indexes import Index, PandasIndex
4+
from xarray.core.indexing import merge_sel_results
5+
6+
from xattree import ROOT, Indices, array, dim, xattree
7+
8+
9+
@xattree(index=lambda ds: Indices.alias_dim(ds, "n", "i"))
10+
class Foo:
11+
n: int = dim(default=3, coord=False)
12+
a: NDArray[np.float64] = array(default=0.0, dims=("n",))
13+
14+
15+
def test_simple_index():
16+
foo = Foo()
17+
assert foo.n == 3
18+
assert foo.a.shape == (3,)
19+
assert foo.data.a.shape == (3,)
20+
assert foo.data.i.shape == (3,)
21+
assert "i" in foo.data.coords
22+
assert "n" not in foo.data.coords
23+
24+
25+
class GridIndex(Index):
26+
def __init__(self, indices):
27+
dims = [idx.dim for idx in indices.values()]
28+
assert len(dims) == 2
29+
assert dims[0] != dims[1]
30+
self._indices = indices
31+
32+
@classmethod
33+
def from_variables(cls, variables):
34+
assert len(variables) == 2
35+
return {k: PandasIndex.from_variables({k: v}) for k, v in variables.items()}
36+
37+
def create_variables(self, variables=None):
38+
idx_vars = {}
39+
for index in self._indices.values():
40+
idx_vars.update(index.create_variables(variables))
41+
return idx_vars
42+
43+
def sel(self, labels):
44+
results = []
45+
for k, index in self._indices.items():
46+
if k in labels:
47+
results.append(index.sel({k: labels[k]}))
48+
return merge_sel_results(results)
49+
50+
51+
@xattree(
52+
index=lambda ds: GridIndex(
53+
{"i": Indices.alias_dim(ds, "rows", "i"), "j": Indices.alias_dim(ds, "cols", "j")}
54+
)
55+
)
56+
class Grid:
57+
rows: int = dim(scope=ROOT, default=3, coord=False)
58+
cols: int = dim(scope=ROOT, default=3, coord=False)
59+
nodes: int = dim(scope=ROOT, init=False)
60+
a: NDArray[np.float64] = array(default=0.0, dims=("rows", "cols"))
61+
aa: NDArray[np.float64] = array(default=0.0, dims=("nodes",))
62+
63+
def __attrs_post_init__(self):
64+
self.nodes = self.rows * self.cols
65+
66+
67+
def test_grid_index():
68+
grid = Grid()
69+
assert grid.rows == 3
70+
assert grid.cols == 3
71+
assert grid.nodes == 9
72+
assert grid.data.i.shape == (3,)
73+
assert grid.data.j.shape == (3,)
74+
assert "i" in grid.data.coords
75+
assert "j" in grid.data.coords
76+
assert "rows" not in grid.data.coords
77+
assert "cols" not in grid.data.coords
78+
assert grid.data.a.shape == (3, 3)
79+
assert grid.data.aa.shape == (9,)

xattree/__init__.py

+28-15
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
)
2525

2626
import numpy as np
27-
import xarray as xa
27+
import pandas as pd
28+
import xarray as xr
2829
from attrs import NOTHING, Attribute, Converter, Factory, cmp_using, define, evolve
2930
from attrs import (
3031
field as attrs_field,
@@ -36,12 +37,13 @@
3637
has as attrs_has,
3738
)
3839
from numpy.typing import ArrayLike, NDArray
40+
from xarray.core.indexes import PandasIndex
3941
from xarray.core.types import Self
4042

4143
_PKG_NAME = "xattree"
4244

4345

44-
class _XatTree(xa.DataTree):
46+
class _XatTree(xr.DataTree):
4547
"""Monkey-patch `DataTree` with a reference to a host object."""
4648

4749
# DataTree is not yet a proper slotted class, it still has `__dict__`.
@@ -71,13 +73,13 @@ def copy(self, *, inherit: bool = True, deep: bool = False) -> Self:
7173
return new
7274

7375

74-
xa.DataTree = _XatTree # type: ignore
76+
xr.DataTree = _XatTree # type: ignore
7577

7678

7779
class _XatList(MutableSequence):
7880
"""Proxy a `DataTree`'s children of a given type through a list-like interface."""
7981

80-
def __init__(self, tree: xa.DataTree, xat: "_Xattribute", where: str):
82+
def __init__(self, tree: xr.DataTree, xat: "_Xattribute", where: str):
8183
self._tree = tree
8284
self._xat = xat
8385
self._where = where
@@ -159,7 +161,7 @@ def insert(self, index: int, value: Any):
159161
class _XatDict(MutableMapping):
160162
"""Proxy a `DataTree`'s children of a given type through a dict-like interface."""
161163

162-
def __init__(self, tree: xa.DataTree, xat: "_Xattribute", where: str):
164+
def __init__(self, tree: xr.DataTree, xat: "_Xattribute", where: str):
163165
self._tree = tree
164166
self._xat = xat
165167
self._where = where
@@ -620,7 +622,7 @@ def _init_tree(
620622
self: Any,
621623
strict: bool = True,
622624
where: str = _WHERE_DEFAULT,
623-
index: Callable[[xa.Dataset], xa.Index] | None = None,
625+
index: Callable[[xr.Dataset], xr.Index] | None = None,
624626
) -> None:
625627
"""
626628
Initialize a `DataTree` for an instance of a `xattree`-decorated class.
@@ -779,7 +781,7 @@ def _find_dim_or_coord(
779781
def _yield_coords() -> Iterator[tuple[str, tuple[str, NDArray]]]:
780782
# register inherited dimension sizes so we can expand arrays
781783
if parent:
782-
parent_tree: xa.DataTree = getattr(parent, where)
784+
parent_tree: xr.DataTree = getattr(parent, where)
783785
for dim_or_coord in parent_tree.coords.values():
784786
dimensions[dim_or_coord.dims[0]] = dim_or_coord.data.size
785787

@@ -805,7 +807,7 @@ def _yield_coords() -> Iterator[tuple[str, tuple[str, NDArray]]]:
805807
if isinstance(value, _Scalar):
806808
match type(value):
807809
case builtins.int | builtins.float | np.number:
808-
# todo customizable step/start? via xarray range index?
810+
# todo customizable step/start?
809811
step = 1
810812
start = 0
811813
case _:
@@ -835,14 +837,13 @@ def _yield_arrays() -> Iterator[tuple[str, NDArray | tuple[tuple[str, ...], NDAr
835837
yield (xat.name, array)
836838

837839
arrays = dict(list(_yield_arrays()))
838-
839-
dataset = xa.Dataset(
840+
dataset = xr.Dataset(
840841
data_vars=arrays,
841842
coords=coordinates,
842843
attrs={n: a for n, a in attributes.items()},
843844
)
844845
if index:
845-
dataset = dataset.assign_coords(xa.Coordinates.from_xindex(index(dataset)))
846+
dataset = dataset.assign_coords(xr.Coordinates.from_xindex(index(dataset)))
846847

847848
setattr(
848849
self,
@@ -863,7 +864,7 @@ def _getattr(self: Any, name: str) -> Any:
863864
raise AttributeError
864865
if name == _XATTREE_READY:
865866
return False
866-
tree = cast(xa.DataTree, getattr(self, where, None))
867+
tree = cast(xr.DataTree, getattr(self, where, None))
867868
if get_xattr := _XTRA_GETTERS.get(name, None):
868869
return get_xattr(tree)
869870
spec = _get_xatspec(cls)
@@ -936,7 +937,7 @@ def _setattr(self: Any, name: str, value: Any):
936937
if getattr(value, "parent", None) is not None:
937938
raise AttributeError(f"Child '{name}' already has a parent, can't set it.")
938939

939-
def drop_matching_children(node: xa.DataTree) -> xa.DataTree:
940+
def drop_matching_children(node: xr.DataTree) -> xr.DataTree:
940941
return node.filter(lambda c: not issubclass(type(c._host), xat.type)) # type: ignore
941942

942943
# DataTree.assign() replaces only the entries you provide it,
@@ -1105,7 +1106,7 @@ def fields(cls, extra: bool = False) -> list[Attribute]:
11051106
def xattree(
11061107
*,
11071108
where: str = _WHERE_DEFAULT,
1108-
index: Callable[[xa.Dataset], xa.Index] | None = None,
1109+
index: Callable[[xr.Dataset], xr.Index] | None = None,
11091110
) -> Callable[[type[T]], type[T]]: ...
11101111

11111112

@@ -1118,7 +1119,7 @@ def xattree(
11181119
maybe_cls: Optional[type[Any]] = None,
11191120
*,
11201121
where: str = _WHERE_DEFAULT,
1121-
index: Callable[[xa.Dataset], xa.Index] | None = None,
1122+
index: Callable[[xr.Dataset], xr.Index] | None = None,
11221123
) -> type[T] | Callable[[type[T]], type[T]]:
11231124
"""
11241125
Make an `attrs`-based class a (node in a) `xattree`.
@@ -1287,3 +1288,15 @@ def _transform_field(field: Attribute) -> Attribute:
12871288
return wrap
12881289

12891290
return wrap(maybe_cls)
1291+
1292+
1293+
class Indices:
1294+
"""
1295+
A collection of static functions for creating indices from datasets.
1296+
These can be used as the `index` argument in the `xattree` decorator.
1297+
"""
1298+
1299+
@staticmethod
1300+
def alias_dim(dataset: xr.Dataset, src_name: str, tgt_name: str) -> PandasIndex:
1301+
"""Alias a dimension field as a dimension coordinate variable with a different name."""
1302+
return PandasIndex(pd.RangeIndex(dataset.sizes[src_name], name=tgt_name), dim=src_name)

0 commit comments

Comments
 (0)