Skip to content

Commit

Permalink
keep track of all the used buffers in BufReaders so no data goes missing
Browse files Browse the repository at this point in the history
this is possible with the help of hyperium/hyper#1107
  • Loading branch information
illegalprime committed Apr 1, 2017
1 parent df581cf commit 134e3cb
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ keywords = ["websocket", "websockets", "rfc6455"]
license = "MIT"

[dependencies]
hyper = "^0.10"
hyper = { git = "https://github.com/hyperium/hyper.git", branch = "0.10.x" }
unicase = "^1.0"
url = "^1.0"
rustc-serialize = "^0.3"
Expand Down
11 changes: 6 additions & 5 deletions src/client/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ use hyper::header::{Headers, Host, Connection, ConnectionOption, Upgrade, Protoc
use unicase::UniCase;
#[cfg(feature="ssl")]
use openssl::ssl::{SslMethod, SslStream, SslConnector, SslConnectorBuilder};
#[cfg(feature="ssl")]
use header::extensions::Extension;
use header::{WebSocketAccept, WebSocketKey, WebSocketVersion, WebSocketProtocol,
WebSocketExtensions, Origin};
use result::{WSUrlErrorKind, WebSocketResult, WebSocketError};
use stream::{Stream, NetworkStream};
#[cfg(feature="ssl")]
use stream::NetworkStream;
use stream::Stream;
use super::Client;

/// Build clients with a builder-style API
Expand Down Expand Up @@ -251,8 +252,8 @@ impl<'u> ClientBuilder<'u> {
try!(write!(stream, "{}\r\n", self.headers));

// wait for a response
// TODO: some extra data might get lost with this reader, try to avoid #72
let response = try!(parse_response(&mut BufReader::new(&mut stream)));
let mut reader = BufReader::new(stream);
let response = try!(parse_response(&mut reader));
let status = StatusCode::from_u16(response.subject.0);

// validate
Expand Down Expand Up @@ -285,6 +286,6 @@ impl<'u> ClientBuilder<'u> {
return Err(WebSocketError::ResponseError("Connection field must be 'Upgrade'"));
}

Ok(Client::unchecked(stream, response.headers))
Ok(Client::unchecked(reader, response.headers))
}
}
47 changes: 28 additions & 19 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ extern crate url;
use std::net::TcpStream;
use std::net::SocketAddr;
use std::io::Result as IoResult;
use std::io::{Read, Write};
use hyper::header::Headers;
use hyper::buffer::BufReader;

use ws;
use ws::sender::Sender as SenderTrait;
Expand Down Expand Up @@ -56,7 +58,7 @@ pub use self::builder::{ClientBuilder, Url, ParseError};
pub struct Client<S>
where S: Stream
{
pub stream: S,
stream: BufReader<S>,
headers: Headers,
sender: Sender,
receiver: Receiver,
Expand All @@ -66,13 +68,13 @@ impl Client<TcpStream> {
/// Shuts down the sending half of the client connection, will cause all pending
/// and future IO to return immediately with an appropriate value.
pub fn shutdown_sender(&self) -> IoResult<()> {
self.stream.as_tcp().shutdown(Shutdown::Write)
self.stream.get_ref().as_tcp().shutdown(Shutdown::Write)
}

/// Shuts down the receiving half of the client connection, will cause all pending
/// and future IO to return immediately with an appropriate value.
pub fn shutdown_receiver(&self) -> IoResult<()> {
self.stream.as_tcp().shutdown(Shutdown::Read)
self.stream.get_ref().as_tcp().shutdown(Shutdown::Read)
}
}

Expand All @@ -82,27 +84,27 @@ impl<S> Client<S>
/// Shuts down the client connection, will cause all pending and future IO to
/// return immediately with an appropriate value.
pub fn shutdown(&self) -> IoResult<()> {
self.stream.as_tcp().shutdown(Shutdown::Both)
self.stream.get_ref().as_tcp().shutdown(Shutdown::Both)
}

/// See `TcpStream.peer_addr()`.
pub fn peer_addr(&self) -> IoResult<SocketAddr> {
self.stream.as_tcp().peer_addr()
self.stream.get_ref().as_tcp().peer_addr()
}

/// See `TcpStream.local_addr()`.
pub fn local_addr(&self) -> IoResult<SocketAddr> {
self.stream.as_tcp().local_addr()
self.stream.get_ref().as_tcp().local_addr()
}

/// See `TcpStream.set_nodelay()`.
pub fn set_nodelay(&mut self, nodelay: bool) -> IoResult<()> {
self.stream.as_tcp().set_nodelay(nodelay)
self.stream.get_ref().as_tcp().set_nodelay(nodelay)
}

/// Changes whether the stream is in nonblocking mode.
pub fn set_nonblocking(&self, nonblocking: bool) -> IoResult<()> {
self.stream.as_tcp().set_nonblocking(nonblocking)
self.stream.get_ref().as_tcp().set_nonblocking(nonblocking)
}
}

Expand All @@ -113,7 +115,7 @@ impl<S> Client<S>
/// **without sending any handshake** this is meant to only be used with
/// a stream that has a websocket connection already set up.
/// If in doubt, don't use this!
pub fn unchecked(stream: S, headers: Headers) -> Self {
pub fn unchecked(stream: BufReader<S>, headers: Headers) -> Self {
Client {
headers: headers,
stream: stream,
Expand All @@ -128,15 +130,15 @@ impl<S> Client<S>
pub fn send_dataframe<D>(&mut self, dataframe: &D) -> WebSocketResult<()>
where D: DataFrameable
{
self.sender.send_dataframe(&mut self.stream, dataframe)
self.sender.send_dataframe(self.stream.get_mut(), dataframe)
}

/// Sends a single message to the remote endpoint.
pub fn send_message<'m, M, D>(&mut self, message: &'m M) -> WebSocketResult<()>
where M: ws::Message<'m, D>,
D: DataFrameable
{
self.sender.send_message(&mut self.stream, message)
self.sender.send_message(self.stream.get_mut(), message)
}

/// Reads a single data frame from the remote endpoint.
Expand All @@ -145,7 +147,7 @@ impl<S> Client<S>
}

/// Returns an iterator over incoming data frames.
pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, S> {
pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader<S>> {
self.receiver.incoming_dataframes(&mut self.stream)
}

Expand Down Expand Up @@ -180,15 +182,20 @@ impl<S> Client<S>
}

pub fn stream_ref(&self) -> &S {
&self.stream
self.stream.get_ref()
}

pub fn stream_ref_mut(&mut self) -> &mut S {
pub fn writer_mut(&mut self) -> &mut Write {
self.stream.get_mut()
}

pub fn reader_mut(&mut self) -> &mut Read {
&mut self.stream
}

pub fn into_stream(self) -> S {
self.stream
pub fn into_stream(self) -> (S, Option<(Vec<u8>, usize, usize)>) {
let (stream, buf, pos, cap) = self.stream.into_parts();
(stream, Some((buf, pos, cap)))
}

/// Returns an iterator over incoming messages.
Expand Down Expand Up @@ -229,7 +236,8 @@ impl<S> Client<S>
///}
///# }
///```
pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, S>
pub fn incoming_messages<'a, M, D>(&'a mut self,)
-> MessageIterator<'a, Receiver, D, M, BufReader<S>>
where M: ws::Message<'a, D>,
D: DataFrameable
{
Expand Down Expand Up @@ -269,9 +277,10 @@ impl<S> Client<S>
pub fn split
(self,)
-> IoResult<(Reader<<S as Splittable>::Reader>, Writer<<S as Splittable>::Writer>)> {
let (read, write) = try!(self.stream.split());
let (stream, buf, pos, cap) = self.stream.into_parts();
let (read, write) = try!(stream.split());
Ok((Reader {
stream: read,
stream: BufReader::from_parts(read, buf, pos, cap),
receiver: self.receiver,
},
Writer {
Expand Down
13 changes: 8 additions & 5 deletions src/receiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use std::io::Read;
use std::io::Result as IoResult;

use hyper::buffer::BufReader;

use dataframe::{DataFrame, Opcode};
use result::{WebSocketResult, WebSocketError};
use ws;
Expand All @@ -15,7 +17,7 @@ pub use stream::Shutdown;
pub struct Reader<R>
where R: Read
{
pub stream: R,
pub stream: BufReader<R>,
pub receiver: Receiver,
}

Expand All @@ -28,7 +30,7 @@ impl<R> Reader<R>
}

/// Returns an iterator over incoming data frames.
pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, R> {
pub fn incoming_dataframes<'a>(&'a mut self) -> DataFrameIterator<'a, Receiver, BufReader<R>> {
self.receiver.incoming_dataframes(&mut self.stream)
}

Expand All @@ -40,7 +42,8 @@ impl<R> Reader<R>
self.receiver.recv_message(&mut self.stream)
}

pub fn incoming_messages<'a, M, D>(&'a mut self) -> MessageIterator<'a, Receiver, D, M, R>
pub fn incoming_messages<'a, M, D>(&'a mut self,)
-> MessageIterator<'a, Receiver, D, M, BufReader<R>>
where M: ws::Message<'a, D>,
D: DataFrameable
{
Expand All @@ -54,13 +57,13 @@ impl<S> Reader<S>
/// Closes the receiver side of the connection, will cause all pending and future IO to
/// return immediately with an appropriate value.
pub fn shutdown(&self) -> IoResult<()> {
self.stream.as_tcp().shutdown(Shutdown::Read)
self.stream.get_ref().as_tcp().shutdown(Shutdown::Read)
}

/// Shuts down both Sender and Receiver, will cause all pending and future IO to
/// return immediately with an appropriate value.
pub fn shutdown_all(&self) -> IoResult<()> {
self.stream.as_tcp().shutdown(Shutdown::Both)
self.stream.get_ref().as_tcp().shutdown(Shutdown::Both)
}
}

Expand Down
12 changes: 9 additions & 3 deletions src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::convert::Into;
#[cfg(feature="ssl")]
use openssl::ssl::{SslStream, SslAcceptor};
use stream::Stream;
use self::upgrade::{WsUpgrade, IntoWs};
use self::upgrade::{WsUpgrade, IntoWs, Buffer};
pub use self::upgrade::{Request, HyperIntoWsError};

pub mod upgrade;
Expand All @@ -15,6 +15,7 @@ pub struct InvalidConnection<S>
{
pub stream: Option<S>,
pub parsed: Option<Request>,
pub buffer: Option<Buffer>,
pub error: HyperIntoWsError,
}

Expand Down Expand Up @@ -150,6 +151,7 @@ impl Server<SslAcceptor> {
return Err(InvalidConnection {
stream: None,
parsed: None,
buffer: None,
error: e.into(),
})
}
Expand All @@ -161,17 +163,19 @@ impl Server<SslAcceptor> {
return Err(InvalidConnection {
stream: None,
parsed: None,
buffer: None,
error: io::Error::new(io::ErrorKind::Other, err).into(),
})
}
};

match stream.into_ws() {
Ok(u) => Ok(u),
Err((s, r, e)) => {
Err((s, r, b, e)) => {
Err(InvalidConnection {
stream: Some(s),
parsed: r,
buffer: b,
error: e.into(),
})
}
Expand Down Expand Up @@ -213,17 +217,19 @@ impl Server<NoSslAcceptor> {
return Err(InvalidConnection {
stream: None,
parsed: None,
buffer: None,
error: e.into(),
})
}
};

match stream.into_ws() {
Ok(u) => Ok(u),
Err((s, r, e)) => {
Err((s, r, b, e)) => {
Err(InvalidConnection {
stream: Some(s),
parsed: r,
buffer: b,
error: e.into(),
})
}
Expand Down
12 changes: 9 additions & 3 deletions src/server/upgrade/hyper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
extern crate hyper;

use hyper::net::NetworkStream;
use super::{IntoWs, WsUpgrade};
use super::{IntoWs, WsUpgrade, Buffer};

pub use hyper::http::h1::Incoming;
pub use hyper::method::Method;
Expand All @@ -28,12 +28,18 @@ impl<'a, 'b> IntoWs for HyperRequest<'a, 'b> {
let (_, method, headers, uri, version, reader) =
self.0.deconstruct();

// TODO: some extra data might get lost with this reader, try to avoid #72
let stream = reader.into_inner().get_mut();
let reader = reader.into_inner();
let (buf, pos, cap) = reader.take_buf();
let stream = reader.get_mut();

Ok(WsUpgrade {
headers: Headers::new(),
stream: stream,
buffer: Some(Buffer {
buf: buf,
pos: pos,
cap: cap,
}),
request: Incoming {
version: version,
headers: headers,
Expand Down
Loading

0 comments on commit 134e3cb

Please sign in to comment.