Skip to content

Commit

Permalink
Auto-reconnect to Stomp when connection lost (#101)
Browse files Browse the repository at this point in the history
* Working reconnect

* Working tests for stomp template

* Skip stomp tests with flag

* Include activemq server in stomp tests (CI)
  • Loading branch information
callumforrester authored Mar 8, 2023
1 parent 71e54a2 commit 8e95744
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 28 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ jobs:
python: "3.8"
install: ".[dev]"

services:
activemq:
image: rmohr/activemq:5.14.5-alpine
ports:
- 61613:61613

runs-on: ${{ matrix.os }}
env:
# https://github.com/pytest-dev/pytest/issues/2042
Expand Down
92 changes: 80 additions & 12 deletions src/blueapi/messaging/stomptemplate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import itertools
import json
import logging
import time
import uuid
from typing import Any, Callable, Dict, Optional
from dataclasses import dataclass
from threading import Event
from typing import Any, Callable, Dict, List, Optional, Set

import stomp
from apischema import deserialize, serialize
from stomp.exception import ConnectFailedException
from stomp.utils import Frame

from blueapi.config import StompConfig
Expand Down Expand Up @@ -36,22 +40,51 @@ def temporary_queue(self, name: str) -> str:
default = queue


@dataclass
class StompReconnectPolicy:
"""
Details of how often stomp will try to reconnect if connection is unexpectedly lost
"""

initial_delay: float = 0.0
attempt_period: float = 10.0


@dataclass
class Subscription:
"""
Details of a subscription, the template needs its own representation to
defer subscriptions until after connection
"""

destination: str
callback: Callable[[Frame], None]


class StompMessagingTemplate(MessagingTemplate):
"""
MessagingTemplate that uses the stompp protocol, meant for use
with ActiveMQ.
"""

_conn: stomp.Connection
_reconnect_policy: StompReconnectPolicy
_sub_num: itertools.count
_listener: stomp.ConnectionListener
_subscriptions: Dict[str, Callable[[Frame], None]]
_subscriptions: Dict[str, Subscription]
_pending_subscriptions: Set[str]
_disconnected: Event

# Stateless implementation means attribute can be static
_destination_provider: DestinationProvider = StompDestinationProvider()

def __init__(self, conn: stomp.Connection) -> None:
def __init__(
self,
conn: stomp.Connection,
reconnect_policy: Optional[StompReconnectPolicy] = None,
) -> None:
self._conn = conn
self._reconnect_policy = reconnect_policy or StompReconnectPolicy()
self._sub_num = itertools.count()
self._listener = stomp.ConnectionListener()

Expand Down Expand Up @@ -95,8 +128,7 @@ def _send_str(
self._conn.send(headers=headers, body=message, destination=destination)

def subscribe(self, destination: str, callback: MessageListener) -> None:
LOGGER.info(f"New subscription to {destination}")

LOGGER.debug(f"New subscription to {destination}")
obj_type = determine_deserialization_type(callback, default=str)

def wrapper(frame: Frame) -> None:
Expand All @@ -109,22 +141,58 @@ def wrapper(frame: Frame) -> None:
callback(context, value)

sub_id = str(next(self._sub_num))

self._subscriptions[sub_id] = wrapper
self._conn.subscribe(destination=destination, id=sub_id, ack="auto")
self._subscriptions[sub_id] = Subscription(destination, wrapper)
# If we're connected, subscribe immediately, otherwise the subscription is
# deferred until connection.
self._ensure_subscribed([sub_id])

def connect(self) -> None:
self._conn.connect()
LOGGER.info("Connecting...")
self._conn.connect(wait=True)
self._listener.on_disconnected = self._on_disconnected
self._ensure_subscribed()

def _ensure_subscribed(self, sub_ids: Optional[List[str]] = None) -> None:
# We must defer subscription until after connection, because stomp literally
# sends a SUB to the broker. But it still nice to be able to call subscribe
# on template before it connects, then just run the subscribes after connection.
if self._conn.is_connected():
for sub_id in sub_ids or self._subscriptions.keys():
sub = self._subscriptions[sub_id]
LOGGER.info(f"Subscribing to {sub.destination}")
self._conn.subscribe(destination=sub.destination, id=sub_id, ack="auto")

def disconnect(self) -> None:
LOGGER.info("Disconnecting...")

# We need to synchronise the disconnect on an event because the stomp Connection
# object doesn't do it for us
disconnected = Event()
self._listener.on_disconnected = disconnected.set
self._conn.disconnect()
disconnected.wait()
self._listener.on_disconnected = None

@handle_all_exceptions
def _on_disconnected(self) -> None:
LOGGER.warn(
"Stomp connection lost, will attempt reconnection with "
f"policy {self._reconnect_policy}"
)
time.sleep(self._reconnect_policy.initial_delay)
while not self._conn.is_connected():
try:
self.connect()
except ConnectFailedException as ex:
LOGGER.error("Reconnect failed", ex)
time.sleep(self._reconnect_policy.attempt_period)

@handle_all_exceptions
def _on_message(self, frame: Frame) -> None:
LOGGER.info(f"Recieved {frame}")
sub_id = frame.headers.get("subscription")
callback = self._subscriptions.get(sub_id)
if callback is not None:
callback(frame)
sub = self._subscriptions.get(sub_id)
if sub is not None:
sub.callback(frame)
else:
LOGGER.warn(f"No subscription active for id: {sub_id}")
4 changes: 2 additions & 2 deletions src/blueapi/service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def run(self) -> None:
}
)

self._template.connect()

self._template.subscribe("worker.run", self._on_run_request)
self._template.subscribe("worker.plans", self._get_plans)
self._template.subscribe("worker.devices", self._get_devices)

self._template.connect()

self._worker.run_forever()

def _publish_event_streams(
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Based on https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # noqa: E501

import pytest


def pytest_addoption(parser):
parser.addoption(
"--skip-stomp",
action="store_true",
default=False,
help="skip stomp tests (e.g. because a server is unavailable)",
)


def pytest_configure(config):
config.addinivalue_line("markers", "stomp: mark test as requiring stomp broker")


def pytest_collection_modifyitems(config, items):
if config.getoption("--skip-stomp"):
skip_stomp = pytest.mark.skip(reason="skipping stomp tests at user request")
for item in items:
if "stomp" in item.keywords:
item.add_marker(skip_stomp)
129 changes: 115 additions & 14 deletions tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,121 @@
import json
from typing import Tuple
from unittest.mock import MagicMock
import itertools
from concurrent.futures import Future
from dataclasses import dataclass
from typing import Any, Iterable, Type

from blueapi.messaging import MessagingTemplate, StompMessagingTemplate
import pytest

from blueapi.config import StompConfig
from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate

def test_send() -> None:
template, conn = _mock_template()
template.send("test", "test_message")
conn.send.assert_called_once_with(
body=json.dumps("test_message"),
destination="test",
headers={},
_TIMEOUT: float = 10.0
_COUNT = itertools.count()


@pytest.fixture
def disconnected_template() -> MessagingTemplate:
return StompMessagingTemplate.autoconfigured(StompConfig())


@pytest.fixture
def template(disconnected_template: MessagingTemplate) -> Iterable[MessagingTemplate]:
disconnected_template.connect()
yield disconnected_template
disconnected_template.disconnect()


@pytest.fixture
def test_queue(template: MessagingTemplate) -> str:
return template.destinations.queue(f"test-{next(_COUNT)}")


@pytest.mark.stomp
def test_send(template: MessagingTemplate, test_queue: str) -> None:
f: Future = Future()

def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.subscribe(test_queue, callback)
template.send(test_queue, "test_message")
assert f.result(timeout=_TIMEOUT)


@pytest.mark.stomp
def test_send_on_reply(template: MessagingTemplate, test_queue: str) -> None:
acknowledge(template, test_queue)

f: Future = Future()

def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.send(test_queue, "test_message", callback)
assert f.result(timeout=_TIMEOUT)


@pytest.mark.stomp
def test_send_and_recieve(template: MessagingTemplate, test_queue: str) -> None:
acknowledge(template, test_queue)
reply = template.send_and_recieve(test_queue, "test", str).result(timeout=_TIMEOUT)
assert reply == "ack"


@dataclass
class Foo:
a: int
b: str


@pytest.mark.stomp
@pytest.mark.parametrize(
"message,message_type",
[("test", str), (1, int), (Foo(1, "test"), Foo)],
)
def test_deserialization(
template: MessagingTemplate, test_queue: str, message: Any, message_type: Type
) -> None:
def server(ctx: MessageContext, message: message_type) -> None: # type: ignore
reply_queue = ctx.reply_destination
if reply_queue is None:
raise RuntimeError("reply queue is None")
template.send(reply_queue, message)

template.subscribe(test_queue, server)
reply = template.send_and_recieve(test_queue, message, message_type).result(
timeout=_TIMEOUT
)
assert reply == message


@pytest.mark.stomp
def test_subscribe_before_connect(
disconnected_template: MessagingTemplate, test_queue: str
) -> None:
acknowledge(disconnected_template, test_queue)
disconnected_template.connect()
reply = disconnected_template.send_and_recieve(test_queue, "test", str).result(
timeout=_TIMEOUT
)
assert reply == "ack"


@pytest.mark.stomp
def test_reconnect(template: MessagingTemplate, test_queue: str) -> None:
acknowledge(template, test_queue)
reply = template.send_and_recieve(test_queue, "test", str).result(timeout=_TIMEOUT)
assert reply == "ack"
template.disconnect()
template.connect()
reply = template.send_and_recieve(test_queue, "test", str).result(timeout=_TIMEOUT)
assert reply == "ack"


def acknowledge(template: MessagingTemplate, test_queue: str) -> None:
def server(ctx: MessageContext, message: str) -> None:
reply_queue = ctx.reply_destination
if reply_queue is None:
raise RuntimeError("reply queue is None")
template.send(reply_queue, "ack")

def _mock_template() -> Tuple[MessagingTemplate, MagicMock]:
conn = MagicMock()
return StompMessagingTemplate(conn), conn
template.subscribe(test_queue, server)

0 comments on commit 8e95744

Please sign in to comment.