Skip to content

Commit

Permalink
prunes repair QUIC connections
Browse files Browse the repository at this point in the history
The commit implements lazy eviction for repair QUIC connections.
The cache is allowed to grow to 2 x capacity at which point at least
half of the entries with lowest stake are evicted, resulting in an
amortized O(1) performance.
  • Loading branch information
behzadnouri committed Oct 19, 2023
1 parent f13c78b commit 150934d
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 15 deletions.
143 changes: 128 additions & 15 deletions core/src/repair/quic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ use {
rustls::{Certificate, PrivateKey},
serde_bytes::ByteBuf,
solana_quic_client::nonblocking::quic_client::SkipServerVerification,
solana_runtime::bank_forks::BankForks,
solana_sdk::{packet::PACKET_DATA_SIZE, pubkey::Pubkey, signature::Keypair},
solana_streamer::{
quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate,
},
std::{
cmp::Reverse,
collections::{hash_map::Entry, HashMap},
io::{Cursor, Error as IoError},
net::{IpAddr, SocketAddr, UdpSocket},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
},
time::Duration,
},
thiserror::Error,
Expand All @@ -40,18 +45,20 @@ const CONNECT_SERVER_NAME: &str = "solana-repair";

const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const CONNECTION_CACHE_CAPACITY: usize = 4096;
const CONNECTION_CACHE_CAPACITY: usize = 3072;
const MAX_CONCURRENT_BIDI_STREAMS: VarInt = VarInt::from_u32(512);

const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1);
const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2);
const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3);
const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4);
const CONNECTION_CLOSE_ERROR_CODE_PRUNED: VarInt = VarInt::from_u32(5);

const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN";
const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED";
const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY";
const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED";
const CONNECTION_CLOSE_REASON_PRUNED: &[u8] = b"PRUNED";

pub(crate) type AsyncTryJoinHandle = TryJoin<JoinHandle<()>, JoinHandle<()>>;

Expand Down Expand Up @@ -108,6 +115,7 @@ pub(crate) fn new_quic_endpoint(
socket: UdpSocket,
address: IpAddr,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
) -> Result<(Endpoint, AsyncSender<LocalRequest>, AsyncTryJoinHandle), Error> {
let (cert, key) = new_self_signed_tls_certificate(keypair, address)?;
let server_config = new_server_config(cert.clone(), key.clone())?;
Expand All @@ -124,19 +132,24 @@ pub(crate) fn new_quic_endpoint(
)?
};
endpoint.set_default_client_config(client_config);
let prune_cache_pending = Arc::<AtomicBool>::default();
let cache = Arc::<Mutex<HashMap<Pubkey, Connection>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let router = Arc::<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>::default();
let server_task = runtime.spawn(run_server(
endpoint.clone(),
remote_request_sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_receiver,
remote_request_sender,
bank_forks,
prune_cache_pending,
router,
cache,
));
Expand Down Expand Up @@ -189,6 +202,8 @@ fn new_transport_config() -> TransportConfig {
async fn run_server(
endpoint: Endpoint,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -197,6 +212,8 @@ async fn run_server(
endpoint.clone(),
connecting,
remote_request_sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand All @@ -207,6 +224,8 @@ async fn run_client(
endpoint: Endpoint,
mut receiver: AsyncReceiver<LocalRequest>,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -230,6 +249,8 @@ async fn run_client(
remote_address,
remote_request_sender.clone(),
receiver,
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand Down Expand Up @@ -263,11 +284,21 @@ async fn handle_connecting_error(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) =
handle_connecting(endpoint, connecting, remote_request_sender, router, cache).await
if let Err(err) = handle_connecting(
endpoint,
connecting,
remote_request_sender,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await
{
error!("handle_connecting: {err:?}");
}
Expand All @@ -277,6 +308,8 @@ async fn handle_connecting(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -295,24 +328,37 @@ async fn handle_connecting(
connection,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await;
Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn handle_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
cache_connection(remote_pubkey, connection.clone(), &cache).await;
cache_connection(
remote_pubkey,
connection.clone(),
bank_forks,
prune_cache_pending,
router.clone(),
cache.clone(),
)
.await;
let send_requests_task = tokio::task::spawn(send_requests_task(
endpoint.clone(),
connection.clone(),
Expand Down Expand Up @@ -492,6 +538,8 @@ async fn make_connection_task(
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -500,6 +548,8 @@ async fn make_connection_task(
remote_address,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -514,6 +564,8 @@ async fn make_connection(
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -527,6 +579,8 @@ async fn make_connection(
connection,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -550,25 +604,32 @@ fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
async fn cache_connection(
remote_pubkey: Pubkey,
connection: Connection,
cache: &Mutex<HashMap<Pubkey, Connection>>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
let old = {
let (old, should_prune_cache) = {
let mut cache = cache.lock().await;
if cache.len() >= CONNECTION_CACHE_CAPACITY {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
CONNECTION_CLOSE_REASON_DROPPED,
);
return;
}
cache.insert(remote_pubkey, connection)
(
cache.insert(remote_pubkey, connection),
cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2),
)
};
if let Some(old) = old {
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}
if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) {
tokio::task::spawn(prune_connection_cache(
bank_forks,
prune_cache_pending,
router,
cache,
));
}
}

async fn drop_connection(
Expand All @@ -587,6 +648,49 @@ async fn drop_connection(
}
}

async fn prune_connection_cache(
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
debug_assert!(prune_cache_pending.load(Ordering::Relaxed));
let staked_nodes = {
let root_bank = bank_forks.read().unwrap().root_bank();
root_bank.staked_nodes()
};
{
let mut cache = cache.lock().await;
if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) {
return;
}
let mut connections: Vec<_> = cache
.drain()
.filter(|(_, connection)| connection.close_reason().is_none())
.map(|entry @ (pubkey, _)| {
let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default();
(stake, entry)
})
.collect();
connections
.select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake));
for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_PRUNED,
CONNECTION_CLOSE_REASON_PRUNED,
);
}
cache.extend(
connections
.into_iter()
.take(CONNECTION_CACHE_CAPACITY)
.map(|(_, entry)| entry),
);
prune_cache_pending.store(false, Ordering::Relaxed);
}
router.write().await.retain(|_, sender| !sender.is_closed());
}

impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(_: crossbeam_channel::SendError<T>) -> Self {
Error::ChannelSendError
Expand All @@ -598,6 +702,8 @@ mod tests {
use {
super::*,
itertools::{izip, multiunzip},
solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo},
solana_runtime::bank::Bank,
solana_sdk::signature::Signer,
std::{iter::repeat_with, net::Ipv4Addr, time::Duration},
};
Expand Down Expand Up @@ -625,6 +731,12 @@ mod tests {
repeat_with(crossbeam_channel::unbounded::<RemoteRequest>)
.take(NUM_ENDPOINTS)
.unzip();
let bank_forks = {
let GenesisConfigInfo { genesis_config, .. } =
create_genesis_config(/*mint_lamports:*/ 100_000);
let bank = Bank::new_for_tests(&genesis_config);
Arc::new(RwLock::new(BankForks::new(bank)))
};
let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(
keypairs
.iter()
Expand All @@ -637,6 +749,7 @@ mod tests {
socket,
IpAddr::V4(Ipv4Addr::LOCALHOST),
remote_request_sender,
bank_forks.clone(),
)
.unwrap()
}),
Expand Down
1 change: 1 addition & 0 deletions core/src/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ impl Validator {
.expect("Operator must spin up node with valid QUIC serve-repair address")
.ip(),
repair_quic_endpoint_sender,
bank_forks.clone(),
)
.unwrap();

Expand Down

0 comments on commit 150934d

Please sign in to comment.