Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cancel safety and more logs #77

Merged
merged 2 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ use bytes::{Bytes, BytesMut};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use tokio::time::{self, Duration};
use tracing::{debug, error, info, instrument, warn, Instrument, Span};
use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};

#[cfg(feature = "noise")]
use crate::transport::NoiseTransport;
Expand Down Expand Up @@ -236,16 +236,24 @@ async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &st
// Keep sending items from the outbound channel to the server
tokio::spawn(async move {
while let Some(t) = outbound_rx.recv().await {
debug!("outbound {:?}", t);
if t.write(&mut wr).await.is_err() {
trace!("outbound {:?}", t);
if let Err(e) = t
.write(&mut wr)
.await
.with_context(|| "Failed to forward UDP traffic to the server")
{
debug!("{:?}", e);
break;
}
}
});

loop {
// Read a packet from the server
let packet = UdpTraffic::read(&mut rd).await?;
let hdr_len = rd.read_u16().await?;
let packet = UdpTraffic::read(&mut rd, hdr_len)
.await
.with_context(|| "Failed to read UDPTraffic from the server")?;
let m = port_map.read().await;

if m.get(&packet.from).is_none() {
Expand Down Expand Up @@ -290,13 +298,15 @@ async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &st
}

// Run a UdpSocket for the visitor `from`
#[instrument(skip_all, fields(from))]
async fn run_udp_forwarder(
s: UdpSocket,
mut inbound_rx: mpsc::Receiver<Bytes>,
outbount_tx: mpsc::Sender<UdpTraffic>,
from: SocketAddr,
port_map: UdpPortMap,
) -> Result<()> {
debug!("Forwarder created");
let mut buf = BytesMut::new();
buf.resize(UDP_BUFFER_SIZE, 0);

Expand Down Expand Up @@ -336,6 +346,7 @@ async fn run_udp_forwarder(
let mut port_map = port_map.write().await;
port_map.remove(&from);

debug!("Forwarder dropped");
Ok(())
}

Expand Down
34 changes: 20 additions & 14 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use lazy_static::lazy_static;
use serde::{Deserialize, Serialize};
use std::net::SocketAddr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tracing::trace;

type ProtocolVersion = u8;
const PROTO_V0: u8 = 0u8;
Expand Down Expand Up @@ -70,12 +71,14 @@ pub struct UdpTraffic {

impl UdpTraffic {
pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
let v = bincode::serialize(&UdpHeader {
let hdr = UdpHeader {
from: self.from,
len: self.data.len() as UdpPacketLen,
})
.unwrap();
};

let v = bincode::serialize(&hdr).unwrap();

trace!("Write {:?} of length {}", hdr, v.len());
writer.write_u16(v.len() as u16).await?;
writer.write_all(&v).await?;

Expand All @@ -90,12 +93,14 @@ impl UdpTraffic {
from: SocketAddr,
data: &[u8],
) -> Result<()> {
let v = bincode::serialize(&UdpHeader {
let hdr = UdpHeader {
from,
len: data.len() as UdpPacketLen,
})
.unwrap();
};

let v = bincode::serialize(&hdr).unwrap();

trace!("Write {:?} of length {}", hdr, v.len());
writer.write_u16(v.len() as u16).await?;
writer.write_all(&v).await?;

Expand All @@ -104,24 +109,25 @@ impl UdpTraffic {
Ok(())
}

pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
let len = reader.read_u16().await? as usize;

pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u16) -> Result<UdpTraffic> {
let mut buf = Vec::new();
buf.resize(len, 0);
buf.resize(hdr_len as usize, 0);
reader
.read_exact(&mut buf)
.await
.with_context(|| "Failed to read udp header")?;
let header: UdpHeader =
bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;

let hdr: UdpHeader =
bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?;

trace!("hdr {:?}", hdr);

let mut data = BytesMut::new();
data.resize(header.len as usize, 0);
data.resize(hdr.len as usize, 0);
reader.read_exact(&mut data).await?;

Ok(UdpTraffic {
from: header.from,
from: hdr.from,
data: data.freeze(),
})
}
Expand Down
93 changes: 59 additions & 34 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ use backoff::ExponentialBackoff;

use rand::RngCore;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{broadcast, mpsc, RwLock};
use tokio::time;
Expand Down Expand Up @@ -162,12 +161,12 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
}
Ok((conn, addr)) => {
backoff.reset();
debug!("Incomming connection from {}", addr);
debug!("Incoming connection from {}", addr);

let services = self.services.clone();
let control_channels = self.control_channels.clone();
tokio::spawn(async move {
if let Err(err) = handle_connection(conn, addr, services, control_channels).await.with_context(||"Failed to handle a connection to `server.bind_addr`") {
if let Err(err) = handle_connection(conn, services, control_channels).await {
error!("{:?}", err);
}
}.instrument(info_span!("handle_connection", %addr)));
Expand Down Expand Up @@ -215,16 +214,14 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
// Handle connections to `server.bind_addr`
async fn handle_connection<T: 'static + Transport>(
mut conn: T::Stream,
addr: SocketAddr,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
) -> Result<()> {
// Read hello
let hello = read_hello(&mut conn).await?;
match hello {
ControlChannelHello(_, service_digest) => {
do_control_channel_handshake(conn, addr, services, control_channels, service_digest)
.await?;
do_control_channel_handshake(conn, services, control_channels, service_digest).await?;
}
DataChannelHello(_, nonce) => {
do_data_channel_handshake(conn, control_channels, nonce).await?;
Expand All @@ -235,12 +232,11 @@ async fn handle_connection<T: 'static + Transport>(

async fn do_control_channel_handshake<T: 'static + Transport>(
mut conn: T::Stream,
addr: SocketAddr,
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
service_digest: ServiceDigest,
) -> Result<()> {
info!("New control channel incomming from {}", addr);
info!("New control channel incoming");

// Generate a nonce
let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
Expand Down Expand Up @@ -321,6 +317,8 @@ async fn do_data_channel_handshake<T: 'static + Transport>(
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
nonce: Nonce,
) -> Result<()> {
info!("New control channel incoming");

// Validate
let control_channels_guard = control_channels.read().await;
match control_channels_guard.get2(&nonce) {
Expand Down Expand Up @@ -358,27 +356,6 @@ where
// Store data channel creation requests
let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();

match service.service_type {
ServiceType::Tcp => tokio::spawn(
run_tcp_connection_pool::<T>(
service.bind_addr.clone(),
data_ch_rx,
data_ch_req_tx.clone(),
shutdown_tx.subscribe(),
)
.instrument(Span::current()),
),
ServiceType::Udp => tokio::spawn(
run_udp_connection_pool::<T>(
service.bind_addr.clone(),
data_ch_rx,
data_ch_req_tx.clone(),
shutdown_tx.subscribe(),
)
.instrument(Span::current()),
),
};

// Cache some data channels for later use
let pool_size = match service.service_type {
ServiceType::Tcp => TCP_POOL_SIZE,
Expand All @@ -391,6 +368,43 @@ where
};
}

let shutdown_rx_clone = shutdown_tx.subscribe();
let bind_addr = service.bind_addr.clone();
match service.service_type {
ServiceType::Tcp => tokio::spawn(
async move {
if let Err(e) = run_tcp_connection_pool::<T>(
bind_addr,
data_ch_rx,
data_ch_req_tx,
shutdown_rx_clone,
)
.await
.with_context(|| "Failed to run TCP connection pool")
{
error!("{:?}", e);
}
}
.instrument(Span::current()),
),
ServiceType::Udp => tokio::spawn(
async move {
if let Err(e) = run_udp_connection_pool::<T>(
bind_addr,
data_ch_rx,
data_ch_req_tx,
shutdown_rx_clone,
)
.await
.with_context(|| "Failed to run TCP connection pool")
{
error!("{:?}", e);
}
}
.instrument(Span::current()),
),
};

// Create the control channel
let ch = ControlChannel::<T> {
conn,
Expand Down Expand Up @@ -568,7 +582,16 @@ async fn run_udp_connection_pool<T: Transport>(
// TODO: Load balance

let l: UdpSocket = backoff::future::retry(listen_backoff(), || async {
Ok(UdpSocket::bind(&bind_addr).await?)
Ok(match UdpSocket::bind(&bind_addr)
.await
.with_context(|| "Failed to listen for the service")
{
Err(e) => {
error!("{:?}", e);
Err(e)
}
v => v,
}?)
})
.await
.with_context(|| "Failed to listen for the service")?;
Expand All @@ -594,16 +617,18 @@ async fn run_udp_connection_pool<T: Transport>(
},

// Forward outbound traffic from the client to the visitor
t = UdpTraffic::read(&mut conn) => {
let t = t?;
hdr_len = conn.read_u16() => {
let t = UdpTraffic::read(&mut conn, hdr_len?).await?;
l.send_to(&t.data, t.from).await?;
},
}

_ = shutdown_rx.recv() => {
break;
}
}
}

debug!("UDP pool dropped");

Ok(())
}