diff --git a/changelog.d/6409.feature b/changelog.d/6409.feature new file mode 100644 index 000000000000..653ff5a5ad6a --- /dev/null +++ b/changelog.d/6409.feature @@ -0,0 +1 @@ +Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). diff --git a/synapse/api/constants.py b/synapse/api/constants.py index e3f086f1c361..69cef369a5bb 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -147,3 +147,7 @@ class EventContentFields(object): # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326 LABELS = "org.matrix.labels" + + # Timestamp to delete the event after + # cf https://github.com/matrix-org/matrix-doc/pull/2228 + SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after" diff --git a/synapse/config/server.py b/synapse/config/server.py index 7a9d7116692d..837fbe158285 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -490,6 +490,8 @@ class LimitRemoteRoomsConfig(object): "cleanup_extremities_with_dummy_events", True ) + self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False) + def has_tls_listener(self) -> bool: return any(l["tls"] for l in self.listeners) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d3267734f7c1..d9d0cd9eef3d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -121,6 +121,7 @@ def __init__(self, hs): self.pusher_pool = hs.get_pusherpool() self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self._message_handler = hs.get_message_handler() self._server_notices_mxid = hs.config.server_notices_mxid self.config = hs.config self.http_client = hs.get_simple_http_client() @@ -141,6 +142,8 @@ def __init__(self, hs): self.third_party_event_rules = hs.get_third_party_event_rules() + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): """ Process a PDU received via a federation /send/ transaction, or @@ -2715,6 +2718,11 @@ def persist_events_and_notify(self, event_and_contexts, backfilled=False): event_and_contexts, backfilled=backfilled ) + if self._ephemeral_messages_enabled: + for (event, context) in event_and_contexts: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: yield self._notify_persisted_event(event, max_stream_id) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 3b0156f516b2..4f53a5f5dc61 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from typing import Optional from six import iteritems, itervalues, string_types @@ -22,9 +23,16 @@ from twisted.internet import defer from twisted.internet.defer import succeed +from twisted.internet.interfaces import IDelayedCall from synapse import event_auth -from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + Membership, + RelationTypes, + UserTypes, +) from synapse.api.errors import ( AuthError, Codes, @@ -62,6 +70,17 @@ def __init__(self, hs): self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._is_worker_app = bool(hs.config.worker_app) + + # The scheduled call to self._expire_event. None if no call is currently + # scheduled. + self._scheduled_expiry = None # type: Optional[IDelayedCall] + + if not hs.config.worker_app: + run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry + ) @defer.inlineCallbacks def get_room_data( @@ -225,6 +244,100 @@ def get_joined_members(self, requester, room_id): for user_id, profile in iteritems(users_with_profile) } + def maybe_schedule_expiry(self, event): + """Schedule the expiry of an event if there's not already one scheduled, + or if the one running is for an event that will expire after the provided + timestamp. + + This function needs to invalidate the event cache, which is only possible on + the master process, and therefore needs to be run on there. + + Args: + event (EventBase): The event to schedule the expiry of. + """ + assert not self._is_worker_app + + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if not isinstance(expiry_ts, int) or event.is_state(): + return + + # _schedule_expiry_for_event won't actually schedule anything if there's already + # a task scheduled for a timestamp that's sooner than the provided one. + self._schedule_expiry_for_event(event.event_id, expiry_ts) + + @defer.inlineCallbacks + def _schedule_next_expiry(self): + """Retrieve the ID and the expiry timestamp of the next event to be expired, + and schedule an expiry task for it. + + If there's no event left to expire, set _expiry_scheduled to None so that a + future call to save_expiry_ts can schedule a new expiry task. + """ + # Try to get the expiry timestamp of the next event to expire. + res = yield self.store.get_next_event_to_expire() + if res: + event_id, expiry_ts = res + self._schedule_expiry_for_event(event_id, expiry_ts) + + def _schedule_expiry_for_event(self, event_id, expiry_ts): + """Schedule an expiry task for the provided event if there's not already one + scheduled at a timestamp that's sooner than the provided one. + + Args: + event_id (str): The ID of the event to expire. + expiry_ts (int): The timestamp at which to expire the event. + """ + if self._scheduled_expiry: + # If the provided timestamp refers to a time before the scheduled time of the + # next expiry task, cancel that task and reschedule it for this timestamp. + next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000 + if expiry_ts < next_scheduled_expiry_ts: + self._scheduled_expiry.cancel() + else: + return + + # Figure out how many seconds we need to wait before expiring the event. + now_ms = self.clock.time_msec() + delay = (expiry_ts - now_ms) / 1000 + + # callLater doesn't support negative delays, so trim the delay to 0 if we're + # in that case. + if delay < 0: + delay = 0 + + logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay) + + self._scheduled_expiry = self.clock.call_later( + delay, + run_as_background_process, + "_expire_event", + self._expire_event, + event_id, + ) + + @defer.inlineCallbacks + def _expire_event(self, event_id): + """Retrieve and expire an event that needs to be expired from the database. + + If the event doesn't exist in the database, log it and delete the expiry date + from the database (so that we don't try to expire it again). + """ + assert self._ephemeral_events_enabled + + self._scheduled_expiry = None + + logger.info("Expiring event %s", event_id) + + try: + # Expire the event if we know about it. This function also deletes the expiry + # date from the database in the same database transaction. + yield self.store.expire_event(event_id) + except Exception as e: + logger.error("Could not expire event %s: %r", event_id, e) + + # Schedule the expiry of the next event to expire. + yield self._schedule_next_expiry() + # The duration (in ms) after which rooms should be removed # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try @@ -295,6 +408,10 @@ def __init__(self, hs): 5 * 60 * 1000, ) + self._message_handler = hs.get_message_handler() + + self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def create_event( self, @@ -877,6 +994,10 @@ def is_inviter_member_event(e): event, context=context ) + if self._ephemeral_events_enabled: + # If there's an expiry timestamp on the event, schedule its expiry. + self._message_handler.maybe_schedule_expiry(event) + yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) def _notify(): diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 2737a1d3aeca..79c91fe284ff 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -130,6 +130,8 @@ def _censor_redactions(): if self.hs.config.redaction_retention_period is not None: hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000) + self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + @defer.inlineCallbacks def _read_forward_extremities(self): def fetch(txn): @@ -940,6 +942,12 @@ def _update_metadata_tables_txn( txn, event.event_id, labels, event.room_id, event.depth ) + if self._ephemeral_messages_enabled: + # If there's an expiry timestamp on the event, store it. + expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER) + if isinstance(expiry_ts, int) and not event.is_state(): + self._insert_event_expiry_txn(txn, event.event_id, expiry_ts) + # Insert into the room_memberships table. self._store_room_members_txn( txn, @@ -1101,12 +1109,7 @@ def _censor_redactions(self): def _update_censor_txn(txn): for redaction_id, event_id, pruned_json in updates: if pruned_json: - self._simple_update_one_txn( - txn, - table="event_json", - keyvalues={"event_id": event_id}, - updatevalues={"json": pruned_json}, - ) + self._censor_event_txn(txn, event_id, pruned_json) self._simple_update_one_txn( txn, @@ -1117,6 +1120,22 @@ def _update_censor_txn(txn): yield self.runInteraction("_update_censor_txn", _update_censor_txn) + def _censor_event_txn(self, txn, event_id, pruned_json): + """Censor an event by replacing its JSON in the event_json table with the + provided pruned JSON. + + Args: + txn (LoggingTransaction): The database transaction. + event_id (str): The ID of the event to censor. + pruned_json (str): The pruned JSON + """ + self._simple_update_one_txn( + txn, + table="event_json", + keyvalues={"event_id": event_id}, + updatevalues={"json": pruned_json}, + ) + @defer.inlineCallbacks def count_daily_messages(self): """ @@ -1957,6 +1976,101 @@ def insert_labels_for_event_txn( ], ) + def _insert_event_expiry_txn(self, txn, event_id, expiry_ts): + """Save the expiry timestamp associated with a given event ID. + + Args: + txn (LoggingTransaction): The database transaction to use. + event_id (str): The event ID the expiry timestamp is associated with. + expiry_ts (int): The timestamp at which to expire (delete) the event. + """ + return self._simple_insert_txn( + txn=txn, + table="event_expiry", + values={"event_id": event_id, "expiry_ts": expiry_ts}, + ) + + @defer.inlineCallbacks + def expire_event(self, event_id): + """Retrieve and expire an event that has expired, and delete its associated + expiry timestamp. If the event can't be retrieved, delete its associated + timestamp so we don't try to expire it again in the future. + + Args: + event_id (str): The ID of the event to delete. + """ + # Try to retrieve the event's content from the database or the event cache. + event = yield self.get_event(event_id) + + def delete_expired_event_txn(txn): + # Delete the expiry timestamp associated with this event from the database. + self._delete_event_expiry_txn(txn, event_id) + + if not event: + # If we can't find the event, log a warning and delete the expiry date + # from the database so that we don't try to expire it again in the + # future. + logger.warning( + "Can't expire event %s because we don't have it.", event_id + ) + return + + # Prune the event's dict then convert it to JSON. + pruned_json = encode_json(prune_event_dict(event.get_dict())) + + # Update the event_json table to replace the event's JSON with the pruned + # JSON. + self._censor_event_txn(txn, event.event_id, pruned_json) + + # We need to invalidate the event cache entry for this event because we + # changed its content in the database. We can't call + # self._invalidate_cache_and_stream because self.get_event_cache isn't of the + # right type. + txn.call_after(self._get_event_cache.invalidate, (event.event_id,)) + # Send that invalidation to replication so that other workers also invalidate + # the event cache. + self._send_invalidation_to_replication( + txn, "_get_event_cache", (event.event_id,) + ) + + yield self.runInteraction("delete_expired_event", delete_expired_event_txn) + + def _delete_event_expiry_txn(self, txn, event_id): + """Delete the expiry timestamp associated with an event ID without deleting the + actual event. + + Args: + txn (LoggingTransaction): The transaction to use to perform the deletion. + event_id (str): The event ID to delete the associated expiry timestamp of. + """ + return self._simple_delete_txn( + txn=txn, table="event_expiry", keyvalues={"event_id": event_id} + ) + + def get_next_event_to_expire(self): + """Retrieve the entry with the lowest expiry timestamp in the event_expiry + table, or None if there's no more event to expire. + + Returns: Deferred[Optional[Tuple[str, int]]] + A tuple containing the event ID as its first element and an expiry timestamp + as its second one, if there's at least one row in the event_expiry table. + None otherwise. + """ + + def get_next_event_to_expire_txn(txn): + txn.execute( + """ + SELECT event_id, expiry_ts FROM event_expiry + ORDER BY expiry_ts ASC LIMIT 1 + """ + ) + + return txn.fetchone() + + return self.runInteraction( + desc="get_next_event_to_expire", func=get_next_event_to_expire_txn + ) + AllNewEventsResult = namedtuple( "AllNewEventsResult", diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql new file mode 100644 index 000000000000..81a36a8b1dc9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS event_expiry ( + event_id TEXT PRIMARY KEY, + expiry_ts BIGINT NOT NULL +); + +CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts); diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py new file mode 100644 index 000000000000..5e9c07ebf3ff --- /dev/null +++ b/tests/rest/client/test_ephemeral_message.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventContentFields, EventTypes +from synapse.rest import admin +from synapse.rest.client.v1 import room + +from tests import unittest + + +class EphemeralMessageTestCase(unittest.HomeserverTestCase): + + user_id = "@user:test" + + servlets = [ + admin.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + + config["enable_ephemeral_messages"] = True + + self.hs = self.setup_test_homeserver(config=config) + return self.hs + + def prepare(self, reactor, clock, homeserver): + self.room_id = self.helper.create_room_as(self.user_id) + + def test_message_expiry_no_delay(self): + """Tests that sending a message sent with a m.self_destruct_after field set to the + past results in that event being deleted right away. + """ + # Send a message in the room that has expired. From here, the reactor clock is + # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock + # is at 0ms the code path is the same if the event's expiry timestamp is the + # current timestamp. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: 0, + }, + ) + event_id = res["event_id"] + + # Check that we can't retrieve the content of the event. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def test_message_expiry_delay(self): + """Tests that sending a message with a m.self_destruct_after field set to the + future results in that event not being deleted right away, but advancing the + clock to after that expiry timestamp causes the event to be deleted. + """ + # Send a message in the room that'll expire in 1s. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "hello", + EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000, + }, + ) + event_id = res["event_id"] + + # Check that we can retrieve the content of the event before it has expired. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertTrue(bool(event_content), event_content) + + # Advance the clock to after the deletion. + self.reactor.advance(1) + + # Check that we can't retrieve the content of the event anymore. + event_content = self.get_event(self.room_id, event_id)["content"] + self.assertFalse(bool(event_content), event_content) + + def get_event(self, room_id, event_id, expected_code=200): + url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id) + + request, channel = self.make_request("GET", url) + self.render(request) + + self.assertEqual(channel.code, expected_code, channel.result) + + return channel.json_body