Skip to content

Commit 6b59d9a

Browse files
max-sixtyIllviljan
andauthored
Consolidate TypeVars in a single place (#5569)
* Consolidate type bounds in a single place * More consolidation * Update xarray/core/types.py Co-authored-by: Illviljan <[email protected]> * Update xarray/core/types.py Co-authored-by: Illviljan <[email protected]> * Rename T_DSorDA to T_Xarray * Update xarray/core/weighted.py Co-authored-by: Illviljan <[email protected]> * Update xarray/core/rolling_exp.py Co-authored-by: Illviljan <[email protected]> * . Co-authored-by: Illviljan <[email protected]>
1 parent befd1b9 commit 6b59d9a

11 files changed

+116
-97
lines changed

xarray/core/_typed_ops.pyi

+10-11
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,23 @@ from .dataarray import DataArray
99
from .dataset import Dataset
1010
from .groupby import DataArrayGroupBy, DatasetGroupBy, GroupBy
1111
from .npcompat import ArrayLike
12+
from .types import (
13+
DaCompatible,
14+
DsCompatible,
15+
GroupByIncompatible,
16+
ScalarOrArray,
17+
T_DataArray,
18+
T_Dataset,
19+
T_Variable,
20+
VarCompatible,
21+
)
1222
from .variable import Variable
1323

1424
try:
1525
from dask.array import Array as DaskArray
1626
except ImportError:
1727
DaskArray = np.ndarray
1828

19-
# DatasetOpsMixin etc. are parent classes of Dataset etc.
20-
T_Dataset = TypeVar("T_Dataset", bound="DatasetOpsMixin")
21-
T_DataArray = TypeVar("T_DataArray", bound="DataArrayOpsMixin")
22-
T_Variable = TypeVar("T_Variable", bound="VariableOpsMixin")
23-
24-
ScalarOrArray = Union[ArrayLike, np.generic, np.ndarray, DaskArray]
25-
DsCompatible = Union[Dataset, DataArray, Variable, GroupBy, ScalarOrArray]
26-
DaCompatible = Union[DataArray, Variable, DataArrayGroupBy, ScalarOrArray]
27-
VarCompatible = Union[Variable, ScalarOrArray]
28-
GroupByIncompatible = Union[Variable, GroupBy]
29-
3029
class DatasetOpsMixin:
3130
__slots__ = ()
3231
def _binary_op(self, other, f, reflexive=...): ...

xarray/core/common.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import warnings
24
from contextlib import suppress
35
from html import escape
@@ -36,10 +38,10 @@
3638
if TYPE_CHECKING:
3739
from .dataarray import DataArray
3840
from .dataset import Dataset
41+
from .types import T_DataWithCoords, T_Xarray
3942
from .variable import Variable
4043
from .weighted import Weighted
4144

42-
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
4345

4446
C = TypeVar("C")
4547
T = TypeVar("T")
@@ -795,9 +797,7 @@ def groupby_bins(
795797
},
796798
)
797799

798-
def weighted(
799-
self: T_DataWithCoords, weights: "DataArray"
800-
) -> "Weighted[T_DataWithCoords]":
800+
def weighted(self: T_DataWithCoords, weights: "DataArray") -> Weighted[T_Xarray]:
801801
"""
802802
Weighted operations.
803803

xarray/core/computation.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
Optional,
2222
Sequence,
2323
Tuple,
24-
TypeVar,
2524
Union,
2625
)
2726

@@ -36,11 +35,9 @@
3635
from .variable import Variable
3736

3837
if TYPE_CHECKING:
39-
from .coordinates import Coordinates # noqa
40-
from .dataarray import DataArray
38+
from .coordinates import Coordinates
4139
from .dataset import Dataset
42-
43-
T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
40+
from .types import T_Xarray
4441

4542
_NO_FILL_VALUE = utils.ReprObject("<no-fill-value>")
4643
_DEFAULT_NAME = utils.ReprObject("<default-name>")
@@ -199,7 +196,7 @@ def result_name(objects: list) -> Any:
199196
return name
200197

201198

202-
def _get_coords_list(args) -> List["Coordinates"]:
199+
def _get_coords_list(args) -> List[Coordinates]:
203200
coords_list = []
204201
for arg in args:
205202
try:
@@ -400,8 +397,8 @@ def apply_dict_of_variables_vfunc(
400397

401398

402399
def _fast_dataset(
403-
variables: Dict[Hashable, Variable], coord_variables: Mapping[Any, Variable]
404-
) -> "Dataset":
400+
variables: Dict[Hashable, Variable], coord_variables: Mapping[Hashable, Variable]
401+
) -> Dataset:
405402
"""Create a dataset as quickly as possible.
406403
407404
Beware: the `variables` dict is modified INPLACE.
@@ -1729,7 +1726,7 @@ def _calc_idxminmax(
17291726
return res
17301727

17311728

1732-
def unify_chunks(*objects: T_DSorDA) -> Tuple[T_DSorDA, ...]:
1729+
def unify_chunks(*objects: T_Xarray) -> Tuple[T_Xarray, ...]:
17331730
"""
17341731
Given any number of Dataset and/or DataArray objects, returns
17351732
new objects with unified chunk size along all chunked dimensions.

xarray/core/dataarray.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import datetime
24
import warnings
35
from typing import (
@@ -12,7 +14,6 @@
1214
Optional,
1315
Sequence,
1416
Tuple,
15-
TypeVar,
1617
Union,
1718
cast,
1819
)
@@ -70,8 +71,6 @@
7071
assert_unique_multiindex_level_names,
7172
)
7273

73-
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
74-
T_DSorDA = TypeVar("T_DSorDA", "DataArray", Dataset)
7574
if TYPE_CHECKING:
7675
try:
7776
from dask.delayed import Delayed
@@ -86,6 +85,8 @@
8685
except ImportError:
8786
iris_Cube = None
8887

88+
from .types import T_DataArray, T_Xarray
89+
8990

9091
def _infer_coords_and_dims(
9192
shape, coords, dims
@@ -3698,11 +3699,11 @@ def unify_chunks(self) -> "DataArray":
36983699

36993700
def map_blocks(
37003701
self,
3701-
func: Callable[..., T_DSorDA],
3702+
func: Callable[..., T_Xarray],
37023703
args: Sequence[Any] = (),
37033704
kwargs: Mapping[str, Any] = None,
37043705
template: Union["DataArray", "Dataset"] = None,
3705-
) -> T_DSorDA:
3706+
) -> T_Xarray:
37063707
"""
37073708
Apply a function to each block of this DataArray.
37083709

xarray/core/dataset.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
Sequence,
2626
Set,
2727
Tuple,
28-
TypeVar,
2928
Union,
3029
cast,
3130
overload,
@@ -109,8 +108,7 @@
109108
from ..backends import AbstractDataStore, ZarrStore
110109
from .dataarray import DataArray
111110
from .merge import CoercibleMapping
112-
113-
T_DSorDA = TypeVar("T_DSorDA", DataArray, "Dataset")
111+
from .types import T_Xarray
114112

115113
try:
116114
from dask.delayed import Delayed
@@ -6630,11 +6628,11 @@ def unify_chunks(self) -> "Dataset":
66306628

66316629
def map_blocks(
66326630
self,
6633-
func: "Callable[..., T_DSorDA]",
6631+
func: "Callable[..., T_Xarray]",
66346632
args: Sequence[Any] = (),
66356633
kwargs: Mapping[str, Any] = None,
66366634
template: Union["DataArray", "Dataset"] = None,
6637-
) -> "T_DSorDA":
6635+
) -> "T_Xarray":
66386636
"""
66396637
Apply a function to each block of this Dataset.
66406638

xarray/core/parallel.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from __future__ import annotations
2+
13
import collections
24
import itertools
35
import operator
46
from typing import (
7+
TYPE_CHECKING,
58
Any,
69
Callable,
710
DefaultDict,
@@ -12,7 +15,6 @@
1215
Mapping,
1316
Sequence,
1417
Tuple,
15-
TypeVar,
1618
Union,
1719
)
1820

@@ -32,7 +34,8 @@
3234
pass
3335

3436

35-
T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset)
37+
if TYPE_CHECKING:
38+
from .types import T_Xarray
3639

3740

3841
def unzip(iterable):
@@ -122,8 +125,8 @@ def make_meta(obj):
122125

123126

124127
def infer_template(
125-
func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs
126-
) -> T_DSorDA:
128+
func: Callable[..., T_Xarray], obj: Union[DataArray, Dataset], *args, **kwargs
129+
) -> T_Xarray:
127130
"""Infer return object by running the function on meta objects."""
128131
meta_args = [make_meta(arg) for arg in (obj,) + args]
129132

@@ -162,12 +165,12 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping
162165

163166

164167
def map_blocks(
165-
func: Callable[..., T_DSorDA],
168+
func: Callable[..., T_Xarray],
166169
obj: Union[DataArray, Dataset],
167170
args: Sequence[Any] = (),
168171
kwargs: Mapping[str, Any] = None,
169172
template: Union[DataArray, Dataset] = None,
170-
) -> T_DSorDA:
173+
) -> T_Xarray:
171174
"""Apply a function to each block of a DataArray or Dataset.
172175
173176
.. warning::

xarray/core/rolling_exp.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1+
from __future__ import annotations
2+
13
from distutils.version import LooseVersion
2-
from typing import TYPE_CHECKING, Generic, Hashable, Mapping, TypeVar, Union
4+
from typing import Generic, Hashable, Mapping, Union
35

46
import numpy as np
57

68
from .options import _get_keep_attrs
79
from .pdcompat import count_not_none
810
from .pycompat import is_duck_dask_array
9-
10-
if TYPE_CHECKING:
11-
from .dataarray import DataArray # noqa: F401
12-
from .dataset import Dataset # noqa: F401
13-
14-
T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset")
11+
from .types import T_Xarray
1512

1613

1714
def _get_alpha(com=None, span=None, halflife=None, alpha=None):
@@ -79,7 +76,7 @@ def _get_center_of_mass(comass, span, halflife, alpha):
7976
return float(comass)
8077

8178

82-
class RollingExp(Generic[T_DSorDA]):
79+
class RollingExp(Generic[T_Xarray]):
8380
"""
8481
Exponentially-weighted moving window object.
8582
Similar to EWM in pandas
@@ -103,16 +100,16 @@ class RollingExp(Generic[T_DSorDA]):
103100

104101
def __init__(
105102
self,
106-
obj: T_DSorDA,
103+
obj: T_Xarray,
107104
windows: Mapping[Hashable, Union[int, float]],
108105
window_type: str = "span",
109106
):
110-
self.obj: T_DSorDA = obj
107+
self.obj: T_Xarray = obj
111108
dim, window = next(iter(windows.items()))
112109
self.dim = dim
113110
self.alpha = _get_alpha(**{window_type: window})
114111

115-
def mean(self, keep_attrs: bool = None) -> T_DSorDA:
112+
def mean(self, keep_attrs: bool = None) -> T_Xarray:
116113
"""
117114
Exponentially weighted moving average.
118115
@@ -139,7 +136,7 @@ def mean(self, keep_attrs: bool = None) -> T_DSorDA:
139136
move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs
140137
)
141138

142-
def sum(self, keep_attrs: bool = None) -> T_DSorDA:
139+
def sum(self, keep_attrs: bool = None) -> T_Xarray:
143140
"""
144141
Exponentially weighted moving sum.
145142

xarray/core/types.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, TypeVar, Union
4+
5+
import numpy as np
6+
7+
if TYPE_CHECKING:
8+
from .common import DataWithCoords
9+
from .dataarray import DataArray
10+
from .dataset import Dataset
11+
from .groupby import DataArrayGroupBy, GroupBy
12+
from .npcompat import ArrayLike
13+
from .variable import Variable
14+
15+
try:
16+
from dask.array import Array as DaskArray
17+
except ImportError:
18+
DaskArray = np.ndarray
19+
20+
T_Dataset = TypeVar("T_Dataset", bound="Dataset")
21+
T_DataArray = TypeVar("T_DataArray", bound="DataArray")
22+
T_Variable = TypeVar("T_Variable", bound="Variable")
23+
# Maybe we rename this to T_Data or something less Fortran-y?
24+
T_Xarray = TypeVar("T_Xarray", "DataArray", "Dataset")
25+
T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")
26+
27+
ScalarOrArray = Union["ArrayLike", np.generic, np.ndarray, "DaskArray"]
28+
DsCompatible = Union["Dataset", "DataArray", "Variable", "GroupBy", "ScalarOrArray"]
29+
DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"]
30+
VarCompatible = Union["Variable", "ScalarOrArray"]
31+
GroupByIncompatible = Union["Variable", "GroupBy"]

0 commit comments

Comments
 (0)