From 39c0aa74c5ccd4e7b1d54d74cadfd14ef268e722 Mon Sep 17 00:00:00 2001 From: Chris Ginter Date: Tue, 21 Nov 2023 08:32:28 -0800 Subject: [PATCH] feat: write results in a cooperative fashion (#50) --- mysql_mimic/connection.py | 8 ++++---- mysql_mimic/utils.py | 15 +++++++++++++++ mysql_mimic/version.py | 2 +- tests/test_auth.py | 2 +- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/mysql_mimic/connection.py b/mysql_mimic/connection.py index b46d9c8..db0e02d 100644 --- a/mysql_mimic/connection.py +++ b/mysql_mimic/connection.py @@ -28,7 +28,7 @@ from mysql_mimic.session import BaseSession from mysql_mimic.stream import MysqlStream, ConnectionClosed from mysql_mimic.types import Capabilities -from mysql_mimic.utils import seq, aiterate +from mysql_mimic.utils import seq, aiterate, cooperative_iterate logger = logging.getLogger(__name__) @@ -502,7 +502,7 @@ async def handle_stmt_execute(self, data: bytes) -> None: ) async def gen_rows() -> AsyncIterator[bytes]: - async for r in aiterate(result_set.rows): + async for r in cooperative_iterate(aiterate(result_set.rows)): yield packets.make_binary_resultrow(r, result_set.columns) rows = gen_rows() @@ -531,7 +531,7 @@ async def handle_stmt_fetch(self, data: bytes) -> None: assert stmt.cursor is not None count = 0 - async for packet in stmt.cursor: + async for packet in cooperative_iterate(stmt.cursor): if count >= com_stmt_fetch.num_rows: break await self.stream.write(packet) @@ -640,7 +640,7 @@ async def text_resultset(self, result_set: ResultSet) -> AsyncIterator[bytes]: affected_rows = 0 - async for row in aiterate(result_set.rows): + async for row in cooperative_iterate(aiterate(result_set.rows)): affected_rows += 1 yield packets.make_text_resultset_row(row, result_set.columns) diff --git a/mysql_mimic/utils.py b/mysql_mimic/utils.py index ccebadc..821a03c 100644 --- a/mysql_mimic/utils.py +++ b/mysql_mimic/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import inspect import sys from collections.abc import Iterator @@ -106,3 +107,17 @@ async def aiterate(iterable: AsyncIterable[T] | Iterable[T]) -> AsyncIterator[T] else: for item in cast(Iterable, iterable): yield item + + +async def cooperative_iterate( + iterable: AsyncIterable[T], batch_size: int = 10_000 +) -> AsyncIterator[T]: + """ + Iterate an async iterable in a cooperative manner, yielding control back to the event loop every `batch_size` iterations + """ + i = 0 + async for item in iterable: + if i != 0 and i % batch_size == 0: + await asyncio.sleep(0) + yield item + i += 1 diff --git a/mysql_mimic/version.py b/mysql_mimic/version.py index a5fe2d3..f928bda 100644 --- a/mysql_mimic/version.py +++ b/mysql_mimic/version.py @@ -1,6 +1,6 @@ """mysql-mimic version information""" -__version__ = "2.5.1" +__version__ = "2.5.2" def main(name: str) -> None: diff --git a/tests/test_auth.py b/tests/test_auth.py index 327be83..3c10725 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -18,7 +18,7 @@ # mysql.connector throws an error if you try to use mysql_clear_password without SSL. # That's silly, since SSL termination doesn't have to be handled by MySQL. # But it's extra silly in tests. -MySQLClearPasswordAuthPlugin.requires_ssl = False +MySQLClearPasswordAuthPlugin.requires_ssl = False # type: ignore MySQLConnectionAbstract.is_secure = True # type: ignore SIMPLE_AUTH_USER = "levon_helm"