diff --git a/base_layer/core/src/transactions/key_manager/inner.rs b/base_layer/core/src/transactions/key_manager/inner.rs index ced54a4e37..ca36b77bf2 100644 --- a/base_layer/core/src/transactions/key_manager/inner.rs +++ b/base_layer/core/src/transactions/key_manager/inner.rs @@ -21,7 +21,6 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use std::{collections::HashMap, ops::Shl}; -use futures::lock::Mutex; use log::*; use rand::rngs::OsRng; use strum::IntoEnumIterator; @@ -53,6 +52,7 @@ use tari_key_manager::{ }, }; use tari_utilities::{hex::Hex, ByteArray}; +use tokio::sync::RwLock; use crate::{ one_sided::diffie_hellman_stealth_domain_hasher, @@ -87,7 +87,7 @@ use crate::{ hash_domain!(KeyManagerHashingDomain, "base_layer.core.key_manager"); pub struct TransactionKeyManagerInner { - key_managers: HashMap>>, + key_managers: HashMap>>, db: KeyManagerDatabase, master_seed: CipherSeed, crypto_factories: CryptoFactories, @@ -141,7 +141,7 @@ where TBackend: KeyManagerBackend + 'static }; self.key_managers.insert( branch.to_string(), - Mutex::new(KeyManager::::from( + RwLock::new(KeyManager::::from( self.master_seed.clone(), state.branch_seed, state.primary_key_index, @@ -155,7 +155,7 @@ where TBackend: KeyManagerBackend + 'static .key_managers .get(branch) .ok_or(KeyManagerServiceError::UnknownKeyBranch)? - .lock() + .write() .await; self.db.increment_key_index(branch)?; let index = km.increment_key_index(1); @@ -186,7 +186,7 @@ where TBackend: KeyManagerBackend + 'static .key_managers .get(branch) .ok_or(KeyManagerServiceError::UnknownKeyBranch)? - .lock() + .read() .await; Ok(km.derive_public_key(*index)?.key) }, @@ -220,7 +220,7 @@ where TBackend: KeyManagerBackend + 'static .key_managers .get(branch) .ok_or(KeyManagerServiceError::UnknownKeyBranch)? - .lock() + .read() .await; let current_index = km.key_index(); @@ -242,7 +242,7 @@ where TBackend: KeyManagerBackend + 'static .key_managers .get(branch) .ok_or(KeyManagerServiceError::UnknownKeyBranch)? - .lock() + .read() .await; let current_index = km.key_index(); @@ -268,7 +268,7 @@ where TBackend: KeyManagerBackend + 'static .key_managers .get(branch) .ok_or(KeyManagerServiceError::UnknownKeyBranch)? - .lock() + .write() .await; let current_index = km.key_index(); if index > current_index { @@ -296,7 +296,7 @@ where TBackend: KeyManagerBackend + 'static .key_managers .get(branch) .ok_or(KeyManagerServiceError::UnknownKeyBranch)? - .lock() + .read() .await; let key = km.get_private_key(*index)?; Ok(key) diff --git a/base_layer/core/src/transactions/key_manager/wrapper.rs b/base_layer/core/src/transactions/key_manager/wrapper.rs index 6a59eb7e6c..29a94b0844 100644 --- a/base_layer/core/src/transactions/key_manager/wrapper.rs +++ b/base_layer/core/src/transactions/key_manager/wrapper.rs @@ -94,7 +94,7 @@ impl KeyManagerInterface for TransactionKeyManagerWrapper + 'static { async fn add_new_branch + Send>(&self, branch: T) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .write() .await .add_key_manager_branch(&branch.into()) @@ -104,7 +104,7 @@ where TBackend: KeyManagerBackend + 'static &self, branch: T, ) -> Result<(TariKeyId, PublicKey), KeyManagerServiceError> { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_next_key(&branch.into()) @@ -112,7 +112,7 @@ where TBackend: KeyManagerBackend + 'static } async fn get_static_key + Send>(&self, branch: T) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_static_key(&branch.into()) @@ -120,7 +120,7 @@ where TBackend: KeyManagerBackend + 'static } async fn get_public_key_at_key_id(&self, key_id: &TariKeyId) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_public_key_at_key_id(key_id) @@ -132,7 +132,7 @@ where TBackend: KeyManagerBackend + 'static branch: T, key: &PublicKey, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .find_key_index(&branch.into(), key) @@ -144,7 +144,7 @@ where TBackend: KeyManagerBackend + 'static branch: T, index: u64, ) -> Result<(), KeyManagerServiceError> { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .update_current_key_index_if_higher(&branch.into(), index) @@ -152,7 +152,7 @@ where TBackend: KeyManagerBackend + 'static } async fn import_key(&self, private_key: PrivateKey) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .import_key(private_key) @@ -169,7 +169,7 @@ where TBackend: KeyManagerBackend + 'static spend_key_id: &TariKeyId, value: &PrivateKey, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_commitment(spend_key_id, value) @@ -182,7 +182,7 @@ where TBackend: KeyManagerBackend + 'static spending_key_id: &TariKeyId, value: u64, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .verify_mask(commitment, spending_key_id, value) @@ -197,7 +197,7 @@ where TBackend: KeyManagerBackend + 'static async fn get_next_spend_and_script_key_ids( &self, ) -> Result<(TariKeyId, PublicKey, TariKeyId, PublicKey), KeyManagerServiceError> { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_next_spend_and_script_key_ids() @@ -209,7 +209,7 @@ where TBackend: KeyManagerBackend + 'static secret_key_id: &TariKeyId, public_key: &PublicKey, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_diffie_hellman_shared_secret(secret_key_id, public_key) @@ -221,7 +221,7 @@ where TBackend: KeyManagerBackend + 'static secret_key_id: &TariKeyId, public_key: &PublicKey, ) -> Result, TransactionError> { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_diffie_hellman_stealth_domain_hasher(secret_key_id, public_key) @@ -233,7 +233,7 @@ where TBackend: KeyManagerBackend + 'static secret_key_id: &TariKeyId, offset: PrivateKey, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .import_add_offset_to_private_key(secret_key_id, offset) @@ -241,7 +241,7 @@ where TBackend: KeyManagerBackend + 'static } async fn get_spending_key_id(&self, public_spending_key: &PublicKey) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_spending_key_id(public_spending_key) @@ -254,7 +254,7 @@ where TBackend: KeyManagerBackend + 'static value: u64, min_value: u64, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .construct_range_proof(spend_key_id, value, min_value) @@ -269,7 +269,7 @@ where TBackend: KeyManagerBackend + 'static txi_version: &TransactionInputVersion, script_message: &[u8; 32], ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_script_signature(script_key_id, spend_key_id, value, txi_version, script_message) @@ -287,7 +287,7 @@ where TBackend: KeyManagerBackend + 'static kernel_features: &KernelFeatures, txo_type: TxoStage, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_partial_txo_kernel_signature( @@ -308,7 +308,7 @@ where TBackend: KeyManagerBackend + 'static spend_key_id: &TariKeyId, nonce_id: &TariKeyId, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_txo_kernel_signature_excess_with_offset(spend_key_id, nonce_id) @@ -320,7 +320,7 @@ where TBackend: KeyManagerBackend + 'static spend_key_id: &TariKeyId, nonce_id: &TariKeyId, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_txo_private_kernel_offset(spend_key_id, nonce_id) @@ -333,7 +333,7 @@ where TBackend: KeyManagerBackend + 'static custom_recovery_key_id: Option<&TariKeyId>, value: u64, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .encrypt_data_for_recovery(spend_key_id, custom_recovery_key_id, value) @@ -345,7 +345,7 @@ where TBackend: KeyManagerBackend + 'static output: &TransactionOutput, custom_recovery_key_id: Option<&TariKeyId>, ) -> Result<(TariKeyId, MicroTari), TransactionError> { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .try_output_key_recovery(output, custom_recovery_key_id) @@ -357,7 +357,7 @@ where TBackend: KeyManagerBackend + 'static script_key_ids: &[TariKeyId], sender_offset_key_ids: &[TariKeyId], ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_script_offset(script_key_ids, sender_offset_key_ids) @@ -369,7 +369,7 @@ where TBackend: KeyManagerBackend + 'static nonce_id: &TariKeyId, range_proof_type: RangeProofType, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_metadata_signature_ephemeral_commitment(nonce_id, range_proof_type) @@ -385,7 +385,7 @@ where TBackend: KeyManagerBackend + 'static metadata_signature_message: &[u8; 32], range_proof_type: RangeProofType, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_metadata_signature( @@ -409,7 +409,7 @@ where TBackend: KeyManagerBackend + 'static metadata_signature_message: &[u8; 32], range_proof_type: RangeProofType, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_receiver_partial_metadata_signature( @@ -433,7 +433,7 @@ where TBackend: KeyManagerBackend + 'static txo_version: &TransactionOutputVersion, metadata_signature_message: &[u8; 32], ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_sender_partial_metadata_signature( @@ -453,7 +453,7 @@ where TBackend: KeyManagerBackend + 'static amount: &PrivateKey, claim_public_key: &PublicKey, ) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .generate_burn_proof(spending_key, amount, claim_public_key) @@ -466,7 +466,7 @@ impl SecretTransactionKeyManagerInterface for TransactionKeyManagerWra where TBackend: KeyManagerBackend + 'static { async fn get_private_key(&self, key_id: &TariKeyId) -> Result { - (*self.transaction_key_manager_inner) + self.transaction_key_manager_inner .read() .await .get_private_key(key_id)