From fbf8eb83353392b977b34bf3d1870ca25320414e Mon Sep 17 00:00:00 2001 From: Stan Bondi Date: Sun, 24 Oct 2021 12:27:20 +0400 Subject: [PATCH] fix: check SAF message inflight and check stored_at is in past (#3444) Description --- - Keeps track of inflight SAF requests and only accepts responses for requests that are inflight - Checks that `stored_at` is in the past - Fixes #3412, #3410 Motivation and Context --- See #3412, #3410 How Has This Been Tested? --- - New/existing unit/integration tests - memorynet - Manually --- applications/tari_base_node/src/bootstrap.rs | 7 +- .../tari_console_wallet/src/init/mod.rs | 7 +- base_layer/wallet/tests/wallet/mod.rs | 7 +- base_layer/wallet_ffi/src/lib.rs | 7 +- comms/dht/examples/memory_net/utilities.rs | 6 +- comms/dht/src/actor.rs | 1 + comms/dht/src/broadcast_strategy.rs | 2 + comms/dht/src/builder.rs | 3 +- comms/dht/src/config.rs | 50 +-- comms/dht/src/dht.rs | 17 +- comms/dht/src/lib.rs | 1 + comms/dht/src/outbound/broadcast.rs | 2 +- comms/dht/src/outbound/message_params.rs | 6 + comms/dht/src/store_forward/config.rs | 73 ++++ comms/dht/src/store_forward/error.rs | 13 +- comms/dht/src/store_forward/local_state.rs | 76 ++++ comms/dht/src/store_forward/mod.rs | 5 + .../src/store_forward/saf_handler/layer.rs | 7 +- .../store_forward/saf_handler/middleware.rs | 7 +- .../dht/src/store_forward/saf_handler/mod.rs | 4 +- .../dht/src/store_forward/saf_handler/task.rs | 334 ++++++++++++++---- comms/dht/src/store_forward/service.rs | 54 ++- comms/dht/src/store_forward/store.rs | 18 +- .../src/test_utils/store_and_forward_mock.rs | 25 +- comms/dht/tests/dht.rs | 2 +- 25 files changed, 564 insertions(+), 170 deletions(-) create mode 100644 comms/dht/src/store_forward/config.rs create mode 100644 comms/dht/src/store_forward/local_state.rs diff --git a/applications/tari_base_node/src/bootstrap.rs b/applications/tari_base_node/src/bootstrap.rs index 29b6404638..dc551a6379 100644 --- a/applications/tari_base_node/src/bootstrap.rs +++ b/applications/tari_base_node/src/bootstrap.rs @@ -28,7 +28,7 @@ use log::*; use tari_app_utilities::{consts, identity_management, utilities::create_transport_type}; use tari_common::{configuration::bootstrap::ApplicationType, GlobalConfig}; use tari_comms::{peer_manager::Peer, protocol::rpc::RpcServer, NodeIdentity, UnspawnedCommsNode}; -use tari_comms_dht::{DbConnectionUrl, Dht, DhtConfig}; +use tari_comms_dht::{store_forward::SafConfig, DbConnectionUrl, Dht, DhtConfig}; use tari_core::{ base_node, base_node::{ @@ -251,7 +251,10 @@ where B: BlockchainBackend + 'static auto_join: true, allow_test_addresses: self.config.allow_test_addresses, flood_ban_max_msg_count: self.config.flood_ban_max_msg_count, - saf_msg_validity: self.config.saf_expiry_duration, + saf_config: SafConfig { + msg_validity: self.config.saf_expiry_duration, + ..Default::default() + }, dedup_cache_capacity: self.config.dedup_cache_capacity, ..Default::default() }, diff --git a/applications/tari_console_wallet/src/init/mod.rs b/applications/tari_console_wallet/src/init/mod.rs index 1a68dd7a88..2f2249d772 100644 --- a/applications/tari_console_wallet/src/init/mod.rs +++ b/applications/tari_console_wallet/src/init/mod.rs @@ -34,7 +34,7 @@ use tari_comms::{ types::CommsSecretKey, NodeIdentity, }; -use tari_comms_dht::{DbConnectionUrl, DhtConfig}; +use tari_comms_dht::{store_forward::SafConfig, DbConnectionUrl, DhtConfig}; use tari_core::transactions::CryptoFactories; use tari_p2p::{ auto_update::AutoUpdateConfig, @@ -337,7 +337,10 @@ pub async fn init_wallet( auto_join: true, allow_test_addresses: config.allow_test_addresses, flood_ban_max_msg_count: config.flood_ban_max_msg_count, - saf_msg_validity: config.saf_expiry_duration, + saf_config: SafConfig { + msg_validity: config.saf_expiry_duration, + ..Default::default() + }, dedup_cache_capacity: config.dedup_cache_capacity, ..Default::default() }, diff --git a/base_layer/wallet/tests/wallet/mod.rs b/base_layer/wallet/tests/wallet/mod.rs index fde22fc9fc..49004f603a 100644 --- a/base_layer/wallet/tests/wallet/mod.rs +++ b/base_layer/wallet/tests/wallet/mod.rs @@ -46,7 +46,7 @@ use tari_comms::{ peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerFlags}, types::{CommsPublicKey, CommsSecretKey}, }; -use tari_comms_dht::DhtConfig; +use tari_comms_dht::{store_forward::SafConfig, DhtConfig}; use tari_core::transactions::{ tari_amount::{uT, MicroTari}, test_helpers::{create_unblinded_output, TestParams}, @@ -119,7 +119,10 @@ async fn create_wallet( dht: DhtConfig { discovery_request_timeout: Duration::from_secs(1), auto_join: true, - saf_auto_request: true, + saf_config: SafConfig { + auto_request: true, + ..Default::default() + }, ..Default::default() }, allow_test_addresses: true, diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 673148b164..f41f447071 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -157,7 +157,7 @@ use tari_comms::{ transports::MemoryTransport, types::CommsSecretKey, }; -use tari_comms_dht::{DbConnectionUrl, DhtConfig}; +use tari_comms_dht::{store_forward::SafConfig, DbConnectionUrl, DhtConfig}; use tari_core::transactions::{tari_amount::MicroTari, transaction::OutputFeatures, CryptoFactories}; use tari_p2p::{ transport::{TorConfig, TransportType, TransportType::Tor}, @@ -2591,7 +2591,10 @@ pub unsafe extern "C" fn comms_config_create( discovery_request_timeout: Duration::from_secs(discovery_timeout_in_secs), database_url: DbConnectionUrl::File(dht_database_path), auto_join: true, - saf_msg_validity: Duration::from_secs(saf_message_duration_in_secs), + saf_config: SafConfig { + msg_validity: Duration::from_secs(saf_message_duration_in_secs), + ..Default::default() + }, ..Default::default() }, // TODO: This should be set to false for non-test wallets. See the `allow_test_addresses` field diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index 4b875675b8..9cf6340ebc 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -53,6 +53,7 @@ use tari_comms_dht::{ envelope::NodeDestination, inbound::DecryptedDhtMessage, outbound::OutboundEncryption, + store_forward::SafConfig, Dht, DhtConfig, }; @@ -912,7 +913,10 @@ async fn setup_comms_dht( let dht = Dht::builder() .with_config(DhtConfig { - saf_auto_request, + saf_config: SafConfig { + auto_request: saf_auto_request, + ..Default::default() + }, auto_join: false, discovery_request_timeout: Duration::from_secs(15), num_neighbouring_nodes, diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 0e32ccdf28..7de0af251a 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -467,6 +467,7 @@ impl DhtActor { .map(|p| p.node_id) .collect()) }, + SelectedPeers(peers) => Ok(peers), Broadcast(exclude) => { let connections = connectivity .select_connections(ConnectivitySelection::random_nodes( diff --git a/comms/dht/src/broadcast_strategy.rs b/comms/dht/src/broadcast_strategy.rs index 9077cc3a58..a6ed0b6cff 100644 --- a/comms/dht/src/broadcast_strategy.rs +++ b/comms/dht/src/broadcast_strategy.rs @@ -61,6 +61,7 @@ pub enum BroadcastStrategy { /// Send directly to destination if connected but otherwise send to all n nearest Communication Nodes DirectOrClosestNodes(Box), Broadcast(Vec), + SelectedPeers(Vec), /// Propagate to a set of closest neighbours and random peers Propagate(NodeDestination, Vec), } @@ -77,6 +78,7 @@ impl fmt::Display for BroadcastStrategy { Random(n, excluded) => write!(f, "Random({}, {} excluded)", n, excluded.len()), Broadcast(excluded) => write!(f, "Broadcast({} excluded)", excluded.len()), Propagate(destination, excluded) => write!(f, "Propagate({}, {} excluded)", destination, excluded.len(),), + SelectedPeers(peers) => write!(f, "SelectedPeers({} peer(s))", peers.len()), } } } diff --git a/comms/dht/src/builder.rs b/comms/dht/src/builder.rs index a35da81ee0..7bf9ce19fa 100644 --- a/comms/dht/src/builder.rs +++ b/comms/dht/src/builder.rs @@ -66,7 +66,7 @@ impl DhtBuilder { } pub fn set_auto_store_and_forward_requests(&mut self, enabled: bool) -> &mut Self { - self.config.saf_auto_request = enabled; + self.config.saf_config.auto_request = enabled; self } @@ -112,6 +112,7 @@ impl DhtBuilder { pub fn with_num_neighbouring_nodes(&mut self, n: usize) -> &mut Self { self.config.num_neighbouring_nodes = n; + self.config.saf_config.num_neighbouring_nodes = n; self } diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index c38d0b5cb0..ccfd260c9b 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -20,7 +20,12 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use crate::{network_discovery::NetworkDiscoveryConfig, storage::DbConnectionUrl, version::DhtProtocolVersion}; +use crate::{ + network_discovery::NetworkDiscoveryConfig, + storage::DbConnectionUrl, + store_forward::SafConfig, + version::DhtProtocolVersion, +}; use std::time::Duration; #[derive(Debug, Clone)] @@ -33,10 +38,10 @@ pub struct DhtConfig { /// Default: 20 pub outbound_buffer_size: usize, /// The maximum number of peer nodes that a message has to be closer to, to be considered a neighbour - /// Default: [DEFAULT_NUM_NEIGHBOURING_NODES](self::DEFAULT_NUM_NEIGHBOURING_NODES) + /// Default: 8 pub num_neighbouring_nodes: usize, /// Number of random peers to include - /// Default: [DEFAULT_NUM_RANDOM_NODES](self::DEFAULT_NUM_RANDOM_NODES) + /// Default: 4 pub num_random_nodes: usize, /// Send to this many peers when using the broadcast strategy /// Default: 8 @@ -44,30 +49,7 @@ pub struct DhtConfig { /// Send to this many peers when using the propagate strategy /// Default: 4 pub propagation_factor: usize, - /// The amount of seconds added to the current time (Utc) which will then be used to check if the message has - /// expired or not when processing the message - /// Default: 10800 - pub saf_msg_validity: Duration, - /// The maximum number of messages that can be stored using the Store-and-forward middleware. - /// Default: 100,000 - pub saf_msg_storage_capacity: usize, - /// A request to retrieve stored messages will be ignored if the requesting node is - /// not within one of this nodes _n_ closest nodes. - /// Default 8 - pub saf_num_closest_nodes: usize, - /// The maximum number of messages to return from a store and forward retrieval request. - /// Default: 100 - pub saf_max_returned_messages: usize, - /// The time-to-live duration used for storage of low priority messages by the Store-and-forward middleware. - /// Default: 6 hours - pub saf_low_priority_msg_storage_ttl: Duration, - /// The time-to-live duration used for storage of high priority messages by the Store-and-forward middleware. - /// Default: 3 days - pub saf_high_priority_msg_storage_ttl: Duration, - /// The limit on the message size to store in SAF storage in bytes. Default 500 KiB - pub saf_max_message_size: usize, - /// When true, store and forward messages are requested from peers on connect (Default: true) - pub saf_auto_request: bool, + pub saf_config: SafConfig, /// The max capacity of the message hash cache /// Default: 2,500 pub dedup_cache_capacity: usize, @@ -127,7 +109,10 @@ impl DhtConfig { pub fn default_local_test() -> Self { Self { database_url: DbConnectionUrl::Memory, - saf_auto_request: false, + saf_config: SafConfig { + auto_request: false, + ..Default::default() + }, auto_join: false, network_discovery: NetworkDiscoveryConfig { // If a test requires the peer probe they should explicitly enable it @@ -150,13 +135,7 @@ impl Default for DhtConfig { propagation_factor: 4, broadcast_factor: 8, outbound_buffer_size: 20, - saf_num_closest_nodes: 10, - saf_max_returned_messages: 50, - saf_msg_storage_capacity: 100_000, - saf_low_priority_msg_storage_ttl: Duration::from_secs(6 * 60 * 60), // 6 hours - saf_high_priority_msg_storage_ttl: Duration::from_secs(3 * 24 * 60 * 60), // 3 days - saf_auto_request: true, - saf_max_message_size: 512 * 1024, + saf_config: Default::default(), dedup_cache_capacity: 2_500, dedup_cache_trim_interval: Duration::from_secs(5 * 60), dedup_allowed_message_occurrences: 1, @@ -172,7 +151,6 @@ impl Default for DhtConfig { flood_ban_max_msg_count: 10000, flood_ban_timespan: Duration::from_secs(100), offline_peer_cooldown: Duration::from_secs(2 * 60 * 60), - saf_msg_validity: Duration::from_secs(10800), } } } diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 632cbd66d8..54f049a3da 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -233,7 +233,7 @@ impl Dht { saf_response_signal_rx: mpsc::Receiver<()>, ) -> StoreAndForwardService { StoreAndForwardService::new( - self.config.clone(), + self.config.saf_config.clone(), conn, self.peer_manager.clone(), self.dht_requester(), @@ -311,7 +311,7 @@ impl Dht { self.node_identity.node_id().short_str() ))) .layer(store_forward::StoreLayer::new( - self.config.clone(), + self.config.saf_config.clone(), Arc::clone(&self.peer_manager), Arc::clone(&self.node_identity), self.store_and_forward_requester(), @@ -321,7 +321,7 @@ impl Dht { self.node_identity.features().contains(PeerFeatures::DHT_STORE_FORWARD), )) .layer(store_forward::MessageHandlerLayer::new( - self.config.clone(), + self.config.saf_config.clone(), self.store_and_forward_requester(), self.dht_requester(), Arc::clone(&self.node_identity), @@ -640,6 +640,12 @@ mod test { .await .unwrap(); + // SAF messages need to be requested before any response is accepted + dht.store_and_forward_requester() + .request_saf_messages_from_peer(node_identity.node_id().clone()) + .await + .unwrap(); + let spy = service_spy(); let mut service = dht.inbound_middleware_layer().layer(spy.to_service()); @@ -652,10 +658,7 @@ mod test { MessageTag::new(), false, ); - dht_envelope.header.as_mut().map(|header| { - header.message_type = DhtMessageType::SafStoredMessages as i32; - header - }); + dht_envelope.header.as_mut().unwrap().message_type = DhtMessageType::SafStoredMessages as i32; let inbound_message = make_comms_inbound_message(&node_identity, dht_envelope.to_encoded_bytes().into()); service.call(inbound_message).await.unwrap_err(); diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index 24b02c1191..87edbe8eb5 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(map_entry_replace)] #![doc(html_root_url = "https://docs.rs/tower-filter/0.3.0-alpha.2")] #![cfg_attr(not(debug_assertions), deny(unused_variables))] #![cfg_attr(not(debug_assertions), deny(unused_imports))] diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index b361b3db24..d9dcb9cf8d 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -85,7 +85,7 @@ impl BroadcastLayer { dht_requester, dht_discovery_requester, node_identity, - message_validity_window: chrono::Duration::from_std(config.saf_msg_validity) + message_validity_window: chrono::Duration::from_std(config.saf_config.msg_validity) .expect("message_validity_window is too large"), protocol_version: config.protocol_version, } diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index 81d92ab19e..e745ffd7ea 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -149,6 +149,12 @@ impl SendMessageParams { self } + /// Set broadcast_strategy to SelectedPeers. Messages are queued for all selected peers. + pub fn selected_peers(&mut self, peers: Vec) -> &mut Self { + self.params_mut().broadcast_strategy = BroadcastStrategy::SelectedPeers(peers); + self + } + /// Set broadcast_strategy to Neighbours. `excluded_peers` are excluded. Only Peers that have /// `PeerFeatures::MESSAGE_PROPAGATION` are included. pub fn broadcast(&mut self, excluded_peers: Vec) -> &mut Self { diff --git a/comms/dht/src/store_forward/config.rs b/comms/dht/src/store_forward/config.rs new file mode 100644 index 0000000000..126b436764 --- /dev/null +++ b/comms/dht/src/store_forward/config.rs @@ -0,0 +1,73 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::time::Duration; + +#[derive(Debug, Clone)] +pub struct SafConfig { + /// The amount of seconds added to the current time (Utc) which will then be used to check if the message has + /// expired or not when processing the message + /// Default: 3 hours + pub msg_validity: Duration, + /// The maximum number of messages that can be stored using the Store-and-forward middleware. + /// Default: 100,000 + pub msg_storage_capacity: usize, + /// A request to retrieve stored messages will be ignored if the requesting node is + /// not within one of this nodes _n_ closest nodes. + /// Default 8 + pub num_closest_nodes: usize, + /// The maximum number of messages to return from a store and forward retrieval request. + /// Default: 100 + pub max_returned_messages: usize, + /// The time-to-live duration used for storage of low priority messages by the Store-and-forward middleware. + /// Default: 6 hours + pub low_priority_msg_storage_ttl: Duration, + /// The time-to-live duration used for storage of high priority messages by the Store-and-forward middleware. + /// Default: 3 days + pub high_priority_msg_storage_ttl: Duration, + /// The limit on the message size to store in SAF storage in bytes. Default 500 KiB + pub max_message_size: usize, + /// When true, store and forward messages are requested from peers on connect (Default: true) + pub auto_request: bool, + /// The maximum allowed time between asking for a message and accepting a response + pub max_inflight_request_age: Duration, + /// The maximum number of peer nodes that a message must be closer than to get stored by SAF + /// Default: 8 + pub num_neighbouring_nodes: usize, +} + +impl Default for SafConfig { + fn default() -> Self { + Self { + msg_validity: Duration::from_secs(3 * 60 * 60), // 3 hours + num_closest_nodes: 10, + max_returned_messages: 50, + msg_storage_capacity: 100_000, + low_priority_msg_storage_ttl: Duration::from_secs(6 * 60 * 60), // 6 hours + high_priority_msg_storage_ttl: Duration::from_secs(3 * 24 * 60 * 60), // 3 days + auto_request: true, + max_message_size: 512 * 1024, + max_inflight_request_age: Duration::from_secs(120), + num_neighbouring_nodes: 8, + } + } +} diff --git a/comms/dht/src/store_forward/error.rs b/comms/dht/src/store_forward/error.rs index 81455d4882..626b6b3531 100644 --- a/comms/dht/src/store_forward/error.rs +++ b/comms/dht/src/store_forward/error.rs @@ -22,7 +22,11 @@ use crate::{actor::DhtActorError, envelope::DhtMessageError, outbound::DhtOutboundError, storage::StorageError}; use prost::DecodeError; -use tari_comms::{message::MessageError, peer_manager::PeerManagerError}; +use std::time::Duration; +use tari_comms::{ + message::MessageError, + peer_manager::{NodeId, PeerManagerError}, +}; use tari_utilities::{byte_array::ByteArrayError, ciphers::cipher::CipherError}; use thiserror::Error; @@ -62,7 +66,6 @@ pub enum StoreAndForwardError { MessageOriginRequired, #[error("The message was malformed")] MalformedMessage, - #[error("StorageError: {0}")] StorageError(#[from] StorageError), #[error("The store and forward service requester channel closed")] @@ -81,4 +84,10 @@ pub enum StoreAndForwardError { InvalidDhtMessageType, #[error("Failed to send request for store and forward messages: {0}")] RequestMessagesFailed(DhtOutboundError), + #[error("Received SAF messages that were not requested")] + ReceivedUnrequestedSafMessages, + #[error("SAF messages received from peer {peer} after deadline. Received after {0:.2?}")] + SafMessagesRecievedAfterDeadline { peer: NodeId, message_age: Duration }, + #[error("Invalid SAF request: `stored_at` cannot be in the future")] + StoredAtWasInFuture, } diff --git a/comms/dht/src/store_forward/local_state.rs b/comms/dht/src/store_forward/local_state.rs new file mode 100644 index 0000000000..15df5a7aea --- /dev/null +++ b/comms/dht/src/store_forward/local_state.rs @@ -0,0 +1,76 @@ +// Copyright 2021, The Tari Project +// +// Redistribution and use in source and binary forms, with or without modification, are permitted provided that the +// following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following +// disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote +// products derived from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, +// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +// USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +use std::{ + collections::{hash_map::Entry, HashMap}, + time::{Duration, Instant}, +}; +use tari_comms::peer_manager::NodeId; + +#[derive(Debug, Clone, Default)] +pub struct SafLocalState { + inflight_saf_requests: HashMap, +} + +impl SafLocalState { + pub fn register_inflight_requests(&mut self, peers: &[NodeId]) { + peers + .iter() + .for_each(|peer| self.register_inflight_request(peer.clone())); + } + + pub fn register_inflight_request(&mut self, peer: NodeId) { + match self.inflight_saf_requests.entry(peer) { + Entry::Occupied(entry) => { + let (count, _) = *entry.get(); + entry.replace_entry((count + 1, Instant::now())); + }, + Entry::Vacant(entry) => { + entry.insert((1, Instant::now())); + }, + } + } + + pub fn mark_infight_response_received(&mut self, peer: NodeId) -> Option { + match self.inflight_saf_requests.entry(peer) { + Entry::Occupied(entry) => { + let (count, ts) = *entry.get(); + let reduced_count = count - 1; + if reduced_count > 0 { + entry.replace_entry((reduced_count, ts)); + } else { + entry.remove(); + } + Some(ts.elapsed()) + }, + Entry::Vacant(_) => None, + } + } + + pub fn garbage_collect(&mut self, older_than: Duration) { + self.inflight_saf_requests = self + .inflight_saf_requests + .drain() + .filter(|(_, (_, i))| i.elapsed() <= older_than) + .collect(); + } +} diff --git a/comms/dht/src/store_forward/mod.rs b/comms/dht/src/store_forward/mod.rs index aa8d9a91e9..f35a5ffdc6 100644 --- a/comms/dht/src/store_forward/mod.rs +++ b/comms/dht/src/store_forward/mod.rs @@ -31,6 +31,9 @@ pub use database::StoredMessage; mod error; pub use error::StoreAndForwardError; +mod config; +pub use config::SafConfig; + mod forward; pub use forward::ForwardLayer; @@ -39,5 +42,7 @@ mod message; mod saf_handler; pub use saf_handler::MessageHandlerLayer; +mod local_state; + mod store; pub use store::StoreLayer; diff --git a/comms/dht/src/store_forward/saf_handler/layer.rs b/comms/dht/src/store_forward/saf_handler/layer.rs index 16e2760a1e..b7fb888cbc 100644 --- a/comms/dht/src/store_forward/saf_handler/layer.rs +++ b/comms/dht/src/store_forward/saf_handler/layer.rs @@ -23,9 +23,8 @@ use super::middleware::MessageHandlerMiddleware; use crate::{ actor::DhtRequester, - config::DhtConfig, outbound::OutboundMessageRequester, - store_forward::StoreAndForwardRequester, + store_forward::{SafConfig, StoreAndForwardRequester}, }; use std::sync::Arc; use tari_comms::peer_manager::{NodeIdentity, PeerManager}; @@ -33,7 +32,7 @@ use tokio::sync::mpsc; use tower::layer::Layer; pub struct MessageHandlerLayer { - config: DhtConfig, + config: SafConfig, saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, peer_manager: Arc, @@ -44,7 +43,7 @@ pub struct MessageHandlerLayer { impl MessageHandlerLayer { pub fn new( - config: DhtConfig, + config: SafConfig, saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, node_identity: Arc, diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 641950e4f1..2f1ddcbbd1 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -23,10 +23,9 @@ use super::task::MessageHandlerTask; use crate::{ actor::DhtRequester, - config::DhtConfig, inbound::DecryptedDhtMessage, outbound::OutboundMessageRequester, - store_forward::StoreAndForwardRequester, + store_forward::{SafConfig, StoreAndForwardRequester}, }; use futures::{future::BoxFuture, task::Context}; use std::{sync::Arc, task::Poll}; @@ -39,7 +38,7 @@ use tower::Service; #[derive(Clone)] pub struct MessageHandlerMiddleware { - config: DhtConfig, + config: SafConfig, next_service: S, saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, @@ -52,7 +51,7 @@ pub struct MessageHandlerMiddleware { impl MessageHandlerMiddleware { #[allow(clippy::too_many_arguments)] pub fn new( - config: DhtConfig, + config: SafConfig, next_service: S, saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, diff --git a/comms/dht/src/store_forward/saf_handler/mod.rs b/comms/dht/src/store_forward/saf_handler/mod.rs index df5cb40552..5d5014b5b0 100644 --- a/comms/dht/src/store_forward/saf_handler/mod.rs +++ b/comms/dht/src/store_forward/saf_handler/mod.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. mod layer; +pub use layer::MessageHandlerLayer; + mod middleware; mod task; - -pub use layer::MessageHandlerLayer; diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index a544f439a8..6fc78accff 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -22,7 +22,6 @@ use crate::{ actor::DhtRequester, - config::DhtConfig, crypt, envelope::{timestamp_to_datetime, DhtMessageFlags, DhtMessageHeader, NodeDestination}, inbound::{DecryptedDhtMessage, DhtInboundMessage}, @@ -37,7 +36,12 @@ use crate::{ }, }, storage::DhtMetadataKey, - store_forward::{error::StoreAndForwardError, service::FetchStoredMessageQuery, StoreAndForwardRequester}, + store_forward::{ + error::StoreAndForwardError, + service::FetchStoredMessageQuery, + SafConfig, + StoreAndForwardRequester, + }, }; use chrono::{DateTime, NaiveDateTime, Utc}; use digest::Digest; @@ -47,7 +51,7 @@ use prost::Message; use std::{convert::TryInto, sync::Arc}; use tari_comms::{ message::{EnvelopeBody, MessageTag}, - peer_manager::{NodeIdentity, Peer, PeerFeatures, PeerManager, PeerManagerError}, + peer_manager::{NodeId, NodeIdentity, Peer, PeerFeatures, PeerManager, PeerManagerError}, pipeline::PipelineError, types::{Challenge, CommsPublicKey}, utils::signature, @@ -59,7 +63,7 @@ use tower::{Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::storeforward::handler"; pub struct MessageHandlerTask { - config: DhtConfig, + config: SafConfig, next_service: S, dht_requester: DhtRequester, peer_manager: Arc, @@ -75,7 +79,7 @@ where S: Service { #[allow(clippy::too_many_arguments)] pub fn new( - config: DhtConfig, + config: SafConfig, next_service: S, saf_requester: StoreAndForwardRequester, dht_requester: DhtRequester, @@ -269,14 +273,20 @@ where S: Service ); let source_node_id = message.source_peer.node_id.clone(); let message_tag = message.dht_header.message_tag; - // TODO: Should check that stored messages were requested before accepting them + + if let Err(err) = self.check_saf_messages_were_requested(&source_node_id).await { + // TODO: Peer send SAF messages we didn't request?? #banheuristics + warn!(target: LOG_TARGET, "SAF response check failed: {}", err); + return Ok(()); + } + let msg = message .success() .expect("already checked that this message decrypted successfully"); let response = msg .decode_part::(0)? .ok_or(StoreAndForwardError::InvalidEnvelopeBody)?; - let source_peer = Arc::new(message.source_peer); + let source_peer = message.source_peer.clone(); debug!( target: LOG_TARGET, @@ -290,13 +300,9 @@ where S: Service message_tag ); - let mut results = Vec::with_capacity(response.messages.len()); - for msg in response.messages { - let result = self - .process_incoming_stored_message(Arc::clone(&source_peer), msg) - .await; - results.push(result); - } + let results = self + .process_incoming_stored_messages(source_peer.clone(), response.messages) + .await?; let successful_msgs_iter = results .into_iter() @@ -376,26 +382,65 @@ where S: Service Ok(()) } - async fn process_incoming_stored_message( + async fn process_incoming_stored_messages( + &mut self, + source_peer: Arc, + messages: Vec, + ) -> Result>, StoreAndForwardError> { + let mut last_saf_received = self + .dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await?; + + let mut results = Vec::with_capacity(messages.len()); + for msg in messages { + let result = self + .validate_and_decrypt_incoming_stored_message(Arc::clone(&source_peer), msg) + .await; + + if let Ok((_, stored_at)) = result.as_ref() { + if last_saf_received.as_ref().map(|dt| stored_at > dt).unwrap_or(true) { + last_saf_received = Some(*stored_at); + } + } + + results.push(result.map(|(msg, _)| msg)); + } + + if let Some(last_saf_received) = last_saf_received { + self.dht_requester + .set_metadata(DhtMetadataKey::LastSafMessageReceived, last_saf_received) + .await?; + } + + Ok(results) + } + + async fn validate_and_decrypt_incoming_stored_message( &mut self, source_peer: Arc, message: ProtoStoredMessage, - ) -> Result { + ) -> Result<(DecryptedDhtMessage, DateTime), StoreAndForwardError> { let node_identity = &self.node_identity; let peer_manager = &self.peer_manager; let config = &self.config; - if message.dht_header.is_none() { return Err(StoreAndForwardError::DhtHeaderNotProvided); } - let stored_at = match message.stored_at { - None => chrono::MIN_DATETIME, - Some(t) => DateTime::from_utc( - NaiveDateTime::from_timestamp(t.seconds, t.nanos.try_into().unwrap_or(0)), - Utc, - ), - }; + let stored_at = message + .stored_at + .map(|t| { + DateTime::from_utc( + NaiveDateTime::from_timestamp(t.seconds, t.nanos.try_into().unwrap_or(u32::MAX)), + Utc, + ) + }) + .unwrap_or(chrono::MIN_DATETIME); + + if stored_at > Utc::now() { + return Err(StoreAndForwardError::StoredAtWasInFuture); + } let dht_header: DhtMessageHeader = message .dht_header @@ -441,31 +486,9 @@ where S: Service DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body); inbound_msg.is_saf_message = true; - let last_saf_received = self - .dht_requester - .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) - .await - .ok() - .flatten() - .unwrap_or(chrono::MIN_DATETIME); - - if stored_at > last_saf_received { - if let Err(err) = self - .dht_requester - .set_metadata(DhtMetadataKey::LastSafMessageReceived, stored_at) - .await - { - warn!( - target: LOG_TARGET, - "Failed to set last SAF message received timestamp: {:?}", err - ); - } - } - - Ok(DecryptedDhtMessage::succeeded( - decrypted_body, - authenticated_pk, - inbound_msg, + Ok(( + DecryptedDhtMessage::succeeded(decrypted_body, authenticated_pk, inbound_msg), + stored_at, )) } @@ -484,7 +507,7 @@ where S: Service } async fn check_destination( - config: &DhtConfig, + config: &SafConfig, peer_manager: &PeerManager, node_identity: &NodeIdentity, dht_header: &DhtMessageHeader, @@ -566,6 +589,17 @@ where S: Service Err(StoreAndForwardError::InvalidOriginMac) } } + + async fn check_saf_messages_were_requested(&mut self, peer: &NodeId) -> Result<(), StoreAndForwardError> { + match self.saf_requester.mark_saf_response_received(peer.clone()).await? { + Some(age) if age <= self.config.max_inflight_request_age => Ok(()), + Some(age) => Err(StoreAndForwardError::SafMessagesRecievedAfterDeadline { + peer: peer.clone(), + message_age: age, + }), + None => Err(StoreAndForwardError::ReceivedUnrequestedSafMessages), + } + } } #[cfg(test)] @@ -587,14 +621,13 @@ mod test { service_spy, }, }; - use chrono::{Duration as OldDuration, Utc}; - use prost::Message; + use chrono::Utc; use std::time::Duration; use tari_comms::{message::MessageExt, runtime, wrap_in_envelope_body}; use tari_crypto::tari_utilities::hex; use tari_test_utils::collect_recv; use tari_utilities::hex::Hex; - use tokio::{runtime::Handle, sync::mpsc, task, time::sleep}; + use tokio::{sync::mpsc, task, time::sleep}; // TODO: unit tests for static functions (check_signature, etc) @@ -680,17 +713,17 @@ mod test { task::spawn(task.run()); - for _ in 0..6 { - if oms_mock_state.call_count() >= 1 { - break; + task::spawn_blocking({ + let mock_state = oms_mock_state.clone(); + move || { + mock_state.wait_call_count(1, Duration::from_secs(10)).unwrap(); } - sleep(Duration::from_secs(5)).await; - } - assert_eq!(oms_mock_state.call_count(), 1); + }) + .await + .unwrap(); - let call = oms_mock_state.pop_call().unwrap(); - let body = call.1.to_vec(); - let body = EnvelopeBody::decode(body.as_slice()).unwrap(); + let (_, body) = oms_mock_state.pop_call().unwrap(); + let body = EnvelopeBody::decode(body.as_ref()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); assert_eq!(msg.messages().len(), 0); assert!(!spy.is_called()); @@ -702,7 +735,7 @@ mod test { assert!(fetch_call.contains(format!("{:?}", since).as_str())); let msg1_time = Utc::now() - .checked_sub_signed(OldDuration::from_std(Duration::from_secs(120)).unwrap()) + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(120)).unwrap()) .unwrap(); let msg1 = "one".to_string(); mock_state @@ -715,7 +748,7 @@ mod test { .await; let msg2_time = Utc::now() - .checked_sub_signed(OldDuration::from_std(Duration::from_secs(30)).unwrap()) + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(30)).unwrap()) .unwrap(); let msg2 = "two".to_string(); mock_state @@ -774,9 +807,8 @@ mod test { #[runtime::test] async fn receive_stored_messages() { - let rt_handle = Handle::current(); let spy = service_spy(); - let (requester, _) = create_store_and_forward_mock(); + let (saf_requester, saf_mock_state) = create_store_and_forward_mock(); let peer_manager = build_peer_manager(); let (oms_tx, _) = mpsc::channel(1); @@ -803,11 +835,11 @@ mod test { .unwrap(); let msg1_time = Utc::now() - .checked_sub_signed(OldDuration::from_std(Duration::from_secs(60)).unwrap()) + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(60)).unwrap()) .unwrap(); let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body, msg1_time); let msg2_time = Utc::now() - .checked_sub_signed(OldDuration::from_std(Duration::from_secs(30)).unwrap()) + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(30)).unwrap()) .unwrap(); let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body, msg2_time); @@ -822,7 +854,7 @@ mod test { ) .dht_header; let msg_clear_time = Utc::now() - .checked_sub_signed(OldDuration::from_std(Duration::from_secs(120)).unwrap()) + .checked_sub_signed(chrono::Duration::from_std(Duration::from_secs(120)).unwrap()) .unwrap(); let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg, msg_clear_time); let mut message = DecryptedDhtMessage::succeeded( @@ -843,7 +875,7 @@ mod test { message.dht_header.message_type = DhtMessageType::SafStoredMessages; let (mut dht_requester, mock) = create_dht_actor_mock(1); - rt_handle.spawn(mock.run()); + task::spawn(mock.run()); let (saf_response_signal_sender, mut saf_response_signal_receiver) = mpsc::channel(20); assert!(dht_requester @@ -852,10 +884,13 @@ mod test { .unwrap() .is_none()); + // Allow request inflight check to pass + saf_mock_state.set_request_inflight(Some(Duration::from_secs(10))).await; + let task = MessageHandlerTask::new( Default::default(), spy.to_service::(), - requester, + saf_requester, dht_requester.clone(), peer_manager, OutboundMessageRequester::new(oms_tx), @@ -891,4 +926,161 @@ mod test { assert_eq!(last_saf_received, msg2_time); } + + #[runtime::test] + async fn stored_at_in_future() { + let spy = service_spy(); + let (requester, _) = create_store_and_forward_mock(); + + let peer_manager = build_peer_manager(); + let (oms_tx, _) = mpsc::channel(1); + + let node_identity = make_node_identity(); + + let msg_a = wrap_in_envelope_body!(&b"A".to_vec()).to_encoded_bytes(); + let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a, DhtMessageFlags::ENCRYPTED, true, false); + peer_manager + .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) + .await + .unwrap(); + + let msg1 = ProtoStoredMessage::new( + 0, + inbound_msg_a.dht_header.clone(), + inbound_msg_a.body, + Utc::now() + chrono::Duration::days(1), + ); + let mut message = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(StoredMessagesResponse { + messages: vec![msg1.clone()], + request_id: 123, + response_type: 0 + }), + None, + make_dht_inbound_message( + &node_identity, + b"Stored message".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + false, + ), + ); + message.dht_header.message_type = DhtMessageType::SafStoredMessages; + + let (mut dht_requester, mock) = create_dht_actor_mock(1); + task::spawn(mock.run()); + + let (saf_response_signal_sender, _) = mpsc::channel(1); + + let task = MessageHandlerTask::new( + Default::default(), + spy.to_service::(), + requester, + dht_requester.clone(), + peer_manager, + OutboundMessageRequester::new(oms_tx), + node_identity, + message, + saf_response_signal_sender, + ); + + task.run().await.unwrap(); + let requests = spy.take_requests(); + // Message was discarded + assert_eq!(spy.call_count(), 0); + assert_eq!(requests.len(), 0); + + let last_saf_received = dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .unwrap(); + + // LastSafMessageReceived was not set at all + assert!(last_saf_received.is_none()); + } + + #[runtime::test] + async fn saf_message_was_requested() { + let spy = service_spy(); + let (saf_requester, saf_mock_state) = create_store_and_forward_mock(); + + let peer_manager = build_peer_manager(); + let (oms_tx, _) = mpsc::channel(1); + + let node_identity = make_node_identity(); + + let msg_a = wrap_in_envelope_body!(&b"A".to_vec()).to_encoded_bytes(); + let inbound_msg_a = make_dht_inbound_message(&node_identity, msg_a, DhtMessageFlags::ENCRYPTED, true, false); + peer_manager + .add_peer(Clone::clone(&*inbound_msg_a.source_peer)) + .await + .unwrap(); + + let msg1 = ProtoStoredMessage::new( + 0, + inbound_msg_a.dht_header.clone(), + inbound_msg_a.body, + Utc::now() - chrono::Duration::days(1), + ); + let mut message = DecryptedDhtMessage::succeeded( + wrap_in_envelope_body!(StoredMessagesResponse { + messages: vec![msg1.clone()], + request_id: 123, + response_type: 0 + }), + None, + make_dht_inbound_message( + &node_identity, + b"Stored message".to_vec(), + DhtMessageFlags::ENCRYPTED, + true, + false, + ), + ); + message.dht_header.message_type = DhtMessageType::SafStoredMessages; + + let (dht_requester, mock) = create_dht_actor_mock(1); + task::spawn(mock.run()); + + let (saf_response_signal_sender, _) = mpsc::channel(1); + + let task = MessageHandlerTask::new( + Default::default(), + spy.to_service::(), + saf_requester.clone(), + dht_requester.clone(), + peer_manager.clone(), + OutboundMessageRequester::new(oms_tx.clone()), + node_identity.clone(), + message.clone(), + saf_response_signal_sender.clone(), + ); + + task.run().await.unwrap(); + let requests = spy.take_requests(); + // Message was discarded + assert_eq!(spy.call_count(), 0); + assert_eq!(requests.len(), 0); + + // The SAF request was made + saf_mock_state.set_request_inflight(Some(Duration::from_secs(0))).await; + + let task = MessageHandlerTask::new( + Default::default(), + spy.to_service::(), + saf_requester, + dht_requester, + peer_manager, + OutboundMessageRequester::new(oms_tx), + node_identity, + message, + saf_response_signal_sender, + ); + + task.run().await.unwrap(); + let requests = spy.take_requests(); + // Message was discarded + assert_eq!(spy.call_count(), 1); + assert_eq!(requests.len(), 1); + } } diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index cf22df1ded..91bff5037c 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -27,12 +27,13 @@ use super::{ StoreAndForwardError, }; use crate::{ + broadcast_strategy::BroadcastStrategy, envelope::DhtMessageType, event::{DhtEvent, DhtEventSender}, outbound::{OutboundMessageRequester, SendMessageParams}, proto::store_forward::{stored_messages_response::SafResponseType, StoredMessagesRequest}, storage::{DbConnection, DhtMetadataKey}, - DhtConfig, + store_forward::{local_state::SafLocalState, SafConfig}, DhtRequester, }; use chrono::{DateTime, NaiveDateTime, Utc}; @@ -96,8 +97,9 @@ pub enum StoreAndForwardRequest { InsertMessage(NewStoredMessage, oneshot::Sender>), RemoveMessages(Vec), RemoveMessagesOlderThan(DateTime), - SendStoreForwardRequestToPeer(Box), + SendStoreForwardRequestToPeer(NodeId), SendStoreForwardRequestNeighbours, + MarkSafResponseReceived(NodeId, oneshot::Sender>), } #[derive(Clone)] @@ -146,7 +148,7 @@ impl StoreAndForwardRequester { pub async fn request_saf_messages_from_peer(&mut self, node_id: NodeId) -> SafResult<()> { self.sender - .send(StoreAndForwardRequest::SendStoreForwardRequestToPeer(Box::new(node_id))) + .send(StoreAndForwardRequest::SendStoreForwardRequestToPeer(node_id)) .await .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; Ok(()) @@ -159,10 +161,19 @@ impl StoreAndForwardRequester { .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; Ok(()) } + + pub async fn mark_saf_response_received(&mut self, peer: NodeId) -> SafResult> { + let (reply_tx, reply_rx) = oneshot::channel(); + self.sender + .send(StoreAndForwardRequest::MarkSafResponseReceived(peer, reply_tx)) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + reply_rx.await.map_err(|_| StoreAndForwardError::RequestCancelled) + } } pub struct StoreAndForwardService { - config: DhtConfig, + config: SafConfig, dht_requester: DhtRequester, database: StoreAndForwardDatabase, peer_manager: Arc, @@ -174,12 +185,13 @@ pub struct StoreAndForwardService { num_online_peers: Option, saf_response_signal_rx: mpsc::Receiver<()>, event_publisher: DhtEventSender, + local_state: SafLocalState, } impl StoreAndForwardService { #[allow(clippy::too_many_arguments)] pub fn new( - config: DhtConfig, + config: SafConfig, conn: DbConnection, peer_manager: Arc, dht_requester: DhtRequester, @@ -203,12 +215,13 @@ impl StoreAndForwardService { num_online_peers: None, saf_response_signal_rx, event_publisher, + local_state: Default::default(), } } pub fn spawn(self) { info!(target: LOG_TARGET, "Store and forward service started"); - task::spawn(Self::run(self)); + task::spawn(self.run()); } async fn run(mut self) { @@ -311,6 +324,9 @@ impl StoreAndForwardService { Err(err) => error!(target: LOG_TARGET, "RemoveMessage failed because '{:?}'", err), } }, + MarkSafResponseReceived(peer, reply) => { + let _ = reply.send(self.local_state.mark_infight_response_received(peer)); + }, } } @@ -320,7 +336,7 @@ impl StoreAndForwardService { #[allow(clippy::single_match)] match event { PeerConnected(conn) => { - if !self.config.saf_auto_request { + if !self.config.auto_request { debug!( target: LOG_TARGET, "Auto store and forward request disabled. Ignoring PeerConnected event" @@ -358,7 +374,7 @@ impl StoreAndForwardService { target: LOG_TARGET, "Sending store and forward request to peer '{}' (Since = {:?})", node_id, request.since ); - + self.local_state.register_inflight_request(node_id.clone()); self.outbound_requester .send_message_no_header( SendMessageParams::new() @@ -379,10 +395,17 @@ impl StoreAndForwardService { target: LOG_TARGET, "Sending store and forward request to neighbours (Since = {:?})", request.since ); + let selected_peers = self + .dht_requester + .select_peers(BroadcastStrategy::Broadcast(vec![])) + .await?; + + self.local_state.register_inflight_requests(&selected_peers); + self.outbound_requester .send_message_no_header( SendMessageParams::new() - .broadcast(vec![]) + .selected_peers(selected_peers) .with_dht_message_type(DhtMessageType::SafRequestMessages) .finish(), request, @@ -432,7 +455,7 @@ impl StoreAndForwardService { async fn handle_fetch_message_query(&self, query: FetchStoredMessageQuery) -> SafResult> { use SafResponseType::*; - let limit = i64::try_from(self.config.saf_max_returned_messages) + let limit = i64::try_from(self.config.max_returned_messages) .ok() .or(Some(std::i64::MAX)) .unwrap(); @@ -453,12 +476,15 @@ impl StoreAndForwardService { Ok(messages) } - async fn cleanup(&self) -> SafResult<()> { + async fn cleanup(&mut self) -> SafResult<()> { + self.local_state + .garbage_collect(self.config.max_inflight_request_age * 2); + let num_removed = self .database .delete_messages_with_priority_older_than( StoredMessagePriority::Low, - since(self.config.saf_low_priority_msg_storage_ttl), + since(self.config.low_priority_msg_storage_ttl), ) .await?; debug!(target: LOG_TARGET, "Cleaned {} old low priority messages", num_removed); @@ -467,14 +493,14 @@ impl StoreAndForwardService { .database .delete_messages_with_priority_older_than( StoredMessagePriority::High, - since(self.config.saf_high_priority_msg_storage_ttl), + since(self.config.high_priority_msg_storage_ttl), ) .await?; debug!(target: LOG_TARGET, "Cleaned {} old high priority messages", num_removed); let num_removed = self .database - .truncate_messages(self.config.saf_msg_storage_capacity) + .truncate_messages(self.config.msg_storage_capacity) .await?; if num_removed > 0 { debug!( diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 788a333c24..f9024f1a98 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -27,9 +27,9 @@ use crate::{ database::NewStoredMessage, error::StoreAndForwardError, message::StoredMessagePriority, + SafConfig, SafResult, }, - DhtConfig, }; use futures::{future::BoxFuture, task::Context}; use log::*; @@ -46,14 +46,14 @@ const LOG_TARGET: &str = "comms::dht::storeforward::store"; /// This layer is responsible for storing messages which have failed to decrypt pub struct StoreLayer { peer_manager: Arc, - config: DhtConfig, + config: SafConfig, node_identity: Arc, saf_requester: StoreAndForwardRequester, } impl StoreLayer { pub fn new( - config: DhtConfig, + config: SafConfig, peer_manager: Arc, node_identity: Arc, saf_requester: StoreAndForwardRequester, @@ -84,7 +84,7 @@ impl Layer for StoreLayer { #[derive(Clone)] pub struct StoreMiddleware { next_service: S, - config: DhtConfig, + config: SafConfig, peer_manager: Arc, node_identity: Arc, saf_requester: StoreAndForwardRequester, @@ -93,7 +93,7 @@ pub struct StoreMiddleware { impl StoreMiddleware { pub fn new( next_service: S, - config: DhtConfig, + config: SafConfig, peer_manager: Arc, node_identity: Arc, saf_requester: StoreAndForwardRequester, @@ -155,7 +155,7 @@ where struct StoreTask { next_service: S, peer_manager: Arc, - config: DhtConfig, + config: SafConfig, node_identity: Arc, saf_requester: StoreAndForwardRequester, } @@ -165,7 +165,7 @@ where S: Service + Se { pub fn new( next_service: S, - config: DhtConfig, + config: SafConfig, peer_manager: Arc, node_identity: Arc, saf_requester: StoreAndForwardRequester, @@ -231,11 +231,11 @@ where S: Service + Se ); }; - if message.body_len() > self.config.saf_max_message_size { + if message.body_len() > self.config.max_message_size { log_not_eligible(&format!( "the message body exceeded the maximum storage size (body size={}, max={})", message.body_len(), - self.config.saf_max_message_size + self.config.max_message_size )); return Ok(None); } diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index 72c0861a6d..3fc3416c37 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -25,9 +25,12 @@ use chrono::Utc; use digest::Digest; use log::*; use rand::{rngs::OsRng, RngCore}; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, }; use tari_comms::types::Challenge; use tari_utilities::hex; @@ -47,20 +50,17 @@ pub fn create_store_and_forward_mock() -> (StoreAndForwardRequester, StoreAndFor (StoreAndForwardRequester::new(tx), state) } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct StoreAndForwardMockState { call_count: Arc, stored_messages: Arc>>, calls: Arc>>, + inflight_request: Arc>>, } impl StoreAndForwardMockState { pub fn new() -> Self { - Self { - call_count: Arc::new(AtomicUsize::new(0)), - stored_messages: Arc::new(RwLock::new(Vec::new())), - calls: Arc::new(RwLock::new(Vec::new())), - } + Default::default() } pub fn inc_call_count(&self) { @@ -89,6 +89,10 @@ impl StoreAndForwardMockState { self.call_count.store(0, Ordering::SeqCst); calls } + + pub async fn set_request_inflight(&self, duration: Option) { + *self.inflight_request.write().await = duration; + } } pub struct StoreAndForwardMock { @@ -161,6 +165,9 @@ impl StoreAndForwardMock { .await .retain(|msg| msg.stored_at >= threshold.naive_utc()); }, + MarkSafResponseReceived(_, reply) => { + let _ = reply.send(*self.state.inflight_request.read().await); + }, } } } diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 3d87f31c78..0739aea985 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -226,7 +226,7 @@ async fn setup_comms_dht( fn dht_config() -> DhtConfig { let mut config = DhtConfig::default_local_test(); config.allow_test_addresses = true; - config.saf_auto_request = false; + config.saf_config.auto_request = false; config.discovery_request_timeout = Duration::from_secs(60); config.num_neighbouring_nodes = 8; config