@@ -620,7 +620,7 @@ def _init_tree(
620
620
self : Any ,
621
621
strict : bool = True ,
622
622
where : str = _WHERE_DEFAULT ,
623
- index : Callable [[xa .Dataset ], type [ xa .Index ] ] | None = None ,
623
+ index : Callable [[xa .Dataset ], xa .Index ] | None = None ,
624
624
) -> None :
625
625
"""
626
626
Initialize a `DataTree` for an instance of a `xattree`-decorated class.
@@ -664,12 +664,10 @@ def _yield_children() -> Iterator[tuple[str, Any]]:
664
664
raise TypeError (f"Bad child collection field '{ xat .name } '" )
665
665
666
666
def _yield_attrs () -> Iterator [tuple [str , Any ]]:
667
- for xat_name , xat in xatspec .dims .items ():
668
- if xat .coord :
667
+ for xat_name , xat in chain ( xatspec .dims .items (), xatspec . attrs . items () ):
668
+ if isinstance ( xat , _Dim ) and xat .coord :
669
669
continue
670
670
yield (xat_name , self .__dict__ .pop (xat_name , xat .default ))
671
- for xat_name , xat in xatspec .attrs .items ():
672
- yield (xat_name , self .__dict__ .pop (xat_name , xat .default ))
673
671
674
672
children = dict (list (_yield_children ()))
675
673
attributes = dict (list (_yield_attrs ()))
@@ -1107,7 +1105,7 @@ def fields(cls, extra: bool = False) -> list[Attribute]:
1107
1105
def xattree (
1108
1106
* ,
1109
1107
where : str = _WHERE_DEFAULT ,
1110
- index : Callable [[xa .Dataset ], type [ xa .Index ] ] | None = None ,
1108
+ index : Callable [[xa .Dataset ], xa .Index ] | None = None ,
1111
1109
) -> Callable [[type [T ]], type [T ]]: ...
1112
1110
1113
1111
@@ -1120,9 +1118,25 @@ def xattree(
1120
1118
maybe_cls : Optional [type [Any ]] = None ,
1121
1119
* ,
1122
1120
where : str = _WHERE_DEFAULT ,
1123
- index : Callable [[xa .Dataset ], type [ xa .Index ] ] | None = None ,
1121
+ index : Callable [[xa .Dataset ], xa .Index ] | None = None ,
1124
1122
) -> type [T ] | Callable [[type [T ]], type [T ]]:
1125
- """Make an `attrs`-based class a (node in a) `xattree`."""
1123
+ """
1124
+ Make an `attrs`-based class a (node in a) `xattree`.
1125
+
1126
+ Parameters
1127
+ ----------
1128
+ maybe_cls : type, optional
1129
+ The class to be decorated. If not provided, the decorator
1130
+ is returned as a callable that can be used to decorate
1131
+ a class later.
1132
+ where : str, optional
1133
+ The name of the attribute that will hold the `xattree`.
1134
+ Default is "data".
1135
+ index : Callable, optional
1136
+ A function that takes a `xarray.Dataset` and returns
1137
+ an `xarray.Index`. If provided, the index built will
1138
+ be assigned as coordinates to the dataset.
1139
+ """
1126
1140
1127
1141
def wrap (cls ):
1128
1142
if has_xats (cls ):
0 commit comments