@@ -131,6 +131,8 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
131
131
# Used for collecting inferred attribute types so that they can be checked
132
132
# for consistency.
133
133
inferred_attribute_types = None # type: Optional[Dict[Var, Type]]
134
+ # Don't infer partial None types if we are processing assignment from Union
135
+ no_partial_types = False # type: bool
134
136
135
137
# The set of all dependencies (suppressed or not) that this module accesses, either
136
138
# directly or indirectly.
@@ -1605,12 +1607,13 @@ def check_multi_assignment(self, lvalues: List[Lvalue],
1605
1607
rvalue : Expression ,
1606
1608
context : Context ,
1607
1609
infer_lvalue_type : bool = True ,
1608
- msg : Optional [str ] = None ) -> None :
1610
+ rv_type : Optional [Type ] = None ,
1611
+ undefined_rvalue : bool = False ) -> None :
1609
1612
"""Check the assignment of one rvalue to a number of lvalues."""
1610
1613
1611
1614
# Infer the type of an ordinary rvalue expression.
1612
- rvalue_type = self . expr_checker . accept ( rvalue ) # TODO maybe elsewhere; redundant
1613
- undefined_rvalue = False
1615
+ # TODO: maybe elsewhere; redundant.
1616
+ rvalue_type = rv_type or self . expr_checker . accept ( rvalue )
1614
1617
1615
1618
if isinstance (rvalue_type , UnionType ):
1616
1619
# If this is an Optional type in non-strict Optional code, unwrap it.
@@ -1628,10 +1631,71 @@ def check_multi_assignment(self, lvalues: List[Lvalue],
1628
1631
elif isinstance (rvalue_type , TupleType ):
1629
1632
self .check_multi_assignment_from_tuple (lvalues , rvalue , rvalue_type ,
1630
1633
context , undefined_rvalue , infer_lvalue_type )
1634
+ elif isinstance (rvalue_type , UnionType ):
1635
+ self .check_multi_assignment_from_union (lvalues , rvalue , rvalue_type , context ,
1636
+ infer_lvalue_type )
1631
1637
else :
1632
1638
self .check_multi_assignment_from_iterable (lvalues , rvalue_type ,
1633
1639
context , infer_lvalue_type )
1634
1640
1641
+ def check_multi_assignment_from_union (self , lvalues : List [Expression ], rvalue : Expression ,
1642
+ rvalue_type : UnionType , context : Context ,
1643
+ infer_lvalue_type : bool ) -> None :
1644
+ """Check assignment to multiple lvalue targets when rvalue type is a Union[...].
1645
+ For example:
1646
+
1647
+ t: Union[Tuple[int, int], Tuple[str, str]]
1648
+ x, y = t
1649
+ reveal_type(x) # Union[int, str]
1650
+
1651
+ The idea in this case is to process the assignment for every item of the union.
1652
+ Important note: the types are collected in two places, 'union_types' contains
1653
+ inferred types for first assignments, 'assignments' contains the narrowed types
1654
+ for binder.
1655
+ """
1656
+ self .no_partial_types = True
1657
+ transposed = tuple ([] for _ in
1658
+ self .flatten_lvalues (lvalues )) # type: Tuple[List[Type], ...]
1659
+ # Notify binder that we want to defer bindings and instead collect types.
1660
+ with self .binder .accumulate_type_assignments () as assignments :
1661
+ for item in rvalue_type .items :
1662
+ # Type check the assignment separately for each union item and collect
1663
+ # the inferred lvalue types for each union item.
1664
+ self .check_multi_assignment (lvalues , rvalue , context ,
1665
+ infer_lvalue_type = infer_lvalue_type ,
1666
+ rv_type = item , undefined_rvalue = True )
1667
+ for t , lv in zip (transposed , self .flatten_lvalues (lvalues )):
1668
+ t .append (self .type_map .pop (lv , AnyType (TypeOfAny .special_form )))
1669
+ union_types = tuple (UnionType .make_simplified_union (col ) for col in transposed )
1670
+ for expr , items in assignments .items ():
1671
+ # Bind a union of types collected in 'assignments' to every expression.
1672
+ if isinstance (expr , StarExpr ):
1673
+ expr = expr .expr
1674
+ types , declared_types = zip (* items )
1675
+ self .binder .assign_type (expr ,
1676
+ UnionType .make_simplified_union (types ),
1677
+ UnionType .make_simplified_union (declared_types ),
1678
+ False )
1679
+ for union , lv in zip (union_types , self .flatten_lvalues (lvalues )):
1680
+ # Properly store the inferred types.
1681
+ _1 , _2 , inferred = self .check_lvalue (lv )
1682
+ if inferred :
1683
+ self .set_inferred_type (inferred , lv , union )
1684
+ else :
1685
+ self .store_type (lv , union )
1686
+ self .no_partial_types = False
1687
+
1688
+ def flatten_lvalues (self , lvalues : List [Expression ]) -> List [Expression ]:
1689
+ res = [] # type: List[Expression]
1690
+ for lv in lvalues :
1691
+ if isinstance (lv , (TupleExpr , ListExpr )):
1692
+ res .extend (self .flatten_lvalues (lv .items ))
1693
+ if isinstance (lv , StarExpr ):
1694
+ # Unwrap StarExpr, since it is unwrapped by other helpers.
1695
+ lv = lv .expr
1696
+ res .append (lv )
1697
+ return res
1698
+
1635
1699
def check_multi_assignment_from_tuple (self , lvalues : List [Lvalue ], rvalue : Expression ,
1636
1700
rvalue_type : TupleType , context : Context ,
1637
1701
undefined_rvalue : bool ,
@@ -1654,7 +1718,11 @@ def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expre
1654
1718
relevant_items = reinferred_rvalue_type .relevant_items ()
1655
1719
if len (relevant_items ) == 1 :
1656
1720
reinferred_rvalue_type = relevant_items [0 ]
1657
-
1721
+ if isinstance (reinferred_rvalue_type , UnionType ):
1722
+ self .check_multi_assignment_from_union (lvalues , rvalue ,
1723
+ reinferred_rvalue_type , context ,
1724
+ infer_lvalue_type )
1725
+ return
1658
1726
assert isinstance (reinferred_rvalue_type , TupleType )
1659
1727
rvalue_type = reinferred_rvalue_type
1660
1728
@@ -1716,7 +1784,7 @@ def split_around_star(self, items: List[T], star_index: int,
1716
1784
returns in: ([1,2], [3,4,5], [6,7])
1717
1785
"""
1718
1786
nr_right_of_star = length - star_index - 1
1719
- right_index = nr_right_of_star if - nr_right_of_star != 0 else len (items )
1787
+ right_index = - nr_right_of_star if nr_right_of_star != 0 else len (items )
1720
1788
left = items [:star_index ]
1721
1789
star = items [star_index :right_index ]
1722
1790
right = items [right_index :]
@@ -1800,7 +1868,7 @@ def infer_variable_type(self, name: Var, lvalue: Lvalue,
1800
1868
"""Infer the type of initialized variables from initializer type."""
1801
1869
if isinstance (init_type , DeletedType ):
1802
1870
self .msg .deleted_as_rvalue (init_type , context )
1803
- elif not is_valid_inferred_type (init_type ):
1871
+ elif not is_valid_inferred_type (init_type ) and not self . no_partial_types :
1804
1872
# We cannot use the type of the initialization expression for full type
1805
1873
# inference (it's not specific enough), but we might be able to give
1806
1874
# partial type which will be made more specific later. A partial type
@@ -1897,7 +1965,7 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type,
1897
1965
rvalue : Expression , context : Context ) -> Tuple [Type , bool ]:
1898
1966
"""Type member assigment.
1899
1967
1900
- This is defers to check_simple_assignment, unless the member expression
1968
+ This defers to check_simple_assignment, unless the member expression
1901
1969
is a descriptor, in which case this checks descriptor semantics as well.
1902
1970
1903
1971
Return the inferred rvalue_type and whether to infer anything about the attribute type
@@ -2697,7 +2765,19 @@ def iterable_item_type(self, instance: Instance) -> Type:
2697
2765
iterable = map_instance_to_supertype (
2698
2766
instance ,
2699
2767
self .lookup_typeinfo ('typing.Iterable' ))
2700
- return iterable .args [0 ]
2768
+ item_type = iterable .args [0 ]
2769
+ if not isinstance (item_type , AnyType ):
2770
+ # This relies on 'map_instance_to_supertype' returning 'Iterable[Any]'
2771
+ # in case there is no explicit base class.
2772
+ return item_type
2773
+ # Try also structural typing.
2774
+ iter_type = find_member ('__iter__' , instance , instance )
2775
+ if (iter_type and isinstance (iter_type , CallableType ) and
2776
+ isinstance (iter_type .ret_type , Instance )):
2777
+ iterator = map_instance_to_supertype (iter_type .ret_type ,
2778
+ self .lookup_typeinfo ('typing.Iterator' ))
2779
+ item_type = iterator .args [0 ]
2780
+ return item_type
2701
2781
2702
2782
def function_type (self , func : FuncBase ) -> FunctionLike :
2703
2783
return function_type (func , self .named_type ('builtins.function' ))
0 commit comments