Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
moser committed Oct 6, 2023
1 parent 22ed195 commit 4aa30c1
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 35 deletions.
5 changes: 4 additions & 1 deletion depeche_db/_executor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections as _collections
import signal as _signal
from typing import Callable

from ._interfaces import RunOnNotification
from .tools import PgNotificationListener
Expand All @@ -8,7 +9,9 @@
class Executor:
def __init__(self, db_dsn: str):
self._db_dsn = db_dsn
self.channel_register = _collections.defaultdict(list)
self.channel_register: dict[
str, list[Callable[[], None]]
] = _collections.defaultdict(list)

def register(self, handler: RunOnNotification):
self.channel_register[handler.notification_channel].append(handler.run)
Expand Down
6 changes: 3 additions & 3 deletions depeche_db/_link_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import contextlib as _contextlib
import datetime as _dt
import uuid as _uuid
from typing import Generic, Iterable, TypeVar
from typing import Generic, Iterator, TypeVar

import sqlalchemy as _sa
from psycopg2.errors import LockNotAvailable
Expand Down Expand Up @@ -112,7 +112,7 @@ def add(
)
)

def read(self, conn: _sa.Connection, partition: int) -> Iterable[_uuid.UUID]:
def read(self, conn: _sa.Connection, partition: int) -> Iterator[_uuid.UUID]:
for row in conn.execute(
_sa.select(self._table.c.message_id)
.where(self._table.c.partition == partition)
Expand All @@ -133,7 +133,7 @@ def get_partition_statistics(
self,
position_limits: dict[int, int] = None,
result_limit: int | None = None,
) -> Iterable[StreamPartitionStatistic]:
) -> Iterator[StreamPartitionStatistic]:
with self._connection() as conn:
position_limits = position_limits or {-1: -1}
tbl = self._table.alias()
Expand Down
10 changes: 5 additions & 5 deletions depeche_db/_message_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib as _contextlib
import uuid as _uuid
from typing import Generic, Iterable, Iterator, Sequence, TypeVar
from typing import Generic, Iterator, Sequence, TypeVar

import sqlalchemy as _sa

Expand Down Expand Up @@ -39,7 +39,7 @@ def get_message_by_id(self, message_id: _uuid.UUID) -> StoredMessage[E]:

def get_messages_by_ids(
self, message_ids: list[_uuid.UUID]
) -> Iterable[StoredMessage[E]]:
) -> Iterator[StoredMessage[E]]:
for row in self._storage.get_messages_by_ids(
conn=self._conn, message_ids=message_ids
):
Expand All @@ -52,7 +52,7 @@ def get_messages_by_ids(
global_position=global_position,
)

def read(self, stream: str) -> Iterable[StoredMessage[E]]:
def read(self, stream: str) -> Iterator[StoredMessage[E]]:
for message_id, version, message, global_position in self._storage.read(
self._conn, stream
):
Expand All @@ -64,7 +64,7 @@ def read(self, stream: str) -> Iterable[StoredMessage[E]]:
global_position=global_position,
)

def read_wildcard(self, stream_wildcard: str) -> Iterable[StoredMessage[E]]:
def read_wildcard(self, stream_wildcard: str) -> Iterator[StoredMessage[E]]:
for (
message_id,
stream,
Expand Down Expand Up @@ -162,6 +162,6 @@ def reader(
with self._get_connection() as conn:
yield self._get_reader(conn)

def read(self, stream: str) -> Iterable[StoredMessage[E]]:
def read(self, stream: str) -> Iterator[StoredMessage[E]]:
with self.reader() as reader:
yield from reader.read(stream)
12 changes: 6 additions & 6 deletions depeche_db/_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import uuid as _uuid
from typing import Iterable
from typing import Iterator

import sqlalchemy as _sa
from sqlalchemy_utils import UUIDType as _UUIDType
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_max_version(self, conn: _sa.Connection, stream: str) -> MessagePosition:

def get_message_ids(
self, conn: _sa.Connection, stream: str
) -> Iterable[_uuid.UUID]:
) -> Iterator[_uuid.UUID]:
return conn.execute(
_sa.select(self.message_table.c.message_id)
.select_from(self.message_table)
Expand All @@ -128,7 +128,7 @@ def get_message_ids(

def read(
self, conn: _sa.Connection, stream: str
) -> Iterable[tuple[_uuid.UUID, int, dict, int]]:
) -> Iterator[tuple[_uuid.UUID, int, dict, int]]:
return conn.execute( # type: ignore
_sa.select(
self.message_table.c.message_id,
Expand All @@ -143,7 +143,7 @@ def read(

def read_multiple(
self, conn: _sa.Connection, streams: list[str]
) -> Iterable[tuple[_uuid.UUID, str, int, dict, int]]:
) -> Iterator[tuple[_uuid.UUID, str, int, dict, int]]:
return conn.execute( # type: ignore
_sa.select(
self.message_table.c.message_id,
Expand All @@ -159,7 +159,7 @@ def read_multiple(

def read_wildcard(
self, conn: _sa.Connection, stream_wildcard: str
) -> Iterable[tuple[_uuid.UUID, str, int, dict, int]]:
) -> Iterator[tuple[_uuid.UUID, str, int, dict, int]]:
return conn.execute( # type: ignore
_sa.select(
self.message_table.c.message_id,
Expand Down Expand Up @@ -188,7 +188,7 @@ def get_message_by_id(

def get_messages_by_ids(
self, conn: _sa.Connection, message_ids: list[_uuid.UUID]
) -> Iterable[tuple[_uuid.UUID, str, int, dict, int]]:
) -> Iterator[tuple[_uuid.UUID, str, int, dict, int]]:
return conn.execute( # type: ignore
_sa.select(
self.message_table.c.message_id,
Expand Down
4 changes: 2 additions & 2 deletions depeche_db/_subscription.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib as _contextlib
import dataclasses as _dc
from typing import Callable, Generic, Iterator, TypeVar
from typing import Callable, Generic, Iterator, Type, TypeVar

from ._interfaces import (
LockProvider,
Expand Down Expand Up @@ -109,7 +109,7 @@ def exec(self, message: SubscriptionMessage):
class SubscriptionHandler(Generic[E]):
def __init__(self, subscription: Subscription[E]):
self._subscription = subscription
self._handlers = {}
self._handlers: dict[Type[E], _Handler] = {}

@property
def notification_channel(self) -> str:
Expand Down
4 changes: 2 additions & 2 deletions depeche_db/tools/pg_notification_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import queue
import select
import threading
from typing import Iterable
from typing import Iterator

import psycopg2

Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
self._select_timeout = select_timeout
self._queue_timeout = select_timeout / 2

def messages(self) -> Iterable[PgNotification]:
def messages(self) -> Iterator[PgNotification]:
while self._keep_running:
try:
yield self._queue.get(block=True, timeout=self._queue_timeout)
Expand Down
17 changes: 9 additions & 8 deletions depeche_db/tools/pydantic_message_serializer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from typing import Type, TypeVar
from typing import Any, TypeVar

from .._interfaces import MessageSerializer

E = TypeVar("E")
T = TypeVar("T")


class PydanticMessageSerializer(MessageSerializer[E]):
# TODO fix typing when Union is given (see typing of pydantic TypeAdapter)
def __init__(self, message_type: Type[E]):
self.message_type = message_type
class PydanticMessageSerializer(MessageSerializer[T]):
# the real type would be: (self, message_type: Type[T])
# but this is not supported by mypy (yet)
def __init__(self, message_type: Any) -> None:
self.message_type: type[T] = message_type

def serialize(self, message: E) -> dict:
def serialize(self, message: T) -> dict:
# TODO pydantic v1 compatibility
return message.model_dump(mode="json") # type: ignore

def deserialize(self, message: dict) -> E:
def deserialize(self, message: dict) -> T:
import pydantic as _pydantic

# TODO pydantic v1 compatibility
Expand Down
16 changes: 9 additions & 7 deletions examples/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class EventB(MyEvent):
from depeche_db import MessageStore
from depeche_db.tools import PydanticMessageSerializer

message_store = MessageStore(
message_store = MessageStore[EventA | EventB](
name="example_docs2",
engine=db_engine,
serializer=PydanticMessageSerializer(EventA | EventB),
Expand Down Expand Up @@ -70,27 +70,29 @@ class EventB(MyEvent):


class NumMessagePartitioner:
def get_partition(self, message: StoredMessage[EventA]) -> int:
return message.message.num % 3
def get_partition(self, message: StoredMessage[EventA | EventB]) -> int:
if isinstance(message.message, EventA):
return message.message.num % 3
return 0


link_stream = LinkStream(
name="example_docs_aggregate_me2",
store=message_store,
)
stream_projector = StreamProjector(
stream_projector = StreamProjector[EventA | EventB](
stream=link_stream,
partitioner=NumMessagePartitioner(),
stream_wildcards=["aggregate-me-%"],
)
stream_projector.update_full()

result = next(link_stream.read(conn=db_engine.connect(), partition=0))
print(result)
first_id = next(link_stream.read(conn=db_engine.connect(), partition=0))
print(first_id)
# 4680cbaf-977e-43a4-afcb-f88e92043e9c (this is the message ID of the first message in partition 0)

with message_store.reader() as reader:
print(reader.get_message_by_id(result))
print(reader.get_message_by_id(first_id))
# StoredMessage(
# message_id=UUID("4680cbaf-977e-43a4-afcb-f88e92043e9c"),
# stream="aggregate-me-0",
Expand Down
2 changes: 1 addition & 1 deletion tests/subscription/test_streams_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_subscription_handler(db_engine, stream, stream_projector):
)
handler = SubscriptionHandler(subject)

seen = []
seen: list[SubscriptionMessage[AccountRegisteredEvent] | AccountEvent] = []

@handler.register
def handle_account_registered(event: SubscriptionMessage[AccountRegisteredEvent]):
Expand Down

0 comments on commit 4aa30c1

Please sign in to comment.