Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove implicit calls to asyncio.set_event_loop so asyncio.get_event_loop() works later #1956

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions src/prompt_toolkit/application/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import asyncio
import sys
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar

__all__ = ["EventLoop", "asyncio_run"]

_T = TypeVar("_T")

if sys.version_info >= (3, 13):
from asyncio import EventLoop
elif sys.platform == "win32":
from asyncio import ProactorEventLoop as EventLoop
else:
from asyncio import SelectorEventLoop as EventLoop

if sys.version_info >= (3, 12):
asyncio_run = asyncio.run
elif sys.version_info >= (3, 11):

def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
# asyncio.run from Python 3.12
# https://docs.python.org/3/license.html#psf-license
with asyncio.Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(main)

else:
# modified version of asyncio.run from Python 3.10 to add loop_factory kwarg
# https://docs.python.org/3/license.html#psf-license
def asyncio_run(
main: Coroutine[Any, Any, _T],
*,
debug: bool = False,
loop_factory: Callable[[], asyncio.AbstractEventLoop] | None = None,
) -> _T:
try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError(
"asyncio.run() cannot be called from a running event loop"
)

if not asyncio.iscoroutine(main):
raise ValueError(f"a coroutine was expected, got {main!r}")

if loop_factory is None:
loop = asyncio.new_event_loop()
else:
loop = loop_factory()
try:
if loop_factory is None:
asyncio.set_event_loop(loop)
if debug is not None:
loop.set_debug(debug)
return loop.run_until_complete(main)
finally:
try:
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if sys.version_info >= (3, 9):
loop.run_until_complete(loop.shutdown_default_executor())
finally:
if loop_factory is None:
asyncio.set_event_loop(None)
loop.close()

def _cancel_all_tasks(loop: asyncio.AbstractEventLoop) -> None:
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
16 changes: 6 additions & 10 deletions src/prompt_toolkit/application/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
)
from prompt_toolkit.utils import Event, in_main_thread

from ._compat import EventLoop, asyncio_run
from .current import get_app_session, set_app
from .run_in_terminal import in_terminal, run_in_terminal

Expand Down Expand Up @@ -971,14 +972,9 @@ def _called_from_ipython() -> bool:
return False

if inputhook is not None:
# Create new event loop with given input hook and run the app.
# In Python 3.12, we can use asyncio.run(loop_factory=...)
# For now, use `run_until_complete()`.
loop = new_eventloop_with_inputhook(inputhook)
result = loop.run_until_complete(coro)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()
return result
return asyncio_run(
coro, loop_factory=lambda: new_eventloop_with_inputhook(inputhook)
)

elif _called_from_ipython():
# workaround to make input hooks work for IPython until
Expand All @@ -992,14 +988,14 @@ def _called_from_ipython() -> bool:
loop = asyncio.get_event_loop()
except RuntimeError:
# No loop installed. Run like usual.
return asyncio.run(coro)
return asyncio_run(coro, loop_factory=EventLoop)
else:
# Use existing loop.
return loop.run_until_complete(coro)

else:
# No loop installed. Run like usual.
return asyncio.run(coro)
return asyncio_run(coro, loop_factory=EventLoop)

def _handle_exception(
self, loop: AbstractEventLoop, context: dict[str, Any]
Expand Down
Loading