From c9d2a3d97f41817e2a089e77666c6dfc787e54f7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 11 Jan 2023 10:11:40 -0500 Subject: [PATCH] Implement MSC3952: Intentional mentions --- changelog.d/14823.feature | 1 + rust/src/push/base_rules.rs | 16 ++++++ rust/src/push/evaluator.rs | 17 +++++- rust/src/push/mod.rs | 26 +++++++++ stubs/synapse/synapse_rust/push.pyi | 3 +- synapse/push/bulk_push_rule_evaluator.py | 13 +++++ tests/push/test_bulk_push_rule_evaluator.py | 60 ++++++++++++++++++++ tests/push/test_push_rule_evaluator.py | 63 ++++++++++++++++++--- 8 files changed, 190 insertions(+), 9 deletions(-) create mode 100644 changelog.d/14823.feature diff --git a/changelog.d/14823.feature b/changelog.d/14823.feature new file mode 100644 index 000000000000..8293e99effbc --- /dev/null +++ b/changelog.d/14823.feature @@ -0,0 +1 @@ +Experimental support for [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952): intentional mentions. diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 35129691ca43..8db7bc7db9c1 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -131,6 +131,14 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_user_mentioned"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsUserMention)]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"), priority_class: 5, @@ -139,6 +147,14 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, + PushRule { + rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mentioned"), + priority_class: 5, + conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsRoomMention)]), + actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"), priority_class: 5, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index c901c0fbcc60..57a4c2752ed9 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -68,6 +68,9 @@ pub struct PushRuleEvaluator { /// The "content.body", if any. body: String, + /// The mentions that were part of the message, note this has + mentions: BTreeSet, + /// The number of users in the room. room_member_count: u64, @@ -100,6 +103,7 @@ impl PushRuleEvaluator { #[new] pub fn py_new( flattened_keys: BTreeMap, + mentions: BTreeSet, room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, @@ -116,6 +120,7 @@ impl PushRuleEvaluator { Ok(PushRuleEvaluator { flattened_keys, body, + mentions, room_member_count, notification_power_levels, sender_power_level, @@ -229,6 +234,14 @@ impl PushRuleEvaluator { KnownCondition::RelatedEventMatch(event_match) => { self.match_related_event_match(event_match, user_id)? } + KnownCondition::IsUserMention => { + if let Some(uid) = user_id { + self.mentions.contains(uid) + } else { + false + } + } + KnownCondition::IsRoomMention => self.mentions.contains("@room"), KnownCondition::ContainsDisplayName => { if let Some(dn) = display_name { if !dn.is_empty() { @@ -424,6 +437,7 @@ fn push_rule_evaluator() { flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); let evaluator = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), 10, Some(0), BTreeMap::new(), @@ -449,6 +463,7 @@ fn test_requires_room_version_supports_condition() { let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()]; let evaluator = PushRuleEvaluator::py_new( flattened_keys, + BTreeSet::new(), 10, Some(0), BTreeMap::new(), diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 2e9d3e38a17b..19a7db6cb144 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -269,6 +269,10 @@ pub enum KnownCondition { EventMatch(EventMatchCondition), #[serde(rename = "im.nheko.msc3664.related_event_match")] RelatedEventMatch(RelatedEventMatchCondition), + #[serde(rename = "org.matrix.msc3952.is_user_mention")] + IsUserMention, + #[serde(rename = "org.matrix.msc3952.is_room_mention")] + IsRoomMention, ContainsDisplayName, RoomMemberCount { #[serde(skip_serializing_if = "Option::is_none")] @@ -514,6 +518,28 @@ fn test_deserialize_unstable_msc3931_condition() { )); } +#[test] +fn test_deserialize_unstable_msc3952_user_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsUserMention) + )); +} + +#[test] +fn test_deserialize_unstable_msc3952_room_condition() { + let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#; + + let condition: Condition = serde_json::from_str(json).unwrap(); + assert!(matches!( + condition, + Condition::Known(KnownCondition::IsRoomMention) + )); +} + #[test] fn test_deserialize_custom_condition() { let json = r#"{"kind":"custom_tag"}"#; diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 373b40740b37..739bf4409cef 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union from synapse.types import JsonDict @@ -54,6 +54,7 @@ class PushRuleEvaluator: def __init__( self, flattened_keys: Mapping[str, str], + mentions: Set[str], room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f27ba64d5365..05347d2b753a 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -342,8 +342,21 @@ async def _action_for_event_by_user( for user_id, level in notification_levels.items(): notification_levels[user_id] = int(level) + # Pull out the mentions field if it exists and trim the values to things + # that might be valid. + mentions_raw = event.content.get("mentions") + if isinstance(mentions_raw, list): + # Take the first 10 items, then strip out any non-string ones and convert + # to a tuple. + mentions = set( + filter(lambda item: isinstance(item, str), mentions_raw[:10]) + ) + else: + mentions = set() + evaluator = PushRuleEvaluator( _flatten_dict(event, room_version=event.room_version), + mentions, room_member_count, sender_power_level, notification_levels, diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 9c17a42b650f..d4cdb8728413 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -126,3 +127,62 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None: # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) bulk_evaluator._action_for_event_by_user.assert_not_called() + + def test_mentions(self) -> None: + """Test the behavior of an event which includes invalid mentions.""" + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + + sentinel = object() + + def create_and_process(mentions: Any = sentinel) -> bool: + content = {} + if mentions is not sentinel: + content["mentions"] = mentions + + # Create a new message event which should cause a notification. + event, context = self.get_success( + self.event_creation_handler.create_event( + self.requester, + { + "type": "test", + "room_id": self.room_id, + "content": content, + "sender": f"@bob:{self.hs.hostname}", + }, + ) + ) + + # Ensure no actions are generated! + self.get_success( + bulk_evaluator.action_for_events_by_user([(event, context)]) + ) + + # If any actions are generated for this event, return true. + result = self.get_success( + self.hs.get_datastores().main.db_pool.simple_select_list( + table="event_push_actions_staging", + keyvalues={"event_id": event.event_id}, + retcols=("*",), + desc="get_event_push_actions_staging", + ) + ) + return len(result) > 0 + + # Not including the mentions field should result in no notifications. + self.assertFalse(create_and_process()) + + # Invalid data should be ignored. + mentions: Any + for mentions in (None, True, False, "foo", {}): + self.assertFalse(create_and_process(mentions)) + + # The Matrix ID appearing anywhere in the mentions list should match + self.assertTrue(create_and_process([self.alice])) + self.assertTrue(create_and_process(["@another:test", self.alice])) + + # The Matrix ID appearing > 10 entries into the list should be ignored. + self.assertFalse(create_and_process(["@another:test"] * 10 + [self.alice])) + + # Invalid entries in the list are ignored, but count towards the limit. + self.assertTrue(create_and_process([None, True, False, {}, [], self.alice])) + self.assertFalse(create_and_process([None] * 10 + [self.alice])) diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 1b87756b751a..0441dd76a796 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union, cast +from typing import Dict, List, Optional, Set, Union, cast import frozendict @@ -39,7 +39,11 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): def _get_evaluator( - self, content: JsonMapping, related_events: Optional[JsonDict] = None + self, + content: JsonMapping, + *, + mentions: Optional[Set[str]] = None, + related_events: Optional[JsonDict] = None, ) -> PushRuleEvaluator: event = FrozenEvent( { @@ -57,13 +61,14 @@ def _get_evaluator( power_levels: Dict[str, Union[int, Dict[str, int]]] = {} return PushRuleEvaluator( _flatten_dict(event), + mentions or set(), room_member_count, sender_power_level, cast(Dict[str, int], power_levels.get("notifications", {})), {} if related_events is None else related_events, - True, - event.room_version.msc3931_push_features, - True, + related_event_match_enabled=True, + room_version_feature_flags=event.room_version.msc3931_push_features, + msc3931_enabled=True, ) def test_display_name(self) -> None: @@ -90,6 +95,50 @@ def test_display_name(self) -> None: # A display name with spaces should work fine. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar")) + def test_user_mentions(self) -> None: + """Check for user mentions.""" + condition = {"kind": "org.matrix.msc3952.is_user_mention"} + + # No mentions shouldn't match. + evaluator = self._get_evaluator({}) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # An empty set shouldn't match + evaluator = self._get_evaluator({}, mentions=set()) + self.assertFalse(evaluator.matches(condition, "@user:test", None)) + + # The Matrix ID appearing anywhere in the mentions list should match + evaluator = self._get_evaluator({}, mentions={"@user:test"}) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + evaluator = self._get_evaluator({}, mentions={"@another:test", "@user:test"}) + self.assertTrue(evaluator.matches(condition, "@user:test", None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + + def test_room_mentions(self) -> None: + """Check for room mentions.""" + condition = {"kind": "org.matrix.msc3952.is_room_mention"} + + # No mentions shouldn't match. + evaluator = self._get_evaluator({}) + self.assertFalse(evaluator.matches(condition, None, None)) + + # An empty set shouldn't match + evaluator = self._get_evaluator({}, mentions=set()) + self.assertFalse(evaluator.matches(condition, None, None)) + + # The @room appearing anywhere in the mentions list should match + evaluator = self._get_evaluator({}, mentions={"@room"}) + self.assertTrue(evaluator.matches(condition, None, None)) + + evaluator = self._get_evaluator({}, mentions={"@another:test", "@room"}) + self.assertTrue(evaluator.matches(condition, None, None)) + + # Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions + # since the BulkPushRuleEvaluator is what handles data sanitisation. + def _assert_matches( self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None ) -> None: @@ -308,7 +357,7 @@ def test_related_event_match(self) -> None: }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message", @@ -408,7 +457,7 @@ def test_related_event_match_with_fallback(self) -> None: }, } }, - { + related_events={ "m.in_reply_to": { "event_id": "$parent_event_id", "type": "m.room.message",