1
1
import logging
2
2
from ssl import SSLContext
3
- from typing import Optional , Dict , Any , Iterator
3
+ from typing import Optional , Dict , Any , Iterator , AsyncIterator
4
4
5
5
from mysql_mimic .auth import (
6
6
AuthInfo ,
26
26
from mysql_mimic .session import BaseSession
27
27
from mysql_mimic .stream import MysqlStream , ConnectionClosed
28
28
from mysql_mimic .types import Capabilities
29
- from mysql_mimic .utils import seq
29
+ from mysql_mimic .utils import seq , aiterate
30
30
31
31
logger = logging .getLogger (__name__ )
32
32
@@ -352,14 +352,16 @@ async def handle_field_list(self, data: bytes) -> None:
352
352
sql = com_field_list_to_show_statement (com_field_list )
353
353
result = await self .query (sql = sql , query_attrs = {})
354
354
columns = b"" .join (
355
- make_column_definition_41 (
356
- server_charset = self .server_charset ,
357
- table = com_field_list .table ,
358
- name = row [0 ],
359
- is_com_field_list = True ,
360
- default = row [4 ],
361
- )
362
- for row in result .rows
355
+ [
356
+ make_column_definition_41 (
357
+ server_charset = self .server_charset ,
358
+ table = com_field_list .table ,
359
+ name = row [0 ],
360
+ is_com_field_list = True ,
361
+ default = row [4 ],
362
+ )
363
+ async for row in aiterate (result .rows )
364
+ ]
363
365
)
364
366
await self .stream .write (columns )
365
367
await self .stream .write (self .ok_or_eof ())
@@ -383,7 +385,7 @@ async def handle_query(self, data: bytes) -> None:
383
385
await self .stream .write (self .ok ())
384
386
return
385
387
386
- for packet in self .text_resultset (result_set ):
388
+ async for packet in self .text_resultset (result_set ):
387
389
await self .stream .write (packet )
388
390
389
391
async def handle_stmt_prepare (self , data : bytes ) -> None :
@@ -457,10 +459,11 @@ async def handle_stmt_execute(self, data: bytes) -> None:
457
459
)
458
460
)
459
461
460
- rows = (
461
- packets .make_binary_resultrow (r , result_set .columns )
462
- for r in result_set .rows
463
- )
462
+ async def gen_rows () -> AsyncIterator [bytes ]:
463
+ async for r in aiterate (result_set .rows ):
464
+ yield packets .make_binary_resultrow (r , result_set .columns )
465
+
466
+ rows = gen_rows ()
464
467
465
468
if com_stmt_execute .use_cursor :
466
469
com_stmt_execute .stmt .cursor = rows
@@ -470,7 +473,7 @@ async def handle_stmt_execute(self, data: bytes) -> None:
470
473
else :
471
474
if not self .deprecate_eof ():
472
475
await self .stream .write (self .eof ())
473
- for row in rows :
476
+ async for row in rows :
474
477
await self .stream .write (row )
475
478
await self .stream .write (self .ok_or_eof ())
476
479
@@ -485,7 +488,10 @@ async def handle_stmt_fetch(self, data: bytes) -> None:
485
488
stmt = self .get_stmt (com_stmt_fetch .stmt_id )
486
489
assert stmt .cursor is not None
487
490
count = 0
488
- for _ , packet in zip (range (com_stmt_fetch .num_rows ), stmt .cursor ):
491
+
492
+ async for packet in stmt .cursor :
493
+ if count >= com_stmt_fetch .num_rows :
494
+ break
489
495
await self .stream .write (packet )
490
496
count += 1
491
497
@@ -530,7 +536,7 @@ def get_stmt(self, stmt_id: int) -> PreparedStatement:
530
536
async def query (self , sql : str , query_attrs : Dict [str , str ]) -> ResultSet :
531
537
logger .debug ("Received query: %s" , sql )
532
538
533
- result_set = ensure_result_set (
539
+ result_set = await ensure_result_set (
534
540
await self .session .handle_query (sql , query_attrs )
535
541
)
536
542
return result_set
@@ -574,7 +580,7 @@ def error(self, **kwargs: Any) -> bytes:
574
580
def deprecate_eof (self ) -> bool :
575
581
return Capabilities .CLIENT_DEPRECATE_EOF in self .capabilities
576
582
577
- def text_resultset (self , result_set : ResultSet ) -> Iterator [bytes ]:
583
+ async def text_resultset (self , result_set : ResultSet ) -> AsyncIterator [bytes ]:
578
584
yield packets .make_column_count (
579
585
capabilities = self .capabilities , column_count = len (result_set .columns )
580
586
)
@@ -592,7 +598,7 @@ def text_resultset(self, result_set: ResultSet) -> Iterator[bytes]:
592
598
593
599
affected_rows = 0
594
600
595
- for row in result_set .rows :
601
+ async for row in aiterate ( result_set .rows ) :
596
602
affected_rows += 1
597
603
yield packets .make_text_resultset_row (row , result_set .columns )
598
604
0 commit comments