Skip to content

Commit 900c546

Browse files
Improve typing of inference functions (#2166)
Co-authored-by: Daniël van Noord <[email protected]>
1 parent 0740a0d commit 900c546

File tree

5 files changed

+84
-42
lines changed

5 files changed

+84
-42
lines changed

astroid/inference.py

-1
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,6 @@ def infer_name(
254254
return bases._infer_stmts(stmts, context, frame)
255255

256256

257-
# pylint: disable=no-value-for-parameter
258257
# The order of the decorators here is important
259258
# See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5
260259
nodes.Name._infer = decorators.raise_if_nothing_inferred(

astroid/inference_tip.py

+27-24
Original file line numberDiff line numberDiff line change
@@ -6,65 +6,67 @@
66

77
from __future__ import annotations
88

9-
import sys
10-
from collections.abc import Callable, Iterator
9+
from collections.abc import Generator
10+
from typing import Any, TypeVar
1111

1212
from astroid.context import InferenceContext
1313
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
1414
from astroid.nodes import NodeNG
15-
from astroid.typing import InferenceResult, InferFn
16-
17-
if sys.version_info >= (3, 11):
18-
from typing import ParamSpec
19-
else:
20-
from typing_extensions import ParamSpec
21-
22-
_P = ParamSpec("_P")
15+
from astroid.typing import (
16+
InferenceResult,
17+
InferFn,
18+
TransformFn,
19+
)
2320

2421
_cache: dict[
25-
tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult]
22+
tuple[InferFn[Any], NodeNG, InferenceContext | None], list[InferenceResult]
2623
] = {}
2724

28-
_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set()
25+
_CURRENTLY_INFERRING: set[tuple[InferFn[Any], NodeNG]] = set()
26+
27+
_NodesT = TypeVar("_NodesT", bound=NodeNG)
2928

3029

3130
def clear_inference_tip_cache() -> None:
3231
"""Clear the inference tips cache."""
3332
_cache.clear()
3433

3534

36-
def _inference_tip_cached(
37-
func: Callable[_P, Iterator[InferenceResult]],
38-
) -> Callable[_P, Iterator[InferenceResult]]:
35+
def _inference_tip_cached(func: InferFn[_NodesT]) -> InferFn[_NodesT]:
3936
"""Cache decorator used for inference tips."""
4037

41-
def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
42-
node = args[0]
43-
context = args[1]
38+
def inner(
39+
node: _NodesT,
40+
context: InferenceContext | None = None,
41+
**kwargs: Any,
42+
) -> Generator[InferenceResult, None, None]:
4443
partial_cache_key = (func, node)
4544
if partial_cache_key in _CURRENTLY_INFERRING:
4645
# If through recursion we end up trying to infer the same
4746
# func + node we raise here.
4847
raise UseInferenceDefault
4948
try:
50-
return _cache[func, node, context]
49+
yield from _cache[func, node, context]
50+
return
5151
except KeyError:
5252
# Recursion guard with a partial cache key.
5353
# Using the full key causes a recursion error on PyPy.
5454
# It's a pragmatic compromise to avoid so much recursive inference
5555
# with slightly different contexts while still passing the simple
5656
# test cases included with this commit.
5757
_CURRENTLY_INFERRING.add(partial_cache_key)
58-
result = _cache[func, node, context] = list(func(*args, **kwargs))
58+
result = _cache[func, node, context] = list(func(node, context, **kwargs))
5959
# Remove recursion guard.
6060
_CURRENTLY_INFERRING.remove(partial_cache_key)
6161

62-
return iter(result)
62+
yield from result
6363

6464
return inner
6565

6666

67-
def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
67+
def inference_tip(
68+
infer_function: InferFn[_NodesT], raise_on_overwrite: bool = False
69+
) -> TransformFn[_NodesT]:
6870
"""Given an instance specific inference function, return a function to be
6971
given to AstroidManager().register_transform to set this inference function.
7072
@@ -86,7 +88,9 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) ->
8688
excess overwrites.
8789
"""
8890

89-
def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
91+
def transform(
92+
node: _NodesT, infer_function: InferFn[_NodesT] = infer_function
93+
) -> _NodesT:
9094
if (
9195
raise_on_overwrite
9296
and node._explicit_inference is not None
@@ -100,7 +104,6 @@ def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
100104
node=node,
101105
)
102106
)
103-
# pylint: disable=no-value-for-parameter
104107
node._explicit_inference = _inference_tip_cached(infer_function)
105108
return node
106109

astroid/nodes/node_ng.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import pprint
8+
import sys
89
import warnings
910
from collections.abc import Generator, Iterator
1011
from functools import cached_property
@@ -37,6 +38,12 @@
3738
from astroid.nodes.utils import Position
3839
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn
3940

41+
if sys.version_info >= (3, 11):
42+
from typing import Self
43+
else:
44+
from typing_extensions import Self
45+
46+
4047
if TYPE_CHECKING:
4148
from astroid import nodes
4249

@@ -80,7 +87,7 @@ class NodeNG:
8087
_other_other_fields: ClassVar[tuple[str, ...]] = ()
8188
"""Attributes that contain AST-dependent fields."""
8289
# instance specific inference function infer(node, context)
83-
_explicit_inference: InferFn | None = None
90+
_explicit_inference: InferFn[Self] | None = None
8491

8592
def __init__(
8693
self,
@@ -137,9 +144,17 @@ def infer(
137144
# explicit_inference is not bound, give it self explicitly
138145
try:
139146
if context is None:
140-
yield from self._explicit_inference(self, context, **kwargs)
147+
yield from self._explicit_inference(
148+
self, # type: ignore[arg-type]
149+
context,
150+
**kwargs,
151+
)
141152
return
142-
for result in self._explicit_inference(self, context, **kwargs):
153+
for result in self._explicit_inference(
154+
self, # type: ignore[arg-type]
155+
context,
156+
**kwargs,
157+
):
143158
context.nodes_inferred += 1
144159
yield result
145160
return

astroid/transforms.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,14 @@
99
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload
1010

1111
from astroid.context import _invalidate_cache
12-
from astroid.typing import SuccessfulInferenceResult
12+
from astroid.typing import SuccessfulInferenceResult, TransformFn
1313

1414
if TYPE_CHECKING:
1515
from astroid import nodes
1616

1717
_SuccessfulInferenceResultT = TypeVar(
1818
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
1919
)
20-
_Transform = Callable[
21-
[_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult]
22-
]
2320
_Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]]
2421

2522
_Vistables = Union[
@@ -52,7 +49,7 @@ def __init__(self) -> None:
5249
type[SuccessfulInferenceResult],
5350
list[
5451
tuple[
55-
_Transform[SuccessfulInferenceResult],
52+
TransformFn[SuccessfulInferenceResult],
5653
_Predicate[SuccessfulInferenceResult],
5754
]
5855
],
@@ -123,7 +120,7 @@ def _visit_generic(self, node: _Vistables) -> _VisitReturns:
123120
def register_transform(
124121
self,
125122
node_class: type[_SuccessfulInferenceResultT],
126-
transform: _Transform[_SuccessfulInferenceResultT],
123+
transform: TransformFn[_SuccessfulInferenceResultT],
127124
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
128125
) -> None:
129126
"""Register `transform(node)` function to be applied on the given node.
@@ -139,7 +136,7 @@ def register_transform(
139136
def unregister_transform(
140137
self,
141138
node_class: type[_SuccessfulInferenceResultT],
142-
transform: _Transform[_SuccessfulInferenceResultT],
139+
transform: TransformFn[_SuccessfulInferenceResultT],
143140
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
144141
) -> None:
145142
"""Unregister the given transform."""

astroid/typing.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,24 @@
44

55
from __future__ import annotations
66

7-
from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
Callable,
11+
Generator,
12+
Generic,
13+
Protocol,
14+
TypedDict,
15+
TypeVar,
16+
Union,
17+
)
818

919
if TYPE_CHECKING:
1020
from astroid import bases, exceptions, nodes, transforms, util
1121
from astroid.context import InferenceContext
1222
from astroid.interpreter._import import spec
1323

1424

15-
_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG")
16-
17-
1825
class InferenceErrorInfo(TypedDict):
1926
"""Store additional Inference error information
2027
raised with StopIteration exception.
@@ -24,9 +31,6 @@ class InferenceErrorInfo(TypedDict):
2431
context: InferenceContext | None
2532

2633

27-
InferFn = Callable[..., Any]
28-
29-
3034
class AstroidManagerBrain(TypedDict):
3135
"""Dictionary to store relevant information for a AstroidManager class."""
3236

@@ -46,6 +50,11 @@ class AstroidManagerBrain(TypedDict):
4650
_SuccessfulInferenceResultT = TypeVar(
4751
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
4852
)
53+
_SuccessfulInferenceResultT_contra = TypeVar(
54+
"_SuccessfulInferenceResultT_contra",
55+
bound=SuccessfulInferenceResult,
56+
contravariant=True,
57+
)
4958

5059
ConstFactoryResult = Union[
5160
"nodes.List",
@@ -67,3 +76,22 @@ class AstroidManagerBrain(TypedDict):
6776
],
6877
Generator[InferenceResult, None, None],
6978
]
79+
80+
81+
class InferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]):
82+
def __call__(
83+
self,
84+
node: _SuccessfulInferenceResultT_contra,
85+
context: InferenceContext | None = None,
86+
**kwargs: Any,
87+
) -> Generator[InferenceResult, None, None]:
88+
... # pragma: no cover
89+
90+
91+
class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]):
92+
def __call__(
93+
self,
94+
node: _SuccessfulInferenceResultT,
95+
infer_function: InferFn[_SuccessfulInferenceResultT] = ...,
96+
) -> _SuccessfulInferenceResultT | None:
97+
... # pragma: no cover

0 commit comments

Comments
 (0)