Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add more types to synapse.storage.database. #8127

Merged
merged 3 commits into from
Aug 20, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add a type-var for a return type.
  • Loading branch information
clokep committed Aug 20, 2020
commit 466f7426e54f62a850bd48b0b39c42942afc76b0
17 changes: 11 additions & 6 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
@@ -117,7 +117,7 @@ def make_conn(
#
# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
# that mypy sees the type but the runtime python doesn't.
_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]


class LoggingTransaction:
@@ -161,7 +161,7 @@ def __init__(
self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks

def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the
correct thread.
@@ -173,7 +173,7 @@ def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any)
self.after_callbacks.append((callback, args, kwargs))

def call_on_exception(
self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
self, callback: Callable[..., None], *args: Any, **kwargs: Any
):
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
@@ -282,6 +282,9 @@ def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
return top_n_counters


R = TypeVar("R")


class DatabasePool(object):
"""Wraps a single physical database and connection pool.

@@ -395,10 +398,10 @@ def new_transaction(
desc: str,
after_callbacks: List[_CallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: Callable,
func: Callable[..., R],
*args: Any,
**kwargs: Any
) -> Any:
) -> R:
start = monotonic_time()
txn_id = self._TXN_ID

@@ -547,7 +550,9 @@ def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):

return result

async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
async def runWithConnection(
self, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
"""Wraps the .runWithConnection() method on the underlying db_pool.

Arguments: