Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add strict mode to validate protocol strings #3638

Merged
merged 14 commits into from
Apr 15, 2024
6 changes: 5 additions & 1 deletion src/frontend/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ where

if opts.prom_store.enable {
builder = builder
.with_prom_handler(self.instance.clone(), opts.prom_store.with_metric_engine)
.with_prom_handler(
self.instance.clone(),
opts.prom_store.with_metric_engine,
opts.http.is_strict_mode,
)
.with_prometheus_handler(self.instance.clone());
}

Expand Down
3 changes: 2 additions & 1 deletion src/servers/benches/prom_decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ fn bench_decode_prom_request(c: &mut Criterion) {

let mut request = WriteRequest::default();
let mut prom_request = PromWriteRequest::default();
let strict_mode = true;
c.benchmark_group("decode")
.measurement_time(Duration::from_secs(3))
.bench_function("write_request", |b| {
Expand All @@ -43,7 +44,7 @@ fn bench_decode_prom_request(c: &mut Criterion) {
.bench_function("prom_write_request", |b| {
b.iter(|| {
let data = data.clone();
prom_request.merge(data).unwrap();
prom_request.merge(data, strict_mode).unwrap();
prom_request.as_row_insert_requests();
});
});
Expand Down
36 changes: 28 additions & 8 deletions src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ pub struct HttpOptions {
pub disable_dashboard: bool,

pub body_limit: ReadableSize,

pub is_strict_mode: bool,
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
}

impl Default for HttpOptions {
Expand All @@ -136,6 +138,7 @@ impl Default for HttpOptions {
timeout: Duration::from_secs(30),
disable_dashboard: false,
body_limit: DEFAULT_BODY_LIMIT,
is_strict_mode: false,
}
}
}
Expand Down Expand Up @@ -502,11 +505,12 @@ impl HttpServerBuilder {
self,
handler: PromStoreProtocolHandlerRef,
prom_store_with_metric_engine: bool,
strict_mode: bool,
) -> Self {
Self {
router: self.router.nest(
&format!("/{HTTP_API_VERSION}/prometheus"),
HttpServer::route_prom(handler, prom_store_with_metric_engine),
HttpServer::route_prom(handler, prom_store_with_metric_engine, strict_mode),
),
..self
}
Expand Down Expand Up @@ -698,15 +702,31 @@ impl HttpServer {
fn route_prom<S>(
prom_handler: PromStoreProtocolHandlerRef,
prom_store_with_metric_engine: bool,
strict_mode: bool,
) -> Router<S> {
let mut router = Router::new().route("/read", routing::post(prom_store::remote_read));
if prom_store_with_metric_engine {
router = router.route("/write", routing::post(prom_store::remote_write));
} else {
router = router.route(
"/write",
routing::post(prom_store::route_write_without_metric_engine),
);
match (prom_store_with_metric_engine, strict_mode) {
(true, true) => {
router = router.route("/write", routing::post(prom_store::remote_write))
}
(true, false) => {
router = router.route(
"/write",
routing::post(prom_store::remote_write_without_strict_mode),
)
}
(false, true) => {
router = router.route(
"/write",
routing::post(prom_store::route_write_without_metric_engine),
)
}
(false, false) => {
router = router.route(
"/write",
routing::post(prom_store::route_write_without_metric_engine_and_strict_mode),
)
}
}
router.with_state(prom_handler)
}
Expand Down
77 changes: 73 additions & 4 deletions src/servers/src/http/prom_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,37 @@ pub async fn route_write_without_metric_engine(
.start_timer();

let is_zstd = content_encoding.contains(VM_ENCODING);
let (request, samples) = decode_remote_write_request(is_zstd, body).await?;
let (request, samples) = decode_remote_write_request(is_zstd, body, true).await?;
// reject if physical table is specified when metric engine is disabled
if params.physical_table.is_some() {
return UnexpectedPhysicalTableSnafu {}.fail();
}

let output = handler.write(request, query_ctx, false).await?;
crate::metrics::PROM_STORE_REMOTE_WRITE_SAMPLES.inc_by(samples as u64);
Ok((
StatusCode::NO_CONTENT,
write_cost_header_map(output.meta.cost),
)
.into_response())
}

/// Same with [remote_write] but won't store data to metric engine.
#[axum_macros::debug_handler]
pub async fn route_write_without_metric_engine_and_strict_mode(
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
State(handler): State<PromStoreProtocolHandlerRef>,
Query(params): Query<RemoteWriteQuery>,
Extension(query_ctx): Extension<QueryContextRef>,
content_encoding: TypedHeader<headers::ContentEncoding>,
RawBody(body): RawBody,
) -> Result<impl IntoResponse> {
let db = params.db.clone().unwrap_or_default();
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_WRITE_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();

let is_zstd = content_encoding.contains(VM_ENCODING);
let (request, samples) = decode_remote_write_request(is_zstd, body, false).await?;
// reject if physical table is specified when metric engine is disabled
if params.physical_table.is_some() {
return UnexpectedPhysicalTableSnafu {}.fail();
Expand Down Expand Up @@ -127,7 +157,8 @@ pub async fn remote_write(
.start_timer();

let is_zstd = content_encoding.contains(VM_ENCODING);
let (request, samples) = decode_remote_write_request_to_row_inserts(is_zstd, body).await?;
let (request, samples) =
decode_remote_write_request_to_row_inserts(is_zstd, body, true).await?;

if let Some(physical_table) = params.physical_table {
let mut new_query_ctx = query_ctx.as_ref().clone();
Expand All @@ -144,6 +175,42 @@ pub async fn remote_write(
.into_response())
}

#[axum_macros::debug_handler]
#[tracing::instrument(
skip_all,
fields(protocol = "prometheus", request_type = "remote_write")
)]
pub async fn remote_write_without_strict_mode(
State(handler): State<PromStoreProtocolHandlerRef>,
Query(params): Query<RemoteWriteQuery>,
Extension(mut query_ctx): Extension<QueryContextRef>,
content_encoding: TypedHeader<headers::ContentEncoding>,
RawBody(body): RawBody,
) -> Result<impl IntoResponse> {
let db = params.db.clone().unwrap_or_default();
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_WRITE_ELAPSED
.with_label_values(&[db.as_str()])
.start_timer();

v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
let is_zstd = content_encoding.contains(VM_ENCODING);
let (request, samples) =
decode_remote_write_request_to_row_inserts(is_zstd, body, false).await?;

if let Some(physical_table) = params.physical_table {
let mut new_query_ctx = query_ctx.as_ref().clone();
new_query_ctx.set_extension(PHYSICAL_TABLE_PARAM, physical_table);
query_ctx = Arc::new(new_query_ctx);
}

let output = handler.write(request, query_ctx, false).await?;
crate::metrics::PROM_STORE_REMOTE_WRITE_SAMPLES.inc_by(samples as u64);
Ok((
StatusCode::NO_CONTENT,
write_cost_header_map(output.meta.cost),
)
.into_response())
}

impl IntoResponse for PromStoreResponse {
fn into_response(self) -> axum::response::Response {
let mut header_map = HeaderMap::new();
Expand Down Expand Up @@ -187,6 +254,7 @@ pub async fn remote_read(
async fn decode_remote_write_request_to_row_inserts(
is_zstd: bool,
body: Body,
strict_mode: bool,
) -> Result<(RowInsertRequests, usize)> {
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_DECODE_ELAPSED.start_timer();
let body = hyper::body::to_bytes(body)
Expand All @@ -201,14 +269,15 @@ async fn decode_remote_write_request_to_row_inserts(

let mut request = PROM_WRITE_REQUEST_POOL.pull(PromWriteRequest::default);
request
.merge(buf)
.merge(buf, strict_mode)
.context(error::DecodePromRemoteRequestSnafu)?;
Ok(request.as_row_insert_requests())
}

async fn decode_remote_write_request(
is_zstd: bool,
body: Body,
strict_mode: bool,
) -> Result<(RowInsertRequests, usize)> {
let _timer = crate::metrics::METRIC_HTTP_PROM_STORE_DECODE_ELAPSED.start_timer();
let body = hyper::body::to_bytes(body)
Expand All @@ -223,7 +292,7 @@ async fn decode_remote_write_request(

let mut request = PromWriteRequest::default();
request
.merge(buf)
.merge(buf, strict_mode)
.context(error::DecodePromRemoteRequestSnafu)?;
Ok(request.as_row_insert_requests())
}
Expand Down
57 changes: 50 additions & 7 deletions src/servers/src/prom_row_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use api::v1::{
use common_query::prelude::{GREPTIME_TIMESTAMP, GREPTIME_VALUE};
use hashbrown::hash_map::Entry;
use hashbrown::HashMap;
use prost::DecodeError;

use crate::proto::PromLabel;
use crate::repeated_field::Clear;
Expand Down Expand Up @@ -118,13 +119,33 @@ impl TableBuilder {
}

/// Adds a set of labels and samples to table builder.
pub(crate) fn add_labels_and_samples(&mut self, labels: &[PromLabel], samples: &[Sample]) {
pub(crate) fn add_labels_and_samples(
&mut self,
labels: &[PromLabel],
samples: &[Sample],
strict_mode: bool,
) -> Result<(), DecodeError> {
let mut row = vec![Value { value_data: None }; self.col_indexes.len()];

for PromLabel { name, value } in labels {
// safety: we expect all labels are UTF-8 encoded strings.
let tag_name = unsafe { String::from_utf8_unchecked(name.to_vec()) };
let tag_value = unsafe { String::from_utf8_unchecked(value.to_vec()) };
let tag_name;
let tag_value;
v0y4g3r marked this conversation as resolved.
Show resolved Hide resolved
let (tag_name, tag_value) = if strict_mode {
tag_name = match String::from_utf8(name.to_vec()) {
Ok(s) => s,
Err(_) => return Err(DecodeError::new("invalid utf-8")),
};
tag_value = match String::from_utf8(value.to_vec()) {
Ok(s) => s,
Err(_) => return Err(DecodeError::new("invalid utf-8")),
};
(tag_name, tag_value)
} else {
tag_name = unsafe { String::from_utf8_unchecked(name.to_vec()) };
tag_value = unsafe { String::from_utf8_unchecked(value.to_vec()) };
(tag_name, tag_value)
};

let tag_value = Some(ValueData::StringValue(tag_value));
let tag_num = self.col_indexes.len();

Expand Down Expand Up @@ -153,7 +174,7 @@ impl TableBuilder {
row[0].value_data = Some(ValueData::TimestampMillisecondValue(sample.timestamp));
row[1].value_data = Some(ValueData::F64Value(sample.value));
self.rows.push(Row { values: row });
return;
return Ok(());
}
for sample in samples {
row[0].value_data = Some(ValueData::TimestampMillisecondValue(sample.timestamp));
Expand All @@ -162,6 +183,8 @@ impl TableBuilder {
values: row.clone(),
});
}

Ok(())
}

/// Converts [TableBuilder] to [RowInsertRequest] and clears buffered data.
Expand All @@ -187,14 +210,17 @@ mod tests {
use api::prom_store::remote::Sample;
use api::v1::value::ValueData;
use api::v1::Value;
use arrow::datatypes::ToByteSlice;
use bytes::Bytes;
use prost::DecodeError;

use crate::prom_row_builder::TableBuilder;
use crate::proto::PromLabel;
#[test]
fn test_table_builder() {
let mut builder = TableBuilder::default();
builder.add_labels_and_samples(
let with_strict_mode = true;
let _ = builder.add_labels_and_samples(
&[
PromLabel {
name: Bytes::from("tag0"),
Expand All @@ -209,9 +235,10 @@ mod tests {
value: 0.0,
timestamp: 0,
}],
with_strict_mode,
);

builder.add_labels_and_samples(
let _ = builder.add_labels_and_samples(
&[
PromLabel {
name: Bytes::from("tag0"),
Expand All @@ -226,6 +253,7 @@ mod tests {
value: 0.1,
timestamp: 1,
}],
with_strict_mode,
);

let request = builder.as_row_insert_request("test".to_string());
Expand Down Expand Up @@ -269,5 +297,20 @@ mod tests {
],
rows[1].values
);

let invalid_utf8_bytes = &[0xFF, 0xFF, 0xFF];

let res = builder.add_labels_and_samples(
&[PromLabel {
name: Bytes::from("tag0"),
value: invalid_utf8_bytes.to_byte_slice().into(),
}],
&[Sample {
value: 0.1,
timestamp: 1,
}],
with_strict_mode,
);
assert_eq!(res, Err(DecodeError::new("invalid utf-8")));
}
}
Loading