Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Generics for ObservableDeferred
Browse files Browse the repository at this point in the history
Now that `Deferred` is a generic class, let's update `ObeservableDeferred` to
follow suit.
  • Loading branch information
richvdh committed Jul 28, 2021
1 parent b653472 commit ac016ff
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
1 change: 1 addition & 0 deletions changelog.d/10491.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type annotations for `ObservableDeferred`.
5 changes: 3 additions & 2 deletions synapse/notifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def __init__(
self.last_notified_token = current_token
self.last_notified_ms = time_now_ms

with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
self.notify_deferred: ObservableDeferred[StreamToken] = ObservableDeferred(
defer.Deferred()
)

def notify(
self,
Expand Down
4 changes: 3 additions & 1 deletion synapse/storage/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ async def add_to_queue(
end_item = queue[-1]
else:
# need to make a new queue item
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
deferred: ObservableDeferred[_PersistResult] = ObservableDeferred(
defer.Deferred(), consumeErrors=True
)

end_item = _EventPersistQueueItem(
events_and_contexts=[],
Expand Down
13 changes: 7 additions & 6 deletions synapse/util/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Awaitable,
Callable,
Dict,
Generic,
Hashable,
Iterable,
List,
Expand Down Expand Up @@ -52,7 +53,7 @@
_T = TypeVar("_T")


class ObservableDeferred:
class ObservableDeferred(Generic[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
observer deferreds do not affect the callback chain of the original
deferred.
Expand All @@ -70,7 +71,7 @@ class ObservableDeferred:

__slots__ = ["_deferred", "_observers", "_result"]

def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
def __init__(self, deferred: defer.Deferred[_T], consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
Expand Down Expand Up @@ -115,15 +116,15 @@ def errback(f):

deferred.addCallbacks(callback, errback)

def observe(self) -> defer.Deferred:
def observe(self) -> defer.Deferred[_T]:
"""Observe the underlying deferred.
This returns a brand new deferred that is resolved when the underlying
deferred is resolved. Interacting with the returned deferred does not
effect the underlying deferred.
"""
if not self._result:
d: defer.Deferred[Any] = defer.Deferred()
d: defer.Deferred[_T] = defer.Deferred()

def remove(r):
self._observers.discard(d)
Expand All @@ -137,7 +138,7 @@ def remove(r):
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)

def observers(self) -> List[defer.Deferred]:
def observers(self) -> List[defer.Deferred[_T]]:
return self._observers

def has_called(self) -> bool:
Expand All @@ -146,7 +147,7 @@ def has_called(self) -> bool:
def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True

def get_result(self) -> Any:
def get_result(self) -> _T:
return self._result[1]

def __getattr__(self, name: str) -> Any:
Expand Down

0 comments on commit ac016ff

Please sign in to comment.