Skip to content

Commit

Permalink
Merge pull request #199 from tottoto/refactor-response-body
Browse files Browse the repository at this point in the history
Refactor response body
  • Loading branch information
lipanski authored Apr 5, 2024
2 parents f477e54 + 42e3efe commit 08f2fa3
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ appveyor = { repository = "lipanski/mockito", branch = "master", service = "gith
assert-json-diff = "2.0"
bytes = "1"
colored = { version = "2.0", optional = true }
futures-core = "0.3"
futures-util = { version = "0.3", default-features = false }
http = "1"
http-body = "1"
http-body-util = "0.1"
Expand Down
11 changes: 5 additions & 6 deletions src/response.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::error::Error;
use crate::Request;
use bytes::Bytes;
use futures_core::stream::Stream;
use futures_util::Stream;
use http::{HeaderMap, StatusCode};
use http_body::Frame;
use std::fmt;
use std::io;
use std::sync::Arc;
Expand Down Expand Up @@ -117,7 +116,7 @@ impl Drop for ChunkedStream {
}

impl Stream for ChunkedStream {
type Item = io::Result<Frame<Bytes>>;
type Item = io::Result<Bytes>;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
Expand All @@ -126,9 +125,9 @@ impl Stream for ChunkedStream {
self.receiver
.as_mut()
.map(move |receiver| {
receiver.poll_recv(cx).map(|received| {
received.map(|result| result.map(|data| Frame::data(Bytes::from(data))))
})
receiver
.poll_recv(cx)
.map(|received| received.map(|result| result.map(Into::into)))
})
.unwrap_or(Poll::Ready(None))
}
Expand Down
93 changes: 72 additions & 21 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ use crate::request::Request;
use crate::response::{Body as ResponseBody, ChunkedStream};
use crate::ServerGuard;
use crate::{Error, ErrorKind, Matcher, Mock};
use bytes::Bytes;
use futures_util::{TryStream, TryStreamExt};
use http::{Request as HttpRequest, Response, StatusCode};
use http_body_util::{BodyExt, Empty, Full, StreamBody};
use http_body::{Body as HttpBody, Frame, SizeHint};
use http_body_util::{BodyExt, StreamBody};
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::{TokioExecutor, TokioIo};
Expand All @@ -14,8 +17,10 @@ use std::error::Error as StdError;
use std::fmt;
use std::net::{IpAddr, SocketAddr};
use std::ops::Drop;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::{mpsc, Arc, RwLock};
use std::task::{ready, Context, Poll};
use std::thread;
use tokio::net::TcpListener;
use tokio::runtime;
Expand Down Expand Up @@ -446,26 +451,72 @@ impl fmt::Display for Server {
}

type BoxError = Box<dyn StdError + Send + Sync>;
type BoxBody = http_body_util::combinators::UnsyncBoxBody<bytes::Bytes, BoxError>;

trait IntoBoxBody {
fn into_box_body(self) -> BoxBody;
enum Body {
Once(Option<Bytes>),
Wrap(http_body_util::combinators::UnsyncBoxBody<Bytes, BoxError>),
}

impl<B> IntoBoxBody for B
where
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B::Error: Into<BoxError>,
{
fn into_box_body(self) -> BoxBody {
self.map_err(Into::into).boxed_unsync()
impl Body {
fn empty() -> Self {
Self::Once(None)
}

fn from_data_stream<S>(stream: S) -> Self
where
S: TryStream<Ok = Bytes> + Send + 'static,
S::Error: Into<BoxError>,
{
let body = StreamBody::new(stream.map_ok(Frame::data).map_err(Into::into)).boxed_unsync();
Self::Wrap(body)
}
}

impl From<Bytes> for Body {
fn from(bytes: Bytes) -> Self {
if bytes.is_empty() {
Self::empty()
} else {
Self::Once(Some(bytes))
}
}
}

impl HttpBody for Body {
type Data = Bytes;
type Error = BoxError;

fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.as_mut().get_mut() {
Self::Once(val) => Poll::Ready(Ok(val.take().map(Frame::data)).transpose()),
Self::Wrap(body) => Poll::Ready(ready!(Pin::new(body).poll_frame(cx))),
}
}

fn size_hint(&self) -> SizeHint {
match self {
Self::Once(None) => SizeHint::with_exact(0),
Self::Once(Some(bytes)) => SizeHint::with_exact(bytes.len() as u64),
Self::Wrap(body) => body.size_hint(),
}
}

fn is_end_stream(&self) -> bool {
match self {
Self::Once(None) => true,
Self::Once(Some(bytes)) => bytes.is_empty(),
Self::Wrap(body) => body.is_end_stream(),
}
}
}

async fn handle_request(
hyper_request: HttpRequest<Incoming>,
state: Arc<RwLock<State>>,
) -> Result<Response<BoxBody>, Error> {
) -> Result<Response<Body>, Error> {
let mut request = Request::new(hyper_request);
request.read_body().await;
log::debug!("Request received: {}", request.formatted());
Expand Down Expand Up @@ -498,7 +549,7 @@ async fn handle_request(
}
}

fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<BoxBody>, Error> {
fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Body>, Error> {
let status: StatusCode = mock.inner.response.status;
let mut response = Response::builder().status(status);

Expand All @@ -512,32 +563,32 @@ fn respond_with_mock(request: Request, mock: &RemoteMock) -> Result<Response<Box
if !request.has_header("content-length") {
response = response.header("content-length", bytes.len());
}
Full::new(bytes.to_owned()).into_box_body()
Body::from(bytes.to_owned())
}
ResponseBody::FnWithWriter(body_fn) => {
let stream = ChunkedStream::new(Arc::clone(body_fn))?;
StreamBody::new(stream).into_box_body()
Body::from_data_stream(stream)
}
ResponseBody::FnWithRequest(body_fn) => {
let bytes = body_fn(&request);
Full::new(bytes.to_owned()).into_box_body()
Body::from(bytes)
}
}
} else {
Empty::new().into_box_body()
Body::empty()
};

let response: Response<BoxBody> = response
let response = response
.body(body)
.map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;

Ok(response)
}

fn respond_with_mock_not_found() -> Result<Response<BoxBody>, Error> {
let response: Response<BoxBody> = Response::builder()
fn respond_with_mock_not_found() -> Result<Response<Body>, Error> {
let response = Response::builder()
.status(StatusCode::NOT_IMPLEMENTED)
.body(Empty::new().into_box_body())
.body(Body::empty())
.map_err(|err| Error::new_with_context(ErrorKind::ResponseFailure, err))?;

Ok(response)
Expand Down

0 comments on commit 08f2fa3

Please sign in to comment.