Skip to content

Commit

Permalink
Use Correlation ID (#103)
Browse files Browse the repository at this point in the history
* Attach correlation ID to message headers of all messages associated with a particular task

* Fix correlation ID header name

* Fix tests and test correlation ID in stomp template
  • Loading branch information
callumforrester authored Mar 9, 2023
1 parent 79b0477 commit 8e754f3
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 33 deletions.
18 changes: 10 additions & 8 deletions src/blueapi/core/event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from abc import ABC, abstractmethod
from typing import Callable, Dict, Generic, TypeVar
from typing import Callable, Dict, Generic, Optional, TypeVar

#: Event type
E = TypeVar("E")
Expand All @@ -15,12 +15,12 @@ class EventStream(ABC, Generic[E, S]):
"""

@abstractmethod
def subscribe(self, __callback: Callable[[E], None]) -> S:
def subscribe(self, __callback: Callable[[E, Optional[str]], None]) -> S:
"""
Subscribe to new events with a callback
Args:
__callback (Callable[[E], None]): What to do with each event
__callback: What to do with each event, optionally takes a correlation id
Returns:
S: A unique token representing the subscription
Expand All @@ -47,14 +47,14 @@ class EventPublisher(EventStream[E, int]):
Simple Observable that can be fed values to publish
"""

_subscriptions: Dict[int, Callable[[E], None]]
_subscriptions: Dict[int, Callable[[E, Optional[str]], None]]
_count: itertools.count

def __init__(self) -> None:
self._subscriptions = {}
self._count = itertools.count()

def subscribe(self, callback: Callable[[E], None]) -> int:
def subscribe(self, callback: Callable[[E, Optional[str]], None]) -> int:
sub_id = next(self._count)
self._subscriptions[sub_id] = callback
return sub_id
Expand All @@ -65,13 +65,15 @@ def unsubscribe(self, subscription: int) -> None:
def unsubscribe_all(self) -> None:
self._subscriptions = {}

def publish(self, event: E) -> None:
def publish(self, event: E, correlation_id: Optional[str] = None) -> None:
"""
Publish a new event to all subscribers
Args:
event (E): The event to publish
event: The event to publish
correlation_id: An optional ID that may be used to correlate this
event with other events
"""

for callback in self._subscriptions.values():
callback(event)
callback(event, correlation_id)
4 changes: 3 additions & 1 deletion src/blueapi/messaging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def send_and_recieve(
destination: str,
obj: Any,
reply_type: Type = str,
correlation_id: Optional[str] = None,
) -> Future:
"""
Send a message expecting a single reply.
Expand All @@ -107,7 +108,7 @@ def callback(_: MessageContext, reply: Any) -> None:
future.set_result(reply)

callback.__annotations__["reply"] = reply_type
self.send(destination, obj, callback)
self.send(destination, obj, callback, correlation_id)
return future

@abstractmethod
Expand All @@ -116,6 +117,7 @@ def send(
__destination: str,
__obj: Any,
__on_reply: Optional[MessageListener] = None,
__correlation_id: Optional[str] = None,
) -> None:
"""
Send a message to a destination
Expand Down
1 change: 1 addition & 0 deletions src/blueapi/messaging/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ class MessageContext:

destination: str
reply_destination: Optional[str]
correlation_id: Optional[str]
19 changes: 14 additions & 5 deletions src/blueapi/messaging/stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

LOGGER = logging.getLogger(__name__)

CORRELATION_ID_HEADER = "correlation-id"


class StompDestinationProvider(DestinationProvider):
"""
Expand Down Expand Up @@ -104,19 +106,22 @@ def destinations(self) -> DestinationProvider:
return self._destination_provider

def send(
self, destination: str, obj: Any, on_reply: Optional[MessageListener] = None
self,
destination: str,
obj: Any,
on_reply: Optional[MessageListener] = None,
correlation_id: Optional[str] = None,
) -> None:
self._send_str(
destination,
json.dumps(serialize(obj)),
on_reply,
destination, json.dumps(serialize(obj)), on_reply, correlation_id
)

def _send_str(
self,
destination: str,
message: str,
on_reply: Optional[MessageListener] = None,
correlation_id: Optional[str] = None,
) -> None:
LOGGER.info(f"SENDING {message} to {destination}")

Expand All @@ -125,6 +130,8 @@ def _send_str(
reply_queue_name = self.destinations.temporary_queue(str(uuid.uuid1()))
headers = {**headers, "reply-to": reply_queue_name}
self.subscribe(reply_queue_name, on_reply)
if correlation_id:
headers = {**headers, CORRELATION_ID_HEADER: correlation_id}
self._conn.send(headers=headers, body=message, destination=destination)

def subscribe(self, destination: str, callback: MessageListener) -> None:
Expand All @@ -136,7 +143,9 @@ def wrapper(frame: Frame) -> None:
value = deserialize(obj_type, as_dict)

context = MessageContext(
frame.headers["destination"], frame.headers.get("reply-to")
frame.headers["destination"],
frame.headers.get("reply-to"),
frame.headers.get(CORRELATION_ID_HEADER),
)
callback(context, value)

Expand Down
12 changes: 8 additions & 4 deletions src/blueapi/service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,19 @@ def _publish_event_streams(
self._publish_event_stream(stream, destination)

def _publish_event_stream(self, stream: EventStream, destination: str) -> None:
stream.subscribe(lambda event: self._template.send(destination, event))
stream.subscribe(
lambda event, correlation_id: self._template.send(
destination, event, None, correlation_id
)
)

def _on_run_request(self, message_context: MessageContext, task: RunPlan) -> None:
name = str(uuid.uuid1())
self._worker.submit_task(name, task)
correlation_id = message_context.correlation_id or str(uuid.uuid1())
self._worker.submit_task(correlation_id, task)

reply_queue = message_context.reply_destination
if reply_queue is not None:
response = TaskResponse(name)
response = TaskResponse(correlation_id)
self._template.send(reply_queue, response)

def _get_plans(self, message_context: MessageContext, message: PlanRequest) -> None:
Expand Down
15 changes: 12 additions & 3 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,22 @@ def _report_status(
self._current.is_complete,
self._current.is_error or bool(errors),
)
correlation_id = self._current.name
else:
task_status = None
correlation_id = None

event = WorkerEvent(self._state, task_status, errors, warnings)
self._worker_events.publish(event)
self._worker_events.publish(event, correlation_id)

def _on_document(self, name: str, document: Mapping[str, Any]) -> None:
self._data_events.publish(DataEvent(name, document))
if self._current is not None:
correlation_id = self._current.name
self._data_events.publish(DataEvent(name, document), correlation_id)
else:
raise KeyError(
"Trying to emit a document despite the fact that the RunEngine is idle"
)

def _waiting_hook(self, statuses: Optional[Iterable[Status]]) -> None:
if statuses is not None:
Expand Down Expand Up @@ -218,5 +226,6 @@ def _publish_status_snapshot(self) -> None:
ProgressEvent(
self._current.name,
statuses=self._status_snapshot,
)
),
self._current.name,
)
4 changes: 2 additions & 2 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def __init_subclass__(cls, **kwargs):
)

@abstractmethod
def do_task(self, ctx: BlueskyContext) -> None:
def do_task(self, __ctx: BlueskyContext) -> None:
"""
Perform the task using the context
Args:
ctx (TaskContext): Context for the task
ctx: Context for the task, holds plans/device/etc
"""


Expand Down
25 changes: 17 additions & 8 deletions tests/core/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def publisher() -> EventPublisher[MyEvent]:
def test_publishes_event(publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
f: Future = Future()
publisher.subscribe(f.set_result)
publisher.subscribe(lambda r, _: f.set_result(r))
publisher.publish(event)
assert f.result(timeout=_TIMEOUT) == event

Expand All @@ -32,8 +32,8 @@ def test_multi_subscriber(publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
f1: Future = Future()
f2: Future = Future()
publisher.subscribe(f1.set_result)
publisher.subscribe(f2.set_result)
publisher.subscribe(lambda r, _: f1.set_result(r))
publisher.subscribe(lambda r, _: f2.set_result(r))
publisher.publish(event)
assert f1.result(timeout=_TIMEOUT) == f2.result(timeout=_TIMEOUT) == event

Expand All @@ -43,11 +43,11 @@ def test_can_unsubscribe(publisher: EventPublisher[MyEvent]) -> None:
event_b = MyEvent("b")
event_c = MyEvent("c")
q: Queue = Queue()
sub = publisher.subscribe(q.put)
sub = publisher.subscribe(lambda r, _: q.put(r))
publisher.publish(event_a)
publisher.unsubscribe(sub)
publisher.publish(event_b)
publisher.subscribe(q.put)
publisher.subscribe(lambda r, _: q.put(r))
publisher.publish(event_c)
assert list(_drain(q)) == [event_a, event_c]

Expand All @@ -57,16 +57,25 @@ def test_can_unsubscribe_all(publisher: EventPublisher[MyEvent]) -> None:
event_b = MyEvent("b")
event_c = MyEvent("c")
q: Queue = Queue()
publisher.subscribe(q.put)
publisher.subscribe(q.put)
publisher.subscribe(lambda r, _: q.put(r))
publisher.subscribe(lambda r, _: q.put(r))
publisher.publish(event_a)
publisher.unsubscribe_all()
publisher.publish(event_b)
publisher.subscribe(q.put)
publisher.subscribe(lambda r, _: q.put(r))
publisher.publish(event_c)
assert list(_drain(q)) == [event_a, event_a, event_c]


def test_correlation_id(publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
correlation_id = "foobar"
f: Future = Future()
publisher.subscribe(lambda _, c: f.set_result(c))
publisher.publish(event, correlation_id)
assert f.result(timeout=_TIMEOUT) == correlation_id


def _drain(queue: Queue) -> Iterable:
while not queue.empty():
yield queue.get_nowait()
34 changes: 32 additions & 2 deletions tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from concurrent.futures import Future
from dataclasses import dataclass
from queue import Queue
from typing import Any, Iterable, Type

import pytest
Expand Down Expand Up @@ -29,6 +30,11 @@ def test_queue(template: MessagingTemplate) -> str:
return template.destinations.queue(f"test-{next(_COUNT)}")


@pytest.fixture
def test_queue_2(template: MessagingTemplate) -> str:
return template.destinations.queue(f"test-{next(_COUNT)}")


@pytest.fixture
def test_topic(template: MessagingTemplate) -> str:
return template.destinations.topic(f"test-{next(_COUNT)}")
Expand Down Expand Up @@ -141,11 +147,35 @@ def test_reconnect(template: MessagingTemplate, test_queue: str) -> None:
assert reply == "ack"


def acknowledge(template: MessagingTemplate, test_queue: str) -> None:
@pytest.mark.stomp
def test_correlation_id(
template: MessagingTemplate, test_queue: str, test_queue_2: str
) -> None:
correlation_id = "foobar"
q: Queue = Queue()

def server(ctx: MessageContext, msg: str) -> None:
q.put(ctx)
template.send(test_queue_2, msg, None, ctx.correlation_id)

def client(ctx: MessageContext, msg: str) -> None:
q.put(ctx)

template.subscribe(test_queue, server)
template.subscribe(test_queue_2, client)
template.send(test_queue, "test", None, correlation_id)

ctx_req: MessageContext = q.get(timeout=_TIMEOUT)
assert ctx_req.correlation_id == correlation_id
ctx_ack: MessageContext = q.get(timeout=_TIMEOUT)
assert ctx_ack.correlation_id == correlation_id


def acknowledge(template: MessagingTemplate, destination: str) -> None:
def server(ctx: MessageContext, message: str) -> None:
reply_queue = ctx.reply_destination
if reply_queue is None:
raise RuntimeError("reply queue is None")
template.send(reply_queue, "ack")

template.subscribe(test_queue, server)
template.subscribe(destination, server)

0 comments on commit 8e754f3

Please sign in to comment.