Skip to content

Commit 308fc1f

Browse files
authored
TYP: Full annotations for Array objects (#159)
* ENH: Fully annotate Array * Update src/array_api_extra/_lib/_funcs.py * More compact _typing.py
1 parent ce7342e commit 308fc1f

15 files changed

+324
-165
lines changed

Diff for: pixi.lock

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ enable_error_code = ["ignore-without-code", "truthy-bool"]
203203
# https://github.com/data-apis/array-api-typing
204204
disallow_any_expr = false
205205
# false positives with input validation
206-
disable_error_code = ["redundant-expr", "unreachable"]
206+
disable_error_code = ["redundant-expr", "unreachable", "no-any-return"]
207207

208208
[[tool.mypy.overrides]]
209209
# slow/unavailable on Windows; do not add to the lint env

Diff for: src/array_api_extra/_lib/_at.py

+34-25
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_jax_array,
1616
is_writeable_array,
1717
)
18-
from ._utils._typing import Array, Index
18+
from ._utils._typing import Array, SetIndex
1919

2020

2121
class _AtOp(Enum):
@@ -43,7 +43,13 @@ def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[
4343
return self.value
4444

4545

46-
_undef = object()
46+
class Undef(Enum):
47+
"""Sentinel for undefined values."""
48+
49+
UNDEF = 0
50+
51+
52+
_undef = Undef.UNDEF
4753

4854

4955
class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
@@ -188,16 +194,16 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
188194
"""
189195

190196
_x: Array
191-
_idx: Index
197+
_idx: SetIndex | Undef
192198
__slots__: ClassVar[tuple[str, ...]] = ("_idx", "_x")
193199

194200
def __init__(
195-
self, x: Array, idx: Index = _undef, /
201+
self, x: Array, idx: SetIndex | Undef = _undef, /
196202
) -> None: # numpydoc ignore=GL08
197203
self._x = x
198204
self._idx = idx
199205

200-
def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
206+
def __getitem__(self, idx: SetIndex, /) -> at: # numpydoc ignore=PR01,RT01
201207
"""
202208
Allow for the alternate syntax ``at(x)[start:stop:step]``.
203209
@@ -212,9 +218,9 @@ def __getitem__(self, idx: Index, /) -> at: # numpydoc ignore=PR01,RT01
212218
def _op(
213219
self,
214220
at_op: _AtOp,
215-
in_place_op: Callable[[Array, Array | object], Array] | None,
221+
in_place_op: Callable[[Array, Array | complex], Array] | None,
216222
out_of_place_op: Callable[[Array, Array], Array] | None,
217-
y: Array | object,
223+
y: Array | complex,
218224
/,
219225
copy: bool | None,
220226
xp: ModuleType | None,
@@ -226,7 +232,7 @@ def _op(
226232
----------
227233
at_op : _AtOp
228234
Method of JAX's Array.at[].
229-
in_place_op : Callable[[Array, Array | object], Array] | None
235+
in_place_op : Callable[[Array, Array | complex], Array] | None
230236
In-place operation to apply on mutable backends::
231237
232238
x[idx] = in_place_op(x[idx], y)
@@ -245,7 +251,7 @@ def _op(
245251
246252
x = xp.where(idx, y, x)
247253
248-
y : array or object
254+
y : array or complex
249255
Right-hand side of the operation.
250256
copy : bool or None
251257
Whether to copy the input array. See the class docstring for details.
@@ -260,7 +266,7 @@ def _op(
260266
x, idx = self._x, self._idx
261267
xp = array_namespace(x, y) if xp is None else xp
262268

263-
if idx is _undef:
269+
if isinstance(idx, Undef):
264270
msg = (
265271
"Index has not been set.\n"
266272
"Usage: either\n"
@@ -306,7 +312,10 @@ def _op(
306312
if copy or (copy is None and not writeable):
307313
if is_jax_array(x):
308314
# Use JAX's at[]
309-
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
315+
func = cast(
316+
Callable[[Array | complex], Array],
317+
getattr(x.at[idx], at_op.value), # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue,reportUnknownArgumentType]
318+
)
310319
out = func(y)
311320
# Undo int->float promotion on JAX after _AtOp.DIVIDE
312321
return xp.astype(out, x.dtype, copy=False)
@@ -315,10 +324,10 @@ def _op(
315324
# with a copy followed by an update
316325

317326
x = xp.asarray(x, copy=True)
318-
if writeable is False:
319-
# A copy of a read-only numpy array is writeable
320-
# Note: this assumes that a copy of a writeable array is writeable
321-
writeable = None
327+
# A copy of a read-only numpy array is writeable
328+
# Note: this assumes that a copy of a writeable array is writeable
329+
assert not writeable
330+
writeable = None
322331

323332
if writeable is None:
324333
writeable = is_writeable_array(x)
@@ -328,14 +337,14 @@ def _op(
328337
raise ValueError(msg)
329338

330339
if in_place_op: # add(), subtract(), ...
331-
x[self._idx] = in_place_op(x[self._idx], y)
340+
x[idx] = in_place_op(x[idx], y)
332341
else: # set()
333-
x[self._idx] = y
342+
x[idx] = y
334343
return x
335344

336345
def set(
337346
self,
338-
y: Array | object,
347+
y: Array | complex,
339348
/,
340349
copy: bool | None = None,
341350
xp: ModuleType | None = None,
@@ -345,7 +354,7 @@ def set(
345354

346355
def add(
347356
self,
348-
y: Array | object,
357+
y: Array | complex,
349358
/,
350359
copy: bool | None = None,
351360
xp: ModuleType | None = None,
@@ -359,7 +368,7 @@ def add(
359368

360369
def subtract(
361370
self,
362-
y: Array | object,
371+
y: Array | complex,
363372
/,
364373
copy: bool | None = None,
365374
xp: ModuleType | None = None,
@@ -371,7 +380,7 @@ def subtract(
371380

372381
def multiply(
373382
self,
374-
y: Array | object,
383+
y: Array | complex,
375384
/,
376385
copy: bool | None = None,
377386
xp: ModuleType | None = None,
@@ -383,7 +392,7 @@ def multiply(
383392

384393
def divide(
385394
self,
386-
y: Array | object,
395+
y: Array | complex,
387396
/,
388397
copy: bool | None = None,
389398
xp: ModuleType | None = None,
@@ -395,7 +404,7 @@ def divide(
395404

396405
def power(
397406
self,
398-
y: Array | object,
407+
y: Array | complex,
399408
/,
400409
copy: bool | None = None,
401410
xp: ModuleType | None = None,
@@ -405,7 +414,7 @@ def power(
405414

406415
def min(
407416
self,
408-
y: Array | object,
417+
y: Array | complex,
409418
/,
410419
copy: bool | None = None,
411420
xp: ModuleType | None = None,
@@ -417,7 +426,7 @@ def min(
417426

418427
def max(
419428
self,
420-
y: Array | object,
429+
y: Array | complex,
421430
/,
422431
copy: bool | None = None,
423432
xp: ModuleType | None = None,

Diff for: src/array_api_extra/_lib/_funcs.py

+24-35
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import warnings
88
from collections.abc import Sequence
99
from types import ModuleType
10-
from typing import TYPE_CHECKING, cast
10+
from typing import cast
1111

1212
from ._at import at
1313
from ._utils import _compat, _helpers
1414
from ._utils._compat import array_namespace, is_jax_array
15-
from ._utils._helpers import asarrays, ndindex
15+
from ._utils._helpers import asarrays, eager_shape, ndindex
1616
from ._utils._typing import Array
1717

1818
__all__ = [
@@ -211,11 +211,13 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
211211
m = xp.astype(m, dtype)
212212

213213
avg = _helpers.mean(m, axis=1, xp=xp)
214-
fact = m.shape[1] - 1
214+
215+
m_shape = eager_shape(m)
216+
fact = m_shape[1] - 1
215217

216218
if fact <= 0:
217219
warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
218-
fact = 0.0
220+
fact = 0
219221

220222
m -= avg[:, None]
221223
m_transpose = m.T
@@ -274,8 +276,10 @@ def create_diagonal(
274276
if x.ndim == 0:
275277
err_msg = "`x` must be at least 1-dimensional."
276278
raise ValueError(err_msg)
277-
batch_dims = x.shape[:-1]
278-
n = x.shape[-1] + abs(offset)
279+
280+
x_shape = eager_shape(x)
281+
batch_dims = x_shape[:-1]
282+
n = x_shape[-1] + abs(offset)
279283
diag = xp.zeros((*batch_dims, n**2), dtype=x.dtype, device=_compat.device(x))
280284

281285
target_slice = slice(
@@ -385,10 +389,6 @@ def isclose(
385389
) -> Array: # numpydoc ignore=PR01,RT01
386390
"""See docstring in array_api_extra._delegation."""
387391
a, b = asarrays(a, b, xp=xp)
388-
# FIXME https://github.com/microsoft/pyright/issues/10085
389-
if TYPE_CHECKING: # pragma: nocover
390-
assert _compat.is_array_api_obj(a)
391-
assert _compat.is_array_api_obj(b)
392392

393393
a_inexact = xp.isdtype(a.dtype, ("real floating", "complex floating"))
394394
b_inexact = xp.isdtype(b.dtype, ("real floating", "complex floating"))
@@ -505,24 +505,17 @@ def kron(
505505
if xp is None:
506506
xp = array_namespace(a, b)
507507
a, b = asarrays(a, b, xp=xp)
508-
# FIXME https://github.com/microsoft/pyright/issues/10085
509-
if TYPE_CHECKING: # pragma: nocover
510-
assert _compat.is_array_api_obj(a)
511-
assert _compat.is_array_api_obj(b)
512508

513509
singletons = (1,) * (b.ndim - a.ndim)
514-
a = xp.broadcast_to(a, singletons + a.shape)
515-
# FIXME https://github.com/microsoft/pyright/issues/10085
516-
if TYPE_CHECKING: # pragma: nocover
517-
assert _compat.is_array_api_obj(a)
510+
a = cast(Array, xp.broadcast_to(a, singletons + a.shape))
518511

519512
nd_b, nd_a = b.ndim, a.ndim
520513
nd_max = max(nd_b, nd_a)
521514
if nd_a == 0 or nd_b == 0:
522515
return xp.multiply(a, b)
523516

524-
a_shape = a.shape
525-
b_shape = b.shape
517+
a_shape = eager_shape(a)
518+
b_shape = eager_shape(b)
526519

527520
# Equalise the shapes by prepending smaller one with 1s
528521
a_shape = (1,) * max(0, nd_b - nd_a) + a_shape
@@ -587,16 +580,14 @@ def pad(
587580
) -> Array: # numpydoc ignore=PR01,RT01
588581
"""See docstring in `array_api_extra._delegation.py`."""
589582
# make pad_width a list of length-2 tuples of ints
590-
x_ndim = cast(int, x.ndim)
591-
592583
if isinstance(pad_width, int):
593-
pad_width_seq = [(pad_width, pad_width)] * x_ndim
584+
pad_width_seq = [(pad_width, pad_width)] * x.ndim
594585
elif (
595586
isinstance(pad_width, tuple)
596587
and len(pad_width) == 2
597588
and all(isinstance(i, int) for i in pad_width)
598589
):
599-
pad_width_seq = [cast(tuple[int, int], pad_width)] * x_ndim
590+
pad_width_seq = [cast(tuple[int, int], pad_width)] * x.ndim
600591
else:
601592
pad_width_seq = cast(list[tuple[int, int]], list(pad_width))
602593

@@ -608,7 +599,8 @@ def pad(
608599
msg = f"expect a 2-tuple (before, after), got {w_tpl}."
609600
raise ValueError(msg)
610601

611-
sh = x.shape[ax]
602+
sh = eager_shape(x)[ax]
603+
612604
if w_tpl[0] == 0 and w_tpl[1] == 0:
613605
sl = slice(None, None, None)
614606
else:
@@ -674,20 +666,17 @@ def setdiff1d(
674666
"""
675667
if xp is None:
676668
xp = array_namespace(x1, x2)
677-
x1, x2 = asarrays(x1, x2, xp=xp)
669+
# https://github.com/microsoft/pyright/issues/10103
670+
x1_, x2_ = asarrays(x1, x2, xp=xp)
678671

679672
if assume_unique:
680-
x1 = xp.reshape(x1, (-1,))
681-
x2 = xp.reshape(x2, (-1,))
673+
x1_ = xp.reshape(x1_, (-1,))
674+
x2_ = xp.reshape(x2_, (-1,))
682675
else:
683-
x1 = xp.unique_values(x1)
684-
x2 = xp.unique_values(x2)
685-
686-
# FIXME https://github.com/microsoft/pyright/issues/10085
687-
if TYPE_CHECKING: # pragma: nocover
688-
assert _compat.is_array_api_obj(x1)
676+
x1_ = xp.unique_values(x1_)
677+
x2_ = xp.unique_values(x2_)
689678

690-
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)]
679+
return x1_[_helpers.in1d(x1_, x2_, assume_unique=True, invert=True, xp=xp)]
691680

692681

693682
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:

0 commit comments

Comments
 (0)