Skip to content

Commit 44ebd51

Browse files
committed
self-review
1 parent d6ef302 commit 44ebd51

10 files changed

+28
-24
lines changed

array_api_strict/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
consuming libraries to test their array API usage.
1616
1717
"""
18+
1819
from types import ModuleType
1920

2021
__all__ = []
@@ -123,7 +124,7 @@
123124
"bool",
124125
]
125126

126-
from ._elementwise_functions import ( # type: ignore[attr-defined]
127+
from ._elementwise_functions import (
127128
abs,
128129
acos,
129130
acosh,

array_api_strict/_array_object.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
except ImportError:
4848
EllipsisType = type(Ellipsis)
4949

50+
5051
class Device:
5152
_device: str
5253
__slots__ = ("_device", "__weakref__")
@@ -91,6 +92,7 @@ class Array:
9192
functions, such as asarray().
9293
9394
"""
95+
9496
_array: npt.NDArray[Any]
9597
_dtype: DType
9698
_device: Device

array_api_strict/_creation_functions.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Circular import
1616
from ._array_object import Array, Device
1717

18+
1819
class _Default(Enum):
1920
DEFAULT = 0
2021

@@ -140,7 +141,7 @@ def arange(
140141
_check_device(device)
141142

142143
return Array._new(
143-
np.arange(start, stop, step, dtype=_np_dtype(dtype)),
144+
np.arange(start, stop, step, dtype=_np_dtype(dtype)),
144145
device=device,
145146
)
146147

@@ -201,7 +202,9 @@ def eye(
201202
_check_valid_dtype(dtype)
202203
_check_device(device)
203204

204-
return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device)
205+
return Array._new(
206+
np.eye(n_rows, M=n_cols, k=k, dtype=_np_dtype(dtype)), device=device
207+
)
205208

206209

207210
def from_dlpack(

array_api_strict/_data_type_functions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
from ._flags import get_array_api_strict_flags
2222

2323

24-
# _default is used to emulate the asarray(device) argument not existing in 2022.12
2524
# Note: astype is a function, not an array method as in NumPy.
2625
def astype(
2726
x: Array,
2827
dtype: DType,
2928
/,
3029
*,
3130
copy: bool = True,
31+
# _default is used to emulate the device argument not existing in 2022.12
3232
device: Device | _Default | None = _default,
3333
) -> Array:
3434
if device is not _default:
@@ -164,7 +164,7 @@ def isdtype(dtype: DType, kind: DType | str | tuple[DType | str, ...]) -> bool:
164164
for more details
165165
"""
166166
if not isinstance(dtype, DType):
167-
raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}")
167+
raise TypeError(f"'dtype' must be a dtype, not a {type(dtype)!r}")
168168

169169
if isinstance(kind, tuple):
170170
# Disallow nested tuples

array_api_strict/_dtypes.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
py_bool = bool
1313

14+
1415
class DType:
1516
_np_dtype: np.dtype[Any]
1617
__slots__ = ("_np_dtype", "__weakref__")

array_api_strict/_flags.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import functools
1818
import os
1919
import warnings
20-
from collections.abc import Callable, Collection
20+
from collections.abc import Callable
2121
from types import TracebackType
22-
from typing import TYPE_CHECKING, Any, TypeVar
22+
from typing import TYPE_CHECKING, Any, Collection, TypeVar
2323

2424
import array_api_strict
2525

@@ -28,8 +28,7 @@
2828
from typing_extensions import ParamSpec
2929

3030
P = ParamSpec("P")
31-
else:
32-
P = object # Sphinx hack
31+
3332
T = TypeVar("T")
3433

3534

array_api_strict/_helpers.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
"""Private helper routines.
2-
"""
1+
"""Private helper routines."""
2+
33
from __future__ import annotations
44

55
from ._array_object import Array
@@ -11,11 +11,10 @@
1111

1212
def _maybe_normalize_py_scalars(
1313
x1: Array | complex,
14-
x2: Array | complex,
14+
x2: Array | complex,
1515
dtype_category: str,
1616
func_name: str,
1717
) -> tuple[Array, Array]:
18-
1918
flags = get_array_api_strict_flags()
2019
if flags["api_version"] < "2024.12":
2120
# scalars will fail at the call site

array_api_strict/_linalg.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,8 @@ def slogdet(x: Array, /) -> SlogdetResult:
320320
def _solve(a: np.ndarray, b: np.ndarray) -> np.ndarray:
321321
try:
322322
from numpy.linalg._linalg import ( # type: ignore[attr-defined]
323-
_assert_stacked_2d,
324-
_assert_stacked_square,
325-
_commonType,
326-
_makearray,
327-
_raise_linalgerror_singular,
328-
isComplexType,
323+
_makearray, _assert_stacked_2d, _assert_stacked_square,
324+
_commonType, isComplexType, _raise_linalgerror_singular
329325
)
330326
except ImportError:
331327
from numpy.linalg.linalg import ( # type: ignore[attr-defined]
@@ -412,7 +408,8 @@ def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array:
412408

413409
# Note: trace always operates on the last two axes, whereas np.trace
414410
# operates on the first two axes by default
415-
return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=np_dtype)), device=x.device)
411+
res = np.trace(x._array, offset=offset, axis1=-2, axis2=-1, dtype=np_dtype)
412+
return Array._new(np.asarray(res), device=x.device)
416413

417414
# Note: the name here is different from norm(). The array API norm is split
418415
# into matrix_norm and vector_norm().

array_api_strict/_typing.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def __dlpack__(self, /, *, stream: None = ...) -> PyCapsule: ...
3535

3636
Capabilities = TypedDict(
3737
"Capabilities",
38-
{"boolean indexing": bool, "data-dependent shapes": bool, "max dimensions": int},
38+
{
39+
"boolean indexing": bool,
40+
"data-dependent shapes": bool,
41+
"max dimensions": int,
42+
},
3943
)
4044

4145
DefaultDataTypes = TypedDict(

array_api_strict/tests/test_validation.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
from typing import Callable
2-
31
import pytest
42

53
import array_api_strict as xp
64

75

8-
def p(func: Callable, *args, **kwargs):
6+
def p(func, *args, **kwargs):
97
f_sig = ", ".join(
108
[str(a) for a in args] + [f"{k}={v}" for k, v in kwargs.items()]
119
)

0 commit comments

Comments
 (0)