From 8535c32cea6f43f9c7c9a9154881be8f6809a154 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 23 Mar 2020 11:32:02 +0000 Subject: [PATCH] Add redis support --- mypy.ini | 3 + synapse/app/homeserver.py | 7 ++ synapse/config/homeserver.py | 2 + synapse/config/redis.py | 47 +++++++++ synapse/python_dependencies.py | 1 + synapse/replication/tcp/handler.py | 18 +++- synapse/replication/tcp/redis.py | 161 +++++++++++++++++++++++++++++ 7 files changed, 234 insertions(+), 5 deletions(-) create mode 100644 synapse/config/redis.py create mode 100644 synapse/replication/tcp/redis.py diff --git a/mypy.ini b/mypy.ini index 69be2f67adf2..d6197d607d2a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -75,3 +75,6 @@ ignore_missing_imports = True [mypy-jwt.*] ignore_missing_imports = True + +[mypy-txredisapi] +ignore_missing_imports = True diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f2b56a636f52..2e477e33dc93 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -263,6 +263,12 @@ def _configure_named_resource(self, name, compress=False): def start_listening(self, listeners): config = self.get_config() + if config.redis_enabled: + # If redis is enabled we connect via the replication command handler + # in the same way as the workers (since we're effectively a client + # rather than a server). + self.get_tcp_replication().start_replication(self) + for listener in listeners: if listener["type"] == "http": self._listening_services.extend(self._listener_http(config, listener)) @@ -282,6 +288,7 @@ def start_listening(self, listeners): ) for s in services: reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) + elif listener["type"] == "metrics": if not self.get_config().enable_metrics: logger.warning( diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b4bca08b20aa..be6c6afa7469 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -31,6 +31,7 @@ from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig +from .redis import RedisConfig from .registration import RegistrationConfig from .repository import ContentRepositoryConfig from .room_directory import RoomDirectoryConfig @@ -82,4 +83,5 @@ class HomeServerConfig(RootConfig): RoomDirectoryConfig, ThirdPartyRulesConfig, TracerConfig, + RedisConfig, ] diff --git a/synapse/config/redis.py b/synapse/config/redis.py new file mode 100644 index 000000000000..7355ebe124b3 --- /dev/null +++ b/synapse/config/redis.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, ConfigError + +try: + import txredisapi +except ImportError: + txredisapi = None + + +MISSING_REDIS = """Missing 'txredisapi' library. This is required for redis support. + + Install by running: + pip install txredisapi +""" + + +class RedisConfig(Config): + section = "redis" + + def read_config(self, config, **kwargs): + redis_config = config.get("redis", {}) + self.redis_enabled = redis_config.get("enabled", False) + + if not self.redis_enabled: + return + + if txredisapi is None: + raise ConfigError(MISSING_REDIS) + + self.redis_host = redis_config.get("host", "localhost") + self.redis_port = redis_config.get("port", 6379) + self.redis_dbid = redis_config.get("dbid") + self.redis_password = redis_config.get("password") diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 8de8cb2c1287..733c51b75841 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -98,6 +98,7 @@ "sentry": ["sentry-sdk>=0.7.2"], "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"], "jwt": ["pyjwt>=1.6.4"], + "redis": ["txredisapi>=1.4.7"], } ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str] diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 130846c2f92e..0ff21b7f722a 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -123,11 +123,19 @@ def start_replication(self, hs): """Helper method to start a replication connection to the remote server using TCP. """ - client_name = hs.config.worker_name - self.factory = ReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self.factory) + if hs.config.redis.redis_enabled: + from synapse.replication.tcp.redis import RedisFactory + + logger.info("Connecting to redis.") + self.factory = RedisFactory(hs) + hs.get_reactor().connectTCP( + hs.config.redis.redis_host, hs.config.redis.redis_port, self.factory, + ) + else: + self.factory = ReplicationClientFactory(hs, hs.config.worker_name, self) + host = hs.config.worker_replication_host + port = hs.config.worker_replication_port + hs.get_reactor().connectTCP(host, port, self.factory) async def on_REPLICATE(self, cmd: ReplicateCommand): # We only want to announce positions by the writer of the streams. diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py new file mode 100644 index 000000000000..2a02fe7f3410 --- /dev/null +++ b/synapse/replication/tcp/redis.py @@ -0,0 +1,161 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import txredisapi + +from synapse.logging.context import PreserveLoggingContext +from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.replication.tcp.commands import ( + COMMAND_MAP, + Command, + RdataCommand, + ReplicateCommand, +) +from synapse.util.stringutils import random_string + +logger = logging.getLogger(__name__) + + +class RedisSubscriber(txredisapi.SubscriberProtocol): + """Connection to redis subscribed to replication stream. + """ + + def connectionMade(self): + logger.info("Connected to redis instance") + self.subscribe(self.stream_name) + self.send_command(ReplicateCommand()) + + self.handler.new_connection(self) + + def messageReceived(self, pattern: str, channel: str, message: str): + """Received a message from redis. + """ + + if message.strip() == "": + # Ignore blank lines + return + + line = message + cmd_name, rest_of_line = line.split(" ", 1) + + cmd_cls = COMMAND_MAP[cmd_name] + try: + cmd = cmd_cls.from_line(rest_of_line) + except Exception as e: + logger.exception( + "[%s] failed to parse line %r: %r", self.id(), cmd_name, rest_of_line + ) + self.send_error( + "failed to parse line for %r: %r (%r):" % (cmd_name, e, rest_of_line) + ) + return + + # Now lets try and call on_ function + run_as_background_process( + "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd + ) + + async def handle_command(self, cmd: Command): + """Handle a command we have received over the replication stream. + + By default delegates to on_, which should return an awaitable. + + Args: + cmd: received command + """ + handled = False + + # First call any command handlers on this instance. These are for redis + # specific handling. + cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + # Then call out to the handler. + cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None) + if cmd_func: + await cmd_func(cmd) + handled = True + + if not handled: + logger.warning("Unhandled command: %r", cmd) + + def connectionLost(self, reason): + logger.info("Lost connection to redis instance") + self.handler.lost_connection(self) + + def send_command(self, cmd): + """Send a command if connection has been established. + + Args: + cmd (Command) + """ + string = "%s %s" % (cmd.NAME, cmd.to_line()) + if "\n" in string: + raise Exception("Unexpected newline in command: %r", string) + + encoded_string = string.encode("utf-8") + + async def _send(): + with PreserveLoggingContext(): + await self.redis_connection.publish(self.stream_name, encoded_string) + + run_as_background_process("send-cmd", _send) + + def stream_update(self, stream_name, token, data): + """Called when a new update is available to stream to clients. + + We need to check if the client is interested in the stream or not + """ + self.send_command(RdataCommand(stream_name, token, data)) + + +class RedisFactory(txredisapi.SubscriberFactory): + """This is a reconnecting factory that connects to redis and immediately + subscribes to a stream. + """ + + maxDelay = 5 + continueTrying = True + protocol = RedisSubscriber + + def __init__(self, hs): + super(RedisFactory, self).__init__() + + self.password = hs.config.redis.redis_password + + self.handler = hs.get_tcp_replication() + self.stream_name = hs.hostname + + self.redis_connection = txredisapi.lazyConnection( + host=hs.config.redis_host, + port=hs.config.redis_port, + dbid=hs.config.redis_dbid, + password=hs.config.redis.redis_password, + reconnect=True, + ) + + self.conn_id = random_string(5) + + def buildProtocol(self, addr): + p = super(RedisFactory, self).buildProtocol(addr) + p.handler = self.handler + p.redis_connection = self.redis_connection + p.conn_id = self.conn_id + p.stream_name = self.stream_name + return p