Skip to content

Improve typing of inference functions #2166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
1 change: 0 additions & 1 deletion astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def infer_name(
return bases._infer_stmts(stmts, context, frame)


# pylint: disable=no-value-for-parameter
# The order of the decorators here is important
# See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5
nodes.Name._infer = decorators.raise_if_nothing_inferred(
Expand Down
51 changes: 27 additions & 24 deletions astroid/inference_tip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,65 +6,67 @@

from __future__ import annotations

import sys
from collections.abc import Callable, Iterator
from collections.abc import Generator
from typing import Any, TypeVar

from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
from astroid.typing import InferenceResult, InferFn

if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

_P = ParamSpec("_P")
from astroid.typing import (
InferenceResult,
InferFn,
TransformFn,
)

_cache: dict[
tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult]
tuple[InferFn[Any], NodeNG, InferenceContext | None], list[InferenceResult]
] = {}

_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set()
_CURRENTLY_INFERRING: set[tuple[InferFn[Any], NodeNG]] = set()

_NodesT = TypeVar("_NodesT", bound=NodeNG)


def clear_inference_tip_cache() -> None:
"""Clear the inference tips cache."""
_cache.clear()


def _inference_tip_cached(
func: Callable[_P, Iterator[InferenceResult]],
) -> Callable[_P, Iterator[InferenceResult]]:
def _inference_tip_cached(func: InferFn[_NodesT]) -> InferFn[_NodesT]:
"""Cache decorator used for inference tips."""

def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
node = args[0]
context = args[1]
def inner(
node: _NodesT,
context: InferenceContext | None = None,
**kwargs: Any,
) -> Generator[InferenceResult, None, None]:
partial_cache_key = (func, node)
if partial_cache_key in _CURRENTLY_INFERRING:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but isn't this what the path wrapper does? Can we use that here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but at glance, it looks like that one is sensitive to specific InferenceContexts. Here, we're not ready to unleash recursive inference with every slightly different context.

# If through recursion we end up trying to infer the same
# func + node we raise here.
raise UseInferenceDefault
try:
return _cache[func, node, context]
yield from _cache[func, node, context]
return
except KeyError:
# Recursion guard with a partial cache key.
# Using the full key causes a recursion error on PyPy.
# It's a pragmatic compromise to avoid so much recursive inference
# with slightly different contexts while still passing the simple
# test cases included with this commit.
_CURRENTLY_INFERRING.add(partial_cache_key)
result = _cache[func, node, context] = list(func(*args, **kwargs))
result = _cache[func, node, context] = list(func(node, context, **kwargs))
# Remove recursion guard.
_CURRENTLY_INFERRING.remove(partial_cache_key)

return iter(result)
yield from result

return inner


def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
def inference_tip(
infer_function: InferFn[_NodesT], raise_on_overwrite: bool = False
) -> TransformFn[_NodesT]:
"""Given an instance specific inference function, return a function to be
given to AstroidManager().register_transform to set this inference function.

Expand All @@ -86,7 +88,9 @@ def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) ->
excess overwrites.
"""

def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
def transform(
node: _NodesT, infer_function: InferFn[_NodesT] = infer_function
) -> _NodesT:
if (
raise_on_overwrite
and node._explicit_inference is not None
Expand All @@ -100,7 +104,6 @@ def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
node=node,
)
)
# pylint: disable=no-value-for-parameter
node._explicit_inference = _inference_tip_cached(infer_function)
return node

Expand Down
21 changes: 18 additions & 3 deletions astroid/nodes/node_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import pprint
import sys
import warnings
from collections.abc import Generator, Iterator
from functools import cached_property
Expand Down Expand Up @@ -37,6 +38,12 @@
from astroid.nodes.utils import Position
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


if TYPE_CHECKING:
from astroid import nodes

Expand Down Expand Up @@ -80,7 +87,7 @@ class NodeNG:
_other_other_fields: ClassVar[tuple[str, ...]] = ()
"""Attributes that contain AST-dependent fields."""
# instance specific inference function infer(node, context)
_explicit_inference: InferFn | None = None
_explicit_inference: InferFn[Self] | None = None

def __init__(
self,
Expand Down Expand Up @@ -137,9 +144,17 @@ def infer(
# explicit_inference is not bound, give it self explicitly
try:
if context is None:
yield from self._explicit_inference(self, context, **kwargs)
yield from self._explicit_inference(
self, # type: ignore[arg-type]
context,
**kwargs,
)
return
for result in self._explicit_inference(self, context, **kwargs):
for result in self._explicit_inference(
self, # type: ignore[arg-type]
context,
**kwargs,
):
context.nodes_inferred += 1
yield result
return
Expand Down
11 changes: 4 additions & 7 deletions astroid/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
from typing import TYPE_CHECKING, List, Optional, Tuple, TypeVar, Union, cast, overload

from astroid.context import _invalidate_cache
from astroid.typing import SuccessfulInferenceResult
from astroid.typing import SuccessfulInferenceResult, TransformFn

if TYPE_CHECKING:
from astroid import nodes

_SuccessfulInferenceResultT = TypeVar(
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
)
_Transform = Callable[
[_SuccessfulInferenceResultT], Optional[SuccessfulInferenceResult]
]
_Predicate = Optional[Callable[[_SuccessfulInferenceResultT], bool]]

_Vistables = Union[
Expand Down Expand Up @@ -52,7 +49,7 @@ def __init__(self) -> None:
type[SuccessfulInferenceResult],
list[
tuple[
_Transform[SuccessfulInferenceResult],
TransformFn[SuccessfulInferenceResult],
_Predicate[SuccessfulInferenceResult],
]
],
Expand Down Expand Up @@ -123,7 +120,7 @@ def _visit_generic(self, node: _Vistables) -> _VisitReturns:
def register_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
transform: _Transform[_SuccessfulInferenceResultT],
transform: TransformFn[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Register `transform(node)` function to be applied on the given node.
Expand All @@ -139,7 +136,7 @@ def register_transform(
def unregister_transform(
self,
node_class: type[_SuccessfulInferenceResultT],
transform: _Transform[_SuccessfulInferenceResultT],
transform: TransformFn[_SuccessfulInferenceResultT],
predicate: _Predicate[_SuccessfulInferenceResultT] | None = None,
) -> None:
"""Unregister the given transform."""
Expand Down
42 changes: 35 additions & 7 deletions astroid/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,24 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generator,
Generic,
Protocol,
TypedDict,
TypeVar,
Union,
)

if TYPE_CHECKING:
from astroid import bases, exceptions, nodes, transforms, util
from astroid.context import InferenceContext
from astroid.interpreter._import import spec


_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG")


class InferenceErrorInfo(TypedDict):
"""Store additional Inference error information
raised with StopIteration exception.
Expand All @@ -24,9 +31,6 @@ class InferenceErrorInfo(TypedDict):
context: InferenceContext | None


InferFn = Callable[..., Any]


class AstroidManagerBrain(TypedDict):
"""Dictionary to store relevant information for a AstroidManager class."""

Expand All @@ -46,6 +50,11 @@ class AstroidManagerBrain(TypedDict):
_SuccessfulInferenceResultT = TypeVar(
"_SuccessfulInferenceResultT", bound=SuccessfulInferenceResult
)
_SuccessfulInferenceResultT_contra = TypeVar(
"_SuccessfulInferenceResultT_contra",
bound=SuccessfulInferenceResult,
contravariant=True,
)

ConstFactoryResult = Union[
"nodes.List",
Expand All @@ -67,3 +76,22 @@ class AstroidManagerBrain(TypedDict):
],
Generator[InferenceResult, None, None],
]


class InferFn(Protocol, Generic[_SuccessfulInferenceResultT_contra]):
def __call__(
self,
node: _SuccessfulInferenceResultT_contra,
context: InferenceContext | None = None,
**kwargs: Any,
) -> Generator[InferenceResult, None, None]:
... # pragma: no cover


class TransformFn(Protocol, Generic[_SuccessfulInferenceResultT]):
def __call__(
self,
node: _SuccessfulInferenceResultT,
infer_function: InferFn[_SuccessfulInferenceResultT] = ...,
) -> _SuccessfulInferenceResultT | None:
... # pragma: no cover