diff --git a/applications/tari_base_node/src/command_handler.rs b/applications/tari_base_node/src/command_handler.rs index bab6c16228..e55f0bda86 100644 --- a/applications/tari_base_node/src/command_handler.rs +++ b/applications/tari_base_node/src/command_handler.rs @@ -406,12 +406,16 @@ impl CommandHandler { }); } - pub fn get_peer(&self, node_id: NodeId) { + pub fn get_peer(&self, partial: Vec, original_str: String) { let peer_manager = self.peer_manager.clone(); self.executor.spawn(async move { - match peer_manager.find_by_node_id(&node_id).await { - Ok(peer) => { + match peer_manager.find_all_starts_with(&partial).await { + Ok(peers) if peers.is_empty() => { + println!("No peer matching '{}'", original_str); + }, + Ok(peers) => { + let peer = peers.first().unwrap(); let eid = EmojiId::from_pubkey(&peer.public_key); println!("Emoji ID: {}", eid); println!("Public Key: {}", peer.public_key); diff --git a/applications/tari_base_node/src/parser.rs b/applications/tari_base_node/src/parser.rs index 899338e488..7eb28f3dc4 100644 --- a/applications/tari_base_node/src/parser.rs +++ b/applications/tari_base_node/src/parser.rs @@ -41,7 +41,12 @@ use tari_app_utilities::utilities::{ parse_emoji_id_or_public_key_or_node_id, }; use tari_common_types::types::{Commitment, PrivateKey, PublicKey, Signature}; -use tari_core::{crypto::tari_utilities::hex::from_hex, proof_of_work::PowAlgorithm, tari_utilities::hex::Hex}; +use tari_core::{ + crypto::tari_utilities::hex::from_hex, + proof_of_work::PowAlgorithm, + tari_utilities::{hex::Hex, ByteArray}, +}; +use tari_crypto::tari_utilities::hex; use tari_shutdown::Shutdown; /// Enum representing commands used by the basenode @@ -515,20 +520,27 @@ impl Parser { } fn process_get_peer<'a, I: Iterator>(&mut self, mut args: I) { - let node_id = match args + let (original_str, partial) = match args .next() - .map(parse_emoji_id_or_public_key_or_node_id) + .map(|s| { + parse_emoji_id_or_public_key_or_node_id(s) + .map(either_to_node_id) + .map(|n| (s.to_string(), n.to_vec())) + .or_else(|| { + let bytes = hex::from_hex(&s[..s.len() - (s.len() % 2)]).unwrap_or_default(); + Some((s.to_string(), bytes)) + }) + }) .flatten() - .map(either_to_node_id) { Some(n) => n, None => { - println!("Usage: get-peer [NodeId|PublicKey|EmojiId]"); + println!("Usage: get-peer [Partial NodeId | PublicKey | EmojiId]"); return; }, }; - self.command_handler.get_peer(node_id) + self.command_handler.get_peer(partial, original_str) } /// Function to process the list-peers command diff --git a/comms/src/peer_manager/manager.rs b/comms/src/peer_manager/manager.rs index d7df51a2ea..5cd99344d1 100644 --- a/comms/src/peer_manager/manager.rs +++ b/comms/src/peer_manager/manager.rs @@ -88,6 +88,11 @@ impl PeerManager { self.peer_storage.read().await.find_by_node_id(node_id) } + /// Find the peer with the provided substring. This currently only compares the given bytes to the NodeId + pub async fn find_all_starts_with(&self, partial: &[u8]) -> Result, PeerManagerError> { + self.peer_storage.read().await.find_all_starts_with(partial) + } + /// Find the peer with the provided PublicKey pub async fn find_by_public_key(&self, public_key: &CommsPublicKey) -> Result { self.peer_storage.read().await.find_by_public_key(public_key) diff --git a/comms/src/peer_manager/node_id.rs b/comms/src/peer_manager/node_id.rs index e6cc0c1259..adfce98268 100644 --- a/comms/src/peer_manager/node_id.rs +++ b/comms/src/peer_manager/node_id.rs @@ -41,21 +41,20 @@ use tari_crypto::tari_utilities::{ }; use thiserror::Error; -const NODE_ID_ARRAY_SIZE: usize = 13; // 104-bit as per RFC-0151 -type NodeIdArray = [u8; NODE_ID_ARRAY_SIZE]; +type NodeIdArray = [u8; NodeId::BYTE_SIZE]; pub type NodeDistance = XorDistance; // or HammingDistance #[derive(Debug, Error, Clone)] pub enum NodeIdError { - #[error("Incorrect byte count (expected {} bytes)", NODE_ID_ARRAY_SIZE)] + #[error("Incorrect byte count (expected {} bytes)", NodeId::BYTE_SIZE)] IncorrectByteCount, #[error("Invalid digest output size")] InvalidDigestOutputSize, } //------------------------------------- XOR Metric -----------------------------------------------// -const NODE_XOR_DISTANCE_ARRAY_SIZE: usize = NODE_ID_ARRAY_SIZE; +const NODE_XOR_DISTANCE_ARRAY_SIZE: usize = NodeId::BYTE_SIZE; type NodeXorDistanceArray = [u8; NODE_XOR_DISTANCE_ARRAY_SIZE]; #[derive(Clone, Debug, Eq, PartialOrd, Ord, Default)] @@ -145,7 +144,7 @@ impl HammingDistance { /// Returns the maximum distance. pub const fn max_distance() -> Self { - Self([NODE_ID_ARRAY_SIZE as u8 * 8; NODE_HAMMING_DISTANCE_ARRAY_SIZE]) + Self([NodeId::BYTE_SIZE as u8 * 8; NODE_HAMMING_DISTANCE_ARRAY_SIZE]) } } @@ -172,7 +171,7 @@ impl PartialEq for HammingDistance { /// Calculate the Exclusive OR between the node_id x and y. fn xor(x: &NodeIdArray, y: &NodeIdArray) -> NodeIdArray { - let mut nd = [0u8; NODE_ID_ARRAY_SIZE]; + let mut nd = [0u8; NodeId::BYTE_SIZE]; for i in 0..nd.len() { nd[i] = x[i] ^ y[i]; } @@ -224,6 +223,9 @@ impl fmt::Display for NodeDistance { pub struct NodeId(NodeIdArray); impl NodeId { + /// 104-bit/13 byte as per RFC-0151 + pub const BYTE_SIZE: usize = 13; + /// Construct a new node id on the origin pub fn new() -> Self { Default::default() @@ -232,9 +234,9 @@ impl NodeId { /// Derive a node id from a public key: node_id=hash(public_key) pub fn from_key(key: &K) -> Self { let bytes = key.as_bytes(); - let mut buf = [0u8; NODE_ID_ARRAY_SIZE]; - VarBlake2b::new(NODE_ID_ARRAY_SIZE) - .expect("NODE_ID_ARRAY_SIZE is invalid") + let mut buf = [0u8; NodeId::BYTE_SIZE]; + VarBlake2b::new(NodeId::BYTE_SIZE) + .expect("NodeId::NODE_ID_ARRAY_SIZE is invalid") .chain(bytes) .finalize_variable(|hash| { // Safety: output size and buf size are equal @@ -347,9 +349,9 @@ impl TryFrom<&[u8]> for NodeId { /// Construct a node id from 32 bytes fn try_from(elements: &[u8]) -> Result { - if elements.len() >= NODE_ID_ARRAY_SIZE { - let mut bytes = [0; NODE_ID_ARRAY_SIZE]; - bytes.copy_from_slice(&elements[0..NODE_ID_ARRAY_SIZE]); + if elements.len() >= NodeId::BYTE_SIZE { + let mut bytes = [0; NodeId::BYTE_SIZE]; + bytes.copy_from_slice(&elements[0..NodeId::BYTE_SIZE]); Ok(NodeId(bytes)) } else { Err(NodeIdError::IncorrectByteCount) @@ -577,7 +579,7 @@ mod test { let hamming_dist = HammingDistance::from_node_ids(&node_id1, &node_id2); assert_eq!(hamming_dist, HammingDistance([18])); - let node_max = NodeId::from_bytes(&[255; NODE_ID_ARRAY_SIZE]).unwrap(); + let node_max = NodeId::from_bytes(&[255; NodeId::BYTE_SIZE]).unwrap(); let node_min = NodeId::default(); let hamming_dist = HammingDistance::from_node_ids(&node_max, &node_min); diff --git a/comms/src/peer_manager/peer_storage.rs b/comms/src/peer_manager/peer_storage.rs index 1d79bc486f..1f1d4517a8 100644 --- a/comms/src/peer_manager/peer_storage.rs +++ b/comms/src/peer_manager/peer_storage.rs @@ -36,6 +36,7 @@ use log::*; use multiaddr::Multiaddr; use rand::{rngs::OsRng, seq::SliceRandom}; use std::{collections::HashMap, time::Duration}; +use tari_crypto::tari_utilities::ByteArray; use tari_storage::{IterationResult, KeyValueStore}; const LOG_TARGET: &str = "comms::peer_manager::peer_storage"; @@ -216,6 +217,23 @@ where DS: KeyValueStore }) } + pub fn find_all_starts_with(&self, partial: &[u8]) -> Result, PeerManagerError> { + if partial.is_empty() || partial.len() > NodeId::BYTE_SIZE { + return Ok(Vec::new()); + } + + let keys = self + .node_id_index + .iter() + .filter(|(k, _)| { + let l = partial.len(); + &k.as_bytes()[..l] == partial + }) + .map(|(_, id)| *id) + .collect::>(); + self.peer_db.get_many(&keys).map_err(PeerManagerError::DatabaseError) + } + /// Find the peer with the provided PublicKey pub fn find_by_public_key(&self, public_key: &CommsPublicKey) -> Result { let peer_key = self diff --git a/comms/src/peer_manager/wrapper.rs b/comms/src/peer_manager/wrapper.rs index 5176a8ba32..b046ee2a81 100644 --- a/comms/src/peer_manager/wrapper.rs +++ b/comms/src/peer_manager/wrapper.rs @@ -53,6 +53,13 @@ where T: KeyValueStore self.inner.get(key) } + fn get_many(&self, keys: &[PeerId]) -> Result, KeyValStoreError> { + if keys.iter().any(|k| k == &MIGRATION_VERSION_KEY) { + return Ok(Vec::new()); + } + self.inner.get_many(keys) + } + fn size(&self) -> Result { self.inner.size().map(|s| s.saturating_sub(1)) } diff --git a/infrastructure/storage/src/key_val_store/error.rs b/infrastructure/storage/src/key_val_store/error.rs index 24d3fa2636..11138e4942 100644 --- a/infrastructure/storage/src/key_val_store/error.rs +++ b/infrastructure/storage/src/key_val_store/error.rs @@ -20,6 +20,7 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +use crate::lmdb_store::LMDBError; use thiserror::Error; #[derive(Debug, Error, Clone)] @@ -35,3 +36,9 @@ pub enum KeyValStoreError { #[error("The specified key did not exist in the key-val store")] KeyNotFound, } + +impl From for KeyValStoreError { + fn from(e: LMDBError) -> Self { + KeyValStoreError::DatabaseError(format!("{:?}", e)) + } +} diff --git a/infrastructure/storage/src/key_val_store/hmap_database.rs b/infrastructure/storage/src/key_val_store/hmap_database.rs index e65a6c9036..941582e6d0 100644 --- a/infrastructure/storage/src/key_val_store/hmap_database.rs +++ b/infrastructure/storage/src/key_val_store/hmap_database.rs @@ -115,6 +115,17 @@ impl KeyValueStore for HashmapDatabase Result, KeyValStoreError> { + keys.iter() + .filter_map(|k| match self.get(k) { + Ok(Some(v)) => Some(Ok(v)), + Ok(None) => None, + Err(e) => Some(Err(e)), + }) + .collect() + } + /// Returns the total number of entries recorded in the key-value database. fn size(&self) -> Result { self.len() diff --git a/infrastructure/storage/src/key_val_store/key_val_store.rs b/infrastructure/storage/src/key_val_store/key_val_store.rs index 80e3c64f2a..4b2d1526dd 100644 --- a/infrastructure/storage/src/key_val_store/key_val_store.rs +++ b/infrastructure/storage/src/key_val_store/key_val_store.rs @@ -44,6 +44,9 @@ pub trait KeyValueStore { /// Get the value corresponding to the provided key from the key-value database. fn get(&self, key: &K) -> Result, KeyValStoreError>; + /// Get the value corresponding to the provided key from the key-value database. + fn get_many(&self, keys: &[K]) -> Result, KeyValStoreError>; + /// Returns the total number of entries recorded in the key-value database. fn size(&self) -> Result; diff --git a/infrastructure/storage/src/key_val_store/lmdb_database.rs b/infrastructure/storage/src/key_val_store/lmdb_database.rs index f6fe543579..770cd62902 100644 --- a/infrastructure/storage/src/key_val_store/lmdb_database.rs +++ b/infrastructure/storage/src/key_val_store/lmdb_database.rs @@ -63,46 +63,50 @@ where { /// Inserts a key-value pair into the key-value database. fn insert(&self, key: K, value: V) -> Result<(), KeyValStoreError> { - self.inner - .insert::(&key, &value) - .map_err(|e| KeyValStoreError::DatabaseError(format!("{:?}", e))) + self.inner.insert::(&key, &value).map_err(Into::into) } /// Get the value corresponding to the provided key from the key-value database. fn get(&self, key: &K) -> Result, KeyValStoreError> + where for<'t> V: serde::de::DeserializeOwned { + self.inner.get::(key).map_err(Into::into) + } + + /// Get the values corresponding to the provided keys from the key-value database. + fn get_many(&self, keys: &[K]) -> Result, KeyValStoreError> where for<'t> V: serde::de::DeserializeOwned { self.inner - .get::(key) - .map_err(|e| KeyValStoreError::DatabaseError(format!("{:?}", e))) + .with_read_transaction(|access| { + keys.iter() + .filter_map(|k| match access.get::(k) { + Ok(Some(v)) => Some(Ok(v)), + Ok(None) => None, + Err(e) => Some(Err(e)), + }) + .collect::, _>>() + })? + .map_err(Into::into) } /// Returns the total number of entries recorded in the key-value database. fn size(&self) -> Result { - self.inner - .len() - .map_err(|e| KeyValStoreError::DatabaseError(format!("{:?}", e))) + self.inner.len().map_err(Into::into) } /// Iterate over all the stored records and execute the function `f` for each pair in the key-value database. fn for_each(&self, f: F) -> Result<(), KeyValStoreError> where F: FnMut(Result<(K, V), KeyValStoreError>) -> IterationResult { - self.inner - .for_each::(f) - .map_err(|e| KeyValStoreError::DatabaseError(format!("{:?}", e))) + self.inner.for_each::(f).map_err(Into::into) } /// Checks whether a record exist in the key-value database that corresponds to the provided `key`. fn exists(&self, key: &K) -> Result { - self.inner - .contains_key::(key) - .map_err(|e| KeyValStoreError::DatabaseError(format!("{:?}", e))) + self.inner.contains_key::(key).map_err(Into::into) } /// Remove the record from the key-value database that corresponds with the provided `key`. fn delete(&self, key: &K) -> Result<(), KeyValStoreError> { - self.inner - .remove::(key) - .map_err(|e| KeyValStoreError::DatabaseError(format!("{:?}", e))) + self.inner.remove::(key).map_err(Into::into) } } diff --git a/infrastructure/storage/src/lmdb_store/store.rs b/infrastructure/storage/src/lmdb_store/store.rs index fbb34bfeee..256e95106a 100644 --- a/infrastructure/storage/src/lmdb_store/store.rs +++ b/infrastructure/storage/src/lmdb_store/store.rs @@ -629,18 +629,13 @@ impl LMDBDatabase { } /// Create a read-only transaction on the current database and execute the instructions given in the closure. The - /// transaction is automatically committed when the closure goes out of scope. You may provide the results of the - /// transaction to the calling scope by populating a `Vec` with the results of `txn.get(k)`. Otherwise, if the - /// results are not needed, or you did not call `get`, just return `Ok(None)`. - pub fn with_read_transaction(&self, f: F) -> Result>, LMDBError> - where - V: serde::de::DeserializeOwned, - F: FnOnce(LMDBReadTransaction) -> Result>, LMDBError>, - { + /// transaction is automatically committed when the closure goes out of scope. + pub fn with_read_transaction(&self, f: F) -> Result + where F: FnOnce(LMDBReadTransaction) -> R { let txn = ReadTransaction::new(self.env.clone())?; let access = txn.access(); let wrapper = LMDBReadTransaction { db: &self.db, access }; - f(wrapper) + Ok(f(wrapper)) } /// Create a transaction with write access on the current table. diff --git a/infrastructure/storage/tests/lmdb.rs b/infrastructure/storage/tests/lmdb.rs index de9d747e8f..fb0940e5e7 100644 --- a/infrastructure/storage/tests/lmdb.rs +++ b/infrastructure/storage/tests/lmdb.rs @@ -181,16 +181,14 @@ fn transactions() { { let (users, db) = insert_all_users("transactions"); // Test the `exists` and value retrieval functions - let res = db.with_read_transaction::<_, User>(|txn| { + db.with_read_transaction(|txn| { for user in users.iter() { assert!(txn.exists(&user.id).unwrap()); let check: User = txn.get(&user.id).unwrap().unwrap(); assert_eq!(check, *user); } - Ok(None) - }); - println!("{:?}", res); - assert!(res.unwrap().is_none()); + }) + .unwrap(); } clean_up("transactions"); // In Windows file handles must be released before files can be deleted }