diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 3fb171065..0d8d4ff2f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -73,7 +73,7 @@ jobs: run: curl -LsSf https://get.nexte.st/latest/linux | tar zxf - -C ${CARGO_HOME:-~/.cargo}/bin - name: Build run: cargo build -p qt -p quilkin -p quilkin-xds --tests - - run: cargo nextest run --no-tests=pass -p qt -p quilkin -p quilkin-xds quilkin + - run: cargo nextest run -p qt -p quilkin -p quilkin-xds quilkin build: name: Build diff --git a/Cargo.lock b/Cargo.lock index 3be76e4ff..138b75cc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,18 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-channel" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b47800b0be77592da0afd425cc03468052844aff33b84e33cc696f64e77b6a" +dependencies = [ + "concurrent-queue", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -2422,6 +2434,7 @@ dependencies = [ name = "qt" version = "0.1.0" dependencies = [ + "async-channel", "once_cell", "quilkin", "rand", @@ -2444,6 +2457,7 @@ name = "quilkin" version = "0.10.0-dev" dependencies = [ "arc-swap", + "async-channel", "async-stream", "async-trait", "base64 0.22.1", diff --git a/Cargo.toml b/Cargo.toml index f7c2bd8cd..f871e4820 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -87,12 +87,12 @@ quilkin-proto.workspace = true # Crates.io arc-swap.workspace = true +async-channel.workspace = true async-stream.workspace = true base64.workspace = true base64-serde = "0.8.0" bytes = { version = "1.8.0", features = ["serde"] } cached.workspace = true -cfg-if = "1.0" crossbeam-utils = { version = "0.8", optional = true } clap = { version = "4.5.21", features = ["cargo", "derive", "env"] } dashmap = { version = "6.1", features = ["serde"] } @@ -153,6 +153,7 @@ hickory-resolver = { version = "0.24", features = [ async-trait = "0.1.83" strum = "0.26" strum_macros = "0.26" +cfg-if = "1.0.0" libflate = "2.1.0" form_urlencoded = "1.2.1" enum_dispatch = "0.3.13" @@ -193,6 +194,7 @@ edition = "2021" [workspace.dependencies] arc-swap = { version = "1.7.1", features = ["serde"] } +async-channel = "2.3.1" async-stream = "0.3.6" base64 = "0.22.1" cached = { version = "0.54", default-features = false } diff --git a/crates/test/Cargo.toml b/crates/test/Cargo.toml index 85d220014..7a90f45ec 100644 --- a/crates/test/Cargo.toml +++ b/crates/test/Cargo.toml @@ -24,6 +24,7 @@ publish = false workspace = true [dependencies] +async-channel.workspace = true once_cell.workspace = true quilkin.workspace = true rand.workspace = true diff --git a/crates/test/tests/proxy.rs b/crates/test/tests/proxy.rs index 2272af221..227cf6648 100644 --- a/crates/test/tests/proxy.rs +++ b/crates/test/tests/proxy.rs @@ -1,5 +1,5 @@ use qt::*; -use quilkin::{components::proxy, test::TestConfig}; +use quilkin::test::TestConfig; use tracing::Instrument as _; trace_test!(server, { @@ -87,7 +87,8 @@ trace_test!(uring_receiver, { let (mut packet_rx, endpoint) = sb.server("server"); - let (error_sender, mut error_receiver) = tokio::sync::mpsc::channel::(20); + let (error_sender, mut error_receiver) = + tokio::sync::mpsc::channel::(20); tokio::task::spawn( async move { @@ -104,32 +105,37 @@ trace_test!(uring_receiver, { config .clusters .modify(|clusters| clusters.insert_default([endpoint.into()].into())); + let (tx, rx) = async_channel::unbounded(); + let (_shutdown_tx, shutdown_rx) = + quilkin::make_shutdown_channel(quilkin::ShutdownKind::Testing); let socket = sb.client(); let (ws, addr) = sb.socket(); - let pending_sends = proxy::PendingSends::new(1).unwrap(); - // we'll test a single DownstreamReceiveWorkerConfig - proxy::packet_router::DownstreamReceiveWorkerConfig { + let ready = quilkin::components::proxy::packet_router::DownstreamReceiveWorkerConfig { worker_id: 1, port: addr.port(), + upstream_receiver: rx.clone(), config: config.clone(), error_sender, buffer_pool: quilkin::test::BUFFER_POOL.clone(), - sessions: proxy::SessionPool::new( + sessions: quilkin::components::proxy::SessionPool::new( config, - vec![pending_sends.0.clone()], + tx, BUFFER_POOL.clone(), + shutdown_rx.clone(), ), } - .spawn(pending_sends) + .spawn(shutdown_rx) .await .expect("failed to spawn task"); // Drop the socket, otherwise it can drop(ws); + ready.recv().unwrap(); + let msg = "hello-downstream"; tracing::debug!("sending packet"); socket.send_to(msg.as_bytes(), addr).await.unwrap(); @@ -152,33 +158,36 @@ trace_test!( .clusters .modify(|clusters| clusters.insert_default([endpoint.into()].into())); - let pending_sends: Vec<_> = [ - proxy::PendingSends::new(1).unwrap(), - proxy::PendingSends::new(1).unwrap(), - proxy::PendingSends::new(1).unwrap(), - ] - .into_iter() - .collect(); + let (tx, rx) = async_channel::unbounded(); + let (_shutdown_tx, shutdown_rx) = + quilkin::make_shutdown_channel(quilkin::ShutdownKind::Testing); - let sessions = proxy::SessionPool::new( + let sessions = quilkin::components::proxy::SessionPool::new( config.clone(), - pending_sends.iter().map(|ps| ps.0.clone()).collect(), + tx, BUFFER_POOL.clone(), + shutdown_rx.clone(), ); const WORKER_COUNT: usize = 3; let (socket, addr) = sb.socket(); - proxy::packet_router::spawn_receivers( + let workers = quilkin::components::proxy::packet_router::spawn_receivers( config, socket, - pending_sends, + WORKER_COUNT, &sessions, + rx, BUFFER_POOL.clone(), + shutdown_rx, ) .await .unwrap(); + for wn in workers { + wn.recv().unwrap(); + } + let socket = std::sync::Arc::new(sb.client()); let msg = "recv-from"; diff --git a/src/collections/ttl.rs b/src/collections/ttl.rs index c6cb015f9..6522eb5d4 100644 --- a/src/collections/ttl.rs +++ b/src/collections/ttl.rs @@ -55,13 +55,11 @@ impl Value { /// Get the expiration time for this value. The returned value is the /// number of seconds relative to some reference point (e.g UNIX_EPOCH), based /// on the clock being used. - #[inline] fn expiration_secs(&self) -> u64 { self.expires_at.load(Ordering::Relaxed) } /// Update the value's expiration time to (now + TTL). - #[inline] fn update_expiration(&self, ttl: Duration) { match self.clock.compute_expiration_secs(ttl) { Ok(new_expiration_time) => { @@ -162,7 +160,6 @@ where /// Returns the current time as the number of seconds relative to some initial /// reference point (e.g UNIX_EPOCH), based on the clock implementation being used. /// In tests, this will be driven by [`tokio::time`] - #[inline] pub(crate) fn now_relative_secs(&self) -> u64 { self.0.clock.now_relative_secs().unwrap_or_default() } @@ -240,12 +237,6 @@ where self.0.inner.remove(&key).is_some() } - /// Removes all entries from the map - #[inline] - pub fn clear(&self) { - self.0.inner.clear(); - } - /// Returns an entry for in-place updates of the specified key-value pair. /// Note: This acquires a write lock on the map's shard that corresponds /// to the entry. diff --git a/src/components/proxy.rs b/src/components/proxy.rs index f06af458a..bbf6a8ffe 100644 --- a/src/components/proxy.rs +++ b/src/components/proxy.rs @@ -18,79 +18,8 @@ mod error; pub mod packet_router; mod sessions; -cfg_if::cfg_if! { - if #[cfg(target_os = "linux")] { - pub(crate) mod io_uring_shared; - pub(crate) type PacketSendReceiver = io_uring_shared::EventFd; - pub(crate) type PacketSendSender = io_uring_shared::EventFdWriter; - } else { - pub(crate) type PacketSendReceiver = tokio::sync::watch::Receiver; - pub(crate) type PacketSendSender = tokio::sync::watch::Sender; - } -} - -/// A simple packet queue that signals when a packet is pushed -/// -/// For io_uring this notifies an eventfd that will be processed on the next -/// completion loop -#[derive(Clone)] -pub struct PendingSends { - packets: Arc>>, - notify: PacketSendSender, -} - -impl PendingSends { - pub fn new(capacity: usize) -> std::io::Result<(Self, PacketSendReceiver)> { - #[cfg(target_os = "linux")] - let (notify, rx) = { - let rx = io_uring_shared::EventFd::new()?; - (rx.writer(), rx) - }; - #[cfg(not(target_os = "linux"))] - let (notify, rx) = tokio::sync::watch::channel(true); - - Ok(( - Self { - packets: Arc::new(parking_lot::Mutex::new(Vec::with_capacity(capacity))), - notify, - }, - rx, - )) - } - - #[inline] - pub(crate) fn capacity(&self) -> usize { - self.packets.lock().capacity() - } - - /// Pushes a packet onto the queue to be sent, signalling a sender that - /// it's available - #[inline] - pub(crate) fn push(&self, packet: SendPacket) { - self.packets.lock().push(packet); - #[cfg(target_os = "linux")] - self.notify.write(1); - #[cfg(not(target_os = "linux"))] - let _ = self.notify.send(true); - } - - /// Called to shutdown the consumer side of the sends (ie the io loop that is - /// actually dequing and sending packets) - #[inline] - pub(crate) fn shutdown_receiver(&self) { - #[cfg(target_os = "linux")] - self.notify.write(0xdeadbeef); - #[cfg(not(target_os = "linux"))] - let _ = self.notify.send(false); - } - - /// Swaps the current queue with an empty one so we only lock for a pointer swap - #[inline] - pub fn swap(&self, mut swap: Vec) -> Vec { - swap.clear(); - std::mem::replace(&mut self.packets.lock(), swap) - } -} +#[cfg(target_os = "linux")] +pub(crate) mod io_uring_shared; use super::RunArgs; pub use error::{ErrorMap, PipelineError}; @@ -104,11 +33,8 @@ use std::{ }; pub struct SendPacket { - /// The destination address of the packet - pub destination: socket2::SockAddr, - /// The packet data being sent + pub destination: SocketAddr, pub data: crate::pool::FrozenPoolBuffer, - /// The asn info for the sender, used for metrics pub asn_info: Option, } @@ -282,6 +208,18 @@ impl Proxy { )); } + let id = config.id.load(); + let num_workers = self.num_workers.get(); + + let (upstream_sender, upstream_receiver) = async_channel::bounded(250); + let buffer_pool = Arc::new(crate::pool::BufferPool::new(num_workers, 64 * 1024)); + let sessions = SessionPool::new( + config.clone(), + upstream_sender, + buffer_pool.clone(), + shutdown_rx.clone(), + ); + #[allow(clippy::type_complexity)] const SUBS: &[(&str, &[(&str, Vec)])] = &[ ( @@ -309,8 +247,6 @@ impl Proxy { *lock = Some(check.clone()); } - let id = config.id.load(); - std::thread::Builder::new() .name("proxy-subscription".into()) .spawn({ @@ -355,25 +291,14 @@ impl Proxy { .expect("failed to spawn proxy-subscription thread"); } - let num_workers = self.num_workers.get(); - let buffer_pool = Arc::new(crate::pool::BufferPool::new(num_workers, 2 * 1024)); - - let mut worker_sends = Vec::with_capacity(num_workers); - let mut session_sends = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let psends = PendingSends::new(15)?; - session_sends.push(psends.0.clone()); - worker_sends.push(psends); - } - - let sessions = SessionPool::new(config.clone(), session_sends, buffer_pool.clone()); - - packet_router::spawn_receivers( + let worker_notifications = packet_router::spawn_receivers( config.clone(), self.socket, - worker_sends, + num_workers, &sessions, + upstream_receiver, buffer_pool, + shutdown_rx.clone(), ) .await?; @@ -385,6 +310,10 @@ impl Proxy { crate::net::phoenix::Phoenix::new(crate::codec::qcmp::QcmpMeasurement::new()?), )?; + for notification in worker_notifications { + let _ = notification.recv(); + } + tracing::info!("Quilkin is ready"); if let Some(initialized) = initialized { let _ = initialized.send(()); @@ -395,7 +324,17 @@ impl Proxy { .await .map_err(|error| eyre::eyre!(error))?; - sessions.shutdown(*shutdown_rx.borrow() == crate::ShutdownKind::Normal); + if *shutdown_rx.borrow() == crate::ShutdownKind::Normal { + tracing::info!(sessions=%sessions.sessions().len(), "waiting for active sessions to expire"); + + let interval = std::time::Duration::from_millis(100); + + while sessions.sessions().is_not_empty() { + tokio::time::sleep(interval).await; + tracing::debug!(sessions=%sessions.sessions().len(), "sessions still active"); + } + tracing::info!("all sessions expired"); + } Ok(()) } diff --git a/src/components/proxy/io_uring_shared.rs b/src/components/proxy/io_uring_shared.rs index 7593db7fd..ccc14b654 100644 --- a/src/components/proxy/io_uring_shared.rs +++ b/src/components/proxy/io_uring_shared.rs @@ -21,9 +21,10 @@ //! enough that it doesn't make sense to share the same code use crate::{ - components::proxy::{self, PendingSends, PipelineError, SendPacket}, + components::proxy::{self, PipelineError}, metrics, - pool::PoolBuffer, + net::maxmind_db::MetricsIpNetEntry, + pool::{FrozenPoolBuffer, PoolBuffer}, time::UtcTimestamp, }; use io_uring::{squeue::Entry, types::Fd}; @@ -37,11 +38,30 @@ use std::{ /// /// We use eventfd to signal to io uring loops from async tasks, it is essentially /// the equivalent of a signalling 64 bit cross-process atomic -pub struct EventFd { +pub(crate) struct EventFd { fd: std::os::fd::OwnedFd, val: u64, } +#[derive(Clone)] +pub(crate) struct EventFdWriter { + fd: i32, +} + +impl EventFdWriter { + #[inline] + pub(crate) fn write(&self, val: u64) { + // SAFETY: we have a valid descriptor, and most of the errors that apply + // to the general write call that eventfd_write wraps are not applicable + // + // Note that while the docs state eventfd_write is glibc, it is implemented + // on musl as well, but really is just a write with 8 bytes + unsafe { + libc::eventfd_write(self.fd, val); + } + } +} + impl EventFd { #[inline] pub(crate) fn new() -> std::io::Result { @@ -82,30 +102,48 @@ impl EventFd { } } +struct RecvPacket { + /// The buffer filled with data during recv_from + buffer: PoolBuffer, + /// The IP of the sender + source: std::net::SocketAddr, +} + +struct SendPacket { + /// The destination address of the packet + destination: SockAddr, + /// The packet data being sent + buffer: FrozenPoolBuffer, + /// The asn info for the sender, used for metrics + asn_info: Option, +} + +/// A simple double buffer for queing packets that need to be sent, each enqueue +/// notifies an eventfd that sends are available #[derive(Clone)] -pub(crate) struct EventFdWriter { - fd: i32, +struct PendingSends { + packets: Arc>>, + notify: EventFdWriter, } -impl EventFdWriter { - #[inline] - pub(crate) fn write(&self, val: u64) { - // SAFETY: we have a valid descriptor, and most of the errors that apply - // to the general write call that eventfd_write wraps are not applicable - // - // Note that while the docs state eventfd_write is glibc, it is implemented - // on musl as well, but really is just a write with 8 bytes - unsafe { - libc::eventfd_write(self.fd, val); +impl PendingSends { + pub fn new(notify: EventFdWriter) -> Self { + Self { + packets: Default::default(), + notify, } } -} -struct RecvPacket { - /// The buffer filled with data during recv_from - buffer: PoolBuffer, - /// The IP of the sender - source: std::net::SocketAddr, + #[inline] + pub fn push(&self, packet: SendPacket) { + self.packets.lock().push(packet); + self.notify.write(1); + } + + #[inline] + pub fn swap(&self, swap: Vec) -> Vec { + std::mem::replace(&mut self.packets.lock(), swap) + } } enum LoopPacketInner { @@ -154,8 +192,8 @@ impl LoopPacket { // For sends, the length of the buffer is the actual number of initialized bytes, // and note that iov_base is a *mut even though for sends the buffer is not actually // mutated - self.io_vec.iov_base = send.data.as_ptr() as *mut u8 as *mut _; - self.io_vec.iov_len = send.data.len(); + self.io_vec.iov_base = send.buffer.as_ptr() as *mut u8 as *mut _; + self.io_vec.iov_len = send.buffer.len(); // SAFETY: both pointers are valid at this point, with the same size unsafe { @@ -224,8 +262,62 @@ pub enum PacketProcessorCtx { }, } +pub enum PacketReceiver { + Router(crate::components::proxy::sessions::DownstreamReceiver), + SessionPool(tokio::sync::mpsc::Receiver), +} + +/// Spawns worker tasks +/// +/// One task processes received packets, notifying the io-uring loop when a +/// packet finishes processing, the other receives packets to send and notifies +/// the io-uring loop when there are 1 or more packets available to be sent +fn spawn_workers( + rt: &tokio::runtime::Runtime, + receiver: PacketReceiver, + pending_sends: PendingSends, + mut shutdown_rx: crate::ShutdownRx, + shutdown_event: EventFdWriter, +) { + // Spawn a task that just monitors the shutdown receiver to notify the io-uring loop to exit + rt.spawn(async move { + // The result is uninteresting, either a shutdown has been signalled, or all senders have been dropped + // which equates to the same thing + let _ = shutdown_rx.changed().await; + shutdown_event.write(1); + }); + + match receiver { + PacketReceiver::Router(upstream_receiver) => { + rt.spawn(async move { + while let Ok(packet) = upstream_receiver.recv().await { + let packet = SendPacket { + destination: packet.destination.into(), + buffer: packet.data, + asn_info: packet.asn_info, + }; + pending_sends.push(packet); + } + }); + } + PacketReceiver::SessionPool(mut downstream_receiver) => { + rt.spawn(async move { + while let Some(packet) = downstream_receiver.recv().await { + let packet = SendPacket { + destination: packet.destination.into(), + buffer: packet.data, + asn_info: packet.asn_info, + }; + pending_sends.push(packet); + } + }); + } + } +} + fn process_packet( ctx: &mut PacketProcessorCtx, + packet_processed_event: &EventFdWriter, packet: RecvPacket, last_received_at: &mut Option, ) { @@ -257,6 +349,8 @@ fn process_packet( error_acc, destinations, ); + + packet_processed_event.write(1); } PacketProcessorCtx::SessionPool { pool, port, .. } => { let mut last_received_at = None; @@ -267,6 +361,8 @@ fn process_packet( *port, &mut last_received_at, ); + + packet_processed_event.write(1); } } } @@ -281,8 +377,10 @@ enum Token { Recv { key: usize }, /// Packet sent Send { key: usize }, - /// One or more packets are ready to be sent OR shutdown of the loop is requested + /// One or more packets are ready to be sent PendingsSends, + /// Loop shutdown requested + Shutdown, } struct LoopCtx<'uring> { @@ -410,6 +508,7 @@ impl<'uring> LoopCtx<'uring> { } pub struct IoUringLoop { + runtime: tokio::runtime::Runtime, socket: crate::net::DualStackLocalSocket, concurrent_sends: usize, } @@ -419,7 +518,14 @@ impl IoUringLoop { concurrent_sends: u16, socket: crate::net::DualStackLocalSocket, ) -> Result { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .max_blocking_threads(1) + .worker_threads(3) + .build()?; + Ok(Self { + runtime, concurrent_sends: concurrent_sends as _, socket, }) @@ -429,18 +535,28 @@ impl IoUringLoop { self, thread_name: String, mut ctx: PacketProcessorCtx, - pending_sends: (PendingSends, EventFd), + receiver: PacketReceiver, buffer_pool: Arc, - ) -> Result<(), PipelineError> { + shutdown: crate::ShutdownRx, + ) -> Result, PipelineError> { let dispatcher = tracing::dispatcher::get_default(|d| d.clone()); + let (tx, rx) = std::sync::mpsc::channel(); + let rt = self.runtime; let socket = self.socket; let concurrent_sends = self.concurrent_sends; let mut ring = io_uring::IoUring::new((concurrent_sends + 3) as _)?; - let mut pending_sends_event = pending_sends.1; - let pending_sends = pending_sends.0; + // Used to notify the uring loop when 1 or more packets have been queued + // up to be sent to a remote address + let mut pending_sends_event = EventFd::new()?; + // Used to notify the uring when a received packet has finished + // processing and we can perform another recv, as we (currently) only + // ever process a single packet at a time + let process_event = EventFd::new()?; + // Used to notify the uring loop to shutdown + let mut shutdown_event = EventFd::new()?; std::thread::Builder::new() .name(thread_name) @@ -448,11 +564,14 @@ impl IoUringLoop { crate::metrics::game_traffic_tasks().inc(); let _guard = tracing::dispatcher::set_default(&dispatcher); - let tokens = slab::Slab::with_capacity(concurrent_sends + 1 + 1); + let tokens = slab::Slab::with_capacity(concurrent_sends + 1 + 1 + 1); let loop_packets = slab::Slab::with_capacity(concurrent_sends + 1); + // Create an eventfd to notify the uring thread (this one) of + // pending sends + let pending_sends = PendingSends::new(pending_sends_event.writer()); // Just double buffer the pending writes for simplicity - let mut double_pending_sends = Vec::with_capacity(pending_sends.capacity()); + let mut double_pending_sends = Vec::new(); // When sending packets, this is the direction used when updating metrics let send_dir = if matches!(ctx, PacketProcessorCtx::Router { .. }) { @@ -461,6 +580,16 @@ impl IoUringLoop { metrics::READ }; + // Spawn the worker tasks that process in an async context unlike + // our io-uring loop below + spawn_workers( + &rt, + receiver, + pending_sends.clone(), + shutdown, + shutdown_event.writer(), + ); + let (submitter, sq, mut cq) = ring.split(); let mut loop_ctx = LoopCtx { @@ -474,12 +603,16 @@ impl IoUringLoop { loop_ctx.enqueue_recv(buffer_pool.clone().alloc()); loop_ctx .push_with_token(pending_sends_event.io_uring_entry(), Token::PendingsSends); + loop_ctx.push_with_token(shutdown_event.io_uring_entry(), Token::Shutdown); // Sync always needs to be called when entries have been pushed // onto the submission queue for the loop to actually function (ie, similar to await on futures) loop_ctx.sync(); + // Notify that we have set everything up + let _ = tx.send(()); let mut last_received_at = None; + let process_event_writer = process_event.writer(); // The core io uring loop 'io: loop { @@ -521,26 +654,26 @@ impl IoUringLoop { } let packet = packet.finalize_recv(ret as usize); - process_packet(&mut ctx, packet, &mut last_received_at); + process_packet( + &mut ctx, + &process_event_writer, + packet, + &mut last_received_at, + ); loop_ctx.enqueue_recv(buffer_pool.clone().alloc()); } Token::PendingsSends => { - if pending_sends_event.val < 0xdeadbeef { - double_pending_sends = pending_sends.swap(double_pending_sends); - loop_ctx.push_with_token( - pending_sends_event.io_uring_entry(), - Token::PendingsSends, - ); - - for pending in - double_pending_sends.drain(0..double_pending_sends.len()) - { - loop_ctx.enqueue_send(pending); - } - } else { - tracing::info!("io-uring loop shutdown requested"); - break 'io; + double_pending_sends = pending_sends.swap(double_pending_sends); + loop_ctx.push_with_token( + pending_sends_event.io_uring_entry(), + Token::PendingsSends, + ); + + for pending in + double_pending_sends.drain(0..double_pending_sends.len()) + { + loop_ctx.enqueue_send(pending); } } Token::Send { key } => { @@ -553,7 +686,7 @@ impl IoUringLoop { metrics::errors_total(send_dir, &source, &asn_info).inc(); metrics::packets_dropped_total(send_dir, &source, &asn_info) .inc(); - } else if ret as usize != packet.data.len() { + } else if ret as usize != packet.buffer.len() { metrics::packets_total(send_dir, &asn_info).inc(); metrics::errors_total( send_dir, @@ -566,6 +699,10 @@ impl IoUringLoop { metrics::bytes_total(send_dir, &asn_info).inc_by(ret as u64); } } + Token::Shutdown => { + tracing::info!("io-uring loop shutdown requested"); + break 'io; + } } } @@ -575,7 +712,7 @@ impl IoUringLoop { crate::metrics::game_traffic_task_closed().inc(); })?; - Ok(()) + Ok(rx) } } diff --git a/src/components/proxy/packet_router.rs b/src/components/proxy/packet_router.rs index 5d54c99e2..0915cc032 100644 --- a/src/components/proxy/packet_router.rs +++ b/src/components/proxy/packet_router.rs @@ -14,7 +14,10 @@ * limitations under the License. */ -use super::{sessions::SessionKey, PipelineError, SessionPool}; +use super::{ + sessions::{DownstreamReceiver, SessionKey}, + PipelineError, SessionPool, +}; use crate::{ filters::{Filter as _, ReadContext}, metrics, @@ -41,6 +44,8 @@ pub(crate) struct DownstreamPacket { pub struct DownstreamReceiveWorkerConfig { /// ID of the worker. pub worker_id: usize, + /// Socket with reused port from which the worker receives packets. + pub upstream_receiver: DownstreamReceiver, pub port: u16, pub config: Arc, pub sessions: Arc, @@ -132,17 +137,21 @@ impl DownstreamReceiveWorkerConfig { pub async fn spawn_receivers( config: Arc, socket: socket2::Socket, - worker_sends: Vec<(super::PendingSends, super::PacketSendReceiver)>, + num_workers: usize, sessions: &Arc, + upstream_receiver: DownstreamReceiver, buffer_pool: Arc, -) -> crate::Result<()> { + shutdown: crate::ShutdownRx, +) -> crate::Result>> { let (error_sender, mut error_receiver) = mpsc::channel(128); let port = crate::net::socket_port(&socket); - for (worker_id, ws) in worker_sends.into_iter().enumerate() { + let mut worker_notifications = Vec::with_capacity(num_workers); + for worker_id in 0..num_workers { let worker = DownstreamReceiveWorkerConfig { worker_id, + upstream_receiver: upstream_receiver.clone(), port, config: config.clone(), sessions: sessions.clone(), @@ -150,7 +159,7 @@ pub async fn spawn_receivers( buffer_pool: buffer_pool.clone(), }; - worker.spawn(ws).await?; + worker_notifications.push(worker.spawn(shutdown.clone()).await?); } drop(error_sender); @@ -188,5 +197,5 @@ pub async fn spawn_receivers( } }); - Ok(()) + Ok(worker_notifications) } diff --git a/src/components/proxy/packet_router/io_uring.rs b/src/components/proxy/packet_router/io_uring.rs index a3bc55450..2b535f41f 100644 --- a/src/components/proxy/packet_router/io_uring.rs +++ b/src/components/proxy/packet_router/io_uring.rs @@ -14,18 +14,18 @@ * limitations under the License. */ -use crate::components::proxy; use eyre::Context as _; impl super::DownstreamReceiveWorkerConfig { pub async fn spawn( self, - pending_sends: (proxy::PendingSends, proxy::PacketSendReceiver), - ) -> eyre::Result<()> { + shutdown: crate::ShutdownRx, + ) -> eyre::Result> { use crate::components::proxy::io_uring_shared; let Self { worker_id, + upstream_receiver, port, config, sessions, @@ -47,8 +47,9 @@ impl super::DownstreamReceiveWorkerConfig { worker_id, destinations: Vec::with_capacity(1), }, - pending_sends, + io_uring_shared::PacketReceiver::Router(upstream_receiver), buffer_pool, + shutdown, ) .context("failed to spawn io-uring loop") } diff --git a/src/components/proxy/packet_router/reference.rs b/src/components/proxy/packet_router/reference.rs index efc441ef9..9a519fc1e 100644 --- a/src/components/proxy/packet_router/reference.rs +++ b/src/components/proxy/packet_router/reference.rs @@ -16,15 +16,14 @@ //! The reference implementation is used for non-Linux targets -use crate::components::proxy; - impl super::DownstreamReceiveWorkerConfig { pub async fn spawn( self, - pending_sends: (proxy::PendingSends, proxy::PacketSendReceiver), - ) -> eyre::Result<()> { + _shutdown: crate::ShutdownRx, + ) -> eyre::Result> { let Self { worker_id, + upstream_receiver, port, config, sessions, @@ -32,9 +31,10 @@ impl super::DownstreamReceiveWorkerConfig { buffer_pool, } = self; + let (tx, rx) = std::sync::mpsc::channel(); + let thread_span = uring_span!(tracing::debug_span!("receiver", id = worker_id).or_current()); - let (tx, mut rx) = tokio::sync::oneshot::channel(); let worker = uring_spawn!(thread_span, async move { crate::metrics::game_traffic_tasks().inc(); @@ -47,49 +47,56 @@ impl super::DownstreamReceiveWorkerConfig { let send_socket = socket.clone(); let inner_task = async move { - let (pending_sends, mut sends_rx) = pending_sends; - let mut sends_double_buffer = Vec::with_capacity(pending_sends.capacity()); - - while sends_rx.changed().await.is_ok() { - if !*sends_rx.borrow() { - tracing::trace!("io loop shutdown requested"); - break; - } + let _ = tx.send(()); - sends_double_buffer = pending_sends.swap(sends_double_buffer); - - for packet in sends_double_buffer.drain(..sends_double_buffer.len()) { - let (result, _) = send_socket - .send_to(packet.data, packet.destination.as_socket().unwrap()) - .await; - let asn_info = packet.asn_info.as_ref().into(); - match result { - Ok(size) => { - crate::metrics::packets_total(crate::metrics::WRITE, &asn_info) - .inc(); - crate::metrics::bytes_total(crate::metrics::WRITE, &asn_info) - .inc_by(size as u64); - } - Err(error) => { - let source = error.to_string(); - crate::metrics::errors_total( - crate::metrics::WRITE, - &source, - &asn_info, - ) - .inc(); - crate::metrics::packets_dropped_total( - crate::metrics::WRITE, - &source, - &asn_info, - ) - .inc(); + loop { + tokio::select! { + result = upstream_receiver.recv() => { + match result { + Err(error) => { + tracing::trace!(%error, "error receiving packet"); + crate::metrics::errors_total( + crate::metrics::WRITE, + &error.to_string(), + &crate::metrics::EMPTY, + ) + .inc(); + } + Ok(crate::components::proxy::SendPacket { + destination, + asn_info, + data, + }) => { + let (result, _) = send_socket.send_to(data, destination).await; + let asn_info = asn_info.as_ref().into(); + match result { + Ok(size) => { + crate::metrics::packets_total(crate::metrics::WRITE, &asn_info) + .inc(); + crate::metrics::bytes_total(crate::metrics::WRITE, &asn_info) + .inc_by(size as u64); + } + Err(error) => { + let source = error.to_string(); + crate::metrics::errors_total( + crate::metrics::WRITE, + &source, + &asn_info, + ) + .inc(); + crate::metrics::packets_dropped_total( + crate::metrics::WRITE, + &source, + &asn_info, + ) + .inc(); + } + } + } } } } } - - let _ = tx.send(()); }; cfg_if::cfg_if! { @@ -109,43 +116,35 @@ impl super::DownstreamReceiveWorkerConfig { // packet, which is the maximum value of 16 a bit integer. let buffer = buffer_pool.clone().alloc(); - tokio::select! { - received = socket.recv_from(buffer) => { - let received_at = crate::time::UtcTimestamp::now(); - let (result, buffer) = received; + let (result, contents) = socket.recv_from(buffer).await; + let received_at = crate::time::UtcTimestamp::now(); - match result { - Ok((_size, mut source)) => { - source.set_ip(source.ip().to_canonical()); - let packet = super::DownstreamPacket { contents: buffer, source }; + match result { + Ok((_size, mut source)) => { + source.set_ip(source.ip().to_canonical()); + let packet = super::DownstreamPacket { contents, source }; - if let Some(last_received_at) = last_received_at { - crate::metrics::packet_jitter( - crate::metrics::READ, - &crate::metrics::EMPTY, - ) - .set((received_at - last_received_at).nanos()); - } - last_received_at = Some(received_at); - - Self::process_task( - packet, - worker_id, - &config, - &sessions, - &mut error_acc, - &mut destinations, - ); - } - Err(error) => { - tracing::error!(%error, "error receiving packet"); - return; - } + if let Some(last_received_at) = last_received_at { + crate::metrics::packet_jitter( + crate::metrics::READ, + &crate::metrics::EMPTY, + ) + .set((received_at - last_received_at).nanos()); } + last_received_at = Some(received_at); + + Self::process_task( + packet, + worker_id, + &config, + &sessions, + &mut error_acc, + &mut destinations, + ); } - _ = &mut rx => { + Err(error) => { crate::metrics::game_traffic_task_closed().inc(); - tracing::debug!("Closing downstream socket loop, shutdown requested"); + tracing::error!(%error, "error receiving packet"); return; } } @@ -154,6 +153,6 @@ impl super::DownstreamReceiveWorkerConfig { use eyre::WrapErr as _; worker.recv().context("failed to spawn receiver task")?; - Ok(()) + Ok(rx) } } diff --git a/src/components/proxy/sessions.rs b/src/components/proxy/sessions.rs index 0ee3a57dd..0fcd6a256 100644 --- a/src/components/proxy/sessions.rs +++ b/src/components/proxy/sessions.rs @@ -18,38 +18,38 @@ use std::{ collections::{HashMap, HashSet}, fmt, net::SocketAddr, - sync::{atomic, Arc}, + sync::Arc, time::Duration, }; -use tokio::time::Instant; +use tokio::{sync::mpsc, time::Instant}; use crate::{ - components::proxy::SendPacket, + components::proxy::{PipelineError, SendPacket}, config::Config, filters::Filter, metrics, net::maxmind_db::{IpNetEntry, MetricsIpNetEntry}, pool::{BufferPool, FrozenPoolBuffer, PoolBuffer}, time::UtcTimestamp, - Loggable, + Loggable, ShutdownRx, }; use parking_lot::RwLock; -use super::PendingSends; - pub(crate) mod inner_metrics; pub type SessionMap = crate::collections::ttl::TtlMap; -cfg_if::cfg_if! { - if #[cfg(target_os = "linux")] { - mod io_uring; - } else { - mod reference; - } -} +#[cfg(target_os = "linux")] +mod io_uring; +#[cfg(not(target_os = "linux"))] +mod reference; + +type UpstreamSender = mpsc::Sender; + +type DownstreamSender = async_channel::Sender; +pub type DownstreamReceiver = async_channel::Receiver; #[derive(PartialEq, Eq, Hash)] pub enum SessionError { @@ -90,13 +90,13 @@ impl fmt::Debug for SessionError { /// Traffic from different gameservers is then demuxed using their address to /// send back to the original client. pub struct SessionPool { - ports_to_sockets: RwLock>, + ports_to_sockets: RwLock>, storage: Arc>, session_map: SessionMap, + downstream_sender: DownstreamSender, buffer_pool: Arc, + shutdown_rx: ShutdownRx, config: Arc, - downstream_sends: Vec, - downstream_index: atomic::AtomicUsize, } /// The wrapper struct responsible for holding all of the socket related mappings. @@ -114,20 +114,21 @@ impl SessionPool { /// to release their sockets back to the parent. pub fn new( config: Arc, - downstream_sends: Vec, + downstream_sender: DownstreamSender, buffer_pool: Arc, + shutdown_rx: ShutdownRx, ) -> Arc { const SESSION_TIMEOUT_SECONDS: Duration = Duration::from_secs(60); const SESSION_EXPIRY_POLL_INTERVAL: Duration = Duration::from_secs(60); Arc::new(Self { config, + downstream_sender, + shutdown_rx, ports_to_sockets: <_>::default(), storage: <_>::default(), session_map: SessionMap::new(SESSION_TIMEOUT_SECONDS, SESSION_EXPIRY_POLL_INTERVAL), buffer_pool, - downstream_sends, - downstream_index: atomic::AtomicUsize::new(0), }) } @@ -135,7 +136,7 @@ impl SessionPool { fn create_new_session_from_new_socket<'pool>( self: &'pool Arc, key: SessionKey, - ) -> Result<(Option, PendingSends), super::PipelineError> { + ) -> Result<(Option, UpstreamSender), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "creating new socket for session"); let raw_socket = crate::net::raw_socket_with_reuse(0)?; let port = raw_socket @@ -143,15 +144,19 @@ impl SessionPool { .as_socket() .ok_or(SessionError::SocketAddressUnavailable)? .port(); + let (downstream_sender, downstream_receiver) = mpsc::channel::(15); - let (pending_sends, srecv) = super::PendingSends::new(15)?; - self.clone() - .spawn_session(raw_socket, port, (pending_sends.clone(), srecv))?; + let initialised = self + .clone() + .spawn_session(raw_socket, port, downstream_receiver)?; + initialised + .recv() + .map_err(|_err| PipelineError::ChannelClosed)?; self.ports_to_sockets .write() - .insert(port, pending_sends.clone()); - self.create_session_from_existing_socket(key, pending_sends, port) + .insert(port, downstream_sender.clone()); + self.create_session_from_existing_socket(key, downstream_sender, port) } pub(crate) fn process_received_upstream_packet( @@ -187,6 +192,7 @@ impl SessionPool { let _timer = metrics::processing_time(metrics::WRITE).start_timer(); Self::process_recv_packet( self.config.clone(), + &self.downstream_sender, recv_addr, downstream_addr, asn_info, @@ -194,25 +200,13 @@ impl SessionPool { ) }; - match result { - Ok(packet) => { - let index = self - .downstream_index - .fetch_add(1, atomic::Ordering::Relaxed) - % self.downstream_sends.len(); - // SAFETY: we've ensured it's within bounds via the % - unsafe { - self.downstream_sends.get_unchecked(index).push(packet); - } - } - Err((asn_info, error)) => { - error.log(); - let label = format!("proxy::Session::process_recv_packet: {error}"); - let asn_metric_info = asn_info.as_ref().into(); + if let Err((asn_info, error)) = result { + error.log(); + let label = format!("proxy::Session::process_recv_packet: {error}"); + let asn_metric_info = asn_info.as_ref().into(); - metrics::packets_dropped_total(metrics::WRITE, &label, &asn_metric_info).inc(); - metrics::errors_total(metrics::WRITE, &label, &asn_metric_info).inc(); - } + metrics::packets_dropped_total(metrics::WRITE, &label, &asn_metric_info).inc(); + metrics::errors_total(metrics::WRITE, &label, &asn_metric_info).inc(); } } @@ -223,14 +217,14 @@ impl SessionPool { pub fn get<'pool>( self: &'pool Arc, key @ SessionKey { dest, .. }: SessionKey, - ) -> Result<(Option, PendingSends), super::PipelineError> { + ) -> Result<(Option, UpstreamSender), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "SessionPool::get"); // If we already have a session for the key pairing, return that session. if let Some(entry) = self.session_map.get(&key) { tracing::trace!("returning existing session"); return Ok(( entry.asn_info.as_ref().map(MetricsIpNetEntry::from), - entry.pending_sends.clone(), + entry.upstream_sender.clone(), )); } @@ -284,9 +278,9 @@ impl SessionPool { fn create_session_from_existing_socket<'session>( self: &'session Arc, key: SessionKey, - pending_sends: PendingSends, + upstream_sender: UpstreamSender, socket_port: u16, - ) -> Result<(Option, PendingSends), super::PipelineError> { + ) -> Result<(Option, UpstreamSender), super::PipelineError> { tracing::trace!(source=%key.source, dest=%key.dest, "reusing socket for session"); let asn_info = { let mut storage = self.storage.write(); @@ -319,7 +313,7 @@ impl SessionPool { let session = Session::new( key, - pending_sends.clone(), + upstream_sender.clone(), socket_port, self.clone(), asn_info, @@ -327,17 +321,18 @@ impl SessionPool { tracing::trace!("inserting session into map"); self.session_map.insert(key, session); tracing::trace!("session inserted"); - Ok((asn_metrics_info, pending_sends)) + Ok((asn_metrics_info, upstream_sender)) } /// process_recv_packet processes a packet that is received by this session. fn process_recv_packet( config: Arc, + downstream_sender: &DownstreamSender, source: SocketAddr, dest: SocketAddr, asn_info: Option, packet: PoolBuffer, - ) -> Result, Error)> { + ) -> Result<(), (Option, Error)> { tracing::trace!(%source, %dest, length = packet.len(), "received packet from upstream"); let mut context = crate::filters::WriteContext::new(source.into(), dest.into(), packet); @@ -346,11 +341,21 @@ impl SessionPool { return Err((asn_info, err.into())); } - Ok(SendPacket { - data: context.contents.freeze(), - destination: dest.into(), - asn_info, - }) + let packet = context.contents.freeze(); + tracing::trace!(%source, %dest, length = packet.len(), "sending packet downstream"); + downstream_sender + .try_send(SendPacket { + data: packet, + destination: dest, + asn_info, + }) + .map_err(|error| match error { + async_channel::TrySendError::Closed(packet) => { + (packet.asn_info, Error::ChannelClosed) + } + async_channel::TrySendError::Full(packet) => (packet.asn_info, Error::ChannelFull), + })?; + Ok(()) } /// Returns a map of active sessions. @@ -359,30 +364,25 @@ impl SessionPool { } /// Sends packet data to the appropiate session based on its `key`. - #[inline] pub fn send( self: &Arc, key: SessionKey, packet: FrozenPoolBuffer, ) -> Result<(), super::PipelineError> { - self.send_inner(key, packet)?; - Ok(()) - } + use tokio::sync::mpsc::error::TrySendError; - #[inline] - fn send_inner( - self: &Arc, - key: SessionKey, - packet: FrozenPoolBuffer, - ) -> Result { let (asn_info, sender) = self.get(key)?; - sender.push(SendPacket { - destination: key.dest.into(), - data: packet, - asn_info, - }); - Ok(sender) + sender + .try_send(crate::components::proxy::SendPacket { + data: packet, + asn_info, + destination: key.dest, + }) + .map_err(|error| match error { + TrySendError::Closed(_) => super::PipelineError::ChannelClosed, + TrySendError::Full(_) => super::PipelineError::ChannelFull, + }) } /// Returns whether the pool contains any sockets allocated to a destination. @@ -405,7 +405,7 @@ impl SessionPool { } /// Handles the logic of releasing a socket back into the pool. - fn release_socket( + async fn release_socket( self: Arc, SessionKey { ref source, @@ -440,28 +440,11 @@ impl SessionPool { storage.destination_to_sources.remove(&(*dest, port)); tracing::trace!("socket released"); } - - /// Closes all active sessions, and all downstream listeners - pub(crate) fn shutdown(self: Arc, wait: bool) { - // Disable downstream listeners first so sessions aren't spawned while - // we are trying to reap the active sessions - for downstream_listener in &self.downstream_sends { - downstream_listener.shutdown_receiver(); - } - - if wait && !self.session_map.is_empty() { - tracing::info!(sessions=%self.session_map.len(), "waiting for active sessions to expire"); - self.session_map.clear(); - } - } } impl Drop for SessionPool { fn drop(&mut self) { - let map = std::mem::take(&mut self.session_map); - std::thread::spawn(move || { - drop(map); - }); + drop(std::mem::take(&mut self.session_map)); } } @@ -473,8 +456,8 @@ pub struct Session { key: SessionKey, /// The socket port of the session. socket_port: u16, - /// The queue of packets being sent to the upstream (server) - pending_sends: PendingSends, + /// The socket of the session. + upstream_sender: UpstreamSender, /// The GeoIP information of the source. asn_info: Option, /// The socket pool of the session. @@ -484,14 +467,14 @@ pub struct Session { impl Session { pub fn new( key: SessionKey, - pending_sends: PendingSends, + upstream_sender: UpstreamSender, socket_port: u16, pool: Arc, asn_info: Option, ) -> Self { let s = Self { key, - pending_sends, + upstream_sender, pool, socket_port, asn_info, @@ -520,18 +503,17 @@ impl Session { inner_metrics::active_sessions(self.asn_info.as_ref()) } - fn release(&mut self) { + fn async_drop(&mut self) -> impl std::future::Future { self.active_session_metric().dec(); inner_metrics::duration_secs().observe(self.created_at.elapsed().as_secs() as f64); tracing::debug!(source = %self.key.source, dest_address = %self.key.dest, "Session closed"); - self.pending_sends.shutdown_receiver(); - SessionPool::release_socket(self.pool.clone(), self.key, self.socket_port); + SessionPool::release_socket(self.pool.clone(), self.key, self.socket_port) } } impl Drop for Session { fn drop(&mut self) { - self.release() + tokio::spawn(self.async_drop()); } } @@ -550,6 +532,10 @@ impl From<(SocketAddr, SocketAddr)> for SessionKey { #[derive(Debug, thiserror::Error)] pub enum Error { + #[error("downstream channel closed")] + ChannelClosed, + #[error("downstream channel full")] + ChannelFull, #[error("filter {0}")] Filter(#[from] crate::filters::FilterError), } @@ -564,24 +550,30 @@ impl Loggable for Error { #[cfg(test)] mod tests { use super::*; - use crate::test::{alloc_buffer, available_addr, AddressType, TestHelper}; + use crate::{ + test::{alloc_buffer, available_addr, AddressType, TestHelper}, + ShutdownTx, + }; use std::sync::Arc; - async fn new_pool() -> (Arc, PendingSends) { - let (pending_sends, _srecv) = PendingSends::new(1).unwrap(); + async fn new_pool() -> (Arc, ShutdownTx, DownstreamReceiver) { + let (tx, rx) = crate::make_shutdown_channel(crate::ShutdownKind::Testing); + let (sender, receiver) = async_channel::unbounded(); ( SessionPool::new( Arc::new(Config::default_agent()), - vec![pending_sends.clone()], + sender, Arc::new(BufferPool::default()), + rx, ), - pending_sends, + tx, + receiver, ) } #[tokio::test] async fn insert_and_release_single_socket() { - let (pool, _receiver) = new_pool().await; + let (pool, _sender, _receiver) = new_pool().await; let key = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -597,7 +589,7 @@ mod tests { #[tokio::test] async fn insert_and_release_multiple_sockets() { - let (pool, _receiver) = new_pool().await; + let (pool, _sender, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -622,7 +614,7 @@ mod tests { #[tokio::test] async fn same_address_uses_different_sockets() { - let (pool, _receiver) = new_pool().await; + let (pool, _sender, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -647,7 +639,7 @@ mod tests { #[tokio::test] async fn different_addresses_uses_same_socket() { - let (pool, _receiver) = new_pool().await; + let (pool, _sender, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -670,7 +662,7 @@ mod tests { #[tokio::test] async fn spawn_safe_same_destination() { - let (pool, _receiver) = new_pool().await; + let (pool, _sender, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -695,7 +687,7 @@ mod tests { #[tokio::test] async fn spawn_safe_different_destination() { - let (pool, _receiver) = new_pool().await; + let (pool, _sender, _receiver) = new_pool().await; let key1 = ( (std::net::Ipv4Addr::LOCALHOST, 8080u16).into(), (std::net::Ipv4Addr::UNSPECIFIED, 8080u16).into(), @@ -729,14 +721,18 @@ mod tests { let socket = tokio::net::UdpSocket::bind(source).await.unwrap(); let mut source = socket.local_addr().unwrap(); crate::test::map_addr_to_localhost(&mut source); - let (pool, _pending_sends) = new_pool().await; + let (pool, _sender, receiver) = new_pool().await; let key: SessionKey = (source, dest).into(); let msg = b"helloworld"; - let pending = pool.send_inner(key, alloc_buffer(msg).freeze()).unwrap(); - let pending = pending.swap(Vec::new()); + pool.send(key, alloc_buffer(msg).freeze()).unwrap(); + + let packet = tokio::time::timeout(std::time::Duration::from_secs(1), receiver.recv()) + .await + .unwrap() + .unwrap(); - assert_eq!(msg, &*pending[0].data); + assert_eq!(msg, &*packet.data); } } diff --git a/src/components/proxy/sessions/io_uring.rs b/src/components/proxy/sessions/io_uring.rs index d345689f6..ce709f8e4 100644 --- a/src/components/proxy/sessions/io_uring.rs +++ b/src/components/proxy/sessions/io_uring.rs @@ -14,7 +14,6 @@ * limitations under the License. */ -use crate::components::proxy; use std::sync::Arc; static SESSION_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0); @@ -24,9 +23,9 @@ impl super::SessionPool { self: Arc, raw_socket: socket2::Socket, port: u16, - pending_sends: (proxy::PendingSends, proxy::io_uring_shared::EventFd), - ) -> Result<(), proxy::PipelineError> { - use proxy::io_uring_shared; + downstream_receiver: tokio::sync::mpsc::Receiver, + ) -> Result, crate::components::proxy::PipelineError> { + use crate::components::proxy::io_uring_shared; let pool = self; let id = SESSION_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed); @@ -37,12 +36,14 @@ impl super::SessionPool { crate::net::DualStackLocalSocket::from_raw(raw_socket), )?; let buffer_pool = pool.buffer_pool.clone(); + let shutdown = pool.shutdown_rx.clone(); io_loop.spawn( format!("session-{id}"), io_uring_shared::PacketProcessorCtx::SessionPool { pool, port }, - pending_sends, + io_uring_shared::PacketReceiver::SessionPool(downstream_receiver), buffer_pool, + shutdown, ) } } diff --git a/src/components/proxy/sessions/reference.rs b/src/components/proxy/sessions/reference.rs index bad4d2ebb..067fee4d2 100644 --- a/src/components/proxy/sessions/reference.rs +++ b/src/components/proxy/sessions/reference.rs @@ -14,21 +14,20 @@ * limitations under the License. */ -use crate::components::proxy; - impl super::SessionPool { pub(super) fn spawn_session( self: std::sync::Arc, raw_socket: socket2::Socket, port: u16, - pending_sends: (proxy::PendingSends, proxy::PacketSendReceiver), - ) -> Result<(), proxy::PipelineError> { + mut downstream_receiver: tokio::sync::mpsc::Receiver, + ) -> Result, crate::components::proxy::PipelineError> { let pool = self; - uring_spawn!( + let rx = uring_spawn!( uring_span!(tracing::debug_span!("session pool")), async move { let mut last_received_at = None; + let mut shutdown_rx = pool.shutdown_rx.clone(); let socket = std::sync::Arc::new(crate::net::DualStackLocalSocket::from_raw(raw_socket)); @@ -36,48 +35,54 @@ impl super::SessionPool { let (tx, mut rx) = tokio::sync::oneshot::channel(); uring_inner_spawn!(async move { - let (pending_sends, mut sends_rx) = pending_sends; - let mut sends_double_buffer = Vec::with_capacity(pending_sends.capacity()); - - while sends_rx.changed().await.is_ok() { - if !*sends_rx.borrow() { - tracing::trace!("io loop shutdown requested"); - break; - } - - sends_double_buffer = pending_sends.swap(sends_double_buffer); - - for packet in sends_double_buffer.drain(..sends_double_buffer.len()) { - let destination = packet.destination.as_socket().unwrap(); - tracing::trace!( - %destination, - length = packet.data.len(), - "sending packet upstream" - ); - let (result, _) = socket2.send_to(packet.data, destination).await; - let asn_info = packet.asn_info.as_ref().into(); - match result { - Ok(size) => { - crate::metrics::packets_total(crate::metrics::READ, &asn_info) + loop { + match downstream_receiver.recv().await { + None => { + crate::metrics::errors_total( + crate::metrics::WRITE, + "downstream channel closed", + &crate::metrics::EMPTY, + ) + .inc(); + break; + } + Some(crate::components::proxy::SendPacket { + destination, + data, + asn_info, + }) => { + tracing::trace!(%destination, length = data.len(), "sending packet upstream"); + let (result, _) = socket2.send_to(data, destination).await; + let asn_info = asn_info.as_ref().into(); + match result { + Ok(size) => { + crate::metrics::packets_total( + crate::metrics::READ, + &asn_info, + ) .inc(); - crate::metrics::bytes_total(crate::metrics::READ, &asn_info) + crate::metrics::bytes_total( + crate::metrics::READ, + &asn_info, + ) .inc_by(size as u64); - } - Err(error) => { - tracing::trace!(%error, "sending packet upstream failed"); - let source = error.to_string(); - crate::metrics::errors_total( - crate::metrics::READ, - &source, - &asn_info, - ) - .inc(); - crate::metrics::packets_dropped_total( - crate::metrics::READ, - &source, - &asn_info, - ) - .inc(); + } + Err(error) => { + tracing::trace!(%error, "sending packet upstream failed"); + let source = error.to_string(); + crate::metrics::errors_total( + crate::metrics::READ, + &source, + &asn_info, + ) + .inc(); + crate::metrics::packets_dropped_total( + crate::metrics::READ, + &source, + &asn_info, + ) + .inc(); + } } } } @@ -99,6 +104,10 @@ impl super::SessionPool { Ok((_size, recv_addr)) => pool.process_received_upstream_packet(buf, recv_addr, port, &mut last_received_at), } } + _ = shutdown_rx.changed() => { + tracing::debug!("Closing upstream socket loop"); + return; + } _ = &mut rx => { tracing::debug!("Closing upstream socket loop, downstream closed"); return; @@ -108,6 +117,6 @@ impl super::SessionPool { } ); - Ok(()) + Ok(rx) } }