Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prunes repair QUIC connections #33775

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 129 additions & 15 deletions core/src/repair/quic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ use {
rustls::{Certificate, PrivateKey},
serde_bytes::ByteBuf,
solana_quic_client::nonblocking::quic_client::SkipServerVerification,
solana_runtime::bank_forks::BankForks,
solana_sdk::{packet::PACKET_DATA_SIZE, pubkey::Pubkey, signature::Keypair},
solana_streamer::{
quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate,
},
std::{
cmp::Reverse,
collections::{hash_map::Entry, HashMap},
io::{Cursor, Error as IoError},
net::{IpAddr, SocketAddr, UdpSocket},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
},
time::Duration,
},
thiserror::Error,
Expand All @@ -40,18 +45,20 @@ const CONNECT_SERVER_NAME: &str = "solana-repair";

const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const CONNECTION_CACHE_CAPACITY: usize = 4096;
const CONNECTION_CACHE_CAPACITY: usize = 3072;
const MAX_CONCURRENT_BIDI_STREAMS: VarInt = VarInt::from_u32(512);

const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1);
const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2);
const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3);
const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4);
const CONNECTION_CLOSE_ERROR_CODE_PRUNED: VarInt = VarInt::from_u32(5);

const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN";
const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED";
const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY";
const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED";
const CONNECTION_CLOSE_REASON_PRUNED: &[u8] = b"PRUNED";

pub(crate) type AsyncTryJoinHandle = TryJoin<JoinHandle<()>, JoinHandle<()>>;

Expand Down Expand Up @@ -108,6 +115,7 @@ pub(crate) fn new_quic_endpoint(
socket: UdpSocket,
address: IpAddr,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
) -> Result<(Endpoint, AsyncSender<LocalRequest>, AsyncTryJoinHandle), Error> {
let (cert, key) = new_self_signed_tls_certificate(keypair, address)?;
let server_config = new_server_config(cert.clone(), key.clone())?;
Expand All @@ -124,19 +132,24 @@ pub(crate) fn new_quic_endpoint(
)?
};
endpoint.set_default_client_config(client_config);
let prune_cache_pending = Arc::<AtomicBool>::default();
let cache = Arc::<Mutex<HashMap<Pubkey, Connection>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let router = Arc::<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>::default();
let server_task = runtime.spawn(run_server(
endpoint.clone(),
remote_request_sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_receiver,
remote_request_sender,
bank_forks,
prune_cache_pending,
router,
cache,
));
Expand Down Expand Up @@ -189,6 +202,8 @@ fn new_transport_config() -> TransportConfig {
async fn run_server(
endpoint: Endpoint,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -197,6 +212,8 @@ async fn run_server(
endpoint.clone(),
connecting,
remote_request_sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand All @@ -207,6 +224,8 @@ async fn run_client(
endpoint: Endpoint,
mut receiver: AsyncReceiver<LocalRequest>,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -230,6 +249,8 @@ async fn run_client(
remote_address,
remote_request_sender.clone(),
receiver,
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand Down Expand Up @@ -263,11 +284,21 @@ async fn handle_connecting_error(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) =
handle_connecting(endpoint, connecting, remote_request_sender, router, cache).await
if let Err(err) = handle_connecting(
endpoint,
connecting,
remote_request_sender,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await
{
error!("handle_connecting: {err:?}");
}
Expand All @@ -277,6 +308,8 @@ async fn handle_connecting(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -295,24 +328,37 @@ async fn handle_connecting(
connection,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await;
Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn handle_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
cache_connection(remote_pubkey, connection.clone(), &cache).await;
cache_connection(
remote_pubkey,
connection.clone(),
bank_forks,
prune_cache_pending,
router.clone(),
cache.clone(),
)
.await;
let send_requests_task = tokio::task::spawn(send_requests_task(
endpoint.clone(),
connection.clone(),
Expand Down Expand Up @@ -492,6 +538,8 @@ async fn make_connection_task(
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -500,6 +548,8 @@ async fn make_connection_task(
remote_address,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -514,6 +564,8 @@ async fn make_connection(
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -527,6 +579,8 @@ async fn make_connection(
connection,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -550,25 +604,32 @@ fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
async fn cache_connection(
remote_pubkey: Pubkey,
connection: Connection,
cache: &Mutex<HashMap<Pubkey, Connection>>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
let old = {
let (old, should_prune_cache) = {
let mut cache = cache.lock().await;
if cache.len() >= CONNECTION_CACHE_CAPACITY {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
CONNECTION_CLOSE_REASON_DROPPED,
);
return;
}
cache.insert(remote_pubkey, connection)
(
cache.insert(remote_pubkey, connection),
cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2),
)
};
if let Some(old) = old {
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}
if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) {
tokio::task::spawn(prune_connection_cache(
bank_forks,
prune_cache_pending,
router,
cache,
));
}
}

async fn drop_connection(
Expand All @@ -587,6 +648,50 @@ async fn drop_connection(
}
}

async fn prune_connection_cache(
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
debug_assert!(prune_cache_pending.load(Ordering::Relaxed));
let staked_nodes = {
let root_bank = bank_forks.read().unwrap().root_bank();
root_bank.staked_nodes()
};
{
let mut cache = cache.lock().await;
if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) {
prune_cache_pending.store(false, Ordering::Relaxed);
return;
}
let mut connections: Vec<_> = cache
.drain()
.filter(|(_, connection)| connection.close_reason().is_none())
.map(|entry @ (pubkey, _)| {
let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default();
(stake, entry)
})
.collect();
connections
.select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake));
for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_PRUNED,
CONNECTION_CLOSE_REASON_PRUNED,
);
}
cache.extend(
connections
.into_iter()
.take(CONNECTION_CACHE_CAPACITY)
.map(|(_, entry)| entry),
);
prune_cache_pending.store(false, Ordering::Relaxed);
}
router.write().await.retain(|_, sender| !sender.is_closed());
}

impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(_: crossbeam_channel::SendError<T>) -> Self {
Error::ChannelSendError
Expand All @@ -598,6 +703,8 @@ mod tests {
use {
super::*,
itertools::{izip, multiunzip},
solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo},
solana_runtime::bank::Bank,
solana_sdk::signature::Signer,
std::{iter::repeat_with, net::Ipv4Addr, time::Duration},
};
Expand Down Expand Up @@ -625,6 +732,12 @@ mod tests {
repeat_with(crossbeam_channel::unbounded::<RemoteRequest>)
.take(NUM_ENDPOINTS)
.unzip();
let bank_forks = {
let GenesisConfigInfo { genesis_config, .. } =
create_genesis_config(/*mint_lamports:*/ 100_000);
let bank = Bank::new_for_tests(&genesis_config);
Arc::new(RwLock::new(BankForks::new(bank)))
};
let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(
keypairs
.iter()
Expand All @@ -637,6 +750,7 @@ mod tests {
socket,
IpAddr::V4(Ipv4Addr::LOCALHOST),
remote_request_sender,
bank_forks.clone(),
)
.unwrap()
}),
Expand Down
1 change: 1 addition & 0 deletions core/src/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ impl Validator {
.expect("Operator must spin up node with valid QUIC serve-repair address")
.ip(),
repair_quic_endpoint_sender,
bank_forks.clone(),
)
.unwrap();

Expand Down