Skip to content

Commit adc80a5

Browse files
bpo-44963: Implement send() and throw() methods for anext_awaitable objects (GH-27955)
Co-authored-by: Yury Selivanov <[email protected]> (cherry picked from commit 533e725) Co-authored-by: Pablo Galindo Salgado <[email protected]>
1 parent af8c781 commit adc80a5

File tree

3 files changed

+270
-22
lines changed

3 files changed

+270
-22
lines changed

Lib/test/test_asyncgen.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import inspect
22
import types
33
import unittest
4+
import contextlib
45

56
from test.support.import_helper import import_module
67
from test.support import gc_collect
78
asyncio = import_module("asyncio")
89

910

11+
_no_default = object()
12+
13+
1014
class AwaitException(Exception):
1115
pass
1216

@@ -45,6 +49,37 @@ async def iterate():
4549
return run_until_complete(iterate())
4650

4751

52+
def py_anext(iterator, default=_no_default):
53+
"""Pure-Python implementation of anext() for testing purposes.
54+
55+
Closely matches the builtin anext() C implementation.
56+
Can be used to compare the built-in implementation of the inner
57+
coroutines machinery to C-implementation of __anext__() and send()
58+
or throw() on the returned generator.
59+
"""
60+
61+
try:
62+
__anext__ = type(iterator).__anext__
63+
except AttributeError:
64+
raise TypeError(f'{iterator!r} is not an async iterator')
65+
66+
if default is _no_default:
67+
return __anext__(iterator)
68+
69+
async def anext_impl():
70+
try:
71+
# The C code is way more low-level than this, as it implements
72+
# all methods of the iterator protocol. In this implementation
73+
# we're relying on higher-level coroutine concepts, but that's
74+
# exactly what we want -- crosstest pure-Python high-level
75+
# implementation and low-level C anext() iterators.
76+
return await __anext__(iterator)
77+
except StopAsyncIteration:
78+
return default
79+
80+
return anext_impl()
81+
82+
4883
class AsyncGenSyntaxTest(unittest.TestCase):
4984

5085
def test_async_gen_syntax_01(self):
@@ -374,6 +409,12 @@ def tearDown(self):
374409
asyncio.set_event_loop_policy(None)
375410

376411
def check_async_iterator_anext(self, ait_class):
412+
with self.subTest(anext="pure-Python"):
413+
self._check_async_iterator_anext(ait_class, py_anext)
414+
with self.subTest(anext="builtin"):
415+
self._check_async_iterator_anext(ait_class, anext)
416+
417+
def _check_async_iterator_anext(self, ait_class, anext):
377418
g = ait_class()
378419
async def consume():
379420
results = []
@@ -406,6 +447,24 @@ async def test_2():
406447
result = self.loop.run_until_complete(test_2())
407448
self.assertEqual(result, "completed")
408449

450+
def test_send():
451+
p = ait_class()
452+
obj = anext(p, "completed")
453+
with self.assertRaises(StopIteration):
454+
with contextlib.closing(obj.__await__()) as g:
455+
g.send(None)
456+
457+
test_send()
458+
459+
async def test_throw():
460+
p = ait_class()
461+
obj = anext(p, "completed")
462+
self.assertRaises(SyntaxError, obj.throw, SyntaxError)
463+
return "completed"
464+
465+
result = self.loop.run_until_complete(test_throw())
466+
self.assertEqual(result, "completed")
467+
409468
def test_async_generator_anext(self):
410469
async def agen():
411470
yield 1
@@ -569,6 +628,119 @@ async def do_test():
569628
result = self.loop.run_until_complete(do_test())
570629
self.assertEqual(result, "completed")
571630

631+
def test_anext_iter(self):
632+
@types.coroutine
633+
def _async_yield(v):
634+
return (yield v)
635+
636+
class MyError(Exception):
637+
pass
638+
639+
async def agenfn():
640+
try:
641+
await _async_yield(1)
642+
except MyError:
643+
await _async_yield(2)
644+
return
645+
yield
646+
647+
def test1(anext):
648+
agen = agenfn()
649+
with contextlib.closing(anext(agen, "default").__await__()) as g:
650+
self.assertEqual(g.send(None), 1)
651+
self.assertEqual(g.throw(MyError, MyError(), None), 2)
652+
try:
653+
g.send(None)
654+
except StopIteration as e:
655+
err = e
656+
else:
657+
self.fail('StopIteration was not raised')
658+
self.assertEqual(err.value, "default")
659+
660+
def test2(anext):
661+
agen = agenfn()
662+
with contextlib.closing(anext(agen, "default").__await__()) as g:
663+
self.assertEqual(g.send(None), 1)
664+
self.assertEqual(g.throw(MyError, MyError(), None), 2)
665+
with self.assertRaises(MyError):
666+
g.throw(MyError, MyError(), None)
667+
668+
def test3(anext):
669+
agen = agenfn()
670+
with contextlib.closing(anext(agen, "default").__await__()) as g:
671+
self.assertEqual(g.send(None), 1)
672+
g.close()
673+
with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
674+
self.assertEqual(g.send(None), 1)
675+
676+
def test4(anext):
677+
@types.coroutine
678+
def _async_yield(v):
679+
yield v * 10
680+
return (yield (v * 10 + 1))
681+
682+
async def agenfn():
683+
try:
684+
await _async_yield(1)
685+
except MyError:
686+
await _async_yield(2)
687+
return
688+
yield
689+
690+
agen = agenfn()
691+
with contextlib.closing(anext(agen, "default").__await__()) as g:
692+
self.assertEqual(g.send(None), 10)
693+
self.assertEqual(g.throw(MyError, MyError(), None), 20)
694+
with self.assertRaisesRegex(MyError, 'val'):
695+
g.throw(MyError, MyError('val'), None)
696+
697+
def test5(anext):
698+
@types.coroutine
699+
def _async_yield(v):
700+
yield v * 10
701+
return (yield (v * 10 + 1))
702+
703+
async def agenfn():
704+
try:
705+
await _async_yield(1)
706+
except MyError:
707+
return
708+
yield 'aaa'
709+
710+
agen = agenfn()
711+
with contextlib.closing(anext(agen, "default").__await__()) as g:
712+
self.assertEqual(g.send(None), 10)
713+
with self.assertRaisesRegex(StopIteration, 'default'):
714+
g.throw(MyError, MyError(), None)
715+
716+
def test6(anext):
717+
@types.coroutine
718+
def _async_yield(v):
719+
yield v * 10
720+
return (yield (v * 10 + 1))
721+
722+
async def agenfn():
723+
await _async_yield(1)
724+
yield 'aaa'
725+
726+
agen = agenfn()
727+
with contextlib.closing(anext(agen, "default").__await__()) as g:
728+
with self.assertRaises(MyError):
729+
g.throw(MyError, MyError(), None)
730+
731+
def run_test(test):
732+
with self.subTest('pure-Python anext()'):
733+
test(py_anext)
734+
with self.subTest('builtin anext()'):
735+
test(anext)
736+
737+
run_test(test1)
738+
run_test(test2)
739+
run_test(test3)
740+
run_test(test4)
741+
run_test(test5)
742+
run_test(test6)
743+
572744
def test_aiter_bad_args(self):
573745
async def gen():
574746
yield 1
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Implement ``send()`` and ``throw()`` methods for ``anext_awaitable``
2+
objects. Patch by Pablo Galindo.

Objects/iterobject.c

Lines changed: 96 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,36 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
313313
return 0;
314314
}
315315

316+
static PyObject *
317+
anextawaitable_getiter(anextawaitableobject *obj)
318+
{
319+
assert(obj->wrapped != NULL);
320+
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
321+
if (awaitable == NULL) {
322+
return NULL;
323+
}
324+
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
325+
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
326+
* or an iterator. Of these, only coroutines lack tp_iternext.
327+
*/
328+
assert(PyCoro_CheckExact(awaitable));
329+
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
330+
PyObject *new_awaitable = getter(awaitable);
331+
if (new_awaitable == NULL) {
332+
Py_DECREF(awaitable);
333+
return NULL;
334+
}
335+
Py_SETREF(awaitable, new_awaitable);
336+
if (!PyIter_Check(awaitable)) {
337+
PyErr_SetString(PyExc_TypeError,
338+
"__await__ returned a non-iterable");
339+
Py_DECREF(awaitable);
340+
return NULL;
341+
}
342+
}
343+
return awaitable;
344+
}
345+
316346
static PyObject *
317347
anextawaitable_iternext(anextawaitableobject *obj)
318348
{
@@ -336,30 +366,10 @@ anextawaitable_iternext(anextawaitableobject *obj)
336366
* Then `await anext(gen)` can just call
337367
* gen.__anext__().__next__()
338368
*/
339-
assert(obj->wrapped != NULL);
340-
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
369+
PyObject *awaitable = anextawaitable_getiter(obj);
341370
if (awaitable == NULL) {
342371
return NULL;
343372
}
344-
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
345-
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
346-
* or an iterator. Of these, only coroutines lack tp_iternext.
347-
*/
348-
assert(PyCoro_CheckExact(awaitable));
349-
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
350-
PyObject *new_awaitable = getter(awaitable);
351-
if (new_awaitable == NULL) {
352-
Py_DECREF(awaitable);
353-
return NULL;
354-
}
355-
Py_SETREF(awaitable, new_awaitable);
356-
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
357-
PyErr_SetString(PyExc_TypeError,
358-
"__await__ returned a non-iterable");
359-
Py_DECREF(awaitable);
360-
return NULL;
361-
}
362-
}
363373
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
364374
Py_DECREF(awaitable);
365375
if (result != NULL) {
@@ -371,6 +381,70 @@ anextawaitable_iternext(anextawaitableobject *obj)
371381
return NULL;
372382
}
373383

384+
385+
static PyObject *
386+
anextawaitable_proxy(anextawaitableobject *obj, char *meth, PyObject *arg) {
387+
PyObject *awaitable = anextawaitable_getiter(obj);
388+
if (awaitable == NULL) {
389+
return NULL;
390+
}
391+
PyObject *ret = PyObject_CallMethod(awaitable, meth, "O", arg);
392+
Py_DECREF(awaitable);
393+
if (ret != NULL) {
394+
return ret;
395+
}
396+
if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
397+
/* `anextawaitableobject` is only used by `anext()` when
398+
* a default value is provided. So when we have a StopAsyncIteration
399+
* exception we replace it with a `StopIteration(default)`, as if
400+
* it was the return value of `__anext__()` coroutine.
401+
*/
402+
_PyGen_SetStopIterationValue(obj->default_value);
403+
}
404+
return NULL;
405+
}
406+
407+
408+
static PyObject *
409+
anextawaitable_send(anextawaitableobject *obj, PyObject *arg) {
410+
return anextawaitable_proxy(obj, "send", arg);
411+
}
412+
413+
414+
static PyObject *
415+
anextawaitable_throw(anextawaitableobject *obj, PyObject *arg) {
416+
return anextawaitable_proxy(obj, "throw", arg);
417+
}
418+
419+
420+
static PyObject *
421+
anextawaitable_close(anextawaitableobject *obj, PyObject *arg) {
422+
return anextawaitable_proxy(obj, "close", arg);
423+
}
424+
425+
426+
PyDoc_STRVAR(send_doc,
427+
"send(arg) -> send 'arg' into the wrapped iterator,\n\
428+
return next yielded value or raise StopIteration.");
429+
430+
431+
PyDoc_STRVAR(throw_doc,
432+
"throw(typ[,val[,tb]]) -> raise exception in the wrapped iterator,\n\
433+
return next yielded value or raise StopIteration.");
434+
435+
436+
PyDoc_STRVAR(close_doc,
437+
"close() -> raise GeneratorExit inside generator.");
438+
439+
440+
static PyMethodDef anextawaitable_methods[] = {
441+
{"send",(PyCFunction)anextawaitable_send, METH_O, send_doc},
442+
{"throw",(PyCFunction)anextawaitable_throw, METH_VARARGS, throw_doc},
443+
{"close",(PyCFunction)anextawaitable_close, METH_VARARGS, close_doc},
444+
{NULL, NULL} /* Sentinel */
445+
};
446+
447+
374448
static PyAsyncMethods anextawaitable_as_async = {
375449
PyObject_SelfIter, /* am_await */
376450
0, /* am_aiter */
@@ -407,7 +481,7 @@ PyTypeObject _PyAnextAwaitable_Type = {
407481
0, /* tp_weaklistoffset */
408482
PyObject_SelfIter, /* tp_iter */
409483
(unaryfunc)anextawaitable_iternext, /* tp_iternext */
410-
0, /* tp_methods */
484+
anextawaitable_methods, /* tp_methods */
411485
};
412486

413487
PyObject *

0 commit comments

Comments
 (0)