diff --git a/core/src/repair/quic_endpoint.rs b/core/src/repair/quic_endpoint.rs index f7b445011c937a..bf3a1802144a42 100644 --- a/core/src/repair/quic_endpoint.rs +++ b/core/src/repair/quic_endpoint.rs @@ -21,16 +21,15 @@ use { collections::{hash_map::Entry, HashMap}, io::{Cursor, Error as IoError}, net::{IpAddr, SocketAddr, UdpSocket}, - ops::Deref, sync::Arc, time::Duration, }, thiserror::Error, tokio::{ sync::{ - mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender}, + mpsc::{error::TrySendError, Receiver as AsyncReceiver, Sender as AsyncSender}, oneshot::Sender as OneShotSender, - RwLock, + Mutex, RwLock as AsyncRwLock, }, task::JoinHandle, }, @@ -39,7 +38,8 @@ use { const ALPN_REPAIR_PROTOCOL_ID: &[u8] = b"solana-repair"; const CONNECT_SERVER_NAME: &str = "solana-repair"; -const CLIENT_CHANNEL_CAPACITY: usize = 1 << 14; +const CLIENT_CHANNEL_BUFFER: usize = 1 << 14; +const ROUTER_CHANNEL_BUFFER: usize = 64; const CONNECTION_CACHE_CAPACITY: usize = 4096; const MAX_CONCURRENT_BIDI_STREAMS: VarInt = VarInt::from_u32(512); @@ -54,7 +54,6 @@ const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY"; const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED"; pub(crate) type AsyncTryJoinHandle = TryJoin, JoinHandle<()>>; -type ConnectionCache = HashMap<(SocketAddr, Option), Arc>>>; // Outgoing local requests. pub struct LocalRequest { @@ -125,17 +124,20 @@ pub(crate) fn new_quic_endpoint( )? }; endpoint.set_default_client_config(client_config); - let cache = Arc::>::default(); - let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_CAPACITY); + let cache = Arc::>>::default(); + let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER); + let router = Arc::>>>::default(); let server_task = runtime.spawn(run_server( endpoint.clone(), remote_request_sender.clone(), + router.clone(), cache.clone(), )); let client_task = runtime.spawn(run_client( endpoint.clone(), client_receiver, remote_request_sender, + router, cache, )); let task = futures::future::try_join(server_task, client_task); @@ -187,13 +189,15 @@ fn new_transport_config() -> TransportConfig { async fn run_server( endpoint: Endpoint, remote_request_sender: Sender, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { while let Some(connecting) = endpoint.accept().await { tokio::task::spawn(handle_connecting_error( endpoint.clone(), connecting, remote_request_sender.clone(), + router.clone(), cache.clone(), )); } @@ -203,26 +207,68 @@ async fn run_client( endpoint: Endpoint, mut receiver: AsyncReceiver, remote_request_sender: Sender, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { while let Some(request) = receiver.recv().await { - tokio::task::spawn(send_request_task( + let Some(request) = try_route_request(request, &*router.read().await) else { + continue; + }; + let remote_address = request.remote_address; + let receiver = { + let mut router = router.write().await; + let Some(request) = try_route_request(request, &router) else { + continue; + }; + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + sender.try_send(request).unwrap(); + router.insert(remote_address, sender); + receiver + }; + tokio::task::spawn(make_connection_task( endpoint.clone(), - request, + remote_address, remote_request_sender.clone(), + receiver, + router.clone(), cache.clone(), )); } close_quic_endpoint(&endpoint); + // Drop sender channels to unblock threads waiting on the receiving end. + router.write().await.clear(); +} + +// Routes the local request to respective channel. Drops the request if the +// channel is full. Bounces the request back if the channel is closed or does +// not exist. +fn try_route_request( + request: LocalRequest, + router: &HashMap>, +) -> Option { + match router.get(&request.remote_address) { + None => Some(request), + Some(sender) => match sender.try_send(request) { + Ok(()) => None, + Err(TrySendError::Full(request)) => { + error!("TrySendError::Full {}", request.remote_address); + None + } + Err(TrySendError::Closed(request)) => Some(request), + }, + } } async fn handle_connecting_error( endpoint: Endpoint, connecting: Connecting, remote_request_sender: Sender, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { - if let Err(err) = handle_connecting(endpoint, connecting, remote_request_sender, cache).await { + if let Err(err) = + handle_connecting(endpoint, connecting, remote_request_sender, router, cache).await + { error!("handle_connecting: {err:?}"); } } @@ -231,52 +277,75 @@ async fn handle_connecting( endpoint: Endpoint, connecting: Connecting, remote_request_sender: Sender, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) -> Result<(), Error> { let connection = connecting.await?; let remote_address = connection.remote_address(); let remote_pubkey = get_remote_pubkey(&connection)?; - handle_connection_error( + let receiver = { + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + router.write().await.insert(remote_address, sender); + receiver + }; + handle_connection( endpoint, remote_address, remote_pubkey, connection, remote_request_sender, + receiver, + router, cache, ) .await; Ok(()) } -async fn handle_connection_error( +async fn handle_connection( endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, remote_request_sender: Sender, - cache: Arc>, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, ) { - cache_connection(remote_address, remote_pubkey, connection.clone(), &cache).await; - if let Err(err) = handle_connection( - &endpoint, + cache_connection(remote_pubkey, connection.clone(), &cache).await; + let send_requests_task = tokio::task::spawn(send_requests_task( + endpoint.clone(), + connection.clone(), + receiver, + )); + let recv_requests_task = tokio::task::spawn(recv_requests_task( + endpoint, remote_address, remote_pubkey, - &connection, - &remote_request_sender, - ) - .await - { - drop_connection(remote_address, remote_pubkey, &connection, &cache).await; - error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"); + connection.clone(), + remote_request_sender, + )); + match futures::future::try_join(send_requests_task, recv_requests_task).await { + Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"), + Ok(((), Err(ref err))) => { + error!("recv_requests_task: {remote_pubkey}, {remote_address}, {err:?}"); + } + Ok(((), Ok(()))) => (), + } + drop_connection(remote_pubkey, &connection, &cache).await; + if let Entry::Occupied(entry) = router.write().await.entry(remote_address) { + if entry.get().is_closed() { + entry.remove(); + } } } -async fn handle_connection( - endpoint: &Endpoint, +async fn recv_requests_task( + endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, - connection: &Connection, - remote_request_sender: &Sender, + connection: Connection, + remote_request_sender: Sender, ) -> Result<(), Error> { loop { let (send_stream, recv_stream) = connection.accept_bi().await?; @@ -352,32 +421,39 @@ async fn handle_streams( send_stream.finish().await.map_err(Error::from) } -async fn send_request_task( +async fn send_requests_task( endpoint: Endpoint, - request: LocalRequest, - remote_request_sender: Sender, - cache: Arc>, + connection: Connection, + mut receiver: AsyncReceiver, ) { - if let Err(err) = send_request(&endpoint, request, remote_request_sender, cache).await { - error!("send_request_task: {err:?}"); + while let Some(request) = receiver.recv().await { + tokio::task::spawn(send_request_task( + endpoint.clone(), + connection.clone(), + request, + )); + } +} + +async fn send_request_task(endpoint: Endpoint, connection: Connection, request: LocalRequest) { + if let Err(err) = send_request(endpoint, connection, request).await { + error!("send_request: {err:?}") } } async fn send_request( - endpoint: &Endpoint, + endpoint: Endpoint, + connection: Connection, LocalRequest { - remote_address, + remote_address: _, bytes, num_expected_responses, response_sender, }: LocalRequest, - remote_request_sender: Sender, - cache: Arc>, ) -> Result<(), Error> { // Assert that send won't block. debug_assert_eq!(response_sender.capacity(), None); const READ_TIMEOUT_DURATION: Duration = Duration::from_secs(10); - let connection = get_connection(endpoint, remote_address, remote_request_sender, cache).await?; let (mut send_stream, mut recv_stream) = connection.open_bi().await?; send_stream.write_all(&bytes).await?; send_stream.finish().await?; @@ -405,50 +481,57 @@ async fn send_request( response_sender .send((remote_address, chunk)) .map_err(|err| { - close_quic_endpoint(endpoint); + close_quic_endpoint(&endpoint); Error::from(err) }) }) } -async fn get_connection( - endpoint: &Endpoint, +async fn make_connection_task( + endpoint: Endpoint, remote_address: SocketAddr, remote_request_sender: Sender, - cache: Arc>, -) -> Result { - let entry = get_cache_entry(remote_address, &cache).await; + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, +) { + if let Err(err) = make_connection( + endpoint, + remote_address, + remote_request_sender, + receiver, + router, + cache, + ) + .await { - let connection: Option = entry.read().await.clone(); - if let Some(connection) = connection { - if connection.close_reason().is_none() { - return Ok(connection); - } - } + error!("make_connection: {remote_address}, {err:?}"); } - let connection = { - // Need to write lock here so that only one task initiates - // a new connection to the same remote_address. - let mut entry = entry.write().await; - if let Some(connection) = entry.deref() { - if connection.close_reason().is_none() { - return Ok(connection.clone()); - } - } - let connection = endpoint - .connect(remote_address, CONNECT_SERVER_NAME)? - .await?; - entry.insert(connection).clone() - }; - tokio::task::spawn(handle_connection_error( - endpoint.clone(), +} + +async fn make_connection( + endpoint: Endpoint, + remote_address: SocketAddr, + remote_request_sender: Sender, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, +) -> Result<(), Error> { + let connection = endpoint + .connect(remote_address, CONNECT_SERVER_NAME)? + .await?; + handle_connection( + endpoint, connection.remote_address(), get_remote_pubkey(&connection)?, - connection.clone(), + connection, remote_request_sender, + receiver, + router, cache, - )); - Ok(connection) + ) + .await; + Ok(()) } fn get_remote_pubkey(connection: &Connection) -> Result { @@ -464,27 +547,13 @@ fn get_remote_pubkey(connection: &Connection) -> Result { } } -async fn get_cache_entry( - remote_address: SocketAddr, - cache: &RwLock, -) -> Arc>> { - let key = (remote_address, /*remote_pubkey:*/ None); - if let Some(entry) = cache.read().await.get(&key) { - return entry.clone(); - } - cache.write().await.entry(key).or_default().clone() -} - async fn cache_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, - cache: &RwLock, + cache: &Mutex>, ) { - // The 2nd cache entry with remote_pubkey == None allows to lookup an entry - // only by SocketAddr when establishing outgoing connections. - let entries: [Arc>>; 2] = { - let mut cache = cache.write().await; + let old = { + let mut cache = cache.lock().await; if cache.len() >= CONNECTION_CACHE_CAPACITY { connection.close( CONNECTION_CLOSE_ERROR_CODE_DROPPED, @@ -492,15 +561,9 @@ async fn cache_connection( ); return; } - [Some(remote_pubkey), None].map(|remote_pubkey| { - let key = (remote_address, remote_pubkey); - cache.entry(key).or_default().clone() - }) + cache.insert(remote_pubkey, connection) }; - let mut entry = entries[0].write().await; - *entries[1].write().await = Some(connection.clone()); - if let Some(old) = entry.replace(connection) { - drop(entry); + if let Some(old) = old { old.close( CONNECTION_CLOSE_ERROR_CODE_REPLACED, CONNECTION_CLOSE_REASON_REPLACED, @@ -509,26 +572,19 @@ async fn cache_connection( } async fn drop_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: &Connection, - cache: &RwLock, + cache: &Mutex>, ) { - if connection.close_reason().is_none() { - connection.close( - CONNECTION_CLOSE_ERROR_CODE_DROPPED, - CONNECTION_CLOSE_REASON_DROPPED, - ); - } - let key = (remote_address, Some(remote_pubkey)); - if let Entry::Occupied(entry) = cache.write().await.entry(key) { - if matches!(entry.get().read().await.deref(), - Some(entry) if entry.stable_id() == connection.stable_id()) - { + connection.close( + CONNECTION_CLOSE_ERROR_CODE_DROPPED, + CONNECTION_CLOSE_REASON_DROPPED, + ); + if let Entry::Occupied(entry) = cache.lock().await.entry(remote_pubkey) { + if entry.get().stable_id() == connection.stable_id() { entry.remove(); } } - // Cache entry for (remote_address, None) will be lazily evicted. } impl From> for Error {