Skip to content

Commit

Permalink
redis: add support for broker priority
Browse files Browse the repository at this point in the history
Create different queues for priority steps and consuming them in that order.
  • Loading branch information
davidt99 committed Oct 29, 2023
1 parent 0c02bb0 commit cc81586
Show file tree
Hide file tree
Showing 7 changed files with 297 additions and 17 deletions.
43 changes: 43 additions & 0 deletions docs/source/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,49 @@ failover if the currently connected node fails.
.. _high availability cluster: https://www.rabbitmq.com/ha.html
.. _connection parameters: https://pika.readthedocs.io/en/0.12.0/modules/parameters.html

Broker Priority Queues
^^^^^^^^^^^^^^^^^^^^^^^
Dramatiq supports priority queues on RabbitMQ and Redis.
To configure the broker to work with priority queues, you should set the ``max_priority`` parameter of the broker.
To enqueue a message with a priority, you should set the ``broker_priority`` parameter of the |Message|'s options.


.. code-block:: python
import dramatiq
from dramatiq.brokers.rabbitmq import RabbitmqBroker
# Using max_priority parameter:
rabbitmq_broker = RabbitmqBroker(url="amqp://guest:guest@127.0.0.1:5672", max_priority=10)
# Define an actor with priority:
@dramatiq.actor(broker=rabbitmq_broker)
def operation(priority):
print(priority)
# Enqueue a message with priority (lower number means higher priority):
operation.send_with_options(args=(3,), options={"broker_priority": 3})
operation.send_with_options(args=(2,), options={"broker_priority": 2})
operation.send_with_options(args=(1,), options={"broker_priority": 1})
RabbitMQ
~~~~~~~~
Dramatiq supports RabbitMQ's `priority queues`_ feature.
To use it, you should set the ``max_priority`` parameter of the |RabbitmqBroker| to a value
between 0 and 255 (10 is the recommended values).

.. `priority queues`_: https://www.rabbitmq.com/priority.html
Redis
~~~~~
Dramatiq take similar approach as celery to implement priority queues on Redis.
To use it, you should set the ``max_priority`` parameter of the |RedisBroker| to a value up to 10.
The broker will created multiple queues for each priority level defined by ``priority_steps``, (default is 4 steps).
The consumer will consume messages from the highest priority queue first.
This method isn't as reliable as RabbitMQ's, for example, large number of messages ending up in the same queue.

Other brokers
^^^^^^^^^^^^^
Expand Down
10 changes: 10 additions & 0 deletions dramatiq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,16 @@ def join(self, queue_name, *, timeout=None): # pragma: no cover
"""
raise NotImplementedError()

def get_queue_size(self, queue_name: str):
"""
Get the number of messages in a queue. This method is only meant to be used in unit and integration tests.
Parameters:
queue_name(str): The queue whose message counts to get.
Returns: The number of messages in the queue, including the delay queue
"""
raise NotImplementedError()


class Consumer:
"""Consumers iterate over messages on a queue.
Expand Down
10 changes: 10 additions & 0 deletions dramatiq/brokers/rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,16 @@ def join(self, queue_name, min_successes=10, idle_time=100, *, timeout=None):

self.connection.sleep(idle_time / 1000)

def get_queue_size(self, queue_name: str):
"""
Get the number of messages in a queue. This method is only meant to be used in unit and integration tests.
Parameters:
queue_name(str): The queue whose message counts to get.
Returns: The number of messages in the queue, including the delay queue
"""
return sum(self.get_queue_message_counts(queue_name)[:-1])


def URLRabbitmqBroker(url, *, middleware=None):
"""Alias for the RabbitMQ broker that takes a connection URL as a
Expand Down
119 changes: 103 additions & 16 deletions dramatiq/brokers/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import random
import time
import warnings
from bisect import bisect
from collections import defaultdict
from os import path
from threading import Lock
from uuid import uuid4
Expand All @@ -32,7 +34,7 @@
from ..message import Message

MAINTENANCE_SCALE = 1000000
MAINTENANCE_COMMAND_BLACKLIST = {"ack", "nack"}
MAINTENANCE_COMMAND_BLACKLIST = {"ack", "nack", "qsize"}

#: How many commands out of a million should trigger queue
#: maintenance.
Expand All @@ -50,6 +52,33 @@
#: the first time it's run, but it may be overwritten using this var.
DEFAULT_LUA_MAX_STACK = getenv_int("dramatiq_lua_max_stack")

#: The default priority steps. Each step will create a new queue
DEFAULT_PRIORITY_STEPS = [0, 3, 6, 9]


def _get_all_priority_queue_names(queue_name, priority_steps):
"""
Yields the queue names for a given queue name and a list of priority steps.
Parameters:
queue_name(str): The queue name
priority_steps(list[int]): The configured priority steps
Returns: The queue names for the given queue name and priority steps
"""
if dq_name(queue_name) == queue_name:
return
for step in priority_steps:
yield pri_name(queue_name, step)


def pri_name(queue_name, priority):
"""Returns the queue name for a given queue name and a priority. If the given
queue name already belongs to a priority queue, then it is returned
unchanged.
"""
if queue_name.endswith(".PR{}".format(priority)):
return queue_name
return "{}.PR{}".format(queue_name, priority)


class RedisBroker(Broker):
"""A broker than can be used with Redis.
Expand Down Expand Up @@ -81,6 +110,11 @@ class RedisBroker(Broker):
dead-lettered messages are kept in Redis for.
requeue_deadline(int): Deprecated. Does nothing.
requeue_interval(int): Deprecated. Does nothing.
max_priority(int): Configure queues with max priority to support message’s broker_priority option.
The queuing is done by having multiple queues for each named queue.
The queues are then consumed by in order of priority. The max value of max_priority is 10.
priority_steps(list[int]): The priority range that is collapsed into the queues (4 by default).
The number of steps can be configured by providing a list of numbers in sorted order
client(redis.StrictRedis): A redis client to use.
**parameters: Connection parameters are passed directly
to :class:`redis.Redis`.
Expand All @@ -97,6 +131,8 @@ def __init__(
requeue_deadline=None,
requeue_interval=None,
client=None,
max_priority=None,
priority_steps=None,
**parameters
):
super().__init__(middleware=middleware)
Expand All @@ -114,6 +150,14 @@ def __init__(
self.heartbeat_timeout = heartbeat_timeout
self.dead_message_ttl = dead_message_ttl
self.queues = set()
if max_priority:
if max_priority > 10:
raise ValueError("max priority is supported up to 10")
if not priority_steps:
self.priority_steps = DEFAULT_PRIORITY_STEPS[:bisect(DEFAULT_PRIORITY_STEPS, max_priority) - 1]
self.priority_steps = priority_steps or []
else:
self.priority_steps = []
# TODO: Replace usages of StrictRedis (redis-py 2.x) with Redis in Dramatiq 2.0.
self.client = client or redis.StrictRedis(**parameters)
self.scripts = {name: self.client.register_script(script) for name, script in _scripts.items()}
Expand Down Expand Up @@ -163,6 +207,9 @@ def enqueue(self, message, *, delay=None):
ValueError: If ``delay`` is longer than 7 days.
"""
queue_name = message.queue_name
if "broker_priority" in message.options and delay is None:
priority = message.options["broker_priority"]
queue_name = self.priority_queue_name(queue_name, priority)

# Each enqueued message must have a unique id in Redis so
# using the Message's id isn't safe because messages may be
Expand Down Expand Up @@ -202,7 +249,7 @@ def flush(self, queue_name):
Parameters:
queue_name(str): The queue to flush.
"""
for name in (queue_name, dq_name(queue_name)):
for name in (queue_name, dq_name(queue_name), *_get_all_priority_queue_names(queue_name, self.priority_steps)):
self.do_purge(name)

def flush_all(self):
Expand Down Expand Up @@ -231,13 +278,36 @@ def join(self, queue_name, *, interval=100, timeout=None):
if deadline and time.monotonic() >= deadline:
raise QueueJoinTimeout(queue_name)

size = self.do_qsize(queue_name)
size = self.get_queue_size(queue_name)

if size == 0:
return

time.sleep(interval / 1000)

def get_queue_size(self, queue_name):
"""
Get the number of messages in a queue. This method is only meant to be used in unit and integration tests.
Parameters:
queue_name(str): The queue whose message counts to get.
Returns: The number of messages in the queue, including the delay queue
"""
size = 0
if self.priority_steps:
for queue_name in _get_all_priority_queue_names(queue_name, self.priority_steps):
qsize = self.do_qsize(queue_name)
size += qsize
size += self.do_qsize(queue_name)
return size

def priority_queue_name(self, queue, priority):
if priority is None or dq_name(queue) == queue:
return queue

queue_number = self.priority_steps[bisect(self.priority_steps, priority) - 1]
return pri_name(queue, queue_number)

def _should_do_maintenance(self, command):
return int(
command not in MAINTENANCE_COMMAND_BLACKLIST and
Expand Down Expand Up @@ -310,7 +380,8 @@ def ack(self, message):
# The current queue might be different from message.queue_name
# if the message has been delayed so we want to ack on the
# current queue.
self.broker.do_ack(self.queue_name, message.options["redis_message_id"])
queue_name = self.broker.priority_queue_name(self.queue_name, message.options.get("broker_priority"))
self.broker.do_ack(queue_name, message.options["redis_message_id"])
except redis.ConnectionError as e:
raise ConnectionClosed(e) from None
finally:
Expand All @@ -319,19 +390,26 @@ def ack(self, message):
def nack(self, message):
try:
# Same deal as above.
self.broker.do_nack(self.queue_name, message.options["redis_message_id"])
queue_name = self.broker.priority_queue_name(self.queue_name, message.options.get("broker_priority"))
self.broker.do_nack(queue_name, message.options["redis_message_id"])
except redis.ConnectionError as e:
raise ConnectionClosed(e) from None
finally:
self.queued_message_ids.discard(message.message_id)

def requeue(self, messages):
message_ids = [message.options["redis_message_id"] for message in messages]
if not message_ids:
return

self.logger.debug("Re-enqueueing %r on queue %r.", message_ids, self.queue_name)
self.broker.do_requeue(self.queue_name, *message_ids)
messages_id_by_queue = defaultdict(list)
for message in messages:
priority = message.options.get("broker_priority")
if priority is None:
queue_name = self.queue_name
else:
queue_name = self.broker.priority_queue_name(self.queue_name, priority)
messages_id_by_queue[queue_name].append(message.options["redis_message_id"])

for queue_name, message_ids in messages_id_by_queue.items():
self.logger.debug("Re-enqueueing %r on queue %r.", message_ids, self.queue_name)
self.broker.do_requeue(queue_name, *message_ids)

def __next__(self):
try:
Expand Down Expand Up @@ -360,11 +438,16 @@ def __next__(self):
# prefetch up to that number of messages.
messages = []
if self.outstanding_message_count < self.prefetch:
self.message_cache = messages = self.broker.do_fetch(
self.queue_name,
self.prefetch - self.outstanding_message_count,
)

for queue_name in self.queue_names():
# Ideally, we would want to sort the messages by their priority,
# but that will require decoding them now
self.message_cache = messages = self.broker.do_fetch(
queue_name,
self.prefetch - self.outstanding_message_count,
)

if messages:
break
# Because we didn't get any messages, we should
# progressively long poll up to the idle timeout.
if not messages:
Expand All @@ -374,6 +457,10 @@ def __next__(self):
except redis.ConnectionError as e:
raise ConnectionClosed(e) from None

def queue_names(self):
yield from _get_all_priority_queue_names(self.queue_name, self.broker.priority_steps)
yield self.queue_name


_scripts = {}
_scripts_path = path.join(path.abspath(path.dirname(__file__)), "redis")
Expand Down
3 changes: 3 additions & 0 deletions dramatiq/brokers/redis/dispatch.lua
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ local queue_acks = acks .. "." .. queue_name
local queue_full_name = namespace .. ":" .. queue_name
local queue_messages = queue_full_name .. ".msgs"
local xqueue_full_name = namespace .. ":" .. queue_canonical_name .. ".XQ"
if string.sub(queue_canonical_name, -4, -2) == ".PR" then
xqueue_full_name = namespace .. ":" .. string.sub(queue_canonical_name, 1, -5) .. ".XQ"
end
local xqueue_messages = xqueue_full_name .. ".msgs"

-- Command-specific arguments.
Expand Down
10 changes: 10 additions & 0 deletions dramatiq/brokers/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,16 @@ def join(self, queue_name, *, fail_fast=False, timeout=None):
raise message._exception from None

return
def get_queue_size(self, queue_name):
"""Returns the number of messages in a queue.
Parameters:
queue_name(str): The queue to inspect.
Returns:
int: The number of messages in the queue.
"""
return self.queues[queue_name].qsize() + self.queues[dq_name(queue_name)].qsize()


class _StubConsumer(Consumer):
Expand Down
Loading

0 comments on commit cc81586

Please sign in to comment.