Skip to content

Commit

Permalink
refactor: avoid grpc forwarding twice (#991)
Browse files Browse the repository at this point in the history
## Rationale
Close #984 

## Detailed Changes
Add a parameter to the headers of grpc to mark that it has been
forwarded.

## Test Plan
Existing tests
  • Loading branch information
baojinri committed Jun 16, 2023
1 parent 77d0d79 commit 032c448
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 19 deletions.
29 changes: 27 additions & 2 deletions proxy/src/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ use tonic::{
transport::{self, Channel},
};

use crate::FORWARDED_FROM;

#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display(
Expand Down Expand Up @@ -68,6 +70,9 @@ pub enum Error {
source: tonic::transport::Error,
backtrace: Backtrace,
},

#[snafu(display("Request should not be forwarded twice, forward from:{}", endpoint))]
ForwardedErr { endpoint: String },
}

define_result!(Error);
Expand Down Expand Up @@ -184,6 +189,7 @@ pub struct ForwardRequest<Req> {
pub schema: String,
pub table: String,
pub req: tonic::Request<Req>,
pub forwarded_from: Option<String>,
}

impl Forwarder<DefaultClientBuilder> {
Expand Down Expand Up @@ -256,7 +262,12 @@ impl<B: ClientBuilder> Forwarder<B> {
F: ForwarderRpc<Req, Resp, Err>,
Req: std::fmt::Debug + Clone,
{
let ForwardRequest { schema, table, req } = forward_req;
let ForwardRequest {
schema,
table,
req,
forwarded_from,
} = forward_req;

let route_req = RouteRequest {
context: Some(RequestContext { database: schema }),
Expand All @@ -281,13 +292,15 @@ impl<B: ClientBuilder> Forwarder<B> {
}
};

self.forward_with_endpoint(endpoint, req, do_rpc).await
self.forward_with_endpoint(endpoint, req, forwarded_from, do_rpc)
.await
}

pub async fn forward_with_endpoint<Req, Resp, Err, F>(
&self,
endpoint: Endpoint,
mut req: tonic::Request<Req>,
forwarded_from: Option<String>,
do_rpc: F,
) -> Result<ForwardResult<Resp, Err>>
where
Expand All @@ -310,6 +323,17 @@ impl<B: ClientBuilder> Forwarder<B> {
"Try to forward request to {:?}, request:{:?}",
endpoint, req,
);

if let Some(endpoint) = forwarded_from {
return ForwardedErr { endpoint }.fail();
}

// mark forwarded
req.metadata_mut().insert(
FORWARDED_FROM,
self.local_endpoint.to_string().parse().unwrap(),
);

let client = self.get_or_create_client(&endpoint).await?;
match do_rpc(client, req, &endpoint).await {
Err(e) => {
Expand Down Expand Up @@ -461,6 +485,7 @@ mod tests {
schema: DEFAULT_SCHEMA.to_string(),
table: table.to_string(),
req: query_request.into_request(),
forwarded_from: None,
}
};

Expand Down
8 changes: 7 additions & 1 deletion proxy/src/grpc/sql_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {

let req_context = req.context.as_ref().unwrap();
let schema = req_context.database.clone();
let req = match self.clone().maybe_forward_stream_sql_query(&req).await {
let req = match self
.clone()
.maybe_forward_stream_sql_query(ctx.clone(), &req)
.await
{
Some(resp) => match resp {
ForwardResult::Forwarded(resp) => return resp,
ForwardResult::Local => req,
Expand Down Expand Up @@ -167,6 +171,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {

async fn maybe_forward_stream_sql_query(
self: Arc<Self>,
ctx: Context,
req: &SqlQueryRequest,
) -> Option<ForwardResult<BoxStream<'static, SqlQueryResponse>, Error>> {
if req.tables.len() != 1 {
Expand All @@ -180,6 +185,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: req_ctx.database.clone(),
table: req.tables[0].clone(),
req: req.clone().into_request(),
forwarded_from: ctx.forwarded_from,
};
let do_query = |mut client: StorageServiceClient<Channel>,
request: tonic::Request<SqlQueryRequest>,
Expand Down
1 change: 1 addition & 0 deletions proxy/src/http/prom.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
runtime: self.engine_runtimes.write_runtime.clone(),
timeout: ctx.timeout,
enable_partition_table_access: false,
forwarded_from: None,
};

let result = self.handle_write_internal(ctx, table_request).await?;
Expand Down
1 change: 1 addition & 0 deletions proxy/src/http/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
timeout: ctx.timeout,
runtime: self.engine_runtimes.read_runtime.clone(),
enable_partition_table_access: true,
forwarded_from: None,
};

match self.handle_sql(context, &ctx.schema, &req.query).await? {
Expand Down
1 change: 1 addition & 0 deletions proxy/src/influxdb/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
timeout: ctx.timeout,
runtime: self.engine_runtimes.write_runtime.clone(),
enable_partition_table_access: false,
forwarded_from: None,
};
let result = self
.handle_write_internal(proxy_context, table_request)
Expand Down
4 changes: 4 additions & 0 deletions proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub mod schema_config_provider;
mod util;
mod write;

pub const FORWARDED_FROM: &str = "forwarded-from";

use std::{
sync::Arc,
time::{Duration, Instant},
Expand Down Expand Up @@ -131,6 +133,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: req_ctx.database.clone(),
table: metric,
req: req.into_request(),
forwarded_from: None,
};
let do_query = |mut client: StorageServiceClient<Channel>,
request: tonic::Request<PrometheusRemoteQueryRequest>,
Expand Down Expand Up @@ -452,4 +455,5 @@ pub struct Context {
pub timeout: Option<Duration>,
pub runtime: Arc<Runtime>,
pub enable_partition_table_access: bool,
pub forwarded_from: Option<String>,
}
7 changes: 6 additions & 1 deletion proxy/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: &str,
sql: &str,
) -> Result<SqlResponse> {
if let Some(resp) = self.maybe_forward_sql_query(schema, sql).await? {
if let Some(resp) = self
.maybe_forward_sql_query(ctx.clone(), schema, sql)
.await?
{
match resp {
ForwardResult::Forwarded(resp) => return Ok(SqlResponse::Forwarded(resp?)),
ForwardResult::Local => (),
Expand Down Expand Up @@ -149,6 +152,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {

async fn maybe_forward_sql_query(
&self,
ctx: Context,
schema: &str,
sql: &str,
) -> Result<Option<ForwardResult<SqlQueryResponse, Error>>> {
Expand All @@ -174,6 +178,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
schema: schema.to_string(),
table: table_name.unwrap(),
req: sql_request.into_request(),
forwarded_from: ctx.forwarded_from,
};
let do_query = |mut client: StorageServiceClient<Channel>,
request: tonic::Request<SqlQueryRequest>,
Expand Down
16 changes: 12 additions & 4 deletions proxy/src/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
let mut futures = Vec::with_capacity(write_requests_to_forward.len() + 1);

// Write to remote.
self.collect_write_to_remote_future(&mut futures, write_requests_to_forward)
self.collect_write_to_remote_future(&mut futures, ctx.clone(), write_requests_to_forward)
.await;

// Write to local.
Expand Down Expand Up @@ -139,7 +139,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
let mut futures = Vec::with_capacity(write_requests_to_forward.len() + 1);

// Write to remote.
self.collect_write_to_remote_future(&mut futures, write_requests_to_forward)
self.collect_write_to_remote_future(&mut futures, ctx.clone(), write_requests_to_forward)
.await;

// Create table.
Expand Down Expand Up @@ -358,12 +358,14 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
async fn collect_write_to_remote_future(
&self,
futures: &mut WriteResponseFutures<'_>,
ctx: Context,
write_request: HashMap<Endpoint, WriteRequest>,
) {
for (endpoint, table_write_request) in write_request {
let forwarder = self.forwarder.clone();
let ctx = ctx.clone();
let write_handle = self.engine_runtimes.io_runtime.spawn(async move {
Self::write_to_remote(forwarder, endpoint, table_write_request).await
Self::write_to_remote(ctx, forwarder, endpoint, table_write_request).await
});

futures.push(write_handle.boxed());
Expand Down Expand Up @@ -408,6 +410,7 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
}

async fn write_to_remote(
ctx: Context,
forwarder: ForwarderRef,
endpoint: Endpoint,
table_write_request: WriteRequest,
Expand All @@ -432,7 +435,12 @@ impl<Q: QueryExecutor + 'static> Proxy<Q> {
};

let forward_result = forwarder
.forward_with_endpoint(endpoint, tonic::Request::new(table_write_request), do_write)
.forward_with_endpoint(
endpoint,
tonic::Request::new(table_write_request),
ctx.forwarded_from,
do_write,
)
.await;
let forward_res = forward_result
.map_err(|e| {
Expand Down
48 changes: 37 additions & 11 deletions server/src/grpc/storage_service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use ceresdbproto::{
use common_util::time::InstantExt;
use futures::{stream, stream::BoxStream, StreamExt};
use http::StatusCode;
use proxy::{Context, Proxy};
use proxy::{Context, Proxy, FORWARDED_FROM};
use query_engine::executor::Executor as QueryExecutor;
use table_engine::engine::EngineRuntimes;

Expand Down Expand Up @@ -138,6 +138,10 @@ impl<Q: QueryExecutor + 'static> StorageService for StorageServiceImpl<Q> {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let stream = Self::stream_sql_query_internal(ctx, proxy, req).await;

Expand All @@ -155,13 +159,17 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<RouteRequest>,
) -> Result<tonic::Response<RouteResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self
.runtimes
Expand All @@ -186,13 +194,17 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<WriteRequest>,
) -> Result<tonic::Response<WriteResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.write_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self.runtimes.write_runtime.spawn(async move {
if req.context.is_none() {
Expand Down Expand Up @@ -226,13 +238,18 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<SqlQueryRequest>,
) -> Result<tonic::Response<SqlQueryResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self
.runtimes
.read_runtime
Expand Down Expand Up @@ -289,13 +306,18 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
&self,
req: tonic::Request<PrometheusQueryRequest>,
) -> Result<tonic::Response<PrometheusQueryResponse>, tonic::Status> {
let req = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.read_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let req = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self.runtimes.read_runtime.spawn(async move {
if req.context.is_none() {
return PrometheusQueryResponse {
Expand Down Expand Up @@ -329,13 +351,17 @@ impl<Q: QueryExecutor + 'static> StorageServiceImpl<Q> {
) -> Result<tonic::Response<WriteResponse>, tonic::Status> {
let mut total_success = 0;

let mut stream = req.into_inner();
let proxy = self.proxy.clone();
let ctx = Context {
runtime: self.runtimes.write_runtime.clone(),
timeout: self.timeout,
enable_partition_table_access: false,
forwarded_from: req
.metadata()
.get(FORWARDED_FROM)
.map(|value| value.to_str().unwrap().to_string()),
};
let mut stream = req.into_inner();
let proxy = self.proxy.clone();

let join_handle = self.runtimes.write_runtime.spawn(async move {
let mut resp = WriteResponse::default();
Expand Down

0 comments on commit 032c448

Please sign in to comment.