@@ -1687,10 +1687,12 @@ def analyze_async_iterable_item_type(self, expr: Node) -> Type:
1687
1687
awaitable = echk .check_call (method , [], [], expr )[0 ]
1688
1688
method = echk .analyze_external_member_access ('__await__' , awaitable , expr )
1689
1689
generator = echk .check_call (method , [], [], expr )[0 ]
1690
+ # XXX TODO Use get_generator_return_type()?
1690
1691
if (isinstance (generator , Instance ) and len (generator .args ) == 3
1691
1692
and generator .type .fullname () == 'typing.Generator' ):
1692
1693
return generator .args [2 ]
1693
1694
else :
1695
+ # XXX TODO What if it's a subclass of Awaitable?
1694
1696
return AnyType ()
1695
1697
1696
1698
def analyze_iterable_item_type (self , expr : Node ) -> Type :
@@ -1801,15 +1803,32 @@ def check_incompatible_property_override(self, e: Decorator) -> None:
1801
1803
1802
1804
def visit_with_stmt (self , s : WithStmt ) -> Type :
1803
1805
echk = self .expr_checker
1806
+ if s .is_async :
1807
+ m_enter = '__aenter__'
1808
+ m_exit = '__aexit__'
1809
+ else :
1810
+ m_enter = '__enter__'
1811
+ m_exit = '__exit__'
1804
1812
for expr , target in zip (s .expr , s .target ):
1805
1813
ctx = self .accept (expr )
1806
- enter = echk .analyze_external_member_access ('__enter__' , ctx , expr )
1814
+ enter = echk .analyze_external_member_access (m_enter , ctx , expr )
1807
1815
obj = echk .check_call (enter , [], [], expr )[0 ]
1816
+ if s .is_async :
1817
+ self .check_subtype (obj , self .named_type ('typing.Awaitable' ), expr )
1808
1818
if target :
1819
+ if s .is_async :
1820
+ # XXX TODO What if it's a subclass of Awaitable?
1821
+ if (isinstance (obj , Instance ) and len (obj .args ) == 1
1822
+ and obj .type .fullname () == 'typing.Awaitable' ):
1823
+ obj = obj .args [0 ]
1824
+ else :
1825
+ obj = AnyType ()
1809
1826
self .check_assignment (target , self .temp_node (obj , expr ))
1810
- exit = echk .analyze_external_member_access ('__exit__' , ctx , expr )
1827
+ exit = echk .analyze_external_member_access (m_exit , ctx , expr )
1811
1828
arg = self .temp_node (AnyType (), expr )
1812
- echk .check_call (exit , [arg ] * 3 , [nodes .ARG_POS ] * 3 , expr )
1829
+ res = echk .check_call (exit , [arg ] * 3 , [nodes .ARG_POS ] * 3 , expr )[0 ]
1830
+ if s .is_async :
1831
+ self .check_subtype (res , self .named_type ('typing.Awaitable' ), expr )
1813
1832
self .accept (s .body )
1814
1833
1815
1834
def visit_print_stmt (self , s : PrintStmt ) -> Type :
@@ -2004,12 +2023,14 @@ def visit_await_expr(self, e: AwaitExpr) -> Type:
2004
2023
if isinstance (actual_type , AnyType ):
2005
2024
return any_type
2006
2025
if is_subtype (actual_type , generator_type ):
2007
- if isinstance (actual_type , Instance ) and len (actual_type .args ) == 3 :
2026
+ if (isinstance (actual_type , Instance ) and len (actual_type .args ) == 3
2027
+ and actual_type .type .fullname () == 'typing.Generator' ):
2008
2028
return actual_type .args [2 ]
2009
2029
else :
2010
2030
return any_type # Must've been unparameterized Generator.
2011
2031
elif is_subtype (actual_type , awaitable_type ):
2012
- if isinstance (actual_type , Instance ) and len (actual_type .args ) == 1 :
2032
+ if (isinstance (actual_type , Instance ) and len (actual_type .args ) == 1
2033
+ and actual_type .type .fullname () == 'typing.Awaitable' ):
2013
2034
return actual_type .args [0 ]
2014
2035
else :
2015
2036
return any_type # Must've been unparameterized Awaitable.
0 commit comments