Skip to content

Commit 3df1b06

Browse files
committed
Adds support for basic union math with overloads
This commit adds support for very basic and simple union math when calling overloaded functions, resolving python#4576. As a side effect, this change also fixes a bug where calling overloaded functions can sometimes silently infer a return type of 'Any' and slightly modifies the semantics of how mypy handles overlaps in overloaded functions. Details on specific changes made: 1. The new algorithm works by modifying checkexpr.overload_call_targets to return all possible matches, rather then just one. We start by trying the first matching signature. If there was some error, we (conservatively) attempt to union all of the matching signatures together and repeat the typechecking process. If it doesn't seem like it's possible to combine the matching signatures in a sound way, we end and just output the errors we obtained from typechecking the first match. The "signature-unioning" code is currently deliberately very conservative. I figured it was better to start small and attempt to handle only basic cases like python#1943 and relax the restrictions later as needed. For more details on this algorithm, see the comments in checkexpr.union_overload_matches. 2. This change incidentally resolves any bugs related to how calling an overloaded function can sometimes silently infer a return type of Any. Previously, if a function call caused an overload to be less precise then a previous one, we gave up and returned a silent Any. This change removes this case altogether and only infers Any if either (a) the caller arguments explicitly contains Any or (b) if there was some error. For example, see python#3295 and python#1322 -- I believe this pull request touches on and maybe resolves (??) those two issues. 3. As a result, this caused a few errors in mypy where code was relying on this "silently infer Any" behavior -- see the changes in checker.py and semanal.py. Both files were using expressions of the form `zip(*iterable)`, which ended up having a type of `Any` under the old algorithm. The new algorithm will instead infer `Iterable[Tuple[Any, ...]]` which actually matches the stubs in typeshed. 4. Many of the attrs tests were also relying on the same behavior. Specifically, these changes cause the attr stubs in `test-data/unit/lib-stub` to no longer work. It seemed that expressions of the form `a = attr.ib()` were evaluated to 'Any' not because of a stub, but because of the 'silent Any' bug. I couldn't find a clean way of fixing the stubs to infer the correct thing under this new behavior, so just gave up and removed the overloads altogether. I think this is fine though -- it seems like the attrs plugin infers the correct type for us anyways, regardless of what the stubs say. If this pull request is accepted, I plan on submitting a similar pull request to the stubs in typeshed. 4. This pull request also probably touches on python/typing#253. We still require the overloads to be written from the most narrow to general and disallow overlapping signatures. However, if a *call* now causes overlaps, we try the "union" algorithm described above and default to selecting the first matching overload instead of giving up.
1 parent 18a77cf commit 3df1b06

File tree

9 files changed

+278
-85
lines changed

9 files changed

+278
-85
lines changed

mypy/checker.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1801,8 +1801,8 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E
18011801
expr = expr.expr
18021802
types, declared_types = zip(*items)
18031803
self.binder.assign_type(expr,
1804-
UnionType.make_simplified_union(types),
1805-
UnionType.make_simplified_union(declared_types),
1804+
UnionType.make_simplified_union(list(types)),
1805+
UnionType.make_simplified_union(list(declared_types)),
18061806
False)
18071807
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
18081808
# Properly store the inferred types.

mypy/checkexpr.py

+151-58
Original file line numberDiff line numberDiff line change
@@ -611,10 +611,63 @@ def check_call(self, callee: Type, args: List[Expression],
611611
arg_types = self.infer_arg_types_in_context(None, args)
612612
self.msg.enable_errors()
613613

614-
target = self.overload_call_target(arg_types, arg_kinds, arg_names,
615-
callee, context,
616-
messages=arg_messages)
617-
return self.check_call(target, args, arg_kinds, context, arg_names,
614+
overload_messages = arg_messages.copy()
615+
targets = self.overload_call_targets(arg_types, arg_kinds, arg_names,
616+
callee, context,
617+
messages=overload_messages)
618+
619+
# If there are multiple targets, that means that there were
620+
# either multiple possible matches or the types were overlapping in some
621+
# way. In either case, we default to picking the first match and
622+
# see what happens if we try using it.
623+
#
624+
# Note: if we pass in an argument that inherits from two overloaded
625+
# types, we default to picking the first match. For example:
626+
#
627+
# class A: pass
628+
# class B: pass
629+
# class C(A, B): pass
630+
#
631+
# @overload
632+
# def f(x: A) -> int: ...
633+
# @overload
634+
# def f(x: B) -> str: ...
635+
# def f(x): ...
636+
#
637+
# reveal_type(f(C())) # Will be 'int', not 'Union[int, str]'
638+
#
639+
# It's unclear if this is really the best thing to do, but multiple
640+
# inheritance is rare. See the docstring of mypy.meet.is_overlapping_types
641+
# for more about this.
642+
643+
original_output = self.check_call(targets[0], args, arg_kinds, context, arg_names,
644+
arg_messages=overload_messages,
645+
callable_name=callable_name,
646+
object_type=object_type)
647+
648+
if not overload_messages.is_errors() or len(targets) == 1:
649+
# If there were no errors or if there was only one match, we can end now.
650+
#
651+
# Note that if we have only one target, there's nothing else we
652+
# can try doing. In that case, we just give up and return early
653+
# and skip the below steps.
654+
arg_messages.add_errors(overload_messages)
655+
return original_output
656+
657+
# Otherwise, we attempt to synthesize together a new callable by combining
658+
# together the different matches by union-ing together their arguments
659+
# and return type.
660+
661+
targets = cast(List[CallableType], targets)
662+
unioned_callable = self.union_overload_matches(targets)
663+
if unioned_callable is None:
664+
# If it was not possible to actually combine together the
665+
# callables in a sound way, we give up and return the original
666+
# error message.
667+
arg_messages.add_errors(overload_messages)
668+
return original_output
669+
670+
return self.check_call(unioned_callable, args, arg_kinds, context, arg_names,
618671
arg_messages=arg_messages,
619672
callable_name=callable_name,
620673
object_type=object_type)
@@ -1089,83 +1142,123 @@ def check_arg(self, caller_type: Type, original_caller_type: Type,
10891142
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and
10901143
# ...except for classmethod first argument
10911144
not caller_type.is_classmethod_class):
1092-
self.msg.concrete_only_call(callee_type, context)
1145+
messages.concrete_only_call(callee_type, context)
10931146
elif not is_subtype(caller_type, callee_type):
10941147
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
10951148
return
10961149
messages.incompatible_argument(n, m, callee, original_caller_type,
10971150
caller_kind, context)
10981151
if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and
10991152
isinstance(callee_type, Instance) and callee_type.type.is_protocol):
1100-
self.msg.report_protocol_problems(original_caller_type, callee_type, context)
1153+
messages.report_protocol_problems(original_caller_type, callee_type, context)
11011154
if (isinstance(callee_type, CallableType) and
11021155
isinstance(original_caller_type, Instance)):
11031156
call = find_member('__call__', original_caller_type, original_caller_type)
11041157
if call:
1105-
self.msg.note_call(original_caller_type, call, context)
1106-
1107-
def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int],
1108-
arg_names: Optional[Sequence[Optional[str]]],
1109-
overload: Overloaded, context: Context,
1110-
messages: Optional[MessageBuilder] = None) -> Type:
1111-
"""Infer the correct overload item to call with given argument types.
1112-
1113-
The return value may be CallableType or AnyType (if an unique item
1114-
could not be determined).
1158+
messages.note_call(original_caller_type, call, context)
1159+
1160+
def overload_call_targets(self, arg_types: List[Type], arg_kinds: List[int],
1161+
arg_names: Optional[Sequence[Optional[str]]],
1162+
overload: Overloaded, context: Context,
1163+
messages: Optional[MessageBuilder] = None) -> Sequence[Type]:
1164+
"""Infer all possible overload targets to call with given argument types.
1165+
The list is guaranteed be one of the following:
1166+
1167+
1. A List[CallableType] of length 1 if we were able to find an
1168+
unambiguous best match.
1169+
2. A List[AnyType] of length 1 if we were unable to find any match
1170+
or discovered the match was ambiguous due to conflicting Any types.
1171+
3. A List[CallableType] of length 2 or more if there were multiple
1172+
plausible matches. The matches are returned in the order they
1173+
were defined.
11151174
"""
11161175
messages = messages or self.msg
1117-
# TODO: For overlapping signatures we should try to get a more precise
1118-
# result than 'Any'.
11191176
match = [] # type: List[CallableType]
11201177
best_match = 0
11211178
for typ in overload.items():
11221179
similarity = self.erased_signature_similarity(arg_types, arg_kinds, arg_names,
11231180
typ, context=context)
11241181
if similarity > 0 and similarity >= best_match:
1125-
if (match and not is_same_type(match[-1].ret_type,
1126-
typ.ret_type) and
1127-
(not mypy.checker.is_more_precise_signature(match[-1], typ)
1128-
or (any(isinstance(arg, AnyType) for arg in arg_types)
1129-
and any_arg_causes_overload_ambiguity(
1130-
match + [typ], arg_types, arg_kinds, arg_names)))):
1131-
# Ambiguous return type. Either the function overload is
1132-
# overlapping (which we don't handle very well here) or the
1133-
# caller has provided some Any argument types; in either
1134-
# case we'll fall back to Any. It's okay to use Any types
1135-
# in calls.
1136-
#
1137-
# Overlapping overload items are generally fine if the
1138-
# overlapping is only possible when there is multiple
1139-
# inheritance, as this is rare. See docstring of
1140-
# mypy.meet.is_overlapping_types for more about this.
1141-
#
1142-
# Note that there is no ambiguity if the items are
1143-
# covariant in both argument types and return types with
1144-
# respect to type precision. We'll pick the best/closest
1145-
# match.
1146-
#
1147-
# TODO: Consider returning a union type instead if the
1148-
# overlapping is NOT due to Any types?
1149-
return AnyType(TypeOfAny.special_form)
1150-
else:
1151-
match.append(typ)
1182+
if (match and not is_same_type(match[-1].ret_type, typ.ret_type)
1183+
and any(isinstance(arg, AnyType) for arg in arg_types)
1184+
and any_arg_causes_overload_ambiguity(
1185+
match + [typ], arg_types, arg_kinds, arg_names)):
1186+
# Ambiguous return type. The caller has provided some
1187+
# Any argument types (which are okay to use in calls),
1188+
# so we fall back to returning 'Any'.
1189+
return [AnyType(TypeOfAny.special_form)]
1190+
match.append(typ)
11521191
best_match = max(best_match, similarity)
1153-
if not match:
1192+
1193+
if len(match) == 0:
11541194
if not self.chk.should_suppress_optional_error(arg_types):
11551195
messages.no_variant_matches_arguments(overload, arg_types, context)
1156-
return AnyType(TypeOfAny.from_error)
1196+
return [AnyType(TypeOfAny.from_error)]
1197+
elif len(match) == 1:
1198+
return match
11571199
else:
1158-
if len(match) == 1:
1159-
return match[0]
1160-
else:
1161-
# More than one signature matches. Pick the first *non-erased*
1162-
# matching signature, or default to the first one if none
1163-
# match.
1164-
for m in match:
1165-
if self.match_signature_types(arg_types, arg_kinds, arg_names, m,
1166-
context=context):
1167-
return m
1168-
return match[0]
1200+
# More than one signature matches or the signatures are
1201+
# overlapping. In either case, we return all of the matching
1202+
# signatures and let the caller decide what to do with them.
1203+
out = [m for m in match if self.match_signature_types(
1204+
arg_types, arg_kinds, arg_names, m, context=context)]
1205+
return out if len(out) >= 1 else match
1206+
1207+
def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]:
1208+
"""Accepts a list of overload signatures and attempts to combine them together into a
1209+
new CallableType consisting of the union of all of the given arguments and return types.
1210+
1211+
Returns None if it is not possible to combine the different callables together in a
1212+
sound manner."""
1213+
1214+
new_args: List[List[Type]] = [[] for _ in range(len(callables[0].arg_types))]
1215+
1216+
expected_names = callables[0].arg_names
1217+
expected_kinds = callables[0].arg_kinds
1218+
1219+
for target in callables:
1220+
if target.arg_names != expected_names or target.arg_kinds != expected_kinds:
1221+
# We conservatively end if the overloads do not have the exact same signature.
1222+
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
1223+
return None
1224+
1225+
for i, arg in enumerate(target.arg_types):
1226+
new_args[i].append(arg)
1227+
1228+
union_count = 0
1229+
final_args = []
1230+
for args in new_args:
1231+
new_type = UnionType.make_simplified_union(args)
1232+
union_count += 1 if isinstance(new_type, UnionType) else 0
1233+
final_args.append(new_type)
1234+
1235+
# TODO: Modify this check to be less conservative.
1236+
#
1237+
# Currently, we permit only one union union in the arguments because if we allow
1238+
# multiple, we can't always guarantee the synthesized callable will be correct.
1239+
#
1240+
# For example, suppose we had the following two overloads:
1241+
#
1242+
# @overload
1243+
# def f(x: A, y: B) -> None: ...
1244+
# @overload
1245+
# def f(x: B, y: A) -> None: ...
1246+
#
1247+
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
1248+
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
1249+
# be rejected.
1250+
#
1251+
# However, that means we'll also give up if the original overloads contained
1252+
# any unions. This is likely unnecessary -- we only really need to give up if
1253+
# there are more then one *synthesized* union arguments.
1254+
if union_count >= 2:
1255+
return None
1256+
1257+
return callables[0].copy_modified(
1258+
arg_types=final_args,
1259+
ret_type=UnionType.make_simplified_union([t.ret_type for t in callables]),
1260+
implicit=True,
1261+
from_overloads=True)
11691262

11701263
def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int],
11711264
arg_names: Optional[Sequence[Optional[str]]],

mypy/messages.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,19 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
629629
expected_type = callee.arg_types[m - 1]
630630
except IndexError: # Varargs callees
631631
expected_type = callee.arg_types[-1]
632+
632633
arg_type_str, expected_type_str = self.format_distinctly(
633634
arg_type, expected_type, bare=True)
635+
expected_type_str = self.quote_type_string(expected_type_str)
636+
637+
if callee.from_overloads and isinstance(expected_type, UnionType):
638+
expected_formatted = []
639+
for e in expected_type.items:
640+
type_str = self.format_distinctly(arg_type, e, bare=True)[1]
641+
expected_formatted.append(self.quote_type_string(type_str))
642+
expected_type_str = 'one of {} based on available overloads'.format(
643+
', '.join(expected_formatted))
644+
634645
if arg_kind == ARG_STAR:
635646
arg_type_str = '*' + arg_type_str
636647
elif arg_kind == ARG_STAR2:
@@ -645,8 +656,7 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
645656
arg_label = '"{}"'.format(arg_name)
646657

647658
msg = 'Argument {} {}has incompatible type {}; expected {}'.format(
648-
arg_label, target, self.quote_type_string(arg_type_str),
649-
self.quote_type_string(expected_type_str))
659+
arg_label, target, self.quote_type_string(arg_type_str), expected_type_str)
650660
if isinstance(arg_type, Instance) and isinstance(expected_type, Instance):
651661
notes = append_invariance_notes(notes, arg_type, expected_type)
652662
self.fail(msg, context)

mypy/semanal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2204,7 +2204,7 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression,
22042204
# about the length mismatch in type-checking.
22052205
elementwise_assignments = zip(rval.items, *[v.items for v in seq_lvals])
22062206
for rv, *lvs in elementwise_assignments:
2207-
self.process_module_assignment(lvs, rv, ctx)
2207+
self.process_module_assignment(list(lvs), rv, ctx)
22082208
elif isinstance(rval, RefExpr):
22092209
rnode = self.lookup_type_node(rval)
22102210
if rnode and rnode.kind == MODULE_REF:

mypy/types.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,8 @@ class CallableType(FunctionLike):
661661
special_sig = None # type: Optional[str]
662662
# Was this callable generated by analyzing Type[...] instantiation?
663663
from_type_type = False # type: bool
664+
# Was this callable generated by synthesizing multiple overloads?
665+
from_overloads = False # type: bool
664666

665667
bound_args = None # type: List[Optional[Type]]
666668

@@ -680,6 +682,7 @@ def __init__(self,
680682
is_classmethod_class: bool = False,
681683
special_sig: Optional[str] = None,
682684
from_type_type: bool = False,
685+
from_overloads: bool = False,
683686
bound_args: Optional[List[Optional[Type]]] = None,
684687
) -> None:
685688
assert len(arg_types) == len(arg_kinds) == len(arg_names)
@@ -704,6 +707,7 @@ def __init__(self,
704707
self.is_classmethod_class = is_classmethod_class
705708
self.special_sig = special_sig
706709
self.from_type_type = from_type_type
710+
self.from_overloads = from_overloads
707711
self.bound_args = bound_args or []
708712
super().__init__(line, column)
709713

@@ -719,8 +723,10 @@ def copy_modified(self,
719723
line: int = _dummy,
720724
column: int = _dummy,
721725
is_ellipsis_args: bool = _dummy,
726+
implicit: bool = _dummy,
722727
special_sig: Optional[str] = _dummy,
723728
from_type_type: bool = _dummy,
729+
from_overloads: bool = _dummy,
724730
bound_args: List[Optional[Type]] = _dummy) -> 'CallableType':
725731
return CallableType(
726732
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
@@ -735,10 +741,11 @@ def copy_modified(self,
735741
column=column if column is not _dummy else self.column,
736742
is_ellipsis_args=(
737743
is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args),
738-
implicit=self.implicit,
744+
implicit=implicit if implicit is not _dummy else self.implicit,
739745
is_classmethod_class=self.is_classmethod_class,
740746
special_sig=special_sig if special_sig is not _dummy else self.special_sig,
741747
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
748+
from_overloads=from_overloads if from_overloads is not _dummy else self.from_overloads,
742749
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
743750
)
744751

@@ -890,6 +897,7 @@ def serialize(self) -> JsonDict:
890897
'is_ellipsis_args': self.is_ellipsis_args,
891898
'implicit': self.implicit,
892899
'is_classmethod_class': self.is_classmethod_class,
900+
'from_overloads': self.from_overloads,
893901
'bound_args': [(None if t is None else t.serialize())
894902
for t in self.bound_args],
895903
}
@@ -908,6 +916,7 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
908916
is_ellipsis_args=data['is_ellipsis_args'],
909917
implicit=data['implicit'],
910918
is_classmethod_class=data['is_classmethod_class'],
919+
from_overloads=data['from_overloads'],
911920
bound_args=[(None if t is None else deserialize_type(t))
912921
for t in data['bound_args']],
913922
)

0 commit comments

Comments
 (0)