diff --git a/tests/fixtures.py b/tests/fixtures.py index 9f6a805..2a854cc 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,6 +1,7 @@ from concurrent import futures from ephemeral_port_reserve import reserve from test_framework.bitcoind import Bitcoind, BitcoindRpcProxy +from test_framework.coordinator import DummyCoordinator from test_framework.miradord import Miradord from test_framework.utils import ( get_descriptors, @@ -9,6 +10,7 @@ MANS_XPUBS, COSIG_PUBKEYS, CSV, + NoiseKeypair, ) import os @@ -114,7 +116,40 @@ def bitcoind(directory): @pytest.fixture -def miradord(request, bitcoind, directory): +def noise_keys(): + # The Noise keys are interdependant, so generate everything in advance + # to avoid roundtrips + coordinator_keys = NoiseKeypair(os.urandom(32)) + manager_keys = NoiseKeypair(os.urandom(32)) + stakeholder_keys = NoiseKeypair(os.urandom(32)) + watchtower_keys = NoiseKeypair(os.urandom(32)) + noise_keys = { + "coordinator": coordinator_keys, + "manager": manager_keys, + "stakeholder": stakeholder_keys, + "watchtower": watchtower_keys, + } + yield noise_keys + + +@pytest.fixture +def coordinator(noise_keys): + coordinator_port = reserve() + coordinator = DummyCoordinator( + coordinator_port, + noise_keys["coordinator"].privkey, + [ + noise_keys["manager"].pubkey, + noise_keys["stakeholder"].pubkey, + noise_keys["watchtower"].pubkey + ], + ) + coordinator.start() + yield coordinator + + +@pytest.fixture +def miradord(request, bitcoind, coordinator, noise_keys, directory): """If a 'mock_bitcoind' pytest marker is set, it will create a proxy for the communication from the miradord process to the bitcoind process. An optional 'mocks' parameter can be set for this marker in order to specify some pre-registered mock of RPC commands. @@ -130,10 +165,6 @@ def miradord(request, bitcoind, directory): ) emer_addr = "bcrt1qewc2348370pgw8kjz8gy09z8xyh0d9fxde6nzamd3txc9gkmjqmq8m4cdq" - coordinator_noise_key = ( - "d91563973102454a7830137e92d0548bc83b4ea2799f1df04622ca1307381402" - ) - bitcoind_cookie = os.path.join(bitcoind.bitcoin_dir, "regtest", ".cookie") bitcoind_rpcport = bitcoind.rpcport @@ -151,10 +182,10 @@ def miradord(request, bitcoind, directory): cpfp_desc, emer_addr, reserve(), - os.urandom(32), - os.urandom(32), - coordinator_noise_key, # Unused yet - reserve(), # Unused yet + noise_keys["watchtower"].privkey, + noise_keys["stakeholder"].privkey, + coordinator.coordinator_pubkey, + coordinator.port, bitcoind_rpcport, bitcoind_cookie, ) diff --git a/tests/test_framework/coordinator.py b/tests/test_framework/coordinator.py new file mode 100644 index 0000000..54fc1a7 --- /dev/null +++ b/tests/test_framework/coordinator.py @@ -0,0 +1,192 @@ +import cryptography +import json +import os +import select +import socket +import threading + +from nacl.public import PrivateKey as Curve25519Private +from noise.connection import NoiseConnection, Keypair +from test_framework.utils import ( + TIMEOUT, +) + +HANDSHAKE_MSG = b"practical_revault_0" + + +class DummyCoordinator: + """A simple in-RAM synchronization server.""" + + def __init__( + self, + port, + coordinator_privkey, + client_pubkeys, + ): + self.port = port + self.coordinator_privkey = coordinator_privkey + self.coordinator_pubkey = bytes( + Curve25519Private(coordinator_privkey).public_key + ) + self.client_pubkeys = client_pubkeys + + # Spin up the coordinator proxy + self.s = socket.socket() + self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.s.bind(("localhost", self.port)) + self.s.listen(1_000) + # Use a pipe to communicate to threads to stop + self.r_close_chann, self.w_close_chann = os.pipe() + + # A mapping from txid to pubkey to signature + self.sigs = {} + # A mapping from deposit_outpoint to base64 tx + self.spend_txs = {} + + def __del__(self): + self.cleanup() + + def start(self): + self.server_thread = threading.Thread(target=self.run) + self.server_thread.start() + + def cleanup(self): + # Write to the pipe to notify the thread it needs to stop + os.write(self.w_close_chann, b".") + self.server_thread.join() + + def run(self): + """Accept new connections until we are told to stop.""" + while True: + r_fds, _, _ = select.select([self.r_close_chann, self.s.fileno()], [], []) + + # First check if we've been told to stop, then spawn a new thread per connection + if self.r_close_chann in r_fds: + break + if self.s.fileno() in r_fds: + t = threading.Thread(target=self.connection_handle, daemon=True) + t.start() + + def connection_handle(self): + """Read and treat requests from this client. Blocking.""" + client_fd, _ = self.s.accept() + client_fd.settimeout(TIMEOUT // 2) + client_noise = self.server_noise_conn(client_fd) + + while True: + # Manually do the select to check if we've been told to stop + r_fds, _, _ = select.select([self.r_close_chann, client_fd], [], []) + if self.r_close_chann in r_fds: + break + req = self.read_msg(client_fd, client_noise) + if req is None: + break + request = json.loads(req) + method, params = request["method"], request["params"] + + if method == "sig": + # TODO: mutex + if params["txid"] not in self.sigs: + self.sigs[params["txid"]] = {} + self.sigs[params["txid"]][params["pubkey"]] = params["signature"] + # TODO: remove this useless response from the protocol + resp = {"result": {"ack": True}, "id": request["id"]} + self.send_msg(client_fd, client_noise, json.dumps(resp)) + + elif method == "get_sigs": + txid = params["txid"] + sigs = self.sigs.get(txid, {}) + resp = {"result": {"signatures": sigs}, "id": request["id"]} + self.send_msg(client_fd, client_noise, json.dumps(resp)) + + elif method == "set_spend_tx": + for outpoint in params["deposit_outpoints"]: + self.spend_txs[outpoint] = params["spend_tx"] + # TODO: remove this useless response from the protocol + resp = {"result": {"ack": True}, "id": request["id"]} + self.send_msg(client_fd, client_noise, json.dumps(resp)) + + elif method == "get_spend_tx": + spend_tx = self.spend_txs.get(params["deposit_outpoint"]) + resp = {"result": {"spend_tx": spend_tx}, "id": request["id"]} + self.send_msg(client_fd, client_noise, json.dumps(resp)) + + else: + assert False, "Invalid request '{}'".format(method) + + def server_noise_conn(self, fd): + """Do practical-revault's Noise handshake with a given client connection.""" + # Read the first message of the handshake only once + data = self.read_data(fd, 32 + len(HANDSHAKE_MSG) + 16) + + # We brute force all pubkeys. FIXME! + for pubkey in self.client_pubkeys: + # Set the local and remote static keys + conn = NoiseConnection.from_name(b"Noise_KK_25519_ChaChaPoly_SHA256") + conn.set_as_responder() + conn.set_keypair_from_private_bytes( + Keypair.STATIC, self.coordinator_privkey + ) + conn.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pubkey) + + # Now, get the first message of the handshake + conn.start_handshake() + try: + plaintext = conn.read_message(data) + except cryptography.exceptions.InvalidTag: + continue + else: + assert plaintext[: len(HANDSHAKE_MSG)] == HANDSHAKE_MSG + + # If it didn't fail it was the right key! Finalize the handshake. + resp = conn.write_message() + fd.sendall(resp) + assert conn.handshake_finished + + return conn + + raise Exception( + f"Unknown client key. Keys: {','.join(k.hex() for k in self.client_pubkeys)}" + ) + + def read_msg(self, fd, noise_conn): + """read a noise-encrypted message from this stream. + + Returns None if the socket closed. + """ + # Read first the length prefix + cypher_header = self.read_data(fd, 2 + 16) + if cypher_header == b"": + return None + msg_header = noise_conn.decrypt(cypher_header) + msg_len = int.from_bytes(msg_header, "big") + + # And then the message + cypher_msg = self.read_data(fd, msg_len) + assert len(cypher_msg) == msg_len + msg = noise_conn.decrypt(cypher_msg).decode("utf-8") + return msg + + def send_msg(self, fd, noise_conn, msg): + """Write a noise-encrypted message from this stream.""" + assert isinstance(msg, str) + + # Compute the message header + msg_raw = msg.encode("utf-8") + length_prefix = (len(msg_raw) + 16).to_bytes(2, "big") + encrypted_header = noise_conn.encrypt(length_prefix) + encrypted_body = noise_conn.encrypt(msg_raw) + + # Then send both the header and the message concatenated + fd.sendall(encrypted_header + encrypted_body) + + def read_data(self, fd, max_len): + """Read data from the given fd until there is nothing to read.""" + data = b"" + while True: + d = fd.recv(max_len) + if len(d) == max_len: + return d + if d == b"": + return data + data += d diff --git a/tests/test_framework/miradord.py b/tests/test_framework/miradord.py index 65f4154..ee9b753 100644 --- a/tests/test_framework/miradord.py +++ b/tests/test_framework/miradord.py @@ -76,13 +76,12 @@ def __init__( f.write("daemon = false\n") f.write(f"log_level = '{LOG_LEVEL}'\n") + f.write(f'listen = "127.0.0.1:{listen_port}"\n') f.write(f'stakeholder_noise_key = "{stk_noise_key.hex()}"\n') - f.write(f'coordinator_host = "127.0.0.1:{coordinator_port}"\n') - f.write(f'coordinator_noise_key = "{coordinator_noise_key}"\n') - f.write("coordinator_poll_seconds = 5\n") - - f.write(f'listen = "127.0.0.1:{listen_port}"\n') + f.write("[coordinator_config]\n") + f.write(f'host = "127.0.0.1:{coordinator_port}"\n') + f.write(f'noise_key = "{coordinator_noise_key.hex()}"\n') f.write("[scripts_config]\n") f.write(f'deposit_descriptor = "{deposit_desc}"\n') diff --git a/tests/test_framework/utils.py b/tests/test_framework/utils.py index e970374..823236a 100644 --- a/tests/test_framework/utils.py +++ b/tests/test_framework/utils.py @@ -12,6 +12,7 @@ import threading import time +from nacl.public import PrivateKey as Curve25519Private TIMEOUT = int(os.getenv("TIMEOUT", 60)) EXECUTOR_WORKERS = int(os.getenv("EXECUTOR_WORKERS", 20)) @@ -63,6 +64,19 @@ DEPOSIT_ADDRESS = "bcrt1qgprmrfkz5mucga0ec046v0sf8yg2y4za99c0h26ew5ycfx64sgdsl0u2j3" +class NoiseKeypair: + """An exchange of paired keys""" + + def __init__( + self, + privkey, + ): + self.privkey = privkey + self.pubkey = bytes( + Curve25519Private(privkey).public_key + ) + + def wait_for(success, timeout=TIMEOUT, debug_fn=None): """ Run success() either until it returns True, or until the timeout is reached.