Skip to content

Commit 13f446f

Browse files
committed
Generate correct constraints for Union[...] :> t
The old code generated a conjunction, where we wanted a disjunction. Since we can't represent or solve a disjunction of constraints, need to approximate. Addresses part of #1241.
1 parent 33559c0 commit 13f446f

File tree

2 files changed

+84
-4
lines changed

2 files changed

+84
-4
lines changed

mypy/constraints.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from mypy.maptype import map_instance_to_supertype
1111
from mypy import nodes
1212
import mypy.subtypes
13+
from mypy.erasetype import erase_typevars
1314

1415

1516
SUBTYPE_OF = 0 # type: int
@@ -272,10 +273,35 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
272273
return []
273274

274275
def visit_union_type(self, template: UnionType) -> List[Constraint]:
275-
res = [] # type: List[Constraint]
276-
for item in template.items:
277-
res.extend(infer_constraints(item, self.actual, self.direction))
278-
return res
276+
if self.direction == SUBTYPE_OF:
277+
# A union is a subtype of T if all the items are subtypes of T.
278+
res = [] # type: List[Constraint]
279+
for item in template.items:
280+
res.extend(infer_constraints(item, self.actual, self.direction))
281+
return res
282+
else:
283+
# A union is a supertype of T if some item is a supertype
284+
# of T. We would like to take the disjunction of the
285+
# constraints for the items, but we can't represent
286+
# that. So we need to approximate.
287+
288+
# First, throw away any items of the union that definitely
289+
# can't be a supertype of T. (Note that using
290+
# erase_typevars here is slightly conservative--it loses
291+
# information about repeated instances of a single type
292+
# variable.)
293+
candidates = [item for item in template.items
294+
if mypy.subtypes.is_subtype(self.actual, erase_typevars(item))]
295+
# If there's just one item left, then use it.
296+
if len(candidates) == 1:
297+
return infer_constraints(candidates[0], self.actual, self.direction)
298+
299+
# Otherwise, give up and don't deduce any constraints.
300+
# TODO: If there are multiple candidates, but they all require
301+
# the same constraint on the same variable, we could return
302+
# that constraint. Or more generally, combine bounds for a given
303+
# variable and direction with meet/join.
304+
return []
279305

280306
def infer_against_any(self, types: List[Type]) -> List[Constraint]:
281307
res = [] # type: List[Constraint]

mypy/test/data/check-inference.test

+54
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,60 @@ l = lb # E: Incompatible types in assignment (expression has type List[bool], va
690690
[builtins fixtures/for.py]
691691

692692

693+
-- Generic function inference with unions
694+
-- --------------------------------------
695+
696+
697+
[case testUnionInference]
698+
from typing import TypeVar, Union, List
699+
T = TypeVar('T')
700+
U = TypeVar('U')
701+
def f(x: Union[T, int], y: T) -> T: pass
702+
f(1, 'a')() # E: "str" not callable
703+
f('a', 1)() # E: "object" not callable
704+
705+
def g(x: Union[T, List[T]]) -> List[T]: pass
706+
def h(x: List[str]) -> None: pass
707+
g('a')() # E: List[str] not callable
708+
g(['a']) # E: Argument 1 to "g" has incompatible type List[str]; expected "Union[None, List[None]]"
709+
h(g(['a']))
710+
711+
def i(x: Union[List[T], List[U]], y: List[T], z: List[U]) -> None: pass
712+
a = [1]
713+
b = ['b']
714+
i(a, a, b)
715+
i(b, a, b)
716+
i(a, b, b) # E: Argument 1 to "i" has incompatible type List[int]; expected "Union[List[str], List[str]]"
717+
[builtins fixtures/list.py]
718+
719+
720+
[case testUnionInferenceWithTypeVarValues]
721+
from typing import TypeVar, Union
722+
AnyStr = TypeVar('AnyStr', bytes, str)
723+
def f(x: Union[AnyStr, int], *a: AnyStr) -> None: pass
724+
f('foo')
725+
f('foo', 'bar')
726+
f('foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object"
727+
f(1)
728+
f(1, 'foo')
729+
f(1, 'foo', b'bar') # E: Type argument 1 of "f" has incompatible value "object"
730+
[builtins fixtures/primitives.py]
731+
732+
733+
[case testUnionTwoPassInference-skip]
734+
from typing import TypeVar, Union
735+
T = TypeVar('T')
736+
U = TypeVar('U')
737+
def j(x: Union[List[T], List[U]], y: List[T]) -> List[U]: pass
738+
739+
a = [1]
740+
b = ['b']
741+
# We could infer: Since List[str] <: List[T], we must have T = str.
742+
# Then since List[int] <: Union[List[str], List[U]], and List[int] is
743+
# not a subtype of List[str], we must have U = int.
744+
j(a, b)
745+
746+
693747
-- Literal expressions
694748
-- -------------------
695749

0 commit comments

Comments
 (0)