Skip to content

Commit

Permalink
feat: write results in a cooperative fashion (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
ginter authored Nov 21, 2023
1 parent cea2b90 commit 39c0aa7
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
8 changes: 4 additions & 4 deletions mysql_mimic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from mysql_mimic.session import BaseSession
from mysql_mimic.stream import MysqlStream, ConnectionClosed
from mysql_mimic.types import Capabilities
from mysql_mimic.utils import seq, aiterate
from mysql_mimic.utils import seq, aiterate, cooperative_iterate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -502,7 +502,7 @@ async def handle_stmt_execute(self, data: bytes) -> None:
)

async def gen_rows() -> AsyncIterator[bytes]:
async for r in aiterate(result_set.rows):
async for r in cooperative_iterate(aiterate(result_set.rows)):
yield packets.make_binary_resultrow(r, result_set.columns)

rows = gen_rows()
Expand Down Expand Up @@ -531,7 +531,7 @@ async def handle_stmt_fetch(self, data: bytes) -> None:
assert stmt.cursor is not None
count = 0

async for packet in stmt.cursor:
async for packet in cooperative_iterate(stmt.cursor):
if count >= com_stmt_fetch.num_rows:
break
await self.stream.write(packet)
Expand Down Expand Up @@ -640,7 +640,7 @@ async def text_resultset(self, result_set: ResultSet) -> AsyncIterator[bytes]:

affected_rows = 0

async for row in aiterate(result_set.rows):
async for row in cooperative_iterate(aiterate(result_set.rows)):
affected_rows += 1
yield packets.make_text_resultset_row(row, result_set.columns)

Expand Down
15 changes: 15 additions & 0 deletions mysql_mimic/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import inspect
import sys
from collections.abc import Iterator
Expand Down Expand Up @@ -106,3 +107,17 @@ async def aiterate(iterable: AsyncIterable[T] | Iterable[T]) -> AsyncIterator[T]
else:
for item in cast(Iterable, iterable):
yield item


async def cooperative_iterate(
iterable: AsyncIterable[T], batch_size: int = 10_000
) -> AsyncIterator[T]:
"""
Iterate an async iterable in a cooperative manner, yielding control back to the event loop every `batch_size` iterations
"""
i = 0
async for item in iterable:
if i != 0 and i % batch_size == 0:
await asyncio.sleep(0)
yield item
i += 1
2 changes: 1 addition & 1 deletion mysql_mimic/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""mysql-mimic version information"""

__version__ = "2.5.1"
__version__ = "2.5.2"


def main(name: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# mysql.connector throws an error if you try to use mysql_clear_password without SSL.
# That's silly, since SSL termination doesn't have to be handled by MySQL.
# But it's extra silly in tests.
MySQLClearPasswordAuthPlugin.requires_ssl = False
MySQLClearPasswordAuthPlugin.requires_ssl = False # type: ignore
MySQLConnectionAbstract.is_secure = True # type: ignore

SIMPLE_AUTH_USER = "levon_helm"
Expand Down

0 comments on commit 39c0aa7

Please sign in to comment.