Skip to content

Generate correct constraints for Unions #1408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 90 additions & 24 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mypy.maptype import map_instance_to_supertype
from mypy import nodes
import mypy.subtypes
from mypy.erasetype import erase_typevars


SUBTYPE_OF = 0 # type: int
Expand Down Expand Up @@ -119,9 +120,82 @@ def infer_constraints(template: Type, actual: Type,
The constraints are represented as Constraint objects.
"""

# If the template is simply a type variable, emit a Constraint directly.
# We need to handle this case before handling Unions for two reasons:
# 1. "T <: Union[U1, U2]" is not equivalent to "T <: U1 or T <: U2",
# because T can itself be a union (notably, Union[U1, U2] itself).
# 2. "T :> Union[U1, U2]" is logically equivalent to "T :> U1 and
# T :> U2", but they are not equivalent to the constraint solver,
# which never introduces new Union types (it uses join() instead).
if isinstance(template, TypeVarType):
return [Constraint(template.id, direction, actual)]

# Now handle the case of either template or actual being a Union.
# For a Union to be a subtype of another type, every item of the Union
# must be a subtype of that type, so concatenate the constraints.
if direction == SUBTYPE_OF and isinstance(template, UnionType):
res = []
for t_item in template.items:
res.extend(infer_constraints(t_item, actual, direction))
return res
if direction == SUPERTYPE_OF and isinstance(actual, UnionType):
res = []
for a_item in actual.items:
res.extend(infer_constraints(template, a_item, direction))
return res

# Now the potential subtype is known not to be a Union or a type
# variable that we are solving for. In that case, for a Union to
# be a supertype of the potential subtype, some item of the Union
# must be a supertype of it.
if direction == SUBTYPE_OF and isinstance(actual, UnionType):
return any_constraints(
[infer_constraints_if_possible(template, a_item, direction)
for a_item in actual.items])
if direction == SUPERTYPE_OF and isinstance(template, UnionType):
return any_constraints(
[infer_constraints_if_possible(t_item, actual, direction)
for t_item in template.items])

# Remaining cases are handled by ConstraintBuilderVisitor.
return template.accept(ConstraintBuilderVisitor(actual, direction))


def infer_constraints_if_possible(template: Type, actual: Type,
direction: int) -> Optional[List[Constraint]]:
"""Like infer_constraints, but return None if the input relation is
known to be unsatisfiable, for example if template=List[T] and actual=int.
(In this case infer_constraints would return [], just like it would for
an automatically satisfied relation like template=List[T] and actual=object.)
"""
if (direction == SUBTYPE_OF and
not mypy.subtypes.is_subtype(erase_typevars(template), actual)):
return None
if (direction == SUPERTYPE_OF and
not mypy.subtypes.is_subtype(actual, erase_typevars(template))):
return None
return infer_constraints(template, actual, direction)


def any_constraints(options: List[Optional[List[Constraint]]]) -> List[Constraint]:
"""Deduce what we can from a collection of constraint lists given that
at least one of the lists must be satisfied. A None element in the
list of options represents an unsatisfiable constraint and is ignored.
"""
valid_options = [option for option in options if option is not None]
if len(valid_options) == 1:
return valid_options[0]
# Otherwise, there are either no valid options or multiple valid options.
# Give up and deduce nothing.
return []

# TODO: In the latter case, it could happen that every valid
# option requires the same constraint on the same variable. Then
# we could include that that constraint in the result. Or more
# generally, if a given (variable, direction) pair appears in
# every option, combine the bounds with meet/join.


class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]):
"""Visitor class for inferring type constraints."""

Expand Down Expand Up @@ -163,7 +237,8 @@ def visit_partial_type(self, template: PartialType) -> List[Constraint]:
# Non-trivial leaf type

def visit_type_var(self, template: TypeVarType) -> List[Constraint]:
return [Constraint(template.id, self.direction, self.actual)]
assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor"
" (should have been handled in infer_constraints)")

# Non-leaf types

Expand All @@ -177,23 +252,23 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
mapped = map_instance_to_supertype(template, instance.type)
for i in range(len(instance.args)):
# The constraints for generic type parameters are
# invariant. Include the default constraint and its
# negation to achieve the effect.
cb = infer_constraints(mapped.args[i], instance.args[i],
self.direction)
res.extend(cb)
res.extend(negate_constraints(cb))
# invariant. Include constraints from both directions
# to achieve the effect.
res.extend(infer_constraints(
mapped.args[i], instance.args[i], self.direction))
res.extend(infer_constraints(
mapped.args[i], instance.args[i], neg_op(self.direction)))
return res
elif (self.direction == SUPERTYPE_OF and
instance.type.has_base(template.type.fullname())):
mapped = map_instance_to_supertype(instance, template.type)
for j in range(len(template.args)):
# The constraints for generic type parameters are
# invariant.
cb = infer_constraints(template.args[j], mapped.args[j],
self.direction)
res.extend(cb)
res.extend(negate_constraints(cb))
res.extend(infer_constraints(
template.args[j], mapped.args[j], self.direction))
res.extend(infer_constraints(
template.args[j], mapped.args[j], neg_op(self.direction)))
return res
if isinstance(actual, AnyType):
# IDEA: Include both ways, i.e. add negation as well?
Expand Down Expand Up @@ -222,8 +297,8 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
if not template.is_ellipsis_args:
# The lengths should match, but don't crash (it will error elsewhere).
for t, a in zip(template.arg_types, cactual.arg_types):
# Negate constraints due function argument type contravariance.
res.extend(negate_constraints(infer_constraints(t, a, self.direction)))
# Negate direction due to function argument type contravariance.
res.extend(infer_constraints(t, a, neg_op(self.direction)))
res.extend(infer_constraints(template.ret_type, cactual.ret_type,
self.direction))
return res
Expand Down Expand Up @@ -264,10 +339,8 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
return []

def visit_union_type(self, template: UnionType) -> List[Constraint]:
res = [] # type: List[Constraint]
for item in template.items:
res.extend(infer_constraints(item, self.actual, self.direction))
return res
assert False, ("Unexpected UnionType in ConstraintBuilderVisitor"
" (should have been handled in infer_constraints)")

def infer_against_any(self, types: List[Type]) -> List[Constraint]:
res = [] # type: List[Constraint]
Expand All @@ -282,13 +355,6 @@ def visit_overloaded(self, type: Overloaded) -> List[Constraint]:
return res


def negate_constraints(constraints: List[Constraint]) -> List[Constraint]:
res = [] # type: List[Constraint]
for c in constraints:
res.append(Constraint(c.type_var, neg_op(c.op), c.target))
return res


def neg_op(op: int) -> int:
"""Map SubtypeOf to SupertypeOf and vice versa."""

Expand Down
4 changes: 3 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def visit_tuple_type(self, t: TupleType) -> Type:
return TupleType(self.expand_types(t.items), t.fallback, t.line)

def visit_union_type(self, t: UnionType) -> Type:
return UnionType(self.expand_types(t.items), t.line)
# After substituting for type variables in t.items,
# some of the resulting types might be subtypes of others.
return UnionType.make_simplified_union(self.expand_types(t.items), t.line)

def visit_partial_type(self, t: PartialType) -> Type:
return t
Expand Down
92 changes: 92 additions & 0 deletions mypy/test/data/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,98 @@ l = lb # E: Incompatible types in assignment (expression has type List[bool], va
[builtins fixtures/for.py]


-- Generic function inference with unions
-- --------------------------------------


[case testUnionInference]
from typing import TypeVar, Union, List
T = TypeVar('T')
U = TypeVar('U')
def f(x: Union[T, int], y: T) -> T: pass
f(1, 'a')() # E: "str" not callable
f('a', 1)() # E: "object" not callable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test f('a', 'a')() and f(1, 1).

f('a', 'a')() # E: "str" not callable
f(1, 1)() # E: "int" not callable

def g(x: Union[T, List[T]]) -> List[T]: pass
def h(x: List[str]) -> None: pass
g('a')() # E: List[str] not callable

# The next line is a case where there are multiple ways to satisfy a constraint
# involving a Union. Either T = List[str] or T = str would turn out to be valid,
# but mypy doesn't know how to branch on these two options (and potentially have
# to backtrack later) and defaults to T = None. The result is an awkward error
# message. Either a better error message, or simply accepting the call, would be
# preferable here.
g(['a']) # E: Argument 1 to "g" has incompatible type List[str]; expected List[None]

h(g(['a']))

def i(x: Union[List[T], List[U]], y: List[T], z: List[U]) -> None: pass
a = [1]
b = ['b']
i(a, a, b)
i(b, a, b)
i(a, b, b) # E: Argument 1 to "i" has incompatible type List[int]; expected List[str]
[builtins fixtures/list.py]


[case testUnionInferenceWithTypeVarValues]
from typing import TypeVar, Union
AnyStr = TypeVar('AnyStr', bytes, str)
def f(x: Union[AnyStr, int], *a: AnyStr) -> None: pass
f('foo')
f('foo', 'bar')
f('foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object"
f(1)
f(1, 'foo')
f(1, 'foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object"
[builtins fixtures/primitives.py]


[case testUnionTwoPassInference-skip]
from typing import TypeVar, Union, List
T = TypeVar('T')
U = TypeVar('U')
def j(x: Union[List[T], List[U]], y: List[T]) -> List[U]: pass

a = [1]
b = ['b']
# We could infer: Since List[str] <: List[T], we must have T = str.
# Then since List[int] <: Union[List[str], List[U]], and List[int] is
# not a subtype of List[str], we must have U = int.
# This is not currently implemented.
j(a, b)
[builtins fixtures/list.py]


[case testUnionContext]
from typing import TypeVar, Union, List
T = TypeVar('T')
def f() -> List[T]: pass
d1 = f() # type: Union[List[int], str]
d2 = f() # type: Union[int, str] # E: Incompatible types in assignment (expression has type List[None], variable has type "Union[int, str]")
def g(x: T) -> List[T]: pass
d3 = g(1) # type: Union[List[int], List[str]]
[builtins fixtures/list.py]


[case testGenericFunctionSubtypingWithUnions]
from typing import TypeVar, Union, List
T = TypeVar('T')
S = TypeVar('S')
def k1(x: int, y: List[T]) -> List[Union[T, int]]: pass
def k2(x: S, y: List[T]) -> List[Union[T, int]]: pass
a = k2
a = k2
a = k1 # E: Incompatible types in assignment (expression has type Callable[[int, List[T]], List[Union[T, int]]], variable has type Callable[[S, List[T]], List[Union[T, int]]])
b = k1
b = k1
b = k2
[builtins fixtures/list.py]


-- Literal expressions
-- -------------------

Expand Down
19 changes: 19 additions & 0 deletions mypy/test/data/check-unions.test
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,22 @@ def f(x: Optional[int]) -> None: pass
f(1)
f(None)
f('') # E: Argument 1 to "f" has incompatible type "str"; expected "int"

[case testUnionSimplificationGenericFunction]
from typing import TypeVar, Union, List
T = TypeVar('T')
def f(x: List[T]) -> Union[T, int]: pass
def g(y: str) -> None: pass
a = f([1])
g(a) # E: Argument 1 to "g" has incompatible type "int"; expected "str"
[builtins fixtures/list.py]

[case testUnionSimplificationGenericClass]
from typing import TypeVar, Union, Generic
T = TypeVar('T')
U = TypeVar('U')
class C(Generic[T, U]):
def f(self, x: str) -> Union[T, U]: pass
a = C() # type: C[int, int]
b = a.f('a')
a.f(b) # E: Argument 1 to "f" of "C" has incompatible type "int"; expected "str"