Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Type hint the constructors of the data store classes #11555

Merged
merged 3 commits into from
Dec 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11555.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add missing type hints to storage classes.
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Optional

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
Expand All @@ -27,7 +27,12 @@


class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache

Expand All @@ -25,7 +25,12 @@


class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.client_ip_last_seen: LruCache[tuple, int] = LruCache(
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
Expand All @@ -27,7 +27,12 @@


class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.hs = hs
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
Expand Down Expand Up @@ -58,7 +58,12 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

events_max = self._stream_id_gen.get_current_token()
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.filtering import FilteringStore

from ._base import BaseSlavedStore
Expand All @@ -24,7 +24,12 @@


class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

# Filters are immutable so this cache doesn't need to be expired
Expand Down
9 changes: 7 additions & 2 deletions synapse/replication/slave/storage/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache

Expand All @@ -26,7 +26,12 @@


class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.hs = hs
Expand Down
13 changes: 8 additions & 5 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union

from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder

Expand All @@ -38,7 +36,12 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""

def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def commit(self) -> None:
def rollback(self) -> None:
self.conn.rollback()

def __enter__(self) -> "Connection":
def __enter__(self) -> "LoggingDatabaseConnection":
Copy link
Contributor Author

@squahtx squahtx Dec 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value here makes its way into the data store constructor:

with make_conn(database_config, engine, "startup") as db_conn:

main = main_store_class(database, db_conn, hs)

The constructor requires a LoggingDatabaseConnection because StreamIdGenerator constructors use the txn_name argument of cursor():

def _load_current_id(
db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
) -> int:
cur = db_conn.cursor(txn_name="_load_current_id")

def cursor(
self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
) -> "LoggingTransaction":

self.conn.__enter__()
return self

Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/databases/main/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import TYPE_CHECKING, List, Optional, Tuple

from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
Expand Down Expand Up @@ -129,7 +129,12 @@ class DataStore(
LockStore,
SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
Expand Down
10 changes: 7 additions & 3 deletions synapse/storage/databases/main/appservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from synapse.config.appservice import load_appservices
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.types import Connection
from synapse.types import JsonDict
from synapse.util import json_encoder

Expand Down Expand Up @@ -58,7 +57,12 @@ def _make_exclusive_regex(


class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
self.services_cache = load_appservices(
hs.hostname, hs.config.appservice.app_service_config_files
)
Expand Down
9 changes: 7 additions & 2 deletions synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
EventsStreamEventRow,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter

Expand All @@ -41,7 +41,12 @@


class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()
Expand Down
13 changes: 11 additions & 2 deletions synapse/storage/databases/main/censor_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util import json_encoder
Expand All @@ -31,7 +35,12 @@


class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

if (
Expand Down
22 changes: 18 additions & 4 deletions synapse/storage/databases/main/client_ips.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
make_tuple_comparison_clause,
)
from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
from synapse.storage.types import Connection
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache

Expand Down Expand Up @@ -65,7 +64,12 @@ class LastConnectionInfo(TypedDict):


class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
Expand Down Expand Up @@ -394,7 +398,12 @@ def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:


class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.user_ips_max_age = hs.config.server.user_ips_max_age
Expand Down Expand Up @@ -532,7 +541,12 @@ def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:


class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):

# (user_id, access_token, ip,) -> last_seen
self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
Expand Down
7 changes: 6 additions & 1 deletion synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,12 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
REMOVE_HIDDEN_DEVICES = "remove_hidden_devices_from_device_inbox"
REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"

def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
Expand Down
22 changes: 19 additions & 3 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_tuple_comparison_clause,
)
Expand All @@ -61,7 +62,12 @@


class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

if hs.config.worker.run_background_tasks:
Expand Down Expand Up @@ -953,7 +959,12 @@ def _prune_txn(txn):


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
Expand Down Expand Up @@ -1085,7 +1096,12 @@ def _txn(txn):


class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)

# Map of (user_id, device_id) -> bool. If there is an entry that implies
Expand Down
Loading