Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Handle interceptor errors as responses (#840) #842

Merged
merged 1 commit into from
Feb 15, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tonic-build/src/client.rs
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ pub fn generate<T: Service>(
attributes: &Attributes,
) -> TokenStream {
let service_ident = quote::format_ident!("{}Client", service.name());
let client_mod = quote::format_ident!("{}_client", naive_snake_case(&service.name()));
let client_mod = quote::format_ident!("{}_client", naive_snake_case(service.name()));
let methods = generate_methods(service, emit_package, proto_path, compile_well_known_types);

let connect = generate_connect(&service_ident);
@@ -57,8 +57,8 @@ pub fn generate<T: Service>(
impl<T> #service_ident<T>
where
T: tonic::client::GrpcService<tonic::body::BoxBody>,
T::ResponseBody: Body + Send + 'static,
T::Error: Into<StdError>,
T::ResponseBody: Default + Body<Data = Bytes> + Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub fn new(inner: T) -> Self {
1 change: 1 addition & 0 deletions tonic/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@ pub type StdError = Box<dyn std::error::Error + Send + Sync + 'static>;
#[cfg(feature = "compression")]
pub use crate::codec::{CompressionEncoding, EnabledCompressionEncodings};
pub use crate::service::interceptor::InterceptedService;
pub use bytes::Bytes;
pub use http_body::Body;

pub type BoxFuture<T, E> = self::Pin<Box<dyn self::Future<Output = Result<T, E>> + Send + 'static>>;
67 changes: 61 additions & 6 deletions tonic/src/service/interceptor.rs
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
//! See [`Interceptor`] for more details.

use crate::{request::SanitizeHeaders, Status};
use bytes::Bytes;
use pin_project::pin_project;
use std::{
fmt,
@@ -140,9 +141,11 @@ where

impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
where
ResBody: Default + http_body::Body<Data = Bytes> + Send + 'static,
F: Interceptor,
S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
S::Error: Into<crate::Error>,
ResBody::Error: Into<crate::Error>,
{
type Response = http::Response<ResBody>;
type Error = crate::Error;
@@ -215,15 +218,18 @@ impl<F, E, B> Future for ResponseFuture<F>
where
F: Future<Output = Result<http::Response<B>, E>>,
E: Into<crate::Error>,
B: Default + http_body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<crate::Error>,
{
type Output = Result<http::Response<B>, crate::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().kind.project() {
KindProj::Future(future) => future.poll(cx).map_err(Into::into),
KindProj::Error(status) => {
let error = status.take().unwrap().into();
Poll::Ready(Err(error))
let response = status.take().unwrap().to_http().map(|_| B::default());

Poll::Ready(Ok(response))
}
}
}
@@ -233,11 +239,38 @@ where
mod tests {
#[allow(unused_imports)]
use super::*;
use http::header::HeaderMap;
use std::{
pin::Pin,
task::{Context, Poll},
};
use tower::ServiceExt;

#[derive(Debug, Default)]
struct TestBody;

impl http_body::Body for TestBody {
type Data = Bytes;
type Error = Status;

fn poll_data(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
Poll::Ready(None)
}

fn poll_trailers(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
Poll::Ready(Ok(None))
}
}

#[tokio::test]
async fn doesnt_remove_headers() {
let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
async fn doesnt_remove_headers_from_requests() {
let svc = tower::service_fn(|request: http::Request<TestBody>| async move {
assert_eq!(
request
.headers()
@@ -246,7 +279,7 @@ mod tests {
"test-tonic"
);

Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
@@ -257,14 +290,36 @@ mod tests {
.expect("missing in interceptor"),
"test-tonic"
);

Ok(request)
});

let request = http::Request::builder()
.header("user-agent", "test-tonic")
.body(hyper::Body::empty())
.body(TestBody)
.unwrap();

svc.oneshot(request).await.unwrap();
}

#[tokio::test]
async fn handles_intercepted_status_as_response() {
let message = "Blocked by the interceptor";
let expected = Status::permission_denied(message).to_http();

let svc = tower::service_fn(|_: http::Request<TestBody>| async {
Ok::<_, Status>(http::Response::new(TestBody))
});

let svc = InterceptedService::new(svc, |_: crate::Request<()>| {
Err(Status::permission_denied(message))
});

let request = http::Request::builder().body(TestBody).unwrap();
let response = svc.oneshot(request).await.unwrap();

assert_eq!(expected.status(), response.status());
assert_eq!(expected.version(), response.version());
assert_eq!(expected.headers(), response.headers());
}
}