diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index cd8d939f..b1925492 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object: specification for more details. """ - wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue] - return wrapped_f # pyright: ignore[reportReturnType] + wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType] return inner diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..f7bfc44d 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -5,11 +5,13 @@ from __future__ import annotations import inspect -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast +from collections.abc import Sequence +from types import NoneType +from typing import TYPE_CHECKING, Any, NamedTuple, cast from ._helpers import _check_device, array_namespace from ._helpers import device as _get_device -from ._helpers import is_cupy_namespace as _is_cupy_namespace +from ._helpers import is_cupy_namespace from ._typing import Array, Device, DType, Namespace if TYPE_CHECKING: @@ -381,8 +383,8 @@ def clip( # TODO: np.clip has other ufunc kwargs out: Array | None = None, ) -> Array: - def _isscalar(a: object) -> TypeIs[int | float | None]: - return isinstance(a, (int, float, type(None))) + def _isscalar(a: object) -> TypeIs[float | None]: + return isinstance(a, int | float | NoneType) min_shape = () if _isscalar(min) else min.shape max_shape = () if _isscalar(max) else max.shape @@ -450,7 +452,7 @@ def reshape( shape: tuple[int, ...], xp: Namespace, *, - copy: Optional[bool] = None, + copy: bool | None = None, **kwargs: object, ) -> Array: if copy is True: @@ -657,7 +659,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: out = xp.sign(x, **kwargs) # CuPy sign() does not propagate nans. See # https://github.com/data-apis/array-api-compat/issues/136 - if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): + if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): out[xp.isnan(x)] = xp.nan return out[()] @@ -720,7 +722,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "finfo", "iinfo", ] -_all_ignore = ["inspect", "array_namespace", "NamedTuple"] +_all_ignore = ["is_cupy_namespace", "inspect", "array_namespace", "NamedTuple"] def __dir__() -> list[str]: diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index db3e4cd7..c3b3a4f1 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,16 +12,14 @@ import math import sys import warnings -from collections.abc import Collection +from types import NoneType from typing import ( TYPE_CHECKING, Any, Final, Literal, - SupportsIndex, TypeAlias, TypeGuard, - TypeVar, cast, overload, ) @@ -29,39 +27,36 @@ from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace if TYPE_CHECKING: - + import cupy as cp import dask.array as da import jax import ndonnx as ndx import numpy as np import numpy.typing as npt - import sparse # pyright: ignore[reportMissingTypeStubs] + import sparse import torch # TODO: import from typing (requires Python >=3.13) - from typing_extensions import TypeIs, TypeVar - - _SizeT = TypeVar("_SizeT", bound = int | None) + from typing_extensions import TypeIs _ZeroGradientArray: TypeAlias = npt.NDArray[np.void] - _CupyArray: TypeAlias = Any # cupy has no py.typed _ArrayApiObj: TypeAlias = ( npt.NDArray[Any] + | cp.ndarray | da.Array | jax.Array | ndx.Array | sparse.SparseArray | torch.Tensor - | SupportsArrayNamespace[Any] - | _CupyArray + | SupportsArrayNamespace ) _API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) -def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: +def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. @@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: ) -def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: +def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: """ Return True if `x` is a NumPy array. @@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool: if "cupy" not in sys.modules: return False - import cupy as cp # pyright: ignore[reportMissingTypeStubs] + import cupy as cp # TODO: Should we reject ndarray subclasses? return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] @@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: if "sparse" not in sys.modules: return False - import sparse # pyright: ignore[reportMissingTypeStubs] + import sparse # TODO: Account for other backends. return isinstance(x, sparse.SparseArray) -def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] +def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: """ Return True if `x` is an array API compatible array object. @@ -587,7 +582,7 @@ def your_function(x, y): namespaces.add(cupy_namespace) else: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] + import cupy as cp namespaces.add(cp) elif is_torch_array(x): @@ -624,14 +619,14 @@ def your_function(x, y): if hasattr(jax.numpy, "__array_api_version__"): jnp = jax.numpy else: - import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] + import jax.experimental.array_api as jnp # type: ignore[no-redef] namespaces.add(jnp) elif is_pydata_sparse_array(x): if use_compat is True: _check_api_version(api_version) raise ValueError("`sparse` does not have an array-api-compat wrapper") else: - import sparse # pyright: ignore[reportMissingTypeStubs] + import sparse # `sparse` is already an array namespace. We do not have a wrapper # submodule for it. namespaces.add(sparse) @@ -640,9 +635,9 @@ def your_function(x, y): raise ValueError( "The given array does not have an array-api-compat wrapper" ) - x = cast("SupportsArrayNamespace[Any]", x) + x = cast(SupportsArrayNamespace, x) namespaces.add(x.__array_namespace__(api_version=api_version)) - elif isinstance(x, (bool, int, float, complex, type(None))): + elif isinstance(x, int | float | complex | NoneType): continue else: # TODO: Support Python scalars? @@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" elif is_dask_array(x): # Peek at the metadata of the Dask array to determine type - if is_numpy_array(x._meta): # pyright: ignore + if is_numpy_array(x._meta): # Must be on CPU since backed by numpy return "cpu" return _DASK_DEVICE @@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device: return "cpu" # Return the device of the constituent array return device(inner) # pyright: ignore - return x.device # pyright: ignore + return x.device # type: ignore # pyright: ignore # Prevent shadowing, used below @@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device: # Based on cupy.array_api.Array.to_device def _cupy_to_device( - x: _CupyArray, + x: cp.ndarray, device: Device, /, stream: int | Any | None = None, -) -> _CupyArray: - import cupy as cp # pyright: ignore[reportMissingTypeStubs] +) -> cp.ndarray: + import cupy as cp from cupy.cuda import Device as _Device # pyright: ignore from cupy.cuda import stream as stream_module # pyright: ignore from cupy_backends.cuda.api import runtime # pyright: ignore @@ -797,10 +792,10 @@ def _cupy_to_device( raise ValueError(f"Unsupported device {device!r}") else: # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] + prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] prev_stream = None if stream is not None: - prev_stream: Any = stream_module.get_current_stream() # pyright: ignore + prev_stream = stream_module.get_current_stream() # pyright: ignore # stream can be an int as specified in __dlpack__, or a CuPy stream if isinstance(stream, int): stream = cp.cuda.ExternalStream(stream) # pyright: ignore @@ -814,7 +809,7 @@ def _cupy_to_device( arr = x.copy() finally: runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] - if stream is not None: + if prev_stream is not None: prev_stream.use() return arr @@ -823,7 +818,7 @@ def _torch_to_device( x: torch.Tensor, device: torch.device | str | int, /, - stream: None = None, + stream: int | Any | None = None, ) -> torch.Tensor: if stream is not None: raise NotImplementedError @@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - # cupy does not yet have to_device return _cupy_to_device(x, device, stream=stream) elif is_torch_array(x): - return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] + return _torch_to_device(x, device, stream=stream) elif is_dask_array(x): if stream is not None: raise ValueError("The stream argument to to_device() is not supported") @@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - @overload -def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... +def size(x: HasShape[int]) -> int: ... @overload -def size(x: HasShape[Collection[None]]) -> None: ... +def size(x: HasShape[int | None]) -> int | None: ... @overload -def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... -def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: +def size(x: HasShape[float]) -> int | None: ... # Dask special case +def size(x: HasShape[float | None]) -> int | None: """ Return the total number of elements of x. @@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: # Lazy API compliant arrays, such as ndonnx, can contain None in their shape if None in x.shape: return None - out = math.prod(cast("Collection[SupportsIndex]", x.shape)) + out = math.prod(cast(tuple[float, ...], x.shape)) # dask.array.Array.shape can contain NaN - return None if math.isnan(out) else out + return None if math.isnan(out) else cast(int, out) -def is_writeable_array(x: object) -> bool: +def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. Return False if `x` is not an array API compatible object. @@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool: return is_array_api_obj(x) -def is_lazy_array(x: object) -> bool: +def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: """Return True if x is potentially a future or it may be otherwise impossible or expensive to eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``. @@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool: # on __bool__ (dask is one such example, which however is special-cased above). # Select a single point of the array - s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) + s = size(cast(HasShape, x)) if s is None: return True xp = array_namespace(x) @@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool: _all_ignore = ["sys", "math", "inspect", "warnings"] + def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7e002aed..f483af41 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -8,7 +8,7 @@ if np.__version__[0] == "2": from numpy.lib.array_utils import normalize_axis_tuple else: - from numpy.core.numeric import normalize_axis_tuple + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] from .._internal import get_xp from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot diff --git a/array_api_compat/common/_typing.py b/array_api_compat/common/_typing.py index d7deade1..c94f73fc 100644 --- a/array_api_compat/common/_typing.py +++ b/array_api_compat/common/_typing.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Mapping from types import ModuleType as Namespace from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, TypedDict, TypeVar @@ -26,13 +25,13 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ... def __len__(self, /) -> int: ... -class SupportsArrayNamespace(Protocol[_T_co]): - def __array_namespace__(self, /, *, api_version: str | None) -> _T_co: ... +class SupportsArrayNamespace(Protocol): + def __array_namespace__(self, /, *, api_version: str | None) -> Namespace: ... class HasShape(Protocol[_T_co]): @property - def shape(self, /) -> _T_co: ... + def shape(self, /) -> tuple[_T_co, ...]: ... # Return type of `__array_namespace_info__.default_dtypes` @@ -70,72 +69,11 @@ def shape(self, /) -> _T_co: ... DTypeKind: TypeAlias = _DTypeKind | tuple[_DTypeKind, ...] -# `__array_namespace_info__.dtypes(kind="bool")` -class DTypesBool(TypedDict): - bool: DType - - -# `__array_namespace_info__.dtypes(kind="signed integer")` -class DTypesSigned(TypedDict): - int8: DType - int16: DType - int32: DType - int64: DType - - -# `__array_namespace_info__.dtypes(kind="unsigned integer")` -class DTypesUnsigned(TypedDict): - uint8: DType - uint16: DType - uint32: DType - uint64: DType - - -# `__array_namespace_info__.dtypes(kind="integral")` -class DTypesIntegral(DTypesSigned, DTypesUnsigned): - pass - - -# `__array_namespace_info__.dtypes(kind="real floating")` -class DTypesReal(TypedDict): - float32: DType - float64: DType - - -# `__array_namespace_info__.dtypes(kind="complex floating")` -class DTypesComplex(TypedDict): - complex64: DType - complex128: DType - - -# `__array_namespace_info__.dtypes(kind="numeric")` -class DTypesNumeric(DTypesIntegral, DTypesReal, DTypesComplex): - pass - - -# `__array_namespace_info__.dtypes(kind=None)` (default) -class DTypesAll(DTypesBool, DTypesNumeric): - pass - - -# `__array_namespace_info__.dtypes(kind=?)` (fallback) -DTypesAny: TypeAlias = Mapping[str, DType] - - __all__ = [ "Array", "Capabilities", "DType", "DTypeKind", - "DTypesAny", - "DTypesAll", - "DTypesBool", - "DTypesNumeric", - "DTypesIntegral", - "DTypesSigned", - "DTypesUnsigned", - "DTypesReal", - "DTypesComplex", "DefaultDTypes", "Device", "HasShape", diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..da4be14b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import Optional +from builtins import bool as py_bool import cupy as cp - from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp @@ -69,18 +68,13 @@ # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - copy: Optional[bool] = _copy_default, - **kwargs, + dtype: DType | None = None, + device: Device | None = None, + copy: py_bool | None = _copy_default, + **kwargs: object, ) -> Array: """ Array API compatibility wrapper for asarray(). @@ -115,8 +109,8 @@ def astype( dtype: DType, /, *, - copy: bool = True, - device: Optional[Device] = None, + copy: py_bool = True, + device: Device | None = None, ) -> Array: if device is None: return x.astype(dtype=dtype, copy=copy) @@ -127,8 +121,8 @@ def astype( # cupy.count_nonzero does not have keepdims def count_nonzero( x: Array, - axis=None, - keepdims=False + axis: int | tuple[int, ...] | None = None, + keepdims: py_bool = False, ) -> Array: result = cp.count_nonzero(x, axis) if keepdims: @@ -161,4 +155,4 @@ def count_nonzero( 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] -_all_ignore = ['cp', 'get_xp'] +_all_ignore = ['cp', 'get_xp', 'py_bool'] diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..2bd11940 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -1,10 +1,11 @@ -from cupy.fft import * # noqa: F403 +from cupy.fft import * # noqa: F403 + # cupy.fft doesn't have __all__. If it is added, replace this with # # from cupy.fft import __all__ as linalg_all -_n = {} -exec('from cupy.fft import *', _n) -del _n['__builtins__'] +_n: dict[str, object] = {} +exec("from cupy.fft import *", _n) +del _n["__builtins__"] fft_all = list(_n) del _n diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..7bc3536e 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -2,7 +2,7 @@ # cupy.linalg doesn't have __all__. If it is added, replace this with # # from cupy.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from cupy.linalg import *', _n) del _n['__builtins__'] linalg_all = list(_n) diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1e47b960..6d2ea7cd 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -3,7 +3,7 @@ from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment] # noqa: F403 __array_api_version__: Final = "2024.12" diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9687a9cd..86870e9b 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -146,7 +146,7 @@ def arange( # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( - obj: complex | NestedSequence[complex] | Array | SupportsBufferProtocol, + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, diff --git a/array_api_compat/dask/array/_info.py b/array_api_compat/dask/array/_info.py index 9e4d736f..5e3e9018 100644 --- a/array_api_compat/dask/array/_info.py +++ b/array_api_compat/dask/array/_info.py @@ -12,9 +12,9 @@ from __future__ import annotations -from typing import Literal as L -from typing import TypeAlias, overload +from typing import Literal, TypeAlias +import dask.array as da from numpy import bool_ as bool from numpy import ( complex64, @@ -33,24 +33,10 @@ uint64, ) -from ...common._helpers import _DASK_DEVICE, _dask_device -from ...common._typing import ( - Capabilities, - DefaultDTypes, - DType, - DTypeKind, - DTypesAll, - DTypesAny, - DTypesBool, - DTypesComplex, - DTypesIntegral, - DTypesNumeric, - DTypesReal, - DTypesSigned, - DTypesUnsigned, -) +from ...common._helpers import _DASK_DEVICE, _check_device, _dask_device +from ...common._typing import Capabilities, DefaultDTypes, DType, DTypeKind -_Device: TypeAlias = L["cpu"] | _dask_device +Device: TypeAlias = Literal["cpu"] | _dask_device class __array_namespace_info__: @@ -142,7 +128,7 @@ def capabilities(self) -> Capabilities: "max dimensions": 64, } - def default_device(self) -> L["cpu"]: + def default_device(self) -> Device: """ The default device used for new Dask arrays. @@ -169,7 +155,7 @@ def default_device(self) -> L["cpu"]: """ return "cpu" - def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: + def default_dtypes(self, /, *, device: Device | None = None) -> DefaultDTypes: """ The default data types used for new Dask arrays. @@ -208,11 +194,7 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: 'indexing': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - f'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, ' - f"but received: {device!r}" - ) + _check_device(da, device) return { "real floating": dtype(float64), "complex floating": dtype(complex128), @@ -220,41 +202,9 @@ def default_dtypes(self, /, *, device: _Device | None = None) -> DefaultDTypes: "indexing": dtype(intp), } - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: None = None - ) -> DTypesAll: ... - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: L["bool"] - ) -> DTypesBool: ... - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: L["signed integer"] - ) -> DTypesSigned: ... - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: L["unsigned integer"] - ) -> DTypesUnsigned: ... - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: L["integral"] - ) -> DTypesIntegral: ... - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: L["real floating"] - ) -> DTypesReal: ... - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: L["complex floating"] - ) -> DTypesComplex: ... - @overload - def dtypes( - self, /, *, device: _Device | None = None, kind: L["numeric"] - ) -> DTypesNumeric: ... def dtypes( - self, /, *, device: _Device | None = None, kind: DTypeKind | None = None - ) -> DTypesAny: + self, /, *, device: Device | None = None, kind: DTypeKind | None = None + ) -> dict[str, DType]: """ The array API data types supported by Dask. @@ -308,11 +258,7 @@ def dtypes( 'int64': dask.int64} """ - if device not in ["cpu", _DASK_DEVICE, None]: - raise ValueError( - 'Device not understood. Only "cpu" or _DASK_DEVICE is allowed, but received:' - f" {device}" - ) + _check_device(da, device) if kind is None: return { "bool": dtype(bool), @@ -381,14 +327,14 @@ def dtypes( "complex64": dtype(complex64), "complex128": dtype(complex128), } - if isinstance(kind, tuple): # type: ignore[reportUnnecessaryIsinstanceCall] + if isinstance(kind, tuple): res: dict[str, DType] = {} for k in kind: res.update(self.dtypes(kind=k)) return res raise ValueError(f"unsupported kind: {kind!r}") - def devices(self) -> list[_Device]: + def devices(self) -> list[Device]: """ The devices supported by Dask. diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 3f40dffe..68c4280e 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -2,7 +2,7 @@ # dask.array.fft doesn't have __all__. If it is added, replace this with # # from dask.array.fft import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.fft import *', _n) for k in ("__builtins__", "Sequence", "annotations", "warnings"): _n.pop(k, None) diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 0825386e..06f596bc 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -4,21 +4,22 @@ import dask.array as da -# The `matmul` and `tensordot` functions are in both the main and linalg namespaces -from dask.array import matmul, outer, tensordot - # Exports from dask.array.linalg import * # noqa: F403 +from dask.array import outer +# The `matmul` and `tensordot` functions are in both the main and linalg namespaces +from dask.array import matmul, tensordot + from ..._internal import get_xp from ...common import _linalg -from ...common._typing import Array as _Array +from ...common._typing import Array from ._aliases import matrix_transpose, vecdot # dask.array.linalg doesn't have __all__. If it is added, replace this with # # from dask.array.linalg import __all__ as linalg_all -_n = {} +_n: dict[str, object] = {} exec('from dask.array.linalg import *', _n) for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): _n.pop(k, None) @@ -33,8 +34,8 @@ # supports the mode keyword on QR # https://github.com/dask/dask/issues/10388 #qr = get_xp(da)(_linalg.qr) -def qr( - x: _Array, +def qr( # type: ignore[no-redef] + x: Array, mode: Literal["reduced", "complete"] = "reduced", **kwargs: object, ) -> QRResult: @@ -50,12 +51,12 @@ def qr( # Wrap the svd functions to not pass full_matrices to dask # when full_matrices=False (as that is the default behavior for dask), # and dask doesn't have the full_matrices keyword -def svd(x: _Array, full_matrices: bool = True, **kwargs) -> SVDResult: +def svd(x: Array, full_matrices: bool = True, **kwargs: object) -> SVDResult: # type: ignore[no-redef] if full_matrices: raise ValueError("full_matrics=True is not supported by dask.") return da.linalg.svd(x, coerce_signs=False, **kwargs) -def svdvals(x: _Array) -> _Array: +def svdvals(x: Array) -> Array: # TODO: can't avoid computing U or V for dask _, s, _ = svd(x) return s diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index f7b558ba..ae9406a6 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -10,7 +10,7 @@ from numpy import round as round # These imports may overwrite names from the import * above. -from ._aliases import * # noqa: F403 +from ._aliases import * # type: ignore[assignment,no-redef] # noqa: F403 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -23,13 +23,6 @@ __import__(__package__ + ".fft") -from ..common._helpers import * # noqa: F403 -from .linalg import matrix_transpose, vecdot # noqa: F401 - -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass +from .linalg import matrix_transpose, vecdot # type: ignore[no-redef] # noqa: F401 __array_api_version__: Final = "2024.12" diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d8792611..0c75d47d 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -2,7 +2,7 @@ from __future__ import annotations from builtins import bool as py_bool -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast +from typing import Any, cast import numpy as np @@ -12,13 +12,6 @@ from ._info import __array_namespace_info__ from ._typing import Array, Device, DType -if TYPE_CHECKING: - from typing_extensions import Buffer, TypeIs - -# The values of the `_CopyMode` enum can be either `False`, `True`, or `2`: -# https://github.com/numpy/numpy/blob/5a8a6a79d9c2fff8f07dcab5d41e14f8508d673f/numpy/_globals.pyi#L7-L10 -_Copy: TypeAlias = py_bool | Literal[2] | np._CopyMode - bool = np.bool_ # Basic renames @@ -74,14 +67,6 @@ iinfo = get_xp(np)(_aliases.iinfo) -def _supports_buffer_protocol(obj: object) -> TypeIs[Buffer]: # pyright: ignore[reportUnusedFunction] - try: - memoryview(obj) # pyright: ignore[reportArgumentType] - except TypeError: - return False - return True - - # asarray also adds the copy keyword, which is not present in numpy 1.0. # asarray() is different enough between numpy, cupy, and dask, the logic # complicated enough that it's easier to define it separately for each module @@ -92,7 +77,7 @@ def asarray( *, dtype: DType | None = None, device: Device | None = None, - copy: _Copy | None = None, + copy: py_bool | None = None, **kwargs: Any, ) -> Array: """ @@ -103,14 +88,14 @@ def asarray( """ _helpers._check_device(np, device) + # None is unsupported in NumPy 1.0, but we can use an internal enum + # False in NumPy 1.0 means None in NumPy 2.0 and in the Array API if copy is None: - copy = np._CopyMode.IF_NEEDED + copy = np._CopyMode.IF_NEEDED # type: ignore[assignment,attr-defined] elif copy is False: - copy = np._CopyMode.NEVER - elif copy is True: - copy = np._CopyMode.ALWAYS + copy = np._CopyMode.NEVER # type: ignore[assignment,attr-defined] - return np.array(obj, copy=copy, dtype=dtype, **kwargs) # pyright: ignore + return np.array(obj, copy=copy, dtype=dtype, **kwargs) def astype( @@ -134,7 +119,7 @@ def count_nonzero( ) -> Array: # NOTE: this is currently incorrectly typed in numpy, but will be fixed in # numpy 2.2.5 and 2.3.0: https://github.com/numpy/numpy/pull/28750 - result = cast("Any", np.count_nonzero(x, axis=axis, keepdims=keepdims)) # pyright: ignore[reportArgumentType, reportCallIssue] + result = cast(Any, np.count_nonzero(x, axis=axis, keepdims=keepdims)) # type: ignore[arg-type] # pyright: ignore[reportArgumentType, reportCallIssue] if axis is None and not keepdims: return np.asarray(result) return result @@ -145,7 +130,7 @@ def count_nonzero( if hasattr(np, "vecdot"): vecdot = np.vecdot else: - vecdot = get_xp(np)(_aliases.vecdot) + vecdot = get_xp(np)(_aliases.vecdot) # type: ignore[assignment] if hasattr(np, "isdtype"): isdtype = np.isdtype diff --git a/array_api_compat/numpy/_info.py b/array_api_compat/numpy/_info.py index f307f62c..11126e5d 100644 --- a/array_api_compat/numpy/_info.py +++ b/array_api_compat/numpy/_info.py @@ -7,6 +7,7 @@ more details. """ + from __future__ import annotations from numpy import bool_ as bool @@ -27,6 +28,7 @@ uint64, ) +from ..common._typing import DefaultDTypes from ._typing import Device, DType @@ -62,7 +64,7 @@ class __array_namespace_info__: """ - __module__ = 'numpy' + __module__ = "numpy" def capabilities(self): """ @@ -139,7 +141,7 @@ def default_dtypes( self, *, device: Device | None = None, - ) -> dict[str, dtype[intp | float64 | complex128]]: + ) -> DefaultDTypes: """ The default data types used for new NumPy arrays. @@ -181,8 +183,7 @@ def default_dtypes( """ if device not in ["cpu", None]: raise ValueError( - 'Device not understood. Only "cpu" is allowed, but received:' - f' {device}' + f'Device not understood. Only "cpu" is allowed, but received: {device}' ) return { "real floating": dtype(float64), @@ -253,8 +254,7 @@ def dtypes( """ if device not in ["cpu", None]: raise ValueError( - 'Device not understood. Only "cpu" is allowed, but received:' - f' {device}' + f'Device not understood. Only "cpu" is allowed, but received: {device}' ) if kind is None: return { diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index e771c788..617cfb71 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -7,7 +7,6 @@ Device: TypeAlias = Literal["cpu"] if TYPE_CHECKING: - # NumPy 1.x on Python 3.10 fails to parse np.dtype[] DType: TypeAlias = np.dtype[ np.bool_ diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 2d3e731d..9a618be9 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -65,7 +65,7 @@ # https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43). def solve(x1: Array, x2: Array, /) -> Array: try: - from numpy.linalg._linalg import ( + from numpy.linalg._linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, @@ -74,7 +74,7 @@ def solve(x1: Array, x2: Array, /) -> Array: isComplexType, ) except ImportError: - from numpy.linalg.linalg import ( + from numpy.linalg.linalg import ( # type: ignore[attr-defined] _assert_stacked_2d, _assert_stacked_square, _commonType, diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 027a0261..5a7d1870 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,8 +1,9 @@ from __future__ import annotations +from collections.abc import Sequence from functools import reduce as _reduce, wraps as _wraps from builtins import all as _builtin_all, any as _builtin_any -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any import torch @@ -96,9 +97,7 @@ def _fix_promotion(x1, x2, only_scalar=True): _py_scalars = (bool, int, float, complex) -def result_type( - *arrays_and_dtypes: Array | DType | bool | int | float | complex -) -> DType: +def result_type(*arrays_and_dtypes: Array | DType | complex) -> DType: num = len(arrays_and_dtypes) if num == 0: @@ -129,10 +128,7 @@ def result_type( return _reduce(_result_type, others + scalars) -def _result_type( - x: Array | DType | bool | int | float | complex, - y: Array | DType | bool | int | float | complex, -) -> DType: +def _result_type(x: Array | DType | complex, y: Array | DType | complex) -> DType: if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)): xdt = x if isinstance(x, torch.dtype) else x.dtype ydt = y if isinstance(y, torch.dtype) else y.dtype @@ -150,7 +146,7 @@ def _result_type( return torch.result_type(x, y) -def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: +def can_cast(from_: DType | Array, to: DType, /) -> bool: if not isinstance(from_, torch.dtype): from_ = from_.dtype return torch.can_cast(from_, to) @@ -194,12 +190,7 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool: def asarray( - obj: ( - Array - | bool | int | float | complex - | NestedSequence[bool | int | float | complex] - | SupportsBufferProtocol - ), + obj: Array | complex | NestedSequence[complex] | SupportsBufferProtocol, /, *, dtype: DType | None = None, @@ -218,13 +209,13 @@ def asarray( # of 'axis'. # torch.min and torch.max return a tuple and don't support multiple axes https://github.com/pytorch/pytorch/issues/58745 -def max(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def max(x: Array, /, *, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) return torch.amax(x, axis, keepdims=keepdims) -def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims: bool = False) -> Array: +def min(x: Array, /, *, axis: int | tuple[int, ...] |None = None, keepdims: bool = False) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -240,7 +231,15 @@ def min(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep # torch.sort also returns a tuple # https://github.com/pytorch/pytorch/issues/70921 -def sort(x: Array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> Array: +def sort( + x: Array, + /, + *, + axis: int = -1, + descending: bool = False, + stable: bool = True, + **kwargs: object, +) -> Array: return torch.sort(x, dim=axis, descending=descending, stable=stable, **kwargs).values def _normalize_axes(axis, ndim): @@ -307,10 +306,10 @@ def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array: def prod(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -331,10 +330,10 @@ def prod(x: Array, def sum(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype: Optional[DType] = None, + axis: int | tuple[int, ...] | None = None, + dtype: DType | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return _sum_prod_no_axis(x, dtype) @@ -350,9 +349,9 @@ def sum(x: Array, def any(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -374,9 +373,9 @@ def any(x: Array, def all(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: if axis == (): return x.to(torch.bool) @@ -398,9 +397,9 @@ def all(x: Array, def mean(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # https://github.com/pytorch/pytorch/issues/29137 if axis == (): return torch.clone(x) @@ -415,10 +414,10 @@ def mean(x: Array, def std(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -446,10 +445,10 @@ def std(x: Array, def var(x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - correction: Union[int, float] = 0.0, + axis: int | tuple[int, ...] | None = None, + correction: float = 0.0, keepdims: bool = False, - **kwargs) -> Array: + **kwargs: object) -> Array: # Note, float correction is not supported # https://github.com/pytorch/pytorch/issues/61492. We don't try to # implement it here for now. @@ -472,11 +471,11 @@ def var(x: Array, # torch.concat doesn't support dim=None # https://github.com/pytorch/pytorch/issues/70925 -def concat(arrays: Union[Tuple[Array, ...], List[Array]], +def concat(arrays: tuple[Array, ...] | list[Array], /, *, - axis: Optional[int] = 0, - **kwargs) -> Array: + axis: int | None = 0, + **kwargs: object) -> Array: if axis is None: arrays = tuple(ar.flatten() for ar in arrays) axis = 0 @@ -485,7 +484,7 @@ def concat(arrays: Union[Tuple[Array, ...], List[Array]], # torch.squeeze only accepts int dim and doesn't require it # https://github.com/pytorch/pytorch/issues/70924. Support for tuple dim was # added at https://github.com/pytorch/pytorch/pull/89017. -def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: +def squeeze(x: Array, /, axis: int | tuple[int, ...]) -> Array: if isinstance(axis, int): axis = (axis,) for a in axis: @@ -499,27 +498,27 @@ def squeeze(x: Array, /, axis: Union[int, Tuple[int, ...]]) -> Array: return x # torch.broadcast_to uses size instead of shape -def broadcast_to(x: Array, /, shape: Tuple[int, ...], **kwargs) -> Array: +def broadcast_to(x: Array, /, shape: tuple[int, ...], **kwargs: object) -> Array: return torch.broadcast_to(x, shape, **kwargs) # torch.permute uses dims instead of axes -def permute_dims(x: Array, /, axes: Tuple[int, ...]) -> Array: +def permute_dims(x: Array, /, axes: tuple[int, ...]) -> Array: return torch.permute(x, axes) # The axis parameter doesn't work for flip() and roll() # https://github.com/pytorch/pytorch/issues/71210. Also torch.flip() doesn't # accept axis=None -def flip(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def flip(x: Array, /, *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: if axis is None: axis = tuple(range(x.ndim)) # torch.flip doesn't accept dim as an int but the method does # https://github.com/pytorch/pytorch/issues/18095 return x.flip(axis, **kwargs) -def roll(x: Array, /, shift: Union[int, Tuple[int, ...]], *, axis: Optional[Union[int, Tuple[int, ...]]] = None, **kwargs) -> Array: +def roll(x: Array, /, shift: int | tuple[int, ...], *, axis: int | tuple[int, ...] | None = None, **kwargs: object) -> Array: return torch.roll(x, shift, axis, **kwargs) -def nonzero(x: Array, /, **kwargs) -> Tuple[Array, ...]: +def nonzero(x: Array, /, **kwargs: object) -> tuple[Array, ...]: if x.ndim == 0: raise ValueError("nonzero() does not support zero-dimensional arrays") return torch.nonzero(x, as_tuple=True, **kwargs) @@ -532,8 +531,8 @@ def diff( *, axis: int = -1, n: int = 1, - prepend: Optional[Array] = None, - append: Optional[Array] = None, + prepend: Array | None = None, + append: Array | None = None, ) -> Array: return torch.diff(x, dim=axis, n=n, prepend=prepend, append=append) @@ -543,7 +542,7 @@ def count_nonzero( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, ) -> Array: result = torch.count_nonzero(x, dim=axis) @@ -560,12 +559,7 @@ def repeat(x: Array, repeats: int | Array, /, *, axis: int | None = None) -> Arr return torch.repeat_interleave(x, repeats, axis) -def where( - condition: Array, - x1: Array | bool | int | float | complex, - x2: Array | bool | int | float | complex, - /, -) -> Array: +def where(condition: Array, x1: Array | complex, x2: Array | complex, /) -> Array: x1, x2 = _fix_promotion(x1, x2) return torch.where(condition, x1, x2) @@ -573,10 +567,10 @@ def where( # torch.reshape doesn't have the copy keyword def reshape(x: Array, /, - shape: Tuple[int, ...], + shape: tuple[int, ...], *, - copy: Optional[bool] = None, - **kwargs) -> Array: + copy: bool | None = None, + **kwargs: object) -> Array: if copy is not None: raise NotImplementedError("torch.reshape doesn't yet support the copy keyword") return torch.reshape(x, shape, **kwargs) @@ -585,14 +579,14 @@ def reshape(x: Array, # (https://github.com/pytorch/pytorch/issues/70915), and doesn't support some # keyword argument combinations # (https://github.com/pytorch/pytorch/issues/70914) -def arange(start: Union[int, float], +def arange(start: float, /, - stop: Optional[Union[int, float]] = None, - step: Union[int, float] = 1, + stop: float | None = None, + step: float = 1, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if stop is None: start, stop = 0, start if step > 0 and stop <= start or step < 0 and stop >= start: @@ -607,13 +601,13 @@ def arange(start: Union[int, float], # torch.eye does not accept None as a default for the second argument and # doesn't support off-diagonals (https://github.com/pytorch/pytorch/issues/70910) def eye(n_rows: int, - n_cols: Optional[int] = None, + n_cols: int | None = None, /, *, k: int = 0, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if n_cols is None: n_cols = n_rows z = torch.zeros(n_rows, n_cols, dtype=dtype, device=device, **kwargs) @@ -622,52 +616,52 @@ def eye(n_rows: int, return z # torch.linspace doesn't have the endpoint parameter -def linspace(start: Union[int, float], - stop: Union[int, float], +def linspace(start: float, + stop: float, /, num: int, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, + dtype: DType | None = None, + device: Device | None = None, endpoint: bool = True, - **kwargs) -> Array: + **kwargs: object) -> Array: if not endpoint: return torch.linspace(start, stop, num+1, dtype=dtype, device=device, **kwargs)[:-1] return torch.linspace(start, stop, num, dtype=dtype, device=device, **kwargs) # torch.full does not accept an int size # https://github.com/pytorch/pytorch/issues/70906 -def full(shape: Union[int, Tuple[int, ...]], - fill_value: bool | int | float | complex, +def full(shape: int | tuple[int, ...], + fill_value: complex, *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: if isinstance(shape, int): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype, device=device, **kwargs) # ones, zeros, and empty do not accept shape as a keyword argument -def ones(shape: Union[int, Tuple[int, ...]], +def ones(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.ones(shape, dtype=dtype, device=device, **kwargs) -def zeros(shape: Union[int, Tuple[int, ...]], +def zeros(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.zeros(shape, dtype=dtype, device=device, **kwargs) -def empty(shape: Union[int, Tuple[int, ...]], +def empty(shape: int | tuple[int, ...], *, - dtype: Optional[DType] = None, - device: Optional[Device] = None, - **kwargs) -> Array: + dtype: DType | None = None, + device: Device | None = None, + **kwargs: object) -> Array: return torch.empty(shape, dtype=dtype, device=device, **kwargs) # tril and triu do not call the keyword argument k @@ -689,14 +683,14 @@ def astype( /, *, copy: bool = True, - device: Optional[Device] = None, + device: Device | None = None, ) -> Array: if device is not None: return x.to(device, dtype=dtype, copy=copy) return x.to(dtype=dtype, copy=copy) -def broadcast_arrays(*arrays: Array) -> List[Array]: +def broadcast_arrays(*arrays: Array) -> list[Array]: shape = torch.broadcast_shapes(*[a.shape for a in arrays]) return [torch.broadcast_to(a, shape) for a in arrays] @@ -734,7 +728,7 @@ def unique_inverse(x: Array) -> UniqueInverseResult: def unique_values(x: Array) -> Array: return torch.unique(x) -def matmul(x1: Array, x2: Array, /, **kwargs) -> Array: +def matmul(x1: Array, x2: Array, /, **kwargs: object) -> Array: # torch.matmul doesn't type promote (but differently from _fix_promotion) x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.matmul(x1, x2, **kwargs) @@ -752,8 +746,8 @@ def tensordot( x2: Array, /, *, - axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2, - **kwargs, + axes: int | tuple[Sequence[int], Sequence[int]] = 2, + **kwargs: object, ) -> Array: # Note: torch.tensordot fails with integer dtypes when there is only 1 # element in the axis (https://github.com/pytorch/pytorch/issues/84530). @@ -762,8 +756,10 @@ def tensordot( def isdtype( - dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]], - *, _tuple=True, # Disallow nested tuples + dtype: DType, + kind: DType | str | tuple[DType | str, ...], + *, + _tuple: bool = True, # Disallow nested tuples ) -> bool: """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. @@ -797,7 +793,7 @@ def isdtype( else: return dtype == kind -def take(x: Array, indices: Array, /, *, axis: Optional[int] = None, **kwargs) -> Array: +def take(x: Array, indices: Array, /, *, axis: int | None = None, **kwargs: object) -> Array: if axis is None: if x.ndim != 1: raise ValueError("axis must be specified when ndim > 1") diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 50e6a0d0..ddf87c65 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Union, Sequence, Literal +from collections.abc import Sequence +from typing import Literal import torch import torch.fft @@ -17,7 +18,7 @@ def fftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.fftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -28,7 +29,7 @@ def ifftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -39,7 +40,7 @@ def rfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.rfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -50,7 +51,7 @@ def irfftn( s: Sequence[int] = None, axes: Sequence[int] = None, norm: Literal["backward", "ortho", "forward"] = "backward", - **kwargs, + **kwargs: object, ) -> Array: return torch.fft.irfftn(x, s=s, dim=axes, norm=norm, **kwargs) @@ -58,8 +59,8 @@ def fftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.fftshift(x, dim=axes, **kwargs) @@ -67,8 +68,8 @@ def ifftshift( x: Array, /, *, - axes: Union[int, Sequence[int]] = None, - **kwargs, + axes: int | Sequence[int] = None, + **kwargs: object, ) -> Array: return torch.fft.ifftshift(x, dim=axes, **kwargs) diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 1ff7319d..490b7bd1 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,8 +1,6 @@ from __future__ import annotations import torch -from typing import Optional, Union, Tuple - from torch.linalg import * # noqa: F403 # torch.linalg doesn't define __all__ @@ -31,7 +29,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: x1, x2 = torch.broadcast_tensors(x1, x2) return torch_linalg.cross(x1, x2, dim=axis) -def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: +def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs: object) -> Array: from ._aliases import isdtype x1, x2 = _fix_promotion(x1, x2, only_scalar=False) @@ -53,7 +51,7 @@ def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: return res[..., 0, 0] return torch.linalg.vecdot(x1, x2, dim=axis, **kwargs) -def solve(x1: Array, x2: Array, /, **kwargs) -> Array: +def solve(x1: Array, x2: Array, /, **kwargs: object) -> Array: x1, x2 = _fix_promotion(x1, x2, only_scalar=False) # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve # whenever @@ -74,7 +72,7 @@ def solve(x1: Array, x2: Array, /, **kwargs) -> Array: return torch.linalg.solve(x1, x2, **kwargs) # torch.trace doesn't support the offset argument and doesn't support stacking -def trace(x: Array, /, *, offset: int = 0, dtype: Optional[DType] = None) -> Array: +def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array: # Use our wrapped sum to make sure it does upcasting correctly return sum(torch.diagonal(x, offset=offset, dim1=-2, dim2=-1), axis=-1, dtype=dtype) @@ -82,11 +80,11 @@ def vector_norm( x: Array, /, *, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: int | tuple[int, ...] | None = None, keepdims: bool = False, # float stands for inf | -inf, which are not valid for Literal - ord: Union[int, float] = 2, - **kwargs, + ord: float = 2, + **kwargs: object, ) -> Array: # torch.vector_norm incorrectly treats axis=() the same as axis=None if axis == (): diff --git a/pyproject.toml b/pyproject.toml index aacebd11..86310358 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,11 +43,11 @@ dev = [ "array-api-strict", "dask[array]>=2024.9.0", "jax[cpu]", + "ndonnx", "numpy>=1.22", "pytest", "torch", "sparse>=0.15.1", - "ndonnx" ] [project.urls] @@ -61,7 +61,7 @@ version = {attr = "array_api_compat.__version__"} include = ["array_api_compat*"] namespaces = false -[toolint] +[tool.ruff.lint] preview = true select = [ # Defaults @@ -79,20 +79,44 @@ ignore = [ "E722" ] -[tool.ruff.lint] -preview = true -select = [ -# Defaults -"E4", "E7", "E9", "F", -# Undefined export -"F822", -# Useless import alias -"PLC0414" -] -ignore = [ - # Module import not at top of file - "E402", - # Do not use bare `except` - "E722" +[tool.mypy] +files = ["array_api_compat"] +python_version = "3.10" +disallow_incomplete_defs = true +disallow_untyped_decorators = true +disallow_untyped_defs = false # TODO +ignore_missing_imports = true +no_implicit_optional = true +show_error_codes = true +warn_redundant_casts = true +warn_unused_ignores = true +warn_unreachable = true + +[[tool.mypy.overrides]] +module = ["cupy.*", "dask.*", "jax.*", "ndonnx.*", "sparse.*", "torch.*"] +ignore_missing_imports = true + + +[tool.pyright] +include = ["src", "tests"] +pythonVersion = "3.10" +pythonPlatform = "All" + +reportAny = false +reportExplicitAny = false +# missing type stubs +reportAttributeAccessIssue = false +reportUnknownMemberType = false +reportUnknownVariableType = false +# Redundant with mypy checks +reportMissingImports = false +reportMissingTypeStubs = false +# false positives for input validation +reportUnreachable = false +# ruff handles this +reportUnusedParameter = false + +executionEnvironments = [ + { root = "array_api_compat" }, ]