Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid race condition between pending frames and closing stream #156

Merged
merged 15 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion test-harness/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ log = "0.4.17"
[dev-dependencies]
env_logger = "0.10"
constrained-connection = "0.1"

2 changes: 2 additions & 0 deletions yamux/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ nohash-hasher = "0.2"
parking_lot = "0.12"
rand = "0.8.3"
static_assertions = "1"
pin-project = "1.1.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using pin-project, we could as well implement Unpin for Stream.

impl Unpin for Stream {}

That said, I doubt there is any async Rust code-base without pin-project in its dependency tree already. Thus fine to include it here as well.


[dev-dependencies]
anyhow = "1"
Expand All @@ -26,6 +27,7 @@ quickcheck = "1.0"
tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] }
tokio-util = { version = "0.7", features = ["compat"] }
constrained-connection = "0.1"
futures_ringbuf = "0.3.1"

[[bench]]
name = "concurrent"
Expand Down
225 changes: 108 additions & 117 deletions yamux/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,18 @@ use crate::{
error::ConnectionError,
frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID},
frame::{self, Frame},
Config, WindowUpdateMode, DEFAULT_CREDIT, MAX_COMMAND_BACKLOG,
Config, WindowUpdateMode, DEFAULT_CREDIT,
};
use cleanup::Cleanup;
use closing::Closing;
use futures::stream::SelectAll;
use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse};
use nohash_hasher::IntMap;
use std::collections::VecDeque;
use std::task::Context;
use std::task::{Context, Waker};
use std::{fmt, sync::Arc, task::Poll};

use crate::tagged_stream::TaggedStream;
pub use stream::{Packet, State, Stream};

/// How the connection is used.
Expand Down Expand Up @@ -347,10 +349,11 @@ struct Active<T> {
config: Arc<Config>,
socket: Fuse<frame::Io<T>>,
next_id: u32,

streams: IntMap<StreamId, Stream>,
stream_sender: mpsc::Sender<StreamCommand>,
stream_receiver: mpsc::Receiver<StreamCommand>,
dropped_streams: Vec<StreamId>,
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
no_streams_waker: Option<Waker>,

pending_frames: VecDeque<Frame<()>>,
}

Expand All @@ -360,7 +363,7 @@ pub(crate) enum StreamCommand {
/// A new frame should be sent to the remote.
SendFrame(Frame<Either<Data, WindowUpdate>>),
/// Close a stream.
CloseStream { id: StreamId, ack: bool },
CloseStream { ack: bool },
}

/// Possible actions as a result of incoming frame handling.
Expand Down Expand Up @@ -408,28 +411,26 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn new(socket: T, cfg: Config, mode: Mode) -> Self {
let id = Id::random();
log::debug!("new connection: {} ({:?})", id, mode);
let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG);
let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse();
Active {
id,
mode,
config: Arc::new(cfg),
socket,
streams: IntMap::default(),
stream_sender,
stream_receiver,
stream_receivers: SelectAll::default(),
no_streams_waker: None,
next_id: match mode {
Mode::Client => 1,
Mode::Server => 2,
},
dropped_streams: Vec::new(),
pending_frames: VecDeque::default(),
}
}

/// Gracefully close the connection to the remote.
fn close(self) -> Closing<T> {
Closing::new(self.stream_receiver, self.pending_frames, self.socket)
Closing::new(self.stream_receivers, self.pending_frames, self.socket)
}

/// Cleanup all our resources.
Expand All @@ -438,13 +439,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
fn cleanup(mut self, error: ConnectionError) -> Cleanup {
self.drop_all_streams();

Cleanup::new(self.stream_receiver, error)
Cleanup::new(self.stream_receivers, error)
}

fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<Stream>> {
loop {
self.garbage_collect();

if self.socket.poll_ready_unpin(cx).is_ready() {
if let Some(frame) = self.pending_frames.pop_front() {
self.socket.start_send_unpin(frame)?;
Expand All @@ -457,17 +456,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Poll::Pending => {}
}

match self.stream_receiver.poll_next_unpin(cx) {
Poll::Ready(Some(StreamCommand::SendFrame(frame))) => {
self.on_send_frame(frame);
match self.stream_receivers.poll_next_unpin(cx) {
Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => {
self.on_send_frame(frame.into());
continue;
}
Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => {
Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => {
self.on_close_stream(id, ack);
continue;
}
Poll::Ready(Some((id, None))) => {
self.on_drop_stream(id);
continue;
}
Poll::Ready(None) => {
debug_assert!(false, "Only closed during shutdown")
self.no_streams_waker = Some(cx.waker().clone());
}
Poll::Pending => {}
}
Expand Down Expand Up @@ -508,16 +511,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
self.pending_frames.push_back(frame.into());
}

let stream = {
let config = self.config.clone();
let sender = self.stream_sender.clone();
let window = self.config.receive_window;
let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender);
if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}
stream
};
let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT);

if extra_credit == 0 {
stream.set_flag(stream::Flag::Syn)
}

log::debug!("{}: new outbound {} of {}", self.id, stream, self);
self.streams.insert(id, stream.clone());
Expand All @@ -541,6 +539,69 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
.push_back(Frame::close_stream(id, ack).into());
}

fn on_drop_stream(&mut self, id: StreamId) {
let stream = self.streams.remove(&id).expect("stream not found");

log::trace!("{}: removing dropped {}", self.id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We already sent our FIN frame.
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if self.config.window_update_mode == WindowUpdateMode::OnRead
&& shared.window == 0
{
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We already have sent our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
}

/// Process the result of reading from the socket.
///
/// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed
Expand Down Expand Up @@ -628,12 +689,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::internal_error());
}
let mut stream = {
let config = self.config.clone();
let credit = DEFAULT_CREDIT;
let sender = self.stream_sender.clone();
Stream::new(stream_id, self.id, config, credit, credit, sender)
};
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT);
let mut window_update = None;
{
let mut shared = stream.shared();
Expand Down Expand Up @@ -748,15 +804,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
log::error!("{}: maximum number of streams reached", self.id);
return Action::Terminate(Frame::protocol_error());
}
let stream = {
let credit = frame.header().credit() + DEFAULT_CREDIT;
let config = self.config.clone();
let sender = self.stream_sender.clone();
let mut stream =
Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender);
stream.set_flag(stream::Flag::Ack);
stream
};

let credit = frame.header().credit() + DEFAULT_CREDIT;
let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit);
stream.set_flag(stream::Flag::Ack);

if is_finish {
stream
.shared()
Expand Down Expand Up @@ -821,6 +873,18 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Action::None
}

fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream {
let config = self.config.clone();

let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number.
self.stream_receivers.push(TaggedStream::new(id, receiver));
if let Some(waker) = self.no_streams_waker.take() {
waker.wake();
}

Stream::new(id, self.id, config, window, credit, sender)
}

fn next_stream_id(&mut self) -> Result<StreamId> {
let proposed = StreamId::new(self.next_id);
self.next_id = self
Expand All @@ -844,79 +908,6 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
Mode::Server => id.is_client(),
}
}

/// Remove stale streams and create necessary messages to be sent to the remote.
fn garbage_collect(&mut self) {
let conn_id = self.id;
let win_update_mode = self.config.window_update_mode;
for stream in self.streams.values_mut() {
if stream.strong_count() > 1 {
continue;
}
log::trace!("{}: removing dropped {}", conn_id, stream);
let stream_id = stream.id();
let frame = {
let mut shared = stream.shared();
let frame = match shared.update_state(conn_id, stream_id, State::Closed) {
// The stream was dropped without calling `poll_close`.
// We reset the stream to inform the remote of the closure.
State::Open => {
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
}
// The stream was dropped without calling `poll_close`.
// We have already received a FIN from remote and send one
// back which closes the stream for good.
State::RecvClosed => {
let mut header = Header::data(stream_id, 0);
header.fin();
Some(Frame::new(header))
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame.
// The remote may be out of credit though and blocked on
// writing more data. We may need to reset the stream.
State::SendClosed => {
if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 {
// The remote may be waiting for a window update
// which we will never send, so reset the stream now.
let mut header = Header::data(stream_id, 0);
header.rst();
Some(Frame::new(header))
} else {
// The remote has either still credit or will be given more
// (due to an enqueued window update or because the update
// mode is `OnReceive`) or we already have inbound frames in
// the socket buffer which will be processed later. In any
// case we will reply with an RST in `Connection::on_data`
// because the stream will no longer be known.
None
}
}
// The stream was properly closed. We either already have
// or will at some later point send our FIN frame. The
// remote end has already done so in the past.
State::Closed => None,
};
if let Some(w) = shared.reader.take() {
w.wake()
}
if let Some(w) = shared.writer.take() {
w.wake()
}
frame
};
if let Some(f) = frame {
log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header());
self.pending_frames.push_back(f.into());
}
self.dropped_streams.push(stream_id)
}
for id in self.dropped_streams.drain(..) {
self.streams.remove(&id);
}
}
}

impl<T> Active<T> {
Expand Down
Loading