From 34fe1d9ed9f453306062945acffddf8778a9b9f0 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Tue, 7 Apr 2020 18:15:06 -0700 Subject: [PATCH 1/3] Remove `unsafe` from ReadToString We can do everything needed with only safe code. --- tokio/src/io/util/read_to_string.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tokio/src/io/util/read_to_string.rs b/tokio/src/io/util/read_to_string.rs index e77d836dee9..534c3f7ad7e 100644 --- a/tokio/src/io/util/read_to_string.rs +++ b/tokio/src/io/util/read_to_string.rs @@ -4,7 +4,7 @@ use crate::io::AsyncRead; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{io, mem, str}; +use std::{io, mem}; cfg_io_util! { /// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method. @@ -25,7 +25,7 @@ where let start_len = buf.len(); ReadToString { reader, - bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) }, + bytes: mem::replace(buf, String::new()).into_bytes(), buf, start_len, } @@ -39,18 +39,17 @@ fn read_to_string_internal( start_len: usize, ) -> Poll> { let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len)); - if str::from_utf8(&bytes).is_err() { + if let Ok(string) = String::from_utf8(mem::take(bytes)) { + debug_assert!(buf.is_empty()); + *bytes = mem::replace(buf, string).into_bytes(); + Poll::Ready(ret) + } else { Poll::Ready(ret.and_then(|_| { Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", )) })) - } else { - debug_assert!(buf.is_empty()); - // Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`. - mem::swap(unsafe { buf.as_mut_vec() }, bytes); - Poll::Ready(ret) } } From b74c496f9686c2ead3bd26b00067cc8a1df4e7d8 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Wed, 8 Apr 2020 13:48:03 -0700 Subject: [PATCH 2/3] Put back the original string on error --- tokio/Cargo.toml | 2 +- tokio/src/io/util/read_to_string.rs | 30 +++++++++++------- tokio/tests/read_to_string.rs | 49 +++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 12 deletions(-) create mode 100644 tokio/tests/read_to_string.rs diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index 9fd5f95de6f..b9ab5b26745 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -121,7 +121,7 @@ default-features = false optional = true [dev-dependencies] -tokio-test = { version = "0.2.0" } +tokio-test = { version = "0.2.0", path = "../tokio-test" } futures = { version = "0.3.0", features = ["async-await"] } proptest = "0.9.4" tempfile = "3.1.0" diff --git a/tokio/src/io/util/read_to_string.rs b/tokio/src/io/util/read_to_string.rs index 534c3f7ad7e..64f4cbcb8bc 100644 --- a/tokio/src/io/util/read_to_string.rs +++ b/tokio/src/io/util/read_to_string.rs @@ -38,18 +38,20 @@ fn read_to_string_internal( bytes: &mut Vec, start_len: usize, ) -> Poll> { - let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len)); - if let Ok(string) = String::from_utf8(mem::take(bytes)) { - debug_assert!(buf.is_empty()); - *bytes = mem::replace(buf, string).into_bytes(); - Poll::Ready(ret) - } else { - Poll::Ready(ret.and_then(|_| { - Err(io::Error::new( + let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?; + match String::from_utf8(mem::take(bytes)) { + Ok(string) => { + debug_assert!(buf.is_empty()); + *buf = string; + Poll::Ready(Ok(ret)) + } + Err(e) => { + *bytes = e.into_bytes(); + Poll::Ready(Err(io::Error::new( io::ErrorKind::InvalidData, "stream did not contain valid UTF-8", - )) - })) + ))) + } } } @@ -66,7 +68,13 @@ where bytes, start_len, } = &mut *self; - read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len) + let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len); + if let Poll::Ready(Err(_)) = ret { + // Put back the original string. + bytes.truncate(*start_len); + **buf = String::from_utf8(mem::take(bytes)).expect("original string no longer utf-8"); + } + ret } } diff --git a/tokio/tests/read_to_string.rs b/tokio/tests/read_to_string.rs new file mode 100644 index 00000000000..db3fa1bf4bd --- /dev/null +++ b/tokio/tests/read_to_string.rs @@ -0,0 +1,49 @@ +use std::io; +use tokio::io::AsyncReadExt; +use tokio_test::io::Builder; + +#[tokio::test] +async fn to_string_does_not_truncate_on_utf8_error() { + let data = vec![0xff, 0xff, 0xff]; + + let mut s = "abc".to_string(); + + match AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s).await { + Ok(len) => panic!("Should fail: {} bytes.", len), + Err(err) if err.to_string() == "stream did not contain valid UTF-8" => {} + Err(err) => panic!("Fail: {}.", err), + } + + assert_eq!(s, "abc"); +} + +#[tokio::test] +async fn to_string_does_not_truncate_on_io_error() { + let mut mock = Builder::new() + .read(b"def") + .read_error(io::Error::new(io::ErrorKind::Other, "whoops")) + .build(); + let mut s = "abc".to_string(); + + match AsyncReadExt::read_to_string(&mut mock, &mut s).await { + Ok(len) => panic!("Should fail: {} bytes.", len), + Err(err) if err.to_string() == "whoops" => {} + Err(err) => panic!("Fail: {}.", err), + } + + assert_eq!(s, "abc"); +} + +#[tokio::test] +async fn to_string_appends() { + let data = b"def".to_vec(); + + let mut s = "abc".to_string(); + + let len = AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s) + .await + .unwrap(); + + assert_eq!(len, 3); + assert_eq!(s, "abcdef"); +} From 5cd3d89989b2a2445c80d2fe669ed23d4bcef8a8 Mon Sep 17 00:00:00 2001 From: Geoffry Song Date: Thu, 9 Apr 2020 01:12:05 -0700 Subject: [PATCH 3/3] Replace mem::take with mem::replace --- tokio/src/io/util/read_to_string.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tokio/src/io/util/read_to_string.rs b/tokio/src/io/util/read_to_string.rs index 64f4cbcb8bc..cab0505ab83 100644 --- a/tokio/src/io/util/read_to_string.rs +++ b/tokio/src/io/util/read_to_string.rs @@ -39,7 +39,7 @@ fn read_to_string_internal( start_len: usize, ) -> Poll> { let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?; - match String::from_utf8(mem::take(bytes)) { + match String::from_utf8(mem::replace(bytes, Vec::new())) { Ok(string) => { debug_assert!(buf.is_empty()); *buf = string; @@ -72,7 +72,8 @@ where if let Poll::Ready(Err(_)) = ret { // Put back the original string. bytes.truncate(*start_len); - **buf = String::from_utf8(mem::take(bytes)).expect("original string no longer utf-8"); + **buf = String::from_utf8(mem::replace(bytes, Vec::new())) + .expect("original string no longer utf-8"); } ret }