Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Fix remaining mypy issues due to Twisted upgrade. #9608

Merged
merged 7 commits into from
Mar 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/9608.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix incorrect type hints.
2 changes: 1 addition & 1 deletion stubs/txredisapi.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union

from twisted.internet import protocol

class RedisProtocol:
class RedisProtocol(protocol.Protocol):
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
Expand Down
12 changes: 10 additions & 2 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
ITCPTransport,
)
from twisted.internet.protocol import connectionDone
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
Expand Down Expand Up @@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception):
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""

transport = None # type: Optional[ITCPTransport]

def __init__(self, deferred: defer.Deferred):
self.deferred = deferred

Expand All @@ -771,18 +775,21 @@ def _maybe_fail(self):
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()

def dataReceived(self, data: bytes) -> None:
self._maybe_fail()

def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
self._maybe_fail()


class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""

transport = None # type: Optional[ITCPTransport]

def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
Expand All @@ -805,9 +812,10 @@ def dataReceived(self, data: bytes) -> None:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()

def connectionLost(self, reason: Failure) -> None:
def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
Expand Down
4 changes: 2 additions & 2 deletions synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def start_replication(self, hs):
hs, outbound_redis_connection
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,
hs.config.redis.redis_host.encode(),
hs.config.redis.redis_port,
self._factory,
)
Expand All @@ -311,7 +311,7 @@ def start_replication(self, hs):
self._factory = DirectTcpReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self._factory)
hs.get_reactor().connectTCP(host.encode(), port, self._factory)

def get_streams(self) -> Dict[str, Stream]:
"""Get a map from stream name to all streams."""
Expand Down
9 changes: 9 additions & 0 deletions synapse/replication/tcp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from zope.interface import Interface, implementer

from twisted.internet import task
from twisted.internet.tcp import Connection
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure

Expand Down Expand Up @@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
(if they send a `PING` command)
"""

# The transport is going to be an ITCPTransport, but that doesn't have the
# (un)registerProducer methods, those are only on the implementation.
transport = None # type: Connection

delimiter = b"\n"

# Valid commands we expect to receive
Expand Down Expand Up @@ -189,6 +194,7 @@ def connectionMade(self):

connected_connections.append(self) # Register connection for metrics

assert self.transport is not None
self.transport.registerProducer(self, True) # For the *Producing callbacks

self._send_pending_commands()
Expand All @@ -213,6 +219,7 @@ def send_ping(self):
logger.info(
"[%s] Failed to close connection gracefully, aborting", self.id()
)
assert self.transport is not None
self.transport.abortConnection()
else:
if now - self.last_sent_command >= PING_TIME:
Expand Down Expand Up @@ -302,6 +309,7 @@ def handle_command(self, cmd: Command) -> None:
def close(self):
logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
assert self.transport is not None
self.transport.loseConnection()
self.on_connection_closed()

Expand Down Expand Up @@ -399,6 +407,7 @@ def stopProducing(self):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc()
else:
connection_close_counter.labels(reason.__class__.__name__).inc()
Expand Down
2 changes: 1 addition & 1 deletion synapse/replication/tcp/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,6 @@ def lazyConnection(
factory.continueTrying = reconnect

reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None)

return factory.handler
44 changes: 16 additions & 28 deletions tests/replication/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple

import attr
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
Expand Down Expand Up @@ -158,10 +156,8 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)

request_factory = OneShotRequestFactory()

# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
Expand All @@ -183,7 +179,7 @@ def handle_http_replication_attempt(self) -> SynapseRequest:
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()

return request_factory.request
return channel.request

def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
Expand Down Expand Up @@ -237,7 +233,7 @@ def setUp(self):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
b"localhost",
6379,
self.connect_any_redis_attempts,
)
Expand Down Expand Up @@ -392,10 +388,8 @@ def _handle_http_replication_attempt(self, hs, repl_port):
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)

request_factory = OneShotRequestFactory()

# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])

# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
Expand All @@ -421,7 +415,7 @@ def connect_any_redis_attempts(self):
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(host, b"localhost")
self.assertEqual(port, 6379)

client_protocol = client_factory.buildProtocol(None)
Expand Down Expand Up @@ -453,21 +447,6 @@ async def on_rdata(self, stream_name, instance_name, token, rows):
self.received_rdata_rows.append((stream_name, token, r))


@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""

request = attr.ib(default=None)

def __call__(self, *args, **kwargs):
assert self.request is None

self.request = SynapseRequest(*args, **kwargs)
return self.request


class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.

Expand All @@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel):
"""

def __init__(
self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
self, reactor: IReactorTime, request_factory: Type[Request], site: Site
):
super().__init__()
self.reactor = reactor
Expand Down Expand Up @@ -510,6 +489,11 @@ def checkPersistence(self, request, version):
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False

def requestDone(self, request):
# Store the request for inspection.
self.request = request
super().requestDone(request)


class _PullToPushProducer:
"""A push producer that wraps a pull producer."""
Expand Down Expand Up @@ -597,6 +581,8 @@ def buildProtocol(self, addr):
class FakeRedisPubSubProtocol(Protocol):
"""A connection from a client talking to the fake Redis server."""

transport = None # type: Optional[FakeTransport]

def __init__(self, server: FakeRedisPubSubServer):
self._server = server
self._reader = hiredis.Reader()
Expand Down Expand Up @@ -641,6 +627,8 @@ def handle_command(self, command, *args):

def send(self, msg):
"""Send a message back to the client."""
assert self.transport is not None

raw = self.encode(msg).encode("utf-8")

self.transport.write(raw)
Expand Down
2 changes: 2 additions & 0 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
IReactorPluggableNameResolver,
IReactorTCP,
IResolverSimple,
ITransport,
)
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
Expand Down Expand Up @@ -467,6 +468,7 @@ def get_clock():
return clock, hs_clock


@implementer(ITransport)
@attr.s(cmp=False)
class FakeTransport:
"""
Expand Down