1
+ # pyright: reportPrivateUsage=false
2
+ # pyright: reportUnknownArgumentType=false
3
+ # pyright: reportUnknownMemberType=false
4
+ # pyright: reportUnknownVariableType=false
5
+
1
6
from __future__ import annotations
2
7
3
- from typing import Callable , Optional , Union
8
+ from builtins import bool as py_bool
9
+ from collections .abc import Callable
10
+ from typing import TYPE_CHECKING , Any
11
+
12
+ if TYPE_CHECKING :
13
+ from typing_extensions import TypeIs
4
14
15
+ import dask .array as da
5
16
import numpy as np
17
+ from numpy import bool_ as bool
6
18
from numpy import (
7
- # dtypes
8
- bool_ as bool ,
19
+ can_cast ,
20
+ complex64 ,
21
+ complex128 ,
9
22
float32 ,
10
23
float64 ,
11
24
int8 ,
12
25
int16 ,
13
26
int32 ,
14
27
int64 ,
28
+ result_type ,
15
29
uint8 ,
16
30
uint16 ,
17
31
uint32 ,
18
32
uint64 ,
19
- complex64 ,
20
- complex128 ,
21
- can_cast ,
22
- result_type ,
23
33
)
24
- import dask .array as da
25
34
35
+ from ..._internal import get_xp
26
36
from ...common import _aliases , _helpers , array_namespace
27
37
from ...common ._typing import (
28
38
Array ,
31
41
NestedSequence ,
32
42
SupportsBufferProtocol ,
33
43
)
34
- from ..._internal import get_xp
35
44
from ._info import __array_namespace_info__
36
45
37
46
isdtype = get_xp (np )(_aliases .isdtype )
@@ -44,8 +53,8 @@ def astype(
44
53
dtype : DType ,
45
54
/ ,
46
55
* ,
47
- copy : bool = True ,
48
- device : Optional [ Device ] = None ,
56
+ copy : py_bool = True ,
57
+ device : Device | None = None ,
49
58
) -> Array :
50
59
"""
51
60
Array API compatibility wrapper for astype().
@@ -69,14 +78,14 @@ def astype(
69
78
# not pass stop/step as keyword arguments, which will cause
70
79
# an error with dask
71
80
def arange (
72
- start : Union [ int , float ] ,
81
+ start : float ,
73
82
/ ,
74
- stop : Optional [ Union [ int , float ]] = None ,
75
- step : Union [ int , float ] = 1 ,
83
+ stop : float | None = None ,
84
+ step : float = 1 ,
76
85
* ,
77
- dtype : Optional [ DType ] = None ,
78
- device : Optional [ Device ] = None ,
79
- ** kwargs ,
86
+ dtype : DType | None = None ,
87
+ device : Device | None = None ,
88
+ ** kwargs : object ,
80
89
) -> Array :
81
90
"""
82
91
Array API compatibility wrapper for arange().
@@ -87,7 +96,7 @@ def arange(
87
96
# TODO: respect device keyword?
88
97
_helpers ._check_device (da , device )
89
98
90
- args = [start ]
99
+ args : list [ Any ] = [start ]
91
100
if stop is not None :
92
101
args .append (stop )
93
102
else :
@@ -137,18 +146,13 @@ def arange(
137
146
138
147
# asarray also adds the copy keyword, which is not present in numpy 1.0.
139
148
def asarray (
140
- obj : (
141
- Array
142
- | bool | int | float | complex
143
- | NestedSequence [bool | int | float | complex ]
144
- | SupportsBufferProtocol
145
- ),
149
+ obj : complex | NestedSequence [complex ] | Array | SupportsBufferProtocol ,
146
150
/ ,
147
151
* ,
148
- dtype : Optional [ DType ] = None ,
149
- device : Optional [ Device ] = None ,
150
- copy : Optional [ bool ] = None ,
151
- ** kwargs ,
152
+ dtype : DType | None = None ,
153
+ device : Device | None = None ,
154
+ copy : py_bool | None = None ,
155
+ ** kwargs : object ,
152
156
) -> Array :
153
157
"""
154
158
Array API compatibility wrapper for asarray().
@@ -164,7 +168,7 @@ def asarray(
164
168
if copy is False :
165
169
raise ValueError ("Unable to avoid copy when changing dtype" )
166
170
obj = obj .astype (dtype )
167
- return obj .copy () if copy else obj
171
+ return obj .copy () if copy else obj # pyright: ignore[reportAttributeAccessIssue]
168
172
169
173
if copy is False :
170
174
raise NotImplementedError (
@@ -177,22 +181,21 @@ def asarray(
177
181
return da .from_array (obj )
178
182
179
183
180
- from dask .array import (
181
- # Element wise aliases
182
- arccos as acos ,
183
- arccosh as acosh ,
184
- arcsin as asin ,
185
- arcsinh as asinh ,
186
- arctan as atan ,
187
- arctan2 as atan2 ,
188
- arctanh as atanh ,
189
- left_shift as bitwise_left_shift ,
190
- right_shift as bitwise_right_shift ,
191
- invert as bitwise_invert ,
192
- power as pow ,
193
- # Other
194
- concatenate as concat ,
195
- )
184
+ # Element wise aliases
185
+ from dask .array import arccos as acos
186
+ from dask .array import arccosh as acosh
187
+ from dask .array import arcsin as asin
188
+ from dask .array import arcsinh as asinh
189
+ from dask .array import arctan as atan
190
+ from dask .array import arctan2 as atan2
191
+ from dask .array import arctanh as atanh
192
+
193
+ # Other
194
+ from dask .array import concatenate as concat
195
+ from dask .array import invert as bitwise_invert
196
+ from dask .array import left_shift as bitwise_left_shift
197
+ from dask .array import power as pow
198
+ from dask .array import right_shift as bitwise_right_shift
196
199
197
200
198
201
# dask.array.clip does not work unless all three arguments are provided.
@@ -202,8 +205,8 @@ def asarray(
202
205
def clip (
203
206
x : Array ,
204
207
/ ,
205
- min : Optional [ Union [ int , float , Array ]] = None ,
206
- max : Optional [ Union [ int , float , Array ]] = None ,
208
+ min : float | Array | None = None ,
209
+ max : float | Array | None = None ,
207
210
) -> Array :
208
211
"""
209
212
Array API compatibility wrapper for clip().
@@ -212,8 +215,8 @@ def clip(
212
215
specification for more details.
213
216
"""
214
217
215
- def _isscalar (a ) :
216
- return isinstance (a , (int , float , type ( None ) ))
218
+ def _isscalar (a : float | Array | None , / ) -> TypeIs [ float | None ] :
219
+ return a is None or isinstance (a , (int , float ))
217
220
218
221
min_shape = () if _isscalar (min ) else min .shape
219
222
max_shape = () if _isscalar (max ) else max .shape
@@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array],
266
269
267
270
268
271
def sort (
269
- x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
272
+ x : Array ,
273
+ / ,
274
+ * ,
275
+ axis : int = - 1 ,
276
+ descending : py_bool = False ,
277
+ stable : py_bool = True ,
270
278
) -> Array :
271
279
"""
272
280
Array API compatibility layer around the lack of sort() in Dask.
@@ -296,7 +304,12 @@ def sort(
296
304
297
305
298
306
def argsort (
299
- x : Array , / , * , axis : int = - 1 , descending : bool = False , stable : bool = True
307
+ x : Array ,
308
+ / ,
309
+ * ,
310
+ axis : int = - 1 ,
311
+ descending : py_bool = False ,
312
+ stable : py_bool = True ,
300
313
) -> Array :
301
314
"""
302
315
Array API compatibility layer around the lack of argsort() in Dask.
@@ -330,25 +343,34 @@ def argsort(
330
343
# dask.array.count_nonzero does not have keepdims
331
344
def count_nonzero (
332
345
x : Array ,
333
- axis = None ,
334
- keepdims = False
346
+ axis : int | None = None ,
347
+ keepdims : py_bool = False ,
335
348
) -> Array :
336
- result = da .count_nonzero (x , axis )
337
- if keepdims :
338
- if axis is None :
339
- return da .reshape (result , [1 ]* x .ndim )
340
- return da .expand_dims (result , axis )
341
- return result
342
-
343
-
349
+ result = da .count_nonzero (x , axis )
350
+ if keepdims :
351
+ if axis is None :
352
+ return da .reshape (result , [1 ] * x .ndim )
353
+ return da .expand_dims (result , axis )
354
+ return result
355
+
356
+
357
+ __all__ = [
358
+ "__array_namespace_info__" ,
359
+ "count_nonzero" ,
360
+ "bool" ,
361
+ "int8" , "int16" , "int32" , "int64" ,
362
+ "uint8" , "uint16" , "uint32" , "uint64" ,
363
+ "float32" , "float64" ,
364
+ "complex64" , "complex128" ,
365
+ "asarray" , "astype" , "can_cast" , "result_type" ,
366
+ "pow" ,
367
+ "concat" ,
368
+ "acos" , "acosh" , "asin" , "asinh" , "atan" , "atan2" , "atanh" ,
369
+ "bitwise_left_shift" , "bitwise_right_shift" , "bitwise_invert" ,
370
+ ] # fmt: skip
371
+ __all__ += _aliases .__all__
372
+ _all_ignore = ["array_namespace" , "get_xp" , "da" , "np" ]
344
373
345
- __all__ = _aliases .__all__ + [
346
- '__array_namespace_info__' , 'asarray' , 'astype' , 'acos' ,
347
- 'acosh' , 'asin' , 'asinh' , 'atan' , 'atan2' ,
348
- 'atanh' , 'bitwise_left_shift' , 'bitwise_invert' ,
349
- 'bitwise_right_shift' , 'concat' , 'pow' , 'can_cast' ,
350
- 'result_type' , 'bool' , 'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' ,
351
- 'uint8' , 'uint16' , 'uint32' , 'uint64' , 'complex64' , 'complex128' ,
352
- 'can_cast' , 'count_nonzero' , 'result_type' ]
353
374
354
- _all_ignore = ["array_namespace" , "get_xp" , "da" , "np" ]
375
+ def __dir__ () -> list [str ]:
376
+ return __all__
0 commit comments