Skip to content

Commit 77ce5e2

Browse files
ilevkivskyiJukkaL
authored andcommitted
Allow assignments to multiple targets from union types (#4067)
Fixes #3859 Fixes #3240 Fixes #1855 Fixes #1575 This is a simple fix of various bugs and a crash based on the idea originally proposed in #2219. The idea is to check assignment for every item in a union. However, in contrast to #2219, I think it will be more expected/consistent to construct a union of the resulting types instead of a join, for example: ``` x: Union[int, str] x1 = x reveal_type(x1) # Revealed type is 'Union[int, str]' y: Union[Tuple[int], Tuple[str]] (y1,) = y reveal_type(y1) # Revealed type is 'Union[int, str]' ``` @elazarg did the initial work on this.
1 parent bd1b3d7 commit 77ce5e2

File tree

7 files changed

+578
-11
lines changed

7 files changed

+578
-11
lines changed

mypy/binder.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
from typing import Dict, List, Set, Iterator, Union, Optional, cast
1+
from typing import Dict, List, Set, Iterator, Union, Optional, Tuple, cast
22
from contextlib import contextmanager
3+
from collections import defaultdict
4+
5+
MYPY = False
6+
if MYPY:
7+
from typing import DefaultDict
38

49
from mypy.types import Type, AnyType, PartialType, UnionType, TypeOfAny
510
from mypy.subtypes import is_subtype
@@ -37,6 +42,12 @@ def __init__(self) -> None:
3742
self.unreachable = False
3843

3944

45+
if MYPY:
46+
# This is the type of stored assignments for union type rvalues.
47+
# We use 'if MYPY: ...' since typing-3.5.1 does not have 'DefaultDict'
48+
Assigns = DefaultDict[Expression, List[Tuple[Type, Optional[Type]]]]
49+
50+
4051
class ConditionalTypeBinder:
4152
"""Keep track of conditional types of variables.
4253
@@ -57,6 +68,9 @@ class A:
5768
reveal_type(lst[0].a) # str
5869
```
5970
"""
71+
# Stored assignments for situations with tuple/list lvalue and rvalue of union type.
72+
# This maps an expression to a list of bound types for every item in the union type.
73+
type_assignments = None # type: Optional[Assigns]
6074

6175
def __init__(self) -> None:
6276
# The stack of frames currently used. These map
@@ -210,10 +224,30 @@ def pop_frame(self, can_skip: bool, fall_through: int) -> Frame:
210224

211225
return result
212226

227+
@contextmanager
228+
def accumulate_type_assignments(self) -> 'Iterator[Assigns]':
229+
"""Push a new map to collect assigned types in multiassign from union.
230+
231+
If this map is not None, actual binding is deferred until all items in
232+
the union are processed (a union of collected items is later bound
233+
manually by the caller).
234+
"""
235+
old_assignments = None
236+
if self.type_assignments is not None:
237+
old_assignments = self.type_assignments
238+
self.type_assignments = defaultdict(list)
239+
yield self.type_assignments
240+
self.type_assignments = old_assignments
241+
213242
def assign_type(self, expr: Expression,
214243
type: Type,
215244
declared_type: Optional[Type],
216245
restrict_any: bool = False) -> None:
246+
if self.type_assignments is not None:
247+
# We are in a multiassign from union, defer the actual binding,
248+
# just collect the types.
249+
self.type_assignments[expr].append((type, declared_type))
250+
return
217251
if not isinstance(expr, BindableTypes):
218252
return None
219253
if not literal(expr):

mypy/checker.py

+88-8
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
131131
# Used for collecting inferred attribute types so that they can be checked
132132
# for consistency.
133133
inferred_attribute_types = None # type: Optional[Dict[Var, Type]]
134+
# Don't infer partial None types if we are processing assignment from Union
135+
no_partial_types = False # type: bool
134136

135137
# The set of all dependencies (suppressed or not) that this module accesses, either
136138
# directly or indirectly.
@@ -1605,12 +1607,13 @@ def check_multi_assignment(self, lvalues: List[Lvalue],
16051607
rvalue: Expression,
16061608
context: Context,
16071609
infer_lvalue_type: bool = True,
1608-
msg: Optional[str] = None) -> None:
1610+
rv_type: Optional[Type] = None,
1611+
undefined_rvalue: bool = False) -> None:
16091612
"""Check the assignment of one rvalue to a number of lvalues."""
16101613

16111614
# Infer the type of an ordinary rvalue expression.
1612-
rvalue_type = self.expr_checker.accept(rvalue) # TODO maybe elsewhere; redundant
1613-
undefined_rvalue = False
1615+
# TODO: maybe elsewhere; redundant.
1616+
rvalue_type = rv_type or self.expr_checker.accept(rvalue)
16141617

16151618
if isinstance(rvalue_type, UnionType):
16161619
# If this is an Optional type in non-strict Optional code, unwrap it.
@@ -1628,10 +1631,71 @@ def check_multi_assignment(self, lvalues: List[Lvalue],
16281631
elif isinstance(rvalue_type, TupleType):
16291632
self.check_multi_assignment_from_tuple(lvalues, rvalue, rvalue_type,
16301633
context, undefined_rvalue, infer_lvalue_type)
1634+
elif isinstance(rvalue_type, UnionType):
1635+
self.check_multi_assignment_from_union(lvalues, rvalue, rvalue_type, context,
1636+
infer_lvalue_type)
16311637
else:
16321638
self.check_multi_assignment_from_iterable(lvalues, rvalue_type,
16331639
context, infer_lvalue_type)
16341640

1641+
def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: Expression,
1642+
rvalue_type: UnionType, context: Context,
1643+
infer_lvalue_type: bool) -> None:
1644+
"""Check assignment to multiple lvalue targets when rvalue type is a Union[...].
1645+
For example:
1646+
1647+
t: Union[Tuple[int, int], Tuple[str, str]]
1648+
x, y = t
1649+
reveal_type(x) # Union[int, str]
1650+
1651+
The idea in this case is to process the assignment for every item of the union.
1652+
Important note: the types are collected in two places, 'union_types' contains
1653+
inferred types for first assignments, 'assignments' contains the narrowed types
1654+
for binder.
1655+
"""
1656+
self.no_partial_types = True
1657+
transposed = tuple([] for _ in
1658+
self.flatten_lvalues(lvalues)) # type: Tuple[List[Type], ...]
1659+
# Notify binder that we want to defer bindings and instead collect types.
1660+
with self.binder.accumulate_type_assignments() as assignments:
1661+
for item in rvalue_type.items:
1662+
# Type check the assignment separately for each union item and collect
1663+
# the inferred lvalue types for each union item.
1664+
self.check_multi_assignment(lvalues, rvalue, context,
1665+
infer_lvalue_type=infer_lvalue_type,
1666+
rv_type=item, undefined_rvalue=True)
1667+
for t, lv in zip(transposed, self.flatten_lvalues(lvalues)):
1668+
t.append(self.type_map.pop(lv, AnyType(TypeOfAny.special_form)))
1669+
union_types = tuple(UnionType.make_simplified_union(col) for col in transposed)
1670+
for expr, items in assignments.items():
1671+
# Bind a union of types collected in 'assignments' to every expression.
1672+
if isinstance(expr, StarExpr):
1673+
expr = expr.expr
1674+
types, declared_types = zip(*items)
1675+
self.binder.assign_type(expr,
1676+
UnionType.make_simplified_union(types),
1677+
UnionType.make_simplified_union(declared_types),
1678+
False)
1679+
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
1680+
# Properly store the inferred types.
1681+
_1, _2, inferred = self.check_lvalue(lv)
1682+
if inferred:
1683+
self.set_inferred_type(inferred, lv, union)
1684+
else:
1685+
self.store_type(lv, union)
1686+
self.no_partial_types = False
1687+
1688+
def flatten_lvalues(self, lvalues: List[Expression]) -> List[Expression]:
1689+
res = [] # type: List[Expression]
1690+
for lv in lvalues:
1691+
if isinstance(lv, (TupleExpr, ListExpr)):
1692+
res.extend(self.flatten_lvalues(lv.items))
1693+
if isinstance(lv, StarExpr):
1694+
# Unwrap StarExpr, since it is unwrapped by other helpers.
1695+
lv = lv.expr
1696+
res.append(lv)
1697+
return res
1698+
16351699
def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expression,
16361700
rvalue_type: TupleType, context: Context,
16371701
undefined_rvalue: bool,
@@ -1654,7 +1718,11 @@ def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expre
16541718
relevant_items = reinferred_rvalue_type.relevant_items()
16551719
if len(relevant_items) == 1:
16561720
reinferred_rvalue_type = relevant_items[0]
1657-
1721+
if isinstance(reinferred_rvalue_type, UnionType):
1722+
self.check_multi_assignment_from_union(lvalues, rvalue,
1723+
reinferred_rvalue_type, context,
1724+
infer_lvalue_type)
1725+
return
16581726
assert isinstance(reinferred_rvalue_type, TupleType)
16591727
rvalue_type = reinferred_rvalue_type
16601728

@@ -1716,7 +1784,7 @@ def split_around_star(self, items: List[T], star_index: int,
17161784
returns in: ([1,2], [3,4,5], [6,7])
17171785
"""
17181786
nr_right_of_star = length - star_index - 1
1719-
right_index = nr_right_of_star if -nr_right_of_star != 0 else len(items)
1787+
right_index = -nr_right_of_star if nr_right_of_star != 0 else len(items)
17201788
left = items[:star_index]
17211789
star = items[star_index:right_index]
17221790
right = items[right_index:]
@@ -1800,7 +1868,7 @@ def infer_variable_type(self, name: Var, lvalue: Lvalue,
18001868
"""Infer the type of initialized variables from initializer type."""
18011869
if isinstance(init_type, DeletedType):
18021870
self.msg.deleted_as_rvalue(init_type, context)
1803-
elif not is_valid_inferred_type(init_type):
1871+
elif not is_valid_inferred_type(init_type) and not self.no_partial_types:
18041872
# We cannot use the type of the initialization expression for full type
18051873
# inference (it's not specific enough), but we might be able to give
18061874
# partial type which will be made more specific later. A partial type
@@ -1897,7 +1965,7 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type,
18971965
rvalue: Expression, context: Context) -> Tuple[Type, bool]:
18981966
"""Type member assigment.
18991967
1900-
This is defers to check_simple_assignment, unless the member expression
1968+
This defers to check_simple_assignment, unless the member expression
19011969
is a descriptor, in which case this checks descriptor semantics as well.
19021970
19031971
Return the inferred rvalue_type and whether to infer anything about the attribute type
@@ -2697,7 +2765,19 @@ def iterable_item_type(self, instance: Instance) -> Type:
26972765
iterable = map_instance_to_supertype(
26982766
instance,
26992767
self.lookup_typeinfo('typing.Iterable'))
2700-
return iterable.args[0]
2768+
item_type = iterable.args[0]
2769+
if not isinstance(item_type, AnyType):
2770+
# This relies on 'map_instance_to_supertype' returning 'Iterable[Any]'
2771+
# in case there is no explicit base class.
2772+
return item_type
2773+
# Try also structural typing.
2774+
iter_type = find_member('__iter__', instance, instance)
2775+
if (iter_type and isinstance(iter_type, CallableType) and
2776+
isinstance(iter_type.ret_type, Instance)):
2777+
iterator = map_instance_to_supertype(iter_type.ret_type,
2778+
self.lookup_typeinfo('typing.Iterator'))
2779+
item_type = iterator.args[0]
2780+
return item_type
27012781

27022782
def function_type(self, func: FuncBase) -> FunctionLike:
27032783
return function_type(func, self.named_type('builtins.function'))

mypy/maptype.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def map_instance_to_supertype(instance: Instance,
1010
"""Produce a supertype of `instance` that is an Instance
1111
of `superclass`, mapping type arguments up the chain of bases.
1212
13-
`superclass` is required to be a superclass of `instance.type`.
13+
If `superclass` is not a nominal superclass of `instance.type`,
14+
then all type arguments are mapped to 'Any'.
1415
"""
1516
if instance.type == superclass:
1617
# Fast path: `instance` already belongs to `superclass`.

0 commit comments

Comments
 (0)