-
Notifications
You must be signed in to change notification settings - Fork 208
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
musig-spec: Add naive Python reference implementation
- Loading branch information
1 parent
73f0cbd
commit 58f8bf0
Showing
2 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
from typing import Any, List, Optional, Tuple | ||
import hashlib | ||
import secrets | ||
|
||
# | ||
# The following helper functions were copied from the BIP-340 reference implementation: | ||
# https://github.com/bitcoin/bips/blob/master/bip-0340/reference.py | ||
# | ||
|
||
p = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F | ||
n = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 | ||
|
||
# Points are tuples of X and Y coordinates and the point at infinity is | ||
# represented by the None keyword. | ||
G = (0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798, 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8) | ||
|
||
Point = Tuple[int, int] | ||
|
||
# This implementation can be sped up by storing the midstate after hashing | ||
# tag_hash instead of rehashing it all the time. | ||
def tagged_hash(tag: str, msg: bytes) -> bytes: | ||
tag_hash = hashlib.sha256(tag.encode()).digest() | ||
return hashlib.sha256(tag_hash + tag_hash + msg).digest() | ||
|
||
def is_infinite(P: Optional[Point]) -> bool: | ||
return P is None | ||
|
||
def x(P: Point) -> int: | ||
assert not is_infinite(P) | ||
return P[0] | ||
|
||
def y(P: Point) -> int: | ||
assert not is_infinite(P) | ||
return P[1] | ||
|
||
def point_add(P1: Optional[Point], P2: Optional[Point]) -> Optional[Point]: | ||
if P1 is None: | ||
return P2 | ||
if P2 is None: | ||
return P1 | ||
if (x(P1) == x(P2)) and (y(P1) != y(P2)): | ||
return None | ||
if P1 == P2: | ||
lam = (3 * x(P1) * x(P1) * pow(2 * y(P1), p - 2, p)) % p | ||
else: | ||
lam = ((y(P2) - y(P1)) * pow(x(P2) - x(P1), p - 2, p)) % p | ||
x3 = (lam * lam - x(P1) - x(P2)) % p | ||
return (x3, (lam * (x(P1) - x3) - y(P1)) % p) | ||
|
||
def point_mul(P: Optional[Point], n: int) -> Optional[Point]: | ||
R = None | ||
for i in range(256): | ||
if (n >> i) & 1: | ||
R = point_add(R, P) | ||
P = point_add(P, P) | ||
return R | ||
|
||
def bytes_from_int(x: int) -> bytes: | ||
return x.to_bytes(32, byteorder="big") | ||
|
||
def bytes_from_point(P: Point) -> bytes: | ||
return bytes_from_int(x(P)) | ||
|
||
def lift_x(b: bytes) -> Optional[Point]: | ||
x = int_from_bytes(b) | ||
if x >= p: | ||
return None | ||
y_sq = (pow(x, 3, p) + 7) % p | ||
y = pow(y_sq, (p + 1) // 4, p) | ||
if pow(y, 2, p) != y_sq: | ||
return None | ||
return (x, y if y & 1 == 0 else p-y) | ||
|
||
def int_from_bytes(b: bytes) -> int: | ||
return int.from_bytes(b, byteorder="big") | ||
|
||
def has_even_y(P: Point) -> bool: | ||
assert not is_infinite(P) | ||
return y(P) % 2 == 0 | ||
|
||
# | ||
# End of helper functions copied from BIP-340 reference implementation. | ||
# | ||
|
||
def cbytes(P: Point) -> bytes: | ||
a = b'\x02' if has_even_y(P) else b'\x03' | ||
return a + bytes_from_point(P) | ||
|
||
def point_negate(P: Point) -> Point: | ||
if is_infinite(P): | ||
return P | ||
return (x(P), p - y(P)) | ||
|
||
def pointc(x: bytes) -> Point: | ||
P = lift_x(x[1:33]) | ||
if x[0] == 2: | ||
return P | ||
elif x[0] == 3: | ||
return point_negate(P) | ||
assert False | ||
|
||
def key_agg(pubkeys: List[bytes]) -> bytes: | ||
Q = key_agg_internal(pubkeys) | ||
return bytes_from_point(Q) | ||
|
||
def key_agg_internal(pubkeys: List[bytes]) -> Point: | ||
u = len(pubkeys) | ||
Q = None | ||
for i in range(u): | ||
a_i = key_agg_coeff(pubkeys, pubkeys[i]) | ||
P_i = lift_x(pubkeys[i]) | ||
Q = point_add(Q, point_mul(P_i, a_i)) | ||
assert not is_infinite(Q) | ||
return Q | ||
|
||
def hash_keys(pubkeys: List[bytes]) -> bytes: | ||
return tagged_hash('KeyAgg list', b''.join(pubkeys)) | ||
|
||
def is_second(pubkeys: List[bytes], pk: bytes) -> bool: | ||
u = len(pubkeys) | ||
for j in range(u): | ||
if pubkeys[j] != pubkeys[0]: | ||
return pubkeys[j] == pk | ||
return False | ||
|
||
def key_agg_coeff(pubkeys: List[bytes], pk: bytes) -> int: | ||
if is_second(pubkeys, pk): | ||
return 1 | ||
else: | ||
L = hash_keys(pubkeys) | ||
return int_from_bytes(tagged_hash('KeyAgg coefficient', L + pk)) % n | ||
|
||
def nonce_gen() -> Tuple[bytes, bytes]: | ||
k_1 = 1 + secrets.randbelow(n - 2) | ||
k_2 = 1 + secrets.randbelow(n - 2) | ||
R_1 = point_mul(G, k_1) | ||
R_2 = point_mul(G, k_2) | ||
pubnonce = cbytes(R_1) + cbytes(R_2) | ||
secnonce = bytes_from_int(k_1) + bytes_from_int(k_2) | ||
return secnonce, pubnonce | ||
|
||
def nonce_agg(pubnonces: List[bytes]) -> bytes: | ||
u = len(pubnonces) | ||
aggnonce = b'' | ||
for i in (1, 2): | ||
R_i_ = None | ||
for j in range(u): | ||
R_i_ = point_add(R_i_, pointc(pubnonces[j][(i-1)*33:i*33])) | ||
R_i = R_i_ if not is_infinite(R_i_) else G | ||
aggnonce += cbytes(R_i) | ||
return aggnonce | ||
|
||
def partial_sign(secnonce: bytes, sk: bytes, aggnonce: bytes, pubkeys: List[bytes], msg: bytes) -> bytes: | ||
R_1 = pointc(aggnonce[0:33]) | ||
R_2 = pointc(aggnonce[33:66]) | ||
Q = key_agg_internal(pubkeys) | ||
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n | ||
R = point_add(R_1, point_mul(R_2, b)) | ||
assert not is_infinite(R) | ||
k_1_ = int_from_bytes(secnonce[0:32]) | ||
k_2_ = int_from_bytes(secnonce[32:64]) | ||
assert 0 < k_1_ < n | ||
assert 0 < k_2_ < n | ||
k_1 = k_1_ if has_even_y(R) else n - k_1_ | ||
k_2 = k_2_ if has_even_y(R) else n - k_2_ | ||
d_ = int_from_bytes(sk) | ||
assert 0 < d_ < n | ||
P = point_mul(G, d_) | ||
d = n - d_ if has_even_y(P) != has_even_y(Q) else d_ | ||
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n | ||
mu = key_agg_coeff(pubkeys, bytes_from_point(P)) | ||
s = (k_1 + b * k_2 + e * mu * d) % n | ||
psig = bytes_from_int(s) | ||
pubnonce = cbytes(point_mul(G, k_1_)) + cbytes(point_mul(G, k_2_)) | ||
assert partial_sig_verify_internal(psig, pubnonce, aggnonce, pubkeys, bytes_from_point(P), msg) | ||
return psig | ||
|
||
def partial_sig_verify(psig: bytes, pubnonces: List[bytes], pubkeys: List[bytes], msg: bytes, i: int) -> bool: | ||
aggnonce = nonce_agg(pubnonces) | ||
return partial_sig_verify_internal(psig, pubnonces[i], aggnonce, pubkeys, pubkeys[i], msg) | ||
|
||
def partial_sig_verify_internal(psig: bytes, pubnonce: bytes, aggnonce: bytes, pubkeys: List[bytes], pk: bytes, msg: bytes) -> bool: | ||
s = int_from_bytes(psig) | ||
assert s < n | ||
R_1 = pointc(aggnonce[0:33]) | ||
R_2 = pointc(aggnonce[33:66]) | ||
Q = key_agg_internal(pubkeys) | ||
b = int_from_bytes(tagged_hash('MuSig/noncecoef', aggnonce + bytes_from_point(Q) + msg)) % n | ||
R = point_add(R_1, point_mul(R_2, b)) | ||
R_1_ = pointc(pubnonce[0:33]) | ||
R_2_ = pointc(pubnonce[33:66]) | ||
R__ = point_add(R_1_, point_mul(R_2_, b)) | ||
R_ = R__ if has_even_y(R) else point_negate(R__) | ||
e = int_from_bytes(tagged_hash('BIP0340/challenge', bytes_from_point(R) + bytes_from_point(Q) + msg)) % n | ||
mu = key_agg_coeff(pubkeys, pk) | ||
P_ = lift_x(pk) | ||
P = P_ if has_even_y(Q) else point_negate(P_) | ||
return point_mul(G, s) == point_add(R_, point_mul(P, e * mu % n)) | ||
|
||
# | ||
# The following code is only used for testing. | ||
# Test vectors were copied from libsecp256k1's MuSig test file. | ||
# See `musig_test_vectors_keyagg` and `musig_test_vectors_sign` in | ||
# https://github.com/ElementsProject/secp256k1-zkp/blob/master/src/modules/musig/tests_impl.h | ||
# | ||
def fromhex_all(l): | ||
return [bytes.fromhex(l_i) for l_i in l] | ||
|
||
def test_key_agg_vectors(): | ||
X = fromhex_all([ | ||
'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9', | ||
'DFF1D77F2A671C5F36183726DB2341BE58FEAE1DA2DECED843240F7B502BA659', | ||
'3590A94E768F8E1815C2F24B4D80A8E3149316C3518CE7B7AD338368D038CA66', | ||
]) | ||
|
||
expected = fromhex_all([ | ||
'E5830140512195D74C8307E39637CBE5FB730EBEAB80EC514CF88A877CEEEE0B', | ||
'D70CD69A2647F7390973DF48CBFA2CCC407B8B2D60B08C5F1641185C7998A290', | ||
'81A8B093912C9E481408D09776CEFB48AEB8B65481B6BAAFB3C5810106717BEB', | ||
'2EB18851887E7BDC5E830E89B19DDBC28078F1FA88AAD0AD01CA06FE4F80210B', | ||
]) | ||
|
||
assert key_agg([X[0], X[1], X[2]]) == expected[0] | ||
assert key_agg([X[2], X[1], X[0]]) == expected[1] | ||
assert key_agg([X[0], X[0], X[0]]) == expected[2] | ||
assert key_agg([X[0], X[0], X[1], X[1]]) == expected[3] | ||
|
||
def test_partial_sign_vectors(): | ||
X = fromhex_all([ | ||
'F9308A019258C31049344F85F89D5229B531C845836F99B08601F113BCE036F9', | ||
'DFF1D77F2A671C5F36183726DB2341BE58FEAE1DA2DECED843240F7B502BA659', | ||
]) | ||
|
||
secnonce = bytes.fromhex( | ||
'508B81A611F100A6B2B6B29656590898AF488BCF2E1F55CF22E5CFB84421FE61' + | ||
'FA27FD49B1D50085B481285E1CA205D55C82CC1B31FF5CD54A489829355901F7') | ||
|
||
aggnonce = bytes.fromhex( | ||
'028465FCF0BBDBCF443AABCCE533D42B4B5A10966AC09A49655E8C42DAAB8FCD61' + | ||
'037496A3CC86926D452CAFCFD55D25972CA1675D549310DE296BFF42F72EEEA8C9') | ||
|
||
sk = bytes.fromhex('7FB9E0E687ADA1EEBF7ECFE2F21E73EBDB51A7D450948DFE8D76D7F2D1007671') | ||
msg = bytes.fromhex('F95466D086770E689964664219266FE5ED215C92AE20BAB5C9D79ADDDDF3C0CF') | ||
|
||
expected = fromhex_all([ | ||
'68537CC5234E505BD14061F8DA9E90C220A181855FD8BDB7F127BB12403B4D3B', | ||
'2DF67BFFF18E3DE797E13C6475C963048138DAEC5CB20A357CECA7C8424295EA', | ||
'0D5B651E6DE34A29A12DE7A8B4183B4AE6A7F7FBE15CDCAFA4A3D1BCAABC7517', | ||
]) | ||
|
||
pk = bytes_from_point(point_mul(G, int_from_bytes(sk))) | ||
|
||
assert partial_sign(secnonce, sk, aggnonce, [pk, X[0], X[1]], msg) == expected[0] | ||
assert partial_sign(secnonce, sk, aggnonce, [X[0], pk, X[1]], msg) == expected[1] | ||
assert partial_sign(secnonce, sk, aggnonce, [X[0], X[1], pk], msg) == expected[2] | ||
|
||
def test_sign_and_verify_random(iters): | ||
for i in range(iters): | ||
sk_1 = secrets.token_bytes(32) | ||
sk_2 = secrets.token_bytes(32) | ||
pk_1 = bytes_from_point(point_mul(G, int_from_bytes(sk_1))) | ||
pk_2 = bytes_from_point(point_mul(G, int_from_bytes(sk_2))) | ||
pubkeys = [pk_1, pk_2] | ||
|
||
secnonce_1, pubnonce_1 = nonce_gen() | ||
secnonce_2, pubnonce_2 = nonce_gen() | ||
pubnonces = [pubnonce_1, pubnonce_2] | ||
aggnonce = nonce_agg(pubnonces) | ||
|
||
msg = secrets.token_bytes(32) | ||
|
||
psig = partial_sign(secnonce_1, sk_1, aggnonce, pubkeys, msg) | ||
assert partial_sig_verify(psig, pubnonces, pubkeys, msg, 0) | ||
|
||
# Wrong signer index | ||
assert not partial_sig_verify(psig, pubnonces, pubkeys, msg, 1) | ||
|
||
# Wrong message | ||
assert not partial_sig_verify(psig, pubnonces, pubkeys, secrets.token_bytes(32), 0) | ||
|
||
if __name__ == '__main__': | ||
test_key_agg_vectors() | ||
test_partial_sign_vectors() | ||
test_sign_and_verify_random(4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters