Skip to content

Commit 3a10339

Browse files
authored
feat: return async iterator (#44)
1 parent 960d096 commit 3a10339

File tree

8 files changed

+197
-56
lines changed

8 files changed

+197
-56
lines changed

examples/streaming_results.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import logging
2+
import asyncio
3+
4+
from mysql_mimic import MysqlServer, Session
5+
6+
7+
class MySession(Session):
8+
async def generate_rows(self, n):
9+
for i in range(n):
10+
if i % 100 == 0:
11+
logging.info("Pretending to fetch another batch of results...")
12+
await asyncio.sleep(1)
13+
yield i,
14+
15+
async def query(self, expression, sql, attrs):
16+
return self.generate_rows(1000), ["a"]
17+
18+
19+
async def main():
20+
logging.basicConfig(level=logging.DEBUG)
21+
server = MysqlServer(session_factory=MySession)
22+
await server.serve_forever()
23+
24+
25+
if __name__ == "__main__":
26+
asyncio.run(main())

mysql_mimic/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from mysql_mimic.results import AllowedResult, ResultColumn, ResultSet
1010
from mysql_mimic.session import Session
1111
from mysql_mimic.server import MysqlServer
12+
from mysql_mimic.types import ColumnType

mysql_mimic/connection.py

+26-20
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from ssl import SSLContext
3-
from typing import Optional, Dict, Any, Iterator
3+
from typing import Optional, Dict, Any, Iterator, AsyncIterator
44

55
from mysql_mimic.auth import (
66
AuthInfo,
@@ -26,7 +26,7 @@
2626
from mysql_mimic.session import BaseSession
2727
from mysql_mimic.stream import MysqlStream, ConnectionClosed
2828
from mysql_mimic.types import Capabilities
29-
from mysql_mimic.utils import seq
29+
from mysql_mimic.utils import seq, aiterate
3030

3131
logger = logging.getLogger(__name__)
3232

@@ -352,14 +352,16 @@ async def handle_field_list(self, data: bytes) -> None:
352352
sql = com_field_list_to_show_statement(com_field_list)
353353
result = await self.query(sql=sql, query_attrs={})
354354
columns = b"".join(
355-
make_column_definition_41(
356-
server_charset=self.server_charset,
357-
table=com_field_list.table,
358-
name=row[0],
359-
is_com_field_list=True,
360-
default=row[4],
361-
)
362-
for row in result.rows
355+
[
356+
make_column_definition_41(
357+
server_charset=self.server_charset,
358+
table=com_field_list.table,
359+
name=row[0],
360+
is_com_field_list=True,
361+
default=row[4],
362+
)
363+
async for row in aiterate(result.rows)
364+
]
363365
)
364366
await self.stream.write(columns)
365367
await self.stream.write(self.ok_or_eof())
@@ -383,7 +385,7 @@ async def handle_query(self, data: bytes) -> None:
383385
await self.stream.write(self.ok())
384386
return
385387

386-
for packet in self.text_resultset(result_set):
388+
async for packet in self.text_resultset(result_set):
387389
await self.stream.write(packet)
388390

389391
async def handle_stmt_prepare(self, data: bytes) -> None:
@@ -457,10 +459,11 @@ async def handle_stmt_execute(self, data: bytes) -> None:
457459
)
458460
)
459461

460-
rows = (
461-
packets.make_binary_resultrow(r, result_set.columns)
462-
for r in result_set.rows
463-
)
462+
async def gen_rows() -> AsyncIterator[bytes]:
463+
async for r in aiterate(result_set.rows):
464+
yield packets.make_binary_resultrow(r, result_set.columns)
465+
466+
rows = gen_rows()
464467

465468
if com_stmt_execute.use_cursor:
466469
com_stmt_execute.stmt.cursor = rows
@@ -470,7 +473,7 @@ async def handle_stmt_execute(self, data: bytes) -> None:
470473
else:
471474
if not self.deprecate_eof():
472475
await self.stream.write(self.eof())
473-
for row in rows:
476+
async for row in rows:
474477
await self.stream.write(row)
475478
await self.stream.write(self.ok_or_eof())
476479

@@ -485,7 +488,10 @@ async def handle_stmt_fetch(self, data: bytes) -> None:
485488
stmt = self.get_stmt(com_stmt_fetch.stmt_id)
486489
assert stmt.cursor is not None
487490
count = 0
488-
for _, packet in zip(range(com_stmt_fetch.num_rows), stmt.cursor):
491+
492+
async for packet in stmt.cursor:
493+
if count >= com_stmt_fetch.num_rows:
494+
break
489495
await self.stream.write(packet)
490496
count += 1
491497

@@ -530,7 +536,7 @@ def get_stmt(self, stmt_id: int) -> PreparedStatement:
530536
async def query(self, sql: str, query_attrs: Dict[str, str]) -> ResultSet:
531537
logger.debug("Received query: %s", sql)
532538

533-
result_set = ensure_result_set(
539+
result_set = await ensure_result_set(
534540
await self.session.handle_query(sql, query_attrs)
535541
)
536542
return result_set
@@ -574,7 +580,7 @@ def error(self, **kwargs: Any) -> bytes:
574580
def deprecate_eof(self) -> bool:
575581
return Capabilities.CLIENT_DEPRECATE_EOF in self.capabilities
576582

577-
def text_resultset(self, result_set: ResultSet) -> Iterator[bytes]:
583+
async def text_resultset(self, result_set: ResultSet) -> AsyncIterator[bytes]:
578584
yield packets.make_column_count(
579585
capabilities=self.capabilities, column_count=len(result_set.columns)
580586
)
@@ -592,7 +598,7 @@ def text_resultset(self, result_set: ResultSet) -> Iterator[bytes]:
592598

593599
affected_rows = 0
594600

595-
for row in result_set.rows:
601+
async for row in aiterate(result_set.rows):
596602
affected_rows += 1
597603
yield packets.make_text_resultset_row(row, result_set.columns)
598604

mysql_mimic/prepared.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import re
22
from dataclasses import dataclass
3-
from typing import Optional, Dict, Iterable
4-
3+
from typing import Optional, Dict, AsyncIterable
54

65
# Borrowed from mysql-connector-python
76
REGEX_PARAM = re.compile(r"""\?(?=(?:[^"'`]*["'`][^"'`]*["'`])*[^"'`]*$)""")
@@ -13,4 +12,4 @@ class PreparedStatement:
1312
sql: str
1413
num_params: int
1514
param_buffers: Optional[Dict[int, bytearray]] = None
16-
cursor: Optional[Iterable] = None
15+
cursor: Optional[AsyncIterable[bytes]] = None

mysql_mimic/results.py

+78-29
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,23 @@
44
import struct
55
from dataclasses import dataclass
66
from datetime import datetime, date, timedelta
7-
from typing import Iterable, Sequence, Optional, Callable, Any, Union, Tuple, Dict
7+
from typing import (
8+
Iterable,
9+
Sequence,
10+
Optional,
11+
Callable,
12+
Any,
13+
Union,
14+
Tuple,
15+
Dict,
16+
AsyncIterable,
17+
cast,
18+
)
819

920
from mysql_mimic.errors import MysqlError
1021
from mysql_mimic.types import ColumnType, str_len, uint_1, uint_2, uint_4
1122
from mysql_mimic.charset import CharacterSet
12-
23+
from mysql_mimic.utils import aiterate
1324

1425
Encoder = Callable[[Any, "ResultColumn"], bytes]
1526

@@ -54,7 +65,7 @@ def __repr__(self) -> str:
5465

5566
@dataclass
5667
class ResultSet:
57-
rows: Iterable[Sequence]
68+
rows: Iterable[Sequence] | AsyncIterable[Sequence]
5869
columns: Sequence[ResultColumn]
5970

6071
def __bool__(self) -> bool:
@@ -63,11 +74,14 @@ def __bool__(self) -> bool:
6374

6475
AllowedColumn = Union[ResultColumn, str]
6576
AllowedResult = Union[
66-
ResultSet, Tuple[Sequence[Sequence[Any]], Sequence[AllowedColumn]], None
77+
ResultSet,
78+
Tuple[Sequence[Sequence[Any]], Sequence[AllowedColumn]],
79+
Tuple[AsyncIterable[Sequence[Any]], Sequence[AllowedColumn]],
80+
None,
6781
]
6882

6983

70-
def ensure_result_set(result: AllowedResult) -> ResultSet:
84+
async def ensure_result_set(result: AllowedResult) -> ResultSet:
7185
if result is None:
7286
return ResultSet([], [])
7387
if isinstance(result, ResultSet):
@@ -80,39 +94,74 @@ def ensure_result_set(result: AllowedResult) -> ResultSet:
8094
rows = result[0]
8195
columns = result[1]
8296

83-
return ResultSet(
84-
rows=rows,
85-
columns=[_ensure_result_col(col, i, rows) for i, col in enumerate(columns)],
86-
)
97+
return await _ensure_result_cols(rows, columns)
8798

8899
raise MysqlError(f"Unexpected result set type: {type(result)}")
89100

90101

91-
def _ensure_result_col(
92-
column: AllowedColumn, idx: int, rows: Sequence[Sequence[Any]]
93-
) -> ResultColumn:
94-
if isinstance(column, ResultColumn):
95-
return column
102+
async def _ensure_result_cols(
103+
rows: Sequence[Sequence[Any]] | AsyncIterable[Sequence[Any]],
104+
columns: Sequence[AllowedColumn],
105+
) -> ResultSet:
106+
# Which columns need to be inferred?
107+
remaining = {
108+
col: i for i, col in enumerate(columns) if not isinstance(col, ResultColumn)
109+
}
110+
111+
if not remaining:
112+
return ResultSet(
113+
rows=rows,
114+
columns=cast(Sequence[ResultColumn], columns),
115+
)
96116

97-
if isinstance(column, str):
98-
value = _find_first_non_null_value(idx, rows)
99-
type_ = infer_type(value)
100-
return ResultColumn(
101-
name=column,
102-
type=type_,
117+
# Copy the columns
118+
columns = list(columns)
119+
120+
arows = aiterate(rows)
121+
122+
# Keep track of rows we've consumed from the iterator so we can add them back
123+
peeks = []
124+
125+
# Find the first non-null value for each column
126+
while remaining:
127+
try:
128+
peek = await arows.__anext__()
129+
except StopAsyncIteration:
130+
break
131+
132+
peeks.append(peek)
133+
134+
inferred = []
135+
for name, i in remaining.items():
136+
value = peek[i]
137+
if value is not None:
138+
type_ = infer_type(value)
139+
columns[i] = ResultColumn(
140+
name=str(name),
141+
type=type_,
142+
)
143+
inferred.append(name)
144+
145+
for name in inferred:
146+
remaining.pop(name)
147+
148+
# If we failed to find a non-null value, set the type to NULL
149+
for name, i in remaining.items():
150+
columns[i] = ResultColumn(
151+
name=str(name),
152+
type=ColumnType.NULL,
103153
)
104154

105-
raise MysqlError(f"Unexpected result column value: {column}")
155+
# Add the consumed rows back in to the iterator
156+
async def gen_rows() -> AsyncIterable[Sequence[Any]]:
157+
for row in peeks:
158+
yield row
106159

160+
async for row in arows:
161+
yield row
107162

108-
def _find_first_non_null_value(
109-
idx: int, rows: Sequence[Sequence[Any]]
110-
) -> Optional[Any]:
111-
for row in rows:
112-
value = row[idx]
113-
if value is not None:
114-
return value
115-
return None
163+
assert all(isinstance(col, ResultColumn) for col in columns)
164+
return ResultSet(rows=gen_rows(), columns=cast(Sequence[ResultColumn], columns))
116165

117166

118167
def _binary_encode_tiny(col: ResultColumn, val: Any) -> bytes:

mysql_mimic/utils.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from __future__ import annotations
2+
3+
import inspect
24
import sys
35
from collections.abc import Iterator
46
import random
5-
from typing import List
7+
from typing import List, TypeVar, AsyncIterable, Iterable, AsyncIterator, cast
68
import string
79

810
from sqlglot import expressions as exp
911
from sqlglot.optimizer.scope import traverse_scope
1012

1113

14+
T = TypeVar("T")
15+
1216
# MySQL Connector/J uses ASCII to decode nonce
1317
SAFE_NONCE_CHARS = (string.ascii_letters + string.digits).encode()
1418

@@ -92,3 +96,13 @@ def dict_depth(d: dict) -> int:
9296
except StopIteration:
9397
# d.values() returns an empty sequence
9498
return 1
99+
100+
101+
async def aiterate(iterable: AsyncIterable[T] | Iterable[T]) -> AsyncIterator[T]:
102+
"""Iterate either an async iterable or a regular iterable"""
103+
if inspect.isasyncgen(iterable):
104+
async for item in iterable:
105+
yield item
106+
else:
107+
for item in cast(Iterable, iterable):
108+
yield item

tests/test_query.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import io
23
from contextlib import closing
34
from datetime import date, datetime, timedelta
@@ -835,3 +836,26 @@ async def test_unsupported_commands(
835836
with pytest.raises(Exception) as ctx:
836837
await query_fixture(sql)
837838
assert msg in str(ctx.value)
839+
840+
841+
@pytest.mark.asyncio
842+
async def test_async_iterator(
843+
session: MockSession,
844+
server: MysqlServer,
845+
query_fixture: QueryFixture,
846+
) -> None:
847+
async def generate_rows() -> Any:
848+
yield 1, None, None
849+
await asyncio.sleep(0)
850+
yield None, "2", None
851+
await asyncio.sleep(0)
852+
yield None, None, None
853+
854+
session.return_value = (generate_rows(), ["a", "b", "c"])
855+
856+
result = await query_fixture("SELECT * FROM x")
857+
assert [
858+
{"a": 1, "b": None, "c": None},
859+
{"a": None, "b": "2", "c": None},
860+
{"a": None, "b": None, "c": None},
861+
] == result

0 commit comments

Comments
 (0)