diff --git a/src/prompt_toolkit/application/_compat.py b/src/prompt_toolkit/application/_compat.py new file mode 100644 index 000000000..f689713d3 --- /dev/null +++ b/src/prompt_toolkit/application/_compat.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import asyncio +import sys +from collections.abc import Callable, Coroutine +from typing import Any, TypeVar + +__all__ = ["EventLoop", "asyncio_run"] + +_T = TypeVar("_T") + +if sys.version_info >= (3, 13): + from asyncio import EventLoop +elif sys.platform == "win32": + from asyncio import ProactorEventLoop as EventLoop +else: + from asyncio import SelectorEventLoop as EventLoop + +if sys.version_info >= (3, 12): + asyncio_run = asyncio.run +elif sys.version_info >= (3, 11): + + def asyncio_run( + main: Coroutine[Any, Any, _T], + *, + debug: bool = False, + loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None, + ) -> _T: + # asyncio.run from Python 3.12 + # https://docs.python.org/3/license.html#psf-license + with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner: + return runner.run(main) + +else: + # modified version of asyncio.run from Python 3.10 to add loop_factory kwarg + # https://docs.python.org/3/license.html#psf-license + def asyncio_run( + main: Coroutine[Any, Any, _T], + *, + debug: bool = False, + loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None, + ) -> _T: + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError( + "asyncio.run() cannot be called from a running event loop" + ) + + if not asyncio.iscoroutine(main): + raise ValueError(f"a coroutine was expected, got {main!r}") + + if loop_factory is None: + loop = asyncio.new_event_loop() + else: + loop = loop_factory() + try: + if loop_factory is None: + asyncio.set_event_loop(loop) + if debug is not None: + loop.set_debug(debug) + return loop.run_until_complete(main) + finally: + try: + _cancel_all_tasks(loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + if sys.version_info >= (3, 9): + loop.run_until_complete(loop.shutdown_default_executor()) + finally: + if loop_factory is None: + asyncio.set_event_loop(None) + loop.close() + + def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None: + to_cancel = asyncio.all_tasks(loop) + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) diff --git a/src/prompt_toolkit/application/application.py b/src/prompt_toolkit/application/application.py index d93c24398..9456c327e 100644 --- a/src/prompt_toolkit/application/application.py +++ b/src/prompt_toolkit/application/application.py @@ -86,6 +86,7 @@ ) from prompt_toolkit.utils import Event, in_main_thread +from ._compat import EventLoop, asyncio_run from .current import get_app_session, set_app from .run_in_terminal import in_terminal, run_in_terminal @@ -971,14 +972,9 @@ def _called_from_ipython() -> bool: return False if inputhook is not None: - # Create new event loop with given input hook and run the app. - # In Python 3.12, we can use asyncio.run(loop_factory=...) - # For now, use `run_until_complete()`. - loop = new_eventloop_with_inputhook(inputhook) - result = loop.run_until_complete(coro) - loop.run_until_complete(loop.shutdown_asyncgens()) - loop.close() - return result + return asyncio_run( + coro, loop_factory=lambda: new_eventloop_with_inputhook(inputhook) + ) elif _called_from_ipython(): # workaround to make input hooks work for IPython until @@ -992,14 +988,14 @@ def _called_from_ipython() -> bool: loop = asyncio.get_event_loop() except RuntimeError: # No loop installed. Run like usual. - return asyncio.run(coro) + return asyncio_run(coro, loop_factory=EventLoop) else: # Use existing loop. return loop.run_until_complete(coro) else: # No loop installed. Run like usual. - return asyncio.run(coro) + return asyncio_run(coro, loop_factory=EventLoop) def _handle_exception( self, loop: AbstractEventLoop, context: dict[str, Any]