diff --git a/changelog.d/12106.misc b/changelog.d/12106.misc new file mode 100644 index 000000000000..d918e9e3b16d --- /dev/null +++ b/changelog.d/12106.misc @@ -0,0 +1 @@ +Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index a83296a2292b..81320b89721e 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -665,3 +665,22 @@ def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]: return value return DoneAwaitable(value) + + +def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`. + + Args: + deferred: The `Deferred` to protect against cancellation. Must not follow the + Synapse logcontext rules. + + Returns: + A new `Deferred`, which will contain the result of the original `Deferred`, + but will not propagate cancellation through to the original. When cancelled, + the new `Deferred` will fail with a `CancelledError` and will not follow the + Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap + the new `Deferred`. + """ + new_deferred: defer.Deferred[T] = defer.Deferred() + deferred.chainDeferred(new_deferred) + return new_deferred diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index cce8d595fc7e..362014f4cb6f 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -27,6 +27,7 @@ from synapse.util.async_helpers import ( ObservableDeferred, concurrently_execute, + stop_cancellation, timeout_deferred, ) @@ -282,3 +283,47 @@ async def caller(): d2 = ensureDeferred(caller()) d1.callback(0) self.successResultOf(d2) + + +class StopCancellationTests(TestCase): + """Tests for the `stop_cancellation` function.""" + + def test_succeed(self): + """Test that the new `Deferred` receives the result.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Success should propagate through. + deferred.callback("success") + self.assertTrue(wrapper_deferred.called) + self.assertEqual("success", self.successResultOf(wrapper_deferred)) + + def test_failure(self): + """Test that the new `Deferred` receives the `Failure`.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Failure should propagate through. + deferred.errback(ValueError("abc")) + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, ValueError) + self.assertIsNone(deferred.result, "`Failure` was not consumed") + + def test_cancellation(self): + """Test that cancellation of the new `Deferred` leaves the original running.""" + deferred: "Deferred[str]" = Deferred() + wrapper_deferred = stop_cancellation(deferred) + + # Cancel the new `Deferred`. + wrapper_deferred.cancel() + self.assertTrue(wrapper_deferred.called) + self.failureResultOf(wrapper_deferred, CancelledError) + self.assertFalse( + deferred.called, "Original `Deferred` was unexpectedly cancelled." + ) + + # Now make the inner `Deferred` fail. + # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed + # in logs. + deferred.errback(ValueError("abc")) + self.assertIsNone(deferred.result, "`Failure` was not consumed")