Skip to content

Commit dab01be

Browse files
authored
Merge pull request #82 from adonath/ruff_ci
Add ruff to ci setup
2 parents 9cb5a13 + 2db3d6a commit dab01be

27 files changed

+1140
-382
lines changed

Diff for: .github/workflows/ruff.yml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
name: CI
2+
on: [push, pull_request]
3+
jobs:
4+
check-ruff:
5+
runs-on: ubuntu-latest
6+
continue-on-error: true
7+
steps:
8+
- uses: actions/checkout@v3
9+
- name: Install Python
10+
uses: actions/setup-python@v4
11+
with:
12+
python-version: "3.11"
13+
- name: Install dependencies
14+
run: |
15+
python -m pip install --upgrade pip
16+
pip install ruff
17+
# Update output format to enable automatic inline annotations.
18+
- name: Run Ruff
19+
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .

Diff for: array_api_compat/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@
1919
"""
2020
__version__ = '1.4.1'
2121

22-
from .common import *
22+
from .common import * # noqa: F401, F403

Diff for: array_api_compat/_internal.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from functools import wraps
66
from inspect import signature
77

8+
89
def get_xp(xp):
910
"""
1011
Decorator to automatically replace xp with the corresponding array module.
@@ -21,13 +22,16 @@ def func(x, /, xp, kwarg=None):
2122
arguments.
2223
2324
"""
25+
2426
def inner(f):
2527
@wraps(f)
2628
def wrapped_f(*args, **kwargs):
2729
return f(*args, xp=xp, **kwargs)
2830

2931
sig = signature(f)
30-
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
32+
new_sig = sig.replace(
33+
parameters=[sig.parameters[i] for i in sig.parameters if i != "xp"]
34+
)
3135

3236
if wrapped_f.__doc__ is None:
3337
wrapped_f.__doc__ = f"""\
@@ -41,3 +45,31 @@ def wrapped_f(*args, **kwargs):
4145
return wrapped_f
4246

4347
return inner
48+
49+
50+
def _get_all_public_members(module, exclude=None, extend_all=False):
51+
"""Get all public members of a module.
52+
53+
Parameters
54+
----------
55+
module : module
56+
The module to get members from.
57+
exclude : callable, optional
58+
A callable that takes a name and returns True if the name should be
59+
excluded from the list of members.
60+
extend_all : bool, optional
61+
If True, extend the module's __all__ attribute with the members of the
62+
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
63+
"""
64+
members = getattr(module, "__all__", [])
65+
66+
if members and not extend_all:
67+
return members
68+
69+
if exclude is None:
70+
exclude = lambda name: name.startswith("_") # noqa: E731
71+
72+
members = members + [_ for _ in dir(module) if not exclude(_)]
73+
74+
# remove duplicates
75+
return list(set(members))

Diff for: array_api_compat/common/__init__.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,17 @@
1-
from ._helpers import *
1+
from ._helpers import (
2+
array_namespace,
3+
device,
4+
get_namespace,
5+
is_array_api_obj,
6+
size,
7+
to_device,
8+
)
9+
10+
__all__ = [
11+
"array_namespace",
12+
"device",
13+
"get_namespace",
14+
"is_array_api_obj",
15+
"size",
16+
"to_device",
17+
]

Diff for: array_api_compat/common/_aliases.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from typing import TYPE_CHECKING
88
if TYPE_CHECKING:
9-
from typing import Optional, Sequence, Tuple, Union, List
9+
import numpy as np
10+
from typing import Optional, Sequence, Tuple, Union
1011
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
1112

1213
from typing import NamedTuple
@@ -544,11 +545,3 @@ def isdtype(
544545
# more strict here to match the type annotation? Note that the
545546
# numpy.array_api implementation will be very strict.
546547
return dtype == kind
547-
548-
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
549-
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
550-
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
551-
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
552-
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
553-
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
554-
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

Diff for: array_api_compat/common/_helpers.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
"""
88
from __future__ import annotations
99

10+
from typing import TYPE_CHECKING
11+
12+
if TYPE_CHECKING:
13+
from typing import Optional, Union, Any
14+
from ._typing import Array, Device
15+
1016
import sys
1117
import math
1218

@@ -142,7 +148,7 @@ def _check_device(xp, device):
142148
# wrapping or subclassing them. These helper functions can be used instead of
143149
# the wrapper functions for libraries that need to support both NumPy/CuPy and
144150
# other libraries that use devices.
145-
def device(x: "Array", /) -> "Device":
151+
def device(x: Array, /) -> Device:
146152
"""
147153
Hardware device the array data resides on.
148154
@@ -204,7 +210,7 @@ def _torch_to_device(x, device, /, stream=None):
204210
raise NotImplementedError
205211
return x.to(device)
206212

207-
def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array":
213+
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
208214
"""
209215
Copy the array from the device on which it currently resides to the specified ``device``.
210216
@@ -252,5 +258,3 @@ def size(x):
252258
if None in x.shape:
253259
return None
254260
return math.prod(x.shape)
255-
256-
__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']

Diff for: array_api_compat/common/_linalg.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING, NamedTuple
44
if TYPE_CHECKING:
5-
from typing import Literal, Optional, Sequence, Tuple, Union
5+
from typing import Literal, Optional, Tuple, Union
66
from ._typing import ndarray
77

88
import numpy as np
@@ -11,7 +11,7 @@
1111
else:
1212
from numpy.core.numeric import normalize_axis_tuple
1313

14-
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
14+
from ._aliases import matrix_transpose, isdtype
1515
from .._internal import get_xp
1616

1717
# These are in the main NumPy namespace but not in numpy.linalg
@@ -149,10 +149,4 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
149149
dtype = xp.float64
150150
elif x.dtype == xp.complex64:
151151
dtype = xp.complex128
152-
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
153-
154-
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
155-
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
156-
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
157-
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
158-
'trace']
152+
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))

Diff for: array_api_compat/common/_typing.py

+3
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ def __getitem__(self, key: int, /) -> _T_co | NestedSequence[_T_co]: ...
1818
def __len__(self, /) -> int: ...
1919

2020
SupportsBufferProtocol = Any
21+
22+
Array = Any
23+
Device = Any

Diff for: array_api_compat/cupy/__init__.py

+144-7
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,153 @@
1-
from cupy import *
1+
import cupy as _cp
2+
from cupy import * # noqa: F401, F403
23

34
# from cupy import * doesn't overwrite these builtin names
45
from cupy import abs, max, min, round
56

7+
from .._internal import _get_all_public_members
8+
from ..common._helpers import (
9+
array_namespace,
10+
device,
11+
get_namespace,
12+
is_array_api_obj,
13+
size,
14+
to_device,
15+
)
16+
617
# These imports may overwrite names from the import * above.
7-
from ._aliases import *
18+
from ._aliases import (
19+
UniqueAllResult,
20+
UniqueCountsResult,
21+
UniqueInverseResult,
22+
acos,
23+
acosh,
24+
arange,
25+
argsort,
26+
asarray,
27+
asarray_cupy,
28+
asin,
29+
asinh,
30+
astype,
31+
atan,
32+
atan2,
33+
atanh,
34+
bitwise_invert,
35+
bitwise_left_shift,
36+
bitwise_right_shift,
37+
bool,
38+
ceil,
39+
concat,
40+
empty,
41+
empty_like,
42+
eye,
43+
floor,
44+
full,
45+
full_like,
46+
isdtype,
47+
linspace,
48+
matmul,
49+
matrix_transpose,
50+
nonzero,
51+
ones,
52+
ones_like,
53+
permute_dims,
54+
pow,
55+
prod,
56+
reshape,
57+
sort,
58+
std,
59+
sum,
60+
tensordot,
61+
trunc,
62+
unique_all,
63+
unique_counts,
64+
unique_inverse,
65+
unique_values,
66+
var,
67+
vecdot,
68+
zeros,
69+
zeros_like,
70+
)
871

9-
# See the comment in the numpy __init__.py
10-
__import__(__package__ + '.linalg')
72+
__all__ = []
73+
74+
__all__ += _get_all_public_members(_cp)
75+
76+
__all__ += [
77+
"abs",
78+
"max",
79+
"min",
80+
"round",
81+
]
1182

12-
from .linalg import matrix_transpose, vecdot
83+
__all__ += [
84+
"array_namespace",
85+
"device",
86+
"get_namespace",
87+
"is_array_api_obj",
88+
"size",
89+
"to_device",
90+
]
1391

14-
from ..common._helpers import *
92+
__all__ += [
93+
"UniqueAllResult",
94+
"UniqueCountsResult",
95+
"UniqueInverseResult",
96+
"acos",
97+
"acosh",
98+
"arange",
99+
"argsort",
100+
"asarray",
101+
"asarray_cupy",
102+
"asin",
103+
"asinh",
104+
"astype",
105+
"atan",
106+
"atan2",
107+
"atanh",
108+
"bitwise_invert",
109+
"bitwise_left_shift",
110+
"bitwise_right_shift",
111+
"bool",
112+
"ceil",
113+
"concat",
114+
"empty",
115+
"empty_like",
116+
"eye",
117+
"floor",
118+
"full",
119+
"full_like",
120+
"isdtype",
121+
"linspace",
122+
"matmul",
123+
"matrix_transpose",
124+
"nonzero",
125+
"ones",
126+
"ones_like",
127+
"permute_dims",
128+
"pow",
129+
"prod",
130+
"reshape",
131+
"sort",
132+
"std",
133+
"sum",
134+
"tensordot",
135+
"trunc",
136+
"unique_all",
137+
"unique_counts",
138+
"unique_inverse",
139+
"unique_values",
140+
"var",
141+
"zeros",
142+
"zeros_like",
143+
]
144+
145+
__all__ += [
146+
"matrix_transpose",
147+
"vecdot",
148+
]
149+
150+
# See the comment in the numpy __init__.py
151+
__import__(__package__ + ".linalg")
15152

16-
__array_api_version__ = '2022.12'
153+
__array_api_version__ = "2022.12"

0 commit comments

Comments
 (0)