diff --git a/src/client.rs b/src/client.rs index 89d00649..dc6710ca 100644 --- a/src/client.rs +++ b/src/client.rs @@ -13,11 +13,11 @@ use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; +use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time::{self, Duration}; -use tracing::{debug, error, info, instrument, warn, Instrument, Span}; +use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span}; #[cfg(feature = "noise")] use crate::transport::NoiseTransport; @@ -236,8 +236,13 @@ async fn run_data_channel_for_udp(conn: T::Stream, local_addr: &st // Keep sending items from the outbound channel to the server tokio::spawn(async move { while let Some(t) = outbound_rx.recv().await { - debug!("outbound {:?}", t); - if t.write(&mut wr).await.is_err() { + trace!("outbound {:?}", t); + if let Err(e) = t + .write(&mut wr) + .await + .with_context(|| "Failed to forward UDP traffic to the server") + { + debug!("{:?}", e); break; } } @@ -245,7 +250,10 @@ async fn run_data_channel_for_udp(conn: T::Stream, local_addr: &st loop { // Read a packet from the server - let packet = UdpTraffic::read(&mut rd).await?; + let hdr_len = rd.read_u16().await?; + let packet = UdpTraffic::read(&mut rd, hdr_len) + .await + .with_context(|| "Failed to read UDPTraffic from the server")?; let m = port_map.read().await; if m.get(&packet.from).is_none() { @@ -290,6 +298,7 @@ async fn run_data_channel_for_udp(conn: T::Stream, local_addr: &st } // Run a UdpSocket for the visitor `from` +#[instrument(skip_all, fields(from))] async fn run_udp_forwarder( s: UdpSocket, mut inbound_rx: mpsc::Receiver, @@ -297,6 +306,7 @@ async fn run_udp_forwarder( from: SocketAddr, port_map: UdpPortMap, ) -> Result<()> { + debug!("Forwarder created"); let mut buf = BytesMut::new(); buf.resize(UDP_BUFFER_SIZE, 0); @@ -336,6 +346,7 @@ async fn run_udp_forwarder( let mut port_map = port_map.write().await; port_map.remove(&from); + debug!("Forwarder dropped"); Ok(()) } diff --git a/src/protocol.rs b/src/protocol.rs index 883b6541..9be620e6 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -6,6 +6,7 @@ use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use std::net::SocketAddr; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tracing::trace; type ProtocolVersion = u8; const PROTO_V0: u8 = 0u8; @@ -70,12 +71,14 @@ pub struct UdpTraffic { impl UdpTraffic { pub async fn write(&self, writer: &mut T) -> Result<()> { - let v = bincode::serialize(&UdpHeader { + let hdr = UdpHeader { from: self.from, len: self.data.len() as UdpPacketLen, - }) - .unwrap(); + }; + + let v = bincode::serialize(&hdr).unwrap(); + trace!("Write {:?} of length {}", hdr, v.len()); writer.write_u16(v.len() as u16).await?; writer.write_all(&v).await?; @@ -90,12 +93,14 @@ impl UdpTraffic { from: SocketAddr, data: &[u8], ) -> Result<()> { - let v = bincode::serialize(&UdpHeader { + let hdr = UdpHeader { from, len: data.len() as UdpPacketLen, - }) - .unwrap(); + }; + + let v = bincode::serialize(&hdr).unwrap(); + trace!("Write {:?} of length {}", hdr, v.len()); writer.write_u16(v.len() as u16).await?; writer.write_all(&v).await?; @@ -104,24 +109,25 @@ impl UdpTraffic { Ok(()) } - pub async fn read(reader: &mut T) -> Result { - let len = reader.read_u16().await? as usize; - + pub async fn read(reader: &mut T, hdr_len: u16) -> Result { let mut buf = Vec::new(); - buf.resize(len, 0); + buf.resize(hdr_len as usize, 0); reader .read_exact(&mut buf) .await .with_context(|| "Failed to read udp header")?; - let header: UdpHeader = - bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?; + + let hdr: UdpHeader = + bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?; + + trace!("hdr {:?}", hdr); let mut data = BytesMut::new(); - data.resize(header.len as usize, 0); + data.resize(hdr.len as usize, 0); reader.read_exact(&mut data).await?; Ok(UdpTraffic { - from: header.from, + from: hdr.from, data: data.freeze(), }) } diff --git a/src/server.rs b/src/server.rs index 0e2bbcc1..0003c9b8 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,10 +14,9 @@ use backoff::ExponentialBackoff; use rand::RngCore; use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; +use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, RwLock}; use tokio::time; @@ -162,12 +161,12 @@ impl<'a, T: 'static + Transport> Server<'a, T> { } Ok((conn, addr)) => { backoff.reset(); - debug!("Incomming connection from {}", addr); + debug!("Incoming connection from {}", addr); let services = self.services.clone(); let control_channels = self.control_channels.clone(); tokio::spawn(async move { - if let Err(err) = handle_connection(conn, addr, services, control_channels).await.with_context(||"Failed to handle a connection to `server.bind_addr`") { + if let Err(err) = handle_connection(conn, services, control_channels).await { error!("{:?}", err); } }.instrument(info_span!("handle_connection", %addr))); @@ -215,7 +214,6 @@ impl<'a, T: 'static + Transport> Server<'a, T> { // Handle connections to `server.bind_addr` async fn handle_connection( mut conn: T::Stream, - addr: SocketAddr, services: Arc>>, control_channels: Arc>>, ) -> Result<()> { @@ -223,8 +221,7 @@ async fn handle_connection( let hello = read_hello(&mut conn).await?; match hello { ControlChannelHello(_, service_digest) => { - do_control_channel_handshake(conn, addr, services, control_channels, service_digest) - .await?; + do_control_channel_handshake(conn, services, control_channels, service_digest).await?; } DataChannelHello(_, nonce) => { do_data_channel_handshake(conn, control_channels, nonce).await?; @@ -235,12 +232,11 @@ async fn handle_connection( async fn do_control_channel_handshake( mut conn: T::Stream, - addr: SocketAddr, services: Arc>>, control_channels: Arc>>, service_digest: ServiceDigest, ) -> Result<()> { - info!("New control channel incomming from {}", addr); + info!("New control channel incoming"); // Generate a nonce let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES]; @@ -321,6 +317,8 @@ async fn do_data_channel_handshake( control_channels: Arc>>, nonce: Nonce, ) -> Result<()> { + info!("New control channel incoming"); + // Validate let control_channels_guard = control_channels.read().await; match control_channels_guard.get2(&nonce) { @@ -358,27 +356,6 @@ where // Store data channel creation requests let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel(); - match service.service_type { - ServiceType::Tcp => tokio::spawn( - run_tcp_connection_pool::( - service.bind_addr.clone(), - data_ch_rx, - data_ch_req_tx.clone(), - shutdown_tx.subscribe(), - ) - .instrument(Span::current()), - ), - ServiceType::Udp => tokio::spawn( - run_udp_connection_pool::( - service.bind_addr.clone(), - data_ch_rx, - data_ch_req_tx.clone(), - shutdown_tx.subscribe(), - ) - .instrument(Span::current()), - ), - }; - // Cache some data channels for later use let pool_size = match service.service_type { ServiceType::Tcp => TCP_POOL_SIZE, @@ -391,6 +368,43 @@ where }; } + let shutdown_rx_clone = shutdown_tx.subscribe(); + let bind_addr = service.bind_addr.clone(); + match service.service_type { + ServiceType::Tcp => tokio::spawn( + async move { + if let Err(e) = run_tcp_connection_pool::( + bind_addr, + data_ch_rx, + data_ch_req_tx, + shutdown_rx_clone, + ) + .await + .with_context(|| "Failed to run TCP connection pool") + { + error!("{:?}", e); + } + } + .instrument(Span::current()), + ), + ServiceType::Udp => tokio::spawn( + async move { + if let Err(e) = run_udp_connection_pool::( + bind_addr, + data_ch_rx, + data_ch_req_tx, + shutdown_rx_clone, + ) + .await + .with_context(|| "Failed to run TCP connection pool") + { + error!("{:?}", e); + } + } + .instrument(Span::current()), + ), + }; + // Create the control channel let ch = ControlChannel:: { conn, @@ -568,7 +582,16 @@ async fn run_udp_connection_pool( // TODO: Load balance let l: UdpSocket = backoff::future::retry(listen_backoff(), || async { - Ok(UdpSocket::bind(&bind_addr).await?) + Ok(match UdpSocket::bind(&bind_addr) + .await + .with_context(|| "Failed to listen for the service") + { + Err(e) => { + error!("{:?}", e); + Err(e) + } + v => v, + }?) }) .await .with_context(|| "Failed to listen for the service")?; @@ -594,10 +617,10 @@ async fn run_udp_connection_pool( }, // Forward outbound traffic from the client to the visitor - t = UdpTraffic::read(&mut conn) => { - let t = t?; + hdr_len = conn.read_u16() => { + let t = UdpTraffic::read(&mut conn, hdr_len?).await?; l.send_to(&t.data, t.from).await?; - }, + } _ = shutdown_rx.recv() => { break; @@ -605,5 +628,7 @@ async fn run_udp_connection_pool( } } + debug!("UDP pool dropped"); + Ok(()) }