Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Properly invalidate caches when an event with a relation is redacted #12121

Merged
merged 13 commits into from
Mar 7, 2022
1 change: 1 addition & 0 deletions changelog.d/12113.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug when redacting events with relations.
1 change: 0 additions & 1 deletion changelog.d/12113.misc

This file was deleted.

1 change: 1 addition & 0 deletions changelog.d/12121.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug when redacting events with relations.
2 changes: 2 additions & 0 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def _invalidate_caches_for_event(
self.get_relations_for_event.invalidate((relates_to,))
self.get_aggregation_groups_for_event.invalidate((relates_to,))
self.get_applicable_edit.invalidate((relates_to,))
self.get_thread_summary.invalidate((relates_to,))
self.get_thread_participated.invalidate((relates_to,))

async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
Expand Down
38 changes: 33 additions & 5 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1518,7 +1518,7 @@ def _update_metadata_tables_txn(
)

# Remove from relations table.
self._handle_redaction(txn, event.redacts)
self._handle_redact_relations(txn, event.redacts)

# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
Expand Down Expand Up @@ -1943,15 +1943,43 @@ def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):

txn.execute(sql, (batch_id,))

def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
any redacted relations from the database.
def _handle_redact_relations(
self, txn: LoggingTransaction, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.

Args:
txn
redacted_event_id (str): The event that was redacted.
redacted_event_id: The event that was redacted.
"""

# Fetch the current relation of the event being redacted.
redacted_relates_to = self.db_pool.simple_select_one_onecol_txn(
txn,
table="event_relations",
keyvalues={"event_id": redacted_event_id},
retcol="relates_to_id",
allow_none=True,
)
# Any relation information for the related event must be cleared.
if redacted_relates_to is not None:
self.store._invalidate_cache_and_stream(
txn, self.store.get_relations_for_event, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
txn, self.store.get_applicable_edit, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_summary, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)

self.db_pool.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
Expand Down
207 changes: 165 additions & 42 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,21 @@ def test_background_update(self) -> None:


class RelationRedactionTestCase(BaseRelationsTestCase):
"""Test the behaviour of relations when the parent or child event is redacted."""
"""
Test the behaviour of relations when the parent or child event is redacted.

The behaviour of each relation type is subtly different which causes the tests
to be a bit repetitive, they follow a naming scheme of:

test_redact_(relation|parent)_{relation_type}

The first bit of "relation" means that the event with the relation defined
on it (the child event) is to be redacted. A "parent" means that the target
of the relation (the parent event) is to be redacted.

The relation_type describes which type of relation is under test (i.e. it is
related to the value of rel_type in the event content).
"""

def _redact(self, event_id: str) -> None:
channel = self.make_request(
Expand All @@ -1284,9 +1298,53 @@ def _redact(self, event_id: str) -> None:
)
self.assertEqual(200, channel.code, channel.json_body)

def _make_relation_requests(self) -> Tuple[List[str], JsonDict]:
"""
Makes requests and ensures they result in a 200 response, returns a
tuple of results:

1. `/relations` -> Returns a list of event IDs.
2. `/event` -> Returns the response's m.relations field (from unsigned),
if it exists.
"""

# Request the relations of the event.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]

# Fetch the bundled aggregations of the event.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
bundled_relations = channel.json_body["unsigned"].get("m.relations", {})

return event_ids, bundled_relations

def _get_aggregations(self) -> List[JsonDict]:
"""Request /aggregations on the parent ID and includes the returned chunk."""
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
return channel.json_body["chunk"]

def test_redact_relation_annotation(self) -> None:
"""Test that annotations of an event are properly handled after the
"""
Test that annotations of an event are properly handled after the
annotation is redacted.

The redacted relation should not be included in bundled aggregations or
the response to relations.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEqual(200, channel.code, channel.json_body)
Expand All @@ -1296,24 +1354,97 @@ def test_redact_relation_annotation(self) -> None:
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEqual(200, channel.code, channel.json_body)
unredacted_event_id = channel.json_body["event_id"]

# Both relations should exist.
event_ids, relations = self._make_relation_requests()
self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertEquals(
relations["m.annotation"],
{"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
)

# Both relations appear in the aggregation.
chunk = self._get_aggregations()
self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])

# Redact one of the reactions.
self._redact(to_redact_event_id)

# Ensure that the aggregations are correct.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
access_token=self.user_token,
# The unredacted relation should still exist.
event_ids, relations = self._make_relation_requests()
self.assertEquals(event_ids, [unredacted_event_id])
self.assertEquals(
relations["m.annotation"],
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)

# The unredacted aggregation should still exist.
chunk = self._get_aggregations()
self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_redact_relation_thread(self) -> None:
"""
Test that thread replies are properly handled after the thread reply redacted.

The redacted event should not be included in bundled aggregations or
the response to relations.
"""
channel = self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
content={"body": "reply 1", "msgtype": "m.text"},
)
self.assertEqual(200, channel.code, channel.json_body)
unredacted_event_id = channel.json_body["event_id"]

# Note that the *last* event in the thread is redacted, as that gets
# included in the bundled aggregation.
channel = self._send_relation(
RelationTypes.THREAD,
EventTypes.Message,
content={"body": "reply 2", "msgtype": "m.text"},
)
self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]

# Both relations exist.
event_ids, relations = self._make_relation_requests()
self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
self.assertDictContainsSubset(
{
"count": 2,
"current_user_participated": True,
},
relations[RelationTypes.THREAD],
)
# And the latest event returned is the event that will be redacted.
self.assertEqual(
channel.json_body,
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
to_redact_event_id,
)

def test_redact_relation_edit(self) -> None:
# Redact one of the reactions.
self._redact(to_redact_event_id)

# The unredacted relation should still exist.
event_ids, relations = self._make_relation_requests()
self.assertEquals(event_ids, [unredacted_event_id])
self.assertDictContainsSubset(
{
"count": 1,
"current_user_participated": True,
},
relations[RelationTypes.THREAD],
)
# And the latest event is now the unredacted event.
self.assertEqual(
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
unredacted_event_id,
)

def test_redact_parent_edit(self) -> None:
"""Test that edits of an event are redacted when the original event
is redacted.
"""
Expand All @@ -1331,51 +1462,43 @@ def test_redact_relation_edit(self) -> None:
self.assertEqual(200, channel.code, channel.json_body)

# Check the relation is returned
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations"
f"/{self.parent_id}/m.replace/m.room.message",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

self.assertIn("chunk", channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1)
event_ids, relations = self._make_relation_requests()
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.REPLACE, relations)

# Redact the original event
self._redact(self.parent_id)

# Try to check for remaining m.replace relations
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations"
f"/{self.parent_id}/m.replace/m.room.message",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)

# Check that no relations are returned
self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
# The relations are not returned.
event_ids, relations = self._make_relation_requests()
self.assertEqual(len(event_ids), 0)
self.assertEqual(relations, {})

def test_redact_parent(self) -> None:
def test_redact_parent_annotation(self) -> None:
"""Test that annotations of an event are redacted when the original event
is redacted.
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEqual(200, channel.code, channel.json_body)

# The relations should exist.
event_ids, relations = self._make_relation_requests()
self.assertEqual(len(event_ids), 1)
self.assertIn(RelationTypes.ANNOTATION, relations)

# The aggregation should exist.
chunk = self._get_aggregations()
self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])

# Redact the original event.
self._redact(self.parent_id)

# Check that aggregations returns zero
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
# The relations are not returned.
event_ids, relations = self._make_relation_requests()
self.assertEqual(event_ids, [])
self.assertEqual(relations, {})

self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
# There's nothing to aggregate.
chunk = self._get_aggregations()
self.assertEqual(chunk, [])