-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit tests for the
connect
method of all Redis connection class…
…es (#2631) * tests: move certificate discovery to a separate module * tests: add 'connect' tests for all Redis connection classes --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com>
- Loading branch information
1 parent
2bb7f10
commit 53bed27
Showing
5 changed files
with
350 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import os | ||
|
||
|
||
def get_ssl_filename(name): | ||
root = os.path.join(os.path.dirname(__file__), "..") | ||
cert_dir = os.path.abspath(os.path.join(root, "docker", "stunnel", "keys")) | ||
if not os.path.isdir(cert_dir): # github actions package validation case | ||
cert_dir = os.path.abspath( | ||
os.path.join(root, "..", "docker", "stunnel", "keys") | ||
) | ||
if not os.path.isdir(cert_dir): | ||
raise IOError(f"No SSL certificates found. They should be in {cert_dir}") | ||
|
||
return os.path.join(cert_dir, name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import asyncio | ||
import logging | ||
import re | ||
import socket | ||
import ssl | ||
|
||
import pytest | ||
|
||
from redis.asyncio.connection import ( | ||
Connection, | ||
SSLConnection, | ||
UnixDomainSocketConnection, | ||
) | ||
|
||
from ..ssl_utils import get_ssl_filename | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
_CLIENT_NAME = "test-suite-client" | ||
_CMD_SEP = b"\r\n" | ||
_SUCCESS_RESP = b"+OK" + _CMD_SEP | ||
_ERROR_RESP = b"-ERR" + _CMD_SEP | ||
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} | ||
|
||
|
||
@pytest.fixture | ||
def tcp_address(): | ||
with socket.socket() as sock: | ||
sock.bind(("127.0.0.1", 0)) | ||
return sock.getsockname() | ||
|
||
|
||
@pytest.fixture | ||
def uds_address(tmpdir): | ||
return tmpdir / "uds.sock" | ||
|
||
|
||
async def test_tcp_connect(tcp_address): | ||
host, port = tcp_address | ||
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) | ||
await _assert_connect(conn, tcp_address) | ||
|
||
|
||
async def test_uds_connect(uds_address): | ||
path = str(uds_address) | ||
conn = UnixDomainSocketConnection( | ||
path=path, client_name=_CLIENT_NAME, socket_timeout=10 | ||
) | ||
await _assert_connect(conn, path) | ||
|
||
|
||
@pytest.mark.ssl | ||
async def test_tcp_ssl_connect(tcp_address): | ||
host, port = tcp_address | ||
certfile = get_ssl_filename("server-cert.pem") | ||
keyfile = get_ssl_filename("server-key.pem") | ||
conn = SSLConnection( | ||
host=host, | ||
port=port, | ||
client_name=_CLIENT_NAME, | ||
ssl_ca_certs=certfile, | ||
socket_timeout=10, | ||
) | ||
await _assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) | ||
|
||
|
||
async def _assert_connect(conn, server_address, certfile=None, keyfile=None): | ||
stop_event = asyncio.Event() | ||
finished = asyncio.Event() | ||
|
||
async def _handler(reader, writer): | ||
try: | ||
return await _redis_request_handler(reader, writer, stop_event) | ||
finally: | ||
finished.set() | ||
|
||
if isinstance(server_address, str): | ||
server = await asyncio.start_unix_server(_handler, path=server_address) | ||
elif certfile: | ||
host, port = server_address | ||
context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) | ||
context.minimum_version = ssl.TLSVersion.TLSv1_2 | ||
context.load_cert_chain(certfile=certfile, keyfile=keyfile) | ||
server = await asyncio.start_server(_handler, host=host, port=port, ssl=context) | ||
else: | ||
host, port = server_address | ||
server = await asyncio.start_server(_handler, host=host, port=port) | ||
|
||
async with server as aserver: | ||
await aserver.start_serving() | ||
try: | ||
await conn.connect() | ||
await conn.disconnect() | ||
finally: | ||
stop_event.set() | ||
aserver.close() | ||
await aserver.wait_closed() | ||
await finished.wait() | ||
|
||
|
||
async def _redis_request_handler(reader, writer, stop_event): | ||
buffer = b"" | ||
command = None | ||
command_ptr = None | ||
fragment_length = None | ||
while not stop_event.is_set() or buffer: | ||
_logger.info(str(stop_event.is_set())) | ||
try: | ||
buffer += await asyncio.wait_for(reader.read(1024), timeout=0.5) | ||
except TimeoutError: | ||
continue | ||
if not buffer: | ||
continue | ||
parts = re.split(_CMD_SEP, buffer) | ||
buffer = parts[-1] | ||
for fragment in parts[:-1]: | ||
fragment = fragment.decode() | ||
_logger.info("Command fragment: %s", fragment) | ||
|
||
if fragment.startswith("*") and command is None: | ||
command = [None for _ in range(int(fragment[1:]))] | ||
command_ptr = 0 | ||
fragment_length = None | ||
continue | ||
|
||
if fragment.startswith("$") and command[command_ptr] is None: | ||
fragment_length = int(fragment[1:]) | ||
continue | ||
|
||
assert len(fragment) == fragment_length | ||
command[command_ptr] = fragment | ||
command_ptr += 1 | ||
|
||
if command_ptr < len(command): | ||
continue | ||
|
||
command = " ".join(command) | ||
_logger.info("Command %s", command) | ||
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) | ||
_logger.info("Response from %s", resp) | ||
writer.write(resp) | ||
await writer.drain() | ||
command = None | ||
_logger.info("Exit handler") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import logging | ||
import re | ||
import socket | ||
import socketserver | ||
import ssl | ||
import threading | ||
|
||
import pytest | ||
|
||
from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection | ||
|
||
from .ssl_utils import get_ssl_filename | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
_CLIENT_NAME = "test-suite-client" | ||
_CMD_SEP = b"\r\n" | ||
_SUCCESS_RESP = b"+OK" + _CMD_SEP | ||
_ERROR_RESP = b"-ERR" + _CMD_SEP | ||
_SUPPORTED_CMDS = {f"CLIENT SETNAME {_CLIENT_NAME}": _SUCCESS_RESP} | ||
|
||
|
||
@pytest.fixture | ||
def tcp_address(): | ||
with socket.socket() as sock: | ||
sock.bind(("127.0.0.1", 0)) | ||
return sock.getsockname() | ||
|
||
|
||
@pytest.fixture | ||
def uds_address(tmpdir): | ||
return tmpdir / "uds.sock" | ||
|
||
|
||
def test_tcp_connect(tcp_address): | ||
host, port = tcp_address | ||
conn = Connection(host=host, port=port, client_name=_CLIENT_NAME, socket_timeout=10) | ||
_assert_connect(conn, tcp_address) | ||
|
||
|
||
def test_uds_connect(uds_address): | ||
path = str(uds_address) | ||
conn = UnixDomainSocketConnection(path, client_name=_CLIENT_NAME, socket_timeout=10) | ||
_assert_connect(conn, path) | ||
|
||
|
||
@pytest.mark.ssl | ||
def test_tcp_ssl_connect(tcp_address): | ||
host, port = tcp_address | ||
certfile = get_ssl_filename("server-cert.pem") | ||
keyfile = get_ssl_filename("server-key.pem") | ||
conn = SSLConnection( | ||
host=host, | ||
port=port, | ||
client_name=_CLIENT_NAME, | ||
ssl_ca_certs=certfile, | ||
socket_timeout=10, | ||
) | ||
_assert_connect(conn, tcp_address, certfile=certfile, keyfile=keyfile) | ||
|
||
|
||
def _assert_connect(conn, server_address, certfile=None, keyfile=None): | ||
if isinstance(server_address, str): | ||
server = _RedisUDSServer(server_address, _RedisRequestHandler) | ||
else: | ||
server = _RedisTCPServer( | ||
server_address, _RedisRequestHandler, certfile=certfile, keyfile=keyfile | ||
) | ||
with server as aserver: | ||
t = threading.Thread(target=aserver.serve_forever) | ||
t.start() | ||
try: | ||
aserver.wait_online() | ||
conn.connect() | ||
conn.disconnect() | ||
finally: | ||
aserver.stop() | ||
t.join(timeout=5) | ||
|
||
|
||
class _RedisTCPServer(socketserver.TCPServer): | ||
def __init__(self, *args, certfile=None, keyfile=None, **kw) -> None: | ||
self._ready_event = threading.Event() | ||
self._stop_requested = False | ||
self._certfile = certfile | ||
self._keyfile = keyfile | ||
super().__init__(*args, **kw) | ||
|
||
def service_actions(self): | ||
self._ready_event.set() | ||
|
||
def wait_online(self): | ||
self._ready_event.wait() | ||
|
||
def stop(self): | ||
self._stop_requested = True | ||
self.shutdown() | ||
|
||
def is_serving(self): | ||
return not self._stop_requested | ||
|
||
def get_request(self): | ||
if self._certfile is None: | ||
return super().get_request() | ||
newsocket, fromaddr = self.socket.accept() | ||
connstream = ssl.wrap_socket( | ||
newsocket, | ||
server_side=True, | ||
certfile=self._certfile, | ||
keyfile=self._keyfile, | ||
ssl_version=ssl.PROTOCOL_TLSv1_2, | ||
) | ||
return connstream, fromaddr | ||
|
||
|
||
class _RedisUDSServer(socketserver.UnixStreamServer): | ||
def __init__(self, *args, **kw) -> None: | ||
self._ready_event = threading.Event() | ||
self._stop_requested = False | ||
super().__init__(*args, **kw) | ||
|
||
def service_actions(self): | ||
self._ready_event.set() | ||
|
||
def wait_online(self): | ||
self._ready_event.wait() | ||
|
||
def stop(self): | ||
self._stop_requested = True | ||
self.shutdown() | ||
|
||
def is_serving(self): | ||
return not self._stop_requested | ||
|
||
|
||
class _RedisRequestHandler(socketserver.StreamRequestHandler): | ||
def setup(self): | ||
_logger.info("%s connected", self.client_address) | ||
|
||
def finish(self): | ||
_logger.info("%s disconnected", self.client_address) | ||
|
||
def handle(self): | ||
buffer = b"" | ||
command = None | ||
command_ptr = None | ||
fragment_length = None | ||
while self.server.is_serving() or buffer: | ||
try: | ||
buffer += self.request.recv(1024) | ||
except socket.timeout: | ||
continue | ||
if not buffer: | ||
continue | ||
parts = re.split(_CMD_SEP, buffer) | ||
buffer = parts[-1] | ||
for fragment in parts[:-1]: | ||
fragment = fragment.decode() | ||
_logger.info("Command fragment: %s", fragment) | ||
|
||
if fragment.startswith("*") and command is None: | ||
command = [None for _ in range(int(fragment[1:]))] | ||
command_ptr = 0 | ||
fragment_length = None | ||
continue | ||
|
||
if fragment.startswith("$") and command[command_ptr] is None: | ||
fragment_length = int(fragment[1:]) | ||
continue | ||
|
||
assert len(fragment) == fragment_length | ||
command[command_ptr] = fragment | ||
command_ptr += 1 | ||
|
||
if command_ptr < len(command): | ||
continue | ||
|
||
command = " ".join(command) | ||
_logger.info("Command %s", command) | ||
resp = _SUPPORTED_CMDS.get(command, _ERROR_RESP) | ||
_logger.info("Response %s", resp) | ||
self.request.sendall(resp) | ||
command = None | ||
_logger.info("Exit handler") |
Oops, something went wrong.