-
Notifications
You must be signed in to change notification settings - Fork 33
TYP: Type annotations, part 4 #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,13 @@ | |
from __future__ import annotations | ||
|
||
import inspect | ||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast | ||
from collections.abc import Sequence | ||
from types import NoneType | ||
from typing import TYPE_CHECKING, Any, NamedTuple, cast | ||
|
||
from ._helpers import _check_device, array_namespace | ||
from ._helpers import device as _get_device | ||
from ._helpers import is_cupy_namespace as _is_cupy_namespace | ||
from ._helpers import is_cupy_namespace | ||
from ._typing import Array, Device, DType, Namespace | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -381,8 +383,8 @@ def clip( | |
# TODO: np.clip has other ufunc kwargs | ||
out: Array | None = None, | ||
) -> Array: | ||
def _isscalar(a: object) -> TypeIs[int | float | None]: | ||
return isinstance(a, (int, float, type(None))) | ||
def _isscalar(a: object) -> TypeIs[float | None]: | ||
return isinstance(a, int | float | NoneType) | ||
|
||
min_shape = () if _isscalar(min) else min.shape | ||
max_shape = () if _isscalar(max) else max.shape | ||
|
@@ -450,7 +452,7 @@ def reshape( | |
shape: tuple[int, ...], | ||
xp: Namespace, | ||
*, | ||
copy: Optional[bool] = None, | ||
copy: bool | None = None, | ||
**kwargs: object, | ||
) -> Array: | ||
if copy is True: | ||
|
@@ -657,7 +659,7 @@ def sign(x: Array, /, xp: Namespace, **kwargs: object) -> Array: | |
out = xp.sign(x, **kwargs) | ||
# CuPy sign() does not propagate nans. See | ||
# https://github.com/data-apis/array-api-compat/issues/136 | ||
if _is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): | ||
if is_cupy_namespace(xp) and isdtype(x.dtype, "real floating", xp=xp): | ||
out[xp.isnan(x)] = xp.nan | ||
return out[()] | ||
|
||
|
@@ -720,7 +722,7 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: | |
"finfo", | ||
"iinfo", | ||
] | ||
_all_ignore = ["inspect", "array_namespace", "NamedTuple"] | ||
_all_ignore = ["is_cupy_namespace", "inspect", "array_namespace", "NamedTuple"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that we have the |
||
|
||
|
||
def __dir__() -> list[str]: | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,56 +12,51 @@ | |||||
import math | ||||||
import sys | ||||||
import warnings | ||||||
from collections.abc import Collection | ||||||
from types import NoneType | ||||||
from typing import ( | ||||||
TYPE_CHECKING, | ||||||
Any, | ||||||
Final, | ||||||
Literal, | ||||||
SupportsIndex, | ||||||
TypeAlias, | ||||||
TypeGuard, | ||||||
TypeVar, | ||||||
cast, | ||||||
overload, | ||||||
) | ||||||
|
||||||
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
|
||||||
import cupy as cp | ||||||
import dask.array as da | ||||||
import jax | ||||||
import ndonnx as ndx | ||||||
import numpy as np | ||||||
import numpy.typing as npt | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
import torch | ||||||
|
||||||
# TODO: import from typing (requires Python >=3.13) | ||||||
from typing_extensions import TypeIs, TypeVar | ||||||
|
||||||
_SizeT = TypeVar("_SizeT", bound = int | None) | ||||||
from typing_extensions import TypeIs | ||||||
|
||||||
_ZeroGradientArray: TypeAlias = npt.NDArray[np.void] | ||||||
_CupyArray: TypeAlias = Any # cupy has no py.typed | ||||||
|
||||||
_ArrayApiObj: TypeAlias = ( | ||||||
npt.NDArray[Any] | ||||||
| cp.ndarray | ||||||
| da.Array | ||||||
| jax.Array | ||||||
| ndx.Array | ||||||
| sparse.SparseArray | ||||||
| torch.Tensor | ||||||
| SupportsArrayNamespace[Any] | ||||||
| _CupyArray | ||||||
| SupportsArrayNamespace | ||||||
) | ||||||
|
||||||
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) | ||||||
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) | ||||||
|
||||||
|
||||||
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | ||||||
def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The So it's better to revert this change (and the one below here) |
||||||
"""Return True if `x` is a zero-gradient array. | ||||||
|
||||||
These arrays are a design quirk of Jax that may one day be removed. | ||||||
|
@@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | |||||
) | ||||||
|
||||||
|
||||||
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: | ||||||
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: | ||||||
""" | ||||||
Return True if `x` is a NumPy array. | ||||||
|
||||||
|
@@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool: | |||||
if "cupy" not in sys.modules: | ||||||
return False | ||||||
|
||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
# TODO: Should we reject ndarray subclasses? | ||||||
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] | ||||||
|
@@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: | |||||
if "sparse" not in sys.modules: | ||||||
return False | ||||||
|
||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
|
||||||
# TODO: Account for other backends. | ||||||
return isinstance(x, sparse.SparseArray) | ||||||
|
||||||
|
||||||
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] | ||||||
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given the current definition of _ArrayApiObj, TypeIs would cause downstream failures for all unknown array api compliant libraries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't |
||||||
""" | ||||||
Return True if `x` is an array API compatible array object. | ||||||
|
||||||
|
@@ -587,7 +582,7 @@ def your_function(x, y): | |||||
|
||||||
namespaces.add(cupy_namespace) | ||||||
else: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
namespaces.add(cp) | ||||||
elif is_torch_array(x): | ||||||
|
@@ -624,14 +619,14 @@ def your_function(x, y): | |||||
if hasattr(jax.numpy, "__array_api_version__"): | ||||||
jnp = jax.numpy | ||||||
else: | ||||||
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] | ||||||
import jax.experimental.array_api as jnp # type: ignore[no-redef] | ||||||
namespaces.add(jnp) | ||||||
elif is_pydata_sparse_array(x): | ||||||
if use_compat is True: | ||||||
_check_api_version(api_version) | ||||||
raise ValueError("`sparse` does not have an array-api-compat wrapper") | ||||||
else: | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's no need for the |
||||||
# `sparse` is already an array namespace. We do not have a wrapper | ||||||
# submodule for it. | ||||||
namespaces.add(sparse) | ||||||
|
@@ -640,9 +635,9 @@ def your_function(x, y): | |||||
raise ValueError( | ||||||
"The given array does not have an array-api-compat wrapper" | ||||||
) | ||||||
x = cast("SupportsArrayNamespace[Any]", x) | ||||||
x = cast(SupportsArrayNamespace, x) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is why I quoted it: https://docs.astral.sh/ruff/rules/runtime-cast-value/ |
||||||
namespaces.add(x.__array_namespace__(api_version=api_version)) | ||||||
elif isinstance(x, (bool, int, float, complex, type(None))): | ||||||
elif isinstance(x, int | float | complex | NoneType): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(I'll spare you the pseudo-philosophical rant this time) |
||||||
continue | ||||||
else: | ||||||
# TODO: Support Python scalars? | ||||||
|
@@ -738,7 +733,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
elif is_dask_array(x): | ||||||
# Peek at the metadata of the Dask array to determine type | ||||||
if is_numpy_array(x._meta): # pyright: ignore | ||||||
if is_numpy_array(x._meta): | ||||||
# Must be on CPU since backed by numpy | ||||||
return "cpu" | ||||||
return _DASK_DEVICE | ||||||
|
@@ -767,7 +762,7 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
return "cpu" | ||||||
# Return the device of the constituent array | ||||||
return device(inner) # pyright: ignore | ||||||
return x.device # pyright: ignore | ||||||
return x.device # type: ignore # pyright: ignore | ||||||
|
||||||
|
||||||
# Prevent shadowing, used below | ||||||
|
@@ -776,12 +771,12 @@ def device(x: _ArrayApiObj, /) -> Device: | |||||
|
||||||
# Based on cupy.array_api.Array.to_device | ||||||
def _cupy_to_device( | ||||||
x: _CupyArray, | ||||||
x: cp.ndarray, | ||||||
device: Device, | ||||||
/, | ||||||
stream: int | Any | None = None, | ||||||
) -> _CupyArray: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
) -> cp.ndarray: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. has this been fixed in cupy since last time or something? |
||||||
import cupy as cp | ||||||
from cupy.cuda import Device as _Device # pyright: ignore | ||||||
from cupy.cuda import stream as stream_module # pyright: ignore | ||||||
from cupy_backends.cuda.api import runtime # pyright: ignore | ||||||
|
@@ -797,10 +792,10 @@ def _cupy_to_device( | |||||
raise ValueError(f"Unsupported device {device!r}") | ||||||
else: | ||||||
# see cupy/cupy#5985 for the reason how we handle device/stream here | ||||||
prev_device: Any = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] | ||||||
prev_device: Device = runtime.getDevice() # pyright: ignore[reportUnknownMemberType] | ||||||
prev_stream = None | ||||||
if stream is not None: | ||||||
prev_stream: Any = stream_module.get_current_stream() # pyright: ignore | ||||||
prev_stream = stream_module.get_current_stream() # pyright: ignore | ||||||
# stream can be an int as specified in __dlpack__, or a CuPy stream | ||||||
if isinstance(stream, int): | ||||||
stream = cp.cuda.ExternalStream(stream) # pyright: ignore | ||||||
|
@@ -814,7 +809,7 @@ def _cupy_to_device( | |||||
arr = x.copy() | ||||||
finally: | ||||||
runtime.setDevice(prev_device) # pyright: ignore[reportUnknownMemberType] | ||||||
if stream is not None: | ||||||
if prev_stream is not None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. casual bugfix 🤔 ? |
||||||
prev_stream.use() | ||||||
return arr | ||||||
|
||||||
|
@@ -823,7 +818,7 @@ def _torch_to_device( | |||||
x: torch.Tensor, | ||||||
device: torch.device | str | int, | ||||||
/, | ||||||
stream: None = None, | ||||||
stream: int | Any | None = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps we should create a stream type-alias (unless there already is one) |
||||||
) -> torch.Tensor: | ||||||
if stream is not None: | ||||||
raise NotImplementedError | ||||||
|
@@ -889,7 +884,7 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
# cupy does not yet have to_device | ||||||
return _cupy_to_device(x, device, stream=stream) | ||||||
elif is_torch_array(x): | ||||||
return _torch_to_device(x, device, stream=stream) # pyright: ignore[reportArgumentType] | ||||||
return _torch_to_device(x, device, stream=stream) | ||||||
elif is_dask_array(x): | ||||||
if stream is not None: | ||||||
raise ValueError("The stream argument to to_device() is not supported") | ||||||
|
@@ -914,12 +909,12 @@ def to_device(x: Array, device: Device, /, *, stream: int | Any | None = None) - | |||||
|
||||||
|
||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex]]) -> int: ... | ||||||
def size(x: HasShape[int]) -> int: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[None]]) -> None: ... | ||||||
def size(x: HasShape[int | None]) -> int | None: ... | ||||||
@overload | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: ... | ||||||
def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | ||||||
def size(x: HasShape[float]) -> int | None: ... # Dask special case | ||||||
def size(x: HasShape[float | None]) -> int | None: | ||||||
Comment on lines
+912
to
+917
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why did you remove the overload that returns |
||||||
""" | ||||||
Return the total number of elements of x. | ||||||
|
||||||
|
@@ -934,12 +929,12 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: | |||||
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape | ||||||
if None in x.shape: | ||||||
return None | ||||||
out = math.prod(cast("Collection[SupportsIndex]", x.shape)) | ||||||
out = math.prod(cast(tuple[float, ...], x.shape)) | ||||||
# dask.array.Array.shape can contain NaN | ||||||
return None if math.isnan(out) else out | ||||||
return None if math.isnan(out) else cast(int, out) | ||||||
|
||||||
|
||||||
def is_writeable_array(x: object) -> bool: | ||||||
def is_writeable_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
""" | ||||||
Return False if ``x.__setitem__`` is expected to raise; True otherwise. | ||||||
Return False if `x` is not an array API compatible object. | ||||||
|
@@ -956,7 +951,7 @@ def is_writeable_array(x: object) -> bool: | |||||
return is_array_api_obj(x) | ||||||
|
||||||
|
||||||
def is_lazy_array(x: object) -> bool: | ||||||
def is_lazy_array(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Anything with a |
||||||
"""Return True if x is potentially a future or it may be otherwise impossible or | ||||||
expensive to eagerly read its contents, regardless of their size, e.g. by | ||||||
calling ``bool(x)`` or ``float(x)``. | ||||||
|
@@ -997,7 +992,7 @@ def is_lazy_array(x: object) -> bool: | |||||
# on __bool__ (dask is one such example, which however is special-cased above). | ||||||
|
||||||
# Select a single point of the array | ||||||
s = size(cast("HasShape[Collection[SupportsIndex | None]]", x)) | ||||||
s = size(cast(HasShape, x)) | ||||||
if s is None: | ||||||
return True | ||||||
xp = array_namespace(x) | ||||||
|
@@ -1044,5 +1039,6 @@ def is_lazy_array(x: object) -> bool: | |||||
|
||||||
_all_ignore = ["sys", "math", "inspect", "warnings"] | ||||||
|
||||||
|
||||||
def __dir__() -> list[str]: | ||||||
return __all__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With
we avoid the
types
import while simultaneously accentuating the violent dissonance between the Python runtime and its type-system, given that the sole purpose of a type-system is to accurately describe the runtime behavior...