Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improvements to bundling aggregations #11815

Merged
merged 12 commits into from
Jan 26, 2022
Merged
1 change: 1 addition & 0 deletions changelog.d/11815.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type safety of bundled aggregations code.
57 changes: 40 additions & 17 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import re
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)

from frozendict import frozendict

Expand All @@ -26,6 +36,10 @@

from . import EventBase

if TYPE_CHECKING:
from synapse.storage.databases.main.relations import BundledAggregations


# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
Expand Down Expand Up @@ -376,7 +390,7 @@ def serialize_event(
event: Union[JsonDict, EventBase],
time_now: int,
*,
bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Expand Down Expand Up @@ -415,7 +429,7 @@ def _inject_bundled_aggregations(
self,
event: EventBase,
time_now: int,
aggregations: JsonDict,
aggregations: "BundledAggregations",
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
Expand All @@ -427,13 +441,18 @@ def _inject_bundled_aggregations(
serialized_event: The serialized event which may be modified.

"""
# Make a copy in-case the object is cached.
aggregations = aggregations.copy()
serialized_aggregations = {}

if aggregations.annotations:
serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations

if aggregations.references:
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references

if RelationTypes.REPLACE in aggregations:
if aggregations.replace:
# If there is an edit replace the content, preserving existing
# relations.
edit = aggregations[RelationTypes.REPLACE]
edit = aggregations.replace

# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
Expand All @@ -451,24 +470,28 @@ def _inject_bundled_aggregations(
else:
serialized_event["content"].pop("m.relates_to", None)

aggregations[RelationTypes.REPLACE] = {
serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
"sender": edit.sender,
}

# If this event is the start of a thread, include a summary of the replies.
if RelationTypes.THREAD in aggregations:
# Serialize the latest thread event.
latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]

# Don't bundle aggregations as this could recurse forever.
aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
latest_thread_event, time_now, bundle_aggregations=None
)
if aggregations.thread:
serialized_aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": self.serialize_event(
aggregations.thread.latest_event, time_now, bundle_aggregations=None
),
"count": aggregations.thread.count,
"current_user_participated": aggregations.thread.current_user_participated,
}

# Include the bundled aggregations in the event.
serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
if serialized_aggregations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(
serialized_aggregations
)

def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
Expand Down
77 changes: 41 additions & 36 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Tuple,
)

import attr
from typing_extensions import TypedDict

from synapse.api.constants import (
Expand Down Expand Up @@ -60,6 +61,7 @@
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
Expand Down Expand Up @@ -90,6 +92,17 @@
FIVE_MINUTES_IN_MS = 5 * 60 * 1000


@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventContext:
events_before: List[EventBase]
event: EventBase
events_after: List[EventBase]
state: List[EventBase]
aggregations: Dict[str, BundledAggregations]
start: str
end: str


class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
Expand Down Expand Up @@ -1119,7 +1132,7 @@ async def get_event_context(
limit: int,
event_filter: Optional[Filter],
use_admin_priviledge: bool = False,
) -> Optional[JsonDict]:
) -> Optional[EventContext]:
"""Retrieves events, pagination tokens and state around a given event
in a room.

Expand Down Expand Up @@ -1167,48 +1180,38 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
events_before = results.events_before
events_after = results.events_after

if event_filter:
results["events_before"] = await event_filter.filter(
results["events_before"]
)
results["events_after"] = await event_filter.filter(results["events_after"])
events_before = await event_filter.filter(events_before)
events_after = await event_filter.filter(events_after)

results["events_before"] = await filter_evts(results["events_before"])
results["events_after"] = await filter_evts(results["events_after"])
events_before = await filter_evts(events_before)
events_after = await filter_evts(events_after)
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
results["event"] = filtered[0]
event = filtered[0]
Comment on lines -1182 to +1195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're still doing some mutation of results on lines +1184 to +1190 above. Not sure if that's intentional? It seems to conflict with the title of this commit!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we are...I initially had converted that too and then stopped. I think it got messy, let me give it a try again though!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this to not mutate those results, it got a bit messy though. Let me know if you'd prefer I back it out. See 1777831.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I don't mind too much either way. (I was mostly checking that we hadn't missed a bit of work that we'd intended to do, judging by the commit message).

I like that the change makes it a bit easier to see what you get out of get_events_around (i.e. a thing with specific fields rather than an arbitrary dictionary), but I don't really have a strong opinion here. Happy to defer to your judgement!


# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
[results["event"]], user.to_string()
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_before"], user.to_string()
)
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_after"], user.to_string()
)
)
results["aggregations"] = aggregations

if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event_id

if event_filter and event_filter.lazy_load_members:
state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
for ev in itertools.chain(
results["events_before"],
(results["event"],),
results["events_after"],
events_before,
(event,),
events_after,
)
)
else:
Expand All @@ -1226,21 +1229,23 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if event_filter:
state_events = await event_filter.filter(state_events)

results["state"] = await filter_evts(state_events)

# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
token = StreamToken.START

results["start"] = await token.copy_and_replace(
"room_key", results["start"]
).to_string(self.store)

results["end"] = await token.copy_and_replace(
"room_key", results["end"]
).to_string(self.store)

return results
return EventContext(
events_before=events_before,
event=event,
events_after=events_after,
state=await filter_evts(state_events),
aggregations=aggregations,
start=await token.copy_and_replace("room_key", results.start).to_string(
self.store
),
end=await token.copy_and_replace("room_key", results.end).to_string(
self.store
),
)


class TimestampLookupHandler:
Expand Down
45 changes: 23 additions & 22 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,36 +361,37 @@ async def search(

logger.info(
"Context for search returned %d and %d events",
len(res["events_before"]),
len(res["events_after"]),
len(res.events_before),
len(res.events_after),
)

res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
events_before = await filter_events_for_client(
self.storage, user.to_string(), res.events_before
)

res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
events_after = await filter_events_for_client(
self.storage, user.to_string(), res.events_after
)

res["start"] = await now_token.copy_and_replace(
"room_key", res["start"]
).to_string(self.store)

res["end"] = await now_token.copy_and_replace(
"room_key", res["end"]
).to_string(self.store)
context = {
"events_before": events_before,
"events_after": events_after,
"start": await now_token.copy_and_replace(
"room_key", res.start
).to_string(self.store),
"end": await now_token.copy_and_replace(
"room_key", res.end
).to_string(self.store),
}

if include_profile:
senders = {
ev.sender
for ev in itertools.chain(
res["events_before"], [event], res["events_after"]
)
for ev in itertools.chain(events_before, [event], events_after)
}

if res["events_after"]:
last_event_id = res["events_after"][-1].event_id
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event.event_id

Expand All @@ -402,7 +403,7 @@ async def search(
last_event_id, state_filter
)

res["profile_info"] = {
context["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
Expand All @@ -411,7 +412,7 @@ async def search(
if s.type == EventTypes.Member and s.state_key in senders
}

contexts[event.event_id] = res
contexts[event.event_id] = context
else:
contexts = {}

Expand All @@ -421,10 +422,10 @@ async def search(

for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now
context["events_before"], time_now # type: ignore[arg-type]
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now
context["events_after"], time_now # type: ignore[arg-type]
)
Comment on lines 423 to 429
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this still mutates the dictionary, but it seems rather messy to not do this...


state_results = {}
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
Expand Down Expand Up @@ -100,7 +101,7 @@ class TimelineBatch:
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None

def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/mailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ async def _get_notif_vars(
}

the_events = await filter_events_for_client(
self.storage, user_id, results["events_before"]
self.storage, user_id, results.events_before
)
the_events.append(notif_event)

Expand Down
Loading