Skip to content

Commit bf20f49

Browse files
committed
Fixes to union simplification, isinstance and more (#3025)
The main change is that unions containing Any are no longer simplified to just Any. Also union simplification now has a deterministic result unlike previously, when result could depend on the order of items in a union (this is true modulo remaining bugs). This required changes in various other places to keep the existing semantics, and resulted in some fixes to existing test cases. I also had to fix some tangentially related minor bugs that were triggered by the other changes. We generally don't have fully constructed TypeInfos so we can't do proper union simplification during semantic analysis. Just implement simple-minded simplification that deals with the cases we care about. Fixes #2978. Fixes #1914.
1 parent 428a536 commit bf20f49

20 files changed

+696
-126
lines changed

mypy/checker.py

+39-38
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType,
2929
Instance, NoneTyp, ErrorType, strip_type, TypeType,
3030
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarDef,
31-
true_only, false_only, function_type, is_named_instance
31+
true_only, false_only, function_type, is_named_instance, union_items
3232
)
3333
from mypy.sametypes import is_same_type, is_same_types
3434
from mypy.messages import MessageBuilder
@@ -812,44 +812,45 @@ def check_overlapping_op_methods(self,
812812
# of x in __radd__ would not be A, the methods could be
813813
# non-overlapping.
814814

815-
if isinstance(forward_type, CallableType):
816-
# TODO check argument kinds
817-
if len(forward_type.arg_types) < 1:
818-
# Not a valid operator method -- can't succeed anyway.
819-
return
815+
for forward_item in union_items(forward_type):
816+
if isinstance(forward_item, CallableType):
817+
# TODO check argument kinds
818+
if len(forward_item.arg_types) < 1:
819+
# Not a valid operator method -- can't succeed anyway.
820+
return
820821

821-
# Construct normalized function signatures corresponding to the
822-
# operator methods. The first argument is the left operand and the
823-
# second operand is the right argument -- we switch the order of
824-
# the arguments of the reverse method.
825-
forward_tweaked = CallableType(
826-
[forward_base, forward_type.arg_types[0]],
827-
[nodes.ARG_POS] * 2,
828-
[None] * 2,
829-
forward_type.ret_type,
830-
forward_type.fallback,
831-
name=forward_type.name)
832-
reverse_args = reverse_type.arg_types
833-
reverse_tweaked = CallableType(
834-
[reverse_args[1], reverse_args[0]],
835-
[nodes.ARG_POS] * 2,
836-
[None] * 2,
837-
reverse_type.ret_type,
838-
fallback=self.named_type('builtins.function'),
839-
name=reverse_type.name)
840-
841-
if is_unsafe_overlapping_signatures(forward_tweaked,
842-
reverse_tweaked):
843-
self.msg.operator_method_signatures_overlap(
844-
reverse_class.name(), reverse_name,
845-
forward_base.type.name(), forward_name, context)
846-
elif isinstance(forward_type, Overloaded):
847-
for item in forward_type.items():
848-
self.check_overlapping_op_methods(
849-
reverse_type, reverse_name, reverse_class,
850-
item, forward_name, forward_base, context)
851-
elif not isinstance(forward_type, AnyType):
852-
self.msg.forward_operator_not_callable(forward_name, context)
822+
# Construct normalized function signatures corresponding to the
823+
# operator methods. The first argument is the left operand and the
824+
# second operand is the right argument -- we switch the order of
825+
# the arguments of the reverse method.
826+
forward_tweaked = CallableType(
827+
[forward_base, forward_item.arg_types[0]],
828+
[nodes.ARG_POS] * 2,
829+
[None] * 2,
830+
forward_item.ret_type,
831+
forward_item.fallback,
832+
name=forward_item.name)
833+
reverse_args = reverse_type.arg_types
834+
reverse_tweaked = CallableType(
835+
[reverse_args[1], reverse_args[0]],
836+
[nodes.ARG_POS] * 2,
837+
[None] * 2,
838+
reverse_type.ret_type,
839+
fallback=self.named_type('builtins.function'),
840+
name=reverse_type.name)
841+
842+
if is_unsafe_overlapping_signatures(forward_tweaked,
843+
reverse_tweaked):
844+
self.msg.operator_method_signatures_overlap(
845+
reverse_class.name(), reverse_name,
846+
forward_base.type.name(), forward_name, context)
847+
elif isinstance(forward_item, Overloaded):
848+
for item in forward_item.items():
849+
self.check_overlapping_op_methods(
850+
reverse_type, reverse_name, reverse_class,
851+
item, forward_name, forward_base, context)
852+
elif not isinstance(forward_item, AnyType):
853+
self.msg.forward_operator_not_callable(forward_name, context)
853854

854855
def check_inplace_operator_method(self, defn: FuncBase) -> None:
855856
"""Check an inplace operator method such as __iadd__.

mypy/checkexpr.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from mypy import messages
3333
from mypy.infer import infer_type_arguments, infer_function_type_arguments
3434
from mypy import join
35-
from mypy.meet import meet_simple
35+
from mypy.meet import narrow_declared_type
3636
from mypy.maptype import map_instance_to_supertype
3737
from mypy.subtypes import is_subtype, is_equivalent
3838
from mypy import applytype
@@ -2221,7 +2221,7 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
22212221
if expr.literal >= LITERAL_TYPE:
22222222
restriction = self.chk.binder.get(expr)
22232223
if restriction:
2224-
ans = meet_simple(known_type, restriction)
2224+
ans = narrow_declared_type(known_type, restriction)
22252225
return ans
22262226
return known_type
22272227

mypy/erasetype.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type:
7676
return t.fallback.accept(self)
7777

7878
def visit_union_type(self, t: UnionType) -> Type:
79-
return AnyType() # XXX: return underlying type if only one?
79+
erased_items = [erase_type(item) for item in t.items]
80+
return UnionType.make_simplified_union(erased_items)
8081

8182
def visit_type_type(self, t: TypeType) -> Type:
8283
return TypeType(t.item.accept(self), line=t.line)

mypy/join.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
UninhabitedType, TypeType, true_or_false
1111
)
1212
from mypy.maptype import map_instance_to_supertype
13-
from mypy.subtypes import is_subtype, is_equivalent, is_subtype_ignoring_tvars
13+
from mypy.subtypes import is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype
1414

1515
from mypy import experiments
1616

@@ -29,10 +29,10 @@ def join_simple(declaration: Type, s: Type, t: Type) -> Type:
2929
if isinstance(s, ErasedType):
3030
return t
3131

32-
if is_subtype(s, t):
32+
if is_proper_subtype(s, t):
3333
return t
3434

35-
if is_subtype(t, s):
35+
if is_proper_subtype(t, s):
3636
return s
3737

3838
if isinstance(declaration, UnionType):

mypy/meet.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,26 @@ def meet_types(s: Type, t: Type) -> Type:
2525
return t.accept(TypeMeetVisitor(s))
2626

2727

28-
def meet_simple(s: Type, t: Type, default_right: bool = True) -> Type:
29-
if s == t:
30-
return s
31-
if isinstance(s, UnionType):
32-
return UnionType.make_simplified_union([meet_types(x, t) for x in s.items])
33-
elif not is_overlapping_types(s, t, use_promotions=True):
28+
def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
29+
"""Return the declared type narrowed down to another type."""
30+
if declared == narrowed:
31+
return declared
32+
if isinstance(declared, UnionType):
33+
return UnionType.make_simplified_union([narrow_declared_type(x, narrowed)
34+
for x in declared.items])
35+
elif not is_overlapping_types(declared, narrowed, use_promotions=True):
3436
if experiments.STRICT_OPTIONAL:
3537
return UninhabitedType()
3638
else:
3739
return NoneTyp()
38-
else:
39-
if default_right:
40-
return t
41-
else:
42-
return s
40+
elif isinstance(narrowed, UnionType):
41+
return UnionType.make_simplified_union([narrow_declared_type(declared, x)
42+
for x in narrowed.items])
43+
elif isinstance(narrowed, AnyType):
44+
return narrowed
45+
elif isinstance(declared, (Instance, TupleType)):
46+
return meet_types(declared, narrowed)
47+
return narrowed
4348

4449

4550
def is_overlapping_types(t: Type, s: Type, use_promotions: bool = False) -> bool:
@@ -249,6 +254,10 @@ def visit_tuple_type(self, t: TupleType) -> Type:
249254
elif (isinstance(self.s, Instance) and
250255
self.s.type.fullname() == 'builtins.tuple' and self.s.args):
251256
return t.copy_modified(items=[meet_types(it, self.s.args[0]) for it in t.items])
257+
elif (isinstance(self.s, Instance) and t.fallback.type == self.s.type):
258+
# Uh oh, a broken named tuple type (https://github.com/python/mypy/issues/3016).
259+
# Do something reasonable until that bug is fixed.
260+
return t
252261
else:
253262
return self.default(self.s)
254263

0 commit comments

Comments
 (0)