diff --git a/docs/usage/databases/sqlalchemy/models_and_repository.rst b/docs/usage/databases/sqlalchemy/models_and_repository.rst index c2edb6f63a..86b28dcada 100644 --- a/docs/usage/databases/sqlalchemy/models_and_repository.rst +++ b/docs/usage/databases/sqlalchemy/models_and_repository.rst @@ -13,7 +13,7 @@ Features a `sentinel column `_, and an optional version with audit columns. * Generic synchronous and asynchronous repositories for select, insert, update, and delete operations on SQLAlchemy models -* Implements optimized methods for bulk inserts, updates, and deletes. +* Implements optimized methods for bulk inserts, updates, and deletes and uses `lambda_stmt `_ when possible. * Integrated counts, pagination, sorting, filtering with ``LIKE``, ``IN``, and dates before and/or after. * Tested support for multiple database backends including: @@ -37,6 +37,9 @@ implementations: Both include a ``UUID`` based primary key and ``UUIDAuditBase`` includes an ``updated`` and ``created`` timestamp column. +The ``UUID`` will be a native ``UUID``/``GUID`` type on databases that support it such as Postgres. For other engines without +a native UUID data type, the UUID is stored as a 16-byte ``BYTES`` or ``RAW`` field. + * :class:`BigIntBase ` * :class:`BigIntAuditBase ` @@ -47,7 +50,8 @@ Models using these bases also include the following enhancements: * Auto-generated snake-case table name from class name * Pydantic BaseModel and Dict classes map to an optimized JSON type that is - :class:`JSONB ` for the Postgres and + :class:`JSONB ` for Postgres, + `VARCHAR` or `BYTES` with JSON check constraint for Oracle, and :class:`JSON ` for other dialects. .. literalinclude:: /examples/contrib/sqlalchemy/sqlalchemy_declarative_models.py diff --git a/litestar/contrib/sqlalchemy/base.py b/litestar/contrib/sqlalchemy/base.py index 369590c376..0d287a6307 100644 --- a/litestar/contrib/sqlalchemy/base.py +++ b/litestar/contrib/sqlalchemy/base.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from sqlalchemy.sql import FromClause + from sqlalchemy.sql.schema import _NamingSchemaParameter as NamingSchemaParameter __all__ = ( "AuditColumns", @@ -42,7 +43,7 @@ UUIDBaseT = TypeVar("UUIDBaseT", bound="UUIDBase") BigIntBaseT = TypeVar("BigIntBaseT", bound="BigIntBase") -convention = { +convention: NamingSchemaParameter = { "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", @@ -151,7 +152,7 @@ def to_dict(self, exclude: set[str] | None = None) -> dict[str, Any]: def create_registry() -> registry: """Create a new SQLAlchemy registry.""" - meta = MetaData(naming_convention=convention) # type: ignore[arg-type] + meta = MetaData(naming_convention=convention) return registry( metadata=meta, type_annotation_map={ diff --git a/litestar/contrib/sqlalchemy/repository/_async.py b/litestar/contrib/sqlalchemy/repository/_async.py index 7b5a64cd83..90fc69efa2 100644 --- a/litestar/contrib/sqlalchemy/repository/_async.py +++ b/litestar/contrib/sqlalchemy/repository/_async.py @@ -2,7 +2,18 @@ from typing import TYPE_CHECKING, Any, Final, Generic, Iterable, Literal, cast -from sqlalchemy import Result, Select, TextClause, delete, over, select, text, update +from sqlalchemy import ( + Result, + Select, + StatementLambdaElement, + TextClause, + delete, + lambda_stmt, + over, + select, + text, + update, +) from sqlalchemy import func as sql_func from sqlalchemy.orm import InstrumentedAttribute @@ -20,7 +31,7 @@ ) from ._util import get_instrumented_attr, wrap_sqlalchemy_exception -from .types import ModelT, RowT, SelectT +from .types import ModelT if TYPE_CHECKING: from collections import abc @@ -41,7 +52,7 @@ class SQLAlchemyAsyncRepository(AbstractAsyncRepository[ModelT], Generic[ModelT] def __init__( self, *, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, session: AsyncSession, auto_expunge: bool = False, auto_refresh: bool = True, @@ -64,7 +75,13 @@ def __init__( self.auto_refresh = auto_refresh self.auto_commit = auto_commit self.session = session - self.statement = statement if statement is not None else select(self.model_type) + if isinstance(statement, Select): + self.statement = lambda_stmt(lambda: statement) + elif statement is None: + statement = select(self.model_type) + self.statement = lambda_stmt(lambda: statement) + else: + self.statement = statement if not self.session.bind: # this shouldn't actually ever happen, but we include it anyway to properly # narrow down the types @@ -192,12 +209,36 @@ async def delete_many( if self._dialect.delete_executemany_returning: instances.extend( await self.session.scalars( - delete(self.model_type).where(id_attribute.in_(chunk)).returning(self.model_type) + self._get_delete_many_statement( + statement_type="delete", + model_type=self.model_type, + id_attribute=id_attribute, + id_chunk=chunk, + supports_returning=self._dialect.delete_executemany_returning, + ) ) ) else: - instances.extend(await self.session.scalars(select(self.model_type).where(id_attribute.in_(chunk)))) - await self.session.execute(delete(self.model_type).where(id_attribute.in_(chunk))) + instances.extend( + await self.session.scalars( + self._get_delete_many_statement( + statement_type="select", + model_type=self.model_type, + id_attribute=id_attribute, + id_chunk=chunk, + supports_returning=self._dialect.delete_executemany_returning, + ) + ) + ) + await self.session.execute( + self._get_delete_many_statement( + statement_type="delete", + model_type=self.model_type, + id_attribute=id_attribute, + id_chunk=chunk, + supports_returning=self._dialect.delete_executemany_returning, + ) + ) await self._flush_or_commit(auto_commit=auto_commit) for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) @@ -219,11 +260,35 @@ async def exists(self, **kwargs: Any) -> bool: existing = await self.count(**kwargs) return existing > 0 + def _get_base_stmt( + self, statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None + ) -> StatementLambdaElement: + if isinstance(statement, Select): + return lambda_stmt(lambda: statement) + return self.statement if statement is None else statement + + @staticmethod + def _get_delete_many_statement( + model_type: type[ModelT], + id_attribute: InstrumentedAttribute, + id_chunk: list[Any], + supports_returning: bool, + statement_type: Literal["delete", "select"] = "delete", + ) -> StatementLambdaElement: + if statement_type == "delete": + statement = lambda_stmt(lambda: delete(model_type)) + elif statement_type == "select": + statement = lambda_stmt(lambda: select(model_type)) + statement += lambda s: s.where(id_attribute.in_(id_chunk)) + if supports_returning and statement_type != "select": + statement += lambda s: s.returning(model_type) + return statement + async def get( # type: ignore[override] self, item_id: Any, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, id_attribute: str | InstrumentedAttribute | None = None, ) -> ModelT: """Get instance identified by `item_id`. @@ -245,8 +310,8 @@ async def get( # type: ignore[override] """ with wrap_sqlalchemy_exception(): id_attribute = id_attribute if id_attribute is not None else self.id_attribute - statement = statement if statement is not None else self.statement - statement = self._filter_select_by_kwargs(statement=statement, kwargs=[(id_attribute, item_id)]) + statement = self._get_base_stmt(statement) + statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)]) instance = (await self._execute(statement)).scalar_one_or_none() instance = self.check_not_found(instance) self._expunge(instance, auto_expunge=auto_expunge) @@ -255,7 +320,7 @@ async def get( # type: ignore[override] async def get_one( self, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> ModelT: """Get instance identified by ``kwargs``. @@ -274,8 +339,8 @@ async def get_one( NotFoundError: If no instance found identified by `item_id`. """ with wrap_sqlalchemy_exception(): - statement = statement if statement is not None else self.statement - statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs) + statement = self._get_base_stmt(statement) + statement = self._filter_select_by_kwargs(statement, kwargs) instance = (await self._execute(statement)).scalar_one_or_none() instance = self.check_not_found(instance) self._expunge(instance, auto_expunge=auto_expunge) @@ -284,7 +349,7 @@ async def get_one( async def get_one_or_none( self, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> ModelT | None: """Get instance identified by ``kwargs`` or None if not found. @@ -300,9 +365,9 @@ async def get_one_or_none( The retrieved instance or None """ with wrap_sqlalchemy_exception(): - statement = statement if statement is not None else self.statement - statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs) - instance = (await self._execute(statement)).scalar_one_or_none() + statement = self._get_base_stmt(statement) + statement = self._filter_select_by_kwargs(statement, kwargs) + instance = cast("Result[tuple[ModelT]]", (await self._execute(statement))).scalar_one_or_none() if instance: self._expunge(instance, auto_expunge=auto_expunge) return instance @@ -373,7 +438,7 @@ async def get_or_create( async def count( self, *filters: FilterTypes, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> int: """Get the count of records returned by a query. @@ -387,11 +452,9 @@ async def count( Returns: Count of records returned by query, ignoring pagination. """ - statement = statement if statement is not None else self.statement - statement = statement.with_only_columns( - sql_func.count(self.get_id_attribute_value(self.model_type)), - maintain_column_froms=True, - ).order_by(None) + statement = self._get_base_stmt(statement) + fragment = self.get_id_attribute_value(self.model_type) + statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None) statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) results = await self._execute(statement) @@ -475,10 +538,12 @@ async def update_many( """ data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore with wrap_sqlalchemy_exception(): - if self._dialect.update_executemany_returning and self._dialect.name != "oracle": + supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle" + statement = self._get_update_many_statement(self.model_type, supports_returning) + if supports_returning: instances = list( await self.session.scalars( - update(self.model_type).returning(self.model_type), + statement, cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way # currently to deal with an SQLAlchemy typing issue. See # https://github.com/sqlalchemy/sqlalchemy/discussions/9925 @@ -488,17 +553,25 @@ async def update_many( for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) return instances - await self.session.execute(update(self.model_type), data_to_update) + await self.session.execute(statement, data_to_update) await self._flush_or_commit(auto_commit=auto_commit) return data + @staticmethod + def _get_update_many_statement(model_type: type[ModelT], supports_returning: bool) -> StatementLambdaElement: + statement = lambda_stmt(lambda: update(model_type)) + if supports_returning: + statement += lambda s: s.returning(model_type) + return statement + async def list_and_count( self, *filters: FilterTypes, auto_commit: bool | None = None, auto_expunge: bool | None = None, auto_refresh: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, + force_basic_query_mode: bool | None = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. @@ -513,12 +586,13 @@ async def list_and_count( :class:`SQLAlchemyAsyncRepository.auto_commit ` statement: To facilitate customization of the underlying select query. Defaults to :class:`SQLAlchemyAsyncRepository.statement ` + force_basic_query_mode: Force list and count to use two queries instead of an analytical window function. **kwargs: Instance attribute value filters. Returns: Count of records returned by query, ignoring pagination. """ - if self._dialect.name in {"spanner", "spanner+spanner"}: + if self._dialect.name in {"spanner", "spanner+spanner"} or force_basic_query_mode: return await self._list_and_count_basic(*filters, auto_expunge=auto_expunge, statement=statement, **kwargs) return await self._list_and_count_window(*filters, auto_expunge=auto_expunge, statement=statement, **kwargs) @@ -554,7 +628,7 @@ async def _list_and_count_window( self, *filters: FilterTypes, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. @@ -570,8 +644,9 @@ async def _list_and_count_window( Returns: Count of records returned by query using an analytical window function, ignoring pagination. """ - statement = statement if statement is not None else self.statement - statement = statement.add_columns(over(sql_func.count(self.get_id_attribute_value(self.model_type)))) + statement = self._get_base_stmt(statement) + field = self.get_id_attribute_value(self.model_type) + statement += lambda s: s.add_columns(over(sql_func.count(field))) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) with wrap_sqlalchemy_exception(): @@ -589,7 +664,7 @@ async def _list_and_count_basic( self, *filters: FilterTypes, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. @@ -605,15 +680,12 @@ async def _list_and_count_basic( Returns: Count of records returned by query using 2 queries, ignoring pagination. """ - statement = statement if statement is not None else self.statement + statement = self._get_base_stmt(statement) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) - count_statement = statement.with_only_columns( - sql_func.count(self.get_id_attribute_value(self.model_type)), - maintain_column_froms=True, - ).order_by(None) + with wrap_sqlalchemy_exception(): - count_result = await self.session.execute(count_statement) + count_result = await self.session.execute(self._get_count_stmt(statement)) count = count_result.scalar_one() result = await self._execute(statement) instances: list[ModelT] = [] @@ -622,6 +694,12 @@ async def _list_and_count_basic( instances.append(instance) return instances, count + def _get_count_stmt(self, statement: StatementLambdaElement) -> StatementLambdaElement: + fragment = self.get_id_attribute_value(self.model_type) + statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True) + statement += lambda s: s.order_by(None) + return statement + async def upsert( self, data: ModelT, @@ -718,7 +796,7 @@ async def list( self, *filters: FilterTypes, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> list[ModelT]: """Get a list of instances, optionally filtered. @@ -734,7 +812,7 @@ async def list( Returns: The list of instances, after filtering applied. """ - statement = statement if statement is not None else self.statement + statement = self._get_base_stmt(statement) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) @@ -746,8 +824,8 @@ async def list( return instances def filter_collection_by_kwargs( # type:ignore[override] - self, collection: SelectT, /, **kwargs: Any - ) -> SelectT: + self, collection: Select[tuple[ModelT]] | StatementLambdaElement, /, **kwargs: Any + ) -> StatementLambdaElement: """Filter the collection by kwargs. Args: @@ -756,7 +834,9 @@ def filter_collection_by_kwargs( # type:ignore[override] have the property that their attribute named `key` has value equal to `value`. """ with wrap_sqlalchemy_exception(): - return collection.filter_by(**kwargs) + collection = lambda_stmt(lambda: collection) + collection += lambda s: s.filter_by(**kwargs) + return collection @classmethod async def check_health(cls, session: AsyncSession) -> bool: @@ -799,13 +879,18 @@ async def _attach_to_session(self, model: ModelT, strategy: Literal["add", "merg return await self.session.merge(model) raise ValueError("Unexpected value for `strategy`, must be `'add'` or `'merge'`") - async def _execute(self, statement: Select[RowT]) -> Result[RowT]: - return cast("Result[RowT]", await self.session.execute(statement)) + async def _execute(self, statement: Select[Any] | StatementLambdaElement) -> Result[Any]: + return await self.session.execute(statement) - def _apply_limit_offset_pagination(self, limit: int, offset: int, statement: SelectT) -> SelectT: - return statement.limit(limit).offset(offset) + def _apply_limit_offset_pagination( + self, limit: int, offset: int, statement: StatementLambdaElement + ) -> StatementLambdaElement: + statement += lambda s: s.limit(limit).offset(offset) + return statement - def _apply_filters(self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT) -> SelectT: + def _apply_filters( + self, *filters: FilterTypes, apply_pagination: bool = True, statement: StatementLambdaElement + ) -> StatementLambdaElement: """Apply filters to a select statement. Args: @@ -843,11 +928,7 @@ def _apply_filters(self, *filters: FilterTypes, apply_pagination: bool = True, s elif isinstance(filter_, CollectionFilter): statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement) elif isinstance(filter_, OrderBy): - statement = self._order_by( - statement, - filter_.field_name, - sort_desc=filter_.sort_order == "desc", - ) + statement = self._order_by(statement, filter_.field_name, sort_desc=filter_.sort_order == "desc") elif isinstance(filter_, SearchFilter): statement = self._filter_by_like( statement, filter_.field_name, value=filter_.value, ignore_case=bool(filter_.ignore_case) @@ -860,57 +941,85 @@ def _apply_filters(self, *filters: FilterTypes, apply_pagination: bool = True, s raise RepositoryError(f"Unexpected filter: {filter_}") return statement - def _filter_in_collection(self, field_name: str, values: abc.Collection[Any], statement: SelectT) -> SelectT: + def _filter_in_collection( + self, field_name: str, values: abc.Collection[Any], statement: StatementLambdaElement + ) -> StatementLambdaElement: if not values: return statement - return statement.where(getattr(self.model_type, field_name).in_(values)) + field = getattr(self.model_type, field_name) + statement += lambda s: s.where(field.in_(values)) + return statement - def _filter_not_in_collection(self, field_name: str, values: abc.Collection[Any], statement: SelectT) -> SelectT: + def _filter_not_in_collection( + self, field_name: str, values: abc.Collection[Any], statement: StatementLambdaElement + ) -> StatementLambdaElement: if not values: return statement - return statement.where(getattr(self.model_type, field_name).notin_(values)) + field = getattr(self.model_type, field_name) + statement += lambda s: s.where(field.notin_(values)) + return statement def _filter_on_datetime_field( self, field_name: str, - statement: SelectT, + statement: StatementLambdaElement, before: datetime | None = None, after: datetime | None = None, on_or_before: datetime | None = None, on_or_after: datetime | None = None, - ) -> SelectT: + ) -> StatementLambdaElement: field = getattr(self.model_type, field_name) if before is not None: - statement = statement.where(field < before) + statement += lambda s: s.where(field < before) if after is not None: - statement = statement.where(field > after) + statement += lambda s: s.where(field > after) if on_or_before is not None: - statement = statement.where(field <= on_or_before) + statement += lambda s: s.where(field <= on_or_before) if on_or_after is not None: - statement = statement.where(field >= on_or_after) + statement += lambda s: s.where(field >= on_or_after) return statement def _filter_select_by_kwargs( - self, statement: SelectT, kwargs: dict[Any, Any] | Iterable[tuple[Any, Any]] - ) -> SelectT: + self, statement: StatementLambdaElement, kwargs: dict[Any, Any] | Iterable[tuple[Any, Any]] + ) -> StatementLambdaElement: for key, val in kwargs.items() if isinstance(kwargs, dict) else kwargs: - statement = statement.where(get_instrumented_attr(self.model_type, key) == val) # pyright: ignore + statement = self._filter_by_where(statement, key, val) # pyright: ignore[reportGeneralTypeIssues] + return statement + + def _filter_by_where(self, statement: StatementLambdaElement, key: str, val: Any) -> StatementLambdaElement: + model_type = self.model_type + field = get_instrumented_attr(model_type, key) + statement += lambda s: s.where(field == val) return statement def _filter_by_like( - self, statement: SelectT, field_name: str | InstrumentedAttribute, value: str, ignore_case: bool - ) -> SelectT: + self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, value: str, ignore_case: bool + ) -> StatementLambdaElement: field = get_instrumented_attr(self.model_type, field_name) search_text = f"%{value}%" - return statement.where(field.ilike(search_text) if ignore_case else field.like(search_text)) + if ignore_case: + statement += lambda s: s.where(field.ilike(search_text)) + else: + statement += lambda s: s.where(field.like(search_text)) + return statement - def _filter_by_not_like(self, statement: SelectT, field_name: str, value: str, ignore_case: bool) -> SelectT: + def _filter_by_not_like( + self, statement: StatementLambdaElement, field_name: str, value: str, ignore_case: bool + ) -> StatementLambdaElement: field = getattr(self.model_type, field_name) search_text = f"%{value}%" - return statement.where(field.not_ilike(search_text) if ignore_case else field.not_like(search_text)) + if ignore_case: + statement += lambda s: s.where(field.not_ilike(search_text)) + else: + statement += lambda s: s.where(field.not_like(search_text)) + return statement def _order_by( - self, statement: SelectT, field_name: str | InstrumentedAttribute, sort_desc: bool = False - ) -> SelectT: + self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, sort_desc: bool = False + ) -> StatementLambdaElement: field = get_instrumented_attr(self.model_type, field_name) - return statement.order_by(field.desc() if sort_desc else field.asc()) + if sort_desc: + statement += lambda s: s.order_by(field.desc()) + else: + statement += lambda s: s.order_by(field.asc()) + return statement diff --git a/litestar/contrib/sqlalchemy/repository/_sync.py b/litestar/contrib/sqlalchemy/repository/_sync.py index dd65115183..4ffbdf517a 100644 --- a/litestar/contrib/sqlalchemy/repository/_sync.py +++ b/litestar/contrib/sqlalchemy/repository/_sync.py @@ -4,7 +4,18 @@ from typing import TYPE_CHECKING, Any, Final, Generic, Iterable, Literal, cast -from sqlalchemy import Result, Select, TextClause, delete, over, select, text, update +from sqlalchemy import ( + Result, + Select, + StatementLambdaElement, + TextClause, + delete, + lambda_stmt, + over, + select, + text, + update, +) from sqlalchemy import func as sql_func from sqlalchemy.orm import InstrumentedAttribute, Session @@ -22,7 +33,7 @@ ) from ._util import get_instrumented_attr, wrap_sqlalchemy_exception -from .types import ModelT, RowT, SelectT +from .types import ModelT if TYPE_CHECKING: from collections import abc @@ -42,7 +53,7 @@ class SQLAlchemySyncRepository(AbstractSyncRepository[ModelT], Generic[ModelT]): def __init__( self, *, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, session: Session, auto_expunge: bool = False, auto_refresh: bool = True, @@ -65,7 +76,13 @@ def __init__( self.auto_refresh = auto_refresh self.auto_commit = auto_commit self.session = session - self.statement = statement if statement is not None else select(self.model_type) + if isinstance(statement, Select): + self.statement = lambda_stmt(lambda: statement) + elif statement is None: + statement = select(self.model_type) + self.statement = lambda_stmt(lambda: statement) + else: + self.statement = statement if not self.session.bind: # this shouldn't actually ever happen, but we include it anyway to properly # narrow down the types @@ -193,12 +210,36 @@ def delete_many( if self._dialect.delete_executemany_returning: instances.extend( self.session.scalars( - delete(self.model_type).where(id_attribute.in_(chunk)).returning(self.model_type) + self._get_delete_many_statement( + statement_type="delete", + model_type=self.model_type, + id_attribute=id_attribute, + id_chunk=chunk, + supports_returning=self._dialect.delete_executemany_returning, + ) ) ) else: - instances.extend(self.session.scalars(select(self.model_type).where(id_attribute.in_(chunk)))) - self.session.execute(delete(self.model_type).where(id_attribute.in_(chunk))) + instances.extend( + self.session.scalars( + self._get_delete_many_statement( + statement_type="select", + model_type=self.model_type, + id_attribute=id_attribute, + id_chunk=chunk, + supports_returning=self._dialect.delete_executemany_returning, + ) + ) + ) + self.session.execute( + self._get_delete_many_statement( + statement_type="delete", + model_type=self.model_type, + id_attribute=id_attribute, + id_chunk=chunk, + supports_returning=self._dialect.delete_executemany_returning, + ) + ) self._flush_or_commit(auto_commit=auto_commit) for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) @@ -220,11 +261,35 @@ def exists(self, **kwargs: Any) -> bool: existing = self.count(**kwargs) return existing > 0 + def _get_base_stmt( + self, statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None + ) -> StatementLambdaElement: + if isinstance(statement, Select): + return lambda_stmt(lambda: statement) + return self.statement if statement is None else statement + + @staticmethod + def _get_delete_many_statement( + model_type: type[ModelT], + id_attribute: InstrumentedAttribute, + id_chunk: list[Any], + supports_returning: bool, + statement_type: Literal["delete", "select"] = "delete", + ) -> StatementLambdaElement: + if statement_type == "delete": + statement = lambda_stmt(lambda: delete(model_type)) + elif statement_type == "select": + statement = lambda_stmt(lambda: select(model_type)) + statement += lambda s: s.where(id_attribute.in_(id_chunk)) + if supports_returning and statement_type != "select": + statement += lambda s: s.returning(model_type) + return statement + def get( # type: ignore[override] self, item_id: Any, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, id_attribute: str | InstrumentedAttribute | None = None, ) -> ModelT: """Get instance identified by `item_id`. @@ -246,8 +311,8 @@ def get( # type: ignore[override] """ with wrap_sqlalchemy_exception(): id_attribute = id_attribute if id_attribute is not None else self.id_attribute - statement = statement if statement is not None else self.statement - statement = self._filter_select_by_kwargs(statement=statement, kwargs=[(id_attribute, item_id)]) + statement = self._get_base_stmt(statement) + statement = self._filter_select_by_kwargs(statement, [(id_attribute, item_id)]) instance = (self._execute(statement)).scalar_one_or_none() instance = self.check_not_found(instance) self._expunge(instance, auto_expunge=auto_expunge) @@ -256,7 +321,7 @@ def get( # type: ignore[override] def get_one( self, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> ModelT: """Get instance identified by ``kwargs``. @@ -275,8 +340,8 @@ def get_one( NotFoundError: If no instance found identified by `item_id`. """ with wrap_sqlalchemy_exception(): - statement = statement if statement is not None else self.statement - statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs) + statement = self._get_base_stmt(statement) + statement = self._filter_select_by_kwargs(statement, kwargs) instance = (self._execute(statement)).scalar_one_or_none() instance = self.check_not_found(instance) self._expunge(instance, auto_expunge=auto_expunge) @@ -285,7 +350,7 @@ def get_one( def get_one_or_none( self, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> ModelT | None: """Get instance identified by ``kwargs`` or None if not found. @@ -301,9 +366,9 @@ def get_one_or_none( The retrieved instance or None """ with wrap_sqlalchemy_exception(): - statement = statement if statement is not None else self.statement - statement = self._filter_select_by_kwargs(statement=statement, kwargs=kwargs) - instance = (self._execute(statement)).scalar_one_or_none() + statement = self._get_base_stmt(statement) + statement = self._filter_select_by_kwargs(statement, kwargs) + instance = cast("Result[tuple[ModelT]]", (self._execute(statement))).scalar_one_or_none() if instance: self._expunge(instance, auto_expunge=auto_expunge) return instance @@ -374,7 +439,7 @@ def get_or_create( def count( self, *filters: FilterTypes, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> int: """Get the count of records returned by a query. @@ -388,11 +453,9 @@ def count( Returns: Count of records returned by query, ignoring pagination. """ - statement = statement if statement is not None else self.statement - statement = statement.with_only_columns( - sql_func.count(self.get_id_attribute_value(self.model_type)), - maintain_column_froms=True, - ).order_by(None) + statement = self._get_base_stmt(statement) + fragment = self.get_id_attribute_value(self.model_type) + statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True).order_by(None) statement = self._apply_filters(*filters, apply_pagination=False, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) results = self._execute(statement) @@ -476,10 +539,12 @@ def update_many( """ data_to_update: list[dict[str, Any]] = [v.to_dict() if isinstance(v, self.model_type) else v for v in data] # type: ignore with wrap_sqlalchemy_exception(): - if self._dialect.update_executemany_returning and self._dialect.name != "oracle": + supports_returning = self._dialect.update_executemany_returning and self._dialect.name != "oracle" + statement = self._get_update_many_statement(self.model_type, supports_returning) + if supports_returning: instances = list( self.session.scalars( - update(self.model_type).returning(self.model_type), + statement, cast("_CoreSingleExecuteParams", data_to_update), # this is not correct but the only way # currently to deal with an SQLAlchemy typing issue. See # https://github.com/sqlalchemy/sqlalchemy/discussions/9925 @@ -489,17 +554,25 @@ def update_many( for instance in instances: self._expunge(instance, auto_expunge=auto_expunge) return instances - self.session.execute(update(self.model_type), data_to_update) + self.session.execute(statement, data_to_update) self._flush_or_commit(auto_commit=auto_commit) return data + @staticmethod + def _get_update_many_statement(model_type: type[ModelT], supports_returning: bool) -> StatementLambdaElement: + statement = lambda_stmt(lambda: update(model_type)) + if supports_returning: + statement += lambda s: s.returning(model_type) + return statement + def list_and_count( self, *filters: FilterTypes, auto_commit: bool | None = None, auto_expunge: bool | None = None, auto_refresh: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, + force_basic_query_mode: bool | None = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. @@ -514,12 +587,13 @@ def list_and_count( :class:`SQLAlchemyAsyncRepository.auto_commit ` statement: To facilitate customization of the underlying select query. Defaults to :class:`SQLAlchemyAsyncRepository.statement ` + force_basic_query_mode: Force list and count to use two queries instead of an analytical window function. **kwargs: Instance attribute value filters. Returns: Count of records returned by query, ignoring pagination. """ - if self._dialect.name in {"spanner", "spanner+spanner"}: + if self._dialect.name in {"spanner", "spanner+spanner"} or force_basic_query_mode: return self._list_and_count_basic(*filters, auto_expunge=auto_expunge, statement=statement, **kwargs) return self._list_and_count_window(*filters, auto_expunge=auto_expunge, statement=statement, **kwargs) @@ -555,7 +629,7 @@ def _list_and_count_window( self, *filters: FilterTypes, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. @@ -571,8 +645,9 @@ def _list_and_count_window( Returns: Count of records returned by query using an analytical window function, ignoring pagination. """ - statement = statement if statement is not None else self.statement - statement = statement.add_columns(over(sql_func.count(self.get_id_attribute_value(self.model_type)))) + statement = self._get_base_stmt(statement) + field = self.get_id_attribute_value(self.model_type) + statement += lambda s: s.add_columns(over(sql_func.count(field))) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) with wrap_sqlalchemy_exception(): @@ -590,7 +665,7 @@ def _list_and_count_basic( self, *filters: FilterTypes, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> tuple[list[ModelT], int]: """List records with total count. @@ -606,15 +681,12 @@ def _list_and_count_basic( Returns: Count of records returned by query using 2 queries, ignoring pagination. """ - statement = statement if statement is not None else self.statement + statement = self._get_base_stmt(statement) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) - count_statement = statement.with_only_columns( - sql_func.count(self.get_id_attribute_value(self.model_type)), - maintain_column_froms=True, - ).order_by(None) + with wrap_sqlalchemy_exception(): - count_result = self.session.execute(count_statement) + count_result = self.session.execute(self._get_count_stmt(statement)) count = count_result.scalar_one() result = self._execute(statement) instances: list[ModelT] = [] @@ -623,6 +695,12 @@ def _list_and_count_basic( instances.append(instance) return instances, count + def _get_count_stmt(self, statement: StatementLambdaElement) -> StatementLambdaElement: + fragment = self.get_id_attribute_value(self.model_type) + statement += lambda s: s.with_only_columns(sql_func.count(fragment), maintain_column_froms=True) + statement += lambda s: s.order_by(None) + return statement + def upsert( self, data: ModelT, @@ -719,7 +797,7 @@ def list( self, *filters: FilterTypes, auto_expunge: bool | None = None, - statement: Select[tuple[ModelT]] | None = None, + statement: Select[tuple[ModelT]] | StatementLambdaElement | None = None, **kwargs: Any, ) -> list[ModelT]: """Get a list of instances, optionally filtered. @@ -735,7 +813,7 @@ def list( Returns: The list of instances, after filtering applied. """ - statement = statement if statement is not None else self.statement + statement = self._get_base_stmt(statement) statement = self._apply_filters(*filters, statement=statement) statement = self._filter_select_by_kwargs(statement, kwargs) @@ -747,8 +825,8 @@ def list( return instances def filter_collection_by_kwargs( # type:ignore[override] - self, collection: SelectT, /, **kwargs: Any - ) -> SelectT: + self, collection: Select[tuple[ModelT]] | StatementLambdaElement, /, **kwargs: Any + ) -> StatementLambdaElement: """Filter the collection by kwargs. Args: @@ -757,7 +835,9 @@ def filter_collection_by_kwargs( # type:ignore[override] have the property that their attribute named `key` has value equal to `value`. """ with wrap_sqlalchemy_exception(): - return collection.filter_by(**kwargs) + collection = lambda_stmt(lambda: collection) + collection += lambda s: s.filter_by(**kwargs) + return collection @classmethod def check_health(cls, session: Session) -> bool: @@ -800,13 +880,18 @@ def _attach_to_session(self, model: ModelT, strategy: Literal["add", "merge"] = return self.session.merge(model) raise ValueError("Unexpected value for `strategy`, must be `'add'` or `'merge'`") - def _execute(self, statement: Select[RowT]) -> Result[RowT]: - return cast("Result[RowT]", self.session.execute(statement)) + def _execute(self, statement: Select[Any] | StatementLambdaElement) -> Result[Any]: + return self.session.execute(statement) - def _apply_limit_offset_pagination(self, limit: int, offset: int, statement: SelectT) -> SelectT: - return statement.limit(limit).offset(offset) + def _apply_limit_offset_pagination( + self, limit: int, offset: int, statement: StatementLambdaElement + ) -> StatementLambdaElement: + statement += lambda s: s.limit(limit).offset(offset) + return statement - def _apply_filters(self, *filters: FilterTypes, apply_pagination: bool = True, statement: SelectT) -> SelectT: + def _apply_filters( + self, *filters: FilterTypes, apply_pagination: bool = True, statement: StatementLambdaElement + ) -> StatementLambdaElement: """Apply filters to a select statement. Args: @@ -844,11 +929,7 @@ def _apply_filters(self, *filters: FilterTypes, apply_pagination: bool = True, s elif isinstance(filter_, CollectionFilter): statement = self._filter_in_collection(filter_.field_name, filter_.values, statement=statement) elif isinstance(filter_, OrderBy): - statement = self._order_by( - statement, - filter_.field_name, - sort_desc=filter_.sort_order == "desc", - ) + statement = self._order_by(statement, filter_.field_name, sort_desc=filter_.sort_order == "desc") elif isinstance(filter_, SearchFilter): statement = self._filter_by_like( statement, filter_.field_name, value=filter_.value, ignore_case=bool(filter_.ignore_case) @@ -861,57 +942,85 @@ def _apply_filters(self, *filters: FilterTypes, apply_pagination: bool = True, s raise RepositoryError(f"Unexpected filter: {filter_}") return statement - def _filter_in_collection(self, field_name: str, values: abc.Collection[Any], statement: SelectT) -> SelectT: + def _filter_in_collection( + self, field_name: str, values: abc.Collection[Any], statement: StatementLambdaElement + ) -> StatementLambdaElement: if not values: return statement - return statement.where(getattr(self.model_type, field_name).in_(values)) + field = getattr(self.model_type, field_name) + statement += lambda s: s.where(field.in_(values)) + return statement - def _filter_not_in_collection(self, field_name: str, values: abc.Collection[Any], statement: SelectT) -> SelectT: + def _filter_not_in_collection( + self, field_name: str, values: abc.Collection[Any], statement: StatementLambdaElement + ) -> StatementLambdaElement: if not values: return statement - return statement.where(getattr(self.model_type, field_name).notin_(values)) + field = getattr(self.model_type, field_name) + statement += lambda s: s.where(field.notin_(values)) + return statement def _filter_on_datetime_field( self, field_name: str, - statement: SelectT, + statement: StatementLambdaElement, before: datetime | None = None, after: datetime | None = None, on_or_before: datetime | None = None, on_or_after: datetime | None = None, - ) -> SelectT: + ) -> StatementLambdaElement: field = getattr(self.model_type, field_name) if before is not None: - statement = statement.where(field < before) + statement += lambda s: s.where(field < before) if after is not None: - statement = statement.where(field > after) + statement += lambda s: s.where(field > after) if on_or_before is not None: - statement = statement.where(field <= on_or_before) + statement += lambda s: s.where(field <= on_or_before) if on_or_after is not None: - statement = statement.where(field >= on_or_after) + statement += lambda s: s.where(field >= on_or_after) return statement def _filter_select_by_kwargs( - self, statement: SelectT, kwargs: dict[Any, Any] | Iterable[tuple[Any, Any]] - ) -> SelectT: + self, statement: StatementLambdaElement, kwargs: dict[Any, Any] | Iterable[tuple[Any, Any]] + ) -> StatementLambdaElement: for key, val in kwargs.items() if isinstance(kwargs, dict) else kwargs: - statement = statement.where(get_instrumented_attr(self.model_type, key) == val) # pyright: ignore + statement = self._filter_by_where(statement, key, val) # pyright: ignore[reportGeneralTypeIssues] + return statement + + def _filter_by_where(self, statement: StatementLambdaElement, key: str, val: Any) -> StatementLambdaElement: + model_type = self.model_type + field = get_instrumented_attr(model_type, key) + statement += lambda s: s.where(field == val) return statement def _filter_by_like( - self, statement: SelectT, field_name: str | InstrumentedAttribute, value: str, ignore_case: bool - ) -> SelectT: + self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, value: str, ignore_case: bool + ) -> StatementLambdaElement: field = get_instrumented_attr(self.model_type, field_name) search_text = f"%{value}%" - return statement.where(field.ilike(search_text) if ignore_case else field.like(search_text)) + if ignore_case: + statement += lambda s: s.where(field.ilike(search_text)) + else: + statement += lambda s: s.where(field.like(search_text)) + return statement - def _filter_by_not_like(self, statement: SelectT, field_name: str, value: str, ignore_case: bool) -> SelectT: + def _filter_by_not_like( + self, statement: StatementLambdaElement, field_name: str, value: str, ignore_case: bool + ) -> StatementLambdaElement: field = getattr(self.model_type, field_name) search_text = f"%{value}%" - return statement.where(field.not_ilike(search_text) if ignore_case else field.not_like(search_text)) + if ignore_case: + statement += lambda s: s.where(field.not_ilike(search_text)) + else: + statement += lambda s: s.where(field.not_like(search_text)) + return statement def _order_by( - self, statement: SelectT, field_name: str | InstrumentedAttribute, sort_desc: bool = False - ) -> SelectT: + self, statement: StatementLambdaElement, field_name: str | InstrumentedAttribute, sort_desc: bool = False + ) -> StatementLambdaElement: field = get_instrumented_attr(self.model_type, field_name) - return statement.order_by(field.desc() if sort_desc else field.asc()) + if sort_desc: + statement += lambda s: s.order_by(field.desc()) + else: + statement += lambda s: s.order_by(field.asc()) + return statement diff --git a/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_repository.py b/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_repository.py index 9de345c40e..355433b189 100644 --- a/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_repository.py +++ b/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_repository.py @@ -573,6 +573,20 @@ async def test_repo_list_and_count_method(raw_authors: RawRecordData, author_rep assert len(collection) == exp_count +async def test_repo_list_and_count_basic_method(raw_authors: RawRecordData, author_repo: AuthorRepository) -> None: + """Test SQLAlchemy basic list with count in asyncpg. + + Args: + raw_authors: list of authors pre-seeded into the mock repository + author_repo: The author mock repository + """ + exp_count = len(raw_authors) + collection, count = await maybe_async(author_repo.list_and_count(force_basic_query_mode=True)) + assert exp_count == count + assert isinstance(collection, list) + assert len(collection) == exp_count + + async def test_repo_list_and_count_method_empty(book_repo: BookRepository) -> None: collection, count = await maybe_async(book_repo.list_and_count()) assert 0 == count diff --git a/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_sqlalchemy.py b/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_sqlalchemy.py index 20103326e4..e8bf0d4b32 100644 --- a/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_sqlalchemy.py +++ b/tests/unit/test_contrib/test_sqlalchemy/test_repository/test_sqlalchemy.py @@ -3,7 +3,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Union, cast -from unittest.mock import AsyncMock, MagicMock, call +from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest @@ -476,7 +476,7 @@ async def test_sqlalchemy_repo_list_and_count(mock_repo: SQLAlchemyAsyncReposito """Test expected method calls for list operation.""" mock_instances = [MagicMock(), MagicMock()] mock_count = len(mock_instances) - mocker.patch.object(mock_repo, "_list_and_count_window", return_value=(mock_instances, mock_count)) + mocker.patch.object(mock_repo, "_list_and_count_basic", return_value=(mock_instances, mock_count)) mocker.patch.object(mock_repo, "_list_and_count_window", return_value=(mock_instances, mock_count)) instances, instance_count = await maybe_async(mock_repo.list_and_count()) @@ -487,6 +487,23 @@ async def test_sqlalchemy_repo_list_and_count(mock_repo: SQLAlchemyAsyncReposito mock_repo.session.commit.assert_not_called() +async def test_sqlalchemy_repo_list_and_count_basic( + mock_repo: SQLAlchemyAsyncRepository, mocker: MockerFixture +) -> None: + """Test expected method calls for list operation.""" + mock_instances = [MagicMock(), MagicMock()] + mock_count = len(mock_instances) + mocker.patch.object(mock_repo, "_list_and_count_basic", return_value=(mock_instances, mock_count)) + mocker.patch.object(mock_repo, "_list_and_count_window", return_value=(mock_instances, mock_count)) + + instances, instance_count = await maybe_async(mock_repo.list_and_count(force_basic_query_mode=True)) + + assert instances == mock_instances + assert instance_count == mock_count + mock_repo.session.expunge.assert_not_called() + mock_repo.session.commit.assert_not_called() + + async def test_sqlalchemy_repo_exists( mock_repo: SQLAlchemyAsyncRepository, monkeypatch: MonkeyPatch, @@ -518,17 +535,15 @@ async def test_sqlalchemy_repo_count( async def test_sqlalchemy_repo_list_with_pagination( - mock_repo: SQLAlchemyAsyncRepository, monkeypatch: MonkeyPatch, mock_repo_execute: AnyMock + mock_repo: SQLAlchemyAsyncRepository, monkeypatch: MonkeyPatch, mock_repo_execute: AnyMock, mocker: MockerFixture ) -> None: """Test list operation with pagination.""" + mocker.patch.object(mock_repo, "_apply_limit_offset_pagination", return_value=mock_repo.statement) mock_repo_execute.return_value = MagicMock() - mock_repo.statement.limit.return_value = mock_repo.statement - mock_repo.statement.offset.return_value = mock_repo.statement - + mock_repo.statement.where.return_value = mock_repo.statement await maybe_async(mock_repo.list(LimitOffset(2, 3))) - - mock_repo.statement.limit.assert_called_once_with(2) - mock_repo.statement.limit().offset.assert_called_once_with(3) # type:ignore[call-arg] + assert mock_repo._apply_limit_offset_pagination.call_count == 1 + mock_repo._apply_limit_offset_pagination.assert_called_with(2, 3, statement=mock_repo.statement) async def test_sqlalchemy_repo_list_with_before_after_filter( @@ -537,14 +552,14 @@ async def test_sqlalchemy_repo_list_with_before_after_filter( """Test list operation with BeforeAfter filter.""" mocker.patch.object(mock_repo.model_type.updated_at, "__lt__", return_value="lt") mocker.patch.object(mock_repo.model_type.updated_at, "__gt__", return_value="gt") - + mocker.patch.object(mock_repo, "_filter_on_datetime_field", return_value=mock_repo.statement) mock_repo_execute.return_value = MagicMock() mock_repo.statement.where.return_value = mock_repo.statement - await maybe_async(mock_repo.list(BeforeAfter("updated_at", datetime.max, datetime.min))) - - assert mock_repo.statement.where.call_count == 2 - mock_repo.statement.where.assert_has_calls([call("gt"), call("lt")], any_order=True) + assert mock_repo._filter_on_datetime_field.call_count == 1 + mock_repo._filter_on_datetime_field.assert_called_with( + field_name="updated_at", before=datetime.max, after=datetime.min, statement=mock_repo.statement + ) async def test_sqlalchemy_repo_list_with_on_before_after_filter( @@ -553,43 +568,42 @@ async def test_sqlalchemy_repo_list_with_on_before_after_filter( """Test list operation with BeforeAfter filter.""" mocker.patch.object(mock_repo.model_type.updated_at, "__le__", return_value="le") mocker.patch.object(mock_repo.model_type.updated_at, "__ge__", return_value="ge") - + mocker.patch.object(mock_repo, "_filter_on_datetime_field", return_value=mock_repo.statement) mock_repo_execute.return_value = MagicMock() mock_repo.statement.where.return_value = mock_repo.statement await maybe_async(mock_repo.list(OnBeforeAfter("updated_at", datetime.max, datetime.min))) - - assert mock_repo.statement.where.call_count == 2 - mock_repo.statement.where.assert_has_calls([call("ge"), call("le")], any_order=True) + assert mock_repo._filter_on_datetime_field.call_count == 1 + mock_repo._filter_on_datetime_field.assert_called_with( + field_name="updated_at", on_or_before=datetime.max, on_or_after=datetime.min, statement=mock_repo.statement + ) async def test_sqlalchemy_repo_list_with_collection_filter( - mock_repo: SQLAlchemyAsyncRepository, monkeypatch: MonkeyPatch, mock_repo_execute: AnyMock + mock_repo: SQLAlchemyAsyncRepository, monkeypatch: MonkeyPatch, mock_repo_execute: AnyMock, mocker: MockerFixture ) -> None: """Test behavior of list operation given CollectionFilter.""" field_name = "id" mock_repo_execute.return_value = MagicMock() mock_repo.statement.where.return_value = mock_repo.statement + mocker.patch.object(mock_repo, "_filter_in_collection", return_value=mock_repo.statement) values = [1, 2, 3] - await maybe_async(mock_repo.list(CollectionFilter(field_name, values))) - - mock_repo.statement.where.assert_called_once() - getattr(mock_repo.model_type, field_name).in_.assert_called_once_with(values) + assert mock_repo._filter_in_collection.call_count == 1 + mock_repo._filter_in_collection.assert_called_with(field_name, values, statement=mock_repo.statement) async def test_sqlalchemy_repo_list_with_not_in_collection_filter( - mock_repo: SQLAlchemyAsyncRepository, monkeypatch: MonkeyPatch, mock_repo_execute: AnyMock + mock_repo: SQLAlchemyAsyncRepository, monkeypatch: MonkeyPatch, mock_repo_execute: AnyMock, mocker: MockerFixture ) -> None: """Test behavior of list operation given CollectionFilter.""" field_name = "id" mock_repo_execute.return_value = MagicMock() mock_repo.statement.where.return_value = mock_repo.statement + mocker.patch.object(mock_repo, "_filter_not_in_collection", return_value=mock_repo.statement) values = [1, 2, 3] - await maybe_async(mock_repo.list(NotInCollectionFilter(field_name, values))) - - mock_repo.statement.where.assert_called_once() - getattr(mock_repo.model_type, field_name).notin_.assert_called_once_with(values) + assert mock_repo._filter_not_in_collection.call_count == 1 + mock_repo._filter_not_in_collection.assert_called_with(field_name, values, statement=mock_repo.statement) async def test_sqlalchemy_repo_unknown_filter_type_raises(mock_repo: SQLAlchemyAsyncRepository) -> None: @@ -650,7 +664,7 @@ async def test_execute(mock_repo: SQLAlchemyAsyncRepository) -> None: def test_filter_in_collection_noop_if_collection_empty(mock_repo: SQLAlchemyAsyncRepository) -> None: """Ensures we don't filter on an empty collection.""" mock_repo._filter_in_collection("id", [], statement=mock_repo.statement) - mock_repo.statement.where.assert_not_called() # type: ignore + mock_repo.statement.where.assert_not_called() @pytest.mark.parametrize( @@ -670,18 +684,21 @@ def test_filter_on_datetime_field(before: datetime, after: datetime, mock_repo: mock_repo._filter_on_datetime_field("updated_at", before=before, after=after, statement=mock_repo.statement) -def test_filter_collection_by_kwargs(mock_repo: SQLAlchemyAsyncRepository) -> None: +def test_filter_collection_by_kwargs(mock_repo: SQLAlchemyAsyncRepository, mocker: MockerFixture) -> None: """Test `filter_by()` called with kwargs.""" + mock_repo_execute.return_value = MagicMock() + mock_repo.statement.where.return_value = mock_repo.statement + mocker.patch.object(mock_repo, "filter_collection_by_kwargs", return_value=mock_repo.statement) _ = mock_repo.filter_collection_by_kwargs(mock_repo.statement, a=1, b=2) - mock_repo.statement.filter_by.assert_called_once_with(a=1, b=2) + mock_repo.filter_collection_by_kwargs.assert_called_once_with(mock_repo.statement, a=1, b=2) def test_filter_collection_by_kwargs_raises_repository_exception_for_attribute_error( - mock_repo: SQLAlchemyAsyncRepository, + mock_repo: SQLAlchemyAsyncRepository, mocker: MockerFixture ) -> None: """Test that we raise a repository exception if an attribute name is incorrect.""" - mock_repo.statement.filter_by = MagicMock( # type:ignore[method-assign] + mock_repo.statement.filter_by = MagicMock( # pyright: ignore[reportGeneralTypeIssues] side_effect=InvalidRequestError, ) with pytest.raises(RepositoryError):