Skip to content

Commit 57e6cd0

Browse files
committed
ENH Test tools for jax.jit and dask
1 parent 6ee70c0 commit 57e6cd0

File tree

11 files changed

+457
-28
lines changed

11 files changed

+457
-28
lines changed

Diff for: docs/api-reference.md

+12
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,15 @@
1818
setdiff1d
1919
sinc
2020
```
21+
22+
## Testing utilities
23+
24+
```{eval-rst}
25+
.. currentmodule:: array_api_extra.testing
26+
.. autosummary::
27+
:nosignatures:
28+
:toctree: generated
29+
30+
lazy_xp_function
31+
patch_lazy_xp_functions
32+
```

Diff for: pixi.lock

+13-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+11-2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ furo = ">=2023.08.17"
9898
myst-parser = ">=0.13"
9999
sphinx-copybutton = "*"
100100
sphinx-autodoc-typehints = "*"
101+
# Needed to import parsed modules with autodoc
102+
pytest = "*"
101103

102104
[tool.pixi.feature.docs.tasks]
103105
docs = { cmd = "sphinx-build . build/", cwd = "docs" }
@@ -180,8 +182,10 @@ markers = ["skip_xp_backend(library, *, reason=None): Skip test for a specific b
180182

181183
[tool.coverage]
182184
run.source = ["array_api_extra"]
183-
report.exclude_also = ['\.\.\.']
184-
185+
report.exclude_also = [
186+
'\.\.\.',
187+
'if TYPE_CHECKING:',
188+
]
185189

186190
# mypy
187191

@@ -221,6 +225,8 @@ reportMissingImports = false
221225
reportMissingTypeStubs = false
222226
# false positives for input validation
223227
reportUnreachable = false
228+
# ruff handles this
229+
reportUnusedParameter = false
224230

225231
executionEnvironments = [
226232
{ root = "tests", reportPrivateUsage = false },
@@ -282,7 +288,10 @@ messages_control.disable = [
282288
"design", # ignore heavily opinionated design checks
283289
"fixme", # allow FIXME comments
284290
"line-too-long", # ruff handles this
291+
"unused-argument", # ruff handles this
285292
"missing-function-docstring", # numpydoc handles this
293+
"import-error", # mypy handles this
294+
"import-outside-toplevel", # optional dependencies
286295
]
287296

288297

Diff for: src/array_api_extra/_lib/_at.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
188188
def _update_common(
189189
self,
190190
at_op: _AtOp,
191-
y: Array,
191+
y: Array | object,
192192
/,
193193
copy: bool | None,
194194
xp: ModuleType | None,
@@ -253,7 +253,7 @@ def _update_common(
253253

254254
def set(
255255
self,
256-
y: Array,
256+
y: Array | object,
257257
/,
258258
copy: bool | None = None,
259259
xp: ModuleType | None = None,
@@ -269,8 +269,8 @@ def set(
269269
def _iop(
270270
self,
271271
at_op: _AtOp,
272-
elwise_op: Callable[[Array, Array], Array],
273-
y: Array,
272+
elwise_op: Callable[[Array, Array | object], Array],
273+
y: Array | object,
274274
/,
275275
copy: bool | None,
276276
xp: ModuleType | None,
@@ -294,7 +294,7 @@ def _iop(
294294

295295
def add(
296296
self,
297-
y: Array,
297+
y: Array | object,
298298
/,
299299
copy: bool | None = None,
300300
xp: ModuleType | None = None,
@@ -308,7 +308,7 @@ def add(
308308

309309
def subtract(
310310
self,
311-
y: Array,
311+
y: Array | object,
312312
/,
313313
copy: bool | None = None,
314314
xp: ModuleType | None = None,
@@ -318,7 +318,7 @@ def subtract(
318318

319319
def multiply(
320320
self,
321-
y: Array,
321+
y: Array | object,
322322
/,
323323
copy: bool | None = None,
324324
xp: ModuleType | None = None,
@@ -328,7 +328,7 @@ def multiply(
328328

329329
def divide(
330330
self,
331-
y: Array,
331+
y: Array | object,
332332
/,
333333
copy: bool | None = None,
334334
xp: ModuleType | None = None,
@@ -338,7 +338,7 @@ def divide(
338338

339339
def power(
340340
self,
341-
y: Array,
341+
y: Array | object,
342342
/,
343343
copy: bool | None = None,
344344
xp: ModuleType | None = None,
@@ -348,7 +348,7 @@ def power(
348348

349349
def min(
350350
self,
351-
y: Array,
351+
y: Array | object,
352352
/,
353353
copy: bool | None = None,
354354
xp: ModuleType | None = None,
@@ -361,7 +361,7 @@ def min(
361361

362362
def max(
363363
self,
364-
y: Array,
364+
y: Array | object,
365365
/,
366366
copy: bool | None = None,
367367
xp: ModuleType | None = None,

Diff for: src/array_api_extra/_lib/_testing.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Testing utilities.
33
44
Note that this is private API; don't expect it to be stable.
5+
See also ..testing for public testing utilities.
56
"""
67

78
import math

0 commit comments

Comments
 (0)