Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when a read on a BufStream is cancelled. #1099

Merged
merged 1 commit into from
Apr 9, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions sqlx-core/src/io/buf_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,48 @@ where
}
}

// Holds a buffer which has been temporarily extended, so that
// we can read into it. Automatically shrinks the buffer back
// down if the read is cancelled.
struct BufTruncator<'a> {
buf: &'a mut BytesMut,
filled_len: usize,
}

impl<'a> BufTruncator<'a> {
fn new(buf: &'a mut BytesMut) -> Self {
let filled_len = buf.len();
Self { buf, filled_len }
}
fn reserve(&mut self, space: usize) {
self.buf.resize(self.filled_len + space, 0);
}
async fn read<S: AsyncRead + Unpin>(&mut self, stream: &mut S) -> Result<usize, Error> {
let n = stream.read(&mut self.buf[self.filled_len..]).await?;
self.filled_len += n;
Ok(n)
}
fn is_full(&self) -> bool {
self.filled_len >= self.buf.len()
}
}

impl Drop for BufTruncator<'_> {
fn drop(&mut self) {
self.buf.truncate(self.filled_len);
}
}

async fn read_raw_into<S: AsyncRead + Unpin>(
stream: &mut S,
buf: &mut BytesMut,
cnt: usize,
) -> Result<(), Error> {
let offset = buf.len();

// zero-fills the space in the read buffer
buf.resize(offset + cnt, 0);
let mut buf = BufTruncator::new(buf);
buf.reserve(cnt);

let mut read = offset;
while (offset + cnt) > read {
// read in bytes from the stream into the read buffer starting
// from the offset we last read from
let n = stream.read(&mut buf[read..]).await?;
while !buf.is_full() {
let n = buf.read(stream).await?;

if n == 0 {
// a zero read when we had space in the read buffer
Expand All @@ -128,8 +155,6 @@ async fn read_raw_into<S: AsyncRead + Unpin>(

return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into());
}

read += n;
}

Ok(())
Expand Down