diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 7cc500e8d3c7..660ebd81259c 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -427,8 +427,8 @@ def _get_unread_counts_by_pos_txn( room_id: The room ID to get unread counts for. user_id: The user ID to get unread counts for. receipt_stream_ordering: The stream ordering of the user's latest - receipt in the room. If there are no receipts, the stream ordering - of the user's join event. + unthreaded receipt in the room. If there are no unthreaded receipts, + the stream ordering of the user's join event. Returns: A RoomNotifCounts object containing the notification count, the @@ -439,6 +439,12 @@ def _get_unread_counts_by_pos_txn( counts = NotifCounts() thread_counts = {} + receipt_types_clause, receipts_args = make_in_list_sql_clause( + self.database_engine, + "receipt_type", + (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), + ) + # First we pull the counts from the summary table. # # We check that `last_receipt_stream_ordering` matches the stream @@ -453,19 +459,38 @@ def _get_unread_counts_by_pos_txn( # updated `event_push_summary` synchronously when persisting a new read # receipt). txn.execute( - """ - SELECT stream_ordering, notif_count, COALESCE(unread_count, 0), thread_id + f""" + SELECT notif_count, COALESCE(unread_count, 0), thread_id FROM event_push_summary + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE room_id = ? AND user_id = ? AND ( - (last_receipt_stream_ordering IS NULL AND stream_ordering > ?) - OR last_receipt_stream_ordering = ? + (last_receipt_stream_ordering IS NULL AND stream_ordering > MAX(COALESCE(receipt_stream_ordering, 0), ?)) + OR last_receipt_stream_ordering = MAX(COALESCE(receipt_stream_ordering, 0), ?) ) """, - (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering), + ( + user_id, + room_id, + *receipts_args, + room_id, + user_id, + receipt_stream_ordering, + receipt_stream_ordering, + ), ) - max_summary_stream_ordering = 0 - for summary_stream_ordering, notif_count, unread_count, thread_id in txn: + summarised_threads = set() + for notif_count, unread_count, thread_id in txn: + summarised_threads.add(thread_id) if thread_id == "main": counts = NotifCounts( notify_count=notif_count, unread_count=unread_count @@ -476,21 +501,36 @@ def _get_unread_counts_by_pos_txn( notify_count=notif_count, unread_count=unread_count ) - # XXX All threads should have the same stream ordering? - max_summary_stream_ordering = max( - summary_stream_ordering, max_summary_stream_ordering - ) - # Next we need to count highlights, which aren't summarised - sql = """ + sql = f""" SELECT COUNT(*), thread_id FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) WHERE user_id = ? AND room_id = ? - AND stream_ordering > ? + AND stream_ordering > MAX(COALESCE(receipt_stream_ordering, 0), ?) AND highlight = 1 GROUP BY thread_id """ - txn.execute(sql, (user_id, room_id, receipt_stream_ordering)) + txn.execute( + sql, + ( + user_id, + room_id, + *receipts_args, + user_id, + room_id, + receipt_stream_ordering, + ), + ) for highlight_count, thread_id in txn: if thread_id == "main": counts.highlight_count += highlight_count @@ -502,18 +542,82 @@ def _get_unread_counts_by_pos_txn( notify_count=0, unread_count=0, highlight_count=highlight_count ) + # For threads which were summarised we need to count actions since the last + # rotation. + thread_id_clause, thread_id_args = make_in_list_sql_clause( + self.database_engine, "thread_id", summarised_threads + ) + + # The (inclusive) event stream ordering that was previously summarised. + rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + unread_counts = self._get_notif_unread_count_for_user_room( + txn, room_id, user_id, rotated_upto_stream_ordering + ) + for notif_count, unread_count, thread_id in unread_counts: + if thread_id not in summarised_threads: + continue + + if thread_id == "main": + counts.notify_count += notif_count + counts.unread_count += unread_count + elif thread_id in thread_counts: + thread_counts[thread_id].notify_count += notif_count + thread_counts[thread_id].unread_count += unread_count + else: + # Previous thread summaries of 0 are discarded above. + # + # TODO If empty summaries are deleted this can be removed. + thread_counts[thread_id] = NotifCounts( + notify_count=notif_count, + unread_count=unread_count, + highlight_count=0, + ) + # Finally we need to count push actions that aren't included in the # summary returned above. This might be due to recent events that haven't # been summarised yet or the summary is out of date due to a recent read # receipt. - start_unread_stream_ordering = max( - receipt_stream_ordering, max_summary_stream_ordering - ) - unread_counts = self._get_notif_unread_count_for_user_room( - txn, room_id, user_id, start_unread_stream_ordering + sql = f""" + SELECT + COUNT(CASE WHEN notif = 1 THEN 1 END), + COUNT(CASE WHEN unread = 1 THEN 1 END), + thread_id + FROM event_push_actions + LEFT JOIN ( + SELECT thread_id, MAX(stream_ordering) AS receipt_stream_ordering + FROM receipts_linearized + LEFT JOIN events USING (room_id, event_id) + WHERE + user_id = ? + AND room_id = ? + AND {receipt_types_clause} + GROUP BY thread_id + ) AS receipts USING (thread_id) + WHERE user_id = ? + AND room_id = ? + AND stream_ordering > MAX(COALESCE(receipt_stream_ordering, 0), ?) + AND NOT {thread_id_clause} + GROUP BY thread_id + """ + txn.execute( + sql, + ( + user_id, + room_id, + *receipts_args, + user_id, + room_id, + receipt_stream_ordering, + *thread_id_args, + ), ) - - for notif_count, unread_count, thread_id in unread_counts: + for notif_count, unread_count, thread_id in txn: if thread_id == "main": counts.notify_count += notif_count counts.unread_count += unread_count @@ -536,6 +640,7 @@ def _get_notif_unread_count_for_user_room( user_id: str, stream_ordering: int, max_stream_ordering: Optional[int] = None, + thread_id: Optional[str] = None, ) -> List[Tuple[int, int, str]]: """Returns the notify and unread counts from `event_push_actions` for the given user/room in the given range. @@ -561,10 +666,10 @@ def _get_notif_unread_count_for_user_room( if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): return [] - clause = "" + stream_ordering_clause = "" args = [user_id, room_id, stream_ordering] if max_stream_ordering is not None: - clause = "AND ea.stream_ordering <= ?" + stream_ordering_clause = "AND ea.stream_ordering <= ?" args.append(max_stream_ordering) # If the max stream ordering is less than the min stream ordering, @@ -572,6 +677,11 @@ def _get_notif_unread_count_for_user_room( if max_stream_ordering <= stream_ordering: return [] + thread_id_clause = "" + if thread_id is not None: + thread_id_clause = "AND thread_id = ?" + args.append(thread_id) + sql = f""" SELECT COUNT(CASE WHEN notif = 1 THEN 1 END), @@ -581,7 +691,8 @@ def _get_notif_unread_count_for_user_room( WHERE user_id = ? AND room_id = ? AND ea.stream_ordering > ? - {clause} + {stream_ordering_clause} + {thread_id_clause} GROUP BY thread_id """ @@ -1087,7 +1198,7 @@ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool: ) sql = """ - SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering + SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering FROM receipts_linearized AS r INNER JOIN events AS e USING (event_id) WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? @@ -1110,13 +1221,18 @@ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool: ) rows = txn.fetchall() - # For each new read receipt we delete push actions from before it and - # recalculate the summary. - for _, room_id, user_id, stream_ordering in rows: + # First handle all the rows without a thread ID (i.e. ones that apply to + # the entire room). + for _, room_id, user_id, thread_id, stream_ordering in rows: # Only handle our own read receipts. if not self.hs.is_mine_id(user_id): continue + if thread_id is not None: + continue + + # For each new read receipt we delete push actions from before it and + # recalculate the summary. txn.execute( """ DELETE FROM event_push_actions @@ -1158,6 +1274,64 @@ def _handle_new_receipts_for_notifs_txn(self, txn: LoggingTransaction) -> bool: value_values=[(row[0], row[1]) for row in unread_counts], ) + # For each new read receipt we delete push actions from before it and + # recalculate the summary. + for _, room_id, user_id, thread_id, stream_ordering in rows: + # Only handle our own read receipts. + if not self.hs.is_mine_id(user_id): + continue + + if thread_id is None: + continue + + # For each new read receipt we delete push actions from before it and + # recalculate the summary. + txn.execute( + """ + DELETE FROM event_push_actions + WHERE room_id = ? + AND user_id = ? + AND thread_id = ? + AND stream_ordering <= ? + AND highlight = 0 + """, + (room_id, user_id, thread_id, stream_ordering), + ) + + # Fetch the notification counts between the stream ordering of the + # latest receipt and what was previously summarised. + unread_counts = self._get_notif_unread_count_for_user_room( + txn, + room_id, + user_id, + stream_ordering, + old_rotate_stream_ordering, + thread_id, + ) + # unread_counts will be a list of 0 or 1 items. + if unread_counts: + notif_count, unread_count, _ = unread_counts[0] + else: + notif_count = 0 + unread_count = 0 + + # Update the summary of this specific thread. + self.db_pool.simple_upsert_txn( + txn, + table="event_push_summary", + keyvalues={ + "room_id": room_id, + "user_id": user_id, + "thread_id": thread_id, + }, + values={ + "notif_count": notif_count, + "unread_count": unread_count, + "stream_ordering": old_rotate_stream_ordering, + "last_receipt_stream_ordering": stream_ordering, + }, + ) + # We always update `event_push_summary_last_receipt_stream_id` to # ensure that we don't rescan the same receipts for remote users. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 52fe0db92405..c7e812224869 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -170,8 +170,8 @@ def get_last_receipt_for_user_txn( receipt_types: Collection[str], ) -> Optional[Tuple[str, int]]: """ - Fetch the event ID and stream_ordering for the latest receipt in a room - with one of the given receipt types. + Fetch the event ID and stream_ordering for the latest unthreaded receipt + in a room with one of the given receipt types. Args: user_id: The user to fetch receipts for. @@ -193,6 +193,7 @@ def get_last_receipt_for_user_txn( WHERE {clause} AND user_id = ? AND room_id = ? + AND thread_id IS NULL ORDER BY stream_ordering DESC LIMIT 1 """ diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 4201e3309310..eef4cb395a5b 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -268,7 +268,7 @@ def _create_event( def _rotate() -> None: self.get_success(self.store._rotate_notifs()) - def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + def _mark_read(event_id: str, thread_id: str = "main") -> None: self.get_success( self.store.insert_receipt( room_id, @@ -304,9 +304,12 @@ def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: _create_event() _create_event(thread_id=thread_id) _mark_read(event_id) + _assert_counts(1, 0, 3, 0) + _mark_read(event_id, thread_id) _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event() @@ -320,6 +323,7 @@ def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event(True) @@ -345,8 +349,190 @@ def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: # Check that sending read receipts at different points results in the # right counts. _mark_read(event_id) + _assert_counts(1, 0, 2, 1) + _mark_read(event_id, thread_id) _assert_counts(1, 0, 1, 0) _mark_read(last_event_id) + _assert_counts(0, 0, 1, 0) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _mark_read(last_event_id) + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + _rotate() + _assert_counts(0, 0, 0, 0) + + def test_count_aggregation_mixed(self) -> None: + """ + This is essentially the same test as test_count_aggregation_threads, but + sends both unthreaded and threaded receipts. + """ + + # Create a user to receive notifications and send receipts. + user_id = self.register_user("user1235", "pass") + token = self.login("user1235", "pass") + + # And another users to send events. + other_id = self.register_user("other", "pass") + other_token = self.login("other", "pass") + + # Create a room and put both users in it. + room_id = self.helper.create_room_as(user_id, tok=token) + self.helper.join(room_id, other_id, tok=other_token) + thread_id: str + + last_event_id: str + + def _assert_counts( + noitf_count: int, + highlight_count: int, + thread_notif_count: int, + thread_highlight_count: int, + ) -> None: + counts = self.get_success( + self.store.db_pool.runInteraction( + "get-unread-counts", + self.store._get_unread_counts_by_receipt_txn, + room_id, + user_id, + ) + ) + self.assertEqual( + counts.main_timeline, + NotifCounts( + notify_count=noitf_count, + unread_count=0, + highlight_count=highlight_count, + ), + ) + if thread_notif_count or thread_highlight_count: + self.assertEqual( + counts.threads, + { + thread_id: NotifCounts( + notify_count=thread_notif_count, + unread_count=0, + highlight_count=thread_highlight_count, + ), + }, + ) + else: + self.assertEqual(counts.threads, {}) + + def _create_event( + highlight: bool = False, thread_id: Optional[str] = None + ) -> str: + content: JsonDict = { + "msgtype": "m.text", + "body": user_id if highlight else "msg", + } + if thread_id: + content["m.relates_to"] = { + "rel_type": "m.thread", + "event_id": thread_id, + } + + result = self.helper.send_event( + room_id, + type="m.room.message", + content=content, + tok=other_token, + ) + nonlocal last_event_id + last_event_id = result["event_id"] + return last_event_id + + def _rotate() -> None: + self.get_success(self.store._rotate_notifs()) + + def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None: + self.get_success( + self.store.insert_receipt( + room_id, + "m.read", + user_id=user_id, + event_ids=[event_id], + thread_id=thread_id, + data={}, + ) + ) + + _assert_counts(0, 0, 0, 0) + thread_id = _create_event() + _assert_counts(1, 0, 0, 0) + _rotate() + _assert_counts(1, 0, 0, 0) + + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + _create_event() + _assert_counts(2, 0, 1, 0) + _rotate() + _assert_counts(2, 0, 1, 0) + + event_id = _create_event(thread_id=thread_id) + _assert_counts(2, 0, 2, 0) + _rotate() + _assert_counts(2, 0, 2, 0) + + _create_event() + _create_event(thread_id=thread_id) + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id, "main") + _mark_read(last_event_id, thread_id) + _assert_counts(0, 0, 0, 0) + + _create_event() + _create_event(thread_id=thread_id) + _assert_counts(1, 0, 1, 0) + _rotate() + _assert_counts(1, 0, 1, 0) + + # Delete old event push actions, this should not affect the (summarised) count. + self.get_success(self.store._remove_old_push_actions_that_have_rotated()) + _assert_counts(1, 0, 1, 0) + + _mark_read(last_event_id) + _assert_counts(0, 0, 0, 0) + + _create_event(True) + _assert_counts(1, 1, 0, 0) + _rotate() + _assert_counts(1, 1, 0, 0) + + event_id = _create_event(True, thread_id) + _assert_counts(1, 1, 1, 1) + _rotate() + _assert_counts(1, 1, 1, 1) + + # Check that adding another notification and rotating after highlight + # works. + _create_event() + _rotate() + _assert_counts(2, 1, 1, 1) + + _create_event(thread_id=thread_id) + _rotate() + _assert_counts(2, 1, 2, 1) + + # Check that sending read receipts at different points results in the + # right counts. + _mark_read(event_id) + _assert_counts(1, 0, 1, 0) + _mark_read(event_id, "main") + _assert_counts(1, 0, 1, 0) + _mark_read(last_event_id, "main") + _assert_counts(0, 0, 1, 0) + _mark_read(last_event_id, thread_id) _assert_counts(0, 0, 0, 0) _create_event(True)