5
5
from collections import defaultdict
6
6
from collections .abc import Hashable , Iterable , Mapping
7
7
from contextlib import suppress
8
- from typing import TYPE_CHECKING , Any , Callable , Generic , cast
8
+ from typing import TYPE_CHECKING , Any , Callable , Final , Generic , TypeVar , cast , overload
9
9
10
10
import numpy as np
11
11
import pandas as pd
26
26
if TYPE_CHECKING :
27
27
from xarray .core .dataarray import DataArray
28
28
from xarray .core .dataset import Dataset
29
- from xarray .core .types import JoinOptions , T_DataArray , T_Dataset , T_DuckArray
29
+ from xarray .core .types import (
30
+ Alignable ,
31
+ JoinOptions ,
32
+ T_DataArray ,
33
+ T_Dataset ,
34
+ T_DuckArray ,
35
+ )
30
36
31
37
32
38
def reindex_variables (
@@ -128,7 +134,7 @@ def __init__(
128
134
objects : Iterable [T_Alignable ],
129
135
join : str = "inner" ,
130
136
indexes : Mapping [Any , Any ] | None = None ,
131
- exclude_dims : Iterable = frozenset (),
137
+ exclude_dims : str | Iterable [ Hashable ] = frozenset (),
132
138
exclude_vars : Iterable [Hashable ] = frozenset (),
133
139
method : str | None = None ,
134
140
tolerance : int | float | Iterable [int | float ] | None = None ,
@@ -576,12 +582,111 @@ def align(self) -> None:
576
582
self .reindex_all ()
577
583
578
584
585
+ T_Obj1 = TypeVar ("T_Obj1" , bound = "Alignable" )
586
+ T_Obj2 = TypeVar ("T_Obj2" , bound = "Alignable" )
587
+ T_Obj3 = TypeVar ("T_Obj3" , bound = "Alignable" )
588
+ T_Obj4 = TypeVar ("T_Obj4" , bound = "Alignable" )
589
+ T_Obj5 = TypeVar ("T_Obj5" , bound = "Alignable" )
590
+
591
+
592
+ @overload
593
+ def align (
594
+ obj1 : T_Obj1 ,
595
+ / ,
596
+ * ,
597
+ join : JoinOptions = "inner" ,
598
+ copy : bool = True ,
599
+ indexes = None ,
600
+ exclude : str | Iterable [Hashable ] = frozenset (),
601
+ fill_value = dtypes .NA ,
602
+ ) -> tuple [T_Obj1 ]:
603
+ ...
604
+
605
+
606
+ @overload
607
+ def align ( # type: ignore[misc]
608
+ obj1 : T_Obj1 ,
609
+ obj2 : T_Obj2 ,
610
+ / ,
611
+ * ,
612
+ join : JoinOptions = "inner" ,
613
+ copy : bool = True ,
614
+ indexes = None ,
615
+ exclude : str | Iterable [Hashable ] = frozenset (),
616
+ fill_value = dtypes .NA ,
617
+ ) -> tuple [T_Obj1 , T_Obj2 ]:
618
+ ...
619
+
620
+
621
+ @overload
622
+ def align ( # type: ignore[misc]
623
+ obj1 : T_Obj1 ,
624
+ obj2 : T_Obj2 ,
625
+ obj3 : T_Obj3 ,
626
+ / ,
627
+ * ,
628
+ join : JoinOptions = "inner" ,
629
+ copy : bool = True ,
630
+ indexes = None ,
631
+ exclude : str | Iterable [Hashable ] = frozenset (),
632
+ fill_value = dtypes .NA ,
633
+ ) -> tuple [T_Obj1 , T_Obj2 , T_Obj3 ]:
634
+ ...
635
+
636
+
637
+ @overload
638
+ def align ( # type: ignore[misc]
639
+ obj1 : T_Obj1 ,
640
+ obj2 : T_Obj2 ,
641
+ obj3 : T_Obj3 ,
642
+ obj4 : T_Obj4 ,
643
+ / ,
644
+ * ,
645
+ join : JoinOptions = "inner" ,
646
+ copy : bool = True ,
647
+ indexes = None ,
648
+ exclude : str | Iterable [Hashable ] = frozenset (),
649
+ fill_value = dtypes .NA ,
650
+ ) -> tuple [T_Obj1 , T_Obj2 , T_Obj3 , T_Obj4 ]:
651
+ ...
652
+
653
+
654
+ @overload
655
+ def align ( # type: ignore[misc]
656
+ obj1 : T_Obj1 ,
657
+ obj2 : T_Obj2 ,
658
+ obj3 : T_Obj3 ,
659
+ obj4 : T_Obj4 ,
660
+ obj5 : T_Obj5 ,
661
+ / ,
662
+ * ,
663
+ join : JoinOptions = "inner" ,
664
+ copy : bool = True ,
665
+ indexes = None ,
666
+ exclude : str | Iterable [Hashable ] = frozenset (),
667
+ fill_value = dtypes .NA ,
668
+ ) -> tuple [T_Obj1 , T_Obj2 , T_Obj3 , T_Obj4 , T_Obj5 ]:
669
+ ...
670
+
671
+
672
+ @overload
579
673
def align (
580
674
* objects : T_Alignable ,
581
675
join : JoinOptions = "inner" ,
582
676
copy : bool = True ,
583
677
indexes = None ,
584
- exclude = frozenset (),
678
+ exclude : str | Iterable [Hashable ] = frozenset (),
679
+ fill_value = dtypes .NA ,
680
+ ) -> tuple [T_Alignable , ...]:
681
+ ...
682
+
683
+
684
+ def align ( # type: ignore[misc]
685
+ * objects : T_Alignable ,
686
+ join : JoinOptions = "inner" ,
687
+ copy : bool = True ,
688
+ indexes = None ,
689
+ exclude : str | Iterable [Hashable ] = frozenset (),
585
690
fill_value = dtypes .NA ,
586
691
) -> tuple [T_Alignable , ...]:
587
692
"""
@@ -620,7 +725,7 @@ def align(
620
725
indexes : dict-like, optional
621
726
Any indexes explicitly provided with the `indexes` argument should be
622
727
used in preference to the aligned indexes.
623
- exclude : sequence of str , optional
728
+ exclude : str, iterable of hashable or None , optional
624
729
Dimensions that must be excluded from alignment
625
730
fill_value : scalar or dict-like, optional
626
731
Value to use for newly missing values. If a dict-like, maps
@@ -787,12 +892,12 @@ def align(
787
892
def deep_align (
788
893
objects : Iterable [Any ],
789
894
join : JoinOptions = "inner" ,
790
- copy = True ,
895
+ copy : bool = True ,
791
896
indexes = None ,
792
- exclude = frozenset (),
793
- raise_on_invalid = True ,
897
+ exclude : str | Iterable [ Hashable ] = frozenset (),
898
+ raise_on_invalid : bool = True ,
794
899
fill_value = dtypes .NA ,
795
- ):
900
+ ) -> list [ Any ] :
796
901
"""Align objects for merging, recursing into dictionary values.
797
902
798
903
This function is not public API.
@@ -807,12 +912,12 @@ def deep_align(
807
912
def is_alignable (obj ):
808
913
return isinstance (obj , (Coordinates , DataArray , Dataset ))
809
914
810
- positions = []
811
- keys = []
812
- out = []
813
- targets = []
814
- no_key = object ()
815
- not_replaced = object ()
915
+ positions : list [ int ] = []
916
+ keys : list [ type [ object ] | Hashable ] = []
917
+ out : list [ Any ] = []
918
+ targets : list [ Alignable ] = []
919
+ no_key : Final = object ()
920
+ not_replaced : Final = object ()
816
921
for position , variables in enumerate (objects ):
817
922
if is_alignable (variables ):
818
923
positions .append (position )
@@ -857,7 +962,7 @@ def is_alignable(obj):
857
962
if key is no_key :
858
963
out [position ] = aligned_obj
859
964
else :
860
- out [position ][key ] = aligned_obj # type: ignore[index] # maybe someone can fix this?
965
+ out [position ][key ] = aligned_obj
861
966
862
967
return out
863
968
@@ -988,9 +1093,69 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:
988
1093
raise ValueError ("all input must be Dataset or DataArray objects" )
989
1094
990
1095
991
- # TODO: this typing is too restrictive since it cannot deal with mixed
992
- # DataArray and Dataset types...? Is this a problem?
993
- def broadcast (* args : T_Alignable , exclude = None ) -> tuple [T_Alignable , ...]:
1096
+ @overload
1097
+ def broadcast (
1098
+ obj1 : T_Obj1 , / , * , exclude : str | Iterable [Hashable ] | None = None
1099
+ ) -> tuple [T_Obj1 ]:
1100
+ ...
1101
+
1102
+
1103
+ @overload
1104
+ def broadcast ( # type: ignore[misc]
1105
+ obj1 : T_Obj1 , obj2 : T_Obj2 , / , * , exclude : str | Iterable [Hashable ] | None = None
1106
+ ) -> tuple [T_Obj1 , T_Obj2 ]:
1107
+ ...
1108
+
1109
+
1110
+ @overload
1111
+ def broadcast ( # type: ignore[misc]
1112
+ obj1 : T_Obj1 ,
1113
+ obj2 : T_Obj2 ,
1114
+ obj3 : T_Obj3 ,
1115
+ / ,
1116
+ * ,
1117
+ exclude : str | Iterable [Hashable ] | None = None ,
1118
+ ) -> tuple [T_Obj1 , T_Obj2 , T_Obj3 ]:
1119
+ ...
1120
+
1121
+
1122
+ @overload
1123
+ def broadcast ( # type: ignore[misc]
1124
+ obj1 : T_Obj1 ,
1125
+ obj2 : T_Obj2 ,
1126
+ obj3 : T_Obj3 ,
1127
+ obj4 : T_Obj4 ,
1128
+ / ,
1129
+ * ,
1130
+ exclude : str | Iterable [Hashable ] | None = None ,
1131
+ ) -> tuple [T_Obj1 , T_Obj2 , T_Obj3 , T_Obj4 ]:
1132
+ ...
1133
+
1134
+
1135
+ @overload
1136
+ def broadcast ( # type: ignore[misc]
1137
+ obj1 : T_Obj1 ,
1138
+ obj2 : T_Obj2 ,
1139
+ obj3 : T_Obj3 ,
1140
+ obj4 : T_Obj4 ,
1141
+ obj5 : T_Obj5 ,
1142
+ / ,
1143
+ * ,
1144
+ exclude : str | Iterable [Hashable ] | None = None ,
1145
+ ) -> tuple [T_Obj1 , T_Obj2 , T_Obj3 , T_Obj4 , T_Obj5 ]:
1146
+ ...
1147
+
1148
+
1149
+ @overload
1150
+ def broadcast (
1151
+ * args : T_Alignable , exclude : str | Iterable [Hashable ] | None = None
1152
+ ) -> tuple [T_Alignable , ...]:
1153
+ ...
1154
+
1155
+
1156
+ def broadcast ( # type: ignore[misc]
1157
+ * args : T_Alignable , exclude : str | Iterable [Hashable ] | None = None
1158
+ ) -> tuple [T_Alignable , ...]:
994
1159
"""Explicitly broadcast any number of DataArray or Dataset objects against
995
1160
one another.
996
1161
@@ -1004,7 +1169,7 @@ def broadcast(*args: T_Alignable, exclude=None) -> tuple[T_Alignable, ...]:
1004
1169
----------
1005
1170
*args : DataArray or Dataset
1006
1171
Arrays to broadcast against each other.
1007
- exclude : sequence of str , optional
1172
+ exclude : str, iterable of hashable or None , optional
1008
1173
Dimensions that must not be broadcasted
1009
1174
1010
1175
Returns
0 commit comments