diff --git a/peewee_async/aio_model.py b/peewee_async/aio_model.py index 438f635..4ec7816 100644 --- a/peewee_async/aio_model.py +++ b/peewee_async/aio_model.py @@ -1,6 +1,6 @@ import peewee -from .result_wrappers import AsyncQueryWrapper +from .result_wrappers import fetch_models from .utils import CursorProtocol @@ -45,14 +45,14 @@ class AioQueryMixin: async def aio_execute(self, database): return await database.aio_execute(self) - async def make_async_query_wrapper(self, cursor: CursorProtocol): - return await AsyncQueryWrapper.make_for_all_rows(cursor, self) + async def fetch_results(self, cursor: CursorProtocol): + return await fetch_models(cursor, self) class AioModelDelete(peewee.ModelDelete, AioQueryMixin): async def fetch_results(self, cursor: CursorProtocol): if self._returning: - return await self.make_async_query_wrapper(cursor) + return await fetch_models(cursor, self) return cursor.rowcount @@ -60,14 +60,14 @@ class AioModelUpdate(peewee.ModelUpdate, AioQueryMixin): async def fetch_results(self, cursor: CursorProtocol): if self._returning: - return await self.make_async_query_wrapper(cursor) + return await fetch_models(cursor, self) return cursor.rowcount class AioModelInsert(peewee.ModelInsert, AioQueryMixin): async def fetch_results(self, cursor: CursorProtocol): if self._returning is not None and len(self._returning) > 1: - return await self.make_async_query_wrapper(cursor) + return await fetch_models(cursor, self) if self._returning: row = await cursor.fetchone() @@ -77,15 +77,11 @@ async def fetch_results(self, cursor: CursorProtocol): class AioModelRaw(peewee.ModelRaw, AioQueryMixin): - async def fetch_results(self, cursor: CursorProtocol): - return await self.make_async_query_wrapper(cursor) + pass class AioSelectMixin(AioQueryMixin): - async def fetch_results(self, cursor: CursorProtocol): - return await self.make_async_query_wrapper(cursor) - @peewee.database_required async def aio_scalar(self, database, as_tuple=False): """ diff --git a/peewee_async/result_wrappers.py b/peewee_async/result_wrappers.py index 8a0e3cf..f5e6819 100644 --- a/peewee_async/result_wrappers.py +++ b/peewee_async/result_wrappers.py @@ -1,12 +1,12 @@ -from typing import Any, List, Iterator +from typing import Any, List from typing import Optional, Sequence -from peewee import CursorWrapper, BaseQuery +from peewee import BaseQuery from .utils import CursorProtocol -class RowsCursor(object): +class SyncCursorAdapter(object): def __init__(self, rows: List[Any], description: Optional[Sequence[Any]]) -> None: self._rows = rows self.description = description @@ -23,60 +23,8 @@ def close(self) -> None: pass -class AsyncQueryWrapper: - """Async query results wrapper for async `select()`. Internally uses - results wrapper produced by sync peewee select query. - - Arguments: - - result_wrapper -- empty results wrapper produced by sync `execute()` - call cursor -- async cursor just executed query - - To retrieve results after async fetching just iterate over this class - instance, like you generally iterate over sync results wrapper. - """ - def __init__(self, *, cursor: CursorProtocol, query: BaseQuery) -> None: - self._cursor = cursor - self._rows: List[Any] = [] - self._result_cache: Optional[List[Any]] = None - self._result_wrapper = self._get_result_wrapper(query) - - def __iter__(self) -> Iterator[Any]: - return iter(self._result_wrapper) - - def __len__(self) -> int: - return len(self._rows) - - def __getitem__(self, idx: int) -> Any: - # NOTE: side effects will appear when both - # iterating and accessing by index! - if self._result_cache is None: - self._result_cache = list(self) - return self._result_cache[idx] - - def _get_result_wrapper(self, query: BaseQuery) -> CursorWrapper: - """Get result wrapper class. - """ - cursor = RowsCursor(self._rows, self._cursor.description) - return query._get_cursor_wrapper(cursor) - - async def fetchone(self) -> None: - """Fetch single row from the cursor. - """ - row = await self._cursor.fetchone() - if not row: - raise GeneratorExit - self._rows.append(row) - - async def fetchall(self) -> None: - try: - while True: - await self.fetchone() - except GeneratorExit: - pass - - @classmethod - async def make_for_all_rows(cls, cursor: CursorProtocol, query: BaseQuery) -> 'AsyncQueryWrapper': - result = AsyncQueryWrapper(cursor=cursor, query=query) - await result.fetchall() - return result +async def fetch_models(cursor: CursorProtocol, query: BaseQuery): + rows = await cursor.fetchall() + sync_cursor = SyncCursorAdapter(rows, cursor.description) + _result_wrapper = query._get_cursor_wrapper(sync_cursor) + return list(_result_wrapper)