From ab57cee04c99eb63ce711b320a6471ca85aec4c6 Mon Sep 17 00:00:00 2001 From: conblem Date: Wed, 13 Jan 2021 11:32:34 +0100 Subject: [PATCH] fix: work further on proxy proto --- src/api/mod.rs | 2 +- src/api/proxy.rs | 84 ++++++++++++++++++++++++++++++------------------ 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 44c5f09..f555b2e 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -5,9 +5,9 @@ use metrics::{metrics, metrics_wrapper}; use sqlx::PgPool; use tokio::net::TcpListener; use tokio::net::ToSocketAddrs; +use tokio_stream::wrappers::TcpListenerStream; use tracing::{debug_span, info}; use tracing_futures::Instrument; -use tokio_stream::wrappers::TcpListenerStream; mod metrics; mod proxy; diff --git a/src/api/proxy.rs b/src/api/proxy.rs index 703db13..c96cb80 100644 --- a/src/api/proxy.rs +++ b/src/api/proxy.rs @@ -1,5 +1,6 @@ use futures_util::future::{ready, BoxFuture, FutureExt}; use std::future::Future; +use std::io::{Cursor, IoSlice, Write}; use std::mem::MaybeUninit; use std::net::SocketAddr; use std::pin::Pin; @@ -11,70 +12,85 @@ pub(super) trait PeerAddr { fn proxy_peer<'a>(&'a mut self) -> BoxFuture<'a, Result>; } +impl PeerAddr for TcpStream { + fn proxy_peer(&mut self) -> BoxFuture> { + ready(self.peer_addr()).boxed() + } +} + struct PeerAddrFuture<'a> { - stream: Pin<&'a mut TcpStream>, - data: Vec, + stream: &'a mut ProxyStream, } impl<'a> PeerAddrFuture<'a> { - fn new(stream: &'a mut TcpStream) -> Self { - PeerAddrFuture { - stream: Pin::new(stream), - data: vec![], - } + fn new(stream: &'a mut ProxyStream) -> Self { + PeerAddrFuture { stream } } } impl<'a> Future for PeerAddrFuture<'a> { - type Output = Result; + type Output = IoResult; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let peer_addr_future: &mut Self = self.get_mut(); - let stream = peer_addr_future.stream.as_mut(); + let this = &mut self.get_mut().stream; + // add option again to make impossible to pull future later + let data = match &mut this.data { + Some(ref mut data) => data, + None => unreachable!("Future cannot be polled anymore"), + }; + let mut buf = [MaybeUninit::::uninit(); 1024]; let mut buf = ReadBuf::uninit(&mut buf); - let data = match stream.poll_read(cx, &mut buf) { + let stream = Pin::new(&mut this.stream); + let buf = match stream.poll_read(cx, &mut buf) { Poll::Ready(Ok(_)) => buf.filled_mut(), Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, }; - peer_addr_future.data.copy_from_slice(data); + if let Err(e) = data.write_all(buf) { + return Poll::Ready(Err(e)); + } Poll::Pending } } -impl PeerAddr for TcpStream { - fn proxy_peer(&mut self) -> BoxFuture> { - ready(self.peer_addr()).boxed() - } -} - pub(super) struct ProxyStream { stream: TcpStream, + data: Option>>, + start_of_data: usize, } impl From for ProxyStream { fn from(stream: TcpStream) -> Self { - ProxyStream { stream } + ProxyStream { + stream, + data: Some(Default::default()), + start_of_data: 0, + } } } impl PeerAddr for ProxyStream { - fn proxy_peer(&mut self) -> BoxFuture> { - PeerAddrFuture::new(&mut self.stream).boxed() + fn proxy_peer(&mut self) -> BoxFuture> { + PeerAddrFuture::new(self).boxed() } } impl AsyncRead for ProxyStream { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - Pin::new(&mut self.stream).poll_read(cx, buf) + let this = self.get_mut(); + // handle the case were the full data has no space in the first place + if let Some(data) = this.data.take() { + buf.put_slice(&data.get_ref()[this.start_of_data..]) + } + Pin::new(&mut this.stream).poll_read(cx, buf) } } @@ -83,21 +99,27 @@ impl AsyncWrite for ProxyStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], - ) -> Poll> { + ) -> Poll> { Pin::new(&mut self.stream).poll_write(cx, buf) } - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.stream).poll_flush(cx) } - fn poll_shutdown( + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } + + fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll> { - Pin::new(&mut self.stream).poll_shutdown(cx) + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() } }