Skip to content

Commit

Permalink
Better interface
Browse files Browse the repository at this point in the history
  • Loading branch information
moser committed Oct 25, 2023
1 parent 1fda392 commit 142e614
Show file tree
Hide file tree
Showing 32 changed files with 546 additions and 133 deletions.
1 change: 0 additions & 1 deletion CHANGELOG.md

This file was deleted.

15 changes: 15 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
gendocs: docs/generated/output README.md

gendocs-auto:
while true; do \
make gendocs; \
inotifywait -e modify **/*.py; \
done

docs/generated/output: docs/generated/getting_started.py docs/generated/_docgen.py
cd docs/generated && poetry run python _docgen.py getting_started.py

README.md: docs/generated/README.md examples/readme.py docs/generated/_genreadme.py
cd docs/generated && poetry run python _genreadme.py

.PHONY: gendocs gendocs-auto
63 changes: 62 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,68 @@ poetry add depeche-db
## Example

```python
# coming soon
import pydantic, sqlalchemy, uuid, datetime as dt

from depeche_db import (
MessageStore,
StoredMessage,
MessageHandler,
SubscriptionMessage,
)
from depeche_db.tools import PydanticMessageSerializer

DB_DSN = "postgresql://depeche:depeche@localhost:4888/depeche_demo"
db_engine = sqlalchemy.create_engine(DB_DSN)


class MyMessage(pydantic.BaseModel):
content: int
message_id: uuid.UUID = pydantic.Field(default_factory=uuid.uuid4)
sent_at: dt.datetime = pydantic.Field(default_factory=dt.datetime.utcnow)

def get_message_id(self) -> uuid.UUID:
return self.message_id

def get_message_time(self) -> dt.datetime:
return self.sent_at


message_store = MessageStore[MyMessage](
name="example_store",
engine=db_engine,
serializer=PydanticMessageSerializer(MyMessage),
)
message_store.write(stream="aggregate-me-1", message=MyMessage(content=2))
print(list(message_store.read(stream="aggregate-me-1")))
# [StoredMessage(message_id=UUID('...'), stream='aggregate-me-1', version=1, message=MyMessage(content=2, message_id=UUID('...'), sent_at=datetime.datetime(...)), global_position=1)]


class ContentMessagePartitioner:
def get_partition(self, message: StoredMessage[MyMessage]) -> int:
return message.message.content % 10


class MyHandlers(MessageHandler[MyMessage]):
@MessageHandler.register
def handle_message(self, message: SubscriptionMessage[MyMessage]):
print(message)


aggregated_stream = message_store.aggregated_stream(
name="aggregated",
partitioner=ContentMessagePartitioner(),
stream_wildcards=["aggregate-me-%"],
)
subscription = aggregated_stream.subscription(
name="example_subscription",
handlers=MyHandlers(),
)

aggregated_stream.projector.run()
subscription.runner.run()
# MyHandlers.handle_message prints:
# SubscriptionMessage(partition=2, position=0, stored_message=StoredMessage(...))

```


Expand Down
2 changes: 2 additions & 0 deletions depeche_db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from ._aggregated_stream import AggregatedStream, StreamProjector # noqa: F401
from ._executor import Executor # noqa: F401
from ._factories import AggregatedStreamFactory, SubscriptionFactory # noqa: F401
from ._interfaces import ( # noqa: F401
CallMiddleware,
ErrorAction,
HandlerDescriptor,
LockProvider,
MessageHandlerRegisterProtocol,
MessagePartitioner,
MessagePosition,
MessageProtocol,
Expand Down
27 changes: 15 additions & 12 deletions depeche_db/_aggregated_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlalchemy_utils import UUIDType as _UUIDType

from ._compat import SAConnection
from ._factories import SubscriptionFactory
from ._interfaces import (
AggregatedStreamMessage,
MessagePartitioner,
Expand Down Expand Up @@ -39,9 +40,11 @@ def __init__(
Attributes:
name (str): Stream name
projector (StreamProjector): Stream projector
subscription (SubscriptionFactory): Factory to create subscriptions
"""
assert name.isidentifier(), "name must be a valid identifier"
self.name = name
self.subscription = SubscriptionFactory(self)
self._store = store
self._metadata = _sa.MetaData()
self._table = _sa.Table(
Expand Down Expand Up @@ -133,24 +136,24 @@ def add(
)
)

def read(
self, conn: SAConnection, partition: int
) -> Iterator[AggregatedStreamMessage]:
def read(self, partition: int) -> Iterator[AggregatedStreamMessage]:
"""
Read all messages from a partition of the aggregated stream.
Args:
conn: Database connection
partition: Partition number
"""
for row in conn.execute(
_sa.select(self._table.c.message_id, self._table.c.position)
.where(self._table.c.partition == partition)
.order_by(self._table.c.position)
):
yield AggregatedStreamMessage(
message_id=row.message_id, position=row.position, partition=partition
)
with self._connection() as conn:
for row in conn.execute(
_sa.select(self._table.c.message_id, self._table.c.position)
.where(self._table.c.partition == partition)
.order_by(self._table.c.position)
):
yield AggregatedStreamMessage(
message_id=row.message_id,
position=row.position,
partition=partition,
)

def read_slice(
self, partition: int, start: int, count: int
Expand Down
96 changes: 96 additions & 0 deletions depeche_db/_factories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import (
TYPE_CHECKING,
Generic,
List,
Optional,
TypeVar,
)

from ._interfaces import (
CallMiddleware,
LockProvider,
MessageHandlerRegisterProtocol,
MessagePartitioner,
MessageProtocol,
SubscriptionErrorHandler,
SubscriptionStateProvider,
)

if TYPE_CHECKING:
from ._aggregated_stream import AggregatedStream
from ._message_store import MessageStore
from ._subscription import Subscription

E = TypeVar("E", bound=MessageProtocol)


class AggregatedStreamFactory(Generic[E]):
def __init__(self, store: "MessageStore[E]"):
self._store = store

def __call__(
self,
name: str,
partitioner: "MessagePartitioner[E]",
stream_wildcards: List[str],
) -> "AggregatedStream[E]":
"""
Create an aggregated stream.
Args:
name: The name of the stream
partitioner: A partitioner for the stream
stream_wildcards: A list of stream wildcards to be aggregated
"""
from ._aggregated_stream import AggregatedStream

return AggregatedStream(
store=self._store,
name=name,
partitioner=partitioner,
stream_wildcards=stream_wildcards,
)


class SubscriptionFactory(Generic[E]):
def __init__(self, stream: "AggregatedStream[E]"):
self._stream = stream

def __call__(
self,
name: str,
handlers: MessageHandlerRegisterProtocol[E] = None,
call_middleware: Optional[CallMiddleware] = None,
error_handler: Optional[SubscriptionErrorHandler] = None,
state_provider: Optional[SubscriptionStateProvider] = None,
lock_provider: Optional[LockProvider] = None,
) -> "Subscription[E]":
"""
Create a subscription.
Args:
name: The name of the subscription
handlers: Handlers to be called when a message is received
call_middleware: A middleware to be called before the handlers
error_handler: A handler for errors raised by the handlers
state_provider: A provider for the subscription state
lock_provider: A provider for the subscription locks
"""
from ._message_handler import MessageHandlerRegister
from ._subscription import Subscription, SubscriptionMessageHandler

if handlers is None:
# allow constructing a subscription without handlers
handlers = MessageHandlerRegister()

return Subscription(
name=name,
stream=self._stream,
message_handler=SubscriptionMessageHandler(
handler_register=handlers,
call_middleware=call_middleware,
error_handler=error_handler,
),
state_provider=state_provider,
lock_provider=lock_provider,
)
13 changes: 9 additions & 4 deletions depeche_db/_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
TypeVar,
Union,
no_type_check,
runtime_checkable,
)


class MessageProtocol:
@runtime_checkable
class MessageProtocol(Protocol):
"""
Message protocol is a base class for all messages that are used in the system.
"""
Expand Down Expand Up @@ -240,9 +242,8 @@ class RunOnNotification(Protocol):
can be registered with a [Executor][depeche_db.Executor] object.
Implemented by:
- [SubscriptionRunner][depeche_db.SubscriptionRunner]
- [StreamProjector][depeche_db.StreamProjector]
- [SubscriptionRunner][depeche_db.SubscriptionRunner]
- [StreamProjector][depeche_db.StreamProjector]
"""

@property
Expand Down Expand Up @@ -341,6 +342,10 @@ def adapt_message_type(
class MessageHandlerRegisterProtocol(Protocol, Generic[E]):
"""
Message handler register protocol is used by runners to find handlers for messages.
Implemented by:
- [MessageHandlerRegister][depeche_db.MessageHandlerRegister]
- [MessageHandler][depeche_db.MessageHandler]
"""

def get_all_handlers(self) -> Iterator[HandlerDescriptor[E]]:
Expand Down
54 changes: 54 additions & 0 deletions depeche_db/_message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,47 @@


class MessageHandlerRegister(Generic[E]):
"""
Message handler register is a registry of message handlers.
Typical usage:
handlers = MessageHandlerRegister()
@handlers.register
def handle_message(message: MyMessage):
...
@handlers.register
def handle_other_message(message: StoredMessage[MyOtherMessage]):
...
Implements: [MessageHandlerRegisterProtocol][depeche_db.MessageHandlerRegisterProtocol]
"""

def __init__(
self,
):
self._handlers: Dict[Type[E], HandlerDescriptor] = {}

def register(self, handler: H) -> H:
"""
Registers a handler for a given message type.
The handler must have at least one parameter. The first parameter must
be of a message type. `E` being your message type, the parameter can be
of type `E`, `SubscriptionMessage[E]` or `StoredMessage[E]`. When a
handler is called, the message will be passed in the requested type.
Multiple handlers can be registered for non-overlapping types of
messages. Overlaps will cause a `ValueError`.
Args:
handler: A handler function.
Returns:
The unaltered handler function.
"""
signature = _inspect.signature(handler)
if len(signature.parameters) < 1:
raise ValueError("Handler must have at least one parameter")
Expand Down Expand Up @@ -79,6 +114,25 @@ def get_all_handlers(self) -> _typing.Iterator[HandlerDescriptor[E]]:


class MessageHandler(Generic[E]):
"""
Message handler is a base class for message handlers.
This is basically a class-based version of the `MessageHandlerRegister`.
Typical usage (equivalent to the example in `MessageHandlerRegister`):
class MyMessageHandler(MessageHandler):
@MessageHandler.register
def handle_message(self, message: MyMessage):
...
@MessageHandler.register
def handle_other_message(self, message: StoredMessage[MyOtherMessage]):
...
Implements: [MessageHandlerRegisterProtocol][depeche_db.MessageHandlerRegisterProtocol]
"""

_register: MessageHandlerRegister[E]

def __init__(self):
Expand Down
5 changes: 5 additions & 0 deletions depeche_db/_message_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ._compat import SAConnection
from ._exceptions import MessageNotFound
from ._factories import AggregatedStreamFactory
from ._interfaces import (
MessagePosition,
MessageProtocol,
Expand Down Expand Up @@ -131,10 +132,14 @@ def __init__(
the database objects that are created.
engine (Engine): A SQLAlchemy engine.
serializer (MessageSerializer): A serializer for the messages.
Attributes:
aggregated_stream (AggregatedStreamFactory): A factory for aggregated streams.
"""
self.engine = engine
self._storage = Storage(name=name, engine=engine)
self._serializer = serializer
self.aggregated_stream = AggregatedStreamFactory(store=self)

def _get_connection(self) -> SAConnection:
return self.engine.connect()
Expand Down
Loading

0 comments on commit 142e614

Please sign in to comment.