Skip to content

Improve typing of inference functions #2166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion astroid/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,6 @@ def infer_name(
return bases._infer_stmts(stmts, context, frame)


# pylint: disable=no-value-for-parameter
# The order of the decorators here is important
# See https://github.com/pylint-dev/astroid/commit/0a8a75db30da060a24922e05048bc270230f5
nodes.Name._infer = decorators.raise_if_nothing_inferred(
Expand Down
33 changes: 18 additions & 15 deletions astroid/inference_tip.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@

from __future__ import annotations

import sys
from collections.abc import Callable, Iterator
from typing import TYPE_CHECKING

from astroid.context import InferenceContext
from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault
from astroid.nodes import NodeNG
from astroid.typing import InferenceResult, InferFn

if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

_P = ParamSpec("_P")
from astroid.typing import (
_P,
InferenceResult,
InferFn,
InferFnExplicit,
InferFnTransform,
)

_cache: dict[
tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult]
Expand All @@ -35,12 +34,15 @@ def clear_inference_tip_cache() -> None:

def _inference_tip_cached(
func: Callable[_P, Iterator[InferenceResult]],
) -> Callable[_P, Iterator[InferenceResult]]:
) -> InferFnExplicit:
"""Cache decorator used for inference tips."""

def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
node = args[0]
context = args[1]
def inner(
*args: _P.args, **kwargs: _P.kwargs
) -> Iterator[InferenceResult] | list[InferenceResult]:
node: NodeNG = args[0]
context: InferenceContext | None = args[1]

partial_cache_key = (func, node)
if partial_cache_key in _CURRENTLY_INFERRING:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, but isn't this what the path wrapper does? Can we use that here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, but at glance, it looks like that one is sensitive to specific InferenceContexts. Here, we're not ready to unleash recursive inference with every slightly different context.

# If through recursion we end up trying to infer the same
Expand All @@ -64,7 +66,9 @@ def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]:
return inner


def inference_tip(infer_function: InferFn, raise_on_overwrite: bool = False) -> InferFn:
def inference_tip(
infer_function: InferFn, raise_on_overwrite: bool = False
) -> InferFnTransform:
"""Given an instance specific inference function, return a function to be
given to AstroidManager().register_transform to set this inference function.

Expand Down Expand Up @@ -100,7 +104,6 @@ def transform(node: NodeNG, infer_function: InferFn = infer_function) -> NodeNG:
node=node,
)
)
# pylint: disable=no-value-for-parameter
node._explicit_inference = _inference_tip_cached(infer_function)
return node

Expand Down
4 changes: 2 additions & 2 deletions astroid/nodes/node_ng.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from astroid.nodes.as_string import AsStringVisitor
from astroid.nodes.const import OP_PRECEDENCE
from astroid.nodes.utils import Position
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFn
from astroid.typing import InferenceErrorInfo, InferenceResult, InferFnExplicit

if TYPE_CHECKING:
from astroid import nodes
Expand Down Expand Up @@ -80,7 +80,7 @@ class NodeNG:
_other_other_fields: ClassVar[tuple[str, ...]] = ()
"""Attributes that contain AST-dependent fields."""
# instance specific inference function infer(node, context)
_explicit_inference: InferFn | None = None
_explicit_inference: InferFnExplicit | None = None

def __init__(
self,
Expand Down
23 changes: 19 additions & 4 deletions astroid/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,28 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Generator, TypedDict, TypeVar, Union
import sys
from typing import (
TYPE_CHECKING,
Callable,
Generator,
Iterator,
TypedDict,
TypeVar,
Union,
)

if TYPE_CHECKING:
from astroid import bases, exceptions, nodes, transforms, util
from astroid.context import InferenceContext
from astroid.interpreter._import import spec

if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

_P = ParamSpec("_P")
_NodesT = TypeVar("_NodesT", bound="nodes.NodeNG")


Expand All @@ -24,9 +38,6 @@ class InferenceErrorInfo(TypedDict):
context: InferenceContext | None


InferFn = Callable[..., Any]


class AstroidManagerBrain(TypedDict):
"""Dictionary to store relevant information for a AstroidManager class."""

Expand Down Expand Up @@ -67,3 +78,7 @@ class AstroidManagerBrain(TypedDict):
],
Generator[InferenceResult, None, None],
]

InferFn = Callable[..., Iterator[InferenceResult]]
InferFnExplicit = Callable[_P, Union[Iterator[InferenceResult], list[InferenceResult]]]
InferFnTransform = Callable[[_NodesT, InferFn], _NodesT]