diff --git a/base_layer/core/src/base_node/sync/rpc/service.rs b/base_layer/core/src/base_node/sync/rpc/service.rs index e9df7073a2..d10ce6a0ea 100644 --- a/base_layer/core/src/base_node/sync/rpc/service.rs +++ b/base_layer/core/src/base_node/sync/rpc/service.rs @@ -36,30 +36,57 @@ use crate::{ }, }; use log::*; -use std::cmp; +use std::{ + cmp, + sync::{Arc, Weak}, +}; use tari_comms::{ + peer_manager::NodeId, protocol::rpc::{Request, Response, RpcStatus, Streaming}, utils, }; use tari_crypto::tari_utilities::hex::Hex; -use tokio::{sync::mpsc, task}; +use tokio::{ + sync::{mpsc, RwLock}, + task, +}; use tracing::{instrument, span, Instrument, Level}; const LOG_TARGET: &str = "c::base_node::sync_rpc"; pub struct BaseNodeSyncRpcService { db: AsyncBlockchainDb, + active_sessions: RwLock>>, } impl BaseNodeSyncRpcService { pub fn new(db: AsyncBlockchainDb) -> Self { - Self { db } + Self { + db, + active_sessions: RwLock::new(Vec::new()), + } } #[inline] fn db(&self) -> AsyncBlockchainDb { self.db.clone() } + + pub async fn try_add_exclusive_session(&self, peer: NodeId) -> Result, RpcStatus> { + let mut lock = self.active_sessions.write().await; + *lock = lock.drain(..).filter(|l| l.strong_count() > 0).collect(); + debug!(target: LOG_TARGET, "Number of active sync sessions: {}", lock.len()); + + if lock.iter().any(|p| p.upgrade().filter(|p| **p == peer).is_some()) { + return Err(RpcStatus::forbidden( + "Existing sync session found for this client. Only a single session is permitted", + )); + } + + let token = Arc::new(peer); + lock.push(Arc::downgrade(&token)); + Ok(token) + } } #[tari_comms::async_trait] @@ -116,20 +143,26 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ "Initiating block sync with peer `{}` from height {} to {}", peer_node_id, start, end, ); + let session_token = self.try_add_exclusive_session(peer_node_id).await?; // Number of blocks to load and push to the stream before loading the next batch - const BATCH_SIZE: usize = 4; + const BATCH_SIZE: usize = 2; let (tx, rx) = mpsc::channel(BATCH_SIZE); let span = span!(Level::TRACE, "sync_rpc::block_sync::inner_worker"); task::spawn( async move { + // Move token into this task + let session_token = session_token; let iter = NonOverlappingIntegerPairIter::new(start, end + 1, BATCH_SIZE); for (start, end) in iter { if tx.is_closed() { break; } - debug!(target: LOG_TARGET, "Sending blocks #{} - #{}", start, end); + debug!( + target: LOG_TARGET, + "Sending blocks #{} - #{} to '{}'", start, end, session_token + ); let blocks = db .fetch_blocks(start..=end) .await @@ -162,7 +195,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ debug!( target: LOG_TARGET, - "Block sync round complete for peer `{}`.", peer_node_id, + "Block sync round complete for peer `{}`.", session_token, ); } .instrument(span), @@ -208,10 +241,13 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ chunk_size ); + let session_token = self.try_add_exclusive_session(peer_node_id.clone()).await?; let (tx, rx) = mpsc::channel(chunk_size); let span = span!(Level::TRACE, "sync_rpc::sync_headers::inner_worker"); task::spawn( async move { + // Move token into this task + let session_token = session_token; let iter = NonOverlappingIntegerPairIter::new( start_header.height + 1, start_header.height.saturating_add(count).saturating_add(1), @@ -247,7 +283,7 @@ impl BaseNodeSyncService for BaseNodeSyncRpcServ debug!( target: LOG_TARGET, - "Header sync round complete for peer `{}`.", peer_node_id, + "Header sync round complete for peer `{}`.", session_token, ); } .instrument(span), diff --git a/comms/src/protocol/rpc/server/mod.rs b/comms/src/protocol/rpc/server/mod.rs index d02545ec24..11b11fcbe1 100644 --- a/comms/src/protocol/rpc/server/mod.rs +++ b/comms/src/protocol/rpc/server/mod.rs @@ -66,6 +66,7 @@ use crate::{ use futures::{stream, SinkExt, StreamExt}; use prost::Message; use std::{ + borrow::Cow, future::Future, sync::Arc, time::{Duration, Instant}, @@ -539,7 +540,11 @@ where Err(_) => { warn!( target: LOG_TARGET, - "RPC service was not able to complete within the deadline ({:.0?}). Request aborted.", deadline + "RPC service was not able to complete within the deadline ({:.0?}). Request aborted for peer '{}' \ + ({}).", + deadline, + self.node_id, + self.protocol_name() ); return Ok(()); }, @@ -550,7 +555,13 @@ where self.process_body(request_id, deadline, body).await?; }, Err(err) => { - debug!(target: LOG_TARGET, "Service returned an error: {}", err); + debug!( + target: LOG_TARGET, + "(peer: {}, protocol: {}) Service returned an error: {}", + self.node_id, + self.protocol_name(), + err + ); let resp = proto::rpc::RpcResponse { request_id, status: err.as_code(), @@ -565,6 +576,10 @@ where Ok(()) } + fn protocol_name(&self) -> Cow<'_, str> { + String::from_utf8_lossy(&self.protocol) + } + async fn process_body( &mut self, request_id: u32, diff --git a/comms/src/protocol/rpc/status.rs b/comms/src/protocol/rpc/status.rs index c0453f161d..9194d36479 100644 --- a/comms/src/protocol/rpc/status.rs +++ b/comms/src/protocol/rpc/status.rs @@ -90,6 +90,13 @@ impl RpcStatus { } } + pub fn forbidden(details: T) -> Self { + Self { + code: RpcStatusCode::Forbidden, + details: details.to_string(), + } + } + /// Returns a closure that logs the given error and returns a generic general error that does not leak any /// potentially sensitive error information. Use this function with map_err to catch "miscellaneous" errors. pub fn log_internal_error<'a, E: std::error::Error + 'a>(target: &'a str) -> impl Fn(E) -> Self + 'a { @@ -186,6 +193,8 @@ pub enum RpcStatusCode { NotFound = 7, /// RPC protocol error ProtocolError = 8, + /// RPC forbidden error + Forbidden = 9, // The following status represents anything that is not recognised (i.e not one of the above codes). /// Unrecognised RPC status code InvalidRpcStatusCode, @@ -217,6 +226,8 @@ impl From for RpcStatusCode { 5 => MalformedResponse, 6 => General, 7 => NotFound, + 8 => ProtocolError, + 9 => Forbidden, _ => InvalidRpcStatusCode, } } @@ -238,6 +249,8 @@ mod test { assert_eq!(RpcStatusCode::from(Timeout as u32), Timeout); assert_eq!(RpcStatusCode::from(NotFound as u32), NotFound); assert_eq!(RpcStatusCode::from(InvalidRpcStatusCode as u32), InvalidRpcStatusCode); + assert_eq!(RpcStatusCode::from(ProtocolError as u32), ProtocolError); + assert_eq!(RpcStatusCode::from(Forbidden as u32), Forbidden); assert_eq!(RpcStatusCode::from(123), InvalidRpcStatusCode); } }