From 71b2ef56cee3ee59cab586fc57a1578aaa370ddb Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 9 Dec 2021 16:54:28 +0000 Subject: [PATCH 1/4] Add type hints to `synapse/storage/databases/main/account_data.py` --- changelog.d/11546.misc | 1 + mypy.ini | 4 +- .../storage/databases/main/account_data.py | 85 +++++++++++++------ synapse/storage/databases/main/tags.py | 22 ++++- 4 files changed, 86 insertions(+), 26 deletions(-) create mode 100644 changelog.d/11546.misc diff --git a/changelog.d/11546.misc b/changelog.d/11546.misc new file mode 100644 index 000000000000..d451940bf216 --- /dev/null +++ b/changelog.d/11546.misc @@ -0,0 +1 @@ +Add missing type hints to storage classes. diff --git a/mypy.ini b/mypy.ini index 1caf807e8505..4b2ddafd2042 100644 --- a/mypy.ini +++ b/mypy.ini @@ -25,7 +25,6 @@ exclude = (?x) ^( |synapse/storage/databases/__init__.py |synapse/storage/databases/main/__init__.py - |synapse/storage/databases/main/account_data.py |synapse/storage/databases/main/cache.py |synapse/storage/databases/main/devices.py |synapse/storage/databases/main/e2e_room_keys.py @@ -181,6 +180,9 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.databases.main.account_data] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.client_ips] disallow_untyped_defs = True diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index f8bec266ac41..a24a60d9a949 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,15 +14,25 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, cast from synapse.api.constants import AccountDataTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool +from synapse.storage._base import db_to_json +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine -from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator +from synapse.storage.util.id_generators import ( + AbstractStreamIdGenerator, + AbstractStreamIdTracker, + MultiWriterIdGenerator, + StreamIdGenerator, +) from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -34,14 +44,24 @@ logger = logging.getLogger(__name__) -class AccountDataWorkerStore(SQLBaseStore): +class AccountDataWorkerStore(CacheInvalidationWorkerStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ - def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): self._instance_name = hs.get_instance_name() + # `_can_write_to_account_data` indicates whether the current worker is allowed + # to write account data. A value of `True` implies that `_account_data_id_gen` + # is an `AbstractStreamIdGenerator` and not just a tracker. + self._account_data_id_gen: AbstractStreamIdTracker + if isinstance(database.engine, PostgresEngine): self._can_write_to_account_data = ( self._instance_name in hs.config.worker.writers.account_data @@ -61,8 +81,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): writers=hs.config.worker.writers.account_data, ) else: - self._can_write_to_account_data = True - # We shouldn't be running in worker mode with SQLite, but its useful # to support it for unit tests. # @@ -71,6 +89,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): # updated over replication. (Multiple writers are not supported for # SQLite). if hs.get_instance_name() in hs.config.worker.writers.account_data: + self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, "room_account_data", @@ -113,7 +132,9 @@ async def get_account_data_for_user( room_id string to per room account_data dicts. """ - def get_account_data_for_user_txn(txn): + def get_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: rows = self.db_pool.simple_select_list_txn( txn, "account_data", @@ -132,7 +153,7 @@ def get_account_data_for_user_txn(txn): ["room_id", "account_data_type", "content"], ) - by_room = {} + by_room: Dict[str, Dict[str, JsonDict]] = {} for row in rows: room_data = by_room.setdefault(row["room_id"], {}) room_data[row["account_data_type"]] = db_to_json(row["content"]) @@ -177,7 +198,9 @@ async def get_account_data_for_room( A dict of the room account_data """ - def get_account_data_for_room_txn(txn): + def get_account_data_for_room_txn( + txn: LoggingTransaction, + ) -> Dict[str, JsonDict]: rows = self.db_pool.simple_select_list_txn( txn, "room_account_data", @@ -207,7 +230,9 @@ async def get_account_data_for_room_and_type( The room account_data for that type, or None if there isn't any set. """ - def get_account_data_for_room_and_type_txn(txn): + def get_account_data_for_room_and_type_txn( + txn: LoggingTransaction, + ) -> Optional[JsonDict]: content_json = self.db_pool.simple_select_one_onecol_txn( txn, table="room_account_data", @@ -243,14 +268,16 @@ async def get_updated_global_account_data( if last_id == current_id: return [] - def get_updated_global_account_data_txn(txn): + def get_updated_global_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: sql = ( "SELECT stream_id, user_id, account_data_type" " FROM account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_global_account_data", get_updated_global_account_data_txn @@ -273,14 +300,16 @@ async def get_updated_room_account_data( if last_id == current_id: return [] - def get_updated_room_account_data_txn(txn): + def get_updated_room_account_data_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str]]: sql = ( "SELECT stream_id, user_id, room_id, account_data_type" " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?" " ORDER BY stream_id ASC LIMIT ?" ) txn.execute(sql, (last_id, current_id, limit)) - return txn.fetchall() + return cast(List[Tuple[int, str, str, str]], txn.fetchall()) return await self.db_pool.runInteraction( "get_updated_room_account_data", get_updated_room_account_data_txn @@ -299,7 +328,9 @@ async def get_updated_account_data_for_user( mapping from room_id string to per room account_data dicts. """ - def get_updated_account_data_for_user_txn(txn): + def get_updated_account_data_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: sql = ( "SELECT account_data_type, content FROM account_data" " WHERE user_id = ? AND stream_id > ?" @@ -316,7 +347,7 @@ def get_updated_account_data_for_user_txn(txn): txn.execute(sql, (user_id, stream_id)) - account_data_by_room = {} + account_data_by_room: Dict[str, Dict[str, JsonDict]] = {} for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = db_to_json(row[2]) @@ -353,12 +384,15 @@ async def ignored_by(self, user_id: str) -> Set[str]: ) ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: if stream_name == TagAccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) elif stream_name == AccountDataStream.NAME: self._account_data_id_gen.advance(instance_name, token) for row in rows: @@ -372,7 +406,8 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): (row.user_id, row.room_id, row.data_type) ) self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + + super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -389,6 +424,7 @@ async def add_account_data_to_room( The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -431,6 +467,7 @@ async def add_account_data_for_user( The maximum stream ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( @@ -452,7 +489,7 @@ async def add_account_data_for_user( def _add_account_data_for_user( self, - txn, + txn: LoggingTransaction, next_id: int, user_id: str, account_data_type: str, diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 8f510de53d43..c8e508a910fb 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,11 +15,13 @@ # limitations under the License. import logging -from typing import Dict, List, Tuple, cast +from typing import Any, Dict, Iterable, List, Tuple, cast +from synapse.replication.tcp.streams import TagAccountDataStream from synapse.storage._base import db_to_json from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.storage.util.id_generators import AbstractStreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -204,6 +206,7 @@ async def add_tag_to_room( The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) content_json = json_encoder.encode(content) @@ -230,6 +233,7 @@ async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> in The next account data ID. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) def remove_tag_txn(txn: LoggingTransaction, next_id: int) -> None: sql = ( @@ -258,6 +262,7 @@ def _update_revision_txn( next_id: The the revision to advance to. """ assert self._can_write_to_account_data + assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -287,6 +292,21 @@ def _update_revision_txn( # than the id that the client has. pass + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + + super().process_replication_rows(stream_name, instance_name, token, rows) + class TagsStore(TagsWorkerStore): pass From 002cb036a29db2011b72700d1090c6c8d799c02d Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Dec 2021 14:43:32 +0000 Subject: [PATCH 2/4] Remove `_instance_name` redefinition and use it consistently --- synapse/storage/databases/main/account_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index a24a60d9a949..54b7d9553e44 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -55,8 +55,6 @@ def __init__( db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): - self._instance_name = hs.get_instance_name() - # `_can_write_to_account_data` indicates whether the current worker is allowed # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. @@ -88,7 +86,7 @@ def __init__( # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets # updated over replication. (Multiple writers are not supported for # SQLite). - if hs.get_instance_name() in hs.config.worker.writers.account_data: + if self._instance_name in hs.config.worker.writers.account_data: self._can_write_to_account_data = True self._account_data_id_gen = StreamIdGenerator( db_conn, From 91a540df5052d4acb506430d3fa4f47ae082f608 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Dec 2021 14:47:05 +0000 Subject: [PATCH 3/4] Remove inaccurate docstring --- synapse/storage/databases/main/account_data.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 54b7d9553e44..36042470b4cd 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -45,10 +45,6 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore): - """This is an abstract base class where subclasses must implement - `get_max_account_data_stream_id` which can be called in the initializer. - """ - def __init__( self, database: DatabasePool, From 872f9f5ea2ddc2449cbdee29118c79169fd4f793 Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Fri, 10 Dec 2021 14:55:39 +0000 Subject: [PATCH 4/4] Call `super().__init__()` sooner --- synapse/storage/databases/main/account_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 36042470b4cd..32a553fdd7bd 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -51,6 +51,8 @@ def __init__( db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): + super().__init__(database, db_conn, hs) + # `_can_write_to_account_data` indicates whether the current worker is allowed # to write account data. A value of `True` implies that `_account_data_id_gen` # is an `AbstractStreamIdGenerator` and not just a tracker. @@ -103,8 +105,6 @@ def __init__( "AccountDataAndTagsChangeCache", account_max ) - super().__init__(database, db_conn, hs) - def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream