Skip to content

Commit

Permalink
fix: work further on proxy proto
Browse files Browse the repository at this point in the history
  • Loading branch information
conblem committed Jan 13, 2021
1 parent 4438d44 commit ab57cee
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ use metrics::{metrics, metrics_wrapper};
use sqlx::PgPool;
use tokio::net::TcpListener;
use tokio::net::ToSocketAddrs;
use tokio_stream::wrappers::TcpListenerStream;
use tracing::{debug_span, info};
use tracing_futures::Instrument;
use tokio_stream::wrappers::TcpListenerStream;

mod metrics;
mod proxy;
Expand Down
84 changes: 53 additions & 31 deletions src/api/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use futures_util::future::{ready, BoxFuture, FutureExt};
use std::future::Future;
use std::io::{Cursor, IoSlice, Write};
use std::mem::MaybeUninit;
use std::net::SocketAddr;
use std::pin::Pin;
Expand All @@ -11,70 +12,85 @@ pub(super) trait PeerAddr<E: std::error::Error> {
fn proxy_peer<'a>(&'a mut self) -> BoxFuture<'a, Result<SocketAddr, E>>;
}

impl PeerAddr<tokio::io::Error> for TcpStream {
fn proxy_peer(&mut self) -> BoxFuture<IoResult<SocketAddr>> {
ready(self.peer_addr()).boxed()
}
}

struct PeerAddrFuture<'a> {
stream: Pin<&'a mut TcpStream>,
data: Vec<u8>,
stream: &'a mut ProxyStream,
}

impl<'a> PeerAddrFuture<'a> {
fn new(stream: &'a mut TcpStream) -> Self {
PeerAddrFuture {
stream: Pin::new(stream),
data: vec![],
}
fn new(stream: &'a mut ProxyStream) -> Self {
PeerAddrFuture { stream }
}
}

impl<'a> Future for PeerAddrFuture<'a> {
type Output = Result<SocketAddr, tokio::io::Error>;
type Output = IoResult<SocketAddr>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let peer_addr_future: &mut Self = self.get_mut();
let stream = peer_addr_future.stream.as_mut();
let this = &mut self.get_mut().stream;
// add option again to make impossible to pull future later
let data = match &mut this.data {
Some(ref mut data) => data,
None => unreachable!("Future cannot be polled anymore"),
};

let mut buf = [MaybeUninit::<u8>::uninit(); 1024];
let mut buf = ReadBuf::uninit(&mut buf);

let data = match stream.poll_read(cx, &mut buf) {
let stream = Pin::new(&mut this.stream);
let buf = match stream.poll_read(cx, &mut buf) {
Poll::Ready(Ok(_)) => buf.filled_mut(),
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};

peer_addr_future.data.copy_from_slice(data);
if let Err(e) = data.write_all(buf) {
return Poll::Ready(Err(e));
}

Poll::Pending
}
}

impl PeerAddr<tokio::io::Error> for TcpStream {
fn proxy_peer(&mut self) -> BoxFuture<Result<SocketAddr, tokio::io::Error>> {
ready(self.peer_addr()).boxed()
}
}

pub(super) struct ProxyStream {
stream: TcpStream,
data: Option<Cursor<Vec<u8>>>,
start_of_data: usize,
}

impl From<TcpStream> for ProxyStream {
fn from(stream: TcpStream) -> Self {
ProxyStream { stream }
ProxyStream {
stream,
data: Some(Default::default()),
start_of_data: 0,
}
}
}

impl PeerAddr<tokio::io::Error> for ProxyStream {
fn proxy_peer(&mut self) -> BoxFuture<Result<SocketAddr, tokio::io::Error>> {
PeerAddrFuture::new(&mut self.stream).boxed()
fn proxy_peer(&mut self) -> BoxFuture<IoResult<SocketAddr>> {
PeerAddrFuture::new(self).boxed()
}
}

impl AsyncRead for ProxyStream {
fn poll_read(
mut self: Pin<&mut Self>,
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<IoResult<()>> {
Pin::new(&mut self.stream).poll_read(cx, buf)
let this = self.get_mut();
// handle the case were the full data has no space in the first place
if let Some(data) = this.data.take() {
buf.put_slice(&data.get_ref()[this.start_of_data..])
}
Pin::new(&mut this.stream).poll_read(cx, buf)
}
}

Expand All @@ -83,21 +99,27 @@ impl AsyncWrite for ProxyStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, tokio::io::Error>> {
) -> Poll<IoResult<usize>> {
Pin::new(&mut self.stream).poll_write(cx, buf)
}

fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), tokio::io::Error>> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.stream).poll_flush(cx)
}

fn poll_shutdown(
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<()>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), tokio::io::Error>> {
Pin::new(&mut self.stream).poll_shutdown(cx)
bufs: &[IoSlice<'_>],
) -> Poll<IoResult<usize>> {
Pin::new(&mut self.stream).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}
}

0 comments on commit ab57cee

Please sign in to comment.