9
9
import sys
10
10
from collections .abc import Callable , Iterator
11
11
12
+ from astroid .context import InferenceContext
12
13
from astroid .exceptions import InferenceOverwriteError , UseInferenceDefault
13
14
from astroid .nodes import NodeNG
14
15
from astroid .typing import InferenceResult , InferFn
20
21
21
22
_P = ParamSpec ("_P" )
22
23
23
- _cache : dict [tuple [InferFn , NodeNG ], list [InferenceResult ] | None ] = {}
24
+ _cache : dict [
25
+ tuple [InferFn , NodeNG , InferenceContext | None ], list [InferenceResult ]
26
+ ] = {}
27
+
28
+ _CURRENTLY_INFERRING : set [tuple [InferFn , NodeNG ]] = set ()
24
29
25
30
26
31
def clear_inference_tip_cache () -> None :
@@ -35,16 +40,25 @@ def _inference_tip_cached(
35
40
36
41
def inner (* args : _P .args , ** kwargs : _P .kwargs ) -> Iterator [InferenceResult ]:
37
42
node = args [0 ]
38
- try :
39
- result = _cache [func , node ]
43
+ context = args [1 ]
44
+ partial_cache_key = (func , node )
45
+ if partial_cache_key in _CURRENTLY_INFERRING :
40
46
# If through recursion we end up trying to infer the same
41
47
# func + node we raise here.
42
- if result is None :
43
- raise UseInferenceDefault ()
48
+ raise UseInferenceDefault
49
+ try :
50
+ return _cache [func , node , context ]
44
51
except KeyError :
45
- _cache [func , node ] = None
46
- result = _cache [func , node ] = list (func (* args , ** kwargs ))
47
- assert result
52
+ # Recursion guard with a partial cache key.
53
+ # Using the full key causes a recursion error on PyPy.
54
+ # It's a pragmatic compromise to avoid so much recursive inference
55
+ # with slightly different contexts while still passing the simple
56
+ # test cases included with this commit.
57
+ _CURRENTLY_INFERRING .add (partial_cache_key )
58
+ result = _cache [func , node , context ] = list (func (* args , ** kwargs ))
59
+ # Remove recursion guard.
60
+ _CURRENTLY_INFERRING .remove (partial_cache_key )
61
+
48
62
return iter (result )
49
63
50
64
return inner
0 commit comments