From 4aa30c13f79d5b1d87f4e15e6c357a9e69597063 Mon Sep 17 00:00:00 2001 From: Martin Vielsmaier Date: Fri, 6 Oct 2023 11:01:51 +0200 Subject: [PATCH] Fix typing --- depeche_db/_executor.py | 5 ++++- depeche_db/_link_stream.py | 6 +++--- depeche_db/_message_store.py | 10 +++++----- depeche_db/_storage.py | 12 ++++++------ depeche_db/_subscription.py | 4 ++-- depeche_db/tools/pg_notification_listener.py | 4 ++-- depeche_db/tools/pydantic_message_serializer.py | 17 +++++++++-------- examples/docs.py | 16 +++++++++------- .../subscription/test_streams_subscriptions.py | 2 +- 9 files changed, 41 insertions(+), 35 deletions(-) diff --git a/depeche_db/_executor.py b/depeche_db/_executor.py index 897e32a..590e2f4 100644 --- a/depeche_db/_executor.py +++ b/depeche_db/_executor.py @@ -1,5 +1,6 @@ import collections as _collections import signal as _signal +from typing import Callable from ._interfaces import RunOnNotification from .tools import PgNotificationListener @@ -8,7 +9,9 @@ class Executor: def __init__(self, db_dsn: str): self._db_dsn = db_dsn - self.channel_register = _collections.defaultdict(list) + self.channel_register: dict[ + str, list[Callable[[], None]] + ] = _collections.defaultdict(list) def register(self, handler: RunOnNotification): self.channel_register[handler.notification_channel].append(handler.run) diff --git a/depeche_db/_link_stream.py b/depeche_db/_link_stream.py index 0298a67..fe14f00 100644 --- a/depeche_db/_link_stream.py +++ b/depeche_db/_link_stream.py @@ -2,7 +2,7 @@ import contextlib as _contextlib import datetime as _dt import uuid as _uuid -from typing import Generic, Iterable, TypeVar +from typing import Generic, Iterator, TypeVar import sqlalchemy as _sa from psycopg2.errors import LockNotAvailable @@ -112,7 +112,7 @@ def add( ) ) - def read(self, conn: _sa.Connection, partition: int) -> Iterable[_uuid.UUID]: + def read(self, conn: _sa.Connection, partition: int) -> Iterator[_uuid.UUID]: for row in conn.execute( _sa.select(self._table.c.message_id) .where(self._table.c.partition == partition) @@ -133,7 +133,7 @@ def get_partition_statistics( self, position_limits: dict[int, int] = None, result_limit: int | None = None, - ) -> Iterable[StreamPartitionStatistic]: + ) -> Iterator[StreamPartitionStatistic]: with self._connection() as conn: position_limits = position_limits or {-1: -1} tbl = self._table.alias() diff --git a/depeche_db/_message_store.py b/depeche_db/_message_store.py index 7854105..26988e0 100644 --- a/depeche_db/_message_store.py +++ b/depeche_db/_message_store.py @@ -1,6 +1,6 @@ import contextlib as _contextlib import uuid as _uuid -from typing import Generic, Iterable, Iterator, Sequence, TypeVar +from typing import Generic, Iterator, Sequence, TypeVar import sqlalchemy as _sa @@ -39,7 +39,7 @@ def get_message_by_id(self, message_id: _uuid.UUID) -> StoredMessage[E]: def get_messages_by_ids( self, message_ids: list[_uuid.UUID] - ) -> Iterable[StoredMessage[E]]: + ) -> Iterator[StoredMessage[E]]: for row in self._storage.get_messages_by_ids( conn=self._conn, message_ids=message_ids ): @@ -52,7 +52,7 @@ def get_messages_by_ids( global_position=global_position, ) - def read(self, stream: str) -> Iterable[StoredMessage[E]]: + def read(self, stream: str) -> Iterator[StoredMessage[E]]: for message_id, version, message, global_position in self._storage.read( self._conn, stream ): @@ -64,7 +64,7 @@ def read(self, stream: str) -> Iterable[StoredMessage[E]]: global_position=global_position, ) - def read_wildcard(self, stream_wildcard: str) -> Iterable[StoredMessage[E]]: + def read_wildcard(self, stream_wildcard: str) -> Iterator[StoredMessage[E]]: for ( message_id, stream, @@ -162,6 +162,6 @@ def reader( with self._get_connection() as conn: yield self._get_reader(conn) - def read(self, stream: str) -> Iterable[StoredMessage[E]]: + def read(self, stream: str) -> Iterator[StoredMessage[E]]: with self.reader() as reader: yield from reader.read(stream) diff --git a/depeche_db/_storage.py b/depeche_db/_storage.py index 9bc9a2d..ab1c9e1 100644 --- a/depeche_db/_storage.py +++ b/depeche_db/_storage.py @@ -1,5 +1,5 @@ import uuid as _uuid -from typing import Iterable +from typing import Iterator import sqlalchemy as _sa from sqlalchemy_utils import UUIDType as _UUIDType @@ -118,7 +118,7 @@ def get_max_version(self, conn: _sa.Connection, stream: str) -> MessagePosition: def get_message_ids( self, conn: _sa.Connection, stream: str - ) -> Iterable[_uuid.UUID]: + ) -> Iterator[_uuid.UUID]: return conn.execute( _sa.select(self.message_table.c.message_id) .select_from(self.message_table) @@ -128,7 +128,7 @@ def get_message_ids( def read( self, conn: _sa.Connection, stream: str - ) -> Iterable[tuple[_uuid.UUID, int, dict, int]]: + ) -> Iterator[tuple[_uuid.UUID, int, dict, int]]: return conn.execute( # type: ignore _sa.select( self.message_table.c.message_id, @@ -143,7 +143,7 @@ def read( def read_multiple( self, conn: _sa.Connection, streams: list[str] - ) -> Iterable[tuple[_uuid.UUID, str, int, dict, int]]: + ) -> Iterator[tuple[_uuid.UUID, str, int, dict, int]]: return conn.execute( # type: ignore _sa.select( self.message_table.c.message_id, @@ -159,7 +159,7 @@ def read_multiple( def read_wildcard( self, conn: _sa.Connection, stream_wildcard: str - ) -> Iterable[tuple[_uuid.UUID, str, int, dict, int]]: + ) -> Iterator[tuple[_uuid.UUID, str, int, dict, int]]: return conn.execute( # type: ignore _sa.select( self.message_table.c.message_id, @@ -188,7 +188,7 @@ def get_message_by_id( def get_messages_by_ids( self, conn: _sa.Connection, message_ids: list[_uuid.UUID] - ) -> Iterable[tuple[_uuid.UUID, str, int, dict, int]]: + ) -> Iterator[tuple[_uuid.UUID, str, int, dict, int]]: return conn.execute( # type: ignore _sa.select( self.message_table.c.message_id, diff --git a/depeche_db/_subscription.py b/depeche_db/_subscription.py index 28f0f04..8508432 100644 --- a/depeche_db/_subscription.py +++ b/depeche_db/_subscription.py @@ -1,6 +1,6 @@ import contextlib as _contextlib import dataclasses as _dc -from typing import Callable, Generic, Iterator, TypeVar +from typing import Callable, Generic, Iterator, Type, TypeVar from ._interfaces import ( LockProvider, @@ -109,7 +109,7 @@ def exec(self, message: SubscriptionMessage): class SubscriptionHandler(Generic[E]): def __init__(self, subscription: Subscription[E]): self._subscription = subscription - self._handlers = {} + self._handlers: dict[Type[E], _Handler] = {} @property def notification_channel(self) -> str: diff --git a/depeche_db/tools/pg_notification_listener.py b/depeche_db/tools/pg_notification_listener.py index 88d7a2d..397374f 100644 --- a/depeche_db/tools/pg_notification_listener.py +++ b/depeche_db/tools/pg_notification_listener.py @@ -4,7 +4,7 @@ import queue import select import threading -from typing import Iterable +from typing import Iterator import psycopg2 @@ -34,7 +34,7 @@ def __init__( self._select_timeout = select_timeout self._queue_timeout = select_timeout / 2 - def messages(self) -> Iterable[PgNotification]: + def messages(self) -> Iterator[PgNotification]: while self._keep_running: try: yield self._queue.get(block=True, timeout=self._queue_timeout) diff --git a/depeche_db/tools/pydantic_message_serializer.py b/depeche_db/tools/pydantic_message_serializer.py index 7a48440..e06aa22 100644 --- a/depeche_db/tools/pydantic_message_serializer.py +++ b/depeche_db/tools/pydantic_message_serializer.py @@ -1,20 +1,21 @@ -from typing import Type, TypeVar +from typing import Any, TypeVar from .._interfaces import MessageSerializer -E = TypeVar("E") +T = TypeVar("T") -class PydanticMessageSerializer(MessageSerializer[E]): - # TODO fix typing when Union is given (see typing of pydantic TypeAdapter) - def __init__(self, message_type: Type[E]): - self.message_type = message_type +class PydanticMessageSerializer(MessageSerializer[T]): + # the real type would be: (self, message_type: Type[T]) + # but this is not supported by mypy (yet) + def __init__(self, message_type: Any) -> None: + self.message_type: type[T] = message_type - def serialize(self, message: E) -> dict: + def serialize(self, message: T) -> dict: # TODO pydantic v1 compatibility return message.model_dump(mode="json") # type: ignore - def deserialize(self, message: dict) -> E: + def deserialize(self, message: dict) -> T: import pydantic as _pydantic # TODO pydantic v1 compatibility diff --git a/examples/docs.py b/examples/docs.py index 7fd29c0..e60d0b3 100644 --- a/examples/docs.py +++ b/examples/docs.py @@ -33,7 +33,7 @@ class EventB(MyEvent): from depeche_db import MessageStore from depeche_db.tools import PydanticMessageSerializer -message_store = MessageStore( +message_store = MessageStore[EventA | EventB]( name="example_docs2", engine=db_engine, serializer=PydanticMessageSerializer(EventA | EventB), @@ -70,27 +70,29 @@ class EventB(MyEvent): class NumMessagePartitioner: - def get_partition(self, message: StoredMessage[EventA]) -> int: - return message.message.num % 3 + def get_partition(self, message: StoredMessage[EventA | EventB]) -> int: + if isinstance(message.message, EventA): + return message.message.num % 3 + return 0 link_stream = LinkStream( name="example_docs_aggregate_me2", store=message_store, ) -stream_projector = StreamProjector( +stream_projector = StreamProjector[EventA | EventB]( stream=link_stream, partitioner=NumMessagePartitioner(), stream_wildcards=["aggregate-me-%"], ) stream_projector.update_full() -result = next(link_stream.read(conn=db_engine.connect(), partition=0)) -print(result) +first_id = next(link_stream.read(conn=db_engine.connect(), partition=0)) +print(first_id) # 4680cbaf-977e-43a4-afcb-f88e92043e9c (this is the message ID of the first message in partition 0) with message_store.reader() as reader: - print(reader.get_message_by_id(result)) + print(reader.get_message_by_id(first_id)) # StoredMessage( # message_id=UUID("4680cbaf-977e-43a4-afcb-f88e92043e9c"), # stream="aggregate-me-0", diff --git a/tests/subscription/test_streams_subscriptions.py b/tests/subscription/test_streams_subscriptions.py index a72a958..679ab18 100644 --- a/tests/subscription/test_streams_subscriptions.py +++ b/tests/subscription/test_streams_subscriptions.py @@ -309,7 +309,7 @@ def test_subscription_handler(db_engine, stream, stream_projector): ) handler = SubscriptionHandler(subject) - seen = [] + seen: list[SubscriptionMessage[AccountRegisteredEvent] | AccountEvent] = [] @handler.register def handle_account_registered(event: SubscriptionMessage[AccountRegisteredEvent]):