Skip to content

Commit b24de77

Browse files
rwbartonGuido van Rossum
authored and
Guido van Rossum
committed
Allow inferring lambda type from Callable[..., T] context (#1522)
Fixes #1517.
1 parent da18dd1 commit b24de77

File tree

2 files changed

+26
-0
lines changed

2 files changed

+26
-0
lines changed

mypy/checkexpr.py

+8
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,14 @@ def infer_lambda_type_using_context(self, e: FuncExpr) -> CallableType:
12991299

13001300
arg_kinds = [arg.kind for arg in e.arguments]
13011301

1302+
if callable_ctx.is_ellipsis_args:
1303+
# Fill in Any arguments to match the arguments of the lambda.
1304+
callable_ctx = callable_ctx.copy_modified(
1305+
is_ellipsis_args=False,
1306+
arg_types=[AnyType()] * len(arg_kinds),
1307+
arg_kinds=arg_kinds
1308+
)
1309+
13021310
if callable_ctx.arg_kinds != arg_kinds:
13031311
# Incompatible context; cannot use it to infer types.
13041312
self.chk.fail(messages.CANNOT_INFER_LAMBDA_TYPE, e)

mypy/test/data/check-inference-context.test

+18
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,24 @@ main:2: error: Cannot infer type of lambda
593593
main:3: error: Cannot infer type of lambda
594594
main:3: error: Incompatible types in assignment (expression has type Callable[[], A], variable has type Callable[[A], A])
595595

596+
[case testEllipsisContextForLambda]
597+
from typing import Callable
598+
f1 = lambda x: 1 # type: Callable[..., int]
599+
f2 = lambda: 1 # type: Callable[..., int]
600+
f3 = lambda *args, **kwargs: 1 # type: Callable[..., int]
601+
f4 = lambda x: x # type: Callable[..., int]
602+
g = lambda x: 1 # type: Callable[..., str]
603+
[builtins fixtures/dict.py]
604+
[out]
605+
main:6: error: Incompatible return value type: expected builtins.str, got builtins.int
606+
main:6: error: Incompatible types in assignment (expression has type Callable[[Any], int], variable has type Callable[..., str])
607+
608+
[case testEllipsisContextForLambda2]
609+
from typing import TypeVar, Callable
610+
T = TypeVar('T')
611+
def foo(arg: Callable[..., T]) -> None: pass
612+
foo(lambda: 1)
613+
596614

597615
-- Overloads + generic functions
598616
-- -----------------------------

0 commit comments

Comments
 (0)