Skip to content

Commit 89431af

Browse files
committed
appease mypy
1 parent fdfa082 commit 89431af

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

xattree/__init__.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def _init_tree(
620620
self: Any,
621621
strict: bool = True,
622622
where: str = _WHERE_DEFAULT,
623-
index: Callable[[xa.Dataset], type[xa.Index]] | None = None,
623+
index: Callable[[xa.Dataset], xa.Index] | None = None,
624624
) -> None:
625625
"""
626626
Initialize a `DataTree` for an instance of a `xattree`-decorated class.
@@ -664,12 +664,10 @@ def _yield_children() -> Iterator[tuple[str, Any]]:
664664
raise TypeError(f"Bad child collection field '{xat.name}'")
665665

666666
def _yield_attrs() -> Iterator[tuple[str, Any]]:
667-
for xat_name, xat in xatspec.dims.items():
668-
if xat.coord:
667+
for xat_name, xat in chain(xatspec.dims.items(), xatspec.attrs.items()):
668+
if isinstance(xat, _Dim) and xat.coord:
669669
continue
670670
yield (xat_name, self.__dict__.pop(xat_name, xat.default))
671-
for xat_name, xat in xatspec.attrs.items():
672-
yield (xat_name, self.__dict__.pop(xat_name, xat.default))
673671

674672
children = dict(list(_yield_children()))
675673
attributes = dict(list(_yield_attrs()))
@@ -1107,7 +1105,7 @@ def fields(cls, extra: bool = False) -> list[Attribute]:
11071105
def xattree(
11081106
*,
11091107
where: str = _WHERE_DEFAULT,
1110-
index: Callable[[xa.Dataset], type[xa.Index]] | None = None,
1108+
index: Callable[[xa.Dataset], xa.Index] | None = None,
11111109
) -> Callable[[type[T]], type[T]]: ...
11121110

11131111

@@ -1120,9 +1118,25 @@ def xattree(
11201118
maybe_cls: Optional[type[Any]] = None,
11211119
*,
11221120
where: str = _WHERE_DEFAULT,
1123-
index: Callable[[xa.Dataset], type[xa.Index]] | None = None,
1121+
index: Callable[[xa.Dataset], xa.Index] | None = None,
11241122
) -> type[T] | Callable[[type[T]], type[T]]:
1125-
"""Make an `attrs`-based class a (node in a) `xattree`."""
1123+
"""
1124+
Make an `attrs`-based class a (node in a) `xattree`.
1125+
1126+
Parameters
1127+
----------
1128+
maybe_cls : type, optional
1129+
The class to be decorated. If not provided, the decorator
1130+
is returned as a callable that can be used to decorate
1131+
a class later.
1132+
where : str, optional
1133+
The name of the attribute that will hold the `xattree`.
1134+
Default is "data".
1135+
index : Callable, optional
1136+
A function that takes a `xarray.Dataset` and returns
1137+
an `xarray.Index`. If provided, the index built will
1138+
be assigned as coordinates to the dataset.
1139+
"""
11261140

11271141
def wrap(cls):
11281142
if has_xats(cls):

0 commit comments

Comments
 (0)