1
+ from collections .abc import Mapping
1
2
from functools import lru_cache
2
- from typing import NamedTuple , Tuple , Union
3
+ from typing import Any , NamedTuple , Sequence , Tuple , Union
3
4
from warnings import warn
4
5
5
6
from . import _array_module as xp
36
37
]
37
38
38
39
40
+ class EqualityMapping (Mapping ):
41
+ """
42
+ Mapping that uses equality for indexing
43
+
44
+ Typical mappings (e.g. the built-in dict) use hashing for indexing. This
45
+ isn't ideal for the Array API, as no __hash__() method is specified for
46
+ dtype objects - but __eq__() is!
47
+
48
+ See https://data-apis.org/array-api/latest/API_specification/data_types.html#data-type-objects
49
+ """
50
+
51
+ def __init__ (self , key_value_pairs : Sequence [Tuple [Any , Any ]]):
52
+ keys = [k for k , _ in key_value_pairs ]
53
+ for i , key in enumerate (keys ):
54
+ if not (key == key ): # specifically checking __eq__, not __neq__
55
+ raise ValueError ("Key {key!r} does not have equality with itself" )
56
+ other_keys = keys [:]
57
+ other_keys .pop (i )
58
+ for other_key in other_keys :
59
+ if key == other_key :
60
+ raise ValueError ("Key {key!r} has equality with key {other_key!r}" )
61
+ self ._key_value_pairs = key_value_pairs
62
+
63
+ def __getitem__ (self , key ):
64
+ for k , v in self ._key_value_pairs :
65
+ if key == k :
66
+ return v
67
+ else :
68
+ raise KeyError (f"{ key !r} not found" )
69
+
70
+ def __iter__ (self ):
71
+ return (k for k , _ in self ._key_value_pairs )
72
+
73
+ def __len__ (self ):
74
+ return len (self ._key_value_pairs )
75
+
76
+ def __str__ (self ):
77
+ return "{" + ", " .join (f"{ k !r} : { v !r} " for k , v in self ._key_value_pairs ) + "}"
78
+
79
+ def __repr__ (self ):
80
+ return f"EqualityMapping({ self } )"
81
+
82
+
39
83
_uint_names = ("uint8" , "uint16" , "uint32" , "uint64" )
40
84
_int_names = ("int8" , "int16" , "int32" , "int64" )
41
85
_float_names = ("float32" , "float64" )
51
95
bool_and_all_int_dtypes = (xp .bool ,) + all_int_dtypes
52
96
53
97
54
- dtype_to_name = { getattr (xp , name ): name for name in _dtype_names }
98
+ dtype_to_name = EqualityMapping ([( getattr (xp , name ), name ) for name in _dtype_names ])
55
99
56
100
57
- dtype_to_scalars = {
58
- xp .bool : [bool ],
59
- ** {d : [int ] for d in all_int_dtypes },
60
- ** {d : [int , float ] for d in float_dtypes },
61
- }
101
+ dtype_to_scalars = EqualityMapping (
102
+ [
103
+ (xp .bool , [bool ]),
104
+ * [(d , [int ]) for d in all_int_dtypes ],
105
+ * [(d , [int , float ]) for d in float_dtypes ],
106
+ ]
107
+ )
62
108
63
109
64
110
def is_int_dtype (dtype ):
@@ -90,31 +136,32 @@ class MinMax(NamedTuple):
90
136
max : Union [int , float ]
91
137
92
138
93
- dtype_ranges = {
94
- xp .int8 : MinMax (- 128 , + 127 ),
95
- xp .int16 : MinMax (- 32_768 , + 32_767 ),
96
- xp .int32 : MinMax (- 2_147_483_648 , + 2_147_483_647 ),
97
- xp .int64 : MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 ),
98
- xp .uint8 : MinMax (0 , + 255 ),
99
- xp .uint16 : MinMax (0 , + 65_535 ),
100
- xp .uint32 : MinMax (0 , + 4_294_967_295 ),
101
- xp .uint64 : MinMax (0 , + 18_446_744_073_709_551_615 ),
102
- xp .float32 : MinMax (- 3.4028234663852886e+38 , 3.4028234663852886e+38 ),
103
- xp .float64 : MinMax (- 1.7976931348623157e+308 , 1.7976931348623157e+308 ),
104
- }
139
+ dtype_ranges = EqualityMapping (
140
+ [
141
+ (xp .int8 , MinMax (- 128 , + 127 )),
142
+ (xp .int16 , MinMax (- 32_768 , + 32_767 )),
143
+ (xp .int32 , MinMax (- 2_147_483_648 , + 2_147_483_647 )),
144
+ (xp .int64 , MinMax (- 9_223_372_036_854_775_808 , + 9_223_372_036_854_775_807 )),
145
+ (xp .uint8 , MinMax (0 , + 255 )),
146
+ (xp .uint16 , MinMax (0 , + 65_535 )),
147
+ (xp .uint32 , MinMax (0 , + 4_294_967_295 )),
148
+ (xp .uint64 , MinMax (0 , + 18_446_744_073_709_551_615 )),
149
+ (xp .float32 , MinMax (- 3.4028234663852886e38 , 3.4028234663852886e38 )),
150
+ (xp .float64 , MinMax (- 1.7976931348623157e308 , 1.7976931348623157e308 )),
151
+ ]
152
+ )
105
153
106
- dtype_nbits = {
107
- ** { d : 8 for d in [xp .int8 , xp .uint8 ]},
108
- ** { d : 16 for d in [xp .int16 , xp .uint16 ]},
109
- ** { d : 32 for d in [xp .int32 , xp .uint32 , xp .float32 ]},
110
- ** { d : 64 for d in [xp .int64 , xp .uint64 , xp .float64 ]},
111
- }
154
+ dtype_nbits = EqualityMapping (
155
+ [( d , 8 ) for d in [xp .int8 , xp .uint8 ]]
156
+ + [( d , 16 ) for d in [xp .int16 , xp .uint16 ]]
157
+ + [( d , 32 ) for d in [xp .int32 , xp .uint32 , xp .float32 ]]
158
+ + [( d , 64 ) for d in [xp .int64 , xp .uint64 , xp .float64 ]]
159
+ )
112
160
113
161
114
- dtype_signed = {
115
- ** {d : True for d in int_dtypes },
116
- ** {d : False for d in uint_dtypes },
117
- }
162
+ dtype_signed = EqualityMapping (
163
+ [(d , True ) for d in int_dtypes ] + [(d , False ) for d in uint_dtypes ]
164
+ )
118
165
119
166
120
167
if isinstance (xp .asarray , _UndefinedStub ):
@@ -137,52 +184,51 @@ class MinMax(NamedTuple):
137
184
default_uint = xp .uint64
138
185
139
186
140
- _numeric_promotions = {
187
+ _numeric_promotions = [
141
188
# ints
142
- (xp .int8 , xp .int8 ): xp .int8 ,
143
- (xp .int8 , xp .int16 ): xp .int16 ,
144
- (xp .int8 , xp .int32 ): xp .int32 ,
145
- (xp .int8 , xp .int64 ): xp .int64 ,
146
- (xp .int16 , xp .int16 ): xp .int16 ,
147
- (xp .int16 , xp .int32 ): xp .int32 ,
148
- (xp .int16 , xp .int64 ): xp .int64 ,
149
- (xp .int32 , xp .int32 ): xp .int32 ,
150
- (xp .int32 , xp .int64 ): xp .int64 ,
151
- (xp .int64 , xp .int64 ): xp .int64 ,
189
+ (( xp .int8 , xp .int8 ), xp .int8 ) ,
190
+ (( xp .int8 , xp .int16 ), xp .int16 ) ,
191
+ (( xp .int8 , xp .int32 ), xp .int32 ) ,
192
+ (( xp .int8 , xp .int64 ), xp .int64 ) ,
193
+ (( xp .int16 , xp .int16 ), xp .int16 ) ,
194
+ (( xp .int16 , xp .int32 ), xp .int32 ) ,
195
+ (( xp .int16 , xp .int64 ), xp .int64 ) ,
196
+ (( xp .int32 , xp .int32 ), xp .int32 ) ,
197
+ (( xp .int32 , xp .int64 ), xp .int64 ) ,
198
+ (( xp .int64 , xp .int64 ), xp .int64 ) ,
152
199
# uints
153
- (xp .uint8 , xp .uint8 ): xp .uint8 ,
154
- (xp .uint8 , xp .uint16 ): xp .uint16 ,
155
- (xp .uint8 , xp .uint32 ): xp .uint32 ,
156
- (xp .uint8 , xp .uint64 ): xp .uint64 ,
157
- (xp .uint16 , xp .uint16 ): xp .uint16 ,
158
- (xp .uint16 , xp .uint32 ): xp .uint32 ,
159
- (xp .uint16 , xp .uint64 ): xp .uint64 ,
160
- (xp .uint32 , xp .uint32 ): xp .uint32 ,
161
- (xp .uint32 , xp .uint64 ): xp .uint64 ,
162
- (xp .uint64 , xp .uint64 ): xp .uint64 ,
200
+ (( xp .uint8 , xp .uint8 ), xp .uint8 ) ,
201
+ (( xp .uint8 , xp .uint16 ), xp .uint16 ) ,
202
+ (( xp .uint8 , xp .uint32 ), xp .uint32 ) ,
203
+ (( xp .uint8 , xp .uint64 ), xp .uint64 ) ,
204
+ (( xp .uint16 , xp .uint16 ), xp .uint16 ) ,
205
+ (( xp .uint16 , xp .uint32 ), xp .uint32 ) ,
206
+ (( xp .uint16 , xp .uint64 ), xp .uint64 ) ,
207
+ (( xp .uint32 , xp .uint32 ), xp .uint32 ) ,
208
+ (( xp .uint32 , xp .uint64 ), xp .uint64 ) ,
209
+ (( xp .uint64 , xp .uint64 ), xp .uint64 ) ,
163
210
# ints and uints (mixed sign)
164
- (xp .int8 , xp .uint8 ): xp .int16 ,
165
- (xp .int8 , xp .uint16 ): xp .int32 ,
166
- (xp .int8 , xp .uint32 ): xp .int64 ,
167
- (xp .int16 , xp .uint8 ): xp .int16 ,
168
- (xp .int16 , xp .uint16 ): xp .int32 ,
169
- (xp .int16 , xp .uint32 ): xp .int64 ,
170
- (xp .int32 , xp .uint8 ): xp .int32 ,
171
- (xp .int32 , xp .uint16 ): xp .int32 ,
172
- (xp .int32 , xp .uint32 ): xp .int64 ,
173
- (xp .int64 , xp .uint8 ): xp .int64 ,
174
- (xp .int64 , xp .uint16 ): xp .int64 ,
175
- (xp .int64 , xp .uint32 ): xp .int64 ,
211
+ (( xp .int8 , xp .uint8 ), xp .int16 ) ,
212
+ (( xp .int8 , xp .uint16 ), xp .int32 ) ,
213
+ (( xp .int8 , xp .uint32 ), xp .int64 ) ,
214
+ (( xp .int16 , xp .uint8 ), xp .int16 ) ,
215
+ (( xp .int16 , xp .uint16 ), xp .int32 ) ,
216
+ (( xp .int16 , xp .uint32 ), xp .int64 ) ,
217
+ (( xp .int32 , xp .uint8 ), xp .int32 ) ,
218
+ (( xp .int32 , xp .uint16 ), xp .int32 ) ,
219
+ (( xp .int32 , xp .uint32 ), xp .int64 ) ,
220
+ (( xp .int64 , xp .uint8 ), xp .int64 ) ,
221
+ (( xp .int64 , xp .uint16 ), xp .int64 ) ,
222
+ (( xp .int64 , xp .uint32 ), xp .int64 ) ,
176
223
# floats
177
- (xp .float32 , xp .float32 ): xp .float32 ,
178
- (xp .float32 , xp .float64 ): xp .float64 ,
179
- (xp .float64 , xp .float64 ): xp .float64 ,
180
- }
181
- promotion_table = {
182
- (xp .bool , xp .bool ): xp .bool ,
183
- ** _numeric_promotions ,
184
- ** {(d2 , d1 ): res for (d1 , d2 ), res in _numeric_promotions .items ()},
185
- }
224
+ ((xp .float32 , xp .float32 ), xp .float32 ),
225
+ ((xp .float32 , xp .float64 ), xp .float64 ),
226
+ ((xp .float64 , xp .float64 ), xp .float64 ),
227
+ ]
228
+ _numeric_promotions += [((d2 , d1 ), res ) for (d1 , d2 ), res in _numeric_promotions ]
229
+ _promotion_table = list (set (_numeric_promotions ))
230
+ _promotion_table .insert (0 , ((xp .bool , xp .bool ), xp .bool ))
231
+ promotion_table = EqualityMapping (_promotion_table )
186
232
187
233
188
234
def result_type (* dtypes : DataType ):
0 commit comments