Skip to content

Commit

Permalink
Bugfix: wallet: Increment "update counter" only after actually making
Browse files Browse the repository at this point in the history
the applicable db changes to avoid potential races

Also does all "update counter" access via IncrementUpdateCounter
  • Loading branch information
luke-jr authored and random-zebra committed May 17, 2021
1 parent db5e692 commit 8aa93b9
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 69 deletions.
105 changes: 36 additions & 69 deletions src/wallet/walletdb.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,48 +69,39 @@ static std::atomic<unsigned int> nWalletDBUpdateCounter;

bool CWalletDB::WriteName(const std::string& strAddress, const std::string& strName)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::NAME), strAddress), strName);
return WriteIC(std::make_pair(std::string(DBKeys::NAME), strAddress), strName);
}

bool CWalletDB::EraseName(const std::string& strAddress)
{
// This should only be used for sending addresses, never for receiving addresses,
// receiving addresses must always have an address book entry if they're not change return.
nWalletDBUpdateCounter++;
return batch.Erase(std::make_pair(std::string(DBKeys::NAME), strAddress));
return EraseIC(std::make_pair(std::string(DBKeys::NAME), strAddress));
}

bool CWalletDB::WritePurpose(const std::string& strAddress, const std::string& strPurpose)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::PURPOSE), strAddress), strPurpose);
return WriteIC(std::make_pair(std::string(DBKeys::PURPOSE), strAddress), strPurpose);
}

bool CWalletDB::ErasePurpose(const std::string& strPurpose)
{
nWalletDBUpdateCounter++;
return batch.Erase(std::make_pair(std::string(DBKeys::PURPOSE), strPurpose));
return EraseIC(std::make_pair(std::string(DBKeys::PURPOSE), strPurpose));
}

bool CWalletDB::WriteTx(const CWalletTx& wtx)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::TX), wtx.GetHash()), wtx);
return WriteIC(std::make_pair(std::string(DBKeys::TX), wtx.GetHash()), wtx);
}

bool CWalletDB::EraseTx(uint256 hash)
{
nWalletDBUpdateCounter++;
return batch.Erase(std::make_pair(std::string(DBKeys::TX), hash));
return EraseIC(std::make_pair(std::string(DBKeys::TX), hash));
}

bool CWalletDB::WriteKey(const CPubKey& vchPubKey, const CPrivKey& vchPrivKey, const CKeyMetadata& keyMeta)
{
nWalletDBUpdateCounter++;

if (!batch.Write(std::make_pair(std::string(DBKeys::KEYMETA), vchPubKey),
keyMeta, false))
if (!WriteIC(std::make_pair(std::string(DBKeys::KEYMETA), vchPubKey), keyMeta, false))
return false;

// hash pubkey/privkey to accelerate wallet load
Expand All @@ -119,47 +110,42 @@ bool CWalletDB::WriteKey(const CPubKey& vchPubKey, const CPrivKey& vchPrivKey, c
vchKey.insert(vchKey.end(), vchPubKey.begin(), vchPubKey.end());
vchKey.insert(vchKey.end(), vchPrivKey.begin(), vchPrivKey.end());

return batch.Write(std::make_pair(std::string(DBKeys::KEY), vchPubKey), std::make_pair(vchPrivKey, Hash(vchKey.begin(), vchKey.end())), false);
return WriteIC(std::make_pair(std::string(DBKeys::KEY), vchPubKey), std::make_pair(vchPrivKey, Hash(vchKey.begin(), vchKey.end())), false);
}

bool CWalletDB::WriteCryptedKey(const CPubKey& vchPubKey,
const std::vector<unsigned char>& vchCryptedSecret,
const CKeyMetadata& keyMeta)
{
const bool fEraseUnencryptedKey = true;
nWalletDBUpdateCounter++;

if (!batch.Write(std::make_pair(std::string(DBKeys::KEYMETA), vchPubKey),
keyMeta))
if (!WriteIC(std::make_pair(std::string(DBKeys::KEYMETA), vchPubKey), keyMeta))
return false;

if (!batch.Write(std::make_pair(std::string(DBKeys::CRYPTED_KEY), vchPubKey), vchCryptedSecret, false))
if (!WriteIC(std::make_pair(std::string(DBKeys::CRYPTED_KEY), vchPubKey), vchCryptedSecret, false))
return false;
if (fEraseUnencryptedKey) {
batch.Erase(std::make_pair(std::string(DBKeys::KEY), vchPubKey));
EraseIC(std::make_pair(std::string(DBKeys::KEY), vchPubKey));
}

return true;
}

bool CWalletDB::WriteSaplingZKey(const libzcash::SaplingIncomingViewingKey &ivk,
const libzcash::SaplingExtendedSpendingKey &key,
const CKeyMetadata &keyMeta)
{
nWalletDBUpdateCounter++;

if (!batch.Write(std::make_pair(std::string(DBKeys::SAP_KEYMETA), ivk), keyMeta))
if (!WriteIC(std::make_pair(std::string(DBKeys::SAP_KEYMETA), ivk), keyMeta))
return false;

return batch.Write(std::make_pair(std::string(DBKeys::SAP_KEY), ivk), key, false);
return WriteIC(std::make_pair(std::string(DBKeys::SAP_KEY), ivk), key, false);
}

bool CWalletDB::WriteSaplingPaymentAddress(
const libzcash::SaplingPaymentAddress &addr,
const libzcash::SaplingIncomingViewingKey &ivk)
{
nWalletDBUpdateCounter++;

return batch.Write(std::make_pair(std::string(DBKeys::SAP_ADDR), addr), ivk, false);
return WriteIC(std::make_pair(std::string(DBKeys::SAP_ADDR), addr), ivk, false);
}

bool CWalletDB::WriteCryptedSaplingZKey(
Expand All @@ -168,25 +154,23 @@ bool CWalletDB::WriteCryptedSaplingZKey(
const CKeyMetadata &keyMeta)
{
const bool fEraseUnencryptedKey = true;
nWalletDBUpdateCounter++;
auto ivk = extfvk.fvk.in_viewing_key();

if (!batch.Write(std::make_pair(std::string(DBKeys::SAP_KEYMETA), ivk), keyMeta))
if (!WriteIC(std::make_pair(std::string(DBKeys::SAP_KEYMETA), ivk), keyMeta))
return false;

if (!batch.Write(std::make_pair(std::string(DBKeys::SAP_KEY_CRIPTED), ivk), std::make_pair(extfvk, vchCryptedSecret), false))
if (!WriteIC(std::make_pair(std::string(DBKeys::SAP_KEY_CRIPTED), ivk), std::make_pair(extfvk, vchCryptedSecret), false))
return false;

if (fEraseUnencryptedKey) {
batch.Erase(std::make_pair(std::string(DBKeys::SAP_KEY), ivk));
EraseIC(std::make_pair(std::string(DBKeys::SAP_KEY), ivk));
}
return true;
}

bool CWalletDB::WriteSaplingCommonOVK(const uint256& ovk)
{
nWalletDBUpdateCounter++;
return batch.Write(std::string(DBKeys::SAP_COMMON_OVK), ovk);
return WriteIC(std::string(DBKeys::SAP_COMMON_OVK), ovk);
}

bool CWalletDB::ReadSaplingCommonOVK(uint256& ovkRet)
Expand All @@ -196,40 +180,33 @@ bool CWalletDB::ReadSaplingCommonOVK(uint256& ovkRet)

bool CWalletDB::WriteWitnessCacheSize(int64_t nWitnessCacheSize)
{
nWalletDBUpdateCounter++;
return batch.Write(std::string(DBKeys::SAP_WITNESS_CACHE_SIZE), nWitnessCacheSize);
return WriteIC(std::string(DBKeys::SAP_WITNESS_CACHE_SIZE), nWitnessCacheSize);
}

bool CWalletDB::WriteMasterKey(unsigned int nID, const CMasterKey& kMasterKey)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::MASTER_KEY), nID), kMasterKey, true);
return WriteIC(std::make_pair(std::string(DBKeys::MASTER_KEY), nID), kMasterKey, true);
}

bool CWalletDB::WriteCScript(const uint160& hash, const CScript& redeemScript)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::CSCRIPT), hash), redeemScript, false);
return WriteIC(std::make_pair(std::string(DBKeys::CSCRIPT), hash), redeemScript, false);
}

bool CWalletDB::WriteWatchOnly(const CScript& dest)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::WATCHS), dest), '1');
return WriteIC(std::make_pair(std::string(DBKeys::WATCHS), dest), '1');
}

bool CWalletDB::EraseWatchOnly(const CScript& dest)
{

nWalletDBUpdateCounter++;
return batch.Erase(std::make_pair(std::string(DBKeys::WATCHS), dest));
return EraseIC(std::make_pair(std::string(DBKeys::WATCHS), dest));
}

bool CWalletDB::WriteBestBlock(const CBlockLocator& locator)
{
nWalletDBUpdateCounter++;
batch.Write(std::string(DBKeys::BESTBLOCK), CBlockLocator()); // Write empty block locator so versions that require a merkle branch automatically rescan
return batch.Write(std::string(DBKeys::BESTBLOCK_NOMERKLE), locator);
WriteIC(std::string(DBKeys::BESTBLOCK), CBlockLocator()); // Write empty block locator so versions that require a merkle branch automatically rescan
return WriteIC(std::string(DBKeys::BESTBLOCK_NOMERKLE), locator);
}

bool CWalletDB::ReadBestBlock(CBlockLocator& locator)
Expand All @@ -240,35 +217,30 @@ bool CWalletDB::ReadBestBlock(CBlockLocator& locator)

bool CWalletDB::WriteOrderPosNext(int64_t nOrderPosNext)
{
nWalletDBUpdateCounter++;
return batch.Write(std::string(DBKeys::ORDERPOSNEXT), nOrderPosNext);
return WriteIC(std::string(DBKeys::ORDERPOSNEXT), nOrderPosNext);
}

bool CWalletDB::WriteStakeSplitThreshold(const CAmount& nStakeSplitThreshold)
{
nWalletDBUpdateCounter++;
return batch.Write(std::string(DBKeys::STAKE_SPLIT_THRESHOLD), nStakeSplitThreshold);
return WriteIC(std::string(DBKeys::STAKE_SPLIT_THRESHOLD), nStakeSplitThreshold);
}

bool CWalletDB::WriteUseCustomFee(bool fUse)
{
nWalletDBUpdateCounter++;
return batch.Write(std::string(DBKeys::USE_CUSTOM_FEE), fUse);
return WriteIC(std::string(DBKeys::USE_CUSTOM_FEE), fUse);
}

bool CWalletDB::WriteCustomFeeValue(const CAmount& nFee)
{
nWalletDBUpdateCounter++;
return batch.Write(std::string(DBKeys::CUSTOM_FEE_VALUE), nFee);
return WriteIC(std::string(DBKeys::CUSTOM_FEE_VALUE), nFee);
}

bool CWalletDB::WriteAutoCombineSettings(bool fEnable, CAmount nCombineThreshold)
{
nWalletDBUpdateCounter++;
std::pair<bool, CAmount> pSettings;
pSettings.first = fEnable;
pSettings.second = nCombineThreshold;
return batch.Write(std::string(DBKeys::AUTOCOMBINE), pSettings, true);
return WriteIC(std::string(DBKeys::AUTOCOMBINE), pSettings, true);
}

bool CWalletDB::ReadPool(int64_t nPool, CKeyPool& keypool)
Expand All @@ -278,14 +250,12 @@ bool CWalletDB::ReadPool(int64_t nPool, CKeyPool& keypool)

bool CWalletDB::WritePool(int64_t nPool, const CKeyPool& keypool)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::POOL), nPool), keypool);
return WriteIC(std::make_pair(std::string(DBKeys::POOL), nPool), keypool);
}

bool CWalletDB::ErasePool(int64_t nPool)
{
nWalletDBUpdateCounter++;
return batch.Erase(std::make_pair(std::string(DBKeys::POOL), nPool));
return EraseIC(std::make_pair(std::string(DBKeys::POOL), nPool));
}

bool CWalletDB::WriteMinVersion(int nVersion)
Expand All @@ -295,10 +265,9 @@ bool CWalletDB::WriteMinVersion(int nVersion)

bool CWalletDB::WriteHDChain(const CHDChain& chain)
{
nWalletDBUpdateCounter++;
std::string key = chain.chainType == HDChain::ChainCounterType::Sapling ?
DBKeys::SAP_HDCHAIN : DBKeys::HDCHAIN;
return batch.Write(key, chain);
return WriteIC(key, chain);
}

DBErrors CWalletDB::ReorderTransactions(CWallet* pwallet)
Expand Down Expand Up @@ -1140,14 +1109,12 @@ bool CWalletDB::VerifyDatabaseFile(const std::string& walletFile, const fs::path

bool CWalletDB::WriteDestData(const std::string& address, const std::string& key, const std::string& value)
{
nWalletDBUpdateCounter++;
return batch.Write(std::make_pair(std::string(DBKeys::DESTDATA), std::make_pair(address, key)), value);
return WriteIC(std::make_pair(std::string(DBKeys::DESTDATA), std::make_pair(address, key)), value);
}

bool CWalletDB::EraseDestData(const std::string& address, const std::string& key)
{
nWalletDBUpdateCounter++;
return batch.Erase(std::make_pair(std::string(DBKeys::DESTDATA), std::make_pair(address, key)));
return EraseIC(std::make_pair(std::string(DBKeys::DESTDATA), std::make_pair(address, key)));
}

void CWalletDB::IncrementUpdateCounter()
Expand Down
21 changes: 21 additions & 0 deletions src/wallet/walletdb.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ class CKeyMetadata
*/
class CWalletDB
{
private:
template <typename K, typename T>
bool WriteIC(const K& key, const T& value, bool fOverwrite = true)
{
if (!batch.Write(key, value, fOverwrite)) {
return false;
}
IncrementUpdateCounter();
return true;
}

template <typename K>
bool EraseIC(const K& key)
{
if (!batch.Erase(key)) {
return false;
}
IncrementUpdateCounter();
return true;
}

public:
CWalletDB(CWalletDBWrapper& dbw, const char* pszMode = "r+", bool _fFlushOnClose = true) :
batch(dbw, pszMode, _fFlushOnClose)
Expand Down

0 comments on commit 8aa93b9

Please sign in to comment.