Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unsafe from ReadToString #2384

Merged
merged 3 commits into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ default-features = false
optional = true

[dev-dependencies]
tokio-test = { version = "0.2.0" }
tokio-test = { version = "0.2.0", path = "../tokio-test" }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is kosher, but a new version of tokio-test with the read_error mock hasn't been released to crates.io yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be ok to keep path dev-deps. What would be the reason to avoid them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't think of any reason, just wondering why it isn't already this way.

futures = { version = "0.3.0", features = ["async-await"] }
proptest = "0.9.4"
tempfile = "3.1.0"
Expand Down
36 changes: 22 additions & 14 deletions tokio/src/io/util/read_to_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::io::AsyncRead;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem, str};
use std::{io, mem};

cfg_io_util! {
/// Future for the [`read_to_string`](super::AsyncReadExt::read_to_string) method.
Expand All @@ -25,7 +25,7 @@ where
let start_len = buf.len();
ReadToString {
reader,
bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
bytes: mem::replace(buf, String::new()).into_bytes(),
buf,
start_len,
}
Expand All @@ -38,19 +38,20 @@ fn read_to_string_internal<R: AsyncRead + ?Sized>(
bytes: &mut Vec<u8>,
start_len: usize,
) -> Poll<io::Result<usize>> {
let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len));
if str::from_utf8(&bytes).is_err() {
Poll::Ready(ret.and_then(|_| {
Err(io::Error::new(
let ret = ready!(read_to_end_internal(reader, cx, bytes, start_len))?;
match String::from_utf8(mem::replace(bytes, Vec::new())) {
Ok(string) => {
debug_assert!(buf.is_empty());
*buf = string;
Poll::Ready(Ok(ret))
}
Err(e) => {
*bytes = e.into_bytes();
Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"stream did not contain valid UTF-8",
))
}))
} else {
debug_assert!(buf.is_empty());
// Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`.
mem::swap(unsafe { buf.as_mut_vec() }, bytes);
Poll::Ready(ret)
)))
}
}
}

Expand All @@ -67,7 +68,14 @@ where
bytes,
start_len,
} = &mut *self;
read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len)
let ret = read_to_string_internal(Pin::new(reader), cx, buf, bytes, *start_len);
if let Poll::Ready(Err(_)) = ret {
// Put back the original string.
bytes.truncate(*start_len);
**buf = String::from_utf8(mem::replace(bytes, Vec::new()))
.expect("original string no longer utf-8");
}
ret
}
}

Expand Down
49 changes: 49 additions & 0 deletions tokio/tests/read_to_string.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use std::io;
use tokio::io::AsyncReadExt;
use tokio_test::io::Builder;

#[tokio::test]
async fn to_string_does_not_truncate_on_utf8_error() {
let data = vec![0xff, 0xff, 0xff];

let mut s = "abc".to_string();

match AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s).await {
Ok(len) => panic!("Should fail: {} bytes.", len),
Err(err) if err.to_string() == "stream did not contain valid UTF-8" => {}
Err(err) => panic!("Fail: {}.", err),
}

assert_eq!(s, "abc");
}

#[tokio::test]
async fn to_string_does_not_truncate_on_io_error() {
let mut mock = Builder::new()
.read(b"def")
.read_error(io::Error::new(io::ErrorKind::Other, "whoops"))
.build();
let mut s = "abc".to_string();

match AsyncReadExt::read_to_string(&mut mock, &mut s).await {
Ok(len) => panic!("Should fail: {} bytes.", len),
Err(err) if err.to_string() == "whoops" => {}
Err(err) => panic!("Fail: {}.", err),
}

assert_eq!(s, "abc");
}

#[tokio::test]
async fn to_string_appends() {
let data = b"def".to_vec();

let mut s = "abc".to_string();

let len = AsyncReadExt::read_to_string(&mut data.as_slice(), &mut s)
.await
.unwrap();

assert_eq!(len, 3);
assert_eq!(s, "abcdef");
}