Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Bundle aggregations outside of the serialization method #11612

Merged
merged 4 commits into from
Jan 7, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 125 additions & 3 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,30 @@
# limitations under the License.

import logging
from typing import List, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
cast,
)

import attr
from frozendict import frozendict

from synapse.api.constants import RelationTypes
from synapse.api.constants import EventTypes, RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import (
AggregationPaginationToken,
Expand All @@ -29,10 +45,24 @@
)
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._msc1849_enabled = hs.config.experimental.msc1849_enabled
self._msc3440_enabled = hs.config.experimental.msc3440_enabled

@cached(tree=True)
async def get_relations_for_event(
self,
Expand Down Expand Up @@ -515,6 +545,98 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)

async def _get_bundled_aggregation_for_event(
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
self, event: EventBase
) -> Optional[Dict[str, Any]]:
"""Generate bundled aggregations for an event.

Note that this does not use a cache, but depends on cached methods.

Args:
event: The event to calculate bundled aggregations for.

Returns:
The bundled aggregations for an event, if bundled aggregations are
enabled and the event can have bundled aggregations.
"""
# State events and redacted events do not get bundled aggregations.
if event.is_state() or event.internal_metadata.is_redacted():
return None

# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return None

event_id = event.event_id
room_id = event.room_id

# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
aggregations: Dict[str, Any] = {}

annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()

references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict()

edit = None
if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id)

if edit:
aggregations[RelationTypes.REPLACE] = edit

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
(
thread_count,
latest_thread_event,
) = await self.get_thread_summary(event_id, room_id)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": latest_thread_event,
"count": thread_count,
}

# Store the bundled aggregations in the event metadata for later use.
return aggregations

async def get_bundled_aggregations(
self, events: Iterable[EventBase]
) -> Dict[str, Dict[str, Any]]:
"""Generate bundled aggregations for events.

Args:
events: The iterable of events to calculate bundled aggregations for.

Returns:
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
# If bundled aggregations are disabled, nothing to do.
if not self._msc1849_enabled:
return {}

# TODO Parallelize.
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event)
if event_result is not None:
results[event.event_id] = event_result

return results


class RelationsStore(RelationsWorkerStore):
pass