Skip to content

Commit 415a30f

Browse files
authored
Merge pull request numpy#28529 from jorenham/numtype/235
TYP: fix stubtest errors in ``numpy.lib._index_tricks_impl``
2 parents f925baf + e020b23 commit 415a30f

File tree

6 files changed

+190
-210
lines changed

6 files changed

+190
-210
lines changed

numpy/__init__.pyi

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,8 @@ from numpy.lib._histograms_impl import (
515515
)
516516

517517
from numpy.lib._index_tricks_impl import (
518+
ndenumerate,
519+
ndindex,
518520
ravel_multi_index,
519521
unravel_index,
520522
mgrid,
@@ -4951,50 +4953,6 @@ class errstate:
49514953
) -> None: ...
49524954
def __call__(self, func: _CallableT) -> _CallableT: ...
49534955

4954-
class ndenumerate(Generic[_SCT_co]):
4955-
@property
4956-
def iter(self) -> flatiter[NDArray[_SCT_co]]: ...
4957-
4958-
@overload
4959-
def __new__(
4960-
cls, arr: _FiniteNestedSequence[_SupportsArray[dtype[_SCT]]],
4961-
) -> ndenumerate[_SCT]: ...
4962-
@overload
4963-
def __new__(cls, arr: str | _NestedSequence[str]) -> ndenumerate[str_]: ...
4964-
@overload
4965-
def __new__(cls, arr: bytes | _NestedSequence[bytes]) -> ndenumerate[bytes_]: ...
4966-
@overload
4967-
def __new__(cls, arr: builtins.bool | _NestedSequence[builtins.bool]) -> ndenumerate[np.bool]: ...
4968-
@overload
4969-
def __new__(cls, arr: int | _NestedSequence[int]) -> ndenumerate[int_]: ...
4970-
@overload
4971-
def __new__(cls, arr: float | _NestedSequence[float]) -> ndenumerate[float64]: ...
4972-
@overload
4973-
def __new__(cls, arr: complex | _NestedSequence[complex]) -> ndenumerate[complex128]: ...
4974-
@overload
4975-
def __new__(cls, arr: object) -> ndenumerate[object_]: ...
4976-
4977-
# The first overload is a (semi-)workaround for a mypy bug (tested with v1.10 and v1.11)
4978-
@overload
4979-
def __next__(
4980-
self: ndenumerate[np.bool | datetime64 | timedelta64 | number[Any] | flexible],
4981-
/,
4982-
) -> tuple[_Shape, _SCT_co]: ...
4983-
@overload
4984-
def __next__(self: ndenumerate[object_], /) -> tuple[_Shape, Any]: ...
4985-
@overload
4986-
def __next__(self, /) -> tuple[_Shape, _SCT_co]: ...
4987-
4988-
def __iter__(self) -> Self: ...
4989-
4990-
class ndindex:
4991-
@overload
4992-
def __init__(self, shape: tuple[SupportsIndex, ...], /) -> None: ...
4993-
@overload
4994-
def __init__(self, *shape: SupportsIndex) -> None: ...
4995-
def __iter__(self) -> Self: ...
4996-
def __next__(self) -> _Shape: ...
4997-
49984956
# TODO: The type of each `__next__` and `iters` return-type depends
49994957
# on the length and dtype of `args`; we can't describe this behavior yet
50004958
# as we lack variadics (PEP 646).

numpy/_core/fromnumeric.pyi

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ from collections.abc import Sequence
33
from typing import (
44
Any,
55
Literal,
6-
NoReturn,
76
Protocol,
87
SupportsIndex,
98
TypeAlias,
109
TypeVar,
1110
overload,
1211
type_check_only,
1312
)
13+
14+
from _typeshed import Incomplete
1415
from typing_extensions import Never, deprecated
1516

1617
import numpy as np
@@ -551,9 +552,6 @@ def ravel(
551552
@overload
552553
def ravel(a: ArrayLike, order: _OrderKACF = "C") -> np.ndarray[tuple[int], np.dtype[Any]]: ...
553554

554-
@overload
555-
def nonzero(a: np.generic | np.ndarray[tuple[()], Any]) -> NoReturn: ...
556-
@overload
557555
def nonzero(a: _ArrayLike[Any]) -> tuple[NDArray[intp], ...]: ...
558556

559557
# this prevents `Any` from being returned with Pyright
@@ -813,7 +811,7 @@ def all(
813811
keepdims: _BoolLike_co | _NoValueType = ...,
814812
*,
815813
where: _ArrayLikeBool_co | _NoValueType = ...,
816-
) -> np.bool | NDArray[np.bool]: ...
814+
) -> Incomplete: ...
817815
@overload
818816
def all(
819817
a: ArrayLike,
@@ -850,7 +848,7 @@ def any(
850848
keepdims: _BoolLike_co | _NoValueType = ...,
851849
*,
852850
where: _ArrayLikeBool_co | _NoValueType = ...,
853-
) -> np.bool | NDArray[np.bool]: ...
851+
) -> Incomplete: ...
854852
@overload
855853
def any(
856854
a: ArrayLike,
@@ -1443,10 +1441,10 @@ def mean(
14431441
keepdims: Literal[False] | _NoValueType = ...,
14441442
*,
14451443
where: _ArrayLikeBool_co | _NoValueType = ...,
1446-
) -> complexfloating[Any, Any]: ...
1444+
) -> complexfloating[Any]: ...
14471445
@overload
14481446
def mean(
1449-
a: _ArrayLikeTD64_co,
1447+
a: _ArrayLike[np.timedelta64],
14501448
axis: None = ...,
14511449
dtype: None = ...,
14521450
out: None = ...,
@@ -1457,23 +1455,33 @@ def mean(
14571455
@overload
14581456
def mean(
14591457
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
1460-
axis: _ShapeLike | None = ...,
1461-
dtype: None = ...,
1462-
out: None = ...,
1458+
axis: _ShapeLike | None,
1459+
dtype: DTypeLike,
1460+
out: _ArrayT,
14631461
keepdims: bool | _NoValueType = ...,
14641462
*,
14651463
where: _ArrayLikeBool_co | _NoValueType = ...,
1466-
) -> Any: ...
1464+
) -> _ArrayT: ...
1465+
@overload
1466+
def mean(
1467+
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
1468+
axis: _ShapeLike | None = ...,
1469+
dtype: DTypeLike | None = ...,
1470+
*,
1471+
out: _ArrayT,
1472+
keepdims: bool | _NoValueType = ...,
1473+
where: _ArrayLikeBool_co | _NoValueType = ...,
1474+
) -> _ArrayT: ...
14671475
@overload
14681476
def mean(
14691477
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
14701478
axis: None,
14711479
dtype: _DTypeLike[_SCT],
14721480
out: None = ...,
1473-
keepdims: bool | _NoValueType = ...,
1481+
keepdims: Literal[False] | _NoValueType = ...,
14741482
*,
14751483
where: _ArrayLikeBool_co | _NoValueType = ...,
1476-
) -> _SCT | NDArray[_SCT]: ...
1484+
) -> _SCT: ...
14771485
@overload
14781486
def mean(
14791487
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
@@ -1487,43 +1495,43 @@ def mean(
14871495
@overload
14881496
def mean(
14891497
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
1490-
axis: None = ...,
1491-
*,
1498+
axis: _ShapeLike | None,
14921499
dtype: _DTypeLike[_SCT],
1493-
out: None = ...,
1494-
keepdims: bool | _NoValueType = ...,
1500+
out: None,
1501+
keepdims: Literal[True, 1],
1502+
*,
14951503
where: _ArrayLikeBool_co | _NoValueType = ...,
1496-
) -> _SCT | NDArray[_SCT]: ...
1504+
) -> NDArray[_SCT]: ...
14971505
@overload
14981506
def mean(
14991507
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
1500-
axis: _ShapeLike | None = ...,
1501-
dtype: DTypeLike = ...,
1508+
axis: _ShapeLike | None,
1509+
dtype: _DTypeLike[_SCT],
15021510
out: None = ...,
1503-
keepdims: bool | _NoValueType = ...,
15041511
*,
1512+
keepdims: bool | _NoValueType = ...,
15051513
where: _ArrayLikeBool_co | _NoValueType = ...,
1506-
) -> Any: ...
1514+
) -> _SCT | NDArray[_SCT]: ...
15071515
@overload
15081516
def mean(
15091517
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
1510-
axis: _ShapeLike | None,
1511-
dtype: DTypeLike,
1512-
out: _ArrayT,
1513-
keepdims: bool | _NoValueType = ...,
1518+
axis: _ShapeLike | None = ...,
15141519
*,
1520+
dtype: _DTypeLike[_SCT],
1521+
out: None = ...,
1522+
keepdims: bool | _NoValueType = ...,
15151523
where: _ArrayLikeBool_co | _NoValueType = ...,
1516-
) -> _ArrayT: ...
1524+
) -> _SCT | NDArray[_SCT]: ...
15171525
@overload
15181526
def mean(
15191527
a: _ArrayLikeComplex_co | _ArrayLikeObject_co,
15201528
axis: _ShapeLike | None = ...,
1521-
dtype: DTypeLike = ...,
1522-
*,
1523-
out: _ArrayT,
1529+
dtype: DTypeLike | None = ...,
1530+
out: None = ...,
15241531
keepdims: bool | _NoValueType = ...,
1532+
*,
15251533
where: _ArrayLikeBool_co | _NoValueType = ...,
1526-
) -> _ArrayT: ...
1534+
) -> Incomplete: ...
15271535

15281536
@overload
15291537
def std(

0 commit comments

Comments
 (0)