Skip to content

Commit

Permalink
feat: add an option to turn on compression for arrow output (#4730)
Browse files Browse the repository at this point in the history
* feat: add an option to turn on compression for arrow output

* fix: typo
  • Loading branch information
sunng87 authored Sep 19, 2024
1 parent d1e0602 commit 08bd403
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 9 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ aquamarine = "0.3"
arrow = { version = "51.0.0", features = ["prettyprint"] }
arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] }
arrow-flight = "51.0"
arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4"] }
arrow-ipc = { version = "51.0.0", default-features = false, features = ["lz4", "zstd"] }
arrow-schema = { version = "51.0", features = ["serde"] }
async-stream = "0.3"
async-trait = "0.1"
Expand Down
2 changes: 1 addition & 1 deletion src/servers/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ mod test {
RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];
let json_resp = match format {
ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await,
ResponseFormat::Arrow => ArrowResponse::from_output(outputs, None).await,
ResponseFormat::Csv => CsvResponse::from_output(outputs).await,
ResponseFormat::Table => TableResponse::from_output(outputs).await,
ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
Expand Down
95 changes: 90 additions & 5 deletions src/servers/src/http/arrow_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ use std::pin::Pin;
use std::sync::Arc;

use arrow::datatypes::Schema;
use arrow_ipc::writer::FileWriter;
use arrow_ipc::writer::{FileWriter, IpcWriteOptions};
use arrow_ipc::CompressionType;
use axum::http::{header, HeaderValue};
use axum::response::{IntoResponse, Response};
use common_error::status_code::StatusCode;
Expand All @@ -41,10 +42,15 @@ pub struct ArrowResponse {
async fn write_arrow_bytes(
mut recordbatches: Pin<Box<dyn RecordBatchStream + Send>>,
schema: &Arc<Schema>,
compression: Option<CompressionType>,
) -> Result<Vec<u8>, Error> {
let mut bytes = Vec::new();
{
let mut writer = FileWriter::try_new(&mut bytes, schema).context(error::ArrowSnafu)?;
let options = IpcWriteOptions::default()
.try_with_compression(compression)
.context(error::ArrowSnafu)?;
let mut writer = FileWriter::try_new_with_options(&mut bytes, schema, options)
.context(error::ArrowSnafu)?;

while let Some(rb) = recordbatches.next().await {
let rb = rb.context(error::CollectRecordbatchSnafu)?;
Expand All @@ -59,15 +65,31 @@ async fn write_arrow_bytes(
Ok(bytes)
}

fn compression_type(compression: Option<String>) -> Option<CompressionType> {
match compression
.map(|compression| compression.to_lowercase())
.as_deref()
{
Some("zstd") => Some(CompressionType::ZSTD),
Some("lz4") => Some(CompressionType::LZ4_FRAME),
_ => None,
}
}

impl ArrowResponse {
pub async fn from_output(mut outputs: Vec<error::Result<Output>>) -> HttpResponse {
pub async fn from_output(
mut outputs: Vec<error::Result<Output>>,
compression: Option<String>,
) -> HttpResponse {
if outputs.len() > 1 {
return HttpResponse::Error(ErrorResponse::from_error_message(
StatusCode::InvalidArguments,
"cannot output multi-statements result in arrow format".to_string(),
));
}

let compression = compression_type(compression);

match outputs.pop() {
None => HttpResponse::Arrow(ArrowResponse {
data: vec![],
Expand All @@ -80,7 +102,9 @@ impl ArrowResponse {
}),
OutputData::RecordBatches(batches) => {
let schema = batches.schema();
match write_arrow_bytes(batches.as_stream(), schema.arrow_schema()).await {
match write_arrow_bytes(batches.as_stream(), schema.arrow_schema(), compression)
.await
{
Ok(payload) => HttpResponse::Arrow(ArrowResponse {
data: payload,
execution_time_ms: 0,
Expand All @@ -90,7 +114,7 @@ impl ArrowResponse {
}
OutputData::Stream(batches) => {
let schema = batches.schema();
match write_arrow_bytes(batches, schema.arrow_schema()).await {
match write_arrow_bytes(batches, schema.arrow_schema(), compression).await {
Ok(payload) => HttpResponse::Arrow(ArrowResponse {
data: payload,
execution_time_ms: 0,
Expand Down Expand Up @@ -136,3 +160,64 @@ impl IntoResponse for ArrowResponse {
.into_response()
}
}

#[cfg(test)]
mod test {
use std::io::Cursor;

use arrow_ipc::reader::FileReader;
use arrow_schema::DataType;
use common_recordbatch::{RecordBatch, RecordBatches};
use datatypes::prelude::*;
use datatypes::schema::{ColumnSchema, Schema};
use datatypes::vectors::{StringVector, UInt32Vector};

use super::*;

#[tokio::test]
async fn test_arrow_output() {
let column_schemas = vec![
ColumnSchema::new("numbers", ConcreteDataType::uint32_datatype(), false),
ColumnSchema::new("strings", ConcreteDataType::string_datatype(), true),
];
let schema = Arc::new(Schema::new(column_schemas));
let columns: Vec<VectorRef> = vec![
Arc::new(UInt32Vector::from_slice(vec![1, 2, 3, 4])),
Arc::new(StringVector::from(vec![
None,
Some("hello"),
Some("greptime"),
None,
])),
];

for compression in [None, Some("zstd".to_string()), Some("lz4".to_string())].into_iter() {
let recordbatch = RecordBatch::new(schema.clone(), columns.clone()).unwrap();
let recordbatches =
RecordBatches::try_new(schema.clone(), vec![recordbatch.clone()]).unwrap();
let outputs = vec![Ok(Output::new_with_record_batches(recordbatches))];

let http_resp = ArrowResponse::from_output(outputs, compression).await;
match http_resp {
HttpResponse::Arrow(resp) => {
let output = resp.data;
let mut reader =
FileReader::try_new(Cursor::new(output), None).expect("Arrow reader error");
let schema = reader.schema();
assert_eq!(schema.fields[0].name(), "numbers");
assert_eq!(schema.fields[0].data_type(), &DataType::UInt32);
assert_eq!(schema.fields[1].name(), "strings");
assert_eq!(schema.fields[1].data_type(), &DataType::Utf8);

let rb = reader.next().unwrap().expect("read record batch failed");
assert_eq!(rb.num_columns(), 2);
assert_eq!(rb.num_rows(), 4);
}
HttpResponse::Error(e) => {
panic!("unexpected {:?}", e);
}
_ => unreachable!(),
}
}
}
}
9 changes: 7 additions & 2 deletions src/servers/src/http/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ use crate::query_handler::sql::ServerSqlQueryHandlerRef;
pub struct SqlQuery {
pub db: Option<String>,
pub sql: Option<String>,
// (Optional) result format: [`greptimedb_v1`, `influxdb_v1`, `csv`],
// (Optional) result format: [`greptimedb_v1`, `influxdb_v1`, `csv`,
// `arrow`],
// the default value is `greptimedb_v1`
pub format: Option<String>,
// Returns epoch timestamps with the specified precision.
Expand All @@ -64,6 +65,8 @@ pub struct SqlQuery {
// param too.
pub epoch: Option<String>,
pub limit: Option<usize>,
// For arrow output
pub compression: Option<String>,
}

/// Handler to execute sql
Expand Down Expand Up @@ -128,7 +131,9 @@ pub async fn sql(
};

let mut resp = match format {
ResponseFormat::Arrow => ArrowResponse::from_output(outputs).await,
ResponseFormat::Arrow => {
ArrowResponse::from_output(outputs, query_params.compression).await
}
ResponseFormat::Csv => CsvResponse::from_output(outputs).await,
ResponseFormat::Table => TableResponse::from_output(outputs).await,
ResponseFormat::GreptimedbV1 => GreptimedbV1Response::from_output(outputs).await,
Expand Down

0 comments on commit 08bd403

Please sign in to comment.