diff --git a/ChangeLog b/ChangeLog index a760172d10..faccda681d 100644 --- a/ChangeLog +++ b/ChangeLog @@ -17,6 +17,13 @@ Release date: TBA * Reduce file system access in ``ast_from_file()``. +* Fix incorrect cache keys for inference results, thereby correctly inferring types + for calls instantiating types dynamically. + + Closes #1828 + Closes pylint-dev/pylint#7464 + Closes pylint-dev/pylint#8074 + * ``nodes.FunctionDef`` no longer inherits from ``nodes.Lambda``. This is a breaking change but considered a bug fix as the nodes did not share the same API and were not interchangeable. diff --git a/astroid/inference_tip.py b/astroid/inference_tip.py index 5b855c9e77..92cb6b4fe1 100644 --- a/astroid/inference_tip.py +++ b/astroid/inference_tip.py @@ -9,6 +9,7 @@ import sys from collections.abc import Callable, Iterator +from astroid.context import InferenceContext from astroid.exceptions import InferenceOverwriteError, UseInferenceDefault from astroid.nodes import NodeNG from astroid.typing import InferenceResult, InferFn @@ -20,7 +21,11 @@ _P = ParamSpec("_P") -_cache: dict[tuple[InferFn, NodeNG], list[InferenceResult] | None] = {} +_cache: dict[ + tuple[InferFn, NodeNG, InferenceContext | None], list[InferenceResult] +] = {} + +_CURRENTLY_INFERRING: set[tuple[InferFn, NodeNG]] = set() def clear_inference_tip_cache() -> None: @@ -35,16 +40,25 @@ def _inference_tip_cached( def inner(*args: _P.args, **kwargs: _P.kwargs) -> Iterator[InferenceResult]: node = args[0] - try: - result = _cache[func, node] + context = args[1] + partial_cache_key = (func, node) + if partial_cache_key in _CURRENTLY_INFERRING: # If through recursion we end up trying to infer the same # func + node we raise here. - if result is None: - raise UseInferenceDefault() + raise UseInferenceDefault + try: + return _cache[func, node, context] except KeyError: - _cache[func, node] = None - result = _cache[func, node] = list(func(*args, **kwargs)) - assert result + # Recursion guard with a partial cache key. + # Using the full key causes a recursion error on PyPy. + # It's a pragmatic compromise to avoid so much recursive inference + # with slightly different contexts while still passing the simple + # test cases included with this commit. + _CURRENTLY_INFERRING.add(partial_cache_key) + result = _cache[func, node, context] = list(func(*args, **kwargs)) + # Remove recursion guard. + _CURRENTLY_INFERRING.remove(partial_cache_key) + return iter(result) return inner diff --git a/tests/brain/test_brain.py b/tests/brain/test_brain.py index 00a023ddb0..4a016868cb 100644 --- a/tests/brain/test_brain.py +++ b/tests/brain/test_brain.py @@ -930,13 +930,7 @@ class A: assert inferred.value == 42 def test_typing_cast_multiple_inference_calls(self) -> None: - """Inference of an outer function should not store the result for cast. - - https://github.com/pylint-dev/pylint/issues/8074 - - Possible solution caused RecursionErrors with Python 3.8 and CPython + PyPy. - https://github.com/pylint-dev/astroid/pull/1982 - """ + """Inference of an outer function should not store the result for cast.""" ast_nodes = builder.extract_node( """ from typing import TypeVar, cast @@ -954,7 +948,7 @@ def ident(var: T) -> T: i1 = next(ast_nodes[1].infer()) assert isinstance(i1, nodes.Const) - assert i1.value == 2 # should be "Hello"! + assert i1.value == "Hello" class ReBrainTest(unittest.TestCase): diff --git a/tests/test_regrtest.py b/tests/test_regrtest.py index 31d9e6b84b..59d344b954 100644 --- a/tests/test_regrtest.py +++ b/tests/test_regrtest.py @@ -336,6 +336,27 @@ def d(self): assert isinstance(inferred, Instance) assert inferred.qname() == ".A" + def test_inference_context_consideration(self) -> None: + """https://github.com/PyCQA/astroid/issues/1828""" + code = """ + class Base: + def return_type(self): + return type(self)() + class A(Base): + def method(self): + return self.return_type() + class B(Base): + def method(self): + return self.return_type() + A().method() #@ + B().method() #@ + """ + node1, node2 = extract_node(code) + inferred1 = next(node1.infer()) + assert inferred1.qname() == ".A" + inferred2 = next(node2.infer()) + assert inferred2.qname() == ".B" + class Whatever: a = property(lambda x: x, lambda x: x) # type: ignore[misc] diff --git a/tests/test_scoped_nodes.py b/tests/test_scoped_nodes.py index b8c55f67d3..86d69624d1 100644 --- a/tests/test_scoped_nodes.py +++ b/tests/test_scoped_nodes.py @@ -1771,9 +1771,7 @@ def __init__(self): "FinalClass", "ClassB", "MixinB", - # We don't recognize what 'cls' is at time of .format() call, only - # what it is at the end. - # "strMixin", + "strMixin", "ClassA", "MixinA", "intMixin",