Skip to content

Commit

Permalink
fix: only allow one session per peer for block/header sync
Browse files Browse the repository at this point in the history
Adds check for existing block sync sessions per peer.
  • Loading branch information
sdbondi committed Oct 1, 2021
1 parent 0c6dd46 commit 505a223
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 9 deletions.
50 changes: 43 additions & 7 deletions base_layer/core/src/base_node/sync/rpc/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,30 +36,57 @@ use crate::{
},
};
use log::*;
use std::cmp;
use std::{
cmp,
sync::{Arc, Weak},
};
use tari_comms::{
peer_manager::NodeId,
protocol::rpc::{Request, Response, RpcStatus, Streaming},
utils,
};
use tari_crypto::tari_utilities::hex::Hex;
use tokio::{sync::mpsc, task};
use tokio::{
sync::{mpsc, RwLock},
task,
};
use tracing::{instrument, span, Instrument, Level};

const LOG_TARGET: &str = "c::base_node::sync_rpc";

pub struct BaseNodeSyncRpcService<B> {
db: AsyncBlockchainDb<B>,
active_sessions: RwLock<Vec<Weak<NodeId>>>,
}

impl<B: BlockchainBackend + 'static> BaseNodeSyncRpcService<B> {
pub fn new(db: AsyncBlockchainDb<B>) -> Self {
Self { db }
Self {
db,
active_sessions: RwLock::new(Vec::new()),
}
}

#[inline]
fn db(&self) -> AsyncBlockchainDb<B> {
self.db.clone()
}

pub async fn try_add_exclusive_session(&self, peer: NodeId) -> Result<Arc<NodeId>, RpcStatus> {
let mut lock = self.active_sessions.write().await;
*lock = lock.drain(..).filter(|l| l.strong_count() > 0).collect();
debug!(target: LOG_TARGET, "Number of active sync sessions: {}", lock.len());

if lock.iter().any(|p| p.upgrade().filter(|p| **p == peer).is_some()) {
return Err(RpcStatus::forbidden(
"Existing sync session found for this client. Only a single session is permitted",
));
}

let token = Arc::new(peer);
lock.push(Arc::downgrade(&token));
Ok(token)
}
}

#[tari_comms::async_trait]
Expand Down Expand Up @@ -116,20 +143,26 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ
"Initiating block sync with peer `{}` from height {} to {}", peer_node_id, start, end,
);

let session_token = self.try_add_exclusive_session(peer_node_id).await?;
// Number of blocks to load and push to the stream before loading the next batch
const BATCH_SIZE: usize = 4;
const BATCH_SIZE: usize = 2;
let (tx, rx) = mpsc::channel(BATCH_SIZE);

let span = span!(Level::TRACE, "sync_rpc::block_sync::inner_worker");
task::spawn(
async move {
// Move token into this task
let session_token = session_token;
let iter = NonOverlappingIntegerPairIter::new(start, end + 1, BATCH_SIZE);
for (start, end) in iter {
if tx.is_closed() {
break;
}

debug!(target: LOG_TARGET, "Sending blocks #{} - #{}", start, end);
debug!(
target: LOG_TARGET,
"Sending blocks #{} - #{} to '{}'", start, end, session_token
);
let blocks = db
.fetch_blocks(start..=end)
.await
Expand Down Expand Up @@ -162,7 +195,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ

debug!(
target: LOG_TARGET,
"Block sync round complete for peer `{}`.", peer_node_id,
"Block sync round complete for peer `{}`.", session_token,
);
}
.instrument(span),
Expand Down Expand Up @@ -208,10 +241,13 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ
chunk_size
);

let session_token = self.try_add_exclusive_session(peer_node_id.clone()).await?;
let (tx, rx) = mpsc::channel(chunk_size);
let span = span!(Level::TRACE, "sync_rpc::sync_headers::inner_worker");
task::spawn(
async move {
// Move token into this task
let session_token = session_token;
let iter = NonOverlappingIntegerPairIter::new(
start_header.height + 1,
start_header.height.saturating_add(count).saturating_add(1),
Expand Down Expand Up @@ -247,7 +283,7 @@ impl<B: BlockchainBackend + 'static> BaseNodeSyncService for BaseNodeSyncRpcServ

debug!(
target: LOG_TARGET,
"Header sync round complete for peer `{}`.", peer_node_id,
"Header sync round complete for peer `{}`.", session_token,
);
}
.instrument(span),
Expand Down
19 changes: 17 additions & 2 deletions comms/src/protocol/rpc/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ use crate::{
use futures::{stream, SinkExt, StreamExt};
use prost::Message;
use std::{
borrow::Cow,
future::Future,
sync::Arc,
time::{Duration, Instant},
Expand Down Expand Up @@ -539,7 +540,11 @@ where
Err(_) => {
warn!(
target: LOG_TARGET,
"RPC service was not able to complete within the deadline ({:.0?}). Request aborted.", deadline
"RPC service was not able to complete within the deadline ({:.0?}). Request aborted for peer '{}' \
({}).",
deadline,
self.node_id,
self.protocol_name()
);
return Ok(());
},
Expand All @@ -550,7 +555,13 @@ where
self.process_body(request_id, deadline, body).await?;
},
Err(err) => {
debug!(target: LOG_TARGET, "Service returned an error: {}", err);
debug!(
target: LOG_TARGET,
"(peer: {}, protocol: {}) Service returned an error: {}",
self.node_id,
self.protocol_name(),
err
);
let resp = proto::rpc::RpcResponse {
request_id,
status: err.as_code(),
Expand All @@ -565,6 +576,10 @@ where
Ok(())
}

fn protocol_name(&self) -> Cow<'_, str> {
String::from_utf8_lossy(&self.protocol)
}

async fn process_body(
&mut self,
request_id: u32,
Expand Down
13 changes: 13 additions & 0 deletions comms/src/protocol/rpc/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ impl RpcStatus {
}
}

pub fn forbidden<T: ToString>(details: T) -> Self {
Self {
code: RpcStatusCode::Forbidden,
details: details.to_string(),
}
}

/// Returns a closure that logs the given error and returns a generic general error that does not leak any
/// potentially sensitive error information. Use this function with map_err to catch "miscellaneous" errors.
pub fn log_internal_error<'a, E: std::error::Error + 'a>(target: &'a str) -> impl Fn(E) -> Self + 'a {
Expand Down Expand Up @@ -186,6 +193,8 @@ pub enum RpcStatusCode {
NotFound = 7,
/// RPC protocol error
ProtocolError = 8,
/// RPC forbidden error
Forbidden = 9,
// The following status represents anything that is not recognised (i.e not one of the above codes).
/// Unrecognised RPC status code
InvalidRpcStatusCode,
Expand Down Expand Up @@ -217,6 +226,8 @@ impl From<u32> for RpcStatusCode {
5 => MalformedResponse,
6 => General,
7 => NotFound,
8 => ProtocolError,
9 => Forbidden,
_ => InvalidRpcStatusCode,
}
}
Expand All @@ -238,6 +249,8 @@ mod test {
assert_eq!(RpcStatusCode::from(Timeout as u32), Timeout);
assert_eq!(RpcStatusCode::from(NotFound as u32), NotFound);
assert_eq!(RpcStatusCode::from(InvalidRpcStatusCode as u32), InvalidRpcStatusCode);
assert_eq!(RpcStatusCode::from(ProtocolError as u32), ProtocolError);
assert_eq!(RpcStatusCode::from(Forbidden as u32), Forbidden);
assert_eq!(RpcStatusCode::from(123), InvalidRpcStatusCode);
}
}

0 comments on commit 505a223

Please sign in to comment.