Skip to content

Commit 3bfca83

Browse files
committed
Make sure throw/send handle StopAsyncIteration correctly
1 parent 1010e3b commit 3bfca83

File tree

2 files changed

+211
-27
lines changed

2 files changed

+211
-27
lines changed

Lib/test/test_asyncgen.py

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
asyncio = import_module("asyncio")
88

99

10+
_no_default = object()
11+
12+
1013
class AwaitException(Exception):
1114
pass
1215

@@ -45,6 +48,37 @@ async def iterate():
4548
return run_until_complete(iterate())
4649

4750

51+
def py_anext(iterator, default=_no_default):
52+
"""Pure-Python implementation of anext() for testing purposes.
53+
54+
Closely matches the builtin anext() C implementation.
55+
Can be used to compare the built-in implementation of the inner
56+
coroutines machinery to C-implementation of __anext__() and send()
57+
or throw() on the returned generator.
58+
"""
59+
60+
try:
61+
__anext__ = type(iterator).__anext__
62+
except AttributeError:
63+
raise TypeError(f'{iterator!r} is not an async iterator')
64+
65+
if default is _no_default:
66+
return __anext__(iterator)
67+
68+
async def anext_impl():
69+
try:
70+
# The C code is way more low-level than this, as it implements
71+
# all methods of the iterator protocol. In this implementation
72+
# we're relying on higher-level coroutine concepts, but that's
73+
# exactly what we want -- crosstest pure-Python high-level
74+
# implementation and low-level C anext() iterators.
75+
return await __anext__(iterator)
76+
except StopAsyncIteration:
77+
return default
78+
79+
return anext_impl()
80+
81+
4882
class AsyncGenSyntaxTest(unittest.TestCase):
4983

5084
def test_async_gen_syntax_01(self):
@@ -374,6 +408,12 @@ def tearDown(self):
374408
asyncio.set_event_loop_policy(None)
375409

376410
def check_async_iterator_anext(self, ait_class):
411+
with self.subTest(anext="pure-Python"):
412+
self._check_async_iterator_anext(ait_class, py_anext)
413+
with self.subTest(anext="builtin"):
414+
self._check_async_iterator_anext(ait_class, anext)
415+
416+
def _check_async_iterator_anext(self, ait_class, anext):
377417
g = ait_class()
378418
async def consume():
379419
results = []
@@ -409,11 +449,9 @@ async def test_2():
409449
def test_send():
410450
p = ait_class()
411451
obj = anext(p, "completed")
412-
try:
452+
with self.assertRaises(StopIteration):
413453
with contextlib.closing(obj.__await__()) as g:
414454
g.send(None)
415-
except StopIteration:
416-
pass
417455

418456
test_send()
419457

@@ -589,6 +627,119 @@ async def do_test():
589627
result = self.loop.run_until_complete(do_test())
590628
self.assertEqual(result, "completed")
591629

630+
def test_anext_iter(self):
631+
@types.coroutine
632+
def _async_yield(v):
633+
return (yield v)
634+
635+
class MyError(Exception):
636+
pass
637+
638+
async def agenfn():
639+
try:
640+
await _async_yield(1)
641+
except MyError:
642+
await _async_yield(2)
643+
return
644+
yield
645+
646+
def test1(anext):
647+
agen = agenfn()
648+
with contextlib.closing(anext(agen, "default").__await__()) as g:
649+
self.assertEqual(g.send(None), 1)
650+
self.assertEqual(g.throw(MyError, MyError(), None), 2)
651+
try:
652+
g.send(None)
653+
except StopIteration as e:
654+
err = e
655+
else:
656+
self.fail('StopIteration was not raised')
657+
self.assertEqual(err.value, "default")
658+
659+
def test2(anext):
660+
agen = agenfn()
661+
with contextlib.closing(anext(agen, "default").__await__()) as g:
662+
self.assertEqual(g.send(None), 1)
663+
self.assertEqual(g.throw(MyError, MyError(), None), 2)
664+
with self.assertRaises(MyError):
665+
g.throw(MyError, MyError(), None)
666+
667+
def test3(anext):
668+
agen = agenfn()
669+
with contextlib.closing(anext(agen, "default").__await__()) as g:
670+
self.assertEqual(g.send(None), 1)
671+
g.close()
672+
with self.assertRaisesRegex(RuntimeError, 'cannot reuse'):
673+
self.assertEqual(g.send(None), 1)
674+
675+
def test4(anext):
676+
@types.coroutine
677+
def _async_yield(v):
678+
yield v * 10
679+
return (yield (v * 10 + 1))
680+
681+
async def agenfn():
682+
try:
683+
await _async_yield(1)
684+
except MyError:
685+
await _async_yield(2)
686+
return
687+
yield
688+
689+
agen = agenfn()
690+
with contextlib.closing(anext(agen, "default").__await__()) as g:
691+
self.assertEqual(g.send(None), 10)
692+
self.assertEqual(g.throw(MyError, MyError(), None), 20)
693+
with self.assertRaisesRegex(MyError, 'val'):
694+
g.throw(MyError, MyError('val'), None)
695+
696+
def test5(anext):
697+
@types.coroutine
698+
def _async_yield(v):
699+
yield v * 10
700+
return (yield (v * 10 + 1))
701+
702+
async def agenfn():
703+
try:
704+
await _async_yield(1)
705+
except MyError:
706+
return
707+
yield 'aaa'
708+
709+
agen = agenfn()
710+
with contextlib.closing(anext(agen, "default").__await__()) as g:
711+
self.assertEqual(g.send(None), 10)
712+
with self.assertRaisesRegex(StopIteration, 'default'):
713+
g.throw(MyError, MyError(), None)
714+
715+
def test6(anext):
716+
@types.coroutine
717+
def _async_yield(v):
718+
yield v * 10
719+
return (yield (v * 10 + 1))
720+
721+
async def agenfn():
722+
await _async_yield(1)
723+
yield 'aaa'
724+
725+
agen = agenfn()
726+
with contextlib.closing(anext(agen, "default").__await__()) as g:
727+
with self.assertRaises(MyError):
728+
g.throw(MyError, MyError(), None)
729+
730+
def run_test(test):
731+
with self.subTest('pure-Python anext()'):
732+
test(py_anext)
733+
with self.subTest('builtin anext()'):
734+
test(anext)
735+
736+
run_test(test1)
737+
run_test(test2)
738+
run_test(test3)
739+
run_test(test4)
740+
run_test(test5)
741+
run_test(test6)
742+
592743
def test_aiter_bad_args(self):
593744
async def gen():
594745
yield 1

Objects/iterobject.c

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,36 @@ anextawaitable_traverse(anextawaitableobject *obj, visitproc visit, void *arg)
313313
return 0;
314314
}
315315

316+
static PyObject *
317+
anextawaitable_getiter(anextawaitableobject *obj)
318+
{
319+
assert(obj->wrapped != NULL);
320+
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
321+
if (awaitable == NULL) {
322+
return NULL;
323+
}
324+
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
325+
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
326+
* or an iterator. Of these, only coroutines lack tp_iternext.
327+
*/
328+
assert(PyCoro_CheckExact(awaitable));
329+
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
330+
PyObject *new_awaitable = getter(awaitable);
331+
if (new_awaitable == NULL) {
332+
Py_DECREF(awaitable);
333+
return NULL;
334+
}
335+
Py_SETREF(awaitable, new_awaitable);
336+
if (!PyIter_Check(awaitable)) {
337+
PyErr_SetString(PyExc_TypeError,
338+
"__await__ returned a non-iterable");
339+
Py_DECREF(awaitable);
340+
return NULL;
341+
}
342+
}
343+
return awaitable;
344+
}
345+
316346
static PyObject *
317347
anextawaitable_iternext(anextawaitableobject *obj)
318348
{
@@ -336,30 +366,10 @@ anextawaitable_iternext(anextawaitableobject *obj)
336366
* Then `await anext(gen)` can just call
337367
* gen.__anext__().__next__()
338368
*/
339-
assert(obj->wrapped != NULL);
340-
PyObject *awaitable = _PyCoro_GetAwaitableIter(obj->wrapped);
369+
PyObject *awaitable = anextawaitable_getiter(obj);
341370
if (awaitable == NULL) {
342371
return NULL;
343372
}
344-
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
345-
/* _PyCoro_GetAwaitableIter returns a Coroutine, a Generator,
346-
* or an iterator. Of these, only coroutines lack tp_iternext.
347-
*/
348-
assert(PyCoro_CheckExact(awaitable));
349-
unaryfunc getter = Py_TYPE(awaitable)->tp_as_async->am_await;
350-
PyObject *new_awaitable = getter(awaitable);
351-
if (new_awaitable == NULL) {
352-
Py_DECREF(awaitable);
353-
return NULL;
354-
}
355-
Py_SETREF(awaitable, new_awaitable);
356-
if (Py_TYPE(awaitable)->tp_iternext == NULL) {
357-
PyErr_SetString(PyExc_TypeError,
358-
"__await__ returned a non-iterable");
359-
Py_DECREF(awaitable);
360-
return NULL;
361-
}
362-
}
363373
PyObject *result = (*Py_TYPE(awaitable)->tp_iternext)(awaitable);
364374
Py_DECREF(awaitable);
365375
if (result != NULL) {
@@ -372,21 +382,44 @@ anextawaitable_iternext(anextawaitableobject *obj)
372382
}
373383

374384

385+
static PyObject *
386+
anextawaitable_proxy(anextawaitableobject *obj, char *meth, PyObject *arg) {
387+
PyObject *awaitable = anextawaitable_getiter(obj);
388+
if (awaitable == NULL) {
389+
return NULL;
390+
}
391+
PyObject *ret = PyObject_CallMethod(awaitable, meth, "O", arg);
392+
Py_DECREF(awaitable);
393+
if (ret != NULL) {
394+
return ret;
395+
}
396+
if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration)) {
397+
/* `anextawaitableobject` is only used by `anext()` when
398+
* a default value is provided. So when we have a StopAsyncIteration
399+
* exception we replace it with a `StopIteration(default)`, as if
400+
* it was the return value of `__anext__()` coroutine.
401+
*/
402+
_PyGen_SetStopIterationValue(obj->default_value);
403+
}
404+
return NULL;
405+
}
406+
407+
375408
static PyObject *
376409
anextawaitable_send(anextawaitableobject *obj, PyObject *arg) {
377-
return PyObject_CallMethod(obj->wrapped, "send", "O", arg);
410+
return anextawaitable_proxy(obj, "send", arg);
378411
}
379412

380413

381414
static PyObject *
382415
anextawaitable_throw(anextawaitableobject *obj, PyObject *arg) {
383-
return PyObject_CallMethod(obj->wrapped, "throw", "O", arg);
416+
return anextawaitable_proxy(obj, "throw", arg);
384417
}
385418

386419

387420
static PyObject *
388421
anextawaitable_close(anextawaitableobject *obj, PyObject *arg) {
389-
return PyObject_CallMethod(obj->wrapped, "close", "O", arg);
422+
return anextawaitable_proxy(obj, "close", arg);
390423
}
391424

392425

0 commit comments

Comments
 (0)