Skip to content

Commit 1580c2c

Browse files
authored
Clean up Dims type annotation (#8606)
1 parent 53fdfca commit 1580c2c

File tree

6 files changed

+32
-39
lines changed

6 files changed

+32
-39
lines changed

.github/workflows/ci-additional.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ jobs:
120120
python xarray/util/print_versions.py
121121
- name: Install mypy
122122
run: |
123-
python -m pip install "mypy<1.8" --force-reinstall
123+
python -m pip install "mypy<1.9" --force-reinstall
124124
125125
- name: Run mypy
126126
run: |
@@ -174,7 +174,7 @@ jobs:
174174
python xarray/util/print_versions.py
175175
- name: Install mypy
176176
run: |
177-
python -m pip install "mypy<1.8" --force-reinstall
177+
python -m pip install "mypy<1.9" --force-reinstall
178178
179179
- name: Run mypy
180180
run: |

xarray/core/computation.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from xarray.core.parallelcompat import get_chunked_array_type
2525
from xarray.core.pycompat import is_chunked_array, is_duck_dask_array
2626
from xarray.core.types import Dims, T_DataArray
27-
from xarray.core.utils import is_dict_like, is_scalar
27+
from xarray.core.utils import is_dict_like, is_scalar, parse_dims
2828
from xarray.core.variable import Variable
2929
from xarray.util.deprecation_helpers import deprecate_dims
3030

@@ -1875,16 +1875,14 @@ def dot(
18751875
einsum_axes = "abcdefghijklmnopqrstuvwxyz"
18761876
dim_map = {d: einsum_axes[i] for i, d in enumerate(all_dims)}
18771877

1878-
if dim is ...:
1879-
dim = all_dims
1880-
elif isinstance(dim, str):
1881-
dim = (dim,)
1882-
elif dim is None:
1883-
# find dimensions that occur more than one times
1878+
if dim is None:
1879+
# find dimensions that occur more than once
18841880
dim_counts: Counter = Counter()
18851881
for arr in arrays:
18861882
dim_counts.update(arr.dims)
18871883
dim = tuple(d for d, c in dim_counts.items() if c > 1)
1884+
else:
1885+
dim = parse_dims(dim, all_dims=tuple(all_dims))
18881886

18891887
dot_dims: set[Hashable] = set(dim)
18901888

xarray/core/types.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import datetime
44
import sys
5-
from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
5+
from collections.abc import Collection, Hashable, Iterator, Mapping, Sequence
66
from typing import (
77
TYPE_CHECKING,
88
Any,
@@ -182,8 +182,9 @@ def copy(
182182
DsCompatible = Union["Dataset", "DaCompatible"]
183183
GroupByCompatible = Union["Dataset", "DataArray"]
184184

185-
Dims = Union[str, Iterable[Hashable], "ellipsis", None]
186-
OrderedDims = Union[str, Sequence[Union[Hashable, "ellipsis"]], "ellipsis", None]
185+
# Don't change to Hashable | Collection[Hashable]
186+
# Read: https://github.com/pydata/xarray/issues/6142
187+
Dims = Union[str, Collection[Hashable], "ellipsis", None]
187188

188189
# FYI in some cases we don't allow `None`, which this doesn't take account of.
189190
T_ChunkDim: TypeAlias = Union[int, Literal["auto"], None, tuple[int, ...]]

xarray/core/utils.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
Mapping,
5858
MutableMapping,
5959
MutableSet,
60-
Sequence,
6160
ValuesView,
6261
)
6362
from enum import Enum
@@ -76,7 +75,7 @@
7675
import pandas as pd
7776

7877
if TYPE_CHECKING:
79-
from xarray.core.types import Dims, ErrorOptionsWithWarn, OrderedDims, T_DuckArray
78+
from xarray.core.types import Dims, ErrorOptionsWithWarn, T_DuckArray
8079

8180
K = TypeVar("K")
8281
V = TypeVar("V")
@@ -983,12 +982,9 @@ def drop_missing_dims(
983982
)
984983

985984

986-
T_None = TypeVar("T_None", None, "ellipsis")
987-
988-
989985
@overload
990986
def parse_dims(
991-
dim: str | Iterable[Hashable] | T_None,
987+
dim: Dims,
992988
all_dims: tuple[Hashable, ...],
993989
*,
994990
check_exists: bool = True,
@@ -999,12 +995,12 @@ def parse_dims(
999995

1000996
@overload
1001997
def parse_dims(
1002-
dim: str | Iterable[Hashable] | T_None,
998+
dim: Dims,
1003999
all_dims: tuple[Hashable, ...],
10041000
*,
10051001
check_exists: bool = True,
10061002
replace_none: Literal[False],
1007-
) -> tuple[Hashable, ...] | T_None:
1003+
) -> tuple[Hashable, ...] | None | ellipsis:
10081004
...
10091005

10101006

@@ -1051,7 +1047,7 @@ def parse_dims(
10511047

10521048
@overload
10531049
def parse_ordered_dims(
1054-
dim: str | Sequence[Hashable | ellipsis] | T_None,
1050+
dim: Dims,
10551051
all_dims: tuple[Hashable, ...],
10561052
*,
10571053
check_exists: bool = True,
@@ -1062,17 +1058,17 @@ def parse_ordered_dims(
10621058

10631059
@overload
10641060
def parse_ordered_dims(
1065-
dim: str | Sequence[Hashable | ellipsis] | T_None,
1061+
dim: Dims,
10661062
all_dims: tuple[Hashable, ...],
10671063
*,
10681064
check_exists: bool = True,
10691065
replace_none: Literal[False],
1070-
) -> tuple[Hashable, ...] | T_None:
1066+
) -> tuple[Hashable, ...] | None | ellipsis:
10711067
...
10721068

10731069

10741070
def parse_ordered_dims(
1075-
dim: OrderedDims,
1071+
dim: Dims,
10761072
all_dims: tuple[Hashable, ...],
10771073
*,
10781074
check_exists: bool = True,
@@ -1126,9 +1122,9 @@ def parse_ordered_dims(
11261122
)
11271123

11281124

1129-
def _check_dims(dim: set[Hashable | ellipsis], all_dims: set[Hashable]) -> None:
1130-
wrong_dims = dim - all_dims
1131-
if wrong_dims and wrong_dims != {...}:
1125+
def _check_dims(dim: set[Hashable], all_dims: set[Hashable]) -> None:
1126+
wrong_dims = (dim - all_dims) - {...}
1127+
if wrong_dims:
11321128
wrong_dims_str = ", ".join(f"'{d!s}'" for d in wrong_dims)
11331129
raise ValueError(
11341130
f"Dimension(s) {wrong_dims_str} do not exist. Expected one or more of {all_dims}"

xarray/tests/test_interp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -838,8 +838,8 @@ def test_interpolate_chunk_1d(
838838
if chunked:
839839
dest[dim] = xr.DataArray(data=dest[dim], dims=[dim])
840840
dest[dim] = dest[dim].chunk(2)
841-
actual = da.interp(method=method, **dest, kwargs=kwargs) # type: ignore
842-
expected = da.compute().interp(method=method, **dest, kwargs=kwargs) # type: ignore
841+
actual = da.interp(method=method, **dest, kwargs=kwargs)
842+
expected = da.compute().interp(method=method, **dest, kwargs=kwargs)
843843

844844
assert_identical(actual, expected)
845845

xarray/tests/test_utils.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Hashable, Iterable, Sequence
3+
from collections.abc import Hashable
44

55
import numpy as np
66
import pandas as pd
@@ -257,17 +257,18 @@ def test_infix_dims_errors(supplied, all_):
257257
pytest.param("a", ("a",), id="str"),
258258
pytest.param(["a", "b"], ("a", "b"), id="list_of_str"),
259259
pytest.param(["a", 1], ("a", 1), id="list_mixed"),
260+
pytest.param(["a", ...], ("a", ...), id="list_with_ellipsis"),
260261
pytest.param(("a", "b"), ("a", "b"), id="tuple_of_str"),
261262
pytest.param(["a", ("b", "c")], ("a", ("b", "c")), id="list_with_tuple"),
262263
pytest.param((("b", "c"),), (("b", "c"),), id="tuple_of_tuple"),
264+
pytest.param({"a", 1}, tuple({"a", 1}), id="non_sequence_collection"),
265+
pytest.param((), (), id="empty_tuple"),
266+
pytest.param(set(), (), id="empty_collection"),
263267
pytest.param(None, None, id="None"),
264268
pytest.param(..., ..., id="ellipsis"),
265269
],
266270
)
267-
def test_parse_dims(
268-
dim: str | Iterable[Hashable] | None,
269-
expected: tuple[Hashable, ...],
270-
) -> None:
271+
def test_parse_dims(dim, expected):
271272
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
272273
actual = utils.parse_dims(dim, all_dims, replace_none=False)
273274
assert actual == expected
@@ -297,7 +298,7 @@ def test_parse_dims_replace_none(dim: None | ellipsis) -> None:
297298
pytest.param(["x", 2], id="list_missing_all"),
298299
],
299300
)
300-
def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
301+
def test_parse_dims_raises(dim):
301302
all_dims = ("a", "b", 1, ("b", "c")) # selection of different Hashables
302303
with pytest.raises(ValueError, match="'x'"):
303304
utils.parse_dims(dim, all_dims, check_exists=True)
@@ -313,10 +314,7 @@ def test_parse_dims_raises(dim: str | Iterable[Hashable]) -> None:
313314
pytest.param(["a", ..., "b"], ("a", "c", "b"), id="list_with_middle_ellipsis"),
314315
],
315316
)
316-
def test_parse_ordered_dims(
317-
dim: str | Sequence[Hashable | ellipsis],
318-
expected: tuple[Hashable, ...],
319-
) -> None:
317+
def test_parse_ordered_dims(dim, expected):
320318
all_dims = ("a", "b", "c")
321319
actual = utils.parse_ordered_dims(dim, all_dims)
322320
assert actual == expected

0 commit comments

Comments
 (0)