Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix limit logic for AccountDataStream #7384

Merged
merged 7 commits into from
May 15, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7384.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.
66 changes: 54 additions & 12 deletions synapse/replication/tcp/streams/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import heapq
import logging
from collections import namedtuple
from typing import Any, Awaitable, Callable, List, Optional, Tuple
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
List,
Optional,
Tuple,
TypeVar,
)

import attr

from synapse.replication.http.streams import ReplicationGetStreamUpdates

if TYPE_CHECKING:
import synapse.server

logger = logging.getLogger(__name__)

# the number of rows to request from an update_function.
Expand All @@ -37,7 +50,7 @@
# parsing with Stream.parse_row (which turns it into a `ROW_TYPE`). Normally it's
# just a row from a database query, though this is dependent on the stream in question.
#
StreamRow = Tuple
StreamRow = TypeVar("StreamRow", bound=Tuple)

# The type returned by the update_function of a stream, as well as get_updates(),
# get_updates_since, etc.
Expand Down Expand Up @@ -499,32 +512,61 @@ class AccountDataStream(Stream):
"""

AccountDataStreamRow = namedtuple(
"AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
"AccountDataStream",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)

NAME = "account_data"
ROW_TYPE = AccountDataStreamRow

def __init__(self, hs):
def __init__(self, hs: "synapse.server.HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_max_account_data_stream_id,
db_query_to_update_function(self._update_function),
self._update_function,
)

async def _update_function(self, from_token, to_token, limit):
global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
async def _update_function(
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
limited = False
global_results = await self.store.get_updated_global_account_data(
from_token, to_token, limit
)

results = list(room_results)
results.extend(
(stream_id, user_id, None, account_data_type)
# if the global results hit the limit, we'll need to limit the room results to
# the same stream token.
if len(global_results) >= limit:
to_token = global_results[-1][0]
limited = True

room_results = await self.store.get_updated_room_account_data(
from_token, to_token, limit
)

# likewise, if the room results hit the limit, limit the global results to
# the same stream token.
if len(room_results) >= limit:
to_token = room_results[-1][0]
limited = True

# convert the global results to the right format, and limit them to the to_token
# at the same time
global_rows = (
(stream_id, (user_id, None, account_data_type))
for stream_id, user_id, account_data_type in global_results
if stream_id <= to_token
)

room_rows = (
(stream_id, (user_id, room_id, account_data_type))
for stream_id, user_id, room_id, account_data_type in room_results
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that technically this doesn't need a if stream_id <= token clause, but I think adding may make it less confusing why we're omitting it (or add a comment)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I added a comment.

)

return results
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(room_rows, global_rows))
return updates, to_token, limited
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason not to move some of this handling into the store? I'm mainly thinking it might be more efficient that way, if for no other reason that we would only need one transaction rather than two.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dunno, maybe?

It feels like that would be a different approach to that taken by all the other stream sources: it would mean we'd have to mess about with limited inside the store, which normally we don't do.

Another way would be to follow the example of get_all_device_list_changes_for_remotes and do one query with a UNION, but I only found that after I wrote this stuff...



class GroupServerStream(Stream):
Expand Down
62 changes: 43 additions & 19 deletions synapse/storage/data_stores/main/account_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import abc
import logging
from typing import List, Tuple

from canonicaljson import json

Expand Down Expand Up @@ -175,41 +176,64 @@ def get_account_data_for_room_and_type_txn(txn):
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)

def get_all_updated_account_data(
self, last_global_id, last_room_id, current_id, limit
):
"""Get all the client account_data that has changed on the server
async def get_updated_global_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get the global account_data that has changed, for the account_data stream

Args:
last_global_id(int): The position to fetch from for top level data
last_room_id(int): The position to fetch from for per room data
current_id(int): The position to fetch up to.
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return

Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
room_id string, and type string.
A list of tuples of stream_id int, user_id string,
and type string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
if last_id == current_id:
return []

def get_updated_account_data_txn(txn):
def get_updated_global_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_global_id, current_id, limit))
global_results = txn.fetchall()
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()

return await self.db.runInteraction(
"get_updated_global_account_data", get_updated_global_account_data_txn
)

async def get_updated_room_account_data(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get the global account_data that has changed, for the account_data stream

Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return

Returns:
A list of tuples of stream_id int, user_id string,
room_id string and type string.
"""
if last_id == current_id:
return []

def get_updated_room_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
return global_results, room_results
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()

return self.db.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
return await self.db.runInteraction(
"get_updated_room_account_data", get_updated_room_account_data_txn
)

def get_updated_account_data_for_user(self, user_id, stream_id):
Expand Down
117 changes: 117 additions & 0 deletions tests/replication/tcp/streams/test_account_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from synapse.replication.tcp.streams._base import (
_STREAM_UPDATE_TARGET_ROW_COUNT,
AccountDataStream,
)

from tests.replication.tcp.streams._base import BaseStreamTestCase


class AccountDataStreamTestCase(BaseStreamTestCase):
def test_update_function_room_account_data_limit(self):
"""Test replication with many room account data updates
"""
store = self.hs.get_datastore()

# generate lots of account data updates
updates = []
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
update = "m.test_type.%i" % (i,)
self.get_success(
store.add_account_data_to_room("test_user", "test_room", update, {})
)
updates.append(update)

# also one global update
self.get_success(store.add_account_data_for_user("test_user", "m.global", {}))

# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()

# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)

# now reconnect to pull the updates
self.reconnect()
self.replicate()

# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows

for t in updates:
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertEqual(row.room_id, "test_room")

(stream_name, token, row) = received_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.global")
self.assertIsNone(row.room_id)

self.assertEqual([], received_rows)

def test_update_function_global_account_data_limit(self):
"""Test replication with many global account data updates
"""
store = self.hs.get_datastore()

# generate lots of account data updates
updates = []
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
update = "m.test_type.%i" % (i,)
self.get_success(store.add_account_data_for_user("test_user", update, {}))
updates.append(update)

# also one per-room update
self.get_success(
store.add_account_data_to_room("test_user", "test_room", "m.per_room", {})
)

# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()

# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)

# now reconnect to pull the updates
self.reconnect()
self.replicate()

# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows

for t in updates:
(stream_name, token, row) = received_rows.pop(0)
self.assertEqual(stream_name, AccountDataStream.NAME)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, t)
self.assertIsNone(row.room_id)

(stream_name, token, row) = received_rows.pop(0)
self.assertIsInstance(row, AccountDataStream.AccountDataStreamRow)
self.assertEqual(row.data_type, "m.per_room")
self.assertEqual(row.room_id, "test_room")

self.assertEqual([], received_rows)