Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the close reason available to application code #29

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/autobahn_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async fn run_case(n: usize) -> Result<(), BoxedError> {
sender.send_text(std::str::from_utf8(&message)?).await?;
sender.flush().await?
}
Err(connection::Error::Closed) => return Ok(()),
Err(connection::Error::Closed(_)) => return Ok(()),
Err(e) => return Err(e.into())
}
}
Expand Down Expand Up @@ -97,4 +97,3 @@ fn new_client(socket: TcpStream, path: &str) -> handshake::Client<'_, BufReader<
client.add_extension(Box::new(deflate));
client
}

3 changes: 1 addition & 2 deletions examples/autobahn_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async fn main() -> Result<(), BoxedError> {
break
}
}
Err(connection::Error::Closed) => break,
Err(connection::Error::Closed(_)) => break,
Err(e) => {
log::error!("connection error: {}", e);
break
Expand All @@ -74,4 +74,3 @@ fn new_server<'a>(socket: TcpStream) -> handshake::Server<'a, BufReader<BufWrite
server.add_extension(Box::new(deflate));
server
}

84 changes: 56 additions & 28 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,24 +207,25 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Receiver<T> {
let message_len = message.len();
loop {
if self.is_closed {
log::debug!("{}: can not receive, connection is closed", self.id);
return Err(Error::Closed)
log::debug!("{}: cannot receive, connection is closed", self.id);
return Err(Error::Closed(None));
}

self.ctrl_buffer.clear();
let mut header = self.receive_header().await?;
log::trace!("{}: recv: {}", self.id, header);

// Handle control frames.
// Handle control frames: PING, PONG and CLOSE.
if header.opcode().is_control() {
self.read_buffer(&header).await?;
self.ctrl_buffer = self.buffer.split_to(header.payload_len());
base::Codec::apply_mask(&header, &mut self.ctrl_buffer);
if header.opcode() == OpCode::Pong {
return Ok(Incoming::Pong(&self.ctrl_buffer[..]))
return Ok(Incoming::Pong(&self.ctrl_buffer[..]));
}
self.on_control(&header).await?;
continue

continue;
}

length = length.saturating_add(header.payload_len());
Expand Down Expand Up @@ -351,6 +352,10 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Receiver<T> {
}

/// Answer incoming control frames.
/// `PING`: replied to immediately with a `PONG`
/// `PONG`: no action
/// `CLOSE`: replied to immediately with a `CLOSE`; returns an [`Error::Closed`] with the [`CloseReason`]
/// All other [`OpCode`]s return [`Error::UnexpectedOpCode`]
async fn on_control(&mut self, header: &Header) -> Result<(), Error> {
match header.opcode() {
OpCode::Ping => {
Expand All @@ -363,19 +368,22 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Receiver<T> {
}
OpCode::Pong => Ok(()),
OpCode::Close => {
log::trace!("Acknowledging close to sender");
self.is_closed = true;
let (mut header, code) = close_answer(&self.ctrl_buffer)?;
let (mut header, reason) = close_answer(&self.ctrl_buffer)?;
// Write back a Close frame
let mut unused = Vec::new();
if let Some(c) = code {
let mut data = c.to_be_bytes();
if let Some(CloseReason { code, .. }) = reason {
let mut data = code.to_be_bytes();
let mut data = Storage::Unique(&mut data);
write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused).await?
} else {
let mut data = Storage::Unique(&mut []);
write(self.id, self.mode, &mut self.codec, &mut self.writer, &mut header, &mut data, &mut unused).await?
}
self.flush().await?;
self.writer.lock().await.close().await.or(Err(Error::Closed))
self.writer.lock().await.close().await?;
Err(Error::Closed(reason))
}
OpCode::Binary
| OpCode::Text
Expand Down Expand Up @@ -411,7 +419,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Receiver<T> {
if self.is_closed {
return Ok(())
}
self.writer.lock().await.flush().await.or(Err(Error::Closed))
self.writer.lock().await.flush().await.or(Err(Error::Closed(None)))
}
}

Expand Down Expand Up @@ -452,7 +460,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Sender<T> {
/// Flush the socket buffer.
pub async fn flush(&mut self) -> Result<(), Error> {
log::trace!("{}: flushing connection", self.id);
self.writer.lock().await.flush().await.or(Err(Error::Closed))
self.writer.lock().await.flush().await.or(Err(Error::Closed(None)))
}

/// Send a close message and close the connection.
Expand All @@ -462,7 +470,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Sender<T> {
let code = 1000_u16.to_be_bytes(); // 1000 = normal closure
self.write(&mut header, &mut Storage::Shared(&code[..])).await?;
self.flush().await?;
self.writer.lock().await.close().await.or(Err(Error::Closed))
self.writer.lock().await.close().await.or(Err(Error::Closed(None)))
}

/// Send arbitrary websocket frames.
Expand Down Expand Up @@ -511,44 +519,57 @@ async fn write<T: AsyncWrite + Unpin>

let header_bytes = codec.encode_header(&header);
let mut w = writer.lock().await;
w.write_all(&header_bytes).await.or(Err(Error::Closed))?;
w.write_all(&header_bytes).await.or(Err(Error::Closed(None)))?;

if !header.is_masked() {
return w.write_all(data.as_ref()).await.or(Err(Error::Closed))
return w.write_all(data.as_ref()).await.or(Err(Error::Closed(None)))
}

match data {
Storage::Shared(slice) => {
mask_buffer.clear();
mask_buffer.extend_from_slice(slice);
base::Codec::apply_mask(header, mask_buffer);
w.write_all(mask_buffer).await.or(Err(Error::Closed))
w.write_all(mask_buffer).await.or(Err(Error::Closed(None)))
}
Storage::Unique(slice) => {
base::Codec::apply_mask(header, slice);
w.write_all(slice).await.or(Err(Error::Closed))
w.write_all(slice).await.or(Err(Error::Closed(None)))
}
Storage::Owned(ref mut bytes) => {
base::Codec::apply_mask(header, bytes);
w.write_all(bytes).await.or(Err(Error::Closed))
w.write_all(bytes).await.or(Err(Error::Closed(None)))
}
}
}

/// Create a close frame based on the given data.
fn close_answer(data: &[u8]) -> Result<(Header, Option<u16>), Error> {
/// Create a close frame based on the given data. The close frame is echoed back
/// to the sender.
fn close_answer(data: &[u8]) -> Result<(Header, Option<CloseReason>), Error> {
let answer = Header::new(OpCode::Close);
if data.len() < 2 {
return Ok((answer, None))
return Ok((answer, None));
}
std::str::from_utf8(&data[2 ..])?; // check reason is properly encoded
// Check that the reason string is properly encoded
let descr = std::str::from_utf8(&data[2..])?.into();
let code = u16::from_be_bytes([data[0], data[1]]);
let reason = CloseReason { code, descr: Some(descr) };
log::trace!("Closing reason: {:?}", reason);

// Status codes are defined in
// https://tools.ietf.org/html/rfc6455#section-7.4.1 and
// https://mailarchive.ietf.org/arch/msg/hybi/P_1vbD9uyHl63nbIIbFxKMfSwcM/
match code {
| 1000 ..= 1003
| 1007 ..= 1011
| 1012 // Service Restart
| 1013 // Try Again Later
| 1015
| 3000 ..= 4999 => Ok((answer, Some(code))), // acceptable codes
_ => Ok((answer, Some(1002))) // invalid code => protocol error (1002)
| 3000 ..= 4999 => Ok((answer, Some(reason))), // acceptable codes
_ => {
// invalid code => protocol error (1002)
Ok((answer, Some(CloseReason { code: 1002, descr: None})))
}
}
}

Expand All @@ -569,7 +590,14 @@ pub enum Error {
/// The total message payload data size exceeds the configured maximum.
MessageTooLarge { current: usize, maximum: usize },
/// The connection is closed.
Closed
Closed(Option<CloseReason>),
}

/// Reason for closing the connection.
#[derive(Debug)]
pub struct CloseReason {
code: u16,
descr: Option<String>,
Comment on lines +599 to +600
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you want to access these fields

Suggested change
code: u16,
descr: Option<String>,
pub code: u16,
pub descr: Option<String>,

}

impl fmt::Display for Error {
Expand All @@ -587,8 +615,8 @@ impl fmt::Display for Error {
write!(f, "utf-8 error: {}", e),
Error::MessageTooLarge { current, maximum } =>
write!(f, "message too large: len >= {}, maximum = {}", current, maximum),
Error::Closed =>
f.write_str("connection closed")
Error::Closed(reason) =>
write!(f, "connection closed (reason: {:?})", reason)
}
}
}
Expand All @@ -602,7 +630,7 @@ impl std::error::Error for Error {
Error::Utf8(e) => Some(e),
Error::UnexpectedOpCode(_)
| Error::MessageTooLarge {..}
| Error::Closed
| Error::Closed(_)
=> None
}
}
Expand All @@ -611,7 +639,7 @@ impl std::error::Error for Error {
impl From<io::Error> for Error {
fn from(e: io::Error) -> Self {
if e.kind() == io::ErrorKind::UnexpectedEof {
Error::Closed
Error::Closed(None)
} else {
Error::Io(e)
}
Expand Down