Skip to content

Commit 4278dfb

Browse files
committed
TYP: dask.array typing fixes and improvements
1 parent ba0b4e5 commit 4278dfb

File tree

4 files changed

+188
-100
lines changed

4 files changed

+188
-100
lines changed

array_api_compat/dask/array/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
from dask.array import * # noqa: F403
1+
from typing import Final
2+
3+
from dask.array import * # noqa: F403
24

35
# These imports may overwrite names from the import * above.
4-
from ._aliases import * # noqa: F403
6+
from ._aliases import * # noqa: F403
57

6-
__array_api_version__ = '2024.12'
8+
__array_api_version__: Final = "2024.12"
79

810
# See the comment in the numpy __init__.py
911
__import__(__package__ + '.linalg')

array_api_compat/dask/array/_aliases.py

Lines changed: 92 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,38 @@
1+
# pyright: reportPrivateUsage=false
2+
# pyright: reportUnknownArgumentType=false
3+
# pyright: reportUnknownMemberType=false
4+
# pyright: reportUnknownVariableType=false
5+
16
from __future__ import annotations
27

3-
from typing import Callable, Optional, Union
8+
from builtins import bool as py_bool
9+
from collections.abc import Callable
10+
from typing import TYPE_CHECKING, Any
11+
12+
if TYPE_CHECKING:
13+
from typing_extensions import TypeIs
414

15+
import dask.array as da
516
import numpy as np
17+
from numpy import bool_ as bool
618
from numpy import (
7-
# dtypes
8-
bool_ as bool,
19+
can_cast,
20+
complex64,
21+
complex128,
922
float32,
1023
float64,
1124
int8,
1225
int16,
1326
int32,
1427
int64,
28+
result_type,
1529
uint8,
1630
uint16,
1731
uint32,
1832
uint64,
19-
complex64,
20-
complex128,
21-
can_cast,
22-
result_type,
2333
)
24-
import dask.array as da
2534

35+
from ..._internal import get_xp
2636
from ...common import _aliases, _helpers, array_namespace
2737
from ...common._typing import (
2838
Array,
@@ -31,7 +41,6 @@
3141
NestedSequence,
3242
SupportsBufferProtocol,
3343
)
34-
from ..._internal import get_xp
3544
from ._info import __array_namespace_info__
3645

3746
isdtype = get_xp(np)(_aliases.isdtype)
@@ -44,8 +53,8 @@ def astype(
4453
dtype: DType,
4554
/,
4655
*,
47-
copy: bool = True,
48-
device: Optional[Device] = None,
56+
copy: py_bool = True,
57+
device: Device | None = None,
4958
) -> Array:
5059
"""
5160
Array API compatibility wrapper for astype().
@@ -69,14 +78,14 @@ def astype(
6978
# not pass stop/step as keyword arguments, which will cause
7079
# an error with dask
7180
def arange(
72-
start: Union[int, float],
81+
start: float,
7382
/,
74-
stop: Optional[Union[int, float]] = None,
75-
step: Union[int, float] = 1,
83+
stop: float | None = None,
84+
step: float = 1,
7685
*,
77-
dtype: Optional[DType] = None,
78-
device: Optional[Device] = None,
79-
**kwargs,
86+
dtype: DType | None = None,
87+
device: Device | None = None,
88+
**kwargs: object,
8089
) -> Array:
8190
"""
8291
Array API compatibility wrapper for arange().
@@ -87,7 +96,7 @@ def arange(
8796
# TODO: respect device keyword?
8897
_helpers._check_device(da, device)
8998

90-
args = [start]
99+
args: list[Any] = [start]
91100
if stop is not None:
92101
args.append(stop)
93102
else:
@@ -137,18 +146,13 @@ def arange(
137146

138147
# asarray also adds the copy keyword, which is not present in numpy 1.0.
139148
def asarray(
140-
obj: (
141-
Array
142-
| bool | int | float | complex
143-
| NestedSequence[bool | int | float | complex]
144-
| SupportsBufferProtocol
145-
),
149+
obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol,
146150
/,
147151
*,
148-
dtype: Optional[DType] = None,
149-
device: Optional[Device] = None,
150-
copy: Optional[bool] = None,
151-
**kwargs,
152+
dtype: DType | None = None,
153+
device: Device | None = None,
154+
copy: py_bool | None = None,
155+
**kwargs: object,
152156
) -> Array:
153157
"""
154158
Array API compatibility wrapper for asarray().
@@ -164,7 +168,7 @@ def asarray(
164168
if copy is False:
165169
raise ValueError("Unable to avoid copy when changing dtype")
166170
obj = obj.astype(dtype)
167-
return obj.copy() if copy else obj
171+
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
168172

169173
if copy is False:
170174
raise NotImplementedError(
@@ -177,22 +181,21 @@ def asarray(
177181
return da.from_array(obj)
178182

179183

180-
from dask.array import (
181-
# Element wise aliases
182-
arccos as acos,
183-
arccosh as acosh,
184-
arcsin as asin,
185-
arcsinh as asinh,
186-
arctan as atan,
187-
arctan2 as atan2,
188-
arctanh as atanh,
189-
left_shift as bitwise_left_shift,
190-
right_shift as bitwise_right_shift,
191-
invert as bitwise_invert,
192-
power as pow,
193-
# Other
194-
concatenate as concat,
195-
)
184+
# Element wise aliases
185+
from dask.array import arccos as acos
186+
from dask.array import arccosh as acosh
187+
from dask.array import arcsin as asin
188+
from dask.array import arcsinh as asinh
189+
from dask.array import arctan as atan
190+
from dask.array import arctan2 as atan2
191+
from dask.array import arctanh as atanh
192+
193+
# Other
194+
from dask.array import concatenate as concat
195+
from dask.array import invert as bitwise_invert
196+
from dask.array import left_shift as bitwise_left_shift
197+
from dask.array import power as pow
198+
from dask.array import right_shift as bitwise_right_shift
196199

197200

198201
# dask.array.clip does not work unless all three arguments are provided.
@@ -202,8 +205,8 @@ def asarray(
202205
def clip(
203206
x: Array,
204207
/,
205-
min: Optional[Union[int, float, Array]] = None,
206-
max: Optional[Union[int, float, Array]] = None,
208+
min: float | Array | None = None,
209+
max: float | Array | None = None,
207210
) -> Array:
208211
"""
209212
Array API compatibility wrapper for clip().
@@ -212,8 +215,8 @@ def clip(
212215
specification for more details.
213216
"""
214217

215-
def _isscalar(a):
216-
return isinstance(a, (int, float, type(None)))
218+
def _isscalar(a: float | Array | None, /) -> TypeIs[float | None]:
219+
return a is None or isinstance(a, (int, float))
217220

218221
min_shape = () if _isscalar(min) else min.shape
219222
max_shape = () if _isscalar(max) else max.shape
@@ -266,7 +269,12 @@ def _ensure_single_chunk(x: Array, axis: int) -> tuple[Array, Callable[[Array],
266269

267270

268271
def sort(
269-
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
272+
x: Array,
273+
/,
274+
*,
275+
axis: int = -1,
276+
descending: py_bool = False,
277+
stable: py_bool = True,
270278
) -> Array:
271279
"""
272280
Array API compatibility layer around the lack of sort() in Dask.
@@ -296,7 +304,12 @@ def sort(
296304

297305

298306
def argsort(
299-
x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True
307+
x: Array,
308+
/,
309+
*,
310+
axis: int = -1,
311+
descending: py_bool = False,
312+
stable: py_bool = True,
300313
) -> Array:
301314
"""
302315
Array API compatibility layer around the lack of argsort() in Dask.
@@ -330,25 +343,34 @@ def argsort(
330343
# dask.array.count_nonzero does not have keepdims
331344
def count_nonzero(
332345
x: Array,
333-
axis=None,
334-
keepdims=False
346+
axis: int | None = None,
347+
keepdims: py_bool = False,
335348
) -> Array:
336-
result = da.count_nonzero(x, axis)
337-
if keepdims:
338-
if axis is None:
339-
return da.reshape(result, [1]*x.ndim)
340-
return da.expand_dims(result, axis)
341-
return result
342-
343-
349+
result = da.count_nonzero(x, axis)
350+
if keepdims:
351+
if axis is None:
352+
return da.reshape(result, [1] * x.ndim)
353+
return da.expand_dims(result, axis)
354+
return result
355+
356+
357+
__all__ = [
358+
"__array_namespace_info__",
359+
"count_nonzero",
360+
"bool",
361+
"int8", "int16", "int32", "int64",
362+
"uint8", "uint16", "uint32", "uint64",
363+
"float32", "float64",
364+
"complex64", "complex128",
365+
"asarray", "astype", "can_cast", "result_type",
366+
"pow",
367+
"concat",
368+
"acos", "acosh", "asin", "asinh", "atan", "atan2", "atanh",
369+
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
370+
] # fmt: skip
371+
__all__ += _aliases.__all__
372+
_all_ignore = ["array_namespace", "get_xp", "da", "np"]
344373

345-
__all__ = _aliases.__all__ + [
346-
'__array_namespace_info__', 'asarray', 'astype', 'acos',
347-
'acosh', 'asin', 'asinh', 'atan', 'atan2',
348-
'atanh', 'bitwise_left_shift', 'bitwise_invert',
349-
'bitwise_right_shift', 'concat', 'pow', 'can_cast',
350-
'result_type', 'bool', 'float32', 'float64', 'int8', 'int16', 'int32', 'int64',
351-
'uint8', 'uint16', 'uint32', 'uint64', 'complex64', 'complex128',
352-
'can_cast', 'count_nonzero', 'result_type']
353374

354-
_all_ignore = ["array_namespace", "get_xp", "da", "np"]
375+
def __dir__() -> list[str]:
376+
return __all__

0 commit comments

Comments
 (0)