Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pgwire):support mix format in extended query mode #7622

Merged
merged 2 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/frontend/src/handler/create_table_as.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,5 @@ pub async fn handle_create_as(
returning: vec![],
};

handle_query(handler_args, insert, false).await
handle_query(handler_args, insert, vec![]).await
}
6 changes: 3 additions & 3 deletions src/frontend/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use futures::{Stream, StreamExt};
use pgwire::pg_response::StatementType::{ABORT, BEGIN, COMMIT, ROLLBACK, START_TRANSACTION};
use pgwire::pg_response::{PgResponse, RowSetResult};
use pgwire::pg_server::BoxedError;
use pgwire::types::Row;
use pgwire::types::{Format, Row};
use risingwave_common::error::{ErrorCode, Result};
use risingwave_sqlparser::ast::*;

Expand Down Expand Up @@ -151,7 +151,7 @@ pub async fn handle(
session: Arc<SessionImpl>,
stmt: Statement,
sql: &str,
format: bool,
formats: Vec<Format>,
) -> Result<RwPgResponse> {
session.clear_cancel_query_flag();
let handler_args = HandlerArgs::new(session, &stmt, sql)?;
Expand Down Expand Up @@ -307,7 +307,7 @@ pub async fn handle(
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_query(handler_args, stmt, format).await,
| Statement::Update { .. } => query::handle_query(handler_args, stmt, formats).await,
Statement::CreateView {
materialized,
name,
Expand Down
12 changes: 8 additions & 4 deletions src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use futures::StreamExt;
use itertools::Itertools;
use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::{PgResponse, StatementType};
use pgwire::types::Format;
use postgres_types::FromSql;
use risingwave_common::catalog::Schema;
use risingwave_common::error::{ErrorCode, Result, RwError};
Expand Down Expand Up @@ -93,7 +94,7 @@ pub fn gen_batch_query_plan(
pub async fn handle_query(
handler_args: HandlerArgs,
stmt: Statement,
format: bool,
formats: Vec<Format>,
) -> Result<RwPgResponse> {
let stmt_type = to_statement_type(&stmt)?;
let session = handler_args.session.clone();
Expand Down Expand Up @@ -133,6 +134,9 @@ pub async fn handle_query(
.map(|f| f.data_type())
.collect_vec();

// Used in counting row count.
let first_field_format = formats.first().copied().unwrap_or(Format::Text);

let mut row_stream = {
let query_epoch = session.config().get_query_epoch();
let query_snapshot = if let Some(query_epoch) = query_epoch {
Expand All @@ -149,15 +153,15 @@ pub async fn handle_query(
QueryMode::Local => PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new(
local_execute(session.clone(), query, query_snapshot).await?,
column_types,
format,
formats,
session.clone(),
)),
// Local mode do not support cancel tasks.
QueryMode::Distributed => {
PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new(
distribute_execute(session.clone(), query, query_snapshot).await?,
column_types,
format,
formats,
session.clone(),
))
}
Expand Down Expand Up @@ -185,7 +189,7 @@ pub async fn handle_query(
let affected_rows_str = first_row_set[0].values()[0]
.as_ref()
.expect("compute node should return affected rows in output");
if format {
if let Format::Binary = first_field_format {
Some(
i64::from_sql(&postgres_types::Type::INT8, affected_rows_str)
.unwrap()
Expand Down
120 changes: 93 additions & 27 deletions src/frontend/src/handler/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ use itertools::Itertools;
use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::RowSetResult;
use pgwire::pg_server::BoxedError;
use pgwire::types::Row;
use pgwire::types::{Format, FormatIterator, Row};
use pin_project_lite::pin_project;
use risingwave_common::array::DataChunk;
use risingwave_common::catalog::{ColumnDesc, Field};
use risingwave_common::error::Result as RwResult;
use risingwave_common::error::{ErrorCode, Result as RwResult};
use risingwave_common::row::Row as _;
use risingwave_common::types::{DataType, ScalarRefImpl};
use risingwave_expr::vector_op::timestamptz::timestamptz_to_string;
Expand All @@ -47,7 +47,7 @@ pin_project! {
#[pin]
chunk_stream: VS,
column_types: Vec<DataType>,
format: bool,
formats: Vec<Format>,
session_data: StaticSessionData,
}
}
Expand All @@ -64,7 +64,7 @@ where
pub fn new(
chunk_stream: VS,
column_types: Vec<DataType>,
format: bool,
formats: Vec<Format>,
session: Arc<SessionImpl>,
) -> Self {
let session_data = StaticSessionData {
Expand All @@ -73,7 +73,7 @@ where
Self {
chunk_stream,
column_types,
format,
formats,
session_data,
}
}
Expand All @@ -92,7 +92,7 @@ where
Poll::Ready(chunk) => match chunk {
Some(chunk_result) => match chunk_result {
Ok(chunk) => Poll::Ready(Some(
to_pg_rows(this.column_types, chunk, *this.format, this.session_data)
to_pg_rows(this.column_types, chunk, this.formats, this.session_data)
.map_err(|err| err.into()),
)),
Err(err) => Poll::Ready(Some(Err(err))),
Expand All @@ -107,19 +107,20 @@ where
fn pg_value_format(
data_type: &DataType,
d: ScalarRefImpl<'_>,
format: bool,
format: Format,
session_data: &StaticSessionData,
) -> RwResult<Bytes> {
// format == false means TEXT format
// format == true means BINARY format
if !format {
if *data_type == DataType::Timestamptz {
Ok(timestamptz_to_string_with_session_data(d, session_data))
} else {
Ok(d.text_format(data_type).into())
match format {
Format::Text => {
if *data_type == DataType::Timestamptz {
Ok(timestamptz_to_string_with_session_data(d, session_data))
} else {
Ok(d.text_format(data_type).into())
}
}
} else {
d.binary_format(data_type)
Format::Binary => d.binary_format(data_type),
}
}

Expand All @@ -140,16 +141,21 @@ fn timestamptz_to_string_with_session_data(
fn to_pg_rows(
column_types: &[DataType],
chunk: DataChunk,
format: bool,
formats: &[Format],
session_data: &StaticSessionData,
) -> RwResult<Vec<Row>> {
assert_eq!(chunk.dimension(), column_types.len());

chunk
.rows()
.map(|r| {
let format_iter = FormatIterator::new(formats, chunk.dimension())
.map_err(ErrorCode::InternalError)?;
let row = r
.iter()
.zip_eq(column_types)
.map(|(data, t)| match data {
.zip_eq(format_iter)
.map(|((data, t), format)| match data {
Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(),
None => Ok(None),
})
Expand Down Expand Up @@ -190,6 +196,8 @@ pub fn to_pg_field(f: &Field) -> PgFieldDescriptor {

#[cfg(test)]
mod tests {
use bytes::BytesMut;
use postgres_types::{ToSql, Type};
use risingwave_common::array::*;

use super::*;
Expand Down Expand Up @@ -222,7 +230,7 @@ mod tests {
DataType::Varchar,
],
chunk,
false,
&[],
&static_session,
);
let expected: Vec<Vec<Option<Bytes>>> = vec![
Expand Down Expand Up @@ -250,6 +258,50 @@ mod tests {
assert_eq!(vec, expected);
}

#[test]
fn test_to_pg_rows_mix_format() {
let chunk = DataChunk::from_pretty(
"i I f T
1 6 6.01 aaa
",
);
let static_session = StaticSessionData {
timezone: "UTC".into(),
};
let rows = to_pg_rows(
&[
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Varchar,
],
chunk,
&[Format::Binary, Format::Binary, Format::Binary, Format::Text],
&static_session,
);
let mut raw_params = vec![BytesMut::new(); 3];
1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap();
6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap();
6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap();
let raw_params = raw_params
.into_iter()
.map(|b| b.freeze())
.collect::<Vec<_>>();
let expected: Vec<Vec<Option<Bytes>>> = vec![vec![
Some(raw_params[0].clone()),
Some(raw_params[1].clone()),
Some(raw_params[2].clone()),
Some("aaa".into()),
]];
let vec = rows
.unwrap()
.into_iter()
.map(|r| r.values().iter().cloned().collect_vec())
.collect_vec();

assert_eq!(vec, expected);
}

#[test]
fn test_value_format() {
use {DataType as T, ScalarRefImpl as S};
Expand All @@ -258,29 +310,43 @@ mod tests {
};

let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap();
assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), false), "1");
assert_eq!(&f(&T::Float32, S::Float32(f32::NAN.into()), false), "NaN");
assert_eq!(&f(&T::Float64, S::Float64(f64::NAN.into()), false), "NaN");
assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1");
assert_eq!(
&f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text),
"NaN"
);
assert_eq!(
&f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text),
"NaN"
);
assert_eq!(
&f(&T::Float32, S::Float32(f32::INFINITY.into()), false),
&f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text),
"Infinity"
);
assert_eq!(
&f(&T::Float32, S::Float32(f32::NEG_INFINITY.into()), false),
&f(
&T::Float32,
S::Float32(f32::NEG_INFINITY.into()),
Format::Text
),
"-Infinity"
);
assert_eq!(
&f(&T::Float64, S::Float64(f64::INFINITY.into()), false),
&f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text),
"Infinity"
);
assert_eq!(
&f(&T::Float64, S::Float64(f64::NEG_INFINITY.into()), false),
&f(
&T::Float64,
S::Float64(f64::NEG_INFINITY.into()),
Format::Text
),
"-Infinity"
);
assert_eq!(&f(&T::Boolean, S::Bool(true), false), "t");
assert_eq!(&f(&T::Boolean, S::Bool(false), false), "f");
assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t");
assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f");
assert_eq!(
&f(&T::Timestamptz, S::Int64(-1), false),
&f(&T::Timestamptz, S::Int64(-1), Format::Text),
"1969-12-31 23:59:59.999999+00:00"
);
}
Expand Down
13 changes: 5 additions & 8 deletions src/frontend/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use parking_lot::{RwLock, RwLockReadGuard};
use pgwire::pg_field_descriptor::PgFieldDescriptor;
use pgwire::pg_response::PgResponse;
use pgwire::pg_server::{BoxedError, Session, SessionId, SessionManager, UserAuthenticator};
use pgwire::types::Format;
use rand::RngCore;
use risingwave_common::array::DataChunk;
use risingwave_common::catalog::DEFAULT_SCHEMA_NAME;
Expand Down Expand Up @@ -653,11 +654,7 @@ impl Session<PgResponseStream> for SessionImpl {
async fn run_statement(
self: Arc<Self>,
sql: &str,

// format: indicate the query PgResponse format (Only meaningful for SELECT queries).
// false: TEXT
// true: BINARY
format: bool,
formats: Vec<Format>,
) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
// Parse sql.
let mut stmts = Parser::parse_sql(sql)
Expand All @@ -675,7 +672,7 @@ impl Session<PgResponseStream> for SessionImpl {
}
let stmt = stmts.swap_remove(0);
let rsp = {
let mut handle_fut = Box::pin(handle(self, stmt, sql, format));
let mut handle_fut = Box::pin(handle(self, stmt, sql, formats));
if cfg!(debug_assertions) {
// Report the SQL in the log periodically if the query is slow.
const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60);
Expand All @@ -702,11 +699,11 @@ impl Session<PgResponseStream> for SessionImpl {
async fn run_one_query(
self: Arc<Self>,
stmt: Statement,
format: bool,
format: Format,
) -> std::result::Result<PgResponse<PgResponseStream>, BoxedError> {
let sql_str = stmt.to_string();
let rsp = {
let mut handle_fut = Box::pin(handle(self, stmt, &sql_str, format));
let mut handle_fut = Box::pin(handle(self, stmt, &sql_str, vec![format]));
if cfg!(debug_assertions) {
// Report the SQL in the log periodically if the query is slow.
const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60);
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl LocalFrontend {
sql: impl Into<String>,
) -> std::result::Result<RwPgResponse, Box<dyn std::error::Error + Send + Sync>> {
let sql = sql.into();
self.session_ref().run_statement(sql.as_str(), false).await
self.session_ref().run_statement(sql.as_str(), vec![]).await
}

pub async fn run_user_sql(
Expand All @@ -103,7 +103,7 @@ impl LocalFrontend {
) -> std::result::Result<RwPgResponse, Box<dyn std::error::Error + Send + Sync>> {
let sql = sql.into();
self.session_user_ref(database, user_name, user_id)
.run_statement(sql.as_str(), false)
.run_statement(sql.as_str(), vec![])
.await
}

Expand Down
2 changes: 1 addition & 1 deletion src/tests/sqlsmith/tests/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub struct SqlsmithEnv {
/// Skip status is required, so that we know if a SQL statement writing to the database was skipped.
/// Then, we can infer the correct state of the database.
async fn handle(session: Arc<SessionImpl>, stmt: Statement, sql: &str) -> Result<bool> {
let result = handler::handle(session.clone(), stmt, sql, false)
let result = handler::handle(session.clone(), stmt, sql, vec![])
.await
.map(|_| ())
.map_err(|e| format!("Error Reason:\n{}", e).into());
Expand Down
1 change: 1 addition & 0 deletions src/utils/pgwire/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#![feature(lint_reasons, once_cell)]
#![feature(trait_alias)]
#![feature(result_option_inspect)]
#![feature(iterator_try_collect)]
#![expect(clippy::doc_markdown, reason = "FIXME: later")]

pub mod error;
Expand Down
Loading