Skip to content

Commit 81f38f3

Browse files
Add overloads to get_axis_num (#8547)
Co-authored-by: Anderson Banihirwe <[email protected]>
1 parent b0b5b2f commit 81f38f3

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

xarray/core/common.py

+8
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ def __iter__(self: Any) -> Iterator[Any]:
199199
raise TypeError("iteration over a 0-d array")
200200
return self._iter()
201201

202+
@overload
203+
def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]:
204+
...
205+
206+
@overload
207+
def get_axis_num(self, dim: Hashable) -> int:
208+
...
209+
202210
def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]:
203211
"""Return axis number(s) corresponding to dimension(s) in this array.
204212

xarray/namedarray/core.py

+8
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,14 @@ def _dask_finalize(
648648
data = array_func(results, *args, **kwargs)
649649
return type(self)(self._dims, data, attrs=self._attrs)
650650

651+
@overload
652+
def get_axis_num(self, dim: Iterable[Hashable]) -> tuple[int, ...]:
653+
...
654+
655+
@overload
656+
def get_axis_num(self, dim: Hashable) -> int:
657+
...
658+
651659
def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]:
652660
"""Return axis number(s) corresponding to dimension(s) in this array.
653661

0 commit comments

Comments
 (0)