Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Additional tests for third-party event rules #8468

Merged
merged 7 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions synapse/events/third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,10 @@ async def check_event_allowed(
prev_state_ids = await context.get_prev_state_ids()

# Retrieve the state events from the database.
state_events = {}
for key, event_id in prev_state_ids.items():
state_events[key] = await self.store.get_event(event_id, allow_none=True)
events = await self.store.get_events(prev_state_ids.values())
state_events = {(ev.type, ev.state_key): ev for ev in events.values()}

ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret
return await self.third_party_rules.check_event_allowed(event, state_events)

async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
Expand Down
52 changes: 39 additions & 13 deletions tests/rest/client/test_third_party_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,49 +12,57 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading

from mock import Mock

from synapse.events import EventBase
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.types import Requester
from synapse.types import Requester, StateMap

from tests import unittest

thread_local = threading.local()


class ThirdPartyRulesTestModule:
def __init__(self, config, *args, **kwargs):
pass
def __init__(self, config, module_api):
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
Comment on lines +30 to +33
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you use thread_local here as we could be running the tests in parallel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, that was the general thought.


async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
return True

async def check_event_allowed(self, event, context):
if event.type == "foo.bar.forbidden":
return False
else:
return True
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
return True

@staticmethod
def parse_config(config):
return config


def current_rules_module() -> ThirdPartyRulesTestModule:
return thread_local.rules_module


class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]

def make_homeserver(self, reactor, clock):
config = self.default_config()
def default_config(self):
config = super().default_config()
config["third_party_event_rules"] = {
"module": __name__ + ".ThirdPartyRulesTestModule",
"config": {},
}

self.hs = self.setup_test_homeserver(config=config)
return self.hs
return config

def prepare(self, reactor, clock, homeserver):
# Create a user and room to play with during the tests
Expand All @@ -67,6 +75,14 @@ def test_third_party_rules(self):
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
return ev.type != "foo.bar.forbidden"

callback = Mock(spec=[], side_effect=check)
current_rules_module().check_event_allowed = callback

request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.allowed/1" % self.room_id,
Expand All @@ -76,6 +92,16 @@ def test_third_party_rules(self):
self.render(request)
self.assertEquals(channel.result["code"], b"200", channel.result)

callback.assert_called_once()

# there should be various state events in the state arg: do some basic cheks
richvdh marked this conversation as resolved.
Show resolved Hide resolved
state_arg = callback.call_args[0][1]
for k in (("m.room.create", ""), ("m.room.member", self.user_id)):
self.assertIn(k, state_arg)
ev = state_arg[k]
self.assertEqual(ev.type, k[0])
self.assertEqual(ev.state_key, k[1])

request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/1" % self.room_id,
Expand Down