24
24
)
25
25
26
26
import numpy as np
27
- import xarray as xa
27
+ import pandas as pd
28
+ import xarray as xr
28
29
from attrs import NOTHING , Attribute , Converter , Factory , cmp_using , define , evolve
29
30
from attrs import (
30
31
field as attrs_field ,
36
37
has as attrs_has ,
37
38
)
38
39
from numpy .typing import ArrayLike , NDArray
40
+ from xarray .core .indexes import PandasIndex
39
41
from xarray .core .types import Self
40
42
41
43
_PKG_NAME = "xattree"
42
44
43
45
44
- class _XatTree (xa .DataTree ):
46
+ class _XatTree (xr .DataTree ):
45
47
"""Monkey-patch `DataTree` with a reference to a host object."""
46
48
47
49
# DataTree is not yet a proper slotted class, it still has `__dict__`.
@@ -71,13 +73,13 @@ def copy(self, *, inherit: bool = True, deep: bool = False) -> Self:
71
73
return new
72
74
73
75
74
- xa .DataTree = _XatTree # type: ignore
76
+ xr .DataTree = _XatTree # type: ignore
75
77
76
78
77
79
class _XatList (MutableSequence ):
78
80
"""Proxy a `DataTree`'s children of a given type through a list-like interface."""
79
81
80
- def __init__ (self , tree : xa .DataTree , xat : "_Xattribute" , where : str ):
82
+ def __init__ (self , tree : xr .DataTree , xat : "_Xattribute" , where : str ):
81
83
self ._tree = tree
82
84
self ._xat = xat
83
85
self ._where = where
@@ -159,7 +161,7 @@ def insert(self, index: int, value: Any):
159
161
class _XatDict (MutableMapping ):
160
162
"""Proxy a `DataTree`'s children of a given type through a dict-like interface."""
161
163
162
- def __init__ (self , tree : xa .DataTree , xat : "_Xattribute" , where : str ):
164
+ def __init__ (self , tree : xr .DataTree , xat : "_Xattribute" , where : str ):
163
165
self ._tree = tree
164
166
self ._xat = xat
165
167
self ._where = where
@@ -620,7 +622,7 @@ def _init_tree(
620
622
self : Any ,
621
623
strict : bool = True ,
622
624
where : str = _WHERE_DEFAULT ,
623
- index : Callable [[xa .Dataset ], xa .Index ] | None = None ,
625
+ index : Callable [[xr .Dataset ], xr .Index ] | None = None ,
624
626
) -> None :
625
627
"""
626
628
Initialize a `DataTree` for an instance of a `xattree`-decorated class.
@@ -779,7 +781,7 @@ def _find_dim_or_coord(
779
781
def _yield_coords () -> Iterator [tuple [str , tuple [str , NDArray ]]]:
780
782
# register inherited dimension sizes so we can expand arrays
781
783
if parent :
782
- parent_tree : xa .DataTree = getattr (parent , where )
784
+ parent_tree : xr .DataTree = getattr (parent , where )
783
785
for dim_or_coord in parent_tree .coords .values ():
784
786
dimensions [dim_or_coord .dims [0 ]] = dim_or_coord .data .size
785
787
@@ -805,7 +807,7 @@ def _yield_coords() -> Iterator[tuple[str, tuple[str, NDArray]]]:
805
807
if isinstance (value , _Scalar ):
806
808
match type (value ):
807
809
case builtins .int | builtins .float | np .number :
808
- # todo customizable step/start? via xarray range index?
810
+ # todo customizable step/start?
809
811
step = 1
810
812
start = 0
811
813
case _:
@@ -835,14 +837,13 @@ def _yield_arrays() -> Iterator[tuple[str, NDArray | tuple[tuple[str, ...], NDAr
835
837
yield (xat .name , array )
836
838
837
839
arrays = dict (list (_yield_arrays ()))
838
-
839
- dataset = xa .Dataset (
840
+ dataset = xr .Dataset (
840
841
data_vars = arrays ,
841
842
coords = coordinates ,
842
843
attrs = {n : a for n , a in attributes .items ()},
843
844
)
844
845
if index :
845
- dataset = dataset .assign_coords (xa .Coordinates .from_xindex (index (dataset )))
846
+ dataset = dataset .assign_coords (xr .Coordinates .from_xindex (index (dataset )))
846
847
847
848
setattr (
848
849
self ,
@@ -863,7 +864,7 @@ def _getattr(self: Any, name: str) -> Any:
863
864
raise AttributeError
864
865
if name == _XATTREE_READY :
865
866
return False
866
- tree = cast (xa .DataTree , getattr (self , where , None ))
867
+ tree = cast (xr .DataTree , getattr (self , where , None ))
867
868
if get_xattr := _XTRA_GETTERS .get (name , None ):
868
869
return get_xattr (tree )
869
870
spec = _get_xatspec (cls )
@@ -936,7 +937,7 @@ def _setattr(self: Any, name: str, value: Any):
936
937
if getattr (value , "parent" , None ) is not None :
937
938
raise AttributeError (f"Child '{ name } ' already has a parent, can't set it." )
938
939
939
- def drop_matching_children (node : xa .DataTree ) -> xa .DataTree :
940
+ def drop_matching_children (node : xr .DataTree ) -> xr .DataTree :
940
941
return node .filter (lambda c : not issubclass (type (c ._host ), xat .type )) # type: ignore
941
942
942
943
# DataTree.assign() replaces only the entries you provide it,
@@ -1105,7 +1106,7 @@ def fields(cls, extra: bool = False) -> list[Attribute]:
1105
1106
def xattree (
1106
1107
* ,
1107
1108
where : str = _WHERE_DEFAULT ,
1108
- index : Callable [[xa .Dataset ], xa .Index ] | None = None ,
1109
+ index : Callable [[xr .Dataset ], xr .Index ] | None = None ,
1109
1110
) -> Callable [[type [T ]], type [T ]]: ...
1110
1111
1111
1112
@@ -1118,7 +1119,7 @@ def xattree(
1118
1119
maybe_cls : Optional [type [Any ]] = None ,
1119
1120
* ,
1120
1121
where : str = _WHERE_DEFAULT ,
1121
- index : Callable [[xa .Dataset ], xa .Index ] | None = None ,
1122
+ index : Callable [[xr .Dataset ], xr .Index ] | None = None ,
1122
1123
) -> type [T ] | Callable [[type [T ]], type [T ]]:
1123
1124
"""
1124
1125
Make an `attrs`-based class a (node in a) `xattree`.
@@ -1287,3 +1288,15 @@ def _transform_field(field: Attribute) -> Attribute:
1287
1288
return wrap
1288
1289
1289
1290
return wrap (maybe_cls )
1291
+
1292
+
1293
+ class Indices :
1294
+ """
1295
+ A collection of static functions for creating indices from datasets.
1296
+ These can be used as the `index` argument in the `xattree` decorator.
1297
+ """
1298
+
1299
+ @staticmethod
1300
+ def alias_dim (dataset : xr .Dataset , src_name : str , tgt_name : str ) -> PandasIndex :
1301
+ """Alias a dimension field as a dimension coordinate variable with a different name."""
1302
+ return PandasIndex (pd .RangeIndex (dataset .sizes [src_name ], name = tgt_name ), dim = src_name )
0 commit comments