From de8a4bc235b4f7dc6a650003196e7bb514d5adac Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Jan 2022 15:06:23 -0500 Subject: [PATCH 1/7] Do not mutate result dictionaries. --- synapse/handlers/room.py | 45 ++++++++++++++++++++++++------------- synapse/rest/admin/rooms.py | 39 +++++++++++++++++++------------- synapse/rest/client/room.py | 39 +++++++++++++++++++------------- 3 files changed, 77 insertions(+), 46 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f963078e596c..2d7ad5dd0c41 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -30,6 +30,7 @@ Tuple, ) +import attr from typing_extensions import TypedDict from synapse.api.constants import ( @@ -90,6 +91,17 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventContext: + events_before: List[EventBase] + event: EventBase + events_after: List[EventBase] + state: List[EventBase] + aggregations: Dict[str, JsonDict] + start: str + end: str + + class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1119,7 +1131,7 @@ async def get_event_context( limit: int, event_filter: Optional[Filter], use_admin_priviledge: bool = False, - ) -> Optional[JsonDict]: + ) -> Optional[EventContext]: """Retrieves events, pagination tokens and state around a given event in a room. @@ -1179,11 +1191,11 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. - results["event"] = filtered[0] + event = filtered[0] # Fetch the aggregations. aggregations = await self.store.get_bundled_aggregations( - [results["event"]], user.to_string() + [event], user.to_string() ) aggregations.update( await self.store.get_bundled_aggregations( @@ -1195,7 +1207,6 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: results["events_after"], user.to_string() ) ) - results["aggregations"] = aggregations if results["events_after"]: last_event_id = results["events_after"][-1].event_id @@ -1207,7 +1218,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: ev.sender for ev in itertools.chain( results["events_before"], - (results["event"],), + (event,), results["events_after"], ) ) @@ -1226,21 +1237,23 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: if event_filter: state_events = await event_filter.filter(state_events) - results["state"] = await filter_evts(state_events) - # We use a dummy token here as we only care about the room portion of # the token, which we replace. token = StreamToken.START - results["start"] = await token.copy_and_replace( - "room_key", results["start"] - ).to_string(self.store) - - results["end"] = await token.copy_and_replace( - "room_key", results["end"] - ).to_string(self.store) - - return results + return EventContext( + events_before=results["events_before"], + event=event, + events_after=results["events_after"], + state=await filter_evts(state_events), + aggregations=aggregations, + start=await token.copy_and_replace("room_key", results["start"]).to_string( + self.store + ), + end=await token.copy_and_replace("room_key", results["end"]).to_string( + self.store + ), + ) class TimestampLookupHandler: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index efe25fe7ebf7..5b706efbcff0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -729,7 +729,7 @@ async def on_GET( else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, @@ -738,25 +738,34 @@ async def on_GET( use_admin_priviledge=True, ) - if not results: + if not event_context: raise SynapseError( HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND ) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return HTTPStatus.OK, results diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90bb9142a098..90355e44b25e 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -706,27 +706,36 @@ async def on_GET( else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, limit, event_filter ) - if not results: + if not event_context: raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return 200, results From 42e1817eef73f8968dbd2318f666143dd4dfba16 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Jan 2022 15:09:27 -0500 Subject: [PATCH 2/7] Fetch all bundled aggregations at once. --- synapse/handlers/room.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2d7ad5dd0c41..c8d01fb3f801 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1195,17 +1195,10 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: # Fetch the aggregations. aggregations = await self.store.get_bundled_aggregations( - [event], user.to_string() - ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_before"], user.to_string() - ) - ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_after"], user.to_string() - ) + itertools.chain( + results["events_before"], (event,), results["events_after"] + ), + user.to_string(), ) if results["events_after"]: From 12e8832a2d952be7bfa4f87c6afeb2e8cf97d23c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Jan 2022 14:25:15 -0500 Subject: [PATCH 3/7] Create a BundledAggregations attrs. --- synapse/events/utils.py | 57 +++++++++++++++------ synapse/handlers/room.py | 3 +- synapse/handlers/sync.py | 3 +- synapse/rest/client/sync.py | 3 +- synapse/storage/databases/main/relations.py | 56 ++++++++++++-------- 5 files changed, 80 insertions(+), 42 deletions(-) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 918adeecf8cd..243696b35724 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import re -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, +) from frozendict import frozendict @@ -26,6 +36,10 @@ from . import EventBase +if TYPE_CHECKING: + from synapse.storage.databases.main.relations import BundledAggregations + + # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (? JsonDict: """Serializes a single event. @@ -415,7 +429,7 @@ def _inject_bundled_aggregations( self, event: EventBase, time_now: int, - aggregations: JsonDict, + aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. @@ -427,13 +441,18 @@ def _inject_bundled_aggregations( serialized_event: The serialized event which may be modified. """ - # Make a copy in-case the object is cached. - aggregations = aggregations.copy() + serialized_aggregations = {} + + if aggregations.annotations: + serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations + + if aggregations.references: + serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references - if RelationTypes.REPLACE in aggregations: + if aggregations.replace: # If there is an edit replace the content, preserving existing # relations. - edit = aggregations[RelationTypes.REPLACE] + edit = aggregations.replace # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -451,24 +470,28 @@ def _inject_bundled_aggregations( else: serialized_event["content"].pop("m.relates_to", None) - aggregations[RelationTypes.REPLACE] = { + serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, } # If this event is the start of a thread, include a summary of the replies. - if RelationTypes.THREAD in aggregations: - # Serialize the latest thread event. - latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] - - # Don't bundle aggregations as this could recurse forever. - aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=None - ) + if aggregations.thread: + serialized_aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": self.serialize_event( + aggregations.thread.latest_event, time_now, bundle_aggregations=None + ), + "count": aggregations.thread.count, + "current_user_participated": aggregations.thread.current_user_participated, + } # Include the bundled aggregations in the event. - serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + if serialized_aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + serialized_aggregations + ) def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index c8d01fb3f801..7972aa8289ab 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -61,6 +61,7 @@ from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -97,7 +98,7 @@ class EventContext: event: EventBase events_after: List[EventBase] state: List[EventBase] - aggregations: Dict[str, JsonDict] + aggregations: Dict[str, BundledAggregations] start: str end: str diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7e2a892b63ae..c72ed7c2907a 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -100,7 +101,7 @@ class TimelineBatch: limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. - bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None + bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index d20ae1421e19..f9615da52583 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace +from synapse.storage.databases.main.relations import BundledAggregations from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -526,7 +527,7 @@ async def encode_room( def serialize( events: Iterable[EventBase], - aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + aggregations: Optional[Dict[str, BundledAggregations]] = None, ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2cb5d06c1352..2f1ee0bd0364 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,17 +13,7 @@ # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast import attr from frozendict import frozendict @@ -43,6 +33,7 @@ PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -51,6 +42,27 @@ logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + latest_event: EventBase + count: int + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -585,7 +597,7 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool: async def _get_bundled_aggregation_for_event( self, event: EventBase, user_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[BundledAggregations]: """Generate bundled aggregations for an event. Note that this does not use a cache, but depends on cached methods. @@ -616,24 +628,24 @@ async def _get_bundled_aggregation_for_event( # The bundled aggregations to include, a mapping of relation type to a # type-specific value. Some types include the direct return type here # while others need more processing during serialization. - aggregations: Dict[str, Any] = {} + aggregations = BundledAggregations() annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations.annotations = annotations.to_dict() references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() + aggregations.references = references.to_dict() edit = None if event.type == EventTypes.Message: edit = await self.get_applicable_edit(event_id, room_id) if edit: - aggregations[RelationTypes.REPLACE] = edit + aggregations.replace = edit # If this event is the start of a thread, include a summary of the replies. if self._msc3440_enabled: @@ -644,11 +656,11 @@ async def _get_bundled_aggregation_for_event( event_id, room_id, user_id ) if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - "latest_event": latest_thread_event, - "count": thread_count, - "current_user_participated": participated, - } + aggregations.thread = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + current_user_participated=participated, + ) # Store the bundled aggregations in the event metadata for later use. return aggregations @@ -657,7 +669,7 @@ async def get_bundled_aggregations( self, events: Iterable[EventBase], user_id: str, - ) -> Dict[str, Dict[str, Any]]: + ) -> Dict[str, BundledAggregations]: """Generate bundled aggregations for events. Args: From 9f3d783d3a495af61647919d5ad62509102fb533 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 24 Jan 2022 14:44:08 -0500 Subject: [PATCH 4/7] Newsfragment --- changelog.d/11815.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/11815.misc diff --git a/changelog.d/11815.misc b/changelog.d/11815.misc new file mode 100644 index 000000000000..83aa6d6eb046 --- /dev/null +++ b/changelog.d/11815.misc @@ -0,0 +1 @@ +Improve type safety of bundled aggregations code. From 177783111825eb78c76c797fe033dd0ee1b2311d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 25 Jan 2022 13:23:08 -0500 Subject: [PATCH 5/7] Convert retun type of get_events_around to attrs. --- synapse/handlers/room.py | 32 ++++++++--------- synapse/handlers/search.py | 45 ++++++++++++------------ synapse/push/mailer.py | 2 +- synapse/storage/databases/main/stream.py | 22 ++++++++---- 4 files changed, 54 insertions(+), 47 deletions(-) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 7972aa8289ab..1420d6772955 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1180,15 +1180,15 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) + events_before = results.events_before + events_after = results.events_after if event_filter: - results["events_before"] = await event_filter.filter( - results["events_before"] - ) - results["events_after"] = await event_filter.filter(results["events_after"]) + events_before = await event_filter.filter(events_before) + events_after = await event_filter.filter(events_after) - results["events_before"] = await filter_evts(results["events_before"]) - results["events_after"] = await filter_evts(results["events_after"]) + events_before = await filter_evts(events_before) + events_after = await filter_evts(events_after) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. @@ -1196,14 +1196,12 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: # Fetch the aggregations. aggregations = await self.store.get_bundled_aggregations( - itertools.chain( - results["events_before"], (event,), results["events_after"] - ), + itertools.chain(events_before, (event,), events_after), user.to_string(), ) - if results["events_after"]: - last_event_id = results["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event_id @@ -1211,9 +1209,9 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: state_filter = StateFilter.from_lazy_load_member_list( ev.sender for ev in itertools.chain( - results["events_before"], + events_before, (event,), - results["events_after"], + events_after, ) ) else: @@ -1236,15 +1234,15 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]: token = StreamToken.START return EventContext( - events_before=results["events_before"], + events_before=events_before, event=event, - events_after=results["events_after"], + events_after=events_after, state=await filter_evts(state_events), aggregations=aggregations, - start=await token.copy_and_replace("room_key", results["start"]).to_string( + start=await token.copy_and_replace("room_key", results.start).to_string( self.store ), - end=await token.copy_and_replace("room_key", results["end"]).to_string( + end=await token.copy_and_replace("room_key", results.end).to_string( self.store ), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0b153a682261..02bb5ae72f51 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -361,36 +361,37 @@ async def search( logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), - len(res["events_after"]), + len(res.events_before), + len(res.events_after), ) - res["events_before"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_before"] + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before ) - res["events_after"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_after"] + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after ) - res["start"] = await now_token.copy_and_replace( - "room_key", res["start"] - ).to_string(self.store) - - res["end"] = await now_token.copy_and_replace( - "room_key", res["end"] - ).to_string(self.store) + context = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace( + "room_key", res.end + ).to_string(self.store), + } if include_profile: senders = { ev.sender - for ev in itertools.chain( - res["events_before"], [event], res["events_after"] - ) + for ev in itertools.chain(events_before, [event], events_after) } - if res["events_after"]: - last_event_id = res["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event.event_id @@ -402,7 +403,7 @@ async def search( last_event_id, state_filter ) - res["profile_info"] = { + context["profile_info"] = { s.state_key: { "displayname": s.content.get("displayname", None), "avatar_url": s.content.get("avatar_url", None), @@ -411,7 +412,7 @@ async def search( if s.type == EventTypes.Member and s.state_key in senders } - contexts[event.event_id] = res + contexts[event.event_id] = context else: contexts = {} @@ -421,10 +422,10 @@ async def search( for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now + context["events_before"], time_now # type: ignore[arg-type] ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now + context["events_after"], time_now # type: ignore[arg-type] ) state_results = {} diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index dadfc574134c..3df8452eecf7 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -455,7 +455,7 @@ async def _get_notif_vars( } the_events = await filter_events_for_client( - self.storage, user_id, results["events_before"] + self.storage, user_id, results.events_before ) the_events.append(notif_event) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 319464b1fa83..a898f847e7d5 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -81,6 +81,14 @@ class _EventDictReturn: stream_ordering: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventsAround: + events_before: List[EventBase] + events_after: List[EventBase] + start: RoomStreamToken + end: RoomStreamToken + + def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], @@ -846,7 +854,7 @@ async def get_events_around( before_limit: int, after_limit: int, event_filter: Optional[Filter] = None, - ) -> dict: + ) -> _EventsAround: """Retrieve events and pagination tokens around a given event in a room. """ @@ -869,12 +877,12 @@ async def get_events_around( list(results["after"]["event_ids"]), get_prev_content=True ) - return { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } + return _EventsAround( + events_before=events_before, + events_after=events_after, + start=results["before"]["token"], + end=results["after"]["token"], + ) def _get_events_around_txn( self, From 067d4bcf69b5e214e9945bf5ed7060674cd46e1b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 25 Jan 2022 14:35:20 -0500 Subject: [PATCH 6/7] Add missing assert in tests. --- tests/rest/client/test_relations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c9b220e73d1a..96ae7790bb15 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -577,7 +577,7 @@ def assert_bundle(event_json: JsonDict) -> None: self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - self._find_event_in_chunk(room_timeline["events"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included From d976be7000fc6e17c4963dafeaf3b88ad2644989 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 25 Jan 2022 14:42:36 -0500 Subject: [PATCH 7/7] Do not include empty bundled aggregations. --- synapse/storage/databases/main/relations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2f1ee0bd0364..a9a5dd5f03b6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -62,6 +62,9 @@ class BundledAggregations: replace: Optional[EventBase] = None thread: Optional[_ThreadAggregation] = None + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + class RelationsWorkerStore(SQLBaseStore): def __init__( @@ -688,7 +691,7 @@ async def get_bundled_aggregations( results = {} for event in events: event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result is not None: + if event_result: results[event.event_id] = event_result return results