diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index f13c7e8d2c6..21bddf44e4e 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -597,12 +597,12 @@ def _dask_finalize( data = array_func(results, *args, **kwargs) return type(self)(self._dims, data, attrs=self._attrs) - def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, ...]: + def _get_axis_nums(self, dims: _Dims) -> tuple[int, ...]: """Return axis number(s) corresponding to dimension(s) in this array. Parameters ---------- - dim : str or iterable of str + dim : tuple of str Dimension name(s) for which to lookup axes. Returns @@ -610,14 +610,12 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, . int or tuple of int Axis number or numbers corresponding to the given dimensions. """ - if not isinstance(dim, str) and isinstance(dim, Iterable): - return tuple(self._get_axis_num(d) for d in dim) - else: - return self._get_axis_num(dim) + return tuple(self._get_axis_num(d) for d in dims) - def _get_axis_num(self: Any, dim: Hashable) -> int: + def _get_axis_num(self: Any, dim: _Dim) -> int: try: - return self.dims.index(dim) # type: ignore[no-any-return] + out: int = self.dims.index(dim) + return out except ValueError: raise ValueError(f"{dim!r} not found in array dimensions {self.dims!r}") @@ -710,7 +708,7 @@ def reduce( raise ValueError("cannot supply both 'axis' and 'dim' arguments") if dim is not None: - axis = self.get_axis_num(dim) + axis = self._get_axis_nums(dim) with warnings.catch_warnings(): warnings.filterwarnings(