From 289ade289c08237bae8f634345dd3be6027d5633 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 28 Jun 2022 14:27:16 +0100 Subject: [PATCH 1/6] Rate limiter: describe leaky bucket --- synapse/api/ratelimiting.py | 38 +++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 54d13026c9e5..1889537a8342 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -27,6 +27,33 @@ class Ratelimiter: """ Ratelimit actions marked by arbitrary keys. + (Note that the source code speaks of "actions" and "burst_count" rather than + "tokens" and a "bucket_size".) + + This is a "leaky bucket as a meter". For each key to be tracked there is a bucket + containing some number 0 <= T <= `burst_count` of tokens corresponding to previously + permitted requests for that key. Each bucket starts empty, and gradually leaks + tokens at a rate of `rate_hz`. + + Upon an incoming request, we must determine: + - the key that this request falls under (which bucket to inspect), and + - the cost C of this request in tokens. + Then, if there is room in the bucket for C tokens (T + C <= `burst_count`), + the request is permitted and `cost` tokens are added to the bucket. + Otherwise the request is denied, and the bucket continues to hold T tokens. + + This means that the limiter enforces an average request frequency of `rate_hz`, + while accumulating a buffer of up to `burst_count` requests which can be consumed + instantaneously. + + The tricky bit is the leaking. We do not want to have a periodic process which + leaks every bucket! Instead, we track + - the time point when the bucket was last completely empty, and + - how many tokens have added to the bucket permitted since then. + Then for each incoming request, we can calculate how many tokens have leaked + since this time point, and use that to decide if we should accept or reject the + request. + Args: clock: A homeserver clock, for retrieving the current time rate_hz: The long term number of actions that can be performed in a second. @@ -41,12 +68,11 @@ def __init__( self.burst_count = burst_count self.store = store - # A ordered dictionary keeping track of actions, when they were last - # performed and how often. Each entry is a mapping from a key of arbitrary type - # to a tuple representing: - # * How many times an action has occurred since a point in time - # * The point in time - # * The rate_hz of this particular entry. This can vary per request + # An ordered dictionary representing the token buckets tracked by this rate + # limiter. Each entry maps a key of arbitrary type to a tuple representing: + # * The number of tokens currently in the bucket, + # * The time point when the bucket was last completely empty, and + # * The rate_hz (leak rate) of this particular bucket. self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() async def can_do_action( From 9e168134c87b4805f0e13f19b898d89fc6dbcee8 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 28 Jun 2022 15:50:18 +0100 Subject: [PATCH 2/6] Rate limiter: Pull out some small methods --- synapse/api/ratelimiting.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 1889537a8342..2bdc90cdb155 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -75,6 +75,29 @@ def __init__( # * The rate_hz (leak rate) of this particular bucket. self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict() + def _get_key( + self, requester: Optional[Requester], key: Optional[Hashable] + ) -> Hashable: + """Use the requester's MXID as a fallback key if no key is provided. + + Pulled out so that `can_do_action` and `record_action` are consistent. + """ + if key is None: + if not requester: + raise ValueError("Must supply at least one of `requester` or `key`") + + key = requester.user.to_string() + return key + + def _get_action_counts( + self, key: Hashable, time_now_s: float + ) -> Tuple[float, float, float]: + """Retrieve the action counts, with a fallback representing an empty bucket. + + Pulled out so that `can_do_action` and `record_action` are consistent. + """ + return self.actions.get(key, (0.0, time_now_s, 0.0)) + async def can_do_action( self, requester: Optional[Requester], @@ -114,11 +137,7 @@ async def can_do_action( * The reactor timestamp for when the action can be performed next. -1 if rate_hz is less than or equal to zero """ - if key is None: - if not requester: - raise ValueError("Must supply at least one of `requester` or `key`") - - key = requester.user.to_string() + key = self._get_key(requester, key) if requester: # Disable rate limiting of users belonging to any AS that is configured @@ -147,7 +166,7 @@ async def can_do_action( self._prune_message_counts(time_now_s) # Check if there is an existing count entry for this key - action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0)) + action_count, time_start, _ = self._get_action_counts(key, time_now_s) # Check whether performing another action is allowed time_delta = time_now_s - time_start From 9d1e80b588ebe769dfdf6360b13874630e57151a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 28 Jun 2022 15:50:32 +0100 Subject: [PATCH 3/6] Rate limiter: Introduce `record_action` --- synapse/api/ratelimiting.py | 31 ++++++++++++++ tests/api/test_ratelimiting.py | 74 ++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 2bdc90cdb155..930e96369d92 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -209,6 +209,37 @@ async def can_do_action( return allowed, time_allowed + def record_action( + self, + requester: Optional[Requester], + key: Optional[Hashable] = None, + n_actions: int = 1, + _time_now_s: Optional[float] = None, + ) -> None: + """Record that an action(s) took place, even if they violate the rate limit. + + This is useful for tracking the frequency of events that happen across + federation which we still want to impose local rate limits on. For instance, if + we are alice.com monitoring a particular room, we cannot prevent bob.com + from joining users to that room. However, we can track the number of recent + joins in the room and refuse to serve new joins ourselves if there have been too + many in the room across both homeservers. + + Args: + requester: The requester that is doing the action, if any. + key: An arbitrary key used to classify an action. Defaults to the + requester's user ID. + n_actions: The number of times the user wants to do this action. If the user + cannot do all of the actions, the user's action count is not incremented + at all. + _time_now_s: The current time. Optional, defaults to the current time according + to self.clock. Only used by tests. + """ + key = self._get_key(requester, key) + time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() + action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s) + self.actions[key] = (action_count + n_actions, time_start, rate_hz) + def _prune_message_counts(self, time_now_s: float) -> None: """Remove message count entries that have not exceeded their defined rate_hz limit diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 18649c2c05dc..c86f783c5bd4 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -314,3 +314,77 @@ def consume_at(time: float) -> bool: # Check that we get rate limited after using that token. self.assertFalse(consume_at(11.1)) + + def test_record_action_which_doesnt_fill_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe two actions, leaving room in the bucket for one more. + limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0) + + # We should be able to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + def test_record_action_which_fills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe three actions, filling up the bucket. + limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0) + + # We should be unable to take a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be able to take one action... + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertTrue(success) + + # ... but not two. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + def test_record_action_which_overfills_bucket(self) -> None: + limiter = Ratelimiter( + store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3 + ) + + # Observe four actions, exceeding the bucket. + limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0) + + # We should be prevented from taking a new action now. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=0.0) + ) + self.assertFalse(success) + + # If we wait 10 seconds to leak a token, we should be unable to take an action + # because the bucket is still full. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=10.0) + ) + self.assertFalse(success) + + # But after another 10 seconds we leak a second token, giving us room for + # action. + success, _ = self.get_success_or_raise( + limiter.can_do_action(requester=None, key="a", _time_now_s=20.0) + ) + self.assertTrue(success) From 7624950323df03fa327601b002421d9ad5392429 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Tue, 12 Jul 2022 11:53:47 +0100 Subject: [PATCH 4/6] Changelog --- changelog.d/13253.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/13253.misc diff --git a/changelog.d/13253.misc b/changelog.d/13253.misc new file mode 100644 index 000000000000..cba6b9ee0ff0 --- /dev/null +++ b/changelog.d/13253.misc @@ -0,0 +1 @@ +Preparatory work for a per-room rate limiter on joins. From 6882a8c6240348243c51d47a9fe6512316b1ecc2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 13 Jul 2022 19:41:52 +0100 Subject: [PATCH 5/6] Omit "why is this pulled out" comments Co-authored-by: Patrick Cloke --- synapse/api/ratelimiting.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 930e96369d92..dbe074e79000 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -79,8 +79,6 @@ def _get_key( self, requester: Optional[Requester], key: Optional[Hashable] ) -> Hashable: """Use the requester's MXID as a fallback key if no key is provided. - - Pulled out so that `can_do_action` and `record_action` are consistent. """ if key is None: if not requester: @@ -93,8 +91,6 @@ def _get_action_counts( self, key: Hashable, time_now_s: float ) -> Tuple[float, float, float]: """Retrieve the action counts, with a fallback representing an empty bucket. - - Pulled out so that `can_do_action` and `record_action` are consistent. """ return self.actions.get(key, (0.0, time_now_s, 0.0)) From ea5dddc3fedd52bf244034965a6462b9b19ae59e Mon Sep 17 00:00:00 2001 From: David Robertson Date: Wed, 13 Jul 2022 19:44:40 +0100 Subject: [PATCH 6/6] Linter script --- synapse/api/ratelimiting.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index dbe074e79000..f43965c1c837 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -78,8 +78,7 @@ def __init__( def _get_key( self, requester: Optional[Requester], key: Optional[Hashable] ) -> Hashable: - """Use the requester's MXID as a fallback key if no key is provided. - """ + """Use the requester's MXID as a fallback key if no key is provided.""" if key is None: if not requester: raise ValueError("Must supply at least one of `requester` or `key`") @@ -90,8 +89,7 @@ def _get_key( def _get_action_counts( self, key: Hashable, time_now_s: float ) -> Tuple[float, float, float]: - """Retrieve the action counts, with a fallback representing an empty bucket. - """ + """Retrieve the action counts, with a fallback representing an empty bucket.""" return self.actions.get(key, (0.0, time_now_s, 0.0)) async def can_do_action(