Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Implement MSC3952: Intentional mentions
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Jan 12, 2023
1 parent b50c008 commit c9d2a3d
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 9 deletions.
1 change: 1 addition & 0 deletions changelog.d/14823.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Experimental support for [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952): intentional mentions.
16 changes: 16 additions & 0 deletions rust/src/push/base_rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed(".org.matrix.msc3952.is_user_mentioned"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsUserMention)]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"),
priority_class: 5,
Expand All @@ -139,6 +147,14 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mentioned"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::IsRoomMention)]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"),
priority_class: 5,
Expand Down
17 changes: 16 additions & 1 deletion rust/src/push/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::collections::{BTreeMap, BTreeSet};

use anyhow::{Context, Error};
use lazy_static::lazy_static;
Expand Down Expand Up @@ -68,6 +68,9 @@ pub struct PushRuleEvaluator {
/// The "content.body", if any.
body: String,

/// The mentions that were part of the message, note this has
mentions: BTreeSet<String>,

/// The number of users in the room.
room_member_count: u64,

Expand Down Expand Up @@ -100,6 +103,7 @@ impl PushRuleEvaluator {
#[new]
pub fn py_new(
flattened_keys: BTreeMap<String, String>,
mentions: BTreeSet<String>,
room_member_count: u64,
sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>,
Expand All @@ -116,6 +120,7 @@ impl PushRuleEvaluator {
Ok(PushRuleEvaluator {
flattened_keys,
body,
mentions,
room_member_count,
notification_power_levels,
sender_power_level,
Expand Down Expand Up @@ -229,6 +234,14 @@ impl PushRuleEvaluator {
KnownCondition::RelatedEventMatch(event_match) => {
self.match_related_event_match(event_match, user_id)?
}
KnownCondition::IsUserMention => {
if let Some(uid) = user_id {
self.mentions.contains(uid)
} else {
false
}
}
KnownCondition::IsRoomMention => self.mentions.contains("@room"),
KnownCondition::ContainsDisplayName => {
if let Some(dn) = display_name {
if !dn.is_empty() {
Expand Down Expand Up @@ -424,6 +437,7 @@ fn push_rule_evaluator() {
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
BTreeSet::new(),
10,
Some(0),
BTreeMap::new(),
Expand All @@ -449,6 +463,7 @@ fn test_requires_room_version_supports_condition() {
let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()];
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
BTreeSet::new(),
10,
Some(0),
BTreeMap::new(),
Expand Down
26 changes: 26 additions & 0 deletions rust/src/push/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@ pub enum KnownCondition {
EventMatch(EventMatchCondition),
#[serde(rename = "im.nheko.msc3664.related_event_match")]
RelatedEventMatch(RelatedEventMatchCondition),
#[serde(rename = "org.matrix.msc3952.is_user_mention")]
IsUserMention,
#[serde(rename = "org.matrix.msc3952.is_room_mention")]
IsRoomMention,
ContainsDisplayName,
RoomMemberCount {
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -514,6 +518,28 @@ fn test_deserialize_unstable_msc3931_condition() {
));
}

#[test]
fn test_deserialize_unstable_msc3952_user_condition() {
let json = r#"{"kind":"org.matrix.msc3952.is_user_mention"}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::IsUserMention)
));
}

#[test]
fn test_deserialize_unstable_msc3952_room_condition() {
let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#;

let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::IsRoomMention)
));
}

#[test]
fn test_deserialize_custom_condition() {
let json = r#"{"kind":"custom_tag"}"#;
Expand Down
3 changes: 2 additions & 1 deletion stubs/synapse/synapse_rust/push.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union

from synapse.types import JsonDict

Expand Down Expand Up @@ -54,6 +54,7 @@ class PushRuleEvaluator:
def __init__(
self,
flattened_keys: Mapping[str, str],
mentions: Set[str],
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
Expand Down
13 changes: 13 additions & 0 deletions synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,21 @@ async def _action_for_event_by_user(
for user_id, level in notification_levels.items():
notification_levels[user_id] = int(level)

# Pull out the mentions field if it exists and trim the values to things
# that might be valid.
mentions_raw = event.content.get("mentions")
if isinstance(mentions_raw, list):
# Take the first 10 items, then strip out any non-string ones and convert
# to a tuple.
mentions = set(
filter(lambda item: isinstance(item, str), mentions_raw[:10])
)
else:
mentions = set()

evaluator = PushRuleEvaluator(
_flatten_dict(event, room_version=event.room_version),
mentions,
room_member_count,
sender_power_level,
notification_levels,
Expand Down
60 changes: 60 additions & 0 deletions tests/push/test_bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any
from unittest.mock import patch

from twisted.test.proto_helpers import MemoryReactor
Expand Down Expand Up @@ -126,3 +127,62 @@ def test_action_for_event_by_user_disabled_by_config(self) -> None:
# Ensure no actions are generated!
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
bulk_evaluator._action_for_event_by_user.assert_not_called()

def test_mentions(self) -> None:
"""Test the behavior of an event which includes invalid mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)

sentinel = object()

def create_and_process(mentions: Any = sentinel) -> bool:
content = {}
if mentions is not sentinel:
content["mentions"] = mentions

# Create a new message event which should cause a notification.
event, context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
"type": "test",
"room_id": self.room_id,
"content": content,
"sender": f"@bob:{self.hs.hostname}",
},
)
)

# Ensure no actions are generated!
self.get_success(
bulk_evaluator.action_for_events_by_user([(event, context)])
)

# If any actions are generated for this event, return true.
result = self.get_success(
self.hs.get_datastores().main.db_pool.simple_select_list(
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
retcols=("*",),
desc="get_event_push_actions_staging",
)
)
return len(result) > 0

# Not including the mentions field should result in no notifications.
self.assertFalse(create_and_process())

# Invalid data should be ignored.
mentions: Any
for mentions in (None, True, False, "foo", {}):
self.assertFalse(create_and_process(mentions))

# The Matrix ID appearing anywhere in the mentions list should match
self.assertTrue(create_and_process([self.alice]))
self.assertTrue(create_and_process(["@another:test", self.alice]))

# The Matrix ID appearing > 10 entries into the list should be ignored.
self.assertFalse(create_and_process(["@another:test"] * 10 + [self.alice]))

# Invalid entries in the list are ignored, but count towards the limit.
self.assertTrue(create_and_process([None, True, False, {}, [], self.alice]))
self.assertFalse(create_and_process([None] * 10 + [self.alice]))
63 changes: 56 additions & 7 deletions tests/push/test_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Union, cast
from typing import Dict, List, Optional, Set, Union, cast

import frozendict

Expand All @@ -39,7 +39,11 @@

class PushRuleEvaluatorTestCase(unittest.TestCase):
def _get_evaluator(
self, content: JsonMapping, related_events: Optional[JsonDict] = None
self,
content: JsonMapping,
*,
mentions: Optional[Set[str]] = None,
related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
{
Expand All @@ -57,13 +61,14 @@ def _get_evaluator(
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
return PushRuleEvaluator(
_flatten_dict(event),
mentions or set(),
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
{} if related_events is None else related_events,
True,
event.room_version.msc3931_push_features,
True,
related_event_match_enabled=True,
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
)

def test_display_name(self) -> None:
Expand All @@ -90,6 +95,50 @@ def test_display_name(self) -> None:
# A display name with spaces should work fine.
self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))

def test_user_mentions(self) -> None:
"""Check for user mentions."""
condition = {"kind": "org.matrix.msc3952.is_user_mention"}

# No mentions shouldn't match.
evaluator = self._get_evaluator({})
self.assertFalse(evaluator.matches(condition, "@user:test", None))

# An empty set shouldn't match
evaluator = self._get_evaluator({}, mentions=set())
self.assertFalse(evaluator.matches(condition, "@user:test", None))

# The Matrix ID appearing anywhere in the mentions list should match
evaluator = self._get_evaluator({}, mentions={"@user:test"})
self.assertTrue(evaluator.matches(condition, "@user:test", None))

evaluator = self._get_evaluator({}, mentions={"@another:test", "@user:test"})
self.assertTrue(evaluator.matches(condition, "@user:test", None))

# Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
# since the BulkPushRuleEvaluator is what handles data sanitisation.

def test_room_mentions(self) -> None:
"""Check for room mentions."""
condition = {"kind": "org.matrix.msc3952.is_room_mention"}

# No mentions shouldn't match.
evaluator = self._get_evaluator({})
self.assertFalse(evaluator.matches(condition, None, None))

# An empty set shouldn't match
evaluator = self._get_evaluator({}, mentions=set())
self.assertFalse(evaluator.matches(condition, None, None))

# The @room appearing anywhere in the mentions list should match
evaluator = self._get_evaluator({}, mentions={"@room"})
self.assertTrue(evaluator.matches(condition, None, None))

evaluator = self._get_evaluator({}, mentions={"@another:test", "@room"})
self.assertTrue(evaluator.matches(condition, None, None))

# Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
# since the BulkPushRuleEvaluator is what handles data sanitisation.

def _assert_matches(
self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
Expand Down Expand Up @@ -308,7 +357,7 @@ def test_related_event_match(self) -> None:
},
}
},
{
related_events={
"m.in_reply_to": {
"event_id": "$parent_event_id",
"type": "m.room.message",
Expand Down Expand Up @@ -408,7 +457,7 @@ def test_related_event_match_with_fallback(self) -> None:
},
}
},
{
related_events={
"m.in_reply_to": {
"event_id": "$parent_event_id",
"type": "m.room.message",
Expand Down

0 comments on commit c9d2a3d

Please sign in to comment.