Skip to content

Commit 7aa207b

Browse files
authored
Improved typing of align & broadcast (#8234)
* add overloads to align * add overloads to broadcast as well * add some more typing * remove unused ignore
1 parent e8be4bb commit 7aa207b

File tree

8 files changed

+216
-37
lines changed

8 files changed

+216
-37
lines changed

xarray/core/alignment.py

Lines changed: 185 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66
from collections.abc import Hashable, Iterable, Mapping
77
from contextlib import suppress
8-
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
8+
from typing import TYPE_CHECKING, Any, Callable, Final, Generic, TypeVar, cast, overload
99

1010
import numpy as np
1111
import pandas as pd
@@ -26,7 +26,13 @@
2626
if TYPE_CHECKING:
2727
from xarray.core.dataarray import DataArray
2828
from xarray.core.dataset import Dataset
29-
from xarray.core.types import JoinOptions, T_DataArray, T_Dataset, T_DuckArray
29+
from xarray.core.types import (
30+
Alignable,
31+
JoinOptions,
32+
T_DataArray,
33+
T_Dataset,
34+
T_DuckArray,
35+
)
3036

3137

3238
def reindex_variables(
@@ -128,7 +134,7 @@ def __init__(
128134
objects: Iterable[T_Alignable],
129135
join: str = "inner",
130136
indexes: Mapping[Any, Any] | None = None,
131-
exclude_dims: Iterable = frozenset(),
137+
exclude_dims: str | Iterable[Hashable] = frozenset(),
132138
exclude_vars: Iterable[Hashable] = frozenset(),
133139
method: str | None = None,
134140
tolerance: int | float | Iterable[int | float] | None = None,
@@ -576,12 +582,111 @@ def align(self) -> None:
576582
self.reindex_all()
577583

578584

585+
T_Obj1 = TypeVar("T_Obj1", bound="Alignable")
586+
T_Obj2 = TypeVar("T_Obj2", bound="Alignable")
587+
T_Obj3 = TypeVar("T_Obj3", bound="Alignable")
588+
T_Obj4 = TypeVar("T_Obj4", bound="Alignable")
589+
T_Obj5 = TypeVar("T_Obj5", bound="Alignable")
590+
591+
592+
@overload
593+
def align(
594+
obj1: T_Obj1,
595+
/,
596+
*,
597+
join: JoinOptions = "inner",
598+
copy: bool = True,
599+
indexes=None,
600+
exclude: str | Iterable[Hashable] = frozenset(),
601+
fill_value=dtypes.NA,
602+
) -> tuple[T_Obj1]:
603+
...
604+
605+
606+
@overload
607+
def align( # type: ignore[misc]
608+
obj1: T_Obj1,
609+
obj2: T_Obj2,
610+
/,
611+
*,
612+
join: JoinOptions = "inner",
613+
copy: bool = True,
614+
indexes=None,
615+
exclude: str | Iterable[Hashable] = frozenset(),
616+
fill_value=dtypes.NA,
617+
) -> tuple[T_Obj1, T_Obj2]:
618+
...
619+
620+
621+
@overload
622+
def align( # type: ignore[misc]
623+
obj1: T_Obj1,
624+
obj2: T_Obj2,
625+
obj3: T_Obj3,
626+
/,
627+
*,
628+
join: JoinOptions = "inner",
629+
copy: bool = True,
630+
indexes=None,
631+
exclude: str | Iterable[Hashable] = frozenset(),
632+
fill_value=dtypes.NA,
633+
) -> tuple[T_Obj1, T_Obj2, T_Obj3]:
634+
...
635+
636+
637+
@overload
638+
def align( # type: ignore[misc]
639+
obj1: T_Obj1,
640+
obj2: T_Obj2,
641+
obj3: T_Obj3,
642+
obj4: T_Obj4,
643+
/,
644+
*,
645+
join: JoinOptions = "inner",
646+
copy: bool = True,
647+
indexes=None,
648+
exclude: str | Iterable[Hashable] = frozenset(),
649+
fill_value=dtypes.NA,
650+
) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]:
651+
...
652+
653+
654+
@overload
655+
def align( # type: ignore[misc]
656+
obj1: T_Obj1,
657+
obj2: T_Obj2,
658+
obj3: T_Obj3,
659+
obj4: T_Obj4,
660+
obj5: T_Obj5,
661+
/,
662+
*,
663+
join: JoinOptions = "inner",
664+
copy: bool = True,
665+
indexes=None,
666+
exclude: str | Iterable[Hashable] = frozenset(),
667+
fill_value=dtypes.NA,
668+
) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]:
669+
...
670+
671+
672+
@overload
579673
def align(
580674
*objects: T_Alignable,
581675
join: JoinOptions = "inner",
582676
copy: bool = True,
583677
indexes=None,
584-
exclude=frozenset(),
678+
exclude: str | Iterable[Hashable] = frozenset(),
679+
fill_value=dtypes.NA,
680+
) -> tuple[T_Alignable, ...]:
681+
...
682+
683+
684+
def align( # type: ignore[misc]
685+
*objects: T_Alignable,
686+
join: JoinOptions = "inner",
687+
copy: bool = True,
688+
indexes=None,
689+
exclude: str | Iterable[Hashable] = frozenset(),
585690
fill_value=dtypes.NA,
586691
) -> tuple[T_Alignable, ...]:
587692
"""
@@ -620,7 +725,7 @@ def align(
620725
indexes : dict-like, optional
621726
Any indexes explicitly provided with the `indexes` argument should be
622727
used in preference to the aligned indexes.
623-
exclude : sequence of str, optional
728+
exclude : str, iterable of hashable or None, optional
624729
Dimensions that must be excluded from alignment
625730
fill_value : scalar or dict-like, optional
626731
Value to use for newly missing values. If a dict-like, maps
@@ -787,12 +892,12 @@ def align(
787892
def deep_align(
788893
objects: Iterable[Any],
789894
join: JoinOptions = "inner",
790-
copy=True,
895+
copy: bool = True,
791896
indexes=None,
792-
exclude=frozenset(),
793-
raise_on_invalid=True,
897+
exclude: str | Iterable[Hashable] = frozenset(),
898+
raise_on_invalid: bool = True,
794899
fill_value=dtypes.NA,
795-
):
900+
) -> list[Any]:
796901
"""Align objects for merging, recursing into dictionary values.
797902
798903
This function is not public API.
@@ -807,12 +912,12 @@ def deep_align(
807912
def is_alignable(obj):
808913
return isinstance(obj, (Coordinates, DataArray, Dataset))
809914

810-
positions = []
811-
keys = []
812-
out = []
813-
targets = []
814-
no_key = object()
815-
not_replaced = object()
915+
positions: list[int] = []
916+
keys: list[type[object] | Hashable] = []
917+
out: list[Any] = []
918+
targets: list[Alignable] = []
919+
no_key: Final = object()
920+
not_replaced: Final = object()
816921
for position, variables in enumerate(objects):
817922
if is_alignable(variables):
818923
positions.append(position)
@@ -857,7 +962,7 @@ def is_alignable(obj):
857962
if key is no_key:
858963
out[position] = aligned_obj
859964
else:
860-
out[position][key] = aligned_obj # type: ignore[index] # maybe someone can fix this?
965+
out[position][key] = aligned_obj
861966

862967
return out
863968

@@ -988,9 +1093,69 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
9881093
raise ValueError("all input must be Dataset or DataArray objects")
9891094

9901095

991-
# TODO: this typing is too restrictive since it cannot deal with mixed
992-
# DataArray and Dataset types...? Is this a problem?
993-
def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]:
1096+
@overload
1097+
def broadcast(
1098+
obj1: T_Obj1, /, *, exclude: str | Iterable[Hashable] | None = None
1099+
) -> tuple[T_Obj1]:
1100+
...
1101+
1102+
1103+
@overload
1104+
def broadcast( # type: ignore[misc]
1105+
obj1: T_Obj1, obj2: T_Obj2, /, *, exclude: str | Iterable[Hashable] | None = None
1106+
) -> tuple[T_Obj1, T_Obj2]:
1107+
...
1108+
1109+
1110+
@overload
1111+
def broadcast( # type: ignore[misc]
1112+
obj1: T_Obj1,
1113+
obj2: T_Obj2,
1114+
obj3: T_Obj3,
1115+
/,
1116+
*,
1117+
exclude: str | Iterable[Hashable] | None = None,
1118+
) -> tuple[T_Obj1, T_Obj2, T_Obj3]:
1119+
...
1120+
1121+
1122+
@overload
1123+
def broadcast( # type: ignore[misc]
1124+
obj1: T_Obj1,
1125+
obj2: T_Obj2,
1126+
obj3: T_Obj3,
1127+
obj4: T_Obj4,
1128+
/,
1129+
*,
1130+
exclude: str | Iterable[Hashable] | None = None,
1131+
) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4]:
1132+
...
1133+
1134+
1135+
@overload
1136+
def broadcast( # type: ignore[misc]
1137+
obj1: T_Obj1,
1138+
obj2: T_Obj2,
1139+
obj3: T_Obj3,
1140+
obj4: T_Obj4,
1141+
obj5: T_Obj5,
1142+
/,
1143+
*,
1144+
exclude: str | Iterable[Hashable] | None = None,
1145+
) -> tuple[T_Obj1, T_Obj2, T_Obj3, T_Obj4, T_Obj5]:
1146+
...
1147+
1148+
1149+
@overload
1150+
def broadcast(
1151+
*args: T_Alignable, exclude: str | Iterable[Hashable] | None = None
1152+
) -> tuple[T_Alignable, ...]:
1153+
...
1154+
1155+
1156+
def broadcast( # type: ignore[misc]
1157+
*args: T_Alignable, exclude: str | Iterable[Hashable] | None = None
1158+
) -> tuple[T_Alignable, ...]:
9941159
"""Explicitly broadcast any number of DataArray or Dataset objects against
9951160
one another.
9961161
@@ -1004,7 +1169,7 @@ def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]:
10041169
----------
10051170
*args : DataArray or Dataset
10061171
Arrays to broadcast against each other.
1007-
exclude : sequence of str, optional
1172+
exclude : str, iterable of hashable or None, optional
10081173
Dimensions that must not be broadcasted
10091174
10101175
Returns

xarray/core/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ def where(self, cond: Any, other: Any = dtypes.NA, drop: bool = False) -> Self:
11631163
f"cond argument is {cond!r} but must be a {Dataset!r} or {DataArray!r} (or a callable than returns one)."
11641164
)
11651165

1166-
self, cond = align(self, cond) # type: ignore[assignment]
1166+
self, cond = align(self, cond)
11671167

11681168
def _dataarray_indexer(dim: Hashable) -> DataArray:
11691169
return cond.any(dim=(d for d in cond.dims if d != dim))

xarray/core/computation.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,14 @@ def apply_dataarray_vfunc(
289289
from xarray.core.dataarray import DataArray
290290

291291
if len(args) > 1:
292-
args = deep_align(
293-
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
292+
args = tuple(
293+
deep_align(
294+
args,
295+
join=join,
296+
copy=False,
297+
exclude=exclude_dims,
298+
raise_on_invalid=False,
299+
)
294300
)
295301

296302
objs = _all_of_type(args, DataArray)
@@ -506,8 +512,14 @@ def apply_dataset_vfunc(
506512
objs = _all_of_type(args, Dataset)
507513

508514
if len(args) > 1:
509-
args = deep_align(
510-
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
515+
args = tuple(
516+
deep_align(
517+
args,
518+
join=join,
519+
copy=False,
520+
exclude=exclude_dims,
521+
raise_on_invalid=False,
522+
)
511523
)
512524

513525
list_of_coords, list_of_indexes = build_output_coords_and_indexes(

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4639,7 +4639,7 @@ def _binary_op(
46394639
return NotImplemented
46404640
if isinstance(other, DataArray):
46414641
align_type = OPTIONS["arithmetic_join"]
4642-
self, other = align(self, other, join=align_type, copy=False) # type: ignore[type-var,assignment]
4642+
self, other = align(self, other, join=align_type, copy=False)
46434643
other_variable_or_arraylike: DaCompatible = getattr(other, "variable", other)
46444644
other_coords = getattr(other, "coords", None)
46454645

xarray/core/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7508,7 +7508,7 @@ def _binary_op(self, other, f, reflexive=False, join=None) -> Dataset:
75087508
return NotImplemented
75097509
align_type = OPTIONS["arithmetic_join"] if join is None else join
75107510
if isinstance(other, (DataArray, Dataset)):
7511-
self, other = align(self, other, join=align_type, copy=False) # type: ignore[assignment]
7511+
self, other = align(self, other, join=align_type, copy=False)
75127512
g = f if not reflexive else lambda x, y: f(y, x)
75137513
ds = self._calculate_binary_op(g, other, join=align_type)
75147514
keep_attrs = _get_keep_attrs(default=False)
@@ -7920,9 +7920,9 @@ def sortby(
79207920
else:
79217921
variables = variables
79227922
arrays = [v if isinstance(v, DataArray) else self[v] for v in variables]
7923-
aligned_vars = align(self, *arrays, join="left") # type: ignore[type-var]
7924-
aligned_self = cast(Self, aligned_vars[0])
7925-
aligned_other_vars: tuple[DataArray, ...] = aligned_vars[1:] # type: ignore[assignment]
7923+
aligned_vars = align(self, *arrays, join="left")
7924+
aligned_self = aligned_vars[0]
7925+
aligned_other_vars: tuple[DataArray, ...] = aligned_vars[1:]
79267926
vars_by_dim = defaultdict(list)
79277927
for data_array in aligned_other_vars:
79287928
if data_array.ndim != 1:

xarray/core/merge.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -474,10 +474,11 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik
474474
from xarray.core.dataarray import DataArray
475475
from xarray.core.dataset import Dataset
476476

477-
out = []
477+
out: list[DatasetLike] = []
478478
for obj in objects:
479+
variables: DatasetLike
479480
if isinstance(obj, (Dataset, Coordinates)):
480-
variables: DatasetLike = obj
481+
variables = obj
481482
else:
482483
variables = {}
483484
if isinstance(obj, PANDAS_TYPES):
@@ -491,7 +492,7 @@ def coerce_pandas_values(objects: Iterable[CoercibleMapping]) -> list[DatasetLik
491492

492493

493494
def _get_priority_vars_and_indexes(
494-
objects: list[DatasetLike],
495+
objects: Sequence[DatasetLike],
495496
priority_arg: int | None,
496497
compat: CompatOptions = "equals",
497498
) -> dict[Hashable, MergeElement]:
@@ -503,7 +504,7 @@ def _get_priority_vars_and_indexes(
503504
504505
Parameters
505506
----------
506-
objects : list of dict-like of Variable
507+
objects : sequence of dict-like of Variable
507508
Dictionaries in which to find the priority variables.
508509
priority_arg : int or None
509510
Integer object whose variable should take priority.

0 commit comments

Comments
 (0)