Skip to content

Commit

Permalink
Add dummy coordinator to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edouardparis committed Jun 17, 2022
1 parent c65602a commit 6af6b56
Show file tree
Hide file tree
Showing 4 changed files with 250 additions and 14 deletions.
49 changes: 40 additions & 9 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from concurrent import futures
from ephemeral_port_reserve import reserve
from test_framework.bitcoind import Bitcoind, BitcoindRpcProxy
from test_framework.coordinator import DummyCoordinator
from test_framework.miradord import Miradord
from test_framework.utils import (
get_descriptors,
Expand All @@ -9,6 +10,7 @@
MANS_XPUBS,
COSIG_PUBKEYS,
CSV,
NoiseKeypair,
)

import os
Expand Down Expand Up @@ -114,7 +116,40 @@ def bitcoind(directory):


@pytest.fixture
def miradord(request, bitcoind, directory):
def noise_keys():
# The Noise keys are interdependant, so generate everything in advance
# to avoid roundtrips
coordinator_keys = NoiseKeypair(os.urandom(32))
manager_keys = NoiseKeypair(os.urandom(32))
stakeholder_keys = NoiseKeypair(os.urandom(32))
watchtower_keys = NoiseKeypair(os.urandom(32))
noise_keys = {
"coordinator": coordinator_keys,
"manager": manager_keys,
"stakeholder": stakeholder_keys,
"watchtower": watchtower_keys,
}
yield noise_keys


@pytest.fixture
def coordinator(noise_keys):
coordinator_port = reserve()
coordinator = DummyCoordinator(
coordinator_port,
noise_keys["coordinator"].privkey,
[
noise_keys["manager"].pubkey,
noise_keys["stakeholder"].pubkey,
noise_keys["watchtower"].pubkey
],
)
coordinator.start()
yield coordinator


@pytest.fixture
def miradord(request, bitcoind, coordinator, noise_keys, directory):
"""If a 'mock_bitcoind' pytest marker is set, it will create a proxy for the communication
from the miradord process to the bitcoind process. An optional 'mocks' parameter can be set
for this marker in order to specify some pre-registered mock of RPC commands.
Expand All @@ -130,10 +165,6 @@ def miradord(request, bitcoind, directory):
)
emer_addr = "bcrt1qewc2348370pgw8kjz8gy09z8xyh0d9fxde6nzamd3txc9gkmjqmq8m4cdq"

coordinator_noise_key = (
"d91563973102454a7830137e92d0548bc83b4ea2799f1df04622ca1307381402"
)

bitcoind_cookie = os.path.join(bitcoind.bitcoin_dir, "regtest", ".cookie")
bitcoind_rpcport = bitcoind.rpcport

Expand All @@ -151,10 +182,10 @@ def miradord(request, bitcoind, directory):
cpfp_desc,
emer_addr,
reserve(),
os.urandom(32),
os.urandom(32),
coordinator_noise_key, # Unused yet
reserve(), # Unused yet
noise_keys["watchtower"].privkey,
noise_keys["stakeholder"].privkey,
coordinator.coordinator_pubkey,
coordinator.port,
bitcoind_rpcport,
bitcoind_cookie,
)
Expand Down
192 changes: 192 additions & 0 deletions tests/test_framework/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import cryptography
import json
import os
import select
import socket
import threading

from nacl.public import PrivateKey as Curve25519Private
from noise.connection import NoiseConnection, Keypair
from test_framework.utils import (
TIMEOUT,
)

HANDSHAKE_MSG = b"practical_revault_0"


class DummyCoordinator:
"""A simple in-RAM synchronization server."""

def __init__(
self,
port,
coordinator_privkey,
client_pubkeys,
):
self.port = port
self.coordinator_privkey = coordinator_privkey
self.coordinator_pubkey = bytes(
Curve25519Private(coordinator_privkey).public_key
)
self.client_pubkeys = client_pubkeys

# Spin up the coordinator proxy
self.s = socket.socket()
self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.s.bind(("localhost", self.port))
self.s.listen(1_000)
# Use a pipe to communicate to threads to stop
self.r_close_chann, self.w_close_chann = os.pipe()

# A mapping from txid to pubkey to signature
self.sigs = {}
# A mapping from deposit_outpoint to base64 tx
self.spend_txs = {}

def __del__(self):
self.cleanup()

def start(self):
self.server_thread = threading.Thread(target=self.run)
self.server_thread.start()

def cleanup(self):
# Write to the pipe to notify the thread it needs to stop
os.write(self.w_close_chann, b".")
self.server_thread.join()

def run(self):
"""Accept new connections until we are told to stop."""
while True:
r_fds, _, _ = select.select([self.r_close_chann, self.s.fileno()], [], [])

# First check if we've been told to stop, then spawn a new thread per connection
if self.r_close_chann in r_fds:
break
if self.s.fileno() in r_fds:
t = threading.Thread(target=self.connection_handle, daemon=True)
t.start()

def connection_handle(self):
"""Read and treat requests from this client. Blocking."""
client_fd, _ = self.s.accept()
client_fd.settimeout(TIMEOUT // 2)
client_noise = self.server_noise_conn(client_fd)

while True:
# Manually do the select to check if we've been told to stop
r_fds, _, _ = select.select([self.r_close_chann, client_fd], [], [])
if self.r_close_chann in r_fds:
break
req = self.read_msg(client_fd, client_noise)
if req is None:
break
request = json.loads(req)
method, params = request["method"], request["params"]

if method == "sig":
# TODO: mutex
if params["txid"] not in self.sigs:
self.sigs[params["txid"]] = {}
self.sigs[params["txid"]][params["pubkey"]] = params["signature"]
# TODO: remove this useless response from the protocol
resp = {"result": {"ack": True}, "id": request["id"]}
self.send_msg(client_fd, client_noise, json.dumps(resp))

elif method == "get_sigs":
txid = params["txid"]
sigs = self.sigs.get(txid, {})
resp = {"result": {"signatures": sigs}, "id": request["id"]}
self.send_msg(client_fd, client_noise, json.dumps(resp))

elif method == "set_spend_tx":
for outpoint in params["deposit_outpoints"]:
self.spend_txs[outpoint] = params["spend_tx"]
# TODO: remove this useless response from the protocol
resp = {"result": {"ack": True}, "id": request["id"]}
self.send_msg(client_fd, client_noise, json.dumps(resp))

elif method == "get_spend_tx":
spend_tx = self.spend_txs.get(params["deposit_outpoint"])
resp = {"result": {"spend_tx": spend_tx}, "id": request["id"]}
self.send_msg(client_fd, client_noise, json.dumps(resp))

else:
assert False, "Invalid request '{}'".format(method)

def server_noise_conn(self, fd):
"""Do practical-revault's Noise handshake with a given client connection."""
# Read the first message of the handshake only once
data = self.read_data(fd, 32 + len(HANDSHAKE_MSG) + 16)

# We brute force all pubkeys. FIXME!
for pubkey in self.client_pubkeys:
# Set the local and remote static keys
conn = NoiseConnection.from_name(b"Noise_KK_25519_ChaChaPoly_SHA256")
conn.set_as_responder()
conn.set_keypair_from_private_bytes(
Keypair.STATIC, self.coordinator_privkey
)
conn.set_keypair_from_public_bytes(Keypair.REMOTE_STATIC, pubkey)

# Now, get the first message of the handshake
conn.start_handshake()
try:
plaintext = conn.read_message(data)
except cryptography.exceptions.InvalidTag:
continue
else:
assert plaintext[: len(HANDSHAKE_MSG)] == HANDSHAKE_MSG

# If it didn't fail it was the right key! Finalize the handshake.
resp = conn.write_message()
fd.sendall(resp)
assert conn.handshake_finished

return conn

raise Exception(
f"Unknown client key. Keys: {','.join(k.hex() for k in self.client_pubkeys)}"
)

def read_msg(self, fd, noise_conn):
"""read a noise-encrypted message from this stream.
Returns None if the socket closed.
"""
# Read first the length prefix
cypher_header = self.read_data(fd, 2 + 16)
if cypher_header == b"":
return None
msg_header = noise_conn.decrypt(cypher_header)
msg_len = int.from_bytes(msg_header, "big")

# And then the message
cypher_msg = self.read_data(fd, msg_len)
assert len(cypher_msg) == msg_len
msg = noise_conn.decrypt(cypher_msg).decode("utf-8")
return msg

def send_msg(self, fd, noise_conn, msg):
"""Write a noise-encrypted message from this stream."""
assert isinstance(msg, str)

# Compute the message header
msg_raw = msg.encode("utf-8")
length_prefix = (len(msg_raw) + 16).to_bytes(2, "big")
encrypted_header = noise_conn.encrypt(length_prefix)
encrypted_body = noise_conn.encrypt(msg_raw)

# Then send both the header and the message concatenated
fd.sendall(encrypted_header + encrypted_body)

def read_data(self, fd, max_len):
"""Read data from the given fd until there is nothing to read."""
data = b""
while True:
d = fd.recv(max_len)
if len(d) == max_len:
return d
if d == b"":
return data
data += d
9 changes: 4 additions & 5 deletions tests/test_framework/miradord.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ def __init__(
f.write("daemon = false\n")
f.write(f"log_level = '{LOG_LEVEL}'\n")

f.write(f'listen = "127.0.0.1:{listen_port}"\n')
f.write(f'stakeholder_noise_key = "{stk_noise_key.hex()}"\n')

f.write(f'coordinator_host = "127.0.0.1:{coordinator_port}"\n')
f.write(f'coordinator_noise_key = "{coordinator_noise_key}"\n')
f.write("coordinator_poll_seconds = 5\n")

f.write(f'listen = "127.0.0.1:{listen_port}"\n')
f.write("[coordinator_config]\n")
f.write(f'host = "127.0.0.1:{coordinator_port}"\n')
f.write(f'noise_key = "{coordinator_noise_key.hex()}"\n')

f.write("[scripts_config]\n")
f.write(f'deposit_descriptor = "{deposit_desc}"\n')
Expand Down
14 changes: 14 additions & 0 deletions tests/test_framework/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import threading
import time

from nacl.public import PrivateKey as Curve25519Private

TIMEOUT = int(os.getenv("TIMEOUT", 60))
EXECUTOR_WORKERS = int(os.getenv("EXECUTOR_WORKERS", 20))
Expand Down Expand Up @@ -63,6 +64,19 @@
DEPOSIT_ADDRESS = "bcrt1qgprmrfkz5mucga0ec046v0sf8yg2y4za99c0h26ew5ycfx64sgdsl0u2j3"


class NoiseKeypair:
"""An exchange of paired keys"""

def __init__(
self,
privkey,
):
self.privkey = privkey
self.pubkey = bytes(
Curve25519Private(privkey).public_key
)


def wait_for(success, timeout=TIMEOUT, debug_fn=None):
"""
Run success() either until it returns True, or until the timeout is reached.
Expand Down

0 comments on commit 6af6b56

Please sign in to comment.