Skip to content

Commit b6fc23f

Browse files
committed
Accept AsyncIterables being passed to Response
Fixes pallets/flask#5322
1 parent 2fc6d4f commit b6fc23f

File tree

3 files changed

+28
-21
lines changed

3 files changed

+28
-21
lines changed

src/quart/utils.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
from pathlib import Path
1111
from typing import (
1212
Any,
13-
AsyncGenerator,
13+
AsyncIterator,
1414
Awaitable,
1515
Callable,
1616
Coroutine,
17-
Generator,
1817
Iterable,
18+
Iterator,
1919
TYPE_CHECKING,
20+
TypeVar,
2021
)
2122

2223
from werkzeug.datastructures import Headers
@@ -66,12 +67,15 @@ async def _wrapper(*args: Any, **kwargs: Any) -> Any:
6667
return _wrapper
6768

6869

69-
def run_sync_iterable(iterable: Generator[Any, None, None]) -> AsyncGenerator[Any, None]:
70-
async def _gen_wrapper() -> AsyncGenerator[Any, None]:
70+
T = TypeVar("T")
71+
72+
73+
def run_sync_iterable(iterable: Iterator[T]) -> AsyncIterator[T]:
74+
async def _gen_wrapper() -> AsyncIterator[T]:
7175
# Wrap the generator such that each iteration runs
7276
# in the executor. Then rationalise the raised
7377
# errors so that it ends.
74-
def _inner() -> Any:
78+
def _inner() -> T:
7579
# https://bugs.python.org/issue26221
7680
# StopIteration errors are swallowed by the
7781
# run_in_exector method

src/quart/wrappers/response.py

+10-16
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4+
from builtins import aiter
45
from hashlib import md5
5-
from inspect import isasyncgen, isgenerator
66
from io import BytesIO
77
from os import PathLike
88
from types import TracebackType
@@ -102,27 +102,21 @@ async def __anext__(self) -> bytes:
102102

103103

104104
class IterableBody(ResponseBody):
105-
def __init__(self, iterable: AsyncGenerator[bytes, None] | Iterable) -> None:
106-
self.iter: AsyncGenerator[bytes, None]
107-
if isasyncgen(iterable):
108-
self.iter = iterable
109-
elif isgenerator(iterable):
110-
self.iter = run_sync_iterable(iterable)
105+
def __init__(self, iterable: AsyncIterable[Any] | Iterable[Any]) -> None:
106+
self.iter: AsyncIterator[Any]
107+
if isinstance(iterable, Iterable):
108+
self.iter = run_sync_iterable(iter(iterable))
111109
else:
112-
113-
async def _aiter() -> AsyncGenerator[bytes, None]:
114-
for data in iterable: # type: ignore
115-
yield data
116-
117-
self.iter = _aiter()
110+
self.iter = aiter(iterable)
118111

119112
async def __aenter__(self) -> IterableBody:
120113
return self
121114

122115
async def __aexit__(self, exc_type: type, exc_value: BaseException, tb: TracebackType) -> None:
123-
await self.iter.aclose()
116+
if hasattr(self.iter, "aclose"): # Is a generator?
117+
await self.iter.aclose()
124118

125-
def __aiter__(self) -> AsyncIterator:
119+
def __aiter__(self) -> AsyncIterator[Any]:
126120
return self.iter
127121

128122

@@ -262,7 +256,7 @@ class Response(SansIOResponse):
262256

263257
def __init__(
264258
self,
265-
response: ResponseBody | AnyStr | Iterable | None = None,
259+
response: ResponseBody | AnyStr | Iterable | AsyncIterable | None = None,
266260
status: int | None = None,
267261
headers: dict | Headers | None = None,
268262
mimetype: str | None = None,

tests/test_templating.py

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
g,
1010
Quart,
1111
render_template_string,
12+
Response,
1213
ResponseReturnValue,
1314
session,
1415
stream_template_string,
@@ -148,3 +149,11 @@ async def index() -> ResponseReturnValue:
148149
test_client = app.test_client()
149150
response = await test_client.get("/")
150151
assert (await response.data) == b"42"
152+
153+
@app.get("/2")
154+
async def index2() -> ResponseReturnValue:
155+
return Response(await stream_template_string("{{ config }}", config=43))
156+
157+
test_client = app.test_client()
158+
response = await test_client.get("/2")
159+
assert (await response.data) == b"43"

0 commit comments

Comments
 (0)