Skip to content

Commit dcc3b86

Browse files
gvanrossumddfisher
authored andcommitted
Implement 'async def' and friends ('await', 'async for', 'async with') (#1808)
* PEP 492 syntax: `async def` and `await`. * Fix type errors; add async for and async with (not fully fledged). * Dispose of Async{For,With}Stmt -- use is_async flag instead. * Basic `async for` is working. * Clear unneeded TODOs. * Fledgeling `async with` support. * Disallow `yield [from]` in `async def`. * Check Python version before looking up typing.Awaitable. Ensure 'async def' is a syntax error in Python 2 (at least with the fast parser). * Vast strides in accuracy for visit_await_expr(). * Add `@with_line` to PEP 492 visit function definitions. * Fix tests now that errors have line numbers. * Tweak tests for async/await a bit. * Get rid of remaining XXX issues. is_generator_return_type() now takes an extra is_coroutine flag. * Move PEP 492 nodes back where they belong. * Respond to code review. * Add tests expecting errors from async for/with. * Test that await <generator> is an error. * Verify that `yield from` does not accept coroutines. This revealed a spurious error "Function does not return a value", fixed that. * Disallow return value in generator declared as -> Iterator. * Fix typo in comment. * Refactor visit_with_stmt() into separate helper methods for async and regular. Also Use get_generator_return_type() instead of manually unpacking the value. * Fix lint error. Correct comment about default ts/tr. * Improve errors when __aenter__/__aexit__ are not async. With tests. * Refactor: move all extraction of T from Awaitable[T] to a single helper. * Follow __await__ to extract t from subclass of Awaitable[t]. * Make get_generator_return_type() default to AnyType() (i.e. as it was). * Fix test to match reverting get_generator_return_type() to default to Any. * Rename get_awaitable_return_type() to check_awaitable_expr(), update docstring.
1 parent 7774f40 commit dcc3b86

14 files changed

+607
-52
lines changed

mypy/checker.py

+172-37
Large diffs are not rendered by default.

mypy/fastparse.py

+41-7
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
UnaryExpr, FuncExpr, ComparisonExpr,
1515
StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension,
1616
SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument,
17+
AwaitExpr,
1718
ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2
1819
)
19-
from mypy.types import Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType
20+
from mypy.types import (
21+
Type, CallableType, FunctionLike, AnyType, UnboundType, TupleType, TypeList, EllipsisType,
22+
)
2023
from mypy import defaults
2124
from mypy import experiments
2225
from mypy.errors import Errors
@@ -242,6 +245,17 @@ def visit_Module(self, mod: ast35.Module) -> Node:
242245
# arg? kwarg, expr* defaults)
243246
@with_line
244247
def visit_FunctionDef(self, n: ast35.FunctionDef) -> Node:
248+
return self.do_func_def(n)
249+
250+
# AsyncFunctionDef(identifier name, arguments args,
251+
# stmt* body, expr* decorator_list, expr? returns, string? type_comment)
252+
@with_line
253+
def visit_AsyncFunctionDef(self, n: ast35.AsyncFunctionDef) -> Node:
254+
return self.do_func_def(n, is_coroutine=True)
255+
256+
def do_func_def(self, n: Union[ast35.FunctionDef, ast35.AsyncFunctionDef],
257+
is_coroutine: bool = False) -> Node:
258+
"""Helper shared between visit_FunctionDef and visit_AsyncFunctionDef."""
245259
args = self.transform_args(n.args, n.lineno)
246260

247261
arg_kinds = [arg.kind for arg in args]
@@ -285,6 +299,9 @@ def visit_FunctionDef(self, n: ast35.FunctionDef) -> Node:
285299
args,
286300
self.as_block(n.body, n.lineno),
287301
func_type)
302+
if is_coroutine:
303+
# A coroutine is also a generator, mostly for internal reasons.
304+
func_def.is_generator = func_def.is_coroutine = True
288305
if func_type is not None:
289306
func_type.definition = func_def
290307
func_type.line = n.lineno
@@ -345,9 +362,6 @@ def make_argument(arg: ast35.arg, default: Optional[ast35.expr], kind: int) -> A
345362

346363
return new_args
347364

348-
# TODO: AsyncFunctionDef(identifier name, arguments args,
349-
# stmt* body, expr* decorator_list, expr? returns, string? type_comment)
350-
351365
def stringify_name(self, n: ast35.AST) -> str:
352366
if isinstance(n, ast35.Name):
353367
return n.id
@@ -419,7 +433,16 @@ def visit_For(self, n: ast35.For) -> Node:
419433
self.as_block(n.body, n.lineno),
420434
self.as_block(n.orelse, n.lineno))
421435

422-
# TODO: AsyncFor(expr target, expr iter, stmt* body, stmt* orelse)
436+
# AsyncFor(expr target, expr iter, stmt* body, stmt* orelse)
437+
@with_line
438+
def visit_AsyncFor(self, n: ast35.AsyncFor) -> Node:
439+
r = ForStmt(self.visit(n.target),
440+
self.visit(n.iter),
441+
self.as_block(n.body, n.lineno),
442+
self.as_block(n.orelse, n.lineno))
443+
r.is_async = True
444+
return r
445+
423446
# While(expr test, stmt* body, stmt* orelse)
424447
@with_line
425448
def visit_While(self, n: ast35.While) -> Node:
@@ -441,7 +464,14 @@ def visit_With(self, n: ast35.With) -> Node:
441464
[self.visit(i.optional_vars) for i in n.items],
442465
self.as_block(n.body, n.lineno))
443466

444-
# TODO: AsyncWith(withitem* items, stmt* body)
467+
# AsyncWith(withitem* items, stmt* body)
468+
@with_line
469+
def visit_AsyncWith(self, n: ast35.AsyncWith) -> Node:
470+
r = WithStmt([self.visit(i.context_expr) for i in n.items],
471+
[self.visit(i.optional_vars) for i in n.items],
472+
self.as_block(n.body, n.lineno))
473+
r.is_async = True
474+
return r
445475

446476
# Raise(expr? exc, expr? cause)
447477
@with_line
@@ -628,7 +658,11 @@ def visit_GeneratorExp(self, n: ast35.GeneratorExp) -> GeneratorExpr:
628658
iters,
629659
ifs_list)
630660

631-
# TODO: Await(expr value)
661+
# Await(expr value)
662+
@with_line
663+
def visit_Await(self, n: ast35.Await) -> Node:
664+
v = self.visit(n.value)
665+
return AwaitExpr(v)
632666

633667
# Yield(expr? value)
634668
@with_line

mypy/messages.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
INCOMPATIBLE_TYPES = 'Incompatible types'
4141
INCOMPATIBLE_TYPES_IN_ASSIGNMENT = 'Incompatible types in assignment'
4242
INCOMPATIBLE_REDEFINITION = 'Incompatible redefinition'
43+
INCOMPATIBLE_TYPES_IN_AWAIT = 'Incompatible types in await'
44+
INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER = 'Incompatible types in "async with" for __aenter__'
45+
INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT = 'Incompatible types in "async with" for __aexit__'
46+
INCOMPATIBLE_TYPES_IN_ASYNC_FOR = 'Incompatible types in "async for"'
47+
4348
INCOMPATIBLE_TYPES_IN_YIELD = 'Incompatible types in yield'
4449
INCOMPATIBLE_TYPES_IN_YIELD_FROM = 'Incompatible types in "yield from"'
4550
INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION = 'Incompatible types in string interpolation'
@@ -57,6 +62,7 @@
5762
INCOMPATIBLE_VALUE_TYPE = 'Incompatible dictionary value type'
5863
NEED_ANNOTATION_FOR_VAR = 'Need type annotation for variable'
5964
ITERABLE_EXPECTED = 'Iterable expected'
65+
ASYNC_ITERABLE_EXPECTED = 'AsyncIterable expected'
6066
INCOMPATIBLE_TYPES_IN_FOR = 'Incompatible types in for statement'
6167
INCOMPATIBLE_ARRAY_VAR_ARGS = 'Incompatible variable arguments in call'
6268
INVALID_SLICE_INDEX = 'Slice index must be an integer or None'

mypy/nodes.py

+17
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ class FuncItem(FuncBase):
416416
# Is this an overload variant of function with more than one overload variant?
417417
is_overload = False
418418
is_generator = False # Contains a yield statement?
419+
is_coroutine = False # Defined using 'async def' syntax?
419420
is_static = False # Uses @staticmethod?
420421
is_class = False # Uses @classmethod?
421422
# Variants of function with type variables with values expanded
@@ -486,6 +487,7 @@ def serialize(self) -> JsonDict:
486487
'is_property': self.is_property,
487488
'is_overload': self.is_overload,
488489
'is_generator': self.is_generator,
490+
'is_coroutine': self.is_coroutine,
489491
'is_static': self.is_static,
490492
'is_class': self.is_class,
491493
'is_decorated': self.is_decorated,
@@ -507,6 +509,7 @@ def deserialize(cls, data: JsonDict) -> 'FuncDef':
507509
ret.is_property = data['is_property']
508510
ret.is_overload = data['is_overload']
509511
ret.is_generator = data['is_generator']
512+
ret.is_coroutine = data['is_coroutine']
510513
ret.is_static = data['is_static']
511514
ret.is_class = data['is_class']
512515
ret.is_decorated = data['is_decorated']
@@ -798,6 +801,7 @@ class ForStmt(Statement):
798801
expr = None # type: Expression
799802
body = None # type: Block
800803
else_body = None # type: Block
804+
is_async = False # True if `async for ...` (PEP 492, Python 3.5)
801805

802806
def __init__(self, index: Expression, expr: Expression, body: Block,
803807
else_body: Block) -> None:
@@ -908,6 +912,7 @@ class WithStmt(Statement):
908912
expr = None # type: List[Expression]
909913
target = None # type: List[Expression]
910914
body = None # type: Block
915+
is_async = False # True if `async with ...` (PEP 492, Python 3.5)
911916

912917
def __init__(self, expr: List[Expression], target: List[Expression],
913918
body: Block) -> None:
@@ -1705,6 +1710,18 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
17051710
return visitor.visit__promote_expr(self)
17061711

17071712

1713+
class AwaitExpr(Node):
1714+
"""Await expression (await ...)."""
1715+
1716+
expr = None # type: Node
1717+
1718+
def __init__(self, expr: Node) -> None:
1719+
self.expr = expr
1720+
1721+
def accept(self, visitor: NodeVisitor[T]) -> T:
1722+
return visitor.visit_await_expr(self)
1723+
1724+
17081725
# Constants
17091726

17101727

mypy/parse.py

+4
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,10 @@ def parse_statement(self) -> Tuple[Node, bool]:
957957
stmt = self.parse_exec_stmt()
958958
else:
959959
stmt = self.parse_expression_or_assignment()
960+
if ts == 'async' and self.current_str() == 'def':
961+
self.parse_error_at(self.current(),
962+
reason='Use --fast-parser to parse code using "async def"')
963+
raise ParseError()
960964
if stmt is not None:
961965
stmt.set_line(t)
962966
return stmt, is_simple

mypy/semanal.py

+25-5
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,16 @@
6262
ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases,
6363
YieldFromExpr, NamedTupleExpr, NonlocalDecl,
6464
SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr,
65-
YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, COVARIANT, CONTRAVARIANT,
65+
YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr,
6666
IntExpr, FloatExpr, UnicodeExpr,
67-
INVARIANT, UNBOUND_IMPORTED
67+
COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED,
6868
)
6969
from mypy.visitor import NodeVisitor
7070
from mypy.traverser import TraverserVisitor
7171
from mypy.errors import Errors, report_internal_error
7272
from mypy.types import (
7373
NoneTyp, CallableType, Overloaded, Instance, Type, TypeVarType, AnyType,
74-
FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef,
74+
FunctionLike, UnboundType, TypeList, ErrorType, TypeVarDef, Void,
7575
replace_leading_arg_type, TupleType, UnionType, StarType, EllipsisType
7676
)
7777
from mypy.nodes import function_type, implicit_module_attrs
@@ -314,6 +314,13 @@ def visit_func_def(self, defn: FuncDef) -> None:
314314
# Second phase of analysis for function.
315315
self.errors.push_function(defn.name())
316316
self.analyze_function(defn)
317+
if defn.is_coroutine and isinstance(defn.type, CallableType):
318+
# A coroutine defined as `async def foo(...) -> T: ...`
319+
# has external return type `Awaitable[T]`.
320+
defn.type = defn.type.copy_modified(
321+
ret_type=Instance(
322+
self.named_type_or_none('typing.Awaitable').type,
323+
[defn.type.ret_type]))
317324
self.errors.pop_function()
318325

319326
def prepare_method_signature(self, func: FuncDef) -> None:
@@ -1821,7 +1828,10 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> None:
18211828
if not self.is_func_scope(): # not sure
18221829
self.fail("'yield from' outside function", e, True, blocker=True)
18231830
else:
1824-
self.function_stack[-1].is_generator = True
1831+
if self.function_stack[-1].is_coroutine:
1832+
self.fail("'yield from' in async function", e, True, blocker=True)
1833+
else:
1834+
self.function_stack[-1].is_generator = True
18251835
if e.expr:
18261836
e.expr.accept(self)
18271837

@@ -2074,10 +2084,20 @@ def visit_yield_expr(self, expr: YieldExpr) -> None:
20742084
if not self.is_func_scope():
20752085
self.fail("'yield' outside function", expr, True, blocker=True)
20762086
else:
2077-
self.function_stack[-1].is_generator = True
2087+
if self.function_stack[-1].is_coroutine:
2088+
self.fail("'yield' in async function", expr, True, blocker=True)
2089+
else:
2090+
self.function_stack[-1].is_generator = True
20782091
if expr.expr:
20792092
expr.expr.accept(self)
20802093

2094+
def visit_await_expr(self, expr: AwaitExpr) -> None:
2095+
if not self.is_func_scope():
2096+
self.fail("'await' outside function", expr)
2097+
elif not self.function_stack[-1].is_coroutine:
2098+
self.fail("'await' outside coroutine ('async def')", expr)
2099+
expr.expr.accept(self)
2100+
20812101
#
20822102
# Helpers
20832103
#

mypy/strconv.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,10 @@ def visit_while_stmt(self, o):
199199
return self.dump(a, o)
200200

201201
def visit_for_stmt(self, o):
202-
a = [o.index]
203-
a.extend([o.expr, o.body])
202+
a = []
203+
if o.is_async:
204+
a.append(('Async', ''))
205+
a.extend([o.index, o.expr, o.body])
204206
if o.else_body:
205207
a.append(('Else', o.else_body.body))
206208
return self.dump(a, o)
@@ -243,6 +245,9 @@ def visit_yield_from_stmt(self, o):
243245
def visit_yield_expr(self, o):
244246
return self.dump([o.expr], o)
245247

248+
def visit_await_expr(self, o):
249+
return self.dump([o.expr], o)
250+
246251
def visit_del_stmt(self, o):
247252
return self.dump([o.expr], o)
248253

@@ -264,6 +269,8 @@ def visit_try_stmt(self, o):
264269

265270
def visit_with_stmt(self, o):
266271
a = []
272+
if o.is_async:
273+
a.append(('Async', ''))
267274
for i in range(len(o.expr)):
268275
a.append(('Expr', [o.expr[i]]))
269276
if o.target[i]:

mypy/test/testcheck.py

+1
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
'check-optional.test',
6464
'check-fastparse.test',
6565
'check-warnings.test',
66+
'check-async-await.test',
6667
]
6768

6869

mypy/treetransform.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ComparisonExpr, TempNode, StarExpr,
2020
YieldFromExpr, NamedTupleExpr, NonlocalDecl, SetComprehension,
2121
DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr,
22-
YieldExpr, ExecStmt, Argument, BackquoteExpr
22+
YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr,
2323
)
2424
from mypy.types import Type, FunctionLike, Instance
2525
from mypy.visitor import NodeVisitor
@@ -339,6 +339,9 @@ def visit_yield_from_expr(self, node: YieldFromExpr) -> Node:
339339
def visit_yield_expr(self, node: YieldExpr) -> Node:
340340
return YieldExpr(self.node(node.expr))
341341

342+
def visit_await_expr(self, node: AwaitExpr) -> Node:
343+
return AwaitExpr(self.node(node.expr))
344+
342345
def visit_call_expr(self, node: CallExpr) -> Node:
343346
return CallExpr(self.node(node.callee),
344347
self.nodes(node.args),

mypy/visitor.py

+3
Original file line numberDiff line numberDiff line change
@@ -228,5 +228,8 @@ def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T:
228228
def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> T:
229229
pass
230230

231+
def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T:
232+
pass
233+
231234
def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T:
232235
pass

0 commit comments

Comments
 (0)