From 48aa05622c9564db2754ec82857939ddbcb8de13 Mon Sep 17 00:00:00 2001 From: Diggory Blake Date: Sat, 13 Mar 2021 20:38:31 +0000 Subject: [PATCH] Fix bug when a read on a BufStream is cancelled. --- sqlx-core/src/io/buf_stream.rs | 47 ++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index 4646cb7a29..6b5b55a4ae 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -104,21 +104,48 @@ where } } +// Holds a buffer which has been temporarily extended, so that +// we can read into it. Automatically shrinks the buffer back +// down if the read is cancelled. +struct BufTruncator<'a> { + buf: &'a mut BytesMut, + filled_len: usize, +} + +impl<'a> BufTruncator<'a> { + fn new(buf: &'a mut BytesMut) -> Self { + let filled_len = buf.len(); + Self { buf, filled_len } + } + fn reserve(&mut self, space: usize) { + self.buf.resize(self.filled_len + space, 0); + } + async fn read(&mut self, stream: &mut S) -> Result { + let n = stream.read(&mut self.buf[self.filled_len..]).await?; + self.filled_len += n; + Ok(n) + } + fn is_full(&self) -> bool { + self.filled_len >= self.buf.len() + } +} + +impl Drop for BufTruncator<'_> { + fn drop(&mut self) { + self.buf.truncate(self.filled_len); + } +} + async fn read_raw_into( stream: &mut S, buf: &mut BytesMut, cnt: usize, ) -> Result<(), Error> { - let offset = buf.len(); - - // zero-fills the space in the read buffer - buf.resize(offset + cnt, 0); + let mut buf = BufTruncator::new(buf); + buf.reserve(cnt); - let mut read = offset; - while (offset + cnt) > read { - // read in bytes from the stream into the read buffer starting - // from the offset we last read from - let n = stream.read(&mut buf[read..]).await?; + while !buf.is_full() { + let n = buf.read(stream).await?; if n == 0 { // a zero read when we had space in the read buffer @@ -128,8 +155,6 @@ async fn read_raw_into( return Err(io::Error::from(io::ErrorKind::ConnectionAborted).into()); } - - read += n; } Ok(())