diff --git a/bip-musig2/reference.py b/bip-musig2/reference.py index 3510368dd2..902280d53f 100644 --- a/bip-musig2/reference.py +++ b/bip-musig2/reference.py @@ -107,6 +107,27 @@ def schnorr_verify(msg: bytes, pubkey: bytes, sig: bytes) -> bool: # End of helper functions copied from BIP-340 reference implementation. # +# There are two types of exceptions that can be raised by this implementation: +# - ValueError for indicating that an input doesn't conform to some function +# precondition (e.g. an input array is the wrong length, a serialized +# representation doesn't have the correct format). +# - InvalidContributionError for indicating that a signer (or the +# aggregator) is misbehaving in the protocol. +# +# Assertions are used to (1) satisfy the type-checking system, and (2) check for +# inconvenient events that can't happen except with negligible probability (e.g. +# output of a hash function is 0) and can't be manually triggered by any +# signer. + +# This exception is raised if a party (signer or nonce aggregator) sends invalid +# values. Actual implementations should not crash when receiving invalid +# contributions. Instead, they should hold the offending party accountable. +class InvalidContributionError(Exception): + def __init__(self, signer, contrib): + self.signer = signer + # contrib is one of "pubkey", "pubnonce", "aggnonce", or "psig". + self.contrib = contrib + infinity = None def cbytes(P: Point) -> bytes: @@ -141,10 +162,12 @@ def key_agg_internal(pubkeys: List[bytes], tweaks: List[bytes], is_xonly: List[b Q = infinity for i in range(u): P_i = lift_x(pubkeys[i]) + if P_i is None: + raise InvalidContributionError(i, "pubkey") a_i = key_agg_coeff_internal(pubkeys, pubkeys[i], pk2) Q = point_add(Q, point_mul(P_i, a_i)) - if Q is None: - raise ValueError('The aggregate public key cannot be infinity.') + # Q is not the point at infinity except with negligible probability. + assert(Q is not None) gacc = 1 tacc = 0 v = len(tweaks) @@ -336,6 +359,15 @@ def partial_sig_agg(psigs: List[bytes], session_ctx: SessionContext) -> Optional def fromhex_all(l): return [bytes.fromhex(l_i) for l_i in l] +# Check that calling `try_fn` raises a `exception`. If `exception` is raised, +# examine it with `except_fn`. +def assertRaises(exception, try_fn, except_fn): + try: + try_fn() + raise RuntimeError("Exception was _not_ raised in a test where it was required.") + except exception as e: + assert(except_fn(e)) + def test_key_agg_vectors(): X = fromhex_all([ 'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9', @@ -350,11 +382,35 @@ def test_key_agg_vectors(): '2EB18851887E7BDC5E830E89B19DDBC28078F1FA88AAD0AD01CA06FE4F80210B', ]) + # Vector 1 assert key_agg([X[0], X[1], X[2]], [], []) == expected[0] + # Vector 2 assert key_agg([X[2], X[1], X[0]], [], []) == expected[1] + # Vector 3 assert key_agg([X[0], X[0], X[0]], [], []) == expected[2] + # Vector 4 assert key_agg([X[0], X[0], X[1], X[1]], [], []) == expected[3] - + # Vector 5: Invalid public key + invalid_pk = bytes.fromhex('0000000000000000000000000000000000000000000000000000000000000005') + assertRaises(InvalidContributionError, + lambda: key_agg([X[0], invalid_pk], [], []), + lambda e: e.signer == 1 and e.contrib == "pubkey") + # Vector 6: Public key exceeds field size + invalid_pk = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC30') + assertRaises(InvalidContributionError, + lambda: key_agg([X[0], invalid_pk], [], []), + lambda e: e.signer == 1 and e.contrib == "pubkey") + # Vector 7: Tweak is out of range + invalid_tweak = bytes.fromhex('FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141') + assertRaises(ValueError, + lambda: key_agg([X[0], X[1]], [invalid_tweak], [True]), + lambda e: str(e) == 'The tweak must be less than n.') + # Vector 8: Intermediate tweaking result is point at infinity + G_ = bytes_from_point(G) + coeff = bytes_from_int(n - key_agg_coeff([G_], G_)) + assertRaises(ValueError, + lambda: key_agg([G_], [coeff], [False]), + lambda e: str(e) == 'The result of tweaking cannot be infinity.') def test_nonce_gen_vectors(): def fill(i):