Skip to content

Commit

Permalink
io: fix panic in read_line (tokio-rs#2541)
Browse files Browse the repository at this point in the history
  • Loading branch information
Darksonn authored and jensim committed Jun 7, 2020
1 parent ff00506 commit 671bad8
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 33 deletions.
1 change: 1 addition & 0 deletions tokio/src/io/util/lines.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ where
let me = self.project();

let n = ready!(read_line_internal(me.reader, cx, me.buf, me.bytes, me.read))?;
debug_assert_eq!(*me.read, 0);

if n == 0 && me.buf.is_empty() {
return Poll::Ready(Ok(None));
Expand Down
71 changes: 49 additions & 22 deletions tokio/src/io/util/read_line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::future::Future;
use std::io;
use std::mem;
use std::pin::Pin;
use std::str;
use std::task::{Context, Poll};

cfg_io_util! {
Expand All @@ -14,45 +13,72 @@ cfg_io_util! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadLine<'a, R: ?Sized> {
reader: &'a mut R,
buf: &'a mut String,
bytes: Vec<u8>,
/// This is the buffer we were provided. It will be replaced with an empty string
/// while reading to postpone utf-8 handling until after reading.
output: &'a mut String,
/// The actual allocation of the string is moved into a vector instead.
buf: Vec<u8>,
/// The number of bytes appended to buf. This can be less than buf.len() if
/// the buffer was not empty when the operation was started.
read: usize,
}
}

pub(crate) fn read_line<'a, R>(reader: &'a mut R, buf: &'a mut String) -> ReadLine<'a, R>
pub(crate) fn read_line<'a, R>(reader: &'a mut R, string: &'a mut String) -> ReadLine<'a, R>
where
R: AsyncBufRead + ?Sized + Unpin,
{
ReadLine {
reader,
bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
buf,
buf: mem::replace(string, String::new()).into_bytes(),
output: string,
read: 0,
}
}

fn put_back_original_data(output: &mut String, mut vector: Vec<u8>, num_bytes_read: usize) {
let original_len = vector.len() - num_bytes_read;
vector.truncate(original_len);
*output = String::from_utf8(vector).expect("The original data must be valid utf-8.");
}

pub(super) fn read_line_internal<R: AsyncBufRead + ?Sized>(
reader: Pin<&mut R>,
cx: &mut Context<'_>,
buf: &mut String,
bytes: &mut Vec<u8>,
output: &mut String,
buf: &mut Vec<u8>,
read: &mut usize,
) -> Poll<io::Result<usize>> {
let ret = ready!(read_until_internal(reader, cx, b'\n', bytes, read));
if str::from_utf8(&bytes).is_err() {
Poll::Ready(ret.and_then(|_| {
Err(io::Error::new(
let io_res = ready!(read_until_internal(reader, cx, b'\n', buf, read));
let utf8_res = String::from_utf8(mem::replace(buf, Vec::new()));

// At this point both buf and output are empty. The allocation is in utf8_res.

debug_assert!(buf.is_empty());
match (io_res, utf8_res) {
(Ok(num_bytes), Ok(string)) => {
debug_assert_eq!(*read, 0);
*output = string;
Poll::Ready(Ok(num_bytes))
}
(Err(io_err), Ok(string)) => {
*output = string;
Poll::Ready(Err(io_err))
}
(Ok(num_bytes), Err(utf8_err)) => {
debug_assert_eq!(*read, 0);
put_back_original_data(output, utf8_err.into_bytes(), num_bytes);

Poll::Ready(Err(io::Error::new(
io::ErrorKind::InvalidData,
"stream did not contain valid UTF-8",
))
}))
} else {
debug_assert!(buf.is_empty());
debug_assert_eq!(*read, 0);
// Safety: `bytes` is a valid UTF-8 because `str::from_utf8` returned `Ok`.
mem::swap(unsafe { buf.as_mut_vec() }, bytes);
Poll::Ready(ret)
)))
}
(Err(io_err), Err(utf8_err)) => {
put_back_original_data(output, utf8_err.into_bytes(), *read);

Poll::Ready(Err(io_err))
}
}
}

Expand All @@ -62,11 +88,12 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadLine<'_, R> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
reader,
output,
buf,
bytes,
read,
} = &mut *self;
read_line_internal(Pin::new(reader), cx, buf, bytes, read)

read_line_internal(Pin::new(reader), cx, output, buf, read)
}
}

Expand Down
17 changes: 10 additions & 7 deletions tokio/src/io/util/read_until.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,30 @@ use std::task::{Context, Poll};

cfg_io_util! {
/// Future for the [`read_until`](crate::io::AsyncBufReadExt::read_until) method.
/// The delimeter is included in the resulting vector.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReadUntil<'a, R: ?Sized> {
reader: &'a mut R,
byte: u8,
delimeter: u8,
buf: &'a mut Vec<u8>,
/// The number of bytes appended to buf. This can be less than buf.len() if
/// the buffer was not empty when the operation was started.
read: usize,
}
}

pub(crate) fn read_until<'a, R>(
reader: &'a mut R,
byte: u8,
delimeter: u8,
buf: &'a mut Vec<u8>,
) -> ReadUntil<'a, R>
where
R: AsyncBufRead + ?Sized + Unpin,
{
ReadUntil {
reader,
byte,
delimeter,
buf,
read: 0,
}
Expand All @@ -37,14 +40,14 @@ where
pub(super) fn read_until_internal<R: AsyncBufRead + ?Sized>(
mut reader: Pin<&mut R>,
cx: &mut Context<'_>,
byte: u8,
delimeter: u8,
buf: &mut Vec<u8>,
read: &mut usize,
) -> Poll<io::Result<usize>> {
loop {
let (done, used) = {
let available = ready!(reader.as_mut().poll_fill_buf(cx))?;
if let Some(i) = memchr::memchr(byte, available) {
if let Some(i) = memchr::memchr(delimeter, available) {
buf.extend_from_slice(&available[..=i]);
(true, i + 1)
} else {
Expand All @@ -66,11 +69,11 @@ impl<R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntil<'_, R> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Self {
reader,
byte,
delimeter,
buf,
read,
} = &mut *self;
read_until_internal(Pin::new(reader), cx, *byte, buf, read)
read_until_internal(Pin::new(reader), cx, *delimeter, buf, read)
}
}

Expand Down
2 changes: 2 additions & 0 deletions tokio/src/io/util/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ where
let n = ready!(read_until_internal(
me.reader, cx, *me.delim, me.buf, me.read,
))?;
// read_until_internal resets me.read to zero once it finds the delimeter
debug_assert_eq!(*me.read, 0);

if n == 0 && me.buf.is_empty() {
return Poll::Ready(Ok(None));
Expand Down
82 changes: 80 additions & 2 deletions tokio/tests/io_read_line.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::AsyncBufReadExt;
use tokio_test::assert_ok;
use std::io::ErrorKind;
use tokio::io::{AsyncBufReadExt, BufReader, Error};
use tokio_test::{assert_ok, io::Builder};

use std::io::Cursor;

Expand All @@ -27,3 +28,80 @@ async fn read_line() {
assert_eq!(n, 0);
assert_eq!(buf, "");
}

#[tokio::test]
async fn read_line_not_all_ready() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"ld\nFizzBuz")
.read(b"z\n1\n2")
.build();

let mut read = BufReader::new(mock);

let mut line = "We say ".to_string();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, "Hello World\n".len());
assert_eq!(line.as_str(), "We say Hello World\n");

line = "I solve ".to_string();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, "FizzBuzz\n".len());
assert_eq!(line.as_str(), "I solve FizzBuzz\n");

line.clear();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, 2);
assert_eq!(line.as_str(), "1\n");

line.clear();
let bytes = read.read_line(&mut line).await.unwrap();
assert_eq!(bytes, 1);
assert_eq!(line.as_str(), "2");
}

#[tokio::test]
async fn read_line_invalid_utf8() {
let mock = Builder::new().read(b"Hello Wor\xffld.\n").build();

let mut read = BufReader::new(mock);

let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::InvalidData);
assert_eq!(err.to_string(), "stream did not contain valid UTF-8");
assert_eq!(line.as_str(), "Foo");
}

#[tokio::test]
async fn read_line_fail() {
let mock = Builder::new()
.read(b"Hello Wor")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();

let mut read = BufReader::new(mock);

let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(line.as_str(), "FooHello Wor");
}

#[tokio::test]
async fn read_line_fail_and_utf8_fail() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"\xff\xff\xff")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();

let mut read = BufReader::new(mock);

let mut line = "Foo".to_string();
let err = read.read_line(&mut line).await.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(line.as_str(), "Foo");
}
55 changes: 53 additions & 2 deletions tokio/tests/io_read_until.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use tokio::io::AsyncBufReadExt;
use tokio_test::assert_ok;
use std::io::ErrorKind;
use tokio::io::{AsyncBufReadExt, BufReader, Error};
use tokio_test::{assert_ok, io::Builder};

#[tokio::test]
async fn read_until() {
Expand All @@ -21,3 +22,53 @@ async fn read_until() {
assert_eq!(n, 0);
assert_eq!(buf, []);
}

#[tokio::test]
async fn read_until_not_all_ready() {
let mock = Builder::new()
.read(b"Hello Wor")
.read(b"ld#Fizz\xffBuz")
.read(b"z#1#2")
.build();

let mut read = BufReader::new(mock);

let mut chunk = b"We say ".to_vec();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, b"Hello World#".len());
assert_eq!(chunk, b"We say Hello World#");

chunk = b"I solve ".to_vec();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, b"Fizz\xffBuzz\n".len());
assert_eq!(chunk, b"I solve Fizz\xffBuzz#");

chunk.clear();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, 2);
assert_eq!(chunk, b"1#");

chunk.clear();
let bytes = read.read_until(b'#', &mut chunk).await.unwrap();
assert_eq!(bytes, 1);
assert_eq!(chunk, b"2");
}

#[tokio::test]
async fn read_until_fail() {
let mock = Builder::new()
.read(b"Hello \xffWor")
.read_error(Error::new(ErrorKind::Other, "The world has no end"))
.build();

let mut read = BufReader::new(mock);

let mut chunk = b"Foo".to_vec();
let err = read
.read_until(b'#', &mut chunk)
.await
.expect_err("Should fail");
assert_eq!(err.kind(), ErrorKind::Other);
assert_eq!(err.to_string(), "The world has no end");
assert_eq!(chunk, b"FooHello \xffWor");
}

0 comments on commit 671bad8

Please sign in to comment.