From 03c514642933f79cd115e242222dd3048718311d Mon Sep 17 00:00:00 2001 From: Philip Robinson Date: Tue, 10 Aug 2021 11:46:49 +0200 Subject: [PATCH] Update handling of SAF message propagation and deletion This PR adds two changes to the way SAF messages are handled to fix two subtle bugs spotted while developing cucumber tests. The first issue was that when a Node propagates a SAF message it was storing to other nodes in its neighbourhood the broadcast strategy it was using only chose from currently connected base nodes. This meant that if the Node had an active connection to a Communication Client (wallet) it would not just directly send the SAF message to that client but to other base nodes in the network region. This meant that the wallet would only receive new SAF message when it actively requested them on connection even though it was directly connected to the node. This PR adds a new broadcast strategy called `DirectOrClosestNodes` which will first check if the node has a direct active connection and if it does just send the SAF message directly to its destination. The second issue was a subtle problem where when a node starts to send SAF messages to a destination it would remove the messages from the database based only on whether the outbound messages were put onto the outbound message pipeline. The problem occurs when the TCP connection to that peer is actually broken the sending of those messages would fail at the end of the pipeline but the SAF messages were already deleted from the database. This PR changes the way SAF messages are deleted. When a client asks a node for SAF message it will also provide a timestamp of the most recent SAF message it has received. The Node will then send all SAF messages since that timestamp that it has for the node and will delete all SAF messages from before the specified Timestamp. This serves as a form of Ack that the client has received the older messages at some point and they are no longer needed. --- comms/dht/src/actor.rs | 119 ++++++---- comms/dht/src/broadcast_strategy.rs | 30 ++- comms/dht/src/config.rs | 6 - comms/dht/src/outbound/broadcast.rs | 2 +- comms/dht/src/outbound/message_params.rs | 20 +- comms/dht/src/outbound/mock.rs | 2 +- comms/dht/src/storage/dht_setting_entry.rs | 2 + comms/dht/src/store_forward/database/mod.rs | 11 + comms/dht/src/store_forward/forward.rs | 2 +- comms/dht/src/store_forward/message.rs | 9 +- .../dht/src/store_forward/saf_handler/task.rs | 207 ++++++++++++++---- comms/dht/src/store_forward/service.rs | 31 ++- .../src/test_utils/store_and_forward_mock.rs | 22 +- 13 files changed, 341 insertions(+), 122 deletions(-) diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 0e0b6e6eea..c2c2d4e52a 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -28,7 +28,7 @@ //! [DhtRequest]: ./enum.DhtRequest.html use crate::{ - broadcast_strategy::BroadcastStrategy, + broadcast_strategy::{BroadcastClosestRequest, BroadcastStrategy}, dedup::DedupCacheDatabase, discovery::DhtDiscoveryError, outbound::{DhtOutboundError, OutboundMessageRequester, SendMessageParams}, @@ -416,43 +416,19 @@ impl DhtActor { .await?; Ok(peers.into_iter().map(|p| p.peer_node_id().clone()).collect()) }, - Closest(closest_request) => { - let connections = connectivity - .select_connections(ConnectivitySelection::closest_to( - closest_request.node_id.clone(), - config.broadcast_factor, - closest_request.excluded_peers.clone(), - )) - .await?; - - let mut candidates = connections - .iter() - .map(|conn| conn.peer_node_id()) - .cloned() - .collect::>(); - - if !closest_request.connected_only { - let excluded = closest_request - .excluded_peers - .iter() - .chain(candidates.iter()) - .cloned() - .collect::>(); - // If we don't have enough connections, let's select some more disconnected peers (at least 2) - let n = cmp::max(config.broadcast_factor.saturating_sub(candidates.len()), 2); - let additional = Self::select_closest_peers_for_propagation( - &peer_manager, - &closest_request.node_id, - n, - &excluded, - PeerFeatures::MESSAGE_PROPAGATION, - ) - .await?; - - candidates.extend(additional); + ClosestNodes(closest_request) => { + Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager).await + }, + DirectOrClosestNodes(closest_request) => { + // First check if a direct connection exists + if connectivity + .get_connection(closest_request.node_id.clone()) + .await? + .is_some() + { + return Ok(vec![closest_request.node_id.clone()]); } - - Ok(candidates) + Self::select_closest_node_connected(closest_request, config, connectivity, peer_manager).await }, Random(n, excluded) => { // Send to a random set of peers of size n that are Communication Nodes @@ -659,6 +635,50 @@ impl DhtActor { Ok(peers.into_iter().map(|p| p.node_id).collect()) } + + async fn select_closest_node_connected( + closest_request: Box, + config: DhtConfig, + mut connectivity: ConnectivityRequester, + peer_manager: Arc, + ) -> Result, DhtActorError> { + let connections = connectivity + .select_connections(ConnectivitySelection::closest_to( + closest_request.node_id.clone(), + config.broadcast_factor, + closest_request.excluded_peers.clone(), + )) + .await?; + + let mut candidates = connections + .iter() + .map(|conn| conn.peer_node_id()) + .cloned() + .collect::>(); + + if !closest_request.connected_only { + let excluded = closest_request + .excluded_peers + .iter() + .chain(candidates.iter()) + .cloned() + .collect::>(); + // If we don't have enough connections, let's select some more disconnected peers (at least 2) + let n = cmp::max(config.broadcast_factor.saturating_sub(candidates.len()), 2); + let additional = Self::select_closest_peers_for_propagation( + &peer_manager, + &closest_request.node_id, + n, + &excluded, + PeerFeatures::MESSAGE_PROPAGATION, + ) + .await?; + + candidates.extend(additional); + } + + Ok(candidates) + } } #[cfg(test)] @@ -888,6 +908,7 @@ mod test { connectivity_manager_mock_state .set_selected_connections(vec![conn_out.clone()]) .await; + let peers = requester .select_peers(BroadcastStrategy::Broadcast(Vec::new())) .await @@ -915,7 +936,29 @@ mod test { connected_only: false, }); let peers = requester - .select_peers(BroadcastStrategy::Closest(send_request)) + .select_peers(BroadcastStrategy::ClosestNodes(send_request)) + .await + .unwrap(); + assert_eq!(peers.len(), 2); + + let send_request = Box::new(BroadcastClosestRequest { + node_id: node_identity.node_id().clone(), + excluded_peers: vec![], + connected_only: false, + }); + let peers = requester + .select_peers(BroadcastStrategy::DirectOrClosestNodes(send_request)) + .await + .unwrap(); + assert_eq!(peers.len(), 1); + + let send_request = Box::new(BroadcastClosestRequest { + node_id: client_node_identity.node_id().clone(), + excluded_peers: vec![], + connected_only: false, + }); + let peers = requester + .select_peers(BroadcastStrategy::DirectOrClosestNodes(send_request)) .await .unwrap(); assert_eq!(peers.len(), 2); diff --git a/comms/dht/src/broadcast_strategy.rs b/comms/dht/src/broadcast_strategy.rs index 3e1b356067..9077cc3a58 100644 --- a/comms/dht/src/broadcast_strategy.rs +++ b/comms/dht/src/broadcast_strategy.rs @@ -57,7 +57,9 @@ pub enum BroadcastStrategy { /// Send to a random set of peers of size n that are Communication Nodes, excluding the given node IDs Random(usize, Vec), /// Send to all n nearest Communication Nodes according to the given BroadcastClosestRequest - Closest(Box), + ClosestNodes(Box), + /// Send directly to destination if connected but otherwise send to all n nearest Communication Nodes + DirectOrClosestNodes(Box), Broadcast(Vec), /// Propagate to a set of closest neighbours and random peers Propagate(NodeDestination, Vec), @@ -70,7 +72,8 @@ impl fmt::Display for BroadcastStrategy { DirectPublicKey(pk) => write!(f, "DirectPublicKey({})", pk), DirectNodeId(node_id) => write!(f, "DirectNodeId({})", node_id), Flood(excluded) => write!(f, "Flood({} excluded)", excluded.len()), - Closest(request) => write!(f, "Closest({})", request), + ClosestNodes(request) => write!(f, "ClosestNodes({})", request), + DirectOrClosestNodes(request) => write!(f, "DirectOrClosestNodes({})", request), Random(n, excluded) => write!(f, "Random({}, {} excluded)", n, excluded.len()), Broadcast(excluded) => write!(f, "Broadcast({} excluded)", excluded.len()), Propagate(destination, excluded) => write!(f, "Propagate({}, {} excluded)", destination, excluded.len(),), @@ -79,13 +82,18 @@ impl fmt::Display for BroadcastStrategy { } impl BroadcastStrategy { - /// Returns true if this strategy will send multiple messages, otherwise false - pub fn is_multi_message(&self) -> bool { + /// Returns true if this strategy will send multiple indirect messages, otherwise false + pub fn is_multi_message(&self, chosen_peers: &[NodeId]) -> bool { use BroadcastStrategy::*; - matches!( - self, - Closest(_) | Flood(_) | Broadcast(_) | Random(_, _) | Propagate(_, _) - ) + + match self { + DirectOrClosestNodes(strategy) => { + // Testing if there is a single chosen peer and it is the target NodeId + chosen_peers.len() == 1 && chosen_peers.first() == Some(&strategy.node_id) + }, + ClosestNodes(_) | Broadcast(_) | Propagate(_, _) | Flood(_) | Random(_, _) => true, + _ => false, + } } pub fn is_direct(&self) -> bool { @@ -129,7 +137,7 @@ mod test { assert!(!BroadcastStrategy::Broadcast(Default::default()).is_direct()); assert!(!BroadcastStrategy::Propagate(Default::default(), Default::default()).is_direct(),); assert!(!BroadcastStrategy::Flood(Default::default()).is_direct()); - assert!(!BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + assert!(!BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { node_id: NodeId::default(), excluded_peers: Default::default(), connected_only: false @@ -152,7 +160,7 @@ mod test { assert!(BroadcastStrategy::Flood(Default::default()) .direct_public_key() .is_none()); - assert!(BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + assert!(BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { node_id: NodeId::default(), excluded_peers: Default::default(), connected_only: false @@ -174,7 +182,7 @@ mod test { .direct_node_id() .is_none()); assert!(BroadcastStrategy::Flood(Default::default()).direct_node_id().is_none()); - assert!(BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + assert!(BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { node_id: NodeId::default(), excluded_peers: Default::default(), connected_only: false diff --git a/comms/dht/src/config.rs b/comms/dht/src/config.rs index 90fc9b8b72..0612445dca 100644 --- a/comms/dht/src/config.rs +++ b/comms/dht/src/config.rs @@ -66,11 +66,6 @@ pub struct DhtConfig { pub saf_max_message_size: usize, /// When true, store and forward messages are requested from peers on connect (Default: true) pub saf_auto_request: bool, - /// The minimum period used to request SAF messages from a peer. When requesting SAF messages, - /// it will request messages since the DHT last went offline, but this may be a small amount of - /// time, so `minimum_request_period` can be used so that messages aren't missed. - /// Default: 3 days - pub saf_minimum_request_period: Duration, /// The max capacity of the message hash cache /// Default: 2,500 pub dedup_cache_capacity: usize, @@ -154,7 +149,6 @@ impl Default for DhtConfig { saf_high_priority_msg_storage_ttl: Duration::from_secs(3 * 24 * 60 * 60), // 3 days saf_auto_request: true, saf_max_message_size: 512 * 1024, - saf_minimum_request_period: Duration::from_secs(3 * 24 * 60 * 60), // 3 days dedup_cache_capacity: 2_500, dedup_cache_trim_interval: Duration::from_secs(5 * 60), database_url: DbConnectionUrl::Memory, diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index a3b122f8ab..0aa9fab611 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -268,7 +268,7 @@ where S: Service is_discovery_enabled, ); - let is_broadcast = broadcast_strategy.is_multi_message(); + let is_broadcast = broadcast_strategy.is_multi_message(&peers); // Discovery is required if: // - Discovery is enabled for this request diff --git a/comms/dht/src/outbound/message_params.rs b/comms/dht/src/outbound/message_params.rs index ffc463771a..0ad00bbc4e 100644 --- a/comms/dht/src/outbound/message_params.rs +++ b/comms/dht/src/outbound/message_params.rs @@ -116,7 +116,7 @@ impl SendMessageParams { /// `node_id` - Select the closest known peers to this `NodeId` /// `excluded_peers` - vector of `NodeId`s to exclude from broadcast. pub fn closest(&mut self, node_id: NodeId, excluded_peers: Vec) -> &mut Self { - self.params_mut().broadcast_strategy = BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + self.params_mut().broadcast_strategy = BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { excluded_peers, node_id, connected_only: false, @@ -124,10 +124,10 @@ impl SendMessageParams { self } - /// Set broadcast_strategy to Closest.`excluded_peers` are excluded. Only peers that are currently connected will be - /// included. + /// Set broadcast_strategy to ClosestNodes.`excluded_peers` are excluded. Only peers that are currently connected + /// will be included. pub fn closest_connected(&mut self, node_id: NodeId, excluded_peers: Vec) -> &mut Self { - self.params_mut().broadcast_strategy = BroadcastStrategy::Closest(Box::new(BroadcastClosestRequest { + self.params_mut().broadcast_strategy = BroadcastStrategy::ClosestNodes(Box::new(BroadcastClosestRequest { excluded_peers, node_id, connected_only: true, @@ -135,6 +135,18 @@ impl SendMessageParams { self } + /// Set broadcast_strategy to DirectOrClosestNodes.`excluded_peers` are excluded. Only peers that are currently + /// connected will be included. + pub fn direct_or_closest_connected(&mut self, node_id: NodeId, excluded_peers: Vec) -> &mut Self { + self.params_mut().broadcast_strategy = + BroadcastStrategy::DirectOrClosestNodes(Box::new(BroadcastClosestRequest { + excluded_peers, + node_id, + connected_only: true, + })); + self + } + /// Set broadcast_strategy to Neighbours. `excluded_peers` are excluded. Only Peers that have /// `PeerFeatures::MESSAGE_PROPAGATION` are included. pub fn broadcast(&mut self, excluded_peers: Vec) -> &mut Self { diff --git a/comms/dht/src/outbound/mock.rs b/comms/dht/src/outbound/mock.rs index 6cf4b83e40..f5c3f30665 100644 --- a/comms/dht/src/outbound/mock.rs +++ b/comms/dht/src/outbound/mock.rs @@ -205,7 +205,7 @@ impl OutboundServiceMock { }, }; }, - BroadcastStrategy::Closest(_) => { + BroadcastStrategy::ClosestNodes(_) => { if behaviour.broadcast == ResponseType::Queued { let (response, mut inner_reply_tx) = self.add_call((*params).clone(), body); reply_tx.send(response).expect("Reply channel cancelled"); diff --git a/comms/dht/src/storage/dht_setting_entry.rs b/comms/dht/src/storage/dht_setting_entry.rs index 73cb39fe69..dd1e06597f 100644 --- a/comms/dht/src/storage/dht_setting_entry.rs +++ b/comms/dht/src/storage/dht_setting_entry.rs @@ -27,6 +27,8 @@ use std::fmt; pub enum DhtMetadataKey { /// Timestamp each time the DHT is shut down OfflineTimestamp, + /// Timestamp of the most recent SAF message received + LastSafMessageReceived, } impl fmt::Display for DhtMetadataKey { diff --git a/comms/dht/src/store_forward/database/mod.rs b/comms/dht/src/store_forward/database/mod.rs index ec6b19a42e..173d00e0ef 100644 --- a/comms/dht/src/store_forward/database/mod.rs +++ b/comms/dht/src/store_forward/database/mod.rs @@ -217,6 +217,17 @@ impl StoreAndForwardDatabase { .await } + pub(crate) async fn delete_messages_older_than(&self, since: NaiveDateTime) -> Result { + self.connection + .with_connection_async(move |conn| { + diesel::delete(stored_messages::table) + .filter(stored_messages::stored_at.lt(since)) + .execute(conn) + .map_err(Into::into) + }) + .await + } + pub(crate) async fn truncate_messages(&self, max_size: usize) -> Result { self.connection .with_connection_async(move |conn| { diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index 607dfe0fd1..95ce5e2500 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -219,7 +219,7 @@ where S: Service target: LOG_TARGET, "Forwarding SAF message directly to node: {}, Tag#{}", node_id, dht_header.message_tag ); - send_params.closest_connected(node_id.clone(), excluded_peers); + send_params.direct_or_closest_connected(node_id.clone(), excluded_peers); }, _ => { debug!( diff --git a/comms/dht/src/store_forward/message.rs b/comms/dht/src/store_forward/message.rs index d29481f3f2..85ba721934 100644 --- a/comms/dht/src/store_forward/message.rs +++ b/comms/dht/src/store_forward/message.rs @@ -52,12 +52,17 @@ impl StoredMessagesRequest { #[cfg(test)] impl StoredMessage { - pub fn new(version: u32, dht_header: crate::envelope::DhtMessageHeader, body: Vec) -> Self { + pub fn new( + version: u32, + dht_header: crate::envelope::DhtMessageHeader, + body: Vec, + stored_at: DateTime, + ) -> Self { Self { version, dht_header: Some(dht_header.into()), body, - stored_at: Some(datetime_to_timestamp(Utc::now())), + stored_at: Some(datetime_to_timestamp(stored_at)), } } } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index e32e3f60a1..f3ba852118 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -36,8 +36,10 @@ use crate::{ StoredMessagesResponse, }, }, + storage::DhtMetadataKey, store_forward::{error::StoreAndForwardError, service::FetchStoredMessageQuery, StoreAndForwardRequester}, }; +use chrono::{DateTime, NaiveDateTime, Utc}; use digest::Digest; use futures::{channel::mpsc, future, stream, SinkExt, StreamExt}; use log::*; @@ -172,15 +174,19 @@ where S: Service // Compile a set of stored messages for the requesting peer let mut query = FetchStoredMessageQuery::new(source_pubkey, source_node_id.clone()); - if let Some(since) = retrieve_msgs.since.map(timestamp_to_datetime) { - debug!( - target: LOG_TARGET, - "Peer '{}' requested all messages since '{}'", - source_node_id.short_str(), - since - ); - query.since(since); - } + let since: Option> = match retrieve_msgs.since.map(timestamp_to_datetime) { + Some(since) => { + debug!( + target: LOG_TARGET, + "Peer '{}' requested all messages since '{}'", + source_node_id.short_str(), + since + ); + query.with_messages_since(since); + Some(since) + }, + None => None, + }; let response_types = vec![SafResponseType::ForMe]; @@ -188,7 +194,6 @@ where S: Service query.with_response_type(resp_type); let messages = self.saf_requester.fetch_messages(query.clone()).await?; - let message_ids = messages.iter().map(|msg| msg.id).collect::>(); let stored_messages = StoredMessagesResponse { messages: try_convert_all(messages)?, request_id: retrieve_msgs.request_id, @@ -201,6 +206,7 @@ where S: Service stored_messages.messages().len(), resp_type ); + match self .outbound_service .send_message_no_header( @@ -215,13 +221,15 @@ where S: Service .await { Ok(_) => { - debug!( - target: LOG_TARGET, - "Removing {} stored message(s) for peer '{}'", - message_ids.len(), - message.source_peer.node_id.short_str() - ); - self.saf_requester.remove_messages(message_ids).await?; + if let Some(threshold) = since { + debug!( + target: LOG_TARGET, + "Removing stored message(s) from before {} for peer '{}'", + threshold, + message.source_peer.node_id.short_str() + ); + self.saf_requester.remove_messages_older_than(threshold).await?; + } }, Err(err) => { error!( @@ -366,6 +374,14 @@ where S: Service return Err(StoreAndForwardError::DhtHeaderNotProvided); } + let stored_at = match message.stored_at { + None => chrono::MIN_DATETIME, + Some(t) => DateTime::from_utc( + NaiveDateTime::from_timestamp(t.seconds, t.nanos.try_into().unwrap_or(0)), + Utc, + ), + }; + let dht_header: DhtMessageHeader = message .dht_header .expect("previously checked") @@ -410,6 +426,27 @@ where S: Service DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body); inbound_msg.is_saf_message = true; + let last_saf_received = self + .dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .ok() + .flatten() + .unwrap_or(chrono::MIN_DATETIME); + + if stored_at > last_saf_received { + if let Err(err) = self + .dht_requester + .set_metadata(DhtMetadataKey::LastSafMessageReceived, stored_at) + .await + { + warn!( + target: LOG_TARGET, + "Failed to set last SAF message received timestamp: {:?}", err + ); + } + } + Ok(DecryptedDhtMessage::succeeded( decrypted_body, authenticated_pk, @@ -515,6 +552,7 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, + outbound::mock::create_outbound_service_mock, proto::envelope::DhtHeader, store_forward::{message::StoredMessagePriority, StoredMessage}, test_utils::{ @@ -528,7 +566,7 @@ mod test { service_spy, }, }; - use chrono::Utc; + use chrono::{Duration as OldDuration, Utc}; use futures::channel::mpsc; use prost::Message; use std::time::Duration; @@ -536,12 +574,17 @@ mod test { use tari_crypto::tari_utilities::hex; use tari_test_utils::collect_stream; use tari_utilities::hex::Hex; - use tokio::runtime::Handle; + use tokio::{runtime::Handle, task, time::delay_for}; // TODO: unit tests for static functions (check_signature, etc) - fn make_stored_message(node_identity: &NodeIdentity, dht_header: DhtMessageHeader) -> StoredMessage { - let body = b"A".to_vec(); + fn make_stored_message( + message: String, + node_identity: &NodeIdentity, + dht_header: DhtMessageHeader, + stored_at: NaiveDateTime, + ) -> StoredMessage { + let body = message.as_bytes().to_vec(); let body_hash = hex::to_hex(&Challenge::new().chain(body.clone()).finalize()); StoredMessage { id: 1, @@ -554,19 +597,20 @@ mod test { body, is_encrypted: false, priority: StoredMessagePriority::High as i32, - stored_at: Utc::now().naive_utc(), + stored_at, body_hash, } } - #[tokio_macros::test_basic] + #[tokio_macros::test] async fn request_stored_messages() { - let rt_handle = Handle::current(); let spy = service_spy(); let (requester, mock_state) = create_store_and_forward_mock(); let peer_manager = build_peer_manager(); - let (oms_tx, mut oms_rx) = mpsc::channel(1); + let (outbound_requester, outbound_mock) = create_outbound_service_mock(10); + let oms_mock_state = outbound_mock.get_state(); + task::spawn(outbound_mock.run()); let node_identity = make_node_identity(); @@ -606,29 +650,59 @@ mod test { requester.clone(), dht_requester.clone(), peer_manager.clone(), - OutboundMessageRequester::new(oms_tx.clone()), + outbound_requester.clone(), node_identity.clone(), message.clone(), saf_response_signal_sender.clone(), ); - rt_handle.spawn(task.run()); + task::spawn(task.run()); - let (_, body) = unwrap_oms_send_msg!(oms_rx.next().await.unwrap()); - let body = body.to_vec(); + for _ in 0..6 { + if oms_mock_state.call_count() >= 1 { + break; + } + delay_for(Duration::from_secs(5)).await; + } + assert_eq!(oms_mock_state.call_count(), 1); + + let call = oms_mock_state.pop_call().unwrap(); + let body = call.1.to_vec(); let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); assert_eq!(msg.messages().len(), 0); assert!(!spy.is_called()); - assert_eq!(mock_state.call_count(), 1); + // assert_eq!(mock_state.call_count(), 2); let calls = mock_state.take_calls().await; - assert!(calls[0].contains("FetchMessages")); - assert!(calls[0].contains(node_identity.public_key().to_hex().as_str())); - assert!(calls[0].contains(format!("{:?}", since).as_str())); + let fetch_call = calls.iter().find(|c| c.contains("FetchMessages")).unwrap(); + assert!(fetch_call.contains(node_identity.public_key().to_hex().as_str())); + assert!(fetch_call.contains(format!("{:?}", since).as_str())); + let msg1_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(120)).unwrap()) + .unwrap(); + let msg1 = "one".to_string(); mock_state - .add_message(make_stored_message(&node_identity, dht_header)) + .add_message(make_stored_message( + msg1.clone(), + &node_identity, + dht_header.clone(), + msg1_time.naive_utc(), + )) + .await; + + let msg2_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(30)).unwrap()) + .unwrap(); + let msg2 = "two".to_string(); + mock_state + .add_message(make_stored_message( + msg2.clone(), + &node_identity, + dht_header, + msg2_time.naive_utc(), + )) .await; // Now lets test its response where there are messages to return. @@ -638,27 +712,42 @@ mod test { requester, dht_requester, peer_manager, - OutboundMessageRequester::new(oms_tx), + outbound_requester.clone(), node_identity.clone(), message, saf_response_signal_sender, ); - rt_handle.spawn(task.run()); + task::spawn(task.run()); - let (_, body) = unwrap_oms_send_msg!(oms_rx.next().await.unwrap()); - let body = body.to_vec(); + for _ in 0..6 { + if oms_mock_state.call_count() >= 1 { + break; + } + delay_for(Duration::from_secs(5)).await; + } + assert_eq!(oms_mock_state.call_count(), 1); + let call = oms_mock_state.pop_call().unwrap(); + + let body = call.1.to_vec(); let body = EnvelopeBody::decode(body.as_slice()).unwrap(); let msg = body.decode_part::(0).unwrap().unwrap(); + assert_eq!(msg.messages().len(), 1); - assert_eq!(msg.messages()[0].body, b"A"); + assert_eq!(msg.messages()[0].body, "two".as_bytes()); assert!(!spy.is_called()); assert_eq!(mock_state.call_count(), 2); let calls = mock_state.take_calls().await; - assert!(calls[0].contains("FetchMessages")); - assert!(calls[0].contains(node_identity.public_key().to_hex().as_str())); - assert!(calls[0].contains(format!("{:?}", since).as_str())); + + let fetch_call = calls.iter().find(|c| c.contains("FetchMessages")).unwrap(); + assert!(fetch_call.contains(node_identity.public_key().to_hex().as_str())); + assert!(fetch_call.contains(format!("{:?}", since).as_str())); + + let stored_messages = mock_state.get_messages().await; + + assert!(!stored_messages.iter().any(|s| s.body == msg1.as_bytes())); + assert!(stored_messages.iter().any(|s| s.body == msg2.as_bytes())); } #[tokio_macros::test_basic] @@ -689,13 +778,23 @@ mod test { .await .unwrap(); - let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body); - let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body); + let msg1_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(60)).unwrap()) + .unwrap(); + let msg1 = ProtoStoredMessage::new(0, inbound_msg_a.dht_header.clone(), inbound_msg_a.body, msg1_time); + let msg2_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(30)).unwrap()) + .unwrap(); + let msg2 = ProtoStoredMessage::new(0, inbound_msg_b.dht_header, inbound_msg_b.body, msg2_time); + // Cleartext message let clear_msg = wrap_in_envelope_body!(b"Clear".to_vec()).to_encoded_bytes(); let clear_header = make_dht_inbound_message(&node_identity, clear_msg.clone(), DhtMessageFlags::empty(), false).dht_header; - let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg); + let msg_clear_time = Utc::now() + .checked_sub_signed(OldDuration::from_std(Duration::from_secs(120)).unwrap()) + .unwrap(); + let msg_clear = ProtoStoredMessage::new(0, clear_header, clear_msg, msg_clear_time); let mut message = DecryptedDhtMessage::succeeded( wrap_in_envelope_body!(StoredMessagesResponse { messages: vec![msg1.clone(), msg2, msg_clear], @@ -712,15 +811,21 @@ mod test { ); message.dht_header.message_type = DhtMessageType::SafStoredMessages; - let (dht_requester, mock) = create_dht_actor_mock(1); + let (mut dht_requester, mock) = create_dht_actor_mock(1); rt_handle.spawn(mock.run()); let (saf_response_signal_sender, mut saf_response_signal_receiver) = mpsc::channel(20); + assert!(dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .unwrap() + .is_none()); + let task = MessageHandlerTask::new( Default::default(), spy.to_service::(), requester, - dht_requester, + dht_requester.clone(), peer_manager, OutboundMessageRequester::new(oms_tx), node_identity, @@ -746,5 +851,13 @@ mod test { timeout = Duration::from_secs(20) ); assert_eq!(signals.len(), 1); + + let last_saf_received = dht_requester + .get_metadata::>(DhtMetadataKey::LastSafMessageReceived) + .await + .unwrap() + .unwrap(); + + assert_eq!(last_saf_received, msg2_time); } } diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index c96d4311cb..5d06d85d56 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -43,7 +43,7 @@ use futures::{ StreamExt, }; use log::*; -use std::{cmp, convert::TryFrom, sync::Arc, time::Duration}; +use std::{convert::TryFrom, sync::Arc, time::Duration}; use tari_comms::{ connectivity::{ConnectivityEvent, ConnectivityEventRx, ConnectivityRequester}, peer_manager::{NodeId, PeerFeatures}, @@ -76,7 +76,7 @@ impl FetchStoredMessageQuery { } } - pub fn since(&mut self, since: DateTime) -> &mut Self { + pub fn with_messages_since(&mut self, since: DateTime) -> &mut Self { self.since = Some(since); self } @@ -85,6 +85,10 @@ impl FetchStoredMessageQuery { self.response_type = response_type; self } + + pub fn since(&self) -> Option> { + self.since + } } #[derive(Debug)] @@ -92,6 +96,7 @@ pub enum StoreAndForwardRequest { FetchMessages(FetchStoredMessageQuery, oneshot::Sender>>), InsertMessage(NewStoredMessage, oneshot::Sender>), RemoveMessages(Vec), + RemoveMessagesOlderThan(DateTime), SendStoreForwardRequestToPeer(Box), SendStoreForwardRequestNeighbours, } @@ -132,6 +137,14 @@ impl StoreAndForwardRequester { Ok(()) } + pub async fn remove_messages_older_than(&mut self, threshold: DateTime) -> SafResult<()> { + self.sender + .send(StoreAndForwardRequest::RemoveMessagesOlderThan(threshold)) + .await + .map_err(|_| StoreAndForwardError::RequesterChannelClosed)?; + Ok(()) + } + pub async fn request_saf_messages_from_peer(&mut self, node_id: NodeId) -> SafResult<()> { self.sender .send(StoreAndForwardRequest::SendStoreForwardRequestToPeer(Box::new(node_id))) @@ -297,6 +310,12 @@ impl StoreAndForwardService { ); } }, + RemoveMessagesOlderThan(threshold) => { + match self.database.delete_messages_older_than(threshold.naive_utc()).await { + Ok(_) => trace!(target: LOG_TARGET, "Removed messages older than {}", threshold), + Err(err) => error!(target: LOG_TARGET, "RemoveMessage failed because '{:?}'", err), + } + }, } } @@ -382,9 +401,9 @@ impl StoreAndForwardService { async fn get_saf_request(&mut self) -> SafResult { let request = self .dht_requester - .get_metadata(DhtMetadataKey::OfflineTimestamp) + .get_metadata(DhtMetadataKey::LastSafMessageReceived) .await? - .map(|t| StoredMessagesRequest::since(cmp::min(t, since_utc(self.config.saf_minimum_request_period)))) + .map(StoredMessagesRequest::since) .unwrap_or_else(StoredMessagesRequest::new); Ok(request) @@ -490,7 +509,3 @@ fn since(period: Duration) -> NaiveDateTime { .checked_sub_signed(period) .expect("period overflowed when used with checked_sub_signed") } - -fn since_utc(period: Duration) -> DateTime { - DateTime::::from_utc(since(period), Utc) -} diff --git a/comms/dht/src/test_utils/store_and_forward_mock.rs b/comms/dht/src/test_utils/store_and_forward_mock.rs index 6a623a5764..0dd464c43a 100644 --- a/comms/dht/src/test_utils/store_and_forward_mock.rs +++ b/comms/dht/src/test_utils/store_and_forward_mock.rs @@ -83,7 +83,9 @@ impl StoreAndForwardMockState { } pub async fn take_calls(&self) -> Vec { - self.calls.write().await.drain(..).collect() + let calls = self.calls.write().await.drain(..).collect(); + self.call_count.store(0, Ordering::SeqCst); + calls } } @@ -115,9 +117,16 @@ impl StoreAndForwardMock { trace!(target: LOG_TARGET, "StoreAndForwardMock received request {:?}", req); self.state.add_call(&req).await; match req { - FetchMessages(_, reply_tx) => { + FetchMessages(request, reply_tx) => { + let since = request.since().unwrap(); + let msgs = self.state.stored_messages.read().await; - let _ = reply_tx.send(Ok(msgs.clone())); + + let _ = reply_tx.send(Ok(msgs + .clone() + .drain(..) + .filter(|m| m.stored_at >= since.naive_utc()) + .collect())); }, InsertMessage(msg, reply_tx) => { self.state.stored_messages.write().await.push(StoredMessage { @@ -143,6 +152,13 @@ impl StoreAndForwardMock { }, SendStoreForwardRequestToPeer(_) => {}, SendStoreForwardRequestNeighbours => {}, + RemoveMessagesOlderThan(threshold) => { + self.state + .stored_messages + .write() + .await + .retain(|msg| msg.stored_at >= threshold.naive_utc()); + }, } } }