Skip to content

Commit bcae393

Browse files
committed
(temp) switch to pyright
1 parent fe8fa8b commit bcae393

File tree

5 files changed

+50
-55
lines changed

5 files changed

+50
-55
lines changed

Diff for: pixi.lock

+20-39
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+19-9
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,16 @@ numpy = "*"
7878
pytest = "*"
7979

8080
[tool.pixi.feature.lint.pypi-dependencies]
81-
basedpyright = "*"
81+
# basedpyright = "*"
82+
pyright = "*"
8283

8384
[tool.pixi.feature.lint.tasks]
8485
pre-commit-install = { cmd = "pre-commit install" }
8586
pre-commit = { cmd = "pre-commit run -v --all-files --show-diff-on-failure" }
8687
mypy = { cmd = "mypy", cwd = "." }
8788
pylint = { cmd = ["pylint", "array_api_extra"], cwd = "src" }
88-
pyright = { cmd = "basedpyright", cwd = "." }
89+
# pyright = { cmd = "basedpyright", cwd = "." }
90+
pyright = { cmd = "pyright", cwd = "." }
8991
lint = { depends-on = ["pre-commit", "pylint", "mypy", "pyright"] }
9092

9193
[tool.pixi.feature.tests.dependencies]
@@ -180,17 +182,25 @@ disallow_incomplete_defs = true
180182

181183
# pyright
182184

183-
[tool.basedpyright]
185+
# [tool.basedpyright]
186+
# include = ["src", "tests"]
187+
# pythonVersion = "3.10"
188+
# pythonPlatform = "All"
189+
# typeCheckingMode = "all"
190+
191+
# # data-apis/array-api#589
192+
# reportAny = false
193+
# reportExplicitAny = false
194+
# # data-apis/array-api-strict#6
195+
# reportUnknownMemberType = false
196+
197+
[tool.pyright]
184198
include = ["src", "tests"]
185199
pythonVersion = "3.10"
186200
pythonPlatform = "All"
187-
typeCheckingMode = "all"
188-
189-
# data-apis/array-api#589
190-
reportAny = false
191-
reportExplicitAny = false
192-
# data-apis/array-api-strict#6
201+
typeCheckingMode = "strict"
193202
reportUnknownMemberType = false
203+
reportImportCycles = true
194204

195205

196206
# Ruff

Diff for: src/array_api_extra/_compat.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import sys
77
import typing
88

9+
from typing_extensions import override
10+
911
if typing.TYPE_CHECKING:
1012
from ._typing import Array, Device
1113

@@ -16,6 +18,7 @@
1618
# when the array backend is not the CPU.
1719
# (since it is not easy to tell which device a dask array is on)
1820
class _dask_device: # pylint: disable=invalid-name
21+
@override
1922
def __repr__(self) -> str:
2023
return "DASK_DEVICE"
2124

@@ -118,7 +121,7 @@ def _is_dask_array(x: Array) -> bool:
118121
return False
119122

120123
# pylint: disable=import-error, import-outside-toplevel
121-
import dask.array # type: ignore[import-not-found]
124+
import dask.array # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
122125

123126
return isinstance(x, dask.array.Array)
124127

@@ -133,10 +136,10 @@ def _is_jax_zero_gradient_array(x: Array) -> bool:
133136
return False
134137

135138
# pylint: disable=import-error, import-outside-toplevel
136-
import jax # type: ignore[import-not-found]
139+
import jax # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
137140
import numpy as np # pylint: disable=import-outside-toplevel
138141

139-
return isinstance(x, np.ndarray) and x.dtype == jax.float0
142+
return isinstance(x, np.ndarray) and x.dtype == jax.float0 # pyright: ignore[reportUnknownVariableType]
140143

141144

142145
def _is_jax_array(x: Array) -> bool:
@@ -146,7 +149,7 @@ def _is_jax_array(x: Array) -> bool:
146149
return False
147150

148151
# pylint: disable=import-error, import-outside-toplevel
149-
import jax
152+
import jax # pyright: ignore[reportMissingImports]
150153

151154
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
152155

@@ -159,7 +162,7 @@ def _is_pydata_sparse_array(x: Array) -> bool:
159162
return False
160163

161164
# pylint: disable=import-error, import-outside-toplevel
162-
import sparse # type: ignore[import-not-found]
165+
import sparse # type: ignore[import-not-found] # pyright: ignore[reportMissingImports]
163166

164167
# TODO: Account for other backends.
165168
return isinstance(x, sparse.SparseArray)

Diff for: src/array_api_extra/_typing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55

66
# To be changed to a Protocol later (see data-apis/array-api#589)
77
Array = Any # type: ignore[no-any-explicit]
8-
Device = Any
8+
Device = Any # type: ignore[no-any-explicit]
99

1010
__all__ = ["Array", "Device", "ModuleType"]

Diff for: src/array_api_extra/_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def in1d(
5858

5959
if assume_unique:
6060
return ret[: x1.shape[0]]
61+
# https://github.com/KotlinIsland/basedmypy/issues/826
6162
# https://github.com/pylint-dev/pylint/issues/10095
6263
# pylint: disable=possibly-used-before-assignment
63-
return xp.take(ret, rev_idx, axis=0)
64+
return xp.take(ret, rev_idx, axis=0) # type: ignore[possibly-undefined] # pyright: ignore[reportPossiblyUnboundVariable]

0 commit comments

Comments
 (0)