Skip to content

bpo-44963: Implement send() and throw() methods for anext_awaitable objects #27955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 172 additions & 0 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import inspect
import types
import unittest
import contextlib

from test.support.import_helper import import_module
asyncio = import_module("asyncio")


_no_default = object()


class AwaitException(Exception):
pass

Expand Down Expand Up @@ -44,6 +48,37 @@ async def iterate():
return run_until_complete(iterate())


def py_anext(iterator, default=_no_default):
"""Pure-Python implementation of anext() for testing purposes.

Closely matches the builtin anext() C implementation.
Can be used to compare the built-in implementation of the inner
coroutines machinery to C-implementation of __anext__() and send()
or throw() on the returned generator.
"""

try:
__anext__ = type(iterator).__anext__
except AttributeError:
raise TypeError(f'{iterator!r} is not an async iterator')

if default is _no_default:
return __anext__(iterator)

async def anext_impl():
try:
# The C code is way more low-level than this, as it implements
# all methods of the iterator protocol. In this implementation
# we're relying on higher-level coroutine concepts, but that's
# exactly what we want -- crosstest pure-Python high-level
# implementation and low-level C anext() iterators.
return await __anext__(iterator)
except StopAsyncIteration:
return default

return anext_impl()


class AsyncGenSyntaxTest(unittest.TestCase):

def test_async_gen_syntax_01(self):
Expand Down Expand Up @@ -373,6 +408,12 @@ def tearDown(self):
asyncio.set_event_loop_policy(None)

def check_async_iterator_anext(self, ait_class):
with self.subTest(anext="pure-Python"):
self._check_async_iterator_anext(ait_class, py_anext)
with self.subTest(anext="builtin"):
self._check_async_iterator_anext(ait_class, anext)

def _check_async_iterator_anext(self, ait_class, anext):
g = ait_class()
async def consume():
results = []
Expand Down Expand Up @@ -405,6 +446,24 @@ async def test_2():
result = self.loop.run_until_complete(test_2())
self.assertEqual(result, "completed")

def test_send():
p = ait_class()
obj = anext(p, "completed")
with self.assertRaises(StopIteration):
with contextlib.closing(obj.__await__()) as g:
g.send(None)

test_send()

async def test_throw():
p = ait_class()
obj = anext(p, "completed")
self.assertRaises(SyntaxError, obj.throw, SyntaxError)
return "completed"

result = self.loop.run_until_complete(test_throw())
self.assertEqual(result, "completed")

def test_async_generator_anext(self):
async def agen():
yield 1
Expand Down Expand Up @@ -568,6 +627,119 @@ async def do_test():
result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed")

def test_anext_iter(self):
@types.coroutine
def _async_yield(v):
return (yield v)

class MyError(Exception):
pass

async def agenfn():
try:
await _async_yield(1)
except MyError:
await _async_yield(2)
return
yield

def test1(anext):
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 1)
self.assertEqual(g.throw(MyError, MyError(), None), 2)
try:
g.send(None)
except StopIteration as e:
err = e
else:
self.fail('StopIteration was not raised')
self.assertEqual(err.value, "default")

def test2(anext):
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 1)
self.assertEqual(g.throw(MyError, MyError(), None), 2)
with self.assertRaises(MyError):
g.throw(MyError, MyError(), None)

def test3(anext):
agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 1)
g.close()
with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
self.assertEqual(g.send(None), 1)

def test4(anext):
@types.coroutine
def _async_yield(v):
yield v * 10
return (yield (v * 10 + 1))

async def agenfn():
try:
await _async_yield(1)
except MyError:
await _async_yield(2)
return
yield

agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 10)
self.assertEqual(g.throw(MyError, MyError(), None), 20)
with self.assertRaisesRegex(MyError, 'val'):
g.throw(MyError, MyError('val'), None)

def test5(anext):
@types.coroutine
def _async_yield(v):
yield v * 10
return (yield (v * 10 + 1))

async def agenfn():
try:
await _async_yield(1)
except MyError:
return
yield 'aaa'

agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
self.assertEqual(g.send(None), 10)
with self.assertRaisesRegex(StopIteration, 'default'):
g.throw(MyError, MyError(), None)

def test6(anext):
@types.coroutine
def _async_yield(v):
yield v * 10
return (yield (v * 10 + 1))

async def agenfn():
await _async_yield(1)
yield 'aaa'

agen = agenfn()
with contextlib.closing(anext(agen, "default").__await__()) as g:
with self.assertRaises(MyError):
g.throw(MyError, MyError(), None)

def run_test(test):
with self.subTest('pure-Python anext()'):
test(py_anext)
with self.subTest('builtin anext()'):
test(anext)

run_test(test1)
run_test(test2)
run_test(test3)
run_test(test4)
run_test(test5)
run_test(test6)

def test_aiter_bad_args(self):
async def gen():
yield 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Implement ``send()`` and ``throw()`` methods for ``anext_awaitable``
objects. Patch by Pablo Galindo.
118 changes: 96 additions & 22 deletions Objects/iterobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,36 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
return 0;
}

static PyObject *
anextawaitable_getiter(anextawaitableobject *obj)
{
assert(obj->wrapped != NULL);
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
if (awaitable == NULL) {
return NULL;
}
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
* or an iterator. Of these, only coroutines lack tp_iternext.
*/
assert(PyCoro_CheckExact(awaitable));
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
PyObject *new_awaitable = getter(awaitable);
if (new_awaitable == NULL) {
Py_DECREF(awaitable);
return NULL;
}
Py_SETREF(awaitable, new_awaitable);
if (!PyIter_Check(awaitable)) {
PyErr_SetString(PyExc_TypeError,
"__await__ returned a non-iterable");
Py_DECREF(awaitable);
return NULL;
}
}
return awaitable;
}

static PyObject *
anextawaitable_iternext(anextawaitableobject *obj)
{
Expand All @@ -336,30 +366,10 @@ anextawaitable_iternext(anextawaitableobject *obj)
* Then `await anext(gen)` can just call
* gen.__anext__().__next__()
*/
assert(obj->wrapped != NULL);
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
PyObject *awaitable = anextawaitable_getiter(obj);
if (awaitable == NULL) {
return NULL;
}
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
* or an iterator. Of these, only coroutines lack tp_iternext.
*/
assert(PyCoro_CheckExact(awaitable));
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
PyObject *new_awaitable = getter(awaitable);
if (new_awaitable == NULL) {
Py_DECREF(awaitable);
return NULL;
}
Py_SETREF(awaitable, new_awaitable);
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
PyErr_SetString(PyExc_TypeError,
"__await__ returned a non-iterable");
Py_DECREF(awaitable);
return NULL;
}
}
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
Py_DECREF(awaitable);
if (result != NULL) {
Expand All @@ -371,6 +381,70 @@ anextawaitable_iternext(anextawaitableobject *obj)
return NULL;
}


static PyObject *
anextawaitable_proxy(anextawaitableobject *obj, char *meth, PyObject *arg) {
PyObject *awaitable = anextawaitable_getiter(obj);
if (awaitable == NULL) {
return NULL;
}
PyObject *ret = PyObject_CallMethod(awaitable, meth, "O", arg);
Py_DECREF(awaitable);
if (ret != NULL) {
return ret;
}
if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
/* `anextawaitableobject` is only used by `anext()` when
* a default value is provided. So when we have a StopAsyncIteration
* exception we replace it with a `StopIteration(default)`, as if
* it was the return value of `__anext__()` coroutine.
*/
_PyGen_SetStopIterationValue(obj->default_value);
}
return NULL;
}


static PyObject *
anextawaitable_send(anextawaitableobject *obj, PyObject *arg) {
return anextawaitable_proxy(obj, "send", arg);
}


static PyObject *
anextawaitable_throw(anextawaitableobject *obj, PyObject *arg) {
return anextawaitable_proxy(obj, "throw", arg);
}


static PyObject *
anextawaitable_close(anextawaitableobject *obj, PyObject *arg) {
return anextawaitable_proxy(obj, "close", arg);
}


PyDoc_STRVAR(send_doc,
"send(arg) -> send 'arg' into the wrapped iterator,\n\
return next yielded value or raise StopIteration.");


PyDoc_STRVAR(throw_doc,
"throw(typ[,val[,tb]]) -> raise exception in the wrapped iterator,\n\
return next yielded value or raise StopIteration.");


PyDoc_STRVAR(close_doc,
"close() -> raise GeneratorExit inside generator.");


static PyMethodDef anextawaitable_methods[] = {
{"send",(PyCFunction)anextawaitable_send, METH_O, send_doc},
{"throw",(PyCFunction)anextawaitable_throw, METH_VARARGS, throw_doc},
{"close",(PyCFunction)anextawaitable_close, METH_VARARGS, close_doc},
{NULL, NULL} /* Sentinel */
};


static PyAsyncMethods anextawaitable_as_async = {
PyObject_SelfIter, /* am_await */
0, /* am_aiter */
Expand Down Expand Up @@ -407,7 +481,7 @@ PyTypeObject _PyAnextAwaitable_Type = {
0, /* tp_weaklistoffset */
PyObject_SelfIter, /* tp_iter */
(unaryfunc)anextawaitable_iternext, /* tp_iternext */
0, /* tp_methods */
anextawaitable_methods, /* tp_methods */
};

PyObject *
Expand Down