Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: AsyncQueryWrapper removed #283

Merged
merged 2 commits into from
Aug 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions peewee_async/aio_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import peewee

from .result_wrappers import AsyncQueryWrapper
from .result_wrappers import fetch_models
from .utils import CursorProtocol


Expand Down Expand Up @@ -45,29 +45,29 @@ class AioQueryMixin:
async def aio_execute(self, database):
return await database.aio_execute(self)

async def make_async_query_wrapper(self, cursor: CursorProtocol):
return await AsyncQueryWrapper.make_for_all_rows(cursor, self)
async def fetch_results(self, cursor: CursorProtocol):
return await fetch_models(cursor, self)


class AioModelDelete(peewee.ModelDelete, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return await fetch_models(cursor, self)
return cursor.rowcount


class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin):

async def fetch_results(self, cursor: CursorProtocol):
if self._returning:
return await self.make_async_query_wrapper(cursor)
return await fetch_models(cursor, self)
return cursor.rowcount


class AioModelInsert(peewee.ModelInsert, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
if self._returning is not None and len(self._returning) > 1:
return await self.make_async_query_wrapper(cursor)
return await fetch_models(cursor, self)

if self._returning:
row = await cursor.fetchone()
Expand All @@ -77,15 +77,11 @@ async def fetch_results(self, cursor: CursorProtocol):


class AioModelRaw(peewee.ModelRaw, AioQueryMixin):
async def fetch_results(self, cursor: CursorProtocol):
return await self.make_async_query_wrapper(cursor)
pass


class AioSelectMixin(AioQueryMixin):

async def fetch_results(self, cursor: CursorProtocol):
return await self.make_async_query_wrapper(cursor)

@peewee.database_required
async def aio_scalar(self, database, as_tuple=False):
"""
Expand Down
68 changes: 8 additions & 60 deletions peewee_async/result_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, List, Iterator
from typing import Any, List
from typing import Optional, Sequence

from peewee import CursorWrapper, BaseQuery
from peewee import BaseQuery

from .utils import CursorProtocol


class RowsCursor(object):
class SyncCursorAdapter(object):
def __init__(self, rows: List[Any], description: Optional[Sequence[Any]]) -> None:
self._rows = rows
self.description = description
Expand All @@ -23,60 +23,8 @@ def close(self) -> None:
pass


class AsyncQueryWrapper:
"""Async query results wrapper for async `select()`. Internally uses
results wrapper produced by sync peewee select query.

Arguments:

result_wrapper -- empty results wrapper produced by sync `execute()`
call cursor -- async cursor just executed query

To retrieve results after async fetching just iterate over this class
instance, like you generally iterate over sync results wrapper.
"""
def __init__(self, *, cursor: CursorProtocol, query: BaseQuery) -> None:
self._cursor = cursor
self._rows: List[Any] = []
self._result_cache: Optional[List[Any]] = None
self._result_wrapper = self._get_result_wrapper(query)

def __iter__(self) -> Iterator[Any]:
return iter(self._result_wrapper)

def __len__(self) -> int:
return len(self._rows)

def __getitem__(self, idx: int) -> Any:
# NOTE: side effects will appear when both
# iterating and accessing by index!
if self._result_cache is None:
self._result_cache = list(self)
return self._result_cache[idx]

def _get_result_wrapper(self, query: BaseQuery) -> CursorWrapper:
"""Get result wrapper class.
"""
cursor = RowsCursor(self._rows, self._cursor.description)
return query._get_cursor_wrapper(cursor)

async def fetchone(self) -> None:
"""Fetch single row from the cursor.
"""
row = await self._cursor.fetchone()
if not row:
raise GeneratorExit
self._rows.append(row)

async def fetchall(self) -> None:
try:
while True:
await self.fetchone()
except GeneratorExit:
pass

@classmethod
async def make_for_all_rows(cls, cursor: CursorProtocol, query: BaseQuery) -> 'AsyncQueryWrapper':
result = AsyncQueryWrapper(cursor=cursor, query=query)
await result.fetchall()
return result
async def fetch_models(cursor: CursorProtocol, query: BaseQuery):
rows = await cursor.fetchall()
sync_cursor = SyncCursorAdapter(rows, cursor.description)
_result_wrapper = query._get_cursor_wrapper(sync_cursor)
return list(_result_wrapper)
Loading