Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix up BatchingQueue #10078

Merged
merged 8 commits into from
May 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/10078.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix up `BatchingQueue` implementation.
70 changes: 48 additions & 22 deletions synapse/util/batching_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
TypeVar,
)

from prometheus_client import Gauge

from twisted.internet import defer

from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock

Expand All @@ -38,6 +39,24 @@
V = TypeVar("V")
R = TypeVar("R")

number_queued = Gauge(
"synapse_util_batching_queue_number_queued",
"The number of items waiting in the queue across all keys",
labelnames=("name",),
)

number_in_flight = Gauge(
"synapse_util_batching_queue_number_pending",
"The number of items across all keys either being processed or waiting in a queue",
labelnames=("name",),
)

number_of_keys = Gauge(
"synapse_util_batching_queue_number_of_keys",
"The number of distinct keys that have items queued",
labelnames=("name",),
)


class BatchingQueue(Generic[V, R]):
"""A queue that batches up work, calling the provided processing function
Expand All @@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]):
called, and will keep being called until the queue has been drained (for the
given key).

If the processing function raises an exception then the exception is proxied
through to the callers waiting on that batch of work.

Note that the return value of `add_to_queue` will be the return value of the
processing function that processed the given item. This means that the
returned value will likely include data for other items that were in the
batch.

Args:
name: A name for the queue, used for logging contexts and metrics.
This must be unique, otherwise the metrics will be wrong.
clock: The clock to use to schedule work.
process_batch_callback: The callback to to be run to process a batch of
work.
"""

def __init__(
Expand All @@ -73,19 +102,15 @@ def __init__(
# The function to call with batches of values.
self._process_batch_callback = process_batch_callback

LaterGauge(
"synapse_util_batching_queue_number_queued",
"The number of items waiting in the queue across all keys",
labels=("name",),
caller=lambda: sum(len(v) for v in self._next_values.values()),
number_queued.labels(self._name).set_function(
lambda: sum(len(q) for q in self._next_values.values())
)

LaterGauge(
"synapse_util_batching_queue_number_of_keys",
"The number of distinct keys that have items queued",
labels=("name",),
caller=lambda: len(self._next_values),
)
number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))

self._number_in_flight_metric = number_in_flight.labels(
self._name
) # type: Gauge

async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
"""Adds the value to the queue with the given key, returning the result
Expand All @@ -107,17 +132,18 @@ async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
if key not in self._processing_keys:
run_as_background_process(self._name, self._process_queue, key)

return await make_deferred_yieldable(d)
with self._number_in_flight_metric.track_inprogress():
return await make_deferred_yieldable(d)

async def _process_queue(self, key: Hashable) -> None:
"""A background task to repeatedly pull things off the queue for the
given key and call the `self._process_batch_callback` with the values.
"""

try:
if key in self._processing_keys:
return
if key in self._processing_keys:
return

try:
self._processing_keys.add(key)

while True:
Expand All @@ -137,16 +163,16 @@ async def _process_queue(self, key: Hashable) -> None:
values = [value for value, _ in next_values]
results = await self._process_batch_callback(values)

for _, deferred in next_values:
with PreserveLoggingContext():
with PreserveLoggingContext():
for _, deferred in next_values:
deferred.callback(results)

except Exception as e:
for _, deferred in next_values:
if deferred.called:
continue
with PreserveLoggingContext():
for _, deferred in next_values:
if deferred.called:
continue

with PreserveLoggingContext():
deferred.errback(e)

finally:
Expand Down
78 changes: 76 additions & 2 deletions tests/util/test_batching_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
from twisted.internet import defer

from synapse.logging.context import make_deferred_yieldable
from synapse.util.batching_queue import BatchingQueue
from synapse.util.batching_queue import (
BatchingQueue,
number_in_flight,
number_of_keys,
number_queued,
)

from tests.server import get_clock
from tests.unittest import TestCase
Expand All @@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase):
def setUp(self):
self.clock, hs_clock = get_clock()

# We ensure that we remove any existing metrics for "test_queue".
try:
number_queued.remove("test_queue")
number_of_keys.remove("test_queue")
number_in_flight.remove("test_queue")
except KeyError:
pass

self._pending_calls = []
self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue)

Expand All @@ -32,6 +45,41 @@ async def _process_queue(self, values):
self._pending_calls.append((values, d))
return await make_deferred_yieldable(d)

def _assert_metrics(self, queued, keys, in_flight):
"""Assert that the metrics are correct"""

self.assertEqual(len(number_queued.collect()), 1)
self.assertEqual(len(number_queued.collect()[0].samples), 1)
self.assertEqual(
number_queued.collect()[0].samples[0].labels,
{"name": self.queue._name},
)
self.assertEqual(
number_queued.collect()[0].samples[0].value,
queued,
"number_queued",
)

self.assertEqual(len(number_of_keys.collect()), 1)
self.assertEqual(len(number_of_keys.collect()[0].samples), 1)
self.assertEqual(
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
)
self.assertEqual(
number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys"
)

self.assertEqual(len(number_in_flight.collect()), 1)
self.assertEqual(len(number_in_flight.collect()[0].samples), 1)
self.assertEqual(
number_queued.collect()[0].samples[0].labels, {"name": self.queue._name}
)
self.assertEqual(
number_in_flight.collect()[0].samples[0].value,
in_flight,
"number_in_flight",
)

def test_simple(self):
"""Tests the basic case of calling `add_to_queue` once and having
`_process_queue` return.
Expand All @@ -41,6 +89,8 @@ def test_simple(self):

queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))

self._assert_metrics(queued=1, keys=1, in_flight=1)

# The queue should wait a reactor tick before calling the processing
# function.
self.assertFalse(self._pending_calls)
Expand All @@ -52,12 +102,15 @@ def test_simple(self):
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo"])
self.assertFalse(queue_d.called)
self._assert_metrics(queued=0, keys=0, in_flight=1)

# Return value of the `_process_queue` should be propagated back.
self._pending_calls.pop()[1].callback("bar")

self.assertEqual(self.successResultOf(queue_d), "bar")

self._assert_metrics(queued=0, keys=0, in_flight=0)

def test_batching(self):
"""Test that multiple calls at the same time get batched up into one
call to `_process_queue`.
Expand All @@ -68,19 +121,23 @@ def test_batching(self):
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))

self._assert_metrics(queued=2, keys=1, in_flight=2)

self.clock.pump([0])

# We should see only *one* call to `_process_queue`
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called)
self._assert_metrics(queued=0, keys=0, in_flight=2)

# Return value of the `_process_queue` should be propagated back to both.
self._pending_calls.pop()[1].callback("bar")

self.assertEqual(self.successResultOf(queue_d1), "bar")
self.assertEqual(self.successResultOf(queue_d2), "bar")
self._assert_metrics(queued=0, keys=0, in_flight=0)

def test_queuing(self):
"""Test that we queue up requests while a `_process_queue` is being
Expand All @@ -92,32 +149,45 @@ def test_queuing(self):
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
self.clock.pump([0])

self.assertEqual(len(self._pending_calls), 1)

# We queue up work after the process function has been called, testing
# that they get correctly queued up.
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3"))

# We should see only *one* call to `_process_queue`
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo1"])
self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=2, keys=1, in_flight=3)

# Return value of the `_process_queue` should be propagated back to the
# first.
self._pending_calls.pop()[1].callback("bar1")

self.assertEqual(self.successResultOf(queue_d1), "bar1")
self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=2, keys=1, in_flight=2)

# We should now see a second call to `_process_queue`
self.clock.pump([0])
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo2"])
self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"])
self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=0, keys=0, in_flight=2)

# Return value of the `_process_queue` should be propagated back to the
# second.
self._pending_calls.pop()[1].callback("bar2")

self.assertEqual(self.successResultOf(queue_d2), "bar2")
self.assertEqual(self.successResultOf(queue_d3), "bar2")
self._assert_metrics(queued=0, keys=0, in_flight=0)

def test_different_keys(self):
"""Test that calls to different keys get processed in parallel."""
Expand All @@ -140,6 +210,7 @@ def test_different_keys(self):
self.assertFalse(queue_d1.called)
self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=1, keys=1, in_flight=3)

# Return value of the `_process_queue` should be propagated back to the
# first.
Expand All @@ -148,6 +219,7 @@ def test_different_keys(self):
self.assertEqual(self.successResultOf(queue_d1), "bar1")
self.assertFalse(queue_d2.called)
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=1, keys=1, in_flight=2)

# Return value of the `_process_queue` should be propagated back to the
# second.
Expand All @@ -161,9 +233,11 @@ def test_different_keys(self):
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo3"])
self.assertFalse(queue_d3.called)
self._assert_metrics(queued=0, keys=0, in_flight=1)

# Return value of the `_process_queue` should be propagated back to the
# third deferred.
self._pending_calls.pop()[1].callback("bar4")

self.assertEqual(self.successResultOf(queue_d3), "bar4")
self._assert_metrics(queued=0, keys=0, in_flight=0)