Skip to content

TYP: Type annotations, part 4 #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
specification for more details.

"""
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]
wrapped_f.__signature__ = new_sig # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # type: ignore[return-value] # pyright: ignore[reportReturnType]

return inner

Expand Down
16 changes: 9 additions & 7 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast
from collections.abc import Sequence
from types import NoneType
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from ._helpers import _check_device, array_namespace
from ._helpers import device as _get_device
from ._helpers import is_cupy_namespace as _is_cupy_namespace
from ._helpers import is_cupy_namespace
from ._typing import Array, Device, DType, Namespace

if TYPE_CHECKING:
Expand Down Expand Up @@ -381,8 +383,8 @@ def clip(
# TODO: np.clip has other ufunc kwargs
out: Array | None = None,
) -> Array:
def _isscalar(a: object) -> TypeIs[int | float | None]:
return isinstance(a, (int, float, type(None)))
def _isscalar(a: object) -> TypeIs[float | None]:
return isinstance(a, int | float | NoneType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With

Suggested change
return isinstance(a, int | float | NoneType)
return a is None or isinstance(a, int | float)

we avoid the types import while simultaneously accentuating the violent dissonance between the Python runtime and its type-system, given that the sole purpose of a type-system is to accurately describe the runtime behavior...


min_shape = () if _isscalar(min) else min.shape
max_shape = () if _isscalar(max) else max.shape
Expand Down Expand Up @@ -450,7 +452,7 @@ def reshape(
shape: tuple[int, ...],
xp: Namespace,
*,
copy: Optional[bool] = None,
copy: bool | None = None,
**kwargs: object,
) -> Array:
if copy is True:
Expand Down Expand Up @@ -657,7 +659,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array:
out = xp.sign(x, **kwargs)
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp):
out[xp.isnan(x)] = xp.nan
return out[()]

Expand Down Expand Up @@ -720,7 +722,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
"finfo",
"iinfo",
]
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]
_all_ignore = ["is_cupy_namespace", "inspect", "array_namespace", "NamedTuple"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have the __dir__ functions, are these _all_ignore's still needed?



def __dir__() -> list[str]:
Expand Down
76 changes: 36 additions & 40 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,51 @@
import math
import sys
import warnings
from collections.abc import Collection
from types import NoneType
from typing import (
TYPE_CHECKING,
Any,
Final,
Literal,
SupportsIndex,
TypeAlias,
TypeGuard,
TypeVar,
cast,
overload,
)

from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace

if TYPE_CHECKING:

import cupy as cp
import dask.array as da
import jax
import ndonnx as ndx
import numpy as np
import numpy.typing as npt
import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse
import torch

# TODO: import from typing (requires Python >=3.13)
from typing_extensions import TypeIs, TypeVar

_SizeT = TypeVar("_SizeT", bound = int | None)
from typing_extensions import TypeIs

_ZeroGradientArray: TypeAlias = npt.NDArray[np.void]
_CupyArray: TypeAlias = Any # cupy has no py.typed

_ArrayApiObj: TypeAlias = (
npt.NDArray[Any]
| cp.ndarray
| da.Array
| jax.Array
| ndx.Array
| sparse.SparseArray
| torch.Tensor
| SupportsArrayNamespace[Any]
| _CupyArray
| SupportsArrayNamespace
)

_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"})
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})


def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TypeGuard was intentional. Because even if x is a _zeroGradientArray, and therefore a npt.NDArray[np.void], the function might still return False, in which case the TypeIs would narrow x to be not npt.NDArray[np.void], whereas a TypeGuard wouldn't.

So it's better to revert this change (and the one below here)

"""Return True if `x` is a zero-gradient array.

These arrays are a design quirk of Jax that may one day be removed.
Expand All @@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
)


def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]:
"""
Return True if `x` is a NumPy array.

Expand Down Expand Up @@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool:
if "cupy" not in sys.modules:
return False

import cupy as cp # pyright: ignore[reportMissingTypeStubs]
import cupy as cp

# TODO: Should we reject ndarray subclasses?
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType]
Expand Down Expand Up @@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
if "sparse" not in sys.modules:
return False

import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse

# TODO: Account for other backends.
return isinstance(x, sparse.SparseArray)


def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the current definition of _ArrayApiObj, TypeIs would cause downstream failures for all unknown array api compliant libraries.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't SupportsArrayNamespace cover all downstream array types?

"""
Return True if `x` is an array API compatible array object.

Expand Down Expand Up @@ -587,7 +582,7 @@ def your_function(x, y):

namespaces.add(cupy_namespace)
else:
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
import cupy as cp

namespaces.add(cp)
elif is_torch_array(x):
Expand Down Expand Up @@ -624,14 +619,14 @@ def your_function(x, y):
if hasattr(jax.numpy, "__array_api_version__"):
jnp = jax.numpy
else:
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports]
import jax.experimental.array_api as jnp # type: ignore[no-redef]
namespaces.add(jnp)
elif is_pydata_sparse_array(x):
if use_compat is True:
_check_api_version(api_version)
raise ValueError("`sparse` does not have an array-api-compat wrapper")
else:
import sparse # pyright: ignore[reportMissingTypeStubs]
import sparse
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no need for the else: clause

# `sparse` is already an array namespace. We do not have a wrapper
# submodule for it.
namespaces.add(sparse)
Expand All @@ -640,9 +635,9 @@ def your_function(x, y):
raise ValueError(
"The given array does not have an array-api-compat wrapper"
)
x = cast("SupportsArrayNamespace[Any]", x)
x = cast(SupportsArrayNamespace, x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

namespaces.add(x.__array_namespace__(api_version=api_version))
elif isinstance(x, (bool, int, float, complex, type(None))):
elif isinstance(x, int | float | complex | NoneType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif isinstance(x, int | float | complex | NoneType):
elif x is None or isinstance(x, int | float | complex):

(I'll spare you the pseudo-philosophical rant this time)

continue
else:
# TODO: Support Python scalars?
Expand Down Expand Up @@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
elif is_dask_array(x):
# Peek at the metadata of the Dask array to determine type
if is_numpy_array(x._meta): # pyright: ignore
if is_numpy_array(x._meta):
# Must be on CPU since backed by numpy
return "cpu"
return _DASK_DEVICE
Expand Down Expand Up @@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device:
return "cpu"
# Return the device of the constituent array
return device(inner) # pyright: ignore
return x.device # pyright: ignore
return x.device # type: ignore # pyright: ignore


# Prevent shadowing, used below
Expand All @@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device:

# Based on cupy.array_api.Array.to_device
def _cupy_to_device(
x: _CupyArray,
x: cp.ndarray,
device: Device,
/,
stream: int | Any | None = None,
) -> _CupyArray:
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
) -> cp.ndarray:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has this been fixed in cupy since last time or something?

import cupy as cp
from cupy.cuda import Device as _Device # pyright: ignore
from cupy.cuda import stream as stream_module # pyright: ignore
from cupy_backends.cuda.api import runtime # pyright: ignore
Expand All @@ -797,10 +792,10 @@ def _cupy_to_device(
raise ValueError(f"Unsupported device {device!r}")
else:
# see cupy/cupy#5985 for the reason how we handle device/stream here
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType]
prev_stream = None
if stream is not None:
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore
prev_stream = stream_module.get_current_stream() # pyright: ignore
# stream can be an int as specified in __dlpack__, or a CuPy stream
if isinstance(stream, int):
stream = cp.cuda.ExternalStream(stream) # pyright: ignore
Expand All @@ -814,7 +809,7 @@ def _cupy_to_device(
arr = x.copy()
finally:
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType]
if stream is not None:
if prev_stream is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casual bugfix 🤔 ?

prev_stream.use()
return arr

Expand All @@ -823,7 +818,7 @@ def _torch_to_device(
x: torch.Tensor,
device: torch.device | str | int,
/,
stream: None = None,
stream: int | Any | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we should create a stream type-alias (unless there already is one)

) -> torch.Tensor:
if stream is not None:
raise NotImplementedError
Expand Down Expand Up @@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif is_torch_array(x):
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType]
return _torch_to_device(x, device, stream=stream)
elif is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
Expand All @@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) -


@overload
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ...
def size(x: HasShape[int]) -> int: ...
@overload
def size(x: HasShape[Collection[None]]) -> None: ...
def size(x: HasShape[int | None]) -> int | None: ...
@overload
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ...
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
def size(x: HasShape[float]) -> int | None: ... # Dask special case
def size(x: HasShape[float | None]) -> int | None:
Comment on lines +912 to +917
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove the overload that returns -> None?

"""
Return the total number of elements of x.

Expand All @@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
if None in x.shape:
return None
out = math.prod(cast("Collection[SupportsIndex]", x.shape))
out = math.prod(cast(tuple[float, ...], x.shape))
# dask.array.Array.shape can contain NaN
return None if math.isnan(out) else out
return None if math.isnan(out) else cast(int, out)


def is_writeable_array(x: object) -> bool:
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]:
"""
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
Return False if `x` is not an array API compatible object.
Expand All @@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool:
return is_array_api_obj(x)


def is_lazy_array(x: object) -> bool:
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anything with a shape: tuple that contains a None would return True here, so this isn't correct.

"""Return True if x is potentially a future or it may be otherwise impossible or
expensive to eagerly read its contents, regardless of their size, e.g. by
calling ``bool(x)`` or ``float(x)``.
Expand Down Expand Up @@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool:
# on __bool__ (dask is one such example, which however is special-cased above).

# Select a single point of the array
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x))
s = size(cast(HasShape, x))
if s is None:
return True
xp = array_namespace(x)
Expand Down Expand Up @@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool:

_all_ignore = ["sys", "math", "inspect", "warnings"]


def __dir__() -> list[str]:
return __all__
2 changes: 1 addition & 1 deletion array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if np.__version__[0] == "2":
from numpy.lib.array_utils import normalize_axis_tuple
else:
from numpy.core.numeric import normalize_axis_tuple
from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef]

from .._internal import get_xp
from ._aliases import isdtype, matmul, matrix_transpose, tensordot, vecdot
Expand Down
Loading
Loading