Skip to content

Commit d9f19dd

Browse files
committed
[suggest] Support refining existing type annotations
1 parent 22a5a4f commit d9f19dd

File tree

2 files changed

+71
-6
lines changed

2 files changed

+71
-6
lines changed

mypy/suggestions.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from mypy.checkexpr import has_any_type
5252

5353
from mypy.join import join_type_list
54+
from mypy.meet import meet_types
5455
from mypy.sametypes import is_same_type
5556
from mypy.typeops import make_simplified_union
5657

@@ -240,6 +241,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
240241
AnyType(TypeOfAny.special_form),
241242
self.builtin_type('builtins.function'))
242243

244+
def get_starting_type(self, fdef: FuncDef) -> CallableType:
245+
if isinstance(fdef.type, CallableType):
246+
return fdef.type
247+
else:
248+
return self.get_trivial_type(fdef)
249+
243250
def get_args(self, is_method: bool,
244251
base: CallableType, defaults: List[Optional[Type]],
245252
callsites: List[Callsite]) -> List[List[Type]]:
@@ -294,11 +301,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
294301
"""
295302
options = self.get_args(is_method, base, defaults, callsites)
296303
options = [self.add_adjustments(tps) for tps in options]
297-
return [base.copy_modified(arg_types=list(x)) for x in itertools.product(*options)]
304+
return [merge_callables(base, base.copy_modified(arg_types=list(x)))
305+
for x in itertools.product(*options)]
298306

299307
def get_callsites(self, func: FuncDef) -> Tuple[List[Callsite], List[str]]:
300308
"""Find all call sites of a function."""
301-
new_type = self.get_trivial_type(func)
309+
new_type = self.get_starting_type(func)
302310

303311
collector_plugin = SuggestionPlugin(func.fullname())
304312

@@ -350,7 +358,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
350358
with strict_optional_set(graph[mod].options.strict_optional):
351359
guesses = self.get_guesses(
352360
is_method,
353-
self.get_trivial_type(node),
361+
self.get_starting_type(node),
354362
self.get_default_arg_types(graph[mod], node),
355363
callsites)
356364
guesses = self.filter_options(guesses, is_method)
@@ -367,7 +375,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
367375
else:
368376
ret_types = [NoneType()]
369377

370-
guesses = [best.copy_modified(ret_type=t) for t in ret_types]
378+
guesses = [merge_callables(best, best.copy_modified(ret_type=t)) for t in ret_types]
371379
guesses = self.filter_options(guesses, is_method)
372380
best, errors = self.find_best(node, guesses)
373381

@@ -528,8 +536,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
528536
"""
529537
old = func.unanalyzed_type
530538
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
531-
# We don't modify type because it isn't necessary and it
532-
# would mess up the snapshotting.
539+
# We set type to None to ensure that the type always changes during
540+
# reprocessing.
541+
func.type = None
533542
func.unanalyzed_type = typ
534543
try:
535544
res = self.fgmanager.trigger(func.fullname())
@@ -778,6 +787,25 @@ def count_errors(msgs: List[str]) -> int:
778787
T = TypeVar('T')
779788

780789

790+
def merge_callables(t: CallableType, s: CallableType) -> CallableType:
791+
"""Merge two callable types in a way that prefers dropping Anys.
792+
793+
This is implemented by doing a meet on both the arguments and the return type,
794+
since meet(t, Any) == t.
795+
796+
This won't do perfectly with complex compound types (like
797+
callables nested inside), but it does pretty well.
798+
"""
799+
800+
# We don't want to ever squash away optionals while doing this, so set
801+
# strict optional to be true always
802+
with strict_optional_set(True):
803+
arg_types = [] # type: List[Type]
804+
for i in range(len(t.arg_types)):
805+
arg_types.append(meet_types(t.arg_types[i], s.arg_types[i]))
806+
return t.copy_modified(arg_types=arg_types, ret_type=meet_types(t.ret_type, s.ret_type))
807+
808+
781809
def dedup(old: List[T]) -> List[T]:
782810
new = [] # type: List[T]
783811
for x in old:

test-data/unit/fine-grained-suggest.test

+37
Original file line numberDiff line numberDiff line change
@@ -829,3 +829,40 @@ Command 'suggest' is only valid after a 'check' command (that produces no parse
829829
==
830830
foo.py:4: error: unexpected EOF while parsing
831831
-- )
832+
833+
[case testSuggestRefine]
834+
# suggest: foo.foo
835+
# suggest: foo.spam
836+
# suggest: foo.eggs
837+
# suggest: foo.take_l
838+
[file foo.py]
839+
from typing import Any, List
840+
841+
def bar():
842+
return 10
843+
844+
def foo(x: int, y):
845+
return x + y
846+
847+
def spam(x: int, y: Any) -> Any:
848+
return x + y
849+
850+
def eggs(x: int) -> List[Any]:
851+
a = [x]
852+
return a
853+
854+
def take_l(x: List[Any]) -> Any:
855+
return x[0]
856+
857+
858+
foo(bar(), 10)
859+
spam(bar(), 20)
860+
test = [10, 20]
861+
take_l(test)
862+
[builtins fixtures/isinstancelist.pyi]
863+
[out]
864+
(int, int) -> int
865+
(int, int) -> int
866+
(int) -> foo.List[int]
867+
(foo.List[int]) -> int
868+
==

0 commit comments

Comments
 (0)