Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add type hints to synapse/storage/databases/main/events_worker.py #11411

Merged
merged 15 commits into from
Nov 26, 2021
Merged
1 change: 1 addition & 0 deletions changelog.d/11411.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to storage classes.
4 changes: 3 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ exclude = (?x)
|synapse/storage/databases/main/event_federation.py
|synapse/storage/databases/main/event_push_actions.py
|synapse/storage/databases/main/events_bg_updates.py
|synapse/storage/databases/main/events_worker.py
|synapse/storage/databases/main/group_server.py
|synapse/storage/databases/main/metrics.py
|synapse/storage/databases/main/monthly_active_users.py
Expand Down Expand Up @@ -181,6 +180,9 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.directory]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.events_worker]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.room_batch]
disallow_untyped_defs = True

Expand Down
14 changes: 2 additions & 12 deletions synapse/replication/slave/storage/_slaved_id_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from typing import List, Optional, Tuple

from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.util.id_generators import _load_current_id
from synapse.storage.util.id_generators import AbstractStreamIdTracker, _load_current_id


class SlavedIdTracker:
class SlavedIdTracker(AbstractStreamIdTracker):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
Expand All @@ -36,17 +36,7 @@ def advance(self, instance_name: Optional[str], new_id: int):
self._current = (max if self.step > 0 else min)(self._current, new_id)

def get_current_token(self) -> int:
"""

Returns:
int
"""
return self._current

def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.

For streams with single writers this is equivalent to
`get_current_token`.
"""
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
return self.get_current_token()
4 changes: 0 additions & 4 deletions synapse/replication/slave/storage/push_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore

Expand All @@ -25,9 +24,6 @@ def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()

def process_replication_rows(self, stream_name, instance_name, token, rows):
# We assert this for the benefit of mypy
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)

if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
Expand Down
6 changes: 3 additions & 3 deletions synapse/replication/tcp/streams/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import heapq
from collections.abc import Iterable
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Optional, Tuple, Type

import attr

Expand Down Expand Up @@ -157,7 +157,7 @@ async def _update_function(

# now we fetch up to that many rows from the events table

event_rows: List[Tuple] = await self._store.get_all_new_forward_event_rows(
event_rows = await self._store.get_all_new_forward_event_rows(
instance_name, from_token, current_token, target_row_count
)

Expand Down Expand Up @@ -191,7 +191,7 @@ async def _update_function(
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.

ex_outliers_rows: List[Tuple] = await self._store.get_ex_outlier_stream_rows(
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
instance_name, from_token, upper_limit
)

Expand Down
2 changes: 1 addition & 1 deletion synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ class StateResolutionStore:
store: "DataStore"

def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
self, event_ids: Collection[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database

Expand Down
3 changes: 2 additions & 1 deletion synapse/state/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import (
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Expand Down Expand Up @@ -44,7 +45,7 @@ async def resolve_events_with_store(
room_version: RoomVersion,
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
state_map_factory: Callable[[Collection[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.types import StreamToken, get_domain_from_id
from synapse.types import get_domain_from_id
from synapse.util import json_decoder

if TYPE_CHECKING:
Expand All @@ -48,7 +48,7 @@ def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: StreamToken,
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved
token: int,
rows: Iterable[Any],
) -> None:
pass
Expand Down
29 changes: 17 additions & 12 deletions synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.
import itertools
import logging
from collections import OrderedDict, namedtuple
from collections import OrderedDict
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -41,9 +41,10 @@
from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
Expand All @@ -64,9 +65,6 @@
)


_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))


@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
Expand Down Expand Up @@ -108,16 +106,21 @@ def __init__(
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id

# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: MultiWriterIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: MultiWriterIdGenerator = self.store._stream_id_gen

# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"

# Since we have been configured to write, we ought to have id generators,
# rather than id trackers.
assert isinstance(self.store._backfill_id_gen, AbstractStreamIdGenerator)
assert isinstance(self.store._stream_id_gen, AbstractStreamIdGenerator)

# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen

async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
Expand Down Expand Up @@ -1553,11 +1556,13 @@ def _add_to_cache(self, txn, events_and_contexts):
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
to_prefill.append(_EventCacheEntry(event=event, redacted_event=None))
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))

def prefill():
for cache_entry in to_prefill:
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
self.store._get_event_cache.set(
(cache_entry.event.event_id,), cache_entry
)

txn.call_after(prefill)

Expand Down
Loading