diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index a08fc942a6a..91a16ad51b4 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -121,7 +121,7 @@ default-features = false optional = true [dev-dependencies] -tokio-test = { version = "0.2.0" } +tokio-test = { version = "0.2.0", path = "../tokio-test" } futures = { version = "0.3.0", features = ["async-await"] } proptest = "0.9.4" tempfile = "3.1.0" diff --git a/tokio/src/io/util/read_to_string.rs b/tokio/src/io/util/read_to_string.rs index e77d836dee9..cab0505ab83 100644 --- a/tokio/src/io/util/read_to_string.rs +++ b/tokio/src/io/util/read_to_string.rs @@ -4,7 +4,7 @@ use crate::io::AsyncRead; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{io, mem, str}; +use std::{io, mem}; cfg_io_util! { /// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method. @@ -25,7 +25,7 @@ where let start_len = buf.len(); ReadToString { reader, - bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) }, + bytes: mem::replace(buf, String::new()).into_bytes(), buf, start_len, } @@ -38,19 +38,20 @@ fn read_to_string_internal( bytes: &mut Vec, start_len: usize, ) -> Poll> { - let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len)); - if str::from_utf8(&bytes).is_err() { - Poll::Ready(ret.and_then(|_| { - Err(io::Error::new( + let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?; + match String::from_utf8(mem::replace(bytes, Vec::new())) { + Ok(string) => { + debug_assert!(buf.is_empty()); + *buf = string; + Poll::Ready(Ok(ret)) + } + Err(e) => { + *bytes = e.into_bytes(); + Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", - )) - })) - } else { - debug_assert!(buf.is_empty()); - // Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`. - mem::swap(unsafe { buf.as_mut_vec() }, bytes); - Poll::Ready(ret) + ))) + } } } @@ -67,7 +68,14 @@ where bytes, start_len, } = &mut *self; - read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len) + let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len); + if let Poll::Ready(Err(_)) = ret { + // Put back the original string. + bytes.truncate(*start_len); + **buf = String::from_utf8(mem::replace(bytes, Vec::new())) + .expect("original string no longer utf-8"); + } + ret } } diff --git a/tokio/tests/read_to_string.rs b/tokio/tests/read_to_string.rs new file mode 100644 index 00000000000..db3fa1bf4bd --- /dev/null +++ b/tokio/tests/read_to_string.rs @@ -0,0 +1,49 @@ +use std::io; +use tokio::io::AsyncReadExt; +use tokio_test::io::Builder; + +#[tokio::test] +async fn to_string_does_not_truncate_on_utf8_error() { + let data = vec![0xff, 0xff, 0xff]; + + let mut s = "abc".to_string(); + + match AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s).await { + Ok(len) => panic!("Should fail: {} bytes.", len), + Err(err) if err.to_string() == "stream did not contain valid UTF-8" => {} + Err(err) => panic!("Fail: {}.", err), + } + + assert_eq!(s, "abc"); +} + +#[tokio::test] +async fn to_string_does_not_truncate_on_io_error() { + let mut mock = Builder::new() + .read(b"def") + .read_error(io::Error::new(io::ErrorKind::Other, "whoops")) + .build(); + let mut s = "abc".to_string(); + + match AsyncReadExt::read_to_string(&mut mock, &mut s).await { + Ok(len) => panic!("Should fail: {} bytes.", len), + Err(err) if err.to_string() == "whoops" => {} + Err(err) => panic!("Fail: {}.", err), + } + + assert_eq!(s, "abc"); +} + +#[tokio::test] +async fn to_string_appends() { + let data = b"def".to_vec(); + + let mut s = "abc".to_string(); + + let len = AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s) + .await + .unwrap(); + + assert_eq!(len, 3); + assert_eq!(s, "abcdef"); +}