Skip to content

Commit a5024a2

Browse files
GH-96764: rewrite asyncio.wait_for to use asyncio.timeout (#98518)
Changes `asyncio.wait_for` to use `asyncio.timeout` as its underlying implementation.
1 parent 226484e commit a5024a2

File tree

4 files changed

+133
-79
lines changed

4 files changed

+133
-79
lines changed

Lib/asyncio/tasks.py

+29-48
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from . import events
2525
from . import exceptions
2626
from . import futures
27+
from . import timeouts
2728
from .coroutines import _is_coroutine
2829

2930
# Helper to generate new task names
@@ -437,65 +438,44 @@ async def wait_for(fut, timeout):
437438
438439
If the wait is cancelled, the task is also cancelled.
439440
441+
If the task supresses the cancellation and returns a value instead,
442+
that value is returned.
443+
440444
This function is a coroutine.
441445
"""
442-
loop = events.get_running_loop()
446+
# The special case for timeout <= 0 is for the following case:
447+
#
448+
# async def test_waitfor():
449+
# func_started = False
450+
#
451+
# async def func():
452+
# nonlocal func_started
453+
# func_started = True
454+
#
455+
# try:
456+
# await asyncio.wait_for(func(), 0)
457+
# except asyncio.TimeoutError:
458+
# assert not func_started
459+
# else:
460+
# assert False
461+
#
462+
# asyncio.run(test_waitfor())
443463

444-
if timeout is None:
445-
return await fut
446464

447-
if timeout <= 0:
448-
fut = ensure_future(fut, loop=loop)
465+
if timeout is not None and timeout <= 0:
466+
fut = ensure_future(fut)
449467

450468
if fut.done():
451469
return fut.result()
452470

453-
await _cancel_and_wait(fut, loop=loop)
471+
await _cancel_and_wait(fut)
454472
try:
455473
return fut.result()
456474
except exceptions.CancelledError as exc:
457-
raise exceptions.TimeoutError() from exc
458-
459-
waiter = loop.create_future()
460-
timeout_handle = loop.call_later(timeout, _release_waiter, waiter)
461-
cb = functools.partial(_release_waiter, waiter)
462-
463-
fut = ensure_future(fut, loop=loop)
464-
fut.add_done_callback(cb)
465-
466-
try:
467-
# wait until the future completes or the timeout
468-
try:
469-
await waiter
470-
except exceptions.CancelledError:
471-
if fut.done():
472-
return fut.result()
473-
else:
474-
fut.remove_done_callback(cb)
475-
# We must ensure that the task is not running
476-
# after wait_for() returns.
477-
# See https://bugs.python.org/issue32751
478-
await _cancel_and_wait(fut, loop=loop)
479-
raise
480-
481-
if fut.done():
482-
return fut.result()
483-
else:
484-
fut.remove_done_callback(cb)
485-
# We must ensure that the task is not running
486-
# after wait_for() returns.
487-
# See https://bugs.python.org/issue32751
488-
await _cancel_and_wait(fut, loop=loop)
489-
# In case task cancellation failed with some
490-
# exception, we should re-raise it
491-
# See https://bugs.python.org/issue40607
492-
try:
493-
return fut.result()
494-
except exceptions.CancelledError as exc:
495-
raise exceptions.TimeoutError() from exc
496-
finally:
497-
timeout_handle.cancel()
475+
raise TimeoutError from exc
498476

477+
async with timeouts.timeout(timeout):
478+
return await fut
499479

500480
async def _wait(fs, timeout, return_when, loop):
501481
"""Internal helper for wait().
@@ -541,9 +521,10 @@ def _on_completion(f):
541521
return done, pending
542522

543523

544-
async def _cancel_and_wait(fut, loop):
524+
async def _cancel_and_wait(fut):
545525
"""Cancel the *fut* future or task and wait until it completes."""
546526

527+
loop = events.get_running_loop()
547528
waiter = loop.create_future()
548529
cb = functools.partial(_release_waiter, waiter)
549530
fut.add_done_callback(cb)

Lib/test/test_asyncio/test_futures2.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,9 @@ async def test_recursive_repr_for_pending_tasks(self):
8686
async def func():
8787
return asyncio.all_tasks()
8888

89-
# The repr() call should not raise RecursiveError at first.
90-
# The check for returned string is not very reliable but
91-
# exact comparison for the whole string is even weaker.
92-
self.assertIn('...', repr(await asyncio.wait_for(func(), timeout=10)))
89+
# The repr() call should not raise RecursionError at first.
90+
waiter = await asyncio.wait_for(asyncio.Task(func()),timeout=10)
91+
self.assertIn('...', repr(waiter))
9392

9493

9594
if __name__ == '__main__':

Lib/test/test_asyncio/test_waitfor.py

+100-27
Original file line numberDiff line numberDiff line change
@@ -237,33 +237,6 @@ async def inner():
237237
with self.assertRaises(FooException):
238238
await foo()
239239

240-
async def test_wait_for_self_cancellation(self):
241-
async def inner():
242-
try:
243-
await asyncio.sleep(0.3)
244-
except asyncio.CancelledError:
245-
try:
246-
await asyncio.sleep(0.3)
247-
except asyncio.CancelledError:
248-
await asyncio.sleep(0.3)
249-
250-
return 42
251-
252-
inner_task = asyncio.create_task(inner())
253-
254-
wait = asyncio.wait_for(inner_task, timeout=0.1)
255-
256-
# Test that wait_for itself is properly cancellable
257-
# even when the initial task holds up the initial cancellation.
258-
task = asyncio.create_task(wait)
259-
await asyncio.sleep(0.2)
260-
task.cancel()
261-
262-
with self.assertRaises(asyncio.CancelledError):
263-
await task
264-
265-
self.assertEqual(await inner_task, 42)
266-
267240
async def _test_cancel_wait_for(self, timeout):
268241
loop = asyncio.get_running_loop()
269242

@@ -289,6 +262,106 @@ async def test_cancel_blocking_wait_for(self):
289262
async def test_cancel_wait_for(self):
290263
await self._test_cancel_wait_for(60.0)
291264

265+
async def test_wait_for_cancel_suppressed(self):
266+
# GH-86296: Supressing CancelledError is discouraged
267+
# but if a task subpresses CancelledError and returns a value,
268+
# `wait_for` should return the value instead of raising CancelledError.
269+
# This is the same behavior as `asyncio.timeout`.
270+
271+
async def return_42():
272+
try:
273+
await asyncio.sleep(10)
274+
except asyncio.CancelledError:
275+
return 42
276+
277+
res = await asyncio.wait_for(return_42(), timeout=0.1)
278+
self.assertEqual(res, 42)
279+
280+
281+
async def test_wait_for_issue86296(self):
282+
# GH-86296: The task should get cancelled and not run to completion.
283+
# inner completes in one cycle of the event loop so it
284+
# completes before the task is cancelled.
285+
286+
async def inner():
287+
return 'done'
288+
289+
inner_task = asyncio.create_task(inner())
290+
reached_end = False
291+
292+
async def wait_for_coro():
293+
await asyncio.wait_for(inner_task, timeout=100)
294+
await asyncio.sleep(1)
295+
nonlocal reached_end
296+
reached_end = True
297+
298+
task = asyncio.create_task(wait_for_coro())
299+
self.assertFalse(task.done())
300+
# Run the task
301+
await asyncio.sleep(0)
302+
task.cancel()
303+
with self.assertRaises(asyncio.CancelledError):
304+
await task
305+
self.assertTrue(inner_task.done())
306+
self.assertEqual(await inner_task, 'done')
307+
self.assertFalse(reached_end)
308+
309+
310+
class WaitForShieldTests(unittest.IsolatedAsyncioTestCase):
311+
312+
async def test_zero_timeout(self):
313+
# `asyncio.shield` creates a new task which wraps the passed in
314+
# awaitable and shields it from cancellation so with timeout=0
315+
# the task returned by `asyncio.shield` aka shielded_task gets
316+
# cancelled immediately and the task wrapped by it is scheduled
317+
# to run.
318+
319+
async def coro():
320+
await asyncio.sleep(0.01)
321+
return 'done'
322+
323+
task = asyncio.create_task(coro())
324+
with self.assertRaises(asyncio.TimeoutError):
325+
shielded_task = asyncio.shield(task)
326+
await asyncio.wait_for(shielded_task, timeout=0)
327+
328+
# Task is running in background
329+
self.assertFalse(task.done())
330+
self.assertFalse(task.cancelled())
331+
self.assertTrue(shielded_task.cancelled())
332+
333+
# Wait for the task to complete
334+
await asyncio.sleep(0.1)
335+
self.assertTrue(task.done())
336+
337+
338+
async def test_none_timeout(self):
339+
# With timeout=None the timeout is disabled so it
340+
# runs till completion.
341+
async def coro():
342+
await asyncio.sleep(0.1)
343+
return 'done'
344+
345+
task = asyncio.create_task(coro())
346+
await asyncio.wait_for(asyncio.shield(task), timeout=None)
347+
348+
self.assertTrue(task.done())
349+
self.assertEqual(await task, "done")
350+
351+
async def test_shielded_timeout(self):
352+
# shield prevents the task from being cancelled.
353+
async def coro():
354+
await asyncio.sleep(0.1)
355+
return 'done'
356+
357+
task = asyncio.create_task(coro())
358+
with self.assertRaises(asyncio.TimeoutError):
359+
await asyncio.wait_for(asyncio.shield(task), timeout=0.01)
360+
361+
self.assertFalse(task.done())
362+
self.assertFalse(task.cancelled())
363+
self.assertEqual(await task, "done")
364+
292365

293366
if __name__ == '__main__':
294367
unittest.main()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
:func:`asyncio.wait_for` now uses :func:`asyncio.timeout` as its underlying implementation. Patch by Kumar Aditya.

0 commit comments

Comments
 (0)