Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert misc database code to async (#8087)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Aug 14, 2020
1 parent 7bdf982 commit 894dae7
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 64 deletions.
1 change: 1 addition & 0 deletions changelog.d/8087.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
14 changes: 5 additions & 9 deletions synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

from canonicaljson import json

from twisted.internet import defer

from synapse.metrics.background_process_metrics import run_as_background_process

from . import engines
Expand Down Expand Up @@ -308,9 +306,8 @@ def register_noop_background_update(self, update_name):
update_name (str): Name of update
"""

@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
async def noop_update(progress, batch_size):
await self._end_background_update(update_name)
return 1

self.register_background_update_handler(update_name, noop_update)
Expand Down Expand Up @@ -409,12 +406,11 @@ def create_index_sqlite(conn):
else:
runner = create_index_sqlite

@defer.inlineCallbacks
def updater(progress, batch_size):
async def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name)
await self.db_pool.runWithConnection(runner)
await self._end_background_update(update_name)
return 1

self.register_background_update_handler(update_name, updater)
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,10 +671,9 @@ def get_device_list_last_stream_id_for_remote(self, user_id: str):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch(
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
Expand Down
9 changes: 4 additions & 5 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,18 +86,17 @@ def __init__(self, database: DatabasePool, db_conn, hs):
self._rotate_delay = 3
self._rotate_count = 10000

@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
ret = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
return ret

def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
Expand Down
9 changes: 3 additions & 6 deletions synapse/storage/databases/main/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,10 @@ def _get_presence_for_user(self, user_id):
raise NotImplementedError()

@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
def get_presence_for_users(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
Expand Down
16 changes: 6 additions & 10 deletions synapse/storage/databases/main/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,15 @@ def have_push_rules_changed_txn(txn):
)

@cachedList(
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
def bulk_get_push_rules(self, user_ids):
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}

results = {user_id: [] for user_id in user_ids}

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
Expand All @@ -194,7 +191,7 @@ def bulk_get_push_rules(self, user_ids):
for row in rows:
results.setdefault(row["user_name"], []).append(row)

enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)

for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
Expand Down Expand Up @@ -260,15 +257,14 @@ def copy_push_rules_from_room_to_room_for_user(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules_enabled(self, user_ids):
async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}

results = {user_id: {} for user_id in user_ids}

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
Expand Down
9 changes: 3 additions & 6 deletions synapse/storage/databases/main/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,10 @@ def get_if_user_has_pusher(self, user_id):
raise NotImplementedError()

@cachedList(
cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_if_users_have_pushers(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,8 @@ def f(txn):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}

Expand Down Expand Up @@ -243,7 +242,7 @@ def f(txn):

return self.db_pool.cursor_to_dict(txn)

txn_results = yield self.db_pool.runInteraction(
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)

Expand Down
17 changes: 6 additions & 11 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set

from twisted.internet import defer

from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
Expand Down Expand Up @@ -92,8 +90,7 @@ def __init__(self, database: DatabasePool, db_conn, hs):
lambda: self._known_servers_count,
)

@defer.inlineCallbacks
def _count_known_servers(self):
async def _count_known_servers(self):
"""
Count the servers that this server knows about.
Expand Down Expand Up @@ -121,7 +118,7 @@ def _transact(txn):
txn.execute(query)
return list(txn)[0][0]

count = yield self.db_pool.runInteraction("get_known_servers", _transact)
count = await self.db_pool.runInteraction("get_known_servers", _transact)

# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
Expand Down Expand Up @@ -589,23 +586,21 @@ def _get_joined_profile_from_event_id(self, event_id):
raise NotImplementedError()

@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
event_ids: The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
Expand Down
5 changes: 2 additions & 3 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,11 @@ def _get_state_group_for_event(self, event_id):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
Expand Down
13 changes: 5 additions & 8 deletions synapse/storage/databases/main/user_erasure_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,23 @@ def is_user_erased(self, user_id):
desc="is_user_erased",
).addCallback(operator.truth)

@cachedList(
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
)
def are_users_erased(self, user_ids):
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
Args:
user_ids (iterable[str]): full user id to check
Returns:
Deferred[dict[str, bool]]:
dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))

rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
Expand All @@ -65,8 +63,7 @@ def are_users_erased(self, user_ids):
)
erased_users = {row["user_id"] for row in rows}

res = {u: u in erased_users for u in user_ids}
return res
return {u: u in erased_users for u in user_ids}


class UserErasureStore(UserErasureWorkerStore):
Expand Down

0 comments on commit 894dae7

Please sign in to comment.