Skip to content

Commit deb6573

Browse files
committed
Preserve parent CallContext when inferring nested functions
1 parent dfd88f5 commit deb6573

File tree

7 files changed

+38
-45
lines changed

7 files changed

+38
-45
lines changed

ChangeLog

+4-1
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@ Release date: TBA
1414

1515
* Fix issues with ``typing_extensions.TypeVar``.
1616

17-
1817
* Fix ``ClassDef.fromlino`` for PyPy 3.8 (v7.3.11) if class is wrapped by a decorator.
1918

19+
* Preserve parent CallContext when inferring nested functions.
20+
21+
Closes PyCQA/pylint#8074
22+
2023

2124
What's New in astroid 2.13.3?
2225
=============================

astroid/brain/brain_typing.py

-34
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
Const,
2929
JoinedStr,
3030
Name,
31-
NodeNG,
3231
Subscript,
3332
Tuple,
3433
)
@@ -380,36 +379,6 @@ def infer_special_alias(
380379
return iter([class_def])
381380

382381

383-
def _looks_like_typing_cast(node: Call) -> bool:
384-
return isinstance(node, Call) and (
385-
isinstance(node.func, Name)
386-
and node.func.name == "cast"
387-
or isinstance(node.func, Attribute)
388-
and node.func.attrname == "cast"
389-
)
390-
391-
392-
def infer_typing_cast(
393-
node: Call, ctx: context.InferenceContext | None = None
394-
) -> Iterator[NodeNG]:
395-
"""Infer call to cast() returning same type as casted-from var."""
396-
if not isinstance(node.func, (Name, Attribute)):
397-
raise UseInferenceDefault
398-
399-
try:
400-
func = next(node.func.infer(context=ctx))
401-
except (InferenceError, StopIteration) as exc:
402-
raise UseInferenceDefault from exc
403-
if (
404-
not isinstance(func, FunctionDef)
405-
or func.qname() != "typing.cast"
406-
or len(node.args) != 2
407-
):
408-
raise UseInferenceDefault
409-
410-
return node.args[1].infer(context=ctx)
411-
412-
413382
AstroidManager().register_transform(
414383
Call,
415384
inference_tip(infer_typing_typevar_or_newtype),
@@ -418,9 +387,6 @@ def infer_typing_cast(
418387
AstroidManager().register_transform(
419388
Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript
420389
)
421-
AstroidManager().register_transform(
422-
Call, inference_tip(infer_typing_cast), _looks_like_typing_cast
423-
)
424390

425391
if PY39_PLUS:
426392
AstroidManager().register_transform(

astroid/context.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,14 @@ def __str__(self) -> str:
161161
class CallContext:
162162
"""Holds information for a call site."""
163163

164-
__slots__ = ("args", "keywords", "callee")
164+
__slots__ = ("args", "keywords", "callee", "parent_call_context")
165165

166166
def __init__(
167167
self,
168168
args: list[NodeNG],
169169
keywords: list[Keyword] | None = None,
170170
callee: NodeNG | None = None,
171+
parent_call_context: CallContext | None = None,
171172
):
172173
self.args = args # Call positional arguments
173174
if keywords:
@@ -176,6 +177,9 @@ def __init__(
176177
arg_value_pairs = []
177178
self.keywords = arg_value_pairs # Call keyword arguments
178179
self.callee = callee # Function being called
180+
self.parent_call_context = (
181+
parent_call_context # Parent CallContext for nested calls
182+
)
179183

180184

181185
def copy_context(context: InferenceContext | None) -> InferenceContext:

astroid/inference.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,10 @@ def infer_call(
273273
try:
274274
if hasattr(callee, "infer_call_result"):
275275
callcontext.callcontext = CallContext(
276-
args=self.args, keywords=self.keywords, callee=callee
276+
args=self.args,
277+
keywords=self.keywords,
278+
callee=callee,
279+
parent_call_context=callcontext.callcontext,
277280
)
278281
yield from callee.infer_call_result(caller=self, context=callcontext)
279282
except InferenceError:

astroid/protocols.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def arguments_assigned_stmts(
470470
# reset call context/name
471471
callcontext = context.callcontext
472472
context = copy_context(context)
473-
context.callcontext = None
473+
context.callcontext = callcontext.parent_call_context
474474
args = arguments.CallSite(callcontext, context=context)
475475
return args.infer_argument(self.parent, node_name, context)
476476
return _arguments_infer_argname(self, node_name, context)

tests/unittest_brain.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -2132,8 +2132,7 @@ class A:
21322132
pass
21332133
21342134
b = 42
2135-
a = cast(A, b)
2136-
a
2135+
cast(A, b)
21372136
"""
21382137
)
21392138
inferred = next(node.infer())
@@ -2148,14 +2147,33 @@ class A:
21482147
pass
21492148
21502149
b = 42
2151-
a = typing.cast(A, b)
2152-
a
2150+
typing.cast(A, b)
21532151
"""
21542152
)
21552153
inferred = next(node.infer())
21562154
assert isinstance(inferred, nodes.Const)
21572155
assert inferred.value == 42
21582156

2157+
def test_typing_cast_multiple_inference_calls(self) -> None:
2158+
ast_nodes = builder.extract_node(
2159+
"""
2160+
from typing import TypeVar, cast
2161+
T = TypeVar("T")
2162+
def ident(var: T) -> T:
2163+
return cast(T, var)
2164+
2165+
ident(2) #@
2166+
ident("Hello") #@
2167+
"""
2168+
)
2169+
i0 = next(ast_nodes[0].infer())
2170+
assert isinstance(i0, nodes.Const)
2171+
assert i0.value == 2
2172+
2173+
i1 = next(ast_nodes[1].infer())
2174+
assert isinstance(i1, nodes.Const)
2175+
assert i1.value == "Hello"
2176+
21592177

21602178
@pytest.mark.skipif(
21612179
not HAS_TYPING_EXTENSIONS,

tests/unittest_inference_calls.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -146,8 +146,6 @@ def g(y):
146146
def test_inner_call_with_dynamic_argument() -> None:
147147
"""Test function where return value is the result of a separate function call,
148148
with a dynamic value passed to the inner function.
149-
150-
Currently, this is Uninferable.
151149
"""
152150
node = builder.extract_node(
153151
"""
@@ -163,7 +161,8 @@ def g(y):
163161
assert isinstance(node, nodes.NodeNG)
164162
inferred = node.inferred()
165163
assert len(inferred) == 1
166-
assert inferred[0] is Uninferable
164+
assert isinstance(inferred[0], nodes.Const)
165+
assert inferred[0].value == 3
167166

168167

169168
def test_method_const_instance_attr() -> None:

0 commit comments

Comments
 (0)