diff --git a/src/Makefile.am b/src/Makefile.am index c1c931b71c..16d5d90613 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -99,6 +99,7 @@ BITCOIN_CORE_H = \ init.h \ key.h \ keystore.h \ + keytree.h \ leveldbwrapper.h \ limitedmap.h \ main.h \ @@ -126,6 +127,7 @@ BITCOIN_CORE_H = \ streams.h \ sync.h \ threadsafety.h \ + thresholdtree.h \ timedata.h \ tinyformat.h \ txdb.h \ @@ -246,6 +248,7 @@ libbitcoin_common_a_SOURCES = \ hash.cpp \ key.cpp \ keystore.cpp \ + keytree.cpp \ merkleblock.cpp \ netbase.cpp \ pow.cpp \ diff --git a/src/Makefile.test.include b/src/Makefile.test.include index 1c947a4881..41017b65ba 100644 --- a/src/Makefile.test.include +++ b/src/Makefile.test.include @@ -41,6 +41,7 @@ BITCOIN_TESTS =\ test/getarg_tests.cpp \ test/hash_tests.cpp \ test/key_tests.cpp \ + test/keytree_tests.cpp \ test/main_tests.cpp \ test/mempool_tests.cpp \ test/miner_tests.cpp \ diff --git a/src/core_io.h b/src/core_io.h index 907aac9325..de1daa07da 100644 --- a/src/core_io.h +++ b/src/core_io.h @@ -13,6 +13,7 @@ class CScript; class CTransaction; class uint256; class UniValue; +class KeyTree; // core_read.cpp extern CScript ParseScript(std::string s); @@ -21,9 +22,11 @@ extern bool DecodeHexBlk(CBlock&, const std::string& strHexBlk); extern uint256 ParseHashUV(const UniValue& v, const std::string& strName); extern uint256 ParseHashStr(const std::string&, const std::string& strName); extern std::vector ParseHexUV(const UniValue& v, const std::string& strName); +extern bool ParseKeyTree(const std::string &s, KeyTree& tree); // core_write.cpp extern std::string FormatScript(const CScript& script); +extern std::string FormatKeyTree(const KeyTree& keytree); extern std::string EncodeHexTx(const CTransaction& tx); extern std::string EncodeHexBlock(const CBlock& block); extern void ScriptPubKeyToUniv(const CScript& scriptPubKey, diff --git a/src/core_read.cpp b/src/core_read.cpp index 1d6bd25609..5301234528 100644 --- a/src/core_read.cpp +++ b/src/core_read.cpp @@ -4,6 +4,7 @@ #include "core_io.h" +#include "keytree.h" #include "primitives/block.h" #include "primitives/transaction.h" #include "script/script.h" @@ -155,3 +156,86 @@ vector ParseHexUV(const UniValue& v, const string& strName) throw runtime_error(strName+" must be hexadecimal string (not '"+strHex+"')"); return ParseHex(strHex); } + +static bool ParseKeyTreeNode(const std::string &s, size_t &pos, KeyTreeNode& tree); +static bool ParseKeyTreeCall(const std::string &s, size_t &pos, unsigned long* num, std::vector& children) +{ + if (s.size() == pos) return false; + if (s[pos] != '(') return false; + pos++; + int count = 0; + if (num) { + const char *ptr = &s[pos]; + char *eptr = NULL; + *num = strtoul(ptr, &eptr, 10); + if (eptr == ptr) return false; + pos += eptr - ptr; + count++; + } + while (true) { + if (count) { + if (pos == s.size()) return false; + if (s[pos] == /*(*/')') { + pos++; + return true; + } + if (s[pos] != ',') return false; + pos++; + } + children.push_back(KeyTreeNode()); + if (!ParseKeyTreeNode(s, pos, children.back())) return false; + count++; + } +} + +static bool ParseKeyTreeNode(const std::string &s, size_t &pos, KeyTreeNode& tree) +{ + while (pos < s.size() && isspace(s[pos])) pos++; + if (s.size() >= pos + 66 && IsHex(s.substr(pos, 66))) { + std::vector data = ParseHex(s.substr(pos, 66)); + tree.leaf.Set(data.begin(), data.end()); + pos += 66; + return tree.leaf.IsFullyValid(); + } + if (s.size() >= pos + 2 && s.substr(pos, 2) == "OR") { + pos += 2; + if (!ParseKeyTreeCall(s, pos, NULL, tree.children)) return false; + if (tree.children.size() < 2) return false; + tree.threshold = 1; + return true; + } + if (s.size() >= pos + 3 && s.substr(pos, 3) == "AND") { + pos += 3; + if (!ParseKeyTreeCall(s, pos, NULL, tree.children)) return false; + if (tree.children.size() < 2) return false; + tree.threshold = tree.children.size(); + return true; + } + if (s.size() >= pos + 9 && s.substr(pos, 9) == "THRESHOLD") { + pos += 9; + unsigned long num; + if (!ParseKeyTreeCall(s, pos, &num, tree.children)) return false; + if (tree.children.size() < 2) return false; + tree.threshold = num; + if (tree.threshold <= 1) return false; + if (tree.threshold >= tree.children.size()) return false; + return true; + } + return false; +} + +bool ParseKeyTree(const std::string &s, KeyTree& tree) +{ + size_t pos = 0; + if (!ParseKeyTreeNode(s, pos, tree.root)) return false; + if (pos != s.size()) return false; + uint64_t count = 0; + tree.hash = GetMerkleRoot(&tree.root, &count); + int levels = 0; + while (count > 1) { + count = (count + 1) >> 1; + levels++; + } + tree.levels = levels; + return true; +} diff --git a/src/core_write.cpp b/src/core_write.cpp index c3982dfa00..592cb47583 100644 --- a/src/core_write.cpp +++ b/src/core_write.cpp @@ -5,6 +5,7 @@ #include "core_io.h" #include "base58.h" +#include "keytree.h" #include "primitives/transaction.h" #include "script/script.h" #include "script/standard.h" @@ -143,3 +144,29 @@ void TxToUniv(const CTransaction& tx, const uint256& hashBlock, UniValue& entry) entry.pushKV("hex", EncodeHexTx(tx)); // the hex-encoded transaction. used the name "hex" to be consistent with the verbose output of "getrawtransaction". } + +static std::string FormatKeyTreeNode(const KeyTreeNode& tree) +{ + if (tree.threshold == 0) { + return HexStr(tree.leaf.begin(), tree.leaf.end()); + } + std::string ret; + if (tree.threshold == 1) { + ret = "OR("; + } else if (tree.threshold == tree.children.size()) { + ret = "AND("; + } else { + ret = strprintf("THRESHOLD(%i,"/*)*/, tree.threshold); + } + for (size_t i = 0; i < tree.children.size(); i++) { + if (i) ret += ","; + ret += FormatKeyTreeNode(tree.children[i]); + } + ret += ")"; + return ret; +} + +std::string FormatKeyTree(const KeyTree& tree) +{ + return FormatKeyTreeNode(tree.root); +} diff --git a/src/key.cpp b/src/key.cpp index aa85e5b8bf..86cc00558e 100644 --- a/src/key.cpp +++ b/src/key.cpp @@ -144,6 +144,81 @@ bool CKey::Derive(CKey& keyChild, unsigned char ccChild[32], unsigned int nChild return ret; } +bool CKey::PartialSigningNonce(const uint256& hash, std::vector& pubnonceout) const { + if (!fValid) + return false; + secp256k1_pubkey_t pubnonce; + unsigned char secnonce[32]; + LockObject(secnonce); + int ret = secp256k1_schnorr_generate_nonce_pair(secp256k1_context, hash.begin(), begin(), secp256k1_nonce_function_rfc6979, NULL, &pubnonce, secnonce); + UnlockObject(secnonce); + if (!ret) + return false; + pubnonceout.resize(33 + 64); + int publen = 33; + secp256k1_ec_pubkey_serialize(secp256k1_context, &pubnonceout[0], &publen, &pubnonce, true); + // Sign the hash + pubnonce with a full signature, to prove possession of the corresponding private key. + uint256 hash2; + CSHA256().Write(hash.begin(), 32).Write(&pubnonceout[0], 33).Finalize(hash2.begin()); + return secp256k1_schnorr_sign(secp256k1_context, hash2.begin(), &pubnonceout[33], begin(), secp256k1_nonce_function_rfc6979, NULL); +} + +static bool CombinePubNonces(const uint256& hash, const std::vector >& pubnonces, const std::vector& pubkeys, secp256k1_pubkey_t& out) { + bool ret = pubnonces.size() > 0; + ret = ret && (pubnonces.size() == pubkeys.size()); + std::vector parsed_pubnonces; + std::vector parsed_pubnonce_pointers; + parsed_pubnonces.reserve(pubnonces.size()); + parsed_pubnonce_pointers.reserve(pubnonces.size()); + std::vector::const_iterator pit = pubkeys.begin(); + for (std::vector >::const_iterator it = pubnonces.begin(); it != pubnonces.end(); ++it, ++pit) { + secp256k1_pubkey_t other_pubnonce; + ret = ret && (it->size() == 33 + 64); + ret = ret && secp256k1_ec_pubkey_parse(secp256k1_context, &other_pubnonce, &(*it)[0], 33); + // Verify the signature on the pubnonce. + uint256 hash2; + secp256k1_pubkey_t pubkey; + CSHA256().Write(hash.begin(), 32).Write(&(*it)[0], 33).Finalize(hash2.begin()); + ret = ret && secp256k1_ec_pubkey_parse(secp256k1_context, &pubkey, &(*pit)[0], pit->size()); + ret = ret && secp256k1_schnorr_verify(secp256k1_context, hash2.begin(), &(*it)[33], &pubkey); + if (ret) { + parsed_pubnonces.push_back(other_pubnonce); + parsed_pubnonce_pointers.push_back(&parsed_pubnonces.back()); + } + } + return (ret && secp256k1_ec_pubkey_combine(secp256k1_context, &out, parsed_pubnonces.size(), &parsed_pubnonce_pointers[0])); +} + +bool CKey::PartialSign(const uint256& hash, const std::vector >& other_pubnonces_in, const std::vector& other_pubkeys_in, const std::vector& my_pubnonce_in, std::vector& vchPartialSig) const { + if (!fValid) + return false; + secp256k1_pubkey_t pubnonce, my_pubnonce, other_pubnonces; + unsigned char secnonce[32]; + LockObject(secnonce); + int ret = my_pubnonce_in.size() == 33 + 64 && secp256k1_ec_pubkey_parse(secp256k1_context, &my_pubnonce, &my_pubnonce_in[0], 33); + ret = ret && secp256k1_schnorr_generate_nonce_pair(secp256k1_context, hash.begin(), begin(), secp256k1_nonce_function_rfc6979, NULL, &pubnonce, secnonce); + ret = ret && memcmp(&pubnonce, &my_pubnonce, sizeof(pubnonce)) == 0; + ret = ret && CombinePubNonces(hash, other_pubnonces_in, other_pubkeys_in, other_pubnonces); + if (ret) { + vchPartialSig.resize(64); + ret = secp256k1_schnorr_partial_sign(secp256k1_context, hash.begin(), &vchPartialSig[0], begin(), secnonce, &other_pubnonces); + } + UnlockObject(secnonce); + return ret; +} + +bool CombinePartialSignatures(const std::vector >& input, std::vector& output) { + std::vector sig_pointers; + sig_pointers.reserve(input.size()); + for (std::vector >::const_iterator it = input.begin(); it != input.end(); ++it) { + if (it->size() != 64) return false; + sig_pointers.push_back(&((*it)[0])); + } + output.resize(64); + bool ret = !!secp256k1_schnorr_partial_combine(secp256k1_context, &output[0], sig_pointers.size(), &sig_pointers[0]); + return ret; +} + bool CExtKey::Derive(CExtKey &out, unsigned int nChild) const { out.nDepth = nDepth + 1; CKeyID id = key.GetPubKey().GetID(); @@ -205,7 +280,7 @@ bool ECC_InitSanityCheck() { void ECC_Start() { assert(secp256k1_context == NULL); - secp256k1_context_t *ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN); + secp256k1_context_t *ctx = secp256k1_context_create(SECP256K1_CONTEXT_SIGN | SECP256K1_CONTEXT_VERIFY); assert(ctx != NULL); { diff --git a/src/key.h b/src/key.h index b261d49858..c677da5229 100644 --- a/src/key.h +++ b/src/key.h @@ -133,6 +133,16 @@ class CKey */ bool Sign(const uint256& hash, std::vector& vchSig, uint32_t test_case = 0) const; + /** + * Create a public nonce to communicate to other parties for creating a multisignature. + */ + bool PartialSigningNonce(const uint256& hash, std::vector& pubnonce) const; + + /** + * Create a part of a multisignature given all parties' public nonces. + */ + bool PartialSign(const uint256& hash, const std::vector >& other_pubnonces_in, const std::vector& other_pubkeys_in, const std::vector& my_pubnonce_in, std::vector& vchPartialSig) const; + /** * Create a compact signature (65 bytes), which allows reconstructing the used public key. * The format is one header byte, followed by two times 32 bytes for the serialized r and s values. @@ -178,6 +188,9 @@ struct CExtKey { void SetMaster(const unsigned char* seed, unsigned int nSeedLen); }; +/** Combine multiple partial signatures into a full one. */ +bool CombinePartialSignatures(const std::vector >& input, std::vector& output); + /** Initialize the elliptic curve support. May not be called twice without calling ECC_Stop first. */ void ECC_Start(void); diff --git a/src/keystore.cpp b/src/keystore.cpp index 879f099720..be12b7f45b 100644 --- a/src/keystore.cpp +++ b/src/keystore.cpp @@ -86,3 +86,28 @@ bool CBasicKeyStore::HaveWatchOnly() const LOCK(cs_KeyStore); return (!setWatchOnly.empty()); } + +bool CBasicKeyStore::AddKeyTree(const KeyTree &tree) +{ + LOCK(cs_KeyStore); + mapKeyTrees[tree.hash] = tree; + return true; +} + +bool CBasicKeyStore::HaveKeyTree(const uint256& hash) const +{ + LOCK(cs_KeyStore); + return mapKeyTrees.count(hash) > 0; +} + +bool CBasicKeyStore::GetKeyTree(const uint256& hash, KeyTree& tree) const +{ + LOCK(cs_KeyStore); + KeyTreeMap::const_iterator mi = mapKeyTrees.find(hash); + if (mi != mapKeyTrees.end()) + { + tree = (*mi).second; + return true; + } + return false; +} diff --git a/src/keystore.h b/src/keystore.h index 60502e9a29..cc2b8978eb 100644 --- a/src/keystore.h +++ b/src/keystore.h @@ -7,6 +7,7 @@ #define BITCOIN_KEYSTORE_H #include "key.h" +#include "keytree.h" #include "pubkey.h" #include "sync.h" @@ -40,6 +41,11 @@ class CKeyStore virtual bool HaveCScript(const CScriptID &hash) const =0; virtual bool GetCScript(const CScriptID &hash, CScript& redeemScriptOut) const =0; + //! Support for pubkey trees + virtual bool AddKeyTree(const KeyTree &tree) =0; + virtual bool HaveKeyTree(const uint256& hash) const =0; + virtual bool GetKeyTree(const uint256& hash, KeyTree& tree) const =0; + //! Support for Watch-only addresses virtual bool AddWatchOnly(const CScript &dest) =0; virtual bool RemoveWatchOnly(const CScript &dest) =0; @@ -48,6 +54,7 @@ class CKeyStore }; typedef std::map KeyMap; +typedef std::map KeyTreeMap; typedef std::map ScriptMap; typedef std::set WatchOnlySet; @@ -56,6 +63,7 @@ class CBasicKeyStore : public CKeyStore { protected: KeyMap mapKeys; + KeyTreeMap mapKeyTrees; ScriptMap mapScripts; WatchOnlySet setWatchOnly; @@ -100,6 +108,10 @@ class CBasicKeyStore : public CKeyStore virtual bool HaveCScript(const CScriptID &hash) const; virtual bool GetCScript(const CScriptID &hash, CScript& redeemScriptOut) const; + virtual bool AddKeyTree(const KeyTree &tree); + virtual bool HaveKeyTree(const uint256& hash) const; + virtual bool GetKeyTree(const uint256& hash, KeyTree& tree) const; + virtual bool AddWatchOnly(const CScript &dest); virtual bool RemoveWatchOnly(const CScript &dest); virtual bool HaveWatchOnly(const CScript &dest) const; diff --git a/src/keytree.cpp b/src/keytree.cpp new file mode 100644 index 0000000000..d00507956f --- /dev/null +++ b/src/keytree.cpp @@ -0,0 +1,327 @@ +// Copyright (c) 2015 Pieter Wuille +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include + +#include "hash.h" +#include "keytree.h" +#include "pubkey.h" +#include "crypto/sha256.h" + +#include + +extern secp256k1_context_t* secp256k1_bitcoin_verify_context; + +namespace { + +/* This implements a constant-space merkle root/branch calculator, limited to 2^63 leaves. */ +template +void MerkleComputation(LeafSource& source, uint256* proot, std::vector* pbranch) { + if (pbranch) pbranch->clear(); + if (!source.Valid()) { + if (proot) *proot = uint256(); + return; + } + // count is the number of leaves processed so far. + uint64_t count = 0; + // inner is an array of eagerly computed subtree hashes, indexed by tree + // level (0 being the leaves). + // For example, when count is 25 (11001 in binary), inner[4] is the hash of + // the first 16 leaves, inner[3] of the next 8 leaves, and inner[0] equal to + // the last leaf. The other inner entries are undefined. + uint256 inner[64]; + // Which position in inner is a hash that depends on the matching leaf. + int matchlevel = -1; + // First process all leaves into 'inner' values. + while (source.Valid()) { + uint256 h = source.Get(); + // If there has been no match before, check whether the current leaf is. + bool matchh = (matchlevel == -1) && source.Match(); + count++; + source.Increment(); + int level; + // For each of the lower bits in count that are 0, do 1 step. Each + // corresponds to an inner value that existed before processing the + // current leaf, and each needs a hash to combine it. + for (level = 0; !(count & (((uint64_t)1) << level)); level++) { + if (pbranch) { + if (matchh) { + pbranch->push_back(inner[level]); + } else if (matchlevel == level) { + pbranch->push_back(h); + matchh = true; + } + } + CSHA256().Write(inner[level].begin(), 32).Write(h.begin(), 32).Finalize(h.begin()); + } + // Store the resulting hash at inner position level. + inner[level] = h; + if (matchh) { + matchlevel = level; + } + } + // Do a final 'sweep' over the rightmost branch of the tree to process + // odd levels, and reduce everything to a single top value. + // Level is the level (counted from the bottom) up to which we've sweeped. + int level = 0; + // As long as bit number level in count is zero, skip it. It means there + // is nothing left at this level. + while (!(count & (((uint64_t)1) << level))) { + level++; + } + uint256 h = inner[level]; + bool matchh = matchlevel == level; + static const unsigned char one[1] = {1}; + while (count != (((uint64_t)1) << level)) { + // If we reach this point, h is an inner value that is not the top. + // We combine it with 1 (special rule for odd levels) to produce a higher level one. + if (pbranch && matchh) { + pbranch->push_back(uint256()); + } + CSHA256().Write(h.begin(), 32).Write(one, 1).Finalize(h.begin()); + // Increment count to the value it would have if two entries at this + // level had existed. + count += (((uint64_t)1) << level); + level++; + // And propagate the result upwards accordingly. + while (!(count & (((uint64_t)1) << level))) { + if (pbranch) { + if (matchh) { + pbranch->push_back(inner[level]); + } else if (matchlevel == level) { + pbranch->push_back(h); + matchh = true; + } + } + CSHA256().Write(inner[level].begin(), 32).Write(h.begin(), 32).Finalize(h.begin()); + level++; + } + } + // Return result. + if (proot) *proot = h; +} + +// Lazily constructed wrapper around KeyTreeNode, suitable for use with ThresholdTreeIterator, +// with cached parsed pubkeys. +struct InnerKeyTree { + const KeyTreeNode* node; + bool processed; + bool cached_pubkey; + bool cached_match; + bool match; + secp256k1_pubkey_t pubkey; + std::vector children; + + InnerKeyTree() : node(NULL), processed(true), cached_pubkey(true), cached_match(true), match(false) {} + InnerKeyTree(const KeyTreeNode* node_) : node(node_), processed(false), cached_pubkey(false), cached_match(false), match(false) {} + + void Process() { + if (processed) return; + processed = true; + children.resize(node->children.size()); + for (size_t i = 0; i < children.size(); i++) { + children[i] = InnerKeyTree(&node->children[i]); + } + } + + bool IsLeaf() const { return node->threshold == 0; } + uint32_t Threshold() const { return node->threshold; } + uint32_t Children() { return node->children.size(); } + InnerKeyTree* Child(int pos) { Process(); return &children[pos]; } + + const secp256k1_pubkey_t* GetParsedPubKey() { + if (!cached_pubkey) { + assert(secp256k1_ec_pubkey_parse(secp256k1_bitcoin_verify_context, &pubkey, node->leaf.begin(), node->leaf.size())); + cached_pubkey = true; + } + return &pubkey; + } + + const CPubKey* GetPubKey() { + return &node->leaf; + } + + bool Matches(KeyTreeFilter& filter) { + if (!cached_match) { + match = filter(node->leaf); + cached_match = true; + } + return match; + } +}; + +// An accumulator for ThresholdTreeIterator taking leaf InnerKeyTree*'s, and supporting the computation of hashes of the combined solution pubkey sets. +struct CombinedKeyHashingAccumulator { + std::vector leaves; + uint32_t first_non_match_before; + KeyTreeFilter* filter; + + CombinedKeyHashingAccumulator(KeyTreeFilter* filter_ = NULL) : first_non_match_before(0), filter(filter_) {} + + void Push(InnerKeyTree* x) { + leaves.push_back(x); + if ((!first_non_match_before) && filter && !x->Matches(*filter)) { + first_non_match_before = leaves.size(); + } + } + + void Pop(InnerKeyTree* x) { + (void)x; + if (first_non_match_before == leaves.size()) { + first_non_match_before = 0; + } + leaves.pop_back(); + } + + uint256 ComputeHash() const { + uint256 ret; + if (leaves.size() == 1) { + // When there is just a single key, do not parse and reserialize. + const CPubKey* pubkey = leaves[0]->GetPubKey(); + CSHA256().Write(pubkey->begin(), pubkey->size()).Finalize(ret.begin()); + } else { + unsigned char pubkey[33]; + int pubkeylen = 33; + std::vector keys; + keys.resize(leaves.size()); + for (size_t i = 0; i < leaves.size(); i++) { + keys[i] = leaves[i]->GetParsedPubKey(); + } + secp256k1_pubkey_t key; + assert(secp256k1_ec_pubkey_combine(secp256k1_bitcoin_verify_context, &key, keys.size(), &keys[0])); + secp256k1_ec_pubkey_serialize(secp256k1_bitcoin_verify_context, pubkey, &pubkeylen, &key, 1); + CSHA256().Write(pubkey, 33).Finalize(ret.begin()); + } + return ret; + } + + bool Matches() { + return (filter != NULL && first_non_match_before == 0); + } + + std::vector GetMatch() const { + std::vector pubkeys; + pubkeys.reserve(leaves.size()); + for (std::vector::const_iterator it = leaves.begin(); it != leaves.end(); it++) { + pubkeys.push_back((*it)->node->leaf); + } + return pubkeys; + } +}; + +// Wrapper around a ThresholdTreeIterator for InnerKeyTrees, producing hashes of the combined pubkeys of the solution sets. +struct InnerKeyTreeIteratorLeafSource { + InnerKeyTree root; + ThresholdTreeIterator iter; + uint64_t count; + uint64_t matchpos; + bool hadmatch; + std::vector* match; + + InnerKeyTreeIteratorLeafSource(const KeyTreeNode *tree, KeyTreeFilter* filter, std::vector* matchout) : root(tree), iter(&root, filter), count(0), matchpos(0), hadmatch(false), match(matchout) {} + + bool Valid() const { + return iter.Valid(); + } + + uint256 Get() const { + return iter.GetAccumulator()->ComputeHash(); + } + + inline bool Match() { + if (iter.GetAccumulator()->Matches()) { + matchpos = count; + hadmatch = true; + if (match) { + *match = iter.GetAccumulator()->GetMatch(); + } + return true; + } + return false; + } + + void Increment() { + iter.Increment(); + ++count; + } + + bool HadMatch() const { return hadmatch; } + uint64_t GetCount() const { return count; } + uint64_t GetMatchPosition() const { return matchpos; } +}; + +} + +uint64_t GetCombinations(const KeyTreeNode* tree) { + InnerKeyTree innertree(tree); + return CountCombinations(&innertree); +} + +uint256 GetMerkleRoot(const KeyTreeNode* tree, uint64_t* count) { + uint256 root; + InnerKeyTreeIteratorLeafSource source(tree, NULL, NULL); + MerkleComputation(source, &root, NULL); + if (count) *count = source.count; + return root; +} + +bool GetMerkleBranch(const KeyTreeNode* tree, KeyTreeFilter* filter, uint256* root, uint64_t* count, uint64_t* matchpos, std::vector* branch, std::vector* pubkeys) { + InnerKeyTreeIteratorLeafSource source(tree, filter, pubkeys); + MerkleComputation(source, root, branch); + if (!source.HadMatch()) return false; + if (count) *count = source.GetCount(); + if (matchpos) *matchpos = source.GetMatchPosition(); + return true; +} + +uint256 GetMerkleRootFromBranch(const uint256& leaf, std::vector& branch, uint64_t position) { + uint256 res = leaf; + for (std::vector::const_iterator it = branch.begin(); it != branch.end(); it++) { + if (position & 1) { + CSHA256().Write(it->begin(), 32).Write(res.begin(), 32).Finalize(res.begin()); + } else if (*it != uint256()) { + CSHA256().Write(res.begin(), 32).Write(it->begin(), 32).Finalize(res.begin()); + } else { + static const unsigned char one[1] = {1}; + CSHA256().Write(res.begin(), 32).Write(one, 1).Finalize(res.begin()); + } + position >>= 1; + } + if (position) { + return uint256(); + } + return res; +} + +bool HasMatch(const KeyTreeNode* node, KeyTreeFilter* filter) { + if (node->threshold == 0) { + return (*filter)(node->leaf); + } + uint32_t matches = 0; + for (std::vector::const_iterator it = node->children.begin(); it != node->children.end(); ++it) { + if (HasMatch(&(*it), filter)) { + matches++; + if (matches == node->threshold) { + return true; + } + } + } + return false; +} + +static void GetAllLeavesRecurse(const KeyTreeNode* tree, std::set& ret) { + if (tree->threshold == 0) { + ret.insert(tree->leaf); + } else { + for (std::vector::const_iterator it = tree->children.begin(); it != tree->children.end(); ++it) { + GetAllLeavesRecurse(&*it, ret); + } + } +} + +std::set GetAllLeaves(const KeyTreeNode* tree) { + std::set ret; + GetAllLeavesRecurse(tree, ret); + return ret; +} diff --git a/src/keytree.h b/src/keytree.h new file mode 100644 index 0000000000..7b18a587fa --- /dev/null +++ b/src/keytree.h @@ -0,0 +1,92 @@ +// Copyright (c) 2015 Pieter Wuille +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_KEYTREE_H +#define BITCOIN_KEYTREE_H + +#include "thresholdtree.h" +#include "pubkey.h" + +/* A node in a threshold tree of public keys. + * Such a tree represents a set of valid combinations of public keys in a compact way. + * + * For example: + * thr=2 + * / | \ + * / | \ + * leafA thr=1 thr=2 + * / \ / \ + * leafB leafC leafD leafE + * + * Corresponds to the combinations (A,B), (A,C), (A,D,E), (B,D,E), (C,D,E). + */ +struct KeyTreeNode { + //! The public key this node requires signing with (only used in leaf node, which have threshold == 0). + CPubKey leaf; + + //! The child nodes of this node (only used when this is an inner node, which have threshold > 0). + std::vector children; + + //! The number of child nodes that need to be satisfied before this node is considered satisfied. + uint32_t threshold; + + ADD_SERIALIZE_METHODS; + + KeyTreeNode() : threshold(0) {} + + template + inline void SerializationOp(Stream& s, Operation ser_action, int nType, int nVersion) + { + READWRITE(VARINT(threshold)); + if (threshold) { + READWRITE(children); + } else { + READWRITE(leaf); + } + } + + friend bool operator==(const KeyTreeNode &a, const KeyTreeNode &b) { + if (a.threshold != b.threshold) return false; + if (a.threshold) { + return (a.children == b.children); + } else { + return (a.leaf == b.leaf); + } + } + +}; + +/* Wrapper around an entire key tree, with precomputed Merkle root and level count. */ +struct KeyTree { + uint256 hash; + int levels; + KeyTreeNode root; + + ADD_SERIALIZE_METHODS; + + KeyTree() : levels(0) {} + + template + inline void SerializationOp(Stream& s, Operation ser_action, int nType, int nVersion) + { + READWRITE(hash); + READWRITE(VARINT(levels)); + READWRITE(root); + } +}; + +struct KeyTreeFilter { + virtual bool operator()(const CPubKey& pubkey) = 0; + virtual ~KeyTreeFilter() {} +}; + + +uint64_t GetCombinations(const KeyTreeNode* tree); +uint256 GetMerkleRoot(const KeyTreeNode* tree, uint64_t* count); +bool GetMerkleBranch(const KeyTreeNode* tree, KeyTreeFilter* filter, uint256* root, uint64_t* count, uint64_t* matchpos, std::vector* branch, std::vector* pubkeys = NULL); +uint256 GetMerkleRootFromBranch(const uint256& leaf, std::vector& branch, uint64_t position); +bool HasMatch(const KeyTreeNode* tree, KeyTreeFilter* filter); +std::set GetAllLeaves(const KeyTreeNode* tree); + +#endif diff --git a/src/pubkey.cpp b/src/pubkey.cpp index a5805bf99f..11304732dd 100644 --- a/src/pubkey.cpp +++ b/src/pubkey.cpp @@ -88,6 +88,31 @@ bool CPubKey::Derive(CPubKey& pubkeyChild, unsigned char ccChild[32], unsigned i return true; } +bool CombinePubKeys(const std::vector& pubkeys, CPubKey& combination) { + bool ret = pubkeys.size() > 0; + std::vector parsed_pubkeys; + std::vector parsed_pubkey_pointers; + parsed_pubkeys.reserve(pubkeys.size()); + parsed_pubkey_pointers.reserve(pubkeys.size()); + for (std::vector::const_iterator it = pubkeys.begin(); it != pubkeys.end(); ++it) { + secp256k1_pubkey_t other_pubkey; + ret = ret && secp256k1_ec_pubkey_parse(secp256k1_context, &other_pubkey, it->begin(), it->size()); + if (ret) { + parsed_pubkeys.push_back(other_pubkey); + parsed_pubkey_pointers.push_back(&parsed_pubkeys.back()); + } + } + secp256k1_pubkey_t out; + ret = ret && secp256k1_ec_pubkey_combine(secp256k1_context, &out, parsed_pubkeys.size(), &parsed_pubkey_pointers[0]); + unsigned char outv[33]; + int outlen = 33; + ret = ret && secp256k1_ec_pubkey_serialize(secp256k1_context, outv, &outlen, &out, 1); + if (ret) { + combination.Set(&outv[0], &outv[33]); + } + return ret; +} + void CExtPubKey::Encode(unsigned char code[74]) const { code[0] = nDepth; memcpy(code+1, vchFingerprint, 4); diff --git a/src/pubkey.h b/src/pubkey.h index 2ce4a0ff3d..fd38aa5eb3 100644 --- a/src/pubkey.h +++ b/src/pubkey.h @@ -185,6 +185,8 @@ class CPubKey bool Derive(CPubKey& pubkeyChild, unsigned char ccChild[32], unsigned int nChild, const unsigned char cc[32]) const; }; +bool CombinePubKeys(const std::vector& pubkeys, CPubKey& combination); + struct CExtPubKey { unsigned char nDepth; unsigned char vchFingerprint[4]; diff --git a/src/rpcmisc.cpp b/src/rpcmisc.cpp index f6ee836fa9..d79d5072c3 100644 --- a/src/rpcmisc.cpp +++ b/src/rpcmisc.cpp @@ -4,6 +4,7 @@ // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include "base58.h" +#include "core_io.h" #include "clientversion.h" #include "init.h" #include "main.h" @@ -322,6 +323,70 @@ Value createmultisig(const Array& params, bool fHelp) return result; } +/** + * Used by addmultisigaddress / createmultisig: + */ +void _createtreesig_redeemScript(const Array& params, KeyTree& tree, CScript* pscript) +{ + const std::string& desc = params[0].get_str(); + + // Parse tree + if (!ParseKeyTree(desc, tree)) + throw runtime_error("Cannot parse key tree description"); + + CScript result = GetScriptForTree(tree); + + if (result.size() > MAX_SCRIPT_ELEMENT_SIZE) + throw runtime_error( + strprintf("redeemScript exceeds size limit: %d > %d", result.size(), MAX_SCRIPT_ELEMENT_SIZE)); + + if (pscript) { + pscript->swap(result); + } +} + +Value createtreesig(const Array& params, bool fHelp) +{ + if (fHelp || params.size() < 1 || params.size() > 1) + { + string msg = "createmultisig \"description\"\n" + "\nCreates a tree multi-signature address.\n" + + "\nArguments:\n" + "1. \"description\" (string, required) A description for the allowed combinations\n" + + "\nResult:\n" + "{\n" + " \"address\":\"multisigaddress\", (string) The value of the new multisig address.\n" + " \"redeemScript\":\"script\" (string) The string value of the hex-encoded redemption script.\n" + " \"merkleroot\":\"merkleroot\", (string) The hex-encoded merkle root of the resulting pubkey tree.\n" + " \"combinations\": number, (numeric) The number of combinations this description allows.\n" + " \"serialization\": \"hex\", (string) Hex-encoded full tree representation.\n" + "}\n" + ; + throw runtime_error(msg); + } + + // Construct using pay-to-script-hash: + KeyTree tree; + CScript inner; + _createtreesig_redeemScript(params, tree, &inner); + CScriptID innerID(inner); + CBitcoinAddress address(innerID); + + CDataStream ss(SER_DISK, CLIENT_VERSION); + ss << tree; + + Object result; + result.push_back(Pair("address", address.ToString())); + result.push_back(Pair("redeemScript", HexStr(inner.begin(), inner.end()))); + result.push_back(Pair("merkleroot", tree.hash.ToString())); + result.push_back(Pair("combinations", (int64_t)GetCombinations(&tree.root))); + result.push_back(Pair("serialization", HexStr(ss.begin(), ss.end()))); + + return result; +} + Value verifymessage(const Array& params, bool fHelp) { if (fHelp || params.size() != 3) diff --git a/src/rpcrawtransaction.cpp b/src/rpcrawtransaction.cpp index 0493b623e3..088f680e4e 100644 --- a/src/rpcrawtransaction.cpp +++ b/src/rpcrawtransaction.cpp @@ -821,6 +821,7 @@ Value signrawtransaction(const Array& params, bool fHelp) " \"nValue\": \"hex\", (string, required) The output's value commitment\n" " \"scriptPubKey\": \"hex\", (string, required) script key\n" " \"redeemScript\": \"hex\" (string, required for P2SH) redeem script\n" + " \"keytree\": \"hex\" (string, required for keytree) serialized key tree\n" " }\n" " ,...\n" " ]\n" @@ -921,6 +922,7 @@ Value signrawtransaction(const Array& params, bool fHelp) RPCTypeCheck(prevOut, map_list_of("txid", str_type)("vout", int_type)("nValue", str_type)("scriptPubKey", str_type)); + uint256 txid = ParseHashO(prevOut, "txid"); int nOut = find_value(prevOut, "vout").get_int(); @@ -954,13 +956,21 @@ Value signrawtransaction(const Array& params, bool fHelp) // if redeemScript given and not using the local wallet (private keys // given), add redeemScript to the tempKeystore so it can be signed: if (fGivenKeys && (scriptPubKey.IsPayToScriptHash() || scriptPubKey.IsWithdrawOutput())) { - RPCTypeCheck(prevOut, map_list_of("txid", str_type)("vout", int_type)("nValue", str_type)("scriptPubKey", str_type)("redeemScript",str_type)); + RPCTypeCheck(prevOut, map_list_of("txid", str_type)("vout", int_type)("nValue", str_type)("scriptPubKey", str_type)("redeemScript",str_type)("keytree",str_type)); Value v = find_value(prevOut, "redeemScript"); if (!(v == Value::null)) { vector rsData(ParseHexV(v, "redeemScript")); CScript redeemScript(rsData.begin(), rsData.end()); tempKeystore.AddCScript(redeemScript); } + Value vv = find_value(prevOut, "keytree"); + if (!(v == Value::null)) { + std::string a = vv.get_str(); + KeyTree tree; + CDataStream ss(ParseHex(a), SER_DISK, CLIENT_VERSION); + ss >> tree; + tempKeystore.AddKeyTree(tree); + } } } } @@ -1001,7 +1011,6 @@ Value signrawtransaction(const Array& params, bool fHelp) } const CScript& prevPubKey = coins->vout[txin.prevout.n].scriptPubKey; - txin.scriptSig.clear(); // Only sign SIGHASH_SINGLE if there's a corresponding output: if (!fHashSingle || (i < mergedTx.vout.size())) SignSignature(keystore, prevPubKey, coins->vout[txin.prevout.n].nValue, mergedTx, i, nHashType); diff --git a/src/rpcserver.cpp b/src/rpcserver.cpp index 9c0562f161..1ddaee1afd 100644 --- a/src/rpcserver.cpp +++ b/src/rpcserver.cpp @@ -305,6 +305,7 @@ static const CRPCCommand vRPCCommands[] = /* Utility functions */ { "util", "createmultisig", &createmultisig, true, true , false }, + { "util", "createtreesig", &createtreesig, true, true , false }, { "util", "validateaddress", &validateaddress, true, false, false }, /* uses wallet if enabled */ { "util", "verifymessage", &verifymessage, true, false, false }, { "util", "estimatefee", &estimatefee, true, true, false }, @@ -318,6 +319,7 @@ static const CRPCCommand vRPCCommands[] = #ifdef ENABLE_WALLET /* Wallet */ { "wallet", "addmultisigaddress", &addmultisigaddress, true, false, true }, + { "wallet", "addtreesigaddress", &addtreesigaddress, true, false, true }, { "wallet", "backupwallet", &backupwallet, true, false, true }, { "wallet", "dumpprivkey", &dumpprivkey, true, false, true }, { "wallet", "dumpwallet", &dumpwallet, true, false, true }, diff --git a/src/rpcserver.h b/src/rpcserver.h index 7052378651..5ab46bb1d6 100644 --- a/src/rpcserver.h +++ b/src/rpcserver.h @@ -181,7 +181,9 @@ extern json_spirit::Value movecmd(const json_spirit::Array& params, bool fHelp); extern json_spirit::Value sendfrom(const json_spirit::Array& params, bool fHelp); extern json_spirit::Value sendmany(const json_spirit::Array& params, bool fHelp); extern json_spirit::Value addmultisigaddress(const json_spirit::Array& params, bool fHelp); +extern json_spirit::Value addtreesigaddress(const json_spirit::Array& params, bool fHelp); extern json_spirit::Value createmultisig(const json_spirit::Array& params, bool fHelp); +extern json_spirit::Value createtreesig(const json_spirit::Array& params, bool fHelp); extern json_spirit::Value listreceivedbyaddress(const json_spirit::Array& params, bool fHelp); extern json_spirit::Value listreceivedbyaccount(const json_spirit::Array& params, bool fHelp); extern json_spirit::Value listtransactions(const json_spirit::Array& params, bool fHelp); diff --git a/src/rpcwallet.cpp b/src/rpcwallet.cpp index ff54548b83..6374d0784e 100644 --- a/src/rpcwallet.cpp +++ b/src/rpcwallet.cpp @@ -929,6 +929,40 @@ Value sendmany(const Array& params, bool fHelp) // Defined in rpcmisc.cpp extern CScript _createmultisig_redeemScript(const Array& params); +void _createtreesig_redeemScript(const Array& params, KeyTree& tree, CScript* pscript); + +Value addtreesigaddress(const Array& params, bool fHelp) +{ + if (fHelp || params.size() < 1 || params.size() > 2) + { + string msg = "addmultisigaddress [\"key\",...] ( \"account\" )\n" + "\nCreates a tree multi-signature address to the wallet.\n" + + "\nArguments:\n" + "1. \"description\" (string, required) A description for the allowed combinations\n" + "2. \"account\" (string, optional) An account to assign the addresses to.\n" + + "\nResult:\n" + "\"bitcoinaddress\" (string) A bitcoin address associated with the keys.\n" + ; + throw runtime_error(msg); + } + + string strAccount; + if (params.size() > 2) + strAccount = AccountFromValue(params[2]); + + // Construct using pay-to-script-hash: + KeyTree tree; + CScript inner; + _createtreesig_redeemScript(params, tree, &inner); + CScriptID innerID(inner); + pwalletMain->AddKeyTree(tree); + pwalletMain->AddCScript(inner); + + pwalletMain->SetAddressBook(innerID, strAccount, "send"); + return CBitcoinAddress(innerID).ToString(); +} Value addmultisigaddress(const Array& params, bool fHelp) { diff --git a/src/script/generic.hpp b/src/script/generic.hpp index b3d2ce1c15..93c4d1cec8 100644 --- a/src/script/generic.hpp +++ b/src/script/generic.hpp @@ -41,6 +41,24 @@ class SimpleSignatureCreator : public BaseSignatureCreator return false; return key.Sign(checker.hash, vchSig); } + bool CreatePartialSigNonce(std::vector& vchSigNonce, const CKeyID& keyid, const CScript& scriptCode) const + { + CKey key; + if (!keystore.GetKey(keyid, key)) + return false; + return key.PartialSigningNonce(checker.hash, vchSigNonce); + } + bool CreatePartialSig(std::vector& vchSig, const CKeyID& keyid, const CScript& scriptCode, const std::vector& my_pubnonce, const std::vector >& other_pubnonces, const std::vector& other_pubkeys) const + { + CKey key; + if (!keystore.GetKey(keyid, key)) + return false; + return key.PartialSign(checker.hash, other_pubnonces, other_pubkeys, my_pubnonce, vchSig); + } + bool CombinePartialSigs(std::vector& out, const std::vector >& ins) const + { + return CombinePartialSignatures(ins, out); + } }; template diff --git a/src/script/script.h b/src/script/script.h index 322fa8082f..4c37697d12 100644 --- a/src/script/script.h +++ b/src/script/script.h @@ -558,6 +558,20 @@ class CScript : public std::vector return (opcodetype)(OP_1+n-1); } + static bool DecodeInt(opcodetype opcode, const std::vector& data, int64_t* ret) + { + if (opcode == OP_0) { + *ret = 0; + return true; + } else if (opcode >= OP_1 && opcode <= OP_16) { + *ret = (int)opcode - (int)(OP_1 - 1); + return true; + } else if (data.size() < 5) { + *ret = CScriptNum(data, false).getint64(); + } + return false; + } + int FindAndDelete(const CScript& b) { int nFound = 0; diff --git a/src/script/sign.cpp b/src/script/sign.cpp index 7375f9ac26..91055b5f2b 100644 --- a/src/script/sign.cpp +++ b/src/script/sign.cpp @@ -6,6 +6,7 @@ #include "script/sign.h" #include "primitives/transaction.h" +#include "pubkey.h" #include "key.h" #include "keystore.h" #include "script/standard.h" @@ -32,6 +33,34 @@ bool TransactionSignatureCreator::CreateSig(std::vector& vchSig, return true; } +bool TransactionSignatureCreator::CreatePartialSigNonce(std::vector& nonce, const CKeyID& keyid, const CScript& scriptCode) const +{ + CKey key; + if (!keystore.GetKey(keyid, key)) + return false; + + uint256 hash = SignatureHash(scriptCode, checker.GetValueIn(), *txTo, nIn, nHashType); + return key.PartialSigningNonce(hash, nonce); +} + +bool TransactionSignatureCreator::CreatePartialSig(std::vector& vchSig, const CKeyID& keyid, const CScript& scriptCode, const std::vector& my_pubnonce, const std::vector >& other_pubnonces, const std::vector& other_pubkeys) const +{ + CKey key; + if (!keystore.GetKey(keyid, key)) + return false; + + uint256 hash = SignatureHash(scriptCode, checker.GetValueIn(), *txTo, nIn, nHashType); + return key.PartialSign(hash, other_pubnonces, other_pubkeys, my_pubnonce, vchSig); +} + +bool TransactionSignatureCreator::CombinePartialSigs(std::vector& out, const std::vector >& ins) const +{ + if (!CombinePartialSignatures(ins, out)) + return false; + out.push_back((unsigned char)nHashType); + return true; +} + static bool Sign1(const CKeyID& address, const BaseSignatureCreator& creator, const CScript& scriptCode, CScript& scriptSigRet) { vector vchSig; @@ -55,6 +84,195 @@ static bool SignN(const vector& multisigdata, const BaseSignatureCreato return nSigned==nRequired; } +static bool ParsePartialTreeSig(const std::vector& parsedSig, std::map& mapNonces, std::map, valtype>& mapSigs) { + for (std::vector::const_iterator it = parsedSig.begin(); it != parsedSig.end(); ++it) { + const valtype& data = *it; + CPubKey key, key2; + if (data.size() == 0) continue; + if (data[0] == 0 && data.size() > 34) { + key.Set(data.begin() + 1, data.begin() + 34); + if (!key.IsValid()) return false; + mapNonces[key] = valtype(data.begin() + 34, data.end()); + } else if (data[0] == 1 && data.size() > 67) { + key.Set(data.begin() + 1, data.begin() + 34); + if (!key.IsValid()) return false; + key2.Set(data.begin() + 34, data.begin() + 67); + if (!key2.IsValid()) return false; + mapSigs[std::make_pair(key, key2)] = valtype(data.begin() + 67, data.end()); + } else { + continue; + } + } + return true; +} + +static std::vector SerializePartialTreeSig(const std::map& mapNonces, const std::map, valtype>& mapSigs) { + std::vector ret; + for (std::map::const_iterator it = mapNonces.begin(); it != mapNonces.end(); ++it) { + valtype add; + add.push_back((unsigned char)0); + add.insert(add.end(), it->first.begin(), it->first.end()); + add.insert(add.end(), it->second.begin(), it->second.end()); + ret.push_back(add); + } + for (std::map, valtype>::const_iterator it = mapSigs.begin(); it != mapSigs.end(); ++it) { + valtype add; + add.push_back((unsigned char)1); + add.insert(add.end(), it->first.first.begin(), it->first.first.end()); + add.insert(add.end(), it->first.second.begin(), it->first.second.end()); + add.insert(add.end(), it->second.begin(), it->second.end()); + ret.push_back(add); + } + return ret; +} + +namespace +{ +struct SignKeyTreeFilter : public KeyTreeFilter +{ + const std::map* nonces; + + SignKeyTreeFilter(const std::map* nonces_) : nonces(nonces_) {} + + bool operator()(const CPubKey& pubkey) + { + return nonces->count(pubkey) > 0; + } +}; +} + +static CScript PushAll(const vector& values); + +static bool SignTreeSig(const BaseSignatureCreator* creator, const std::vector& vSolutions, const CScript* scriptCode, CScript& scriptSig, const std::vector* scriptSigIn1, const std::vector* scriptSigIn2) +{ + // Combine existing nonces and signatures + std::map mapNonces; + std::map, valtype> mapSigs; + if (scriptSigIn2) { + ParsePartialTreeSig(*scriptSigIn2, mapNonces, mapSigs); + } + if (scriptSigIn1) { + ParsePartialTreeSig(*scriptSigIn1, mapNonces, mapSigs); + } + + // Retrieve the tree. + uint256 hash; + memcpy(hash.begin(), &vSolutions[1][0], 32); + bool havetree = false; + KeyTree tree; + if (creator && creator->KeyStore().GetKeyTree(hash, tree)) { + if (tree.hash != hash) { + assert(!"Tree hash and requested hash mismatch"); + } + havetree = true; + } + + // Create nonces where possible + if (havetree && scriptCode) { + std::set leaves = GetAllLeaves(&tree.root); + for (std::set::const_iterator it = leaves.begin(); it != leaves.end(); ++it) { + if (mapNonces.count(*it) == 0) { + valtype nonce; + if (creator->CreatePartialSigNonce(nonce, it->GetID(), *scriptCode)) { + mapNonces[*it] = nonce; + } + } + } + } + + // Find a permissible subset of pubkeys + SignKeyTreeFilter filter(&mapNonces); + std::vector branch; + uint64_t position; + std::vector pubkeys; + uint256 root; + if (havetree && HasMatch(&tree.root, &filter)) { + if (!GetMerkleBranch(&tree.root, &filter, &root, NULL, &position, &branch, &pubkeys)) { + assert(!"No merkle branch found despite matching filter"); + } + if (root != hash) { + assert(!"Recomputed merkle root differs from merkle root in scriptPubKey"); + } + CPubKey combinedpubkey; + if (!CombinePubKeys(pubkeys, combinedpubkey)) { + assert(!"Cannot create combined pubkey"); + } + std::vector allsigs; + bool haveall = true; + valtype combinedsig; + if (pubkeys.size() == 1) { + if (!creator->CreateSig(combinedsig, pubkeys[0].GetID(), *scriptCode)) { + haveall = false; + } + } else { + for (std::vector::const_iterator it = pubkeys.begin(); it != pubkeys.end(); ++it) { + std::pair id = make_pair(combinedpubkey, *it); + if (mapSigs.count(id) == 0) { + // Try signing for this missing pubkey + CKey key; + if (scriptCode && creator->KeyStore().GetKey(it->GetID(), key)) { + valtype sig; + valtype my_pubnonce; + std::vector other_pubnonces; + other_pubnonces.reserve(pubkeys.size() - 1); + std::vector other_pubkeys; + other_pubkeys.reserve(pubkeys.size() - 1); + for (std::vector::const_iterator it2 = pubkeys.begin(); it2 != pubkeys.end(); ++it2) { + if (it == it2) { + my_pubnonce = mapNonces[*it2]; + } else { + other_pubnonces.push_back(mapNonces[*it2]); + other_pubkeys.push_back(*it2); + } + } + if (!creator->CreatePartialSig(sig, it->GetID(), *scriptCode, my_pubnonce, other_pubnonces, other_pubkeys)) { + assert(!"Failed to create partial signature, despite valid secret key"); + } + mapSigs[id] = sig; + allsigs.push_back(sig); + } else { + haveall = false; + } + } else { + allsigs.push_back(mapSigs[id]); + } + } + if (haveall && !creator->CombinePartialSigs(combinedsig, allsigs)) { + assert(!"Cannot combine partial signatures"); + } + } + if (haveall) { + scriptSig.clear(); + scriptSig << combinedsig; + scriptSig << ToByteVector(combinedpubkey); + CScript walk; + for (std::vector::const_iterator it = branch.begin(); it != branch.end(); ++it) { + CScript add; + if (*it == uint256()) { + add << OP_1; + } else { + valtype merklenode; + merklenode.resize(32); + memcpy(&merklenode[0], it->begin(), 32); + add << merklenode; + } + if (position & 1) { + add << OP_0; + } else { + add << OP_1; + } + walk = add + walk; + position >>= 1; + } + scriptSig += walk; + return true; + } + } + + scriptSig = PushAll(SerializePartialTreeSig(mapNonces, mapSigs)); + return false; +} + /** * Sign scriptPubKey using signature made with creator. * Signatures are returned in scriptSigRet (or returns false if scriptPubKey can't be signed), @@ -64,6 +282,8 @@ static bool SignN(const vector& multisigdata, const BaseSignatureCreato static bool SignStep(const BaseSignatureCreator& creator, const CScript& scriptPubKey, CScript& scriptSigRet, txnouttype& whichTypeRet) { + vector stack; + EvalScript(stack, scriptSigRet, SCRIPT_VERIFY_STRICTENC, BaseSignatureChecker()); scriptSigRet.clear(); vector vSolutions; @@ -105,6 +325,10 @@ static bool SignStep(const BaseSignatureCreator& creator, const CScript& scriptP case TX_MULTISIG: scriptSigRet << OP_0; // workaround CHECKMULTISIG bug return (SignN(vSolutions, creator, scriptPubKey, scriptSigRet)); + + case TX_TREESIG: + return SignTreeSig(&creator, vSolutions, &scriptPubKey, scriptSigRet, &stack, NULL); + case TX_TRUE: return true; } @@ -113,7 +337,11 @@ static bool SignStep(const BaseSignatureCreator& creator, const CScript& scriptP bool ProduceSignature(const BaseSignatureCreator& creator, const CScript& fromPubKey, CScript& scriptSig) { + if (VerifyScript(scriptSig, fromPubKey, STANDARD_SCRIPT_VERIFY_FLAGS, creator.Checker())) + return true; + txnouttype whichType; + CScript scriptSigOriginal = scriptSig; if (!SignStep(creator, fromPubKey, scriptSig, whichType)) return false; @@ -123,6 +351,7 @@ bool ProduceSignature(const BaseSignatureCreator& creator, const CScript& fromPu // the final scriptSig is the signatures from that // and then the serialized subscript: CScript subscript = scriptSig; + scriptSig = scriptSigOriginal; txnouttype subType; bool fSolved = @@ -273,6 +502,10 @@ static CScript CombineSignatures(const CScript& scriptPubKey, const BaseSignatur } case TX_MULTISIG: return CombineMultisig(scriptPubKey, checker, vSolutions, sigs1, sigs2); + case TX_TREESIG: + CScript result; + SignTreeSig(NULL, vSolutions, NULL, result, &sigs1, &sigs2); + return result; } return CScript(); @@ -288,6 +521,13 @@ CScript CombineSignatures(const CScript& scriptPubKey, const CTransaction& txTo, CScript CombineSignatures(const CScript& scriptPubKey, const BaseSignatureChecker& checker, const CScript& scriptSig1, const CScript& scriptSig2) { + if (VerifyScript(scriptSig1, scriptPubKey, STANDARD_SCRIPT_VERIFY_FLAGS, checker)) { + return scriptSig1; + } + if (VerifyScript(scriptSig2, scriptPubKey, STANDARD_SCRIPT_VERIFY_FLAGS, checker)) { + return scriptSig2; + } + txnouttype txType; vector > vSolutions; Solver(scriptPubKey, txType, vSolutions); diff --git a/src/script/sign.h b/src/script/sign.h index c90f016ee5..c64a98f5c2 100644 --- a/src/script/sign.h +++ b/src/script/sign.h @@ -28,6 +28,11 @@ class BaseSignatureCreator { /** Create a singular (non-script) signature. */ virtual bool CreateSig(std::vector& vchSig, const CKeyID& keyid, const CScript& scriptCode) const =0; + + /** Deal with partial signatures to produce a multisignature. */ + virtual bool CreatePartialSigNonce(std::vector& vchSigNonce, const CKeyID& keyid, const CScript& scriptCode) const =0; + virtual bool CreatePartialSig(std::vector& vchSig, const CKeyID& keyid, const CScript& scriptCode, const std::vector& my_pubnonce, const std::vector >& other_pubnonces, const std::vector& other_pubkeys) const =0; + virtual bool CombinePartialSigs(std::vector& out, const std::vector >& ins) const =0; }; /** A signature creator for transactions. */ @@ -41,6 +46,9 @@ class TransactionSignatureCreator : public BaseSignatureCreator { TransactionSignatureCreator(const CKeyStore& keystoreIn, const CTransaction* txToIn, unsigned int nInIn, const CTxOutValue& nValueIn, int nHashTypeIn=SIGHASH_ALL); const BaseSignatureChecker& Checker() const { return checker; } bool CreateSig(std::vector& vchSig, const CKeyID& keyid, const CScript& scriptCode) const; + bool CreatePartialSigNonce(std::vector& vchSigNonce, const CKeyID& keyid, const CScript& scriptCode) const; + bool CreatePartialSig(std::vector& vchSig, const CKeyID& keyid, const CScript& scriptCode, const std::vector& my_pubnonce, const std::vector >& other_pubnonces, const std::vector& other_pubkeys) const; + bool CombinePartialSigs(std::vector& out, const std::vector >& ins) const; }; /** Produce a script signature using a generic signature creator. */ diff --git a/src/script/standard.cpp b/src/script/standard.cpp index 6818afc476..d6a013a04a 100644 --- a/src/script/standard.cpp +++ b/src/script/standard.cpp @@ -29,6 +29,7 @@ const char* GetTxnOutputType(txnouttype t) case TX_PUBKEYHASH: return "pubkeyhash"; case TX_SCRIPTHASH: return "scripthash"; case TX_MULTISIG: return "multisig"; + case TX_TREESIG: return "treesig"; case TX_NULL_DATA: return "nulldata"; case TX_WITHDRAW_LOCK: return "withdraw"; case TX_WITHDRAW_OUT: return "withdrawout"; @@ -37,6 +38,35 @@ const char* GetTxnOutputType(txnouttype t) return NULL; } +bool CheckTreeSig(const CScript& scriptPubKey, int& levelsout, uint256& merklerootout) +{ + CScript::const_iterator pc = scriptPubKey.begin(); + opcodetype opcode; + std::vector data; + int64_t levels = -1; + if (!scriptPubKey.GetOp(pc, opcode, data) || !CScript::DecodeInt(opcode, data, &levels)) return false; + if (levels % 2) return false; + levels >>= 1; + if (levels > 32) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_PICK) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_SHA256) return false; + for (int64_t i = 0; i < levels; i++) { + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_SWAP) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_IF) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_SWAP) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_ENDIF) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_CAT) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_SHA256) return false; + } + if (!scriptPubKey.GetOp(pc, opcode, data) || data.size() != 32) return false; + memcpy(merklerootout.begin(), &data[0], 32); + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_EQUALVERIFY) return false; + if (!scriptPubKey.GetOp(pc, opcode) || opcode != OP_CHECKSIG) return false; + if (pc != scriptPubKey.end()) return false; + levelsout = levels; + return true; +} + /** * Return public keys or hashes from scriptPubKey, for 'standard' transaction types. */ @@ -172,6 +202,17 @@ bool Solver(const CScript& scriptPubKey, txnouttype& typeRet, vector(1, (unsigned char)levels); + vSolutionsRet[1] = std::vector(merkleroot.begin(), merkleroot.end()); + return true; + } + vSolutionsRet.clear(); typeRet = TX_NONSTANDARD; return false; @@ -194,6 +235,8 @@ int ScriptSigArgsExpected(txnouttype t, const std::vector& keys) script << CScript::EncodeOP_N(keys.size()) << OP_CHECKMULTISIG; return script; } + +CScript GetScriptForTree(const KeyTree& tree) +{ + std::vector merkleroot; + merkleroot.resize(32); + memcpy(&merkleroot[0], tree.hash.begin(), 32); + int levels = tree.levels; + + CScript script; + script << (levels * 2) << OP_PICK << OP_SHA256; + for (int i = 0; i < levels; i++) { + script << OP_SWAP << OP_IF << OP_SWAP << OP_ENDIF << OP_CAT << OP_SHA256; + } + script << merkleroot << OP_EQUALVERIFY << OP_CHECKSIG; + + return script; +} diff --git a/src/script/standard.h b/src/script/standard.h index 05e736c316..3f0dbbe661 100644 --- a/src/script/standard.h +++ b/src/script/standard.h @@ -6,6 +6,7 @@ #ifndef BITCOIN_SCRIPT_STANDARD_H #define BITCOIN_SCRIPT_STANDARD_H +#include "keytree.h" #include "script/interpreter.h" #include "uint256.h" @@ -62,6 +63,7 @@ enum txnouttype TX_PUBKEY, TX_PUBKEYHASH, TX_SCRIPTHASH, + TX_TREESIG, TX_MULTISIG, TX_NULL_DATA, TX_WITHDRAW_LOCK, @@ -94,5 +96,6 @@ bool ExtractDestinations(const CScript& scriptPubKey, txnouttype& typeRet, std:: CScript GetScriptForDestination(const CTxDestination& dest); CScript GetScriptForMultisig(int nRequired, const std::vector& keys); +CScript GetScriptForTree(const KeyTree& tree); #endif // BITCOIN_SCRIPT_STANDARD_H diff --git a/src/test/keytree_tests.cpp b/src/test/keytree_tests.cpp new file mode 100644 index 0000000000..101633ffb6 --- /dev/null +++ b/src/test/keytree_tests.cpp @@ -0,0 +1,584 @@ +// Copyright (c) 2015 Pieter Wuille +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#include "util.h" +#include + +#include + +#include "core_io.h" +#include "thresholdtree.h" +#include "keytree.h" +#include "key.h" +#include "keystore.h" +#include "script/script.h" +#include "script/sign.h" +#include "script/standard.h" + +BOOST_AUTO_TEST_SUITE(keytree_tests) + +struct SimpleTree { + std::vector children; + int leaf; + int threshold; + + inline int Children() const { return children.size(); } + inline int Threshold() const { return threshold; } + inline bool IsLeaf() const { return children.size() == 0; } + inline SimpleTree* Child(int pos) { return &children[pos]; } +}; + +struct TestAccumulator { + std::vector x; + void Push(SimpleTree* v) { + x.push_back(v->leaf); + } + void Pop(SimpleTree* v) { + BOOST_CHECK(x.size() >= 1); + BOOST_CHECK_EQUAL(x.back(), v->leaf); + x.pop_back(); + } + + void Test(int a, int b, int c, int d) { + BOOST_CHECK_EQUAL(x.size(), 4); + BOOST_CHECK_EQUAL(x[0], a); + BOOST_CHECK_EQUAL(x[1], b); + BOOST_CHECK_EQUAL(x[2], c); + BOOST_CHECK_EQUAL(x[3], d); + } +}; + +BOOST_AUTO_TEST_CASE(thresholdtree_test) +{ + SimpleTree root; + root.threshold = 2; + root.children.resize(3); + root.children[0].threshold = 2; + root.children[0].children.resize(3); + root.children[0].children[0].leaf = 11; + root.children[0].children[1].leaf = 12; + root.children[0].children[2].leaf = 13; + root.children[1].threshold = 2; + root.children[1].children.resize(3); + root.children[1].children[0].leaf = 21; + root.children[1].children[1].leaf = 22; + root.children[1].children[2].leaf = 23; + root.children[2].threshold = 2; + root.children[2].children.resize(3); + root.children[2].children[0].leaf = 31; + root.children[2].children[1].leaf = 32; + root.children[2].children[2].leaf = 33; + + ThresholdTreeIterator iter(&root); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 12, 21, 22); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 12, 21, 23); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 12, 22, 23); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 13, 21, 22); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 13, 21, 23); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 13, 22, 23); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(12, 13, 21, 22); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(12, 13, 21, 23); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(12, 13, 22, 23); iter.Increment(); + + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 12, 31, 32); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 12, 31, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 12, 32, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 13, 31, 32); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 13, 31, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(11, 13, 32, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(12, 13, 31, 32); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(12, 13, 31, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(12, 13, 32, 33); iter.Increment(); + + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(21, 22, 31, 32); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(21, 22, 31, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(21, 22, 32, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(21, 23, 31, 32); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(21, 23, 31, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(21, 23, 32, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(22, 23, 31, 32); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(22, 23, 31, 33); iter.Increment(); + BOOST_CHECK(iter.Valid()); iter.GetAccumulator()->Test(22, 23, 32, 33); iter.Increment(); + + BOOST_CHECK(!iter.Valid()); + + int64_t count = CountCombinations(&root); + BOOST_CHECK_EQUAL(count, 27); +} + +CKey SecKeyNum(uint32_t x) { + unsigned char data[32] = {0}; + data[31] = x & 0xFF; + data[30] = (x >> 8) & 0xFF; + data[29] = (x >> 16) & 0xFF; + data[28] = (x >> 24) & 0xFF; + CKey key; + key.Set(&data[0], &data[32], true); + return key; +} + +CPubKey KeyNum(uint32_t x) { + CKey key = SecKeyNum(x); + return key.GetPubKey(); +} + +uint256 KeyNumHash(uint32_t i) { + CPubKey x = KeyNum(i); + assert(x.size() == 33); + uint256 ret; + CSHA256().Write(x.begin(), 33).Finalize(ret.begin()); + return ret; +} + +uint256 CombineTwoHashes(const uint256& x, const uint256& y) { + uint256 ret; + CSHA256().Write(x.begin(), 32).Write(y.begin(), 32).Finalize(ret.begin()); + return ret; +} + +uint256 CombineOneHash(const uint256& x) { + uint256 ret; + static const unsigned char one[1] = {1}; + CSHA256().Write(x.begin(), 32).Write(one, 1).Finalize(ret.begin()); + return ret; +} + +struct SetKeyTreeFilter : public KeyTreeFilter +{ + std::set valid; + + bool operator()(const CPubKey& key) { + return valid.count(key); + } + + SetKeyTreeFilter() {} + + SetKeyTreeFilter(const CPubKey& key) { + valid.insert(key); + } +}; + +BOOST_AUTO_TEST_CASE(keytree_test0) +{ + /* 1 + */ + + static const std::string formatted = + "0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"; + KeyTree tree; + BOOST_CHECK(ParseKeyTree(formatted, tree)); + BOOST_CHECK_EQUAL(FormatKeyTree(tree), formatted); + BOOST_CHECK_EQUAL(tree.levels, 0); + + KeyTreeNode root; + root.leaf = KeyNum(1); + uint64_t count; + uint256 rootA = GetMerkleRoot(&root, &count); + BOOST_CHECK_EQUAL(count, 1); + BOOST_CHECK(rootA == tree.hash); + BOOST_CHECK(tree.root == root); + + uint256 rootB = KeyNumHash(1); + BOOST_CHECK(rootA == rootB); + + uint256 rootC; + uint64_t countC; + uint64_t positionC; + std::vector branch; + std::vector pubkeys; + SetKeyTreeFilter filterkey(KeyNum(1)); + BOOST_CHECK(GetMerkleBranch(&root, &filterkey, &rootC, &countC, &positionC, &branch, &pubkeys)); + BOOST_CHECK(rootA == rootC); + BOOST_CHECK_EQUAL(count, countC); + BOOST_CHECK_EQUAL(positionC, 0); + BOOST_CHECK_EQUAL(branch.size(), 0); + BOOST_CHECK_EQUAL(pubkeys.size(), 1); + BOOST_CHECK(pubkeys[0] == KeyNum(1)); + + SetKeyTreeFilter filterkeynone; + BOOST_CHECK(GetMerkleBranch(&root, &filterkeynone, &rootC, &countC, &positionC, &branch) == false); + + CMutableTransaction mtxFrom; + mtxFrom.vout.resize(1); + mtxFrom.vout[0].scriptPubKey = GetScriptForTree(tree); + CTransaction txFrom(mtxFrom); + CMutableTransaction mtxTo; + mtxTo.vin.resize(1); + mtxTo.vin[0].prevout.hash = txFrom.GetHash(); + mtxTo.vin[0].prevout.n = 0; + CBasicKeyStore keystore; + keystore.AddKeyTree(tree); + keystore.AddKey(SecKeyNum(1)); + BOOST_CHECK(SignSignature(keystore, txFrom, mtxTo, 0)); + BOOST_CHECK_EQUAL(mtxTo.vin[0].scriptSig.ToString(), + "6438800535962d52309f384d257e30bdeeaf895d49f80dfa73787426997324c77adb3d" + "94c052c0cf7e52783d7f8346255e5f4bac42f8e13eeb4f587175b93c1e01 0279be667" + "ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"); +} + +BOOST_AUTO_TEST_CASE(keytree_test1) +{ + /* 1-of + * / | \ + * / | \ + * 1 2 4 + */ + + static const std::string formatted = + "OR(0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798," + "02c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5,02e" + "493dbf1c10d80f3581e4904930b1404cc6c13900ee0758474fa94abe8c4cd13)"; + KeyTree tree; + BOOST_CHECK(ParseKeyTree(formatted, tree)); + BOOST_CHECK_EQUAL(FormatKeyTree(tree), formatted); + BOOST_CHECK_EQUAL(tree.levels, 2); + + KeyTreeNode root; + root.threshold = 1; + root.children.resize(3); + for (int i = 0; i < 3; i++) { + root.children[i].leaf = KeyNum(1 << i); + } + uint64_t count; + uint256 rootA = GetMerkleRoot(&root, &count); + BOOST_CHECK_EQUAL(count, 3); + BOOST_CHECK(rootA == tree.hash); + BOOST_CHECK(tree.root == root); + + uint256 hashes0[12] = { + KeyNumHash(1), KeyNumHash(2), KeyNumHash(4) + }; + uint256 hashes1[6] = { + CombineTwoHashes(hashes0[0], hashes0[1]), CombineOneHash(hashes0[2]), + }; + uint256 rootB = CombineTwoHashes(hashes1[0], hashes1[1]); + BOOST_CHECK(rootA == rootB); + + uint256 rootC; + uint64_t countC; + uint64_t positionC; + std::vector branch; + std::vector pubkeys; + SetKeyTreeFilter filterkey4(KeyNum(4)); + BOOST_CHECK(GetMerkleBranch(&root, &filterkey4, &rootC, &countC, &positionC, &branch, &pubkeys)); + BOOST_CHECK(rootA == rootC); + BOOST_CHECK_EQUAL(count, countC); + BOOST_CHECK_EQUAL(positionC, 2); + BOOST_CHECK_EQUAL(branch.size(), 2); + BOOST_CHECK(branch[0] == uint256()); + BOOST_CHECK(branch[1] == hashes1[0]); + BOOST_CHECK_EQUAL(pubkeys.size(), 1); + BOOST_CHECK(pubkeys[0] == KeyNum(4)); + + SetKeyTreeFilter filterkeynone; + BOOST_CHECK(GetMerkleBranch(&root, &filterkeynone, &rootC, &countC, &positionC, &branch) == false); + + CMutableTransaction mtxFrom; + mtxFrom.vout.resize(1); + mtxFrom.vout[0].scriptPubKey = GetScriptForTree(tree); + CTransaction txFrom(mtxFrom); + CMutableTransaction mtxTo; + mtxTo.vin.resize(1); + mtxTo.vin[0].prevout.hash = txFrom.GetHash(); + mtxTo.vin[0].prevout.n = 0; + CBasicKeyStore keystore; + keystore.AddKeyTree(tree); + keystore.AddKey(SecKeyNum(1)); + BOOST_CHECK(SignSignature(keystore, txFrom, mtxTo, 0)); + BOOST_CHECK_EQUAL(mtxTo.vin[0].scriptSig.ToString(), + "4ea7e71ddab23f1e9991bfe7579dd777cf68699a6c4809d69ab197f1a4709ced177ec2" + "bf9864b147c85faab20783830be9fb9507ee5649cc14f13314e098d1b001 0279be667" + "ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798 ed057686095e" + "5e9df917f11adf776e27199f80c34bdf106d76d91d8ec6c41cde 1 b1c9938f01121e1" + "59887ac2c8d393a22e4476ff8212de13fe1939de2a236f0a7 1"); +} + +BOOST_AUTO_TEST_CASE(keytree_test2) +{ + /* 2-of + * / | \ + * / | \ + * 1-of 1-of 1-of + * / \ / \ / \ + * 1 2 4 8 16 32 + */ + static const std::string formatted = + "THRESHOLD(2,OR(0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f28" + "15b16f81798,02c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b9" + "5c709ee5),OR(02e493dbf1c10d80f3581e4904930b1404cc6c13900ee0758474fa94a" + "be8c4cd13,022f01e5e15cca351daff3843fb70f3c2f0a1bdd05e5af888a67784ef3e1" + "0a2a01),OR(03e60fce93b59e9ec53011aabc21c23e97b2a31369b87a5ae9c44ee89e2" + "a6dec0a,03d30199d74fb5a22d47b6e054e2f378cedacffcb89904a61d75d0dbd40714" + "3e65))"; + KeyTree tree; + BOOST_CHECK(ParseKeyTree(formatted, tree)); + BOOST_CHECK_EQUAL(FormatKeyTree(tree), formatted); + BOOST_CHECK_EQUAL(tree.levels, 4); + + KeyTreeNode root; + root.threshold = 2; + root.children.resize(3); + for (int i = 0; i < 3; i++) { + root.children[i].threshold = 1; + root.children[i].children.resize(2); + for (int j = 0; j < 2; j++) { + root.children[i].children[j].leaf = KeyNum(1 << (2*i + j)); + } + } + uint64_t count; + uint256 rootA = GetMerkleRoot(&root, &count); + BOOST_CHECK_EQUAL(count, 12); + BOOST_CHECK(rootA == tree.hash); + BOOST_CHECK(tree.root == root); + + uint256 hashes0[12] = { + KeyNumHash(5), KeyNumHash(9), KeyNumHash(6), KeyNumHash(10), + KeyNumHash(17), KeyNumHash(33), KeyNumHash(18), KeyNumHash(34), + KeyNumHash(20), KeyNumHash(36), KeyNumHash(24), KeyNumHash(40) + }; + uint256 hashes1[6] = { + CombineTwoHashes(hashes0[0], hashes0[1]), CombineTwoHashes(hashes0[2], hashes0[3]), + CombineTwoHashes(hashes0[4], hashes0[5]), CombineTwoHashes(hashes0[6], hashes0[7]), + CombineTwoHashes(hashes0[8], hashes0[9]), CombineTwoHashes(hashes0[10], hashes0[11]) + }; + uint256 hashes2[3] = { + CombineTwoHashes(hashes1[0], hashes1[1]), CombineTwoHashes(hashes1[2], hashes1[3]), + CombineTwoHashes(hashes1[4], hashes1[5]) + }; + uint256 hashes3[2] = { + CombineTwoHashes(hashes2[0], hashes2[1]), CombineOneHash(hashes2[2]) + }; + uint256 rootB = CombineTwoHashes(hashes3[0], hashes3[1]); + BOOST_CHECK(rootA == rootB); + + SetKeyTreeFilter filterkey; + filterkey.valid.insert(KeyNum(16)); + filterkey.valid.insert(KeyNum(8)); + filterkey.valid.insert(KeyNum(32)); + uint256 rootC; + uint64_t countC; + uint64_t positionC; + std::vector branch; + std::vector pubkeys; + BOOST_CHECK(GetMerkleBranch(&root, &filterkey, &rootC, &countC, &positionC, &branch, &pubkeys)); + BOOST_CHECK(rootA == rootC); + BOOST_CHECK(count == countC); + BOOST_CHECK(positionC == 10); + BOOST_CHECK(branch.size() == 4); + BOOST_CHECK(branch[0] == hashes0[11]); + BOOST_CHECK(branch[1] == hashes1[4]); + BOOST_CHECK(branch[2] == uint256()); + BOOST_CHECK(branch[3] == hashes3[0]); + BOOST_CHECK(pubkeys.size() == 2); + BOOST_CHECK(pubkeys[0] == KeyNum(8)); + BOOST_CHECK(pubkeys[1] == KeyNum(16)); + + CScript redeemScript = GetScriptForTree(tree); + CScriptID p2sh(redeemScript); + CMutableTransaction mtxFrom; + mtxFrom.vout.resize(1); + mtxFrom.vout[0].scriptPubKey = GetScriptForDestination(p2sh); + CTransaction txFrom(mtxFrom); + CMutableTransaction mtxTo; + mtxTo.vin.resize(1); + mtxTo.vin[0].prevout.hash = txFrom.GetHash(); + mtxTo.vin[0].prevout.n = 0; + CBasicKeyStore keystore1; + keystore1.AddKeyTree(tree); + keystore1.AddCScript(redeemScript); + keystore1.AddKey(SecKeyNum(4)); + keystore1.AddKey(SecKeyNum(8)); + BOOST_CHECK(!SignSignature(keystore1, txFrom, mtxTo, 0)); + CBasicKeyStore keystore2; + keystore2.AddKeyTree(tree); + keystore2.AddCScript(redeemScript); + keystore2.AddKey(SecKeyNum(16)); + BOOST_CHECK(!SignSignature(keystore2, txFrom, mtxTo, 0)); + BOOST_CHECK(SignSignature(keystore1, txFrom, mtxTo, 0)); + BOOST_CHECK_EQUAL(mtxTo.vin[0].scriptSig.ToString(), + "24199b18f2148c9521c7ec97f49c5241a3ec47d2a08257f9671683aaee1a8a3bc8c2ff" + "3b9cc77c97ed863565e939494e61623d2537b9458303c340c7a7689c6701 024ce119c" + "96e2fa357200b559b2f7dd5a5f02d5290aff74b03f3e471b273211c97 bd2a91465655" + "067eadb86cf4f660671e63566c71a5328adfbcde342c17d5e450 0 1 1 6e7468b3b7a" + "82215ee3a373a658ce4bc6ad11c53fa9c896f3239650442c1f0a3 1 5417114a2fc7e0" + "f1cd82684c9b0ca852ceb600562a6adeb18cc795c2efdd9c63 1 5879a87c637c687ea" + "87c637c687ea87c637c687ea87c637c687ea820927db000c4bfec1a1f6c07cb13b36f9" + "593300885b5d816f64c269245c02724f588ac"); +} + +BOOST_AUTO_TEST_CASE(keytree_test3) +{ + /* 2-of + * / | \ + * / | \ + * 1 1-of 2-of + * / \ / \ + * 2 4 8 16 + */ + static const std::string formatted = + "THRESHOLD(2,0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b" + "16f81798,OR(02c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b9" + "5c709ee5,02e493dbf1c10d80f3581e4904930b1404cc6c13900ee0758474fa94abe8c" + "4cd13),AND(022f01e5e15cca351daff3843fb70f3c2f0a1bdd05e5af888a67784ef3e" + "10a2a01,03e60fce93b59e9ec53011aabc21c23e97b2a31369b87a5ae9c44ee89e2a6d" + "ec0a))"; + KeyTree tree; + BOOST_CHECK(ParseKeyTree(formatted, tree)); + BOOST_CHECK_EQUAL(FormatKeyTree(tree), formatted); + BOOST_CHECK_EQUAL(tree.levels, 3); + + KeyTreeNode root; + root.threshold = 2; + root.children.resize(3); + root.children[0].leaf = KeyNum(1); + root.children[1].threshold = 1; + root.children[1].children.resize(2); + root.children[1].children[0].leaf = KeyNum(2); + root.children[1].children[1].leaf = KeyNum(4); + root.children[2].threshold = 2; + root.children[2].children.resize(2); + root.children[2].children[0].leaf = KeyNum(8); + root.children[2].children[1].leaf = KeyNum(16); + uint64_t count; + uint256 rootA = GetMerkleRoot(&root, &count); + BOOST_CHECK_EQUAL(count, 5); + BOOST_CHECK(rootA == tree.hash); + BOOST_CHECK(tree.root == root); + + uint256 hashes0[5] = { + KeyNumHash(3), KeyNumHash(5), KeyNumHash(25), KeyNumHash(26), + KeyNumHash(28) + }; + uint256 hashes1[3] = { + CombineTwoHashes(hashes0[0], hashes0[1]), CombineTwoHashes(hashes0[2], hashes0[3]), + CombineOneHash(hashes0[4]), + }; + uint256 hashes2[2] = { + CombineTwoHashes(hashes1[0], hashes1[1]), CombineOneHash(hashes1[2]) + }; + uint256 rootB = CombineTwoHashes(hashes2[0], hashes2[1]); + BOOST_CHECK(rootA == rootB); + + SetKeyTreeFilter filterkey; + filterkey.valid.insert(KeyNum(2)); + filterkey.valid.insert(KeyNum(4)); + filterkey.valid.insert(KeyNum(8)); + filterkey.valid.insert(KeyNum(16)); + uint256 rootC; + uint64_t countC; + uint64_t positionC; + std::vector branch; + std::vector pubkeys; + BOOST_CHECK(GetMerkleBranch(&root, &filterkey, &rootC, &countC, &positionC, &branch, &pubkeys)); + BOOST_CHECK(rootA == rootC); + BOOST_CHECK(count == countC); + BOOST_CHECK(positionC == 3); + BOOST_CHECK(branch.size() == 3); + BOOST_CHECK(branch[0] == hashes0[2]); + BOOST_CHECK(branch[1] == hashes1[0]); + BOOST_CHECK(branch[2] == hashes2[1]); + BOOST_CHECK(pubkeys.size() == 3); + BOOST_CHECK(pubkeys[0] == KeyNum(2)); + BOOST_CHECK(pubkeys[1] == KeyNum(8)); + BOOST_CHECK(pubkeys[2] == KeyNum(16)); +} + +BOOST_AUTO_TEST_CASE(keytree_test4) +{ + /* 3-of + * / | \ + * / | \ + * 1 2-of 8 + * / \ + * 2 4 + */ + + static const std::string formatted = + "AND(0279be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798" + ",AND(02c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee" + "5,02e493dbf1c10d80f3581e4904930b1404cc6c13900ee0758474fa94abe8c4cd13)," + "022f01e5e15cca351daff3843fb70f3c2f0a1bdd05e5af888a67784ef3e10a2a01)"; + KeyTree tree; + BOOST_CHECK(ParseKeyTree(formatted, tree)); + BOOST_CHECK_EQUAL(FormatKeyTree(tree), formatted); + BOOST_CHECK_EQUAL(tree.levels, 0); + + KeyTreeNode root; + root.threshold = 3; + root.children.resize(3); + root.children[0].leaf = KeyNum(1); + root.children[1].threshold = 2; + root.children[1].children.resize(2); + root.children[1].children[0].leaf = KeyNum(2); + root.children[1].children[1].leaf = KeyNum(4); + root.children[2].leaf = KeyNum(8); + + uint64_t count; + uint256 rootA = GetMerkleRoot(&root, &count); + BOOST_CHECK_EQUAL(count, 1); + BOOST_CHECK(rootA == tree.hash); + BOOST_CHECK(tree.root == root); + + uint256 rootB = KeyNumHash(15); + BOOST_CHECK(rootA == rootB); + + uint256 rootC; + uint64_t countC; + uint64_t positionC; + std::vector branch; + std::vector pubkeys; + SetKeyTreeFilter filterkey; + filterkey.valid.insert(KeyNum(1)); + filterkey.valid.insert(KeyNum(2)); + filterkey.valid.insert(KeyNum(8)); + filterkey.valid.insert(KeyNum(4)); + BOOST_CHECK(GetMerkleBranch(&root, &filterkey, &rootC, &countC, &positionC, &branch, &pubkeys)); + BOOST_CHECK(rootA == rootC); + BOOST_CHECK_EQUAL(count, countC); + BOOST_CHECK_EQUAL(positionC, 0); + BOOST_CHECK_EQUAL(branch.size(), 0); + BOOST_CHECK_EQUAL(pubkeys.size(), 4); + BOOST_CHECK(pubkeys[0] == KeyNum(1)); + BOOST_CHECK(pubkeys[1] == KeyNum(2)); + BOOST_CHECK(pubkeys[2] == KeyNum(4)); + BOOST_CHECK(pubkeys[3] == KeyNum(8)); + + SetKeyTreeFilter filterkeynone; + BOOST_CHECK(GetMerkleBranch(&root, &filterkeynone, &rootC, &countC, &positionC, &branch) == false); + + CMutableTransaction mtxFrom; + mtxFrom.vout.resize(1); + mtxFrom.vout[0].scriptPubKey = GetScriptForTree(tree); + CTransaction txFrom(mtxFrom); + CMutableTransaction mtxTo; + mtxTo.vin.resize(1); + mtxTo.vin[0].prevout.hash = txFrom.GetHash(); + mtxTo.vin[0].prevout.n = 0; + + CBasicKeyStore keystore1; + keystore1.AddKeyTree(tree); + keystore1.AddKey(SecKeyNum(1)); + CBasicKeyStore keystore2; + keystore2.AddKeyTree(tree); + keystore2.AddKey(SecKeyNum(2)); + CBasicKeyStore keystore3; + keystore3.AddKeyTree(tree); + keystore3.AddKey(SecKeyNum(4)); + CBasicKeyStore keystore4; + keystore4.AddKeyTree(tree); + keystore4.AddKey(SecKeyNum(8)); + BOOST_CHECK(!SignSignature(keystore1, txFrom, mtxTo, 0)); + BOOST_CHECK(!SignSignature(keystore2, txFrom, mtxTo, 0)); + BOOST_CHECK(!SignSignature(keystore3, txFrom, mtxTo, 0)); + BOOST_CHECK(!SignSignature(keystore4, txFrom, mtxTo, 0)); + BOOST_CHECK(!SignSignature(keystore2, txFrom, mtxTo, 0)); + BOOST_CHECK(!SignSignature(keystore1, txFrom, mtxTo, 0)); + BOOST_CHECK(SignSignature(keystore3, txFrom, mtxTo, 0)); + BOOST_CHECK_EQUAL(mtxTo.vin[0].scriptSig.ToString(), + "e3cefe369d7d75472ef307ab9a4a13ddc9b2eac7304c93a47ac715d701f0347d89cdbd" + "c19f173cc8143091c9206f5648691f5dffa8a3570ec10ae101a9b48d7f01 02d7924d4" + "f7d43ea965a465ae3095ff41131e5946f3c85f79e44adbcf8e27e080e"); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/thresholdtree.h b/src/thresholdtree.h new file mode 100644 index 0000000000..51f6041b64 --- /dev/null +++ b/src/thresholdtree.h @@ -0,0 +1,196 @@ +// Copyright (c) 2015 Pieter Wuille +// Distributed under the MIT software license, see the accompanying +// file COPYING or http://www.opensource.org/licenses/mit-license.php. + +#ifndef BITCOIN_THRESHOLDTREE_H +#define BITCOIN_THRESHOLDTREE_H + +#include +#include +#include + +/** A generic iterator over threshold trees. + * + * A threshold tree is a tree structure that describes a set of combinations + * of leaves. + * Every inner node defines a requirement that K (the threshold) of its + * N subnodes are satisfied. When K=1, it is equivalent to an OR over the N + * children. When K=N, it is equivalent to an AND over the N children. + * + * The tree node data type should support the following methods: + * - bool Node::IsLeaf(): return whether a node is a leaf + * - int Node::Threshold(): return the threshold of a node (only for non-leaf + * nodes) + * - int Node::Children(): return the number of children of a node (only for + * non-leaf nodes). + * - Node* Node::Child(int pos): return a pointer to the pos'th child node + * (only for non-leaf nodes). + * + * An iterator is also parametrized in the type of an accumulator. This + * accumulator is an object kept inside the iterator, and matching leaves are + * added and removed from it. When a iterator reaches a stable combination, + * the accumulator can be queried to retrieve information about the matching + * subset of leaves. Using an accumulator means that early checks can be + * performed, and not all combinations are necessarily actually ever stored. + */ +template +class ThresholdTreeIterator { + // InnerIterators are created for iterating over the combinations of one + // node of the original tree. + struct InnerIterator { + Node* node; // Tree node this InnerIterator iterates over + struct InnerChild { + int position; // What position we're at (subnode iterates over node->children[position]) + InnerIterator* inner_iterator; // Inner iterator for the child. + }; + std::vector children; + }; + + // InnerIterator objects are cached and reused to avoid frequent allocation + // within the iteration code. + std::vector all_inner_iterators; + std::vector available_inner_iterators; + InnerIterator* iterator_root; + Accumulator accumulator; + + // Return a new InnerIterator object (in unknown state). + InnerIterator* AcquireInnerIterator() { + if (!available_inner_iterators.empty()) { + InnerIterator* ret = available_inner_iterators.back(); + available_inner_iterators.pop_back(); + return ret; + } + InnerIterator* ret = new InnerIterator(); + all_inner_iterators.push_back(ret); + return ret; + } + + // Return an InnerIterator object to the cache. + void ReleaseInnerIterator(InnerIterator* iter) { + available_inner_iterators.push_back(iter); + } + + // Construct an InnerIterator object for a particular tree node. + InnerIterator* BuildInnerIterator(Node* node) { + InnerIterator *ret = AcquireInnerIterator(); + ret->node = node; + if (node->IsLeaf()) { + accumulator.Push(node); + return ret; + } + int threshold = node->Threshold(); + ret->children.resize(threshold); + for (int childnum = 0; childnum < threshold; childnum++) { + ret->children[childnum].position = childnum; + ret->children[childnum].inner_iterator = BuildInnerIterator(node->Child(childnum)); + } + return ret; + } + + // Move an InnerIterator to the next combinations it or its children allow. + // This will return false if there are no combinations left. In that case, + // InnerIterator will have been released. + bool IncrementInnerIterator(InnerIterator* iter) { + // First deal with the leaf case. + if (iter->node->IsLeaf()) { + accumulator.Pop(iter->node); + ReleaseInnerIterator(iter); + return false; + } + // Try to increment its child iterators first. + for (int childnum = iter->children.size() - 1; childnum >= 0; childnum--) { + if (IncrementInnerIterator(iter->children[childnum].inner_iterator)) { + for (int childnum2 = childnum + 1; childnum2 < iter->children.size(); childnum2++) { + iter->children[childnum2].inner_iterator = BuildInnerIterator(iter->node->Child(iter->children[childnum2].position)); + } + return true; + } + } + // If we reach this point, all child InnerIterators have been exhausted and released, so + // check quickly whether we're done entirely, in which case we can release this one as + // well. Otherwise we will need to consturct new child InnerIterators for all positions. + if (iter->children[0].position + iter->children.size() == iter->node->Children()) { + ReleaseInnerIterator(iter); + return false; + } + // If not, move to the next combinations of children. + for (int childnum = iter->children.size() - 1; ; childnum--) { + if (iter->children[childnum].position + iter->children.size() != childnum + iter->node->Children()) { + iter->children[childnum].position++; + for (int childnum2 = childnum + 1; childnum2 < iter->children.size(); childnum2++) { + iter->children[childnum2].position = iter->children[childnum2 - 1].position + 1; + } + for (int childnum2 = 0; childnum2 < iter->children.size(); childnum2++) { + iter->children[childnum2].inner_iterator = BuildInnerIterator(iter->node->Child(iter->children[childnum2].position)); + } + return true; + } + } + } + +public: + // Construct an iterator for a tree, using the empty constructor for the accumulator. + ThresholdTreeIterator(Node* root) { + iterator_root = BuildInnerIterator(root); + } + + // Construct an iterator for a tree, constructing the accumulator with one argument. + template + ThresholdTreeIterator(Node* root, const A& acc) : accumulator(acc) { + iterator_root = BuildInnerIterator(root); + } + + // Destroy an iterator and all its inner iterators. + ~ThresholdTreeIterator() { + for (typename std::vector::iterator it = all_inner_iterators.begin(); it != all_inner_iterators.end(); it++) { + delete *it; + } + } + + // Retrieve a pointer to the accumulator. + Accumulator* GetAccumulator() { + return &accumulator; + } + + // Retrieve a pointer to the accumulator (const) + const Accumulator* GetAccumulator() const { + return &accumulator; + } + + // Check whether this iterator is done iterating. + bool Valid() const { + return iterator_root != NULL; + } + + // Move to the next combination. + void Increment() { + if (iterator_root && !IncrementInnerIterator(iterator_root)) { + iterator_root = NULL; + } + } +}; + +// Compute the number of combinations, given the number of combinations allowed by children. +static inline uint64_t CountCombinationsFromArray(uint32_t pick, std::vector::const_iterator begin, size_t total) { + if (pick == 0) return 1; + if (pick > total) return 0; + uint64_t ret = 0; + for (uint32_t pos = 0; pos <= total - pick; pos++) { + ret += *(begin + pos) * CountCombinationsFromArray(pick - 1, begin + pos + 1, total - pos - 1); + } + return ret; +} + +// Compute the number of combinations allowed by a threshold tree. +template +static inline uint64_t CountCombinations(Node* node) { + if (node->IsLeaf()) return 1; + std::vector list; + list.resize(node->Children()); + for (size_t pos = 0; pos < list.size(); pos++) { + list[pos] = CountCombinations(node->Child(pos)); + } + return CountCombinationsFromArray(node->Threshold(), list.begin(), list.size()); +} + +#endif diff --git a/src/wallet.cpp b/src/wallet.cpp index 7520613d51..d6cac66f78 100644 --- a/src/wallet.cpp +++ b/src/wallet.cpp @@ -163,6 +163,15 @@ bool CWallet::AddCScript(const CScript& redeemScript) return CWalletDB(strWalletFile).WriteCScript(Hash160(redeemScript), redeemScript); } +bool CWallet::AddKeyTree(const KeyTree& tree) +{ + if (!CCryptoKeyStore::AddKeyTree(tree)) + return false; + if (!fFileBacked) + return true; + return CWalletDB(strWalletFile).WriteKeyTree(tree); +} + bool CWallet::LoadCScript(const CScript& redeemScript) { /* A sanity check was added in pull #3843 to avoid adding redeemScripts @@ -179,6 +188,11 @@ bool CWallet::LoadCScript(const CScript& redeemScript) return CCryptoKeyStore::AddCScript(redeemScript); } +bool CWallet::LoadKeyTree(const KeyTree& tree) +{ + return CCryptoKeyStore::AddKeyTree(tree); +} + bool CWallet::AddWatchOnly(const CScript &dest) { if (!CCryptoKeyStore::AddWatchOnly(dest)) diff --git a/src/wallet.h b/src/wallet.h index 7a4c0217cb..e2a9dbcd96 100644 --- a/src/wallet.h +++ b/src/wallet.h @@ -253,6 +253,8 @@ class CWallet : public CCryptoKeyStore, public CValidationInterface bool LoadCryptedKey(const CPubKey &vchPubKey, const std::vector &vchCryptedSecret); bool AddCScript(const CScript& redeemScript); bool LoadCScript(const CScript& redeemScript); + bool AddKeyTree(const KeyTree& pubkeys); + bool LoadKeyTree(const KeyTree& pubkeys); //! Adds a destination data tuple to the store, and saves it to disk bool AddDestData(const CTxDestination &dest, const std::string &key, const std::string &value); diff --git a/src/wallet_ismine.cpp b/src/wallet_ismine.cpp index 342b0ede24..f560fef815 100644 --- a/src/wallet_ismine.cpp +++ b/src/wallet_ismine.cpp @@ -7,6 +7,7 @@ #include "key.h" #include "keystore.h" +#include "keytree.h" #include "script/script.h" #include "script/standard.h" @@ -34,6 +35,20 @@ isminetype IsMine(const CKeyStore &keystore, const CTxDestination& dest) return IsMine(keystore, script); } +namespace +{ +struct IsMineKeyTreeFilter : public KeyTreeFilter +{ + const CKeyStore *keystore; + + IsMineKeyTreeFilter(const CKeyStore *keystore_) : keystore(keystore_) {} + bool operator()(const CPubKey& pubkey) + { + return keystore->HaveKey(pubkey.GetID()); + } +}; +} + isminetype IsMine(const CKeyStore &keystore, const CScript& scriptPubKey) { vector vSolutions; @@ -85,6 +100,17 @@ isminetype IsMine(const CKeyStore &keystore, const CScript& scriptPubKey) return ISMINE_SPENDABLE; break; } + case TX_TREESIG: + { + std::vector merkleroot = vSolutions[1]; + uint256 hash; + memcpy(hash.begin(), &merkleroot[0], 32); + KeyTree tree; + IsMineKeyTreeFilter filter(&keystore); + if (keystore.GetKeyTree(hash, tree) && HasMatch(&tree.root, &filter)) { + return ISMINE_SPENDABLE; + } + } case TX_TRUE: return ISMINE_SPENDABLE; } diff --git a/src/walletdb.cpp b/src/walletdb.cpp index 7b848ffcac..6602a4fe1d 100644 --- a/src/walletdb.cpp +++ b/src/walletdb.cpp @@ -115,6 +115,12 @@ bool CWalletDB::WriteCScript(const uint160& hash, const CScript& redeemScript) return Write(std::make_pair(std::string("cscript"), hash), redeemScript, false); } +bool CWalletDB::WriteKeyTree(const KeyTree& tree) +{ + nWalletDBUpdated++; + return Write(std::make_pair(std::string("keytree"), tree.hash), tree, false); +} + bool CWalletDB::WriteWatchOnly(const CScript &dest) { nWalletDBUpdated++; @@ -579,6 +585,23 @@ ReadKeyValue(CWallet* pwallet, CDataStream& ssKey, CDataStream& ssValue, return false; } } + else if (strType == "keytree") + { + uint256 hash; + ssKey >> hash; + KeyTree tree; + ssValue >> tree; + if (hash != tree.hash) + { + strErr = "Error reading wallet database: keytree root hash mismatch"; + return false; + } + if (!pwallet->LoadKeyTree(tree)) + { + strErr = "Error reading wallet database: LoadKeyTree failed"; + return false; + } + } else if (strType == "orderposnext") { ssValue >> pwallet->nOrderPosNext; diff --git a/src/walletdb.h b/src/walletdb.h index 87bf1644c6..0e94bfa871 100644 --- a/src/walletdb.h +++ b/src/walletdb.h @@ -94,6 +94,7 @@ class CWalletDB : public CDB bool WriteMasterKey(unsigned int nID, const CMasterKey& kMasterKey); bool WriteCScript(const uint160& hash, const CScript& redeemScript); + bool WriteKeyTree(const KeyTree& tree); bool WriteWatchOnly(const CScript &script); bool EraseWatchOnly(const CScript &script);