From 134e3cb6e665a331d0dc465d6df6e8ceeefff305 Mon Sep 17 00:00:00 2001 From: Michael Eden Date: Thu, 30 Mar 2017 21:46:15 -0400 Subject: [PATCH] keep track of all the used buffers in BufReaders so no data goes missing this is possible with the help of hyperium/hyper#1107 --- Cargo.toml | 2 +- src/client/builder.rs | 11 +++++---- src/client/mod.rs | 47 ++++++++++++++++++++++--------------- src/receiver.rs | 13 ++++++---- src/server/mod.rs | 12 +++++++--- src/server/upgrade/hyper.rs | 12 +++++++--- src/server/upgrade/mod.rs | 40 ++++++++++++++++++++++--------- 7 files changed, 90 insertions(+), 47 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6fed692717..d280de4bc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,7 @@ keywords = ["websocket", "websockets", "rfc6455"] license = "MIT" [dependencies] -hyper = "^0.10" +hyper = { git = "https://github.com/hyperium/hyper.git", branch = "0.10.x" } unicase = "^1.0" url = "^1.0" rustc-serialize = "^0.3" diff --git a/src/client/builder.rs b/src/client/builder.rs index 20aa001a0b..52e20b773a 100644 --- a/src/client/builder.rs +++ b/src/client/builder.rs @@ -10,12 +10,13 @@ use hyper::header::{Headers, Host, Connection, ConnectionOption, Upgrade, Protoc use unicase::UniCase; #[cfg(feature="ssl")] use openssl::ssl::{SslMethod, SslStream, SslConnector, SslConnectorBuilder}; -#[cfg(feature="ssl")] use header::extensions::Extension; use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol, WebSocketExtensions, Origin}; use result::{WSUrlErrorKind, WebSocketResult, WebSocketError}; -use stream::{Stream, NetworkStream}; +#[cfg(feature="ssl")] +use stream::NetworkStream; +use stream::Stream; use super::Client; /// Build clients with a builder-style API @@ -251,8 +252,8 @@ impl<'u> ClientBuilder<'u> { try!(write!(stream, "{}\r\n", self.headers)); // wait for a response - // TODO: some extra data might get lost with this reader, try to avoid #72 - let response = try!(parse_response(&mut BufReader::new(&mut stream))); + let mut reader = BufReader::new(stream); + let response = try!(parse_response(&mut reader)); let status = StatusCode::from_u16(response.subject.0); // validate @@ -285,6 +286,6 @@ impl<'u> ClientBuilder<'u> { return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'")); } - Ok(Client::unchecked(stream, response.headers)) + Ok(Client::unchecked(reader, response.headers)) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 8295226560..4210515d11 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -4,7 +4,9 @@ extern crate url; use std::net::TcpStream; use std::net::SocketAddr; use std::io::Result as IoResult; +use std::io::{Read, Write}; use hyper::header::Headers; +use hyper::buffer::BufReader; use ws; use ws::sender::Sender as SenderTrait; @@ -56,7 +58,7 @@ pub use self::builder::{ClientBuilder, Url, ParseError}; pub struct Client where S: Stream { - pub stream: S, + stream: BufReader, headers: Headers, sender: Sender, receiver: Receiver, @@ -66,13 +68,13 @@ impl Client { /// Shuts down the sending half of the client connection, will cause all pending /// and future IO to return immediately with an appropriate value. pub fn shutdown_sender(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Write) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Write) } /// Shuts down the receiving half of the client connection, will cause all pending /// and future IO to return immediately with an appropriate value. pub fn shutdown_receiver(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Read) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Read) } } @@ -82,27 +84,27 @@ impl Client /// Shuts down the client connection, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Both) } /// See `TcpStream.peer_addr()`. pub fn peer_addr(&self) -> IoResult { - self.stream.as_tcp().peer_addr() + self.stream.get_ref().as_tcp().peer_addr() } /// See `TcpStream.local_addr()`. pub fn local_addr(&self) -> IoResult { - self.stream.as_tcp().local_addr() + self.stream.get_ref().as_tcp().local_addr() } /// See `TcpStream.set_nodelay()`. pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> { - self.stream.as_tcp().set_nodelay(nodelay) + self.stream.get_ref().as_tcp().set_nodelay(nodelay) } /// Changes whether the stream is in nonblocking mode. pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> { - self.stream.as_tcp().set_nonblocking(nonblocking) + self.stream.get_ref().as_tcp().set_nonblocking(nonblocking) } } @@ -113,7 +115,7 @@ impl Client /// **without sending any handshake** this is meant to only be used with /// a stream that has a websocket connection already set up. /// If in doubt, don't use this! - pub fn unchecked(stream: S, headers: Headers) -> Self { + pub fn unchecked(stream: BufReader, headers: Headers) -> Self { Client { headers: headers, stream: stream, @@ -128,7 +130,7 @@ impl Client pub fn send_dataframe(&mut self, dataframe: &D) -> WebSocketResult<()> where D: DataFrameable { - self.sender.send_dataframe(&mut self.stream, dataframe) + self.sender.send_dataframe(self.stream.get_mut(), dataframe) } /// Sends a single message to the remote endpoint. @@ -136,7 +138,7 @@ impl Client where M: ws::Message<'m, D>, D: DataFrameable { - self.sender.send_message(&mut self.stream, message) + self.sender.send_message(self.stream.get_mut(), message) } /// Reads a single data frame from the remote endpoint. @@ -145,7 +147,7 @@ impl Client } /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S> { + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader> { self.receiver.incoming_dataframes(&mut self.stream) } @@ -180,15 +182,20 @@ impl Client } pub fn stream_ref(&self) -> &S { - &self.stream + self.stream.get_ref() } - pub fn stream_ref_mut(&mut self) -> &mut S { + pub fn writer_mut(&mut self) -> &mut Write { + self.stream.get_mut() + } + + pub fn reader_mut(&mut self) -> &mut Read { &mut self.stream } - pub fn into_stream(self) -> S { - self.stream + pub fn into_stream(self) -> (S, Option<(Vec, usize, usize)>) { + let (stream, buf, pos, cap) = self.stream.into_parts(); + (stream, Some((buf, pos, cap))) } /// Returns an iterator over incoming messages. @@ -229,7 +236,8 @@ impl Client ///} ///# } ///``` - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S> + pub fn incoming_messages<'a, M, D>(&'a mut self,) + -> MessageIterator<'a, Receiver, D, M, BufReader> where M: ws::Message<'a, D>, D: DataFrameable { @@ -269,9 +277,10 @@ impl Client pub fn split (self,) -> IoResult<(Reader<::Reader>, Writer<::Writer>)> { - let (read, write) = try!(self.stream.split()); + let (stream, buf, pos, cap) = self.stream.into_parts(); + let (read, write) = try!(stream.split()); Ok((Reader { - stream: read, + stream: BufReader::from_parts(read, buf, pos, cap), receiver: self.receiver, }, Writer { diff --git a/src/receiver.rs b/src/receiver.rs index 6c55b1af9f..0f1bd25e54 100644 --- a/src/receiver.rs +++ b/src/receiver.rs @@ -3,6 +3,8 @@ use std::io::Read; use std::io::Result as IoResult; +use hyper::buffer::BufReader; + use dataframe::{DataFrame, Opcode}; use result::{WebSocketResult, WebSocketError}; use ws; @@ -15,7 +17,7 @@ pub use stream::Shutdown; pub struct Reader where R: Read { - pub stream: R, + pub stream: BufReader, pub receiver: Receiver, } @@ -28,7 +30,7 @@ impl Reader } /// Returns an iterator over incoming data frames. - pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, R> { + pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader> { self.receiver.incoming_dataframes(&mut self.stream) } @@ -40,7 +42,8 @@ impl Reader self.receiver.recv_message(&mut self.stream) } - pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, R> + pub fn incoming_messages<'a, M, D>(&'a mut self,) + -> MessageIterator<'a, Receiver, D, M, BufReader> where M: ws::Message<'a, D>, D: DataFrameable { @@ -54,13 +57,13 @@ impl Reader /// Closes the receiver side of the connection, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Read) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Read) } /// Shuts down both Sender and Receiver, will cause all pending and future IO to /// return immediately with an appropriate value. pub fn shutdown_all(&self) -> IoResult<()> { - self.stream.as_tcp().shutdown(Shutdown::Both) + self.stream.get_ref().as_tcp().shutdown(Shutdown::Both) } } diff --git a/src/server/mod.rs b/src/server/mod.rs index 5a8af0be3d..dc3cbd283d 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -5,7 +5,7 @@ use std::convert::Into; #[cfg(feature="ssl")] use openssl::ssl::{SslStream, SslAcceptor}; use stream::Stream; -use self::upgrade::{WsUpgrade, IntoWs}; +use self::upgrade::{WsUpgrade, IntoWs, Buffer}; pub use self::upgrade::{Request, HyperIntoWsError}; pub mod upgrade; @@ -15,6 +15,7 @@ pub struct InvalidConnection { pub stream: Option, pub parsed: Option, + pub buffer: Option, pub error: HyperIntoWsError, } @@ -150,6 +151,7 @@ impl Server { return Err(InvalidConnection { stream: None, parsed: None, + buffer: None, error: e.into(), }) } @@ -161,6 +163,7 @@ impl Server { return Err(InvalidConnection { stream: None, parsed: None, + buffer: None, error: io::Error::new(io::ErrorKind::Other, err).into(), }) } @@ -168,10 +171,11 @@ impl Server { match stream.into_ws() { Ok(u) => Ok(u), - Err((s, r, e)) => { + Err((s, r, b, e)) => { Err(InvalidConnection { stream: Some(s), parsed: r, + buffer: b, error: e.into(), }) } @@ -213,6 +217,7 @@ impl Server { return Err(InvalidConnection { stream: None, parsed: None, + buffer: None, error: e.into(), }) } @@ -220,10 +225,11 @@ impl Server { match stream.into_ws() { Ok(u) => Ok(u), - Err((s, r, e)) => { + Err((s, r, b, e)) => { Err(InvalidConnection { stream: Some(s), parsed: r, + buffer: b, error: e.into(), }) } diff --git a/src/server/upgrade/hyper.rs b/src/server/upgrade/hyper.rs index a70c06182e..1e44a6e1a2 100644 --- a/src/server/upgrade/hyper.rs +++ b/src/server/upgrade/hyper.rs @@ -1,7 +1,7 @@ extern crate hyper; use hyper::net::NetworkStream; -use super::{IntoWs, WsUpgrade}; +use super::{IntoWs, WsUpgrade, Buffer}; pub use hyper::http::h1::Incoming; pub use hyper::method::Method; @@ -28,12 +28,18 @@ impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> { let (_, method, headers, uri, version, reader) = self.0.deconstruct(); - // TODO: some extra data might get lost with this reader, try to avoid #72 - let stream = reader.into_inner().get_mut(); + let reader = reader.into_inner(); + let (buf, pos, cap) = reader.take_buf(); + let stream = reader.get_mut(); Ok(WsUpgrade { headers: Headers::new(), stream: stream, + buffer: Some(Buffer { + buf: buf, + pos: pos, + cap: cap, + }), request: Incoming { version: version, headers: headers, diff --git a/src/server/upgrade/mod.rs b/src/server/upgrade/mod.rs index 1c75f1744f..4aeb9ba0ab 100644 --- a/src/server/upgrade/mod.rs +++ b/src/server/upgrade/mod.rs @@ -27,6 +27,12 @@ pub use self::real_hyper::header::{Headers, Upgrade, Protocol, ProtocolName, Con pub mod hyper; +pub struct Buffer { + pub buf: Vec, + pub pos: usize, + pub cap: usize, +} + /// Intermediate representation of a half created websocket session. /// Should be used to examine the client's handshake /// accept the protocols requested, route the path, etc. @@ -39,6 +45,7 @@ pub struct WsUpgrade pub headers: Headers, pub stream: S, pub request: Request, + pub buffer: Option, } impl WsUpgrade @@ -94,7 +101,12 @@ impl WsUpgrade return Err((self.stream, e)); } - Ok(Client::unchecked(self.stream, self.headers)) + let stream = match self.buffer { + Some(Buffer { buf, pos, cap }) => BufReader::from_parts(self.stream, buf, pos, cap), + None => BufReader::new(self.stream), + }; + + Ok(Client::unchecked(stream, self.headers)) } pub fn reject(self) -> Result { @@ -186,29 +198,34 @@ impl IntoWs for S where S: Stream { type Stream = S; - type Error = (Self, Option, HyperIntoWsError); + type Error = (S, Option, Option, HyperIntoWsError); - fn into_ws(mut self) -> Result, Self::Error> { - let request = { - // TODO: some extra data might get lost with this reader, try to avoid #72 - let mut reader = BufReader::new(&mut self); - parse_request(&mut reader) - }; + fn into_ws(self) -> Result, Self::Error> { + let mut reader = BufReader::new(self); + let request = parse_request(&mut reader); + + let (stream, buf, pos, cap) = reader.into_parts(); + let buffer = Some(Buffer { + buf: buf, + cap: cap, + pos: pos, + }); let request = match request { Ok(r) => r, - Err(e) => return Err((self, None, e.into())), + Err(e) => return Err((stream, None, buffer, e.into())), }; match validate(&request.subject.0, &request.version, &request.headers) { Ok(_) => { Ok(WsUpgrade { headers: Headers::new(), - stream: self, + stream: stream, request: request, + buffer: buffer, }) } - Err(e) => Err((self, Some(request), e)), + Err(e) => Err((stream, Some(request), buffer, e)), } } } @@ -226,6 +243,7 @@ impl IntoWs for RequestStreamPair headers: Headers::new(), stream: self.0, request: self.1, + buffer: None, }) } Err(e) => Err((self.0, self.1, e)),