diff --git a/tokio/src/io/util/lines.rs b/tokio/src/io/util/lines.rs index be4e86648d8..ee27400c9de 100644 --- a/tokio/src/io/util/lines.rs +++ b/tokio/src/io/util/lines.rs @@ -91,6 +91,7 @@ where let me = self.project(); let n = ready!(read_line_internal(me.reader, cx, me.buf, me.bytes, me.read))?; + debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { return Poll::Ready(Ok(None)); diff --git a/tokio/src/io/util/read_line.rs b/tokio/src/io/util/read_line.rs index c5ee597486f..d625a76b80a 100644 --- a/tokio/src/io/util/read_line.rs +++ b/tokio/src/io/util/read_line.rs @@ -5,7 +5,6 @@ use std::future::Future; use std::io; use std::mem; use std::pin::Pin; -use std::str; use std::task::{Context, Poll}; cfg_io_util! { @@ -14,45 +13,72 @@ cfg_io_util! { #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadLine<'a, R: ?Sized> { reader: &'a mut R, - buf: &'a mut String, - bytes: Vec, + /// This is the buffer we were provided. It will be replaced with an empty string + /// while reading to postpone utf-8 handling until after reading. + output: &'a mut String, + /// The actual allocation of the string is moved into a vector instead. + buf: Vec, + /// The number of bytes appended to buf. This can be less than buf.len() if + /// the buffer was not empty when the operation was started. read: usize, } } -pub(crate) fn read_line<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadLine<'a, R> +pub(crate) fn read_line<'a, R>(reader: &'a mut R, string: &'a mut String) -> ReadLine<'a, R> where R: AsyncBufRead + ?Sized + Unpin, { ReadLine { reader, - bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) }, - buf, + buf: mem::replace(string, String::new()).into_bytes(), + output: string, read: 0, } } +fn put_back_original_data(output: &mut String, mut vector: Vec, num_bytes_read: usize) { + let original_len = vector.len() - num_bytes_read; + vector.truncate(original_len); + *output = String::from_utf8(vector).expect("The original data must be valid utf-8."); +} + pub(super) fn read_line_internal( reader: Pin<&mut R>, cx: &mut Context<'_>, - buf: &mut String, - bytes: &mut Vec, + output: &mut String, + buf: &mut Vec, read: &mut usize, ) -> Poll> { - let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read)); - if str::from_utf8(&bytes).is_err() { - Poll::Ready(ret.and_then(|_| { - Err(io::Error::new( + let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read)); + let utf8_res = String::from_utf8(mem::replace(buf, Vec::new())); + + // At this point both buf and output are empty. The allocation is in utf8_res. + + debug_assert!(buf.is_empty()); + match (io_res, utf8_res) { + (Ok(num_bytes), Ok(string)) => { + debug_assert_eq!(*read, 0); + *output = string; + Poll::Ready(Ok(num_bytes)) + } + (Err(io_err), Ok(string)) => { + *output = string; + Poll::Ready(Err(io_err)) + } + (Ok(num_bytes), Err(utf8_err)) => { + debug_assert_eq!(*read, 0); + put_back_original_data(output, utf8_err.into_bytes(), num_bytes); + + Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", - )) - })) - } else { - debug_assert!(buf.is_empty()); - debug_assert_eq!(*read, 0); - // Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`. - mem::swap(unsafe { buf.as_mut_vec() }, bytes); - Poll::Ready(ret) + ))) + } + (Err(io_err), Err(utf8_err)) => { + put_back_original_data(output, utf8_err.into_bytes(), *read); + + Poll::Ready(Err(io_err)) + } } } @@ -62,11 +88,12 @@ impl Future for ReadLine<'_, R> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Self { reader, + output, buf, - bytes, read, } = &mut *self; - read_line_internal(Pin::new(reader), cx, buf, bytes, read) + + read_line_internal(Pin::new(reader), cx, output, buf, read) } } diff --git a/tokio/src/io/util/read_until.rs b/tokio/src/io/util/read_until.rs index 1adeda66f05..78dac8c2a14 100644 --- a/tokio/src/io/util/read_until.rs +++ b/tokio/src/io/util/read_until.rs @@ -8,19 +8,22 @@ use std::task::{Context, Poll}; cfg_io_util! { /// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method. + /// The delimeter is included in the resulting vector. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct ReadUntil<'a, R: ?Sized> { reader: &'a mut R, - byte: u8, + delimeter: u8, buf: &'a mut Vec, + /// The number of bytes appended to buf. This can be less than buf.len() if + /// the buffer was not empty when the operation was started. read: usize, } } pub(crate) fn read_until<'a, R>( reader: &'a mut R, - byte: u8, + delimeter: u8, buf: &'a mut Vec, ) -> ReadUntil<'a, R> where @@ -28,7 +31,7 @@ where { ReadUntil { reader, - byte, + delimeter, buf, read: 0, } @@ -37,14 +40,14 @@ where pub(super) fn read_until_internal( mut reader: Pin<&mut R>, cx: &mut Context<'_>, - byte: u8, + delimeter: u8, buf: &mut Vec, read: &mut usize, ) -> Poll> { loop { let (done, used) = { let available = ready!(reader.as_mut().poll_fill_buf(cx))?; - if let Some(i) = memchr::memchr(byte, available) { + if let Some(i) = memchr::memchr(delimeter, available) { buf.extend_from_slice(&available[..=i]); (true, i + 1) } else { @@ -66,11 +69,11 @@ impl Future for ReadUntil<'_, R> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let Self { reader, - byte, + delimeter, buf, read, } = &mut *self; - read_until_internal(Pin::new(reader), cx, *byte, buf, read) + read_until_internal(Pin::new(reader), cx, *delimeter, buf, read) } } diff --git a/tokio/src/io/util/split.rs b/tokio/src/io/util/split.rs index f1ed2fd89d3..f552ed503d1 100644 --- a/tokio/src/io/util/split.rs +++ b/tokio/src/io/util/split.rs @@ -75,6 +75,8 @@ where let n = ready!(read_until_internal( me.reader, cx, *me.delim, me.buf, me.read, ))?; + // read_until_internal resets me.read to zero once it finds the delimeter + debug_assert_eq!(*me.read, 0); if n == 0 && me.buf.is_empty() { return Poll::Ready(Ok(None)); diff --git a/tokio/tests/io_read_line.rs b/tokio/tests/io_read_line.rs index 57ae37cef3e..15841c9b49d 100644 --- a/tokio/tests/io_read_line.rs +++ b/tokio/tests/io_read_line.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; use std::io::Cursor; @@ -27,3 +28,80 @@ async fn read_line() { assert_eq!(n, 0); assert_eq!(buf, ""); } + +#[tokio::test] +async fn read_line_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld\nFizzBuz") + .read(b"z\n1\n2") + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "We say ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "Hello World\n".len()); + assert_eq!(line.as_str(), "We say Hello World\n"); + + line = "I solve ".to_string(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, "FizzBuzz\n".len()); + assert_eq!(line.as_str(), "I solve FizzBuzz\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(line.as_str(), "1\n"); + + line.clear(); + let bytes = read.read_line(&mut line).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(line.as_str(), "2"); +} + +#[tokio::test] +async fn read_line_invalid_utf8() { + let mock = Builder::new().read(b"Hello Wor\xffld.\n").build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::InvalidData); + assert_eq!(err.to_string(), "stream did not contain valid UTF-8"); + assert_eq!(line.as_str(), "Foo"); +} + +#[tokio::test] +async fn read_line_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(line.as_str(), "FooHello Wor"); +} + +#[tokio::test] +async fn read_line_fail_and_utf8_fail() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"\xff\xff\xff") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut line = "Foo".to_string(); + let err = read.read_line(&mut line).await.expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(line.as_str(), "Foo"); +} diff --git a/tokio/tests/io_read_until.rs b/tokio/tests/io_read_until.rs index 4e0e0d10d34..61800a0d9c1 100644 --- a/tokio/tests/io_read_until.rs +++ b/tokio/tests/io_read_until.rs @@ -1,8 +1,9 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::AsyncBufReadExt; -use tokio_test::assert_ok; +use std::io::ErrorKind; +use tokio::io::{AsyncBufReadExt, BufReader, Error}; +use tokio_test::{assert_ok, io::Builder}; #[tokio::test] async fn read_until() { @@ -21,3 +22,53 @@ async fn read_until() { assert_eq!(n, 0); assert_eq!(buf, []); } + +#[tokio::test] +async fn read_until_not_all_ready() { + let mock = Builder::new() + .read(b"Hello Wor") + .read(b"ld#Fizz\xffBuz") + .read(b"z#1#2") + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"We say ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, b"Hello World#".len()); + assert_eq!(chunk, b"We say Hello World#"); + + chunk = b"I solve ".to_vec(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, b"Fizz\xffBuzz\n".len()); + assert_eq!(chunk, b"I solve Fizz\xffBuzz#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 2); + assert_eq!(chunk, b"1#"); + + chunk.clear(); + let bytes = read.read_until(b'#', &mut chunk).await.unwrap(); + assert_eq!(bytes, 1); + assert_eq!(chunk, b"2"); +} + +#[tokio::test] +async fn read_until_fail() { + let mock = Builder::new() + .read(b"Hello \xffWor") + .read_error(Error::new(ErrorKind::Other, "The world has no end")) + .build(); + + let mut read = BufReader::new(mock); + + let mut chunk = b"Foo".to_vec(); + let err = read + .read_until(b'#', &mut chunk) + .await + .expect_err("Should fail"); + assert_eq!(err.kind(), ErrorKind::Other); + assert_eq!(err.to_string(), "The world has no end"); + assert_eq!(chunk, b"FooHello \xffWor"); +}