Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve tests for get_unread_push_actions_for_user_in_range #13893

Merged
merged 5 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13893.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
35 changes: 21 additions & 14 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,18 @@ def f(txn: LoggingTransaction) -> List[str]:

def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
) -> List[Tuple[str, int]]:
) -> Dict[str, int]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.

Args:
txn:
user_id: The user to fetch receipts for.

Returns:
A map of room ID to stream ordering for all rooms the user has a receipt in.
"""
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
Expand All @@ -580,7 +591,7 @@ def _get_receipts_by_room_txn(

args.extend((user_id,))
txn.execute(sql, args)
return cast(List[Tuple[str, int]], txn.fetchall())
return dict(cast(List[Tuple[str, int]], txn.fetchall()))
clokep marked this conversation as resolved.
Show resolved Hide resolved

async def get_unread_push_actions_for_user_in_range_for_http(
self,
Expand All @@ -605,12 +616,10 @@ async def get_unread_push_actions_for_user_in_range_for_http(
The list will have between 0~limit entries.
"""

receipts_by_room = dict(
await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
),
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)

def get_push_actions_txn(
Expand Down Expand Up @@ -679,12 +688,10 @@ async def get_unread_push_actions_for_user_in_range_for_email(
The list will have between 0~limit entries.
"""

receipts_by_room = dict(
await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
),
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)

def get_push_actions_txn(
Expand Down
88 changes: 72 additions & 16 deletions tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

from twisted.test.proto_helpers import MemoryReactor

from synapse.rest import admin
Expand All @@ -22,8 +24,6 @@

from tests.unittest import HomeserverTestCase

USER_ID = "@user:example.com"


class EventPushActionsStoreTestCase(HomeserverTestCase):
servlets = [
Expand All @@ -38,21 +38,13 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
assert persist_events_store is not None
self.persist_events_store = persist_events_store

def test_get_unread_push_actions_for_user_in_range_for_http(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20
)
)
def _create_users_and_room(self) -> Tuple[str, str, str, str, str]:
"""
Creates two users and a shared room.

def test_get_unread_push_actions_for_user_in_range_for_email(self) -> None:
self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20
)
)

def test_count_aggregation(self) -> None:
Returns:
Tuple of (user 1 ID, user 1 token, user 2 ID, user 2 token, room ID).
"""
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
Expand All @@ -65,6 +57,70 @@ def test_count_aggregation(self) -> None:
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)

return user_id, token, other_id, other_token, room_id

def test_get_unread_push_actions_for_user_in_range(self) -> None:
"""Test getting unread push actions for HTTP and email pushers."""
user_id, token, _, other_token, room_id = self._create_users_and_room()

# Create two events, one of which is a highlight.
self.helper.send_event(
room_id,
type="m.room.message",
content={"msgtype": "m.text", "body": "msg"},
tok=other_token,
)
event_id = self.helper.send_event(
room_id,
type="m.room.message",
content={"msgtype": "m.text", "body": user_id},
tok=other_token,
)["event_id"]

# Fetch unread actions for HTTP pushers.
http_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
user_id, 0, 1000, 20
)
)
self.assertEqual(2, len(http_actions))

# Fetch unread actions for email pushers.
email_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
user_id, 0, 1000, 20
)
)
self.assertEqual(2, len(email_actions))

# Send a receipt, which should clear any actions.
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=None,
data={},
)
)
http_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_http(
user_id, 0, 1000, 20
)
)
self.assertEqual([], http_actions)
email_actions = self.get_success(
self.store.get_unread_push_actions_for_user_in_range_for_email(
user_id, 0, 1000, 20
)
)
self.assertEqual([], email_actions)

def test_count_aggregation(self) -> None:
# Create a user to receive notifications and send receipts.
user_id, token, _, other_token, room_id = self._create_users_and_room()

last_event_id: str

def _assert_counts(noitf_count: int, highlight_count: int) -> None:
Expand Down