From 9fa4c122e102159dd99615135141f9b40f63da72 Mon Sep 17 00:00:00 2001 From: Reid Barton Date: Mon, 2 May 2016 09:22:07 -0700 Subject: [PATCH 1/2] Simplify unions when expanding types --- mypy/expandtype.py | 4 +++- mypy/test/data/check-unions.test | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mypy/expandtype.py b/mypy/expandtype.py index d08a11fb8a4c..60730b563bd6 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -87,7 +87,9 @@ def visit_tuple_type(self, t: TupleType) -> Type: return TupleType(self.expand_types(t.items), t.fallback, t.line) def visit_union_type(self, t: UnionType) -> Type: - return UnionType(self.expand_types(t.items), t.line) + # After substituting for type variables in t.items, + # some of the resulting types might be subtypes of others. + return UnionType.make_simplified_union(self.expand_types(t.items), t.line) def visit_partial_type(self, t: PartialType) -> Type: return t diff --git a/mypy/test/data/check-unions.test b/mypy/test/data/check-unions.test index 3e05cc9299eb..81b7d4451877 100644 --- a/mypy/test/data/check-unions.test +++ b/mypy/test/data/check-unions.test @@ -110,3 +110,22 @@ def f(x: Optional[int]) -> None: pass f(1) f(None) f('') # E: Argument 1 to "f" has incompatible type "str"; expected "int" + +[case testUnionSimplificationGenericFunction] +from typing import TypeVar, Union, List +T = TypeVar('T') +def f(x: List[T]) -> Union[T, int]: pass +def g(y: str) -> None: pass +a = f([1]) +g(a) # E: Argument 1 to "g" has incompatible type "int"; expected "str" +[builtins fixtures/list.py] + +[case testUnionSimplificationGenericClass] +from typing import TypeVar, Union, Generic +T = TypeVar('T') +U = TypeVar('U') +class C(Generic[T, U]): + def f(self, x: str) -> Union[T, U]: pass +a = C() # type: C[int, int] +b = a.f('a') +a.f(b) # E: Argument 1 to "f" of "C" has incompatible type "int"; expected "str" From 212a5fe041983c9861a3acf2d6118725806a6d8d Mon Sep 17 00:00:00 2001 From: Reid Barton Date: Thu, 14 Apr 2016 20:40:38 -0700 Subject: [PATCH 2/2] Generate correct constraints for Unions The old code always generated a conjunction, where we sometimes want a disjunction. Since we can't represent or solve a disjunction of constraints, need to approximate. In the presence of Unions, the constraints inferred from (template, actual, direction) are not necessarily the reverses of the constraints inferred from (template, actual, neg_op(direction)). The visit_instance and visit_callable cases were updated accordingly. Test testGenericFunctionSubtypingWithUnions fails without this change. Fixes #1458 and the remaining part of #1241. --- mypy/constraints.py | 114 ++++++++++++++++++++++------ mypy/test/data/check-inference.test | 92 ++++++++++++++++++++++ 2 files changed, 182 insertions(+), 24 deletions(-) diff --git a/mypy/constraints.py b/mypy/constraints.py index 542ddab54e09..4ea64f43bb65 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -10,6 +10,7 @@ from mypy.maptype import map_instance_to_supertype from mypy import nodes import mypy.subtypes +from mypy.erasetype import erase_typevars SUBTYPE_OF = 0 # type: int @@ -119,9 +120,82 @@ def infer_constraints(template: Type, actual: Type, The constraints are represented as Constraint objects. """ + # If the template is simply a type variable, emit a Constraint directly. + # We need to handle this case before handling Unions for two reasons: + # 1. "T <: Union[U1, U2]" is not equivalent to "T <: U1 or T <: U2", + # because T can itself be a union (notably, Union[U1, U2] itself). + # 2. "T :> Union[U1, U2]" is logically equivalent to "T :> U1 and + # T :> U2", but they are not equivalent to the constraint solver, + # which never introduces new Union types (it uses join() instead). + if isinstance(template, TypeVarType): + return [Constraint(template.id, direction, actual)] + + # Now handle the case of either template or actual being a Union. + # For a Union to be a subtype of another type, every item of the Union + # must be a subtype of that type, so concatenate the constraints. + if direction == SUBTYPE_OF and isinstance(template, UnionType): + res = [] + for t_item in template.items: + res.extend(infer_constraints(t_item, actual, direction)) + return res + if direction == SUPERTYPE_OF and isinstance(actual, UnionType): + res = [] + for a_item in actual.items: + res.extend(infer_constraints(template, a_item, direction)) + return res + + # Now the potential subtype is known not to be a Union or a type + # variable that we are solving for. In that case, for a Union to + # be a supertype of the potential subtype, some item of the Union + # must be a supertype of it. + if direction == SUBTYPE_OF and isinstance(actual, UnionType): + return any_constraints( + [infer_constraints_if_possible(template, a_item, direction) + for a_item in actual.items]) + if direction == SUPERTYPE_OF and isinstance(template, UnionType): + return any_constraints( + [infer_constraints_if_possible(t_item, actual, direction) + for t_item in template.items]) + + # Remaining cases are handled by ConstraintBuilderVisitor. return template.accept(ConstraintBuilderVisitor(actual, direction)) +def infer_constraints_if_possible(template: Type, actual: Type, + direction: int) -> Optional[List[Constraint]]: + """Like infer_constraints, but return None if the input relation is + known to be unsatisfiable, for example if template=List[T] and actual=int. + (In this case infer_constraints would return [], just like it would for + an automatically satisfied relation like template=List[T] and actual=object.) + """ + if (direction == SUBTYPE_OF and + not mypy.subtypes.is_subtype(erase_typevars(template), actual)): + return None + if (direction == SUPERTYPE_OF and + not mypy.subtypes.is_subtype(actual, erase_typevars(template))): + return None + return infer_constraints(template, actual, direction) + + +def any_constraints(options: List[Optional[List[Constraint]]]) -> List[Constraint]: + """Deduce what we can from a collection of constraint lists given that + at least one of the lists must be satisfied. A None element in the + list of options represents an unsatisfiable constraint and is ignored. + """ + valid_options = [option for option in options if option is not None] + if len(valid_options) == 1: + return valid_options[0] + # Otherwise, there are either no valid options or multiple valid options. + # Give up and deduce nothing. + return [] + + # TODO: In the latter case, it could happen that every valid + # option requires the same constraint on the same variable. Then + # we could include that that constraint in the result. Or more + # generally, if a given (variable, direction) pair appears in + # every option, combine the bounds with meet/join. + + class ConstraintBuilderVisitor(TypeVisitor[List[Constraint]]): """Visitor class for inferring type constraints.""" @@ -163,7 +237,8 @@ def visit_partial_type(self, template: PartialType) -> List[Constraint]: # Non-trivial leaf type def visit_type_var(self, template: TypeVarType) -> List[Constraint]: - return [Constraint(template.id, self.direction, self.actual)] + assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor" + " (should have been handled in infer_constraints)") # Non-leaf types @@ -177,12 +252,12 @@ def visit_instance(self, template: Instance) -> List[Constraint]: mapped = map_instance_to_supertype(template, instance.type) for i in range(len(instance.args)): # The constraints for generic type parameters are - # invariant. Include the default constraint and its - # negation to achieve the effect. - cb = infer_constraints(mapped.args[i], instance.args[i], - self.direction) - res.extend(cb) - res.extend(negate_constraints(cb)) + # invariant. Include constraints from both directions + # to achieve the effect. + res.extend(infer_constraints( + mapped.args[i], instance.args[i], self.direction)) + res.extend(infer_constraints( + mapped.args[i], instance.args[i], neg_op(self.direction))) return res elif (self.direction == SUPERTYPE_OF and instance.type.has_base(template.type.fullname())): @@ -190,10 +265,10 @@ def visit_instance(self, template: Instance) -> List[Constraint]: for j in range(len(template.args)): # The constraints for generic type parameters are # invariant. - cb = infer_constraints(template.args[j], mapped.args[j], - self.direction) - res.extend(cb) - res.extend(negate_constraints(cb)) + res.extend(infer_constraints( + template.args[j], mapped.args[j], self.direction)) + res.extend(infer_constraints( + template.args[j], mapped.args[j], neg_op(self.direction))) return res if isinstance(actual, AnyType): # IDEA: Include both ways, i.e. add negation as well? @@ -222,8 +297,8 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: if not template.is_ellipsis_args: # The lengths should match, but don't crash (it will error elsewhere). for t, a in zip(template.arg_types, cactual.arg_types): - # Negate constraints due function argument type contravariance. - res.extend(negate_constraints(infer_constraints(t, a, self.direction))) + # Negate direction due to function argument type contravariance. + res.extend(infer_constraints(t, a, neg_op(self.direction))) res.extend(infer_constraints(template.ret_type, cactual.ret_type, self.direction)) return res @@ -264,10 +339,8 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]: return [] def visit_union_type(self, template: UnionType) -> List[Constraint]: - res = [] # type: List[Constraint] - for item in template.items: - res.extend(infer_constraints(item, self.actual, self.direction)) - return res + assert False, ("Unexpected UnionType in ConstraintBuilderVisitor" + " (should have been handled in infer_constraints)") def infer_against_any(self, types: List[Type]) -> List[Constraint]: res = [] # type: List[Constraint] @@ -282,13 +355,6 @@ def visit_overloaded(self, type: Overloaded) -> List[Constraint]: return res -def negate_constraints(constraints: List[Constraint]) -> List[Constraint]: - res = [] # type: List[Constraint] - for c in constraints: - res.append(Constraint(c.type_var, neg_op(c.op), c.target)) - return res - - def neg_op(op: int) -> int: """Map SubtypeOf to SupertypeOf and vice versa.""" diff --git a/mypy/test/data/check-inference.test b/mypy/test/data/check-inference.test index 0d8db680307c..2c5ef5e24b8d 100644 --- a/mypy/test/data/check-inference.test +++ b/mypy/test/data/check-inference.test @@ -690,6 +690,98 @@ l = lb # E: Incompatible types in assignment (expression has type List[bool], va [builtins fixtures/for.py] +-- Generic function inference with unions +-- -------------------------------------- + + +[case testUnionInference] +from typing import TypeVar, Union, List +T = TypeVar('T') +U = TypeVar('U') +def f(x: Union[T, int], y: T) -> T: pass +f(1, 'a')() # E: "str" not callable +f('a', 1)() # E: "object" not callable +f('a', 'a')() # E: "str" not callable +f(1, 1)() # E: "int" not callable + +def g(x: Union[T, List[T]]) -> List[T]: pass +def h(x: List[str]) -> None: pass +g('a')() # E: List[str] not callable + +# The next line is a case where there are multiple ways to satisfy a constraint +# involving a Union. Either T = List[str] or T = str would turn out to be valid, +# but mypy doesn't know how to branch on these two options (and potentially have +# to backtrack later) and defaults to T = None. The result is an awkward error +# message. Either a better error message, or simply accepting the call, would be +# preferable here. +g(['a']) # E: Argument 1 to "g" has incompatible type List[str]; expected List[None] + +h(g(['a'])) + +def i(x: Union[List[T], List[U]], y: List[T], z: List[U]) -> None: pass +a = [1] +b = ['b'] +i(a, a, b) +i(b, a, b) +i(a, b, b) # E: Argument 1 to "i" has incompatible type List[int]; expected List[str] +[builtins fixtures/list.py] + + +[case testUnionInferenceWithTypeVarValues] +from typing import TypeVar, Union +AnyStr = TypeVar('AnyStr', bytes, str) +def f(x: Union[AnyStr, int], *a: AnyStr) -> None: pass +f('foo') +f('foo', 'bar') +f('foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object" +f(1) +f(1, 'foo') +f(1, 'foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object" +[builtins fixtures/primitives.py] + + +[case testUnionTwoPassInference-skip] +from typing import TypeVar, Union, List +T = TypeVar('T') +U = TypeVar('U') +def j(x: Union[List[T], List[U]], y: List[T]) -> List[U]: pass + +a = [1] +b = ['b'] +# We could infer: Since List[str] <: List[T], we must have T = str. +# Then since List[int] <: Union[List[str], List[U]], and List[int] is +# not a subtype of List[str], we must have U = int. +# This is not currently implemented. +j(a, b) +[builtins fixtures/list.py] + + +[case testUnionContext] +from typing import TypeVar, Union, List +T = TypeVar('T') +def f() -> List[T]: pass +d1 = f() # type: Union[List[int], str] +d2 = f() # type: Union[int, str] # E: Incompatible types in assignment (expression has type List[None], variable has type "Union[int, str]") +def g(x: T) -> List[T]: pass +d3 = g(1) # type: Union[List[int], List[str]] +[builtins fixtures/list.py] + + +[case testGenericFunctionSubtypingWithUnions] +from typing import TypeVar, Union, List +T = TypeVar('T') +S = TypeVar('S') +def k1(x: int, y: List[T]) -> List[Union[T, int]]: pass +def k2(x: S, y: List[T]) -> List[Union[T, int]]: pass +a = k2 +a = k2 +a = k1 # E: Incompatible types in assignment (expression has type Callable[[int, List[T]], List[Union[T, int]]], variable has type Callable[[S, List[T]], List[Union[T, int]]]) +b = k1 +b = k1 +b = k2 +[builtins fixtures/list.py] + + -- Literal expressions -- -------------------