51
51
from mypy .checkexpr import has_any_type
52
52
53
53
from mypy .join import join_type_list
54
+ from mypy .meet import meet_types
54
55
from mypy .sametypes import is_same_type
55
56
from mypy .typeops import make_simplified_union
56
57
@@ -240,6 +241,12 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType:
240
241
AnyType (TypeOfAny .special_form ),
241
242
self .builtin_type ('builtins.function' ))
242
243
244
+ def get_starting_type (self , fdef : FuncDef ) -> CallableType :
245
+ if isinstance (fdef .type , CallableType ):
246
+ return fdef .type
247
+ else :
248
+ return self .get_trivial_type (fdef )
249
+
243
250
def get_args (self , is_method : bool ,
244
251
base : CallableType , defaults : List [Optional [Type ]],
245
252
callsites : List [Callsite ]) -> List [List [Type ]]:
@@ -294,11 +301,12 @@ def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Option
294
301
"""
295
302
options = self .get_args (is_method , base , defaults , callsites )
296
303
options = [self .add_adjustments (tps ) for tps in options ]
297
- return [base .copy_modified (arg_types = list (x )) for x in itertools .product (* options )]
304
+ return [merge_callables (base , base .copy_modified (arg_types = list (x )))
305
+ for x in itertools .product (* options )]
298
306
299
307
def get_callsites (self , func : FuncDef ) -> Tuple [List [Callsite ], List [str ]]:
300
308
"""Find all call sites of a function."""
301
- new_type = self .get_trivial_type (func )
309
+ new_type = self .get_starting_type (func )
302
310
303
311
collector_plugin = SuggestionPlugin (func .fullname ())
304
312
@@ -350,7 +358,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
350
358
with strict_optional_set (graph [mod ].options .strict_optional ):
351
359
guesses = self .get_guesses (
352
360
is_method ,
353
- self .get_trivial_type (node ),
361
+ self .get_starting_type (node ),
354
362
self .get_default_arg_types (graph [mod ], node ),
355
363
callsites )
356
364
guesses = self .filter_options (guesses , is_method )
@@ -367,7 +375,7 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature:
367
375
else :
368
376
ret_types = [NoneType ()]
369
377
370
- guesses = [best .copy_modified (ret_type = t ) for t in ret_types ]
378
+ guesses = [merge_callables ( best , best .copy_modified (ret_type = t ) ) for t in ret_types ]
371
379
guesses = self .filter_options (guesses , is_method )
372
380
best , errors = self .find_best (node , guesses )
373
381
@@ -528,8 +536,9 @@ def try_type(self, func: FuncDef, typ: ProperType) -> List[str]:
528
536
"""
529
537
old = func .unanalyzed_type
530
538
# During reprocessing, unanalyzed_type gets copied to type (by aststrip).
531
- # We don't modify type because it isn't necessary and it
532
- # would mess up the snapshotting.
539
+ # We set type to None to ensure that the type always changes during
540
+ # reprocessing.
541
+ func .type = None
533
542
func .unanalyzed_type = typ
534
543
try :
535
544
res = self .fgmanager .trigger (func .fullname ())
@@ -778,6 +787,25 @@ def count_errors(msgs: List[str]) -> int:
778
787
T = TypeVar ('T' )
779
788
780
789
790
+ def merge_callables (t : CallableType , s : CallableType ) -> CallableType :
791
+ """Merge two callable types in a way that prefers dropping Anys.
792
+
793
+ This is implemented by doing a meet on both the arguments and the return type,
794
+ since meet(t, Any) == t.
795
+
796
+ This won't do perfectly with complex compound types (like
797
+ callables nested inside), but it does pretty well.
798
+ """
799
+
800
+ # We don't want to ever squash away optionals while doing this, so set
801
+ # strict optional to be true always
802
+ with strict_optional_set (True ):
803
+ arg_types = [] # type: List[Type]
804
+ for i in range (len (t .arg_types )):
805
+ arg_types .append (meet_types (t .arg_types [i ], s .arg_types [i ]))
806
+ return t .copy_modified (arg_types = arg_types , ret_type = meet_types (t .ret_type , s .ret_type ))
807
+
808
+
781
809
def dedup (old : List [T ]) -> List [T ]:
782
810
new = [] # type: List[T]
783
811
for x in old :
0 commit comments