Skip to content

Commit 7c37519

Browse files
rwbartonJukkaL
authored andcommitted
Generate correct constraints for Unions (#1408)
* Simplify unions when expanding types * Generate correct constraints for Unions The old code always generated a conjunction, where we sometimes want a disjunction. Since we can't represent or solve a disjunction of constraints, need to approximate. In the presence of Unions, the constraints inferred from (template, actual, direction) are not necessarily the reverses of the constraints inferred from (template, actual, neg_op(direction)). The visit_instance and visit_callable cases were updated accordingly. Test testGenericFunctionSubtypingWithUnions fails without this change. Fixes #1458 and the remaining part of #1241.
1 parent 55e13ff commit 7c37519

File tree

4 files changed

+204
-25
lines changed

4 files changed

+204
-25
lines changed

mypy/constraints.py

+90-24
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy import nodes
1212
import mypy.subtypes
13+
from mypy.erasetype import erase_typevars
1314

1415

1516
SUBTYPE_OF = 0 # type: int
@@ -119,9 +120,82 @@ def infer_constraints(template: Type, actual: Type,
119120
The constraints are represented as Constraint objects.
120121
"""
121122

123+
# If the template is simply a type variable, emit a Constraint directly.
124+
# We need to handle this case before handling Unions for two reasons:
125+
# 1. "T <: Union[U1, U2]" is not equivalent to "T <: U1 or T <: U2",
126+
# because T can itself be a union (notably, Union[U1, U2] itself).
127+
# 2. "T :> Union[U1, U2]" is logically equivalent to "T :> U1 and
128+
# T :> U2", but they are not equivalent to the constraint solver,
129+
# which never introduces new Union types (it uses join() instead).
130+
if isinstance(template, TypeVarType):
131+
return [Constraint(template.id, direction, actual)]
132+
133+
# Now handle the case of either template or actual being a Union.
134+
# For a Union to be a subtype of another type, every item of the Union
135+
# must be a subtype of that type, so concatenate the constraints.
136+
if direction == SUBTYPE_OF and isinstance(template, UnionType):
137+
res = []
138+
for t_item in template.items:
139+
res.extend(infer_constraints(t_item, actual, direction))
140+
return res
141+
if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
142+
res = []
143+
for a_item in actual.items:
144+
res.extend(infer_constraints(template, a_item, direction))
145+
return res
146+
147+
# Now the potential subtype is known not to be a Union or a type
148+
# variable that we are solving for. In that case, for a Union to
149+
# be a supertype of the potential subtype, some item of the Union
150+
# must be a supertype of it.
151+
if direction == SUBTYPE_OF and isinstance(actual, UnionType):
152+
return any_constraints(
153+
[infer_constraints_if_possible(template, a_item, direction)
154+
for a_item in actual.items])
155+
if direction == SUPERTYPE_OF and isinstance(template, UnionType):
156+
return any_constraints(
157+
[infer_constraints_if_possible(t_item, actual, direction)
158+
for t_item in template.items])
159+
160+
# Remaining cases are handled by ConstraintBuilderVisitor.
122161
return template.accept(ConstraintBuilderVisitor(actual, direction))
123162

124163

164+
def infer_constraints_if_possible(template: Type, actual: Type,
165+
direction: int) -> Optional[List[Constraint]]:
166+
"""Like infer_constraints, but return None if the input relation is
167+
known to be unsatisfiable, for example if template=List[T] and actual=int.
168+
(In this case infer_constraints would return [], just like it would for
169+
an automatically satisfied relation like template=List[T] and actual=object.)
170+
"""
171+
if (direction == SUBTYPE_OF and
172+
not mypy.subtypes.is_subtype(erase_typevars(template), actual)):
173+
return None
174+
if (direction == SUPERTYPE_OF and
175+
not mypy.subtypes.is_subtype(actual, erase_typevars(template))):
176+
return None
177+
return infer_constraints(template, actual, direction)
178+
179+
180+
def any_constraints(options: List[Optional[List[Constraint]]]) -> List[Constraint]:
181+
"""Deduce what we can from a collection of constraint lists given that
182+
at least one of the lists must be satisfied. A None element in the
183+
list of options represents an unsatisfiable constraint and is ignored.
184+
"""
185+
valid_options = [option for option in options if option is not None]
186+
if len(valid_options) == 1:
187+
return valid_options[0]
188+
# Otherwise, there are either no valid options or multiple valid options.
189+
# Give up and deduce nothing.
190+
return []
191+
192+
# TODO: In the latter case, it could happen that every valid
193+
# option requires the same constraint on the same variable. Then
194+
# we could include that that constraint in the result. Or more
195+
# generally, if a given (variable, direction) pair appears in
196+
# every option, combine the bounds with meet/join.
197+
198+
125199
class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):
126200
"""Visitor class for inferring type constraints."""
127201

@@ -163,7 +237,8 @@ def visit_partial_type(self, template: PartialType) -> List[Constraint]:
163237
# Non-trivial leaf type
164238

165239
def visit_type_var(self, template: TypeVarType) -> List[Constraint]:
166-
return [Constraint(template.id, self.direction, self.actual)]
240+
assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor"
241+
" (should have been handled in infer_constraints)")
167242

168243
# Non-leaf types
169244

@@ -177,23 +252,23 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
177252
mapped = map_instance_to_supertype(template, instance.type)
178253
for i in range(len(instance.args)):
179254
# The constraints for generic type parameters are
180-
# invariant. Include the default constraint and its
181-
# negation to achieve the effect.
182-
cb = infer_constraints(mapped.args[i], instance.args[i],
183-
self.direction)
184-
res.extend(cb)
185-
res.extend(negate_constraints(cb))
255+
# invariant. Include constraints from both directions
256+
# to achieve the effect.
257+
res.extend(infer_constraints(
258+
mapped.args[i], instance.args[i], self.direction))
259+
res.extend(infer_constraints(
260+
mapped.args[i], instance.args[i], neg_op(self.direction)))
186261
return res
187262
elif (self.direction == SUPERTYPE_OF and
188263
instance.type.has_base(template.type.fullname())):
189264
mapped = map_instance_to_supertype(instance, template.type)
190265
for j in range(len(template.args)):
191266
# The constraints for generic type parameters are
192267
# invariant.
193-
cb = infer_constraints(template.args[j], mapped.args[j],
194-
self.direction)
195-
res.extend(cb)
196-
res.extend(negate_constraints(cb))
268+
res.extend(infer_constraints(
269+
template.args[j], mapped.args[j], self.direction))
270+
res.extend(infer_constraints(
271+
template.args[j], mapped.args[j], neg_op(self.direction)))
197272
return res
198273
if isinstance(actual, AnyType):
199274
# IDEA: Include both ways, i.e. add negation as well?
@@ -222,8 +297,8 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
222297
if not template.is_ellipsis_args:
223298
# The lengths should match, but don't crash (it will error elsewhere).
224299
for t, a in zip(template.arg_types, cactual.arg_types):
225-
# Negate constraints due function argument type contravariance.
226-
res.extend(negate_constraints(infer_constraints(t, a, self.direction)))
300+
# Negate direction due to function argument type contravariance.
301+
res.extend(infer_constraints(t, a, neg_op(self.direction)))
227302
res.extend(infer_constraints(template.ret_type, cactual.ret_type,
228303
self.direction))
229304
return res
@@ -264,10 +339,8 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
264339
return []
265340

266341
def visit_union_type(self, template: UnionType) -> List[Constraint]:
267-
res = [] # type: List[Constraint]
268-
for item in template.items:
269-
res.extend(infer_constraints(item, self.actual, self.direction))
270-
return res
342+
assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
343+
" (should have been handled in infer_constraints)")
271344

272345
def infer_against_any(self, types: List[Type]) -> List[Constraint]:
273346
res = [] # type: List[Constraint]
@@ -282,13 +355,6 @@ def visit_overloaded(self, type: Overloaded) -> List[Constraint]:
282355
return res
283356

284357

285-
def negate_constraints(constraints: List[Constraint]) -> List[Constraint]:
286-
res = [] # type: List[Constraint]
287-
for c in constraints:
288-
res.append(Constraint(c.type_var, neg_op(c.op), c.target))
289-
return res
290-
291-
292358
def neg_op(op: int) -> int:
293359
"""Map SubtypeOf to SupertypeOf and vice versa."""
294360

mypy/expandtype.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:
8787
return TupleType(self.expand_types(t.items), t.fallback, t.line)
8888

8989
def visit_union_type(self, t: UnionType) -> Type:
90-
return UnionType(self.expand_types(t.items), t.line)
90+
# After substituting for type variables in t.items,
91+
# some of the resulting types might be subtypes of others.
92+
return UnionType.make_simplified_union(self.expand_types(t.items), t.line)
9193

9294
def visit_partial_type(self, t: PartialType) -> Type:
9395
return t

mypy/test/data/check-inference.test

+92
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,98 @@ l = lb # E: Incompatible types in assignment (expression has type List[bool], va
690690
[builtins fixtures/for.py]
691691

692692

693+
-- Generic function inference with unions
694+
-- --------------------------------------
695+
696+
697+
[case testUnionInference]
698+
from typing import TypeVar, Union, List
699+
T = TypeVar('T')
700+
U = TypeVar('U')
701+
def f(x: Union[T, int], y: T) -> T: pass
702+
f(1, 'a')() # E: "str" not callable
703+
f('a', 1)() # E: "object" not callable
704+
f('a', 'a')() # E: "str" not callable
705+
f(1, 1)() # E: "int" not callable
706+
707+
def g(x: Union[T, List[T]]) -> List[T]: pass
708+
def h(x: List[str]) -> None: pass
709+
g('a')() # E: List[str] not callable
710+
711+
# The next line is a case where there are multiple ways to satisfy a constraint
712+
# involving a Union. Either T = List[str] or T = str would turn out to be valid,
713+
# but mypy doesn't know how to branch on these two options (and potentially have
714+
# to backtrack later) and defaults to T = None. The result is an awkward error
715+
# message. Either a better error message, or simply accepting the call, would be
716+
# preferable here.
717+
g(['a']) # E: Argument 1 to "g" has incompatible type List[str]; expected List[None]
718+
719+
h(g(['a']))
720+
721+
def i(x: Union[List[T], List[U]], y: List[T], z: List[U]) -> None: pass
722+
a = [1]
723+
b = ['b']
724+
i(a, a, b)
725+
i(b, a, b)
726+
i(a, b, b) # E: Argument 1 to "i" has incompatible type List[int]; expected List[str]
727+
[builtins fixtures/list.py]
728+
729+
730+
[case testUnionInferenceWithTypeVarValues]
731+
from typing import TypeVar, Union
732+
AnyStr = TypeVar('AnyStr', bytes, str)
733+
def f(x: Union[AnyStr, int], *a: AnyStr) -> None: pass
734+
f('foo')
735+
f('foo', 'bar')
736+
f('foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object"
737+
f(1)
738+
f(1, 'foo')
739+
f(1, 'foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object"
740+
[builtins fixtures/primitives.py]
741+
742+
743+
[case testUnionTwoPassInference-skip]
744+
from typing import TypeVar, Union, List
745+
T = TypeVar('T')
746+
U = TypeVar('U')
747+
def j(x: Union[List[T], List[U]], y: List[T]) -> List[U]: pass
748+
749+
a = [1]
750+
b = ['b']
751+
# We could infer: Since List[str] <: List[T], we must have T = str.
752+
# Then since List[int] <: Union[List[str], List[U]], and List[int] is
753+
# not a subtype of List[str], we must have U = int.
754+
# This is not currently implemented.
755+
j(a, b)
756+
[builtins fixtures/list.py]
757+
758+
759+
[case testUnionContext]
760+
from typing import TypeVar, Union, List
761+
T = TypeVar('T')
762+
def f() -> List[T]: pass
763+
d1 = f() # type: Union[List[int], str]
764+
d2 = f() # type: Union[int, str] # E: Incompatible types in assignment (expression has type List[None], variable has type "Union[int, str]")
765+
def g(x: T) -> List[T]: pass
766+
d3 = g(1) # type: Union[List[int], List[str]]
767+
[builtins fixtures/list.py]
768+
769+
770+
[case testGenericFunctionSubtypingWithUnions]
771+
from typing import TypeVar, Union, List
772+
T = TypeVar('T')
773+
S = TypeVar('S')
774+
def k1(x: int, y: List[T]) -> List[Union[T, int]]: pass
775+
def k2(x: S, y: List[T]) -> List[Union[T, int]]: pass
776+
a = k2
777+
a = k2
778+
a = k1 # E: Incompatible types in assignment (expression has type Callable[[int, List[T]], List[Union[T, int]]], variable has type Callable[[S, List[T]], List[Union[T, int]]])
779+
b = k1
780+
b = k1
781+
b = k2
782+
[builtins fixtures/list.py]
783+
784+
693785
-- Literal expressions
694786
-- -------------------
695787

mypy/test/data/check-unions.test

+19
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,22 @@ def f(x: Optional[int]) -> None: pass
110110
f(1)
111111
f(None)
112112
f('') # E: Argument 1 to "f" has incompatible type "str"; expected "int"
113+
114+
[case testUnionSimplificationGenericFunction]
115+
from typing import TypeVar, Union, List
116+
T = TypeVar('T')
117+
def f(x: List[T]) -> Union[T, int]: pass
118+
def g(y: str) -> None: pass
119+
a = f([1])
120+
g(a) # E: Argument 1 to "g" has incompatible type "int"; expected "str"
121+
[builtins fixtures/list.py]
122+
123+
[case testUnionSimplificationGenericClass]
124+
from typing import TypeVar, Union, Generic
125+
T = TypeVar('T')
126+
U = TypeVar('U')
127+
class C(Generic[T, U]):
128+
def f(self, x: str) -> Union[T, U]: pass
129+
a = C() # type: C[int, int]
130+
b = a.f('a')
131+
a.f(b) # E: Argument 1 to "f" of "C" has incompatible type "int"; expected "str"

0 commit comments

Comments
 (0)