diff --git a/changelog.d/11376.bugfix b/changelog.d/11376.bugfix new file mode 100644 index 000000000000..639e48b59b0d --- /dev/null +++ b/changelog.d/11376.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, for real this time. Also fix a race condition introduced in the previous insufficient fix in 1.47.0. diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index c6bf316d5bf3..31587e857116 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -602,7 +602,7 @@ async def _get_events_from_cache_or_db( # already due to `_get_events_from_db`). fetching_deferred: ObservableDeferred[ Dict[str, _EventCacheEntry] - ] = ObservableDeferred(defer.Deferred()) + ] = ObservableDeferred(defer.Deferred(), consumeErrors=True) for event_id in missing_events_ids: self._current_event_fetches[event_id] = fetching_deferred @@ -736,35 +736,56 @@ async def get_stripped_room_state_from_event_context( for e in state_to_include.values() ] - def _do_fetch(self, conn: Connection) -> None: + async def _do_fetch(self) -> None: + """Services requests for events from the `_event_fetch_list` queue.""" + try: + await self.db_pool.runWithConnection(self._do_fetch_txn) + except BaseException as e: + with self._event_fetch_lock: + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + + if self._event_fetch_ongoing == 0 and self._event_fetch_list: + # We are the last remaining fetcher and we have just failed. + # Fail any outstanding event fetches, since no one else will process + # them. + failed_event_list = self._event_fetch_list + self._event_fetch_list = [] + else: + failed_event_list = [] + + for _, deferred in failed_event_list: + if not deferred.called: + with PreserveLoggingContext(): + deferred.errback(e) + + def _do_fetch_txn(self, conn: Connection) -> None: """Takes a database connection and waits for requests for events from the _event_fetch_list queue. """ - try: - i = 0 - while True: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] + i = 0 + while True: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if ( + not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING + or single_threaded + or i > EVENT_QUEUE_ITERATIONS + ): + self._event_fetch_ongoing -= 1 + event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 - if not event_list: - single_threaded = self.database_engine.single_threaded - if ( - not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING - or single_threaded - or i > EVENT_QUEUE_ITERATIONS - ): - break - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - self._fetch_event_list(conn, event_list) - finally: - self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + self._fetch_event_list(conn, event_list) def _fetch_event_list( self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]] @@ -994,9 +1015,7 @@ async def _enqueue_events(self, events: Iterable[str]) -> Dict[str, _EventRow]: should_start = False if should_start: - run_as_background_process( - "fetch_events", self.db_pool.runWithConnection, self._do_fetch - ) + run_as_background_process("fetch_events", self._do_fetch) logger.debug("Loading %d events: %s", len(events), events) with PreserveLoggingContext(): diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index a649e8c61872..ac1cbcd61917 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -12,11 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from contextlib import contextmanager +from typing import Generator +from twisted.enterprise.adbapi import ConnectionPool +from twisted.internet.defer import ensureDeferred +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import EventFormatVersions, RoomVersions from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room -from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.server import HomeServer +from synapse.storage.databases.main.events_worker import ( + EVENT_QUEUE_THREADS, + EventsWorkerStore, +) +from synapse.storage.types import Connection +from synapse.util import Clock from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -144,3 +157,108 @@ def test_dedupe(self): # We should have fetched the event from the DB self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + +class DatabaseOutageTestCase(unittest.HomeserverTestCase): + """Test event fetching during a database outage.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): + self.store: EventsWorkerStore = hs.get_datastore() + + self.room_id = f"!room:{hs.hostname}" + self.event_ids = [f"event{i}" for i in range(20)] + + self._populate_events() + + def _populate_events(self) -> None: + """Ensure that there are test events in the database.""" + self.get_success( + self.store.db_pool.simple_upsert( + "rooms", + {"room_id": self.room_id}, + {"room_version": RoomVersions.V4.identifier}, + ) + ) + + self.event_ids = [f"event{i}" for i in range(20)] + for idx, event_id in enumerate(self.event_ids): + self.get_success( + self.store.db_pool.simple_upsert( + "events", + {"event_id": event_id}, + { + "event_id": event_id, + "room_id": self.room_id, + "topological_ordering": idx, + "stream_ordering": idx, + "type": "test", + "processed": True, + "outlier": False, + }, + ) + ) + self.get_success( + self.store.db_pool.simple_upsert( + "event_json", + {"event_id": event_id}, + { + "room_id": self.room_id, + "json": json.dumps({"type": "test", "room_id": self.room_id}), + "internal_metadata": "{}", + "format_version": EventFormatVersions.V3, + }, + ) + ) + + @contextmanager + def _outage(self) -> Generator[None, None, None]: + """Simulate a database outage.""" + connection_pool = self.store.db_pool._db_pool + connection_pool.close() + connection_pool.start() + original_connection_factory = connection_pool.connectionFactory + + def connection_factory(_pool: ConnectionPool) -> Connection: + raise Exception("Could not connect to the database.") + + connection_pool.connectionFactory = connection_factory # type: ignore[assignment] + try: + yield + finally: + connection_pool.connectionFactory = original_connection_factory + + # If the in-memory SQLite database is being used, all the events are gone. + # Restore the test data. + self._populate_events() + + def test_failure(self) -> None: + """Test that event fetches do not get stuck during a database outage.""" + with self._outage(): + failure = self.get_failure( + self.store.get_event(self.event_ids[0]), Exception + ) + self.assertEqual(str(failure.value), "Could not connect to the database.") + + def test_recovery(self) -> None: + """Test that event fetchers recover after a database outage.""" + with self._outage(): + # Kick off a bunch of event fetches but do not pump the reactor + event_deferreds = [] + for event_id in self.event_ids: + event_deferreds.append(ensureDeferred(self.store.get_event(event_id))) + + # We should have maxed out on event fetcher threads + self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS) + + # All the event fetchers will fail + self.pump() + self.assertEqual(self.store._event_fetch_ongoing, 0) + + for event_deferred in event_deferreds: + failure = self.get_failure(event_deferred, Exception) + self.assertEqual( + str(failure.value), "Could not connect to the database." + ) + + # This next event fetch should succeed + self.get_success(self.store.get_event(self.event_ids[0]))