From bc8fbc1e478aa9838364d61cf5bb58641b27c35f Mon Sep 17 00:00:00 2001 From: Zejiong Dong <810918843@qq.com> Date: Wed, 1 Feb 2023 20:24:30 +0800 Subject: [PATCH] support mix format in extended query mode --- src/frontend/src/handler/create_table_as.rs | 2 +- src/frontend/src/handler/mod.rs | 6 +- src/frontend/src/handler/query.rs | 12 +- src/frontend/src/handler/util.rs | 128 ++++++-- src/frontend/src/session.rs | 13 +- src/frontend/src/test_utils.rs | 4 +- src/tests/sqlsmith/tests/frontend/mod.rs | 2 +- src/utils/pgwire/src/lib.rs | 1 + src/utils/pgwire/src/pg_extended.rs | 323 +++++++++++--------- src/utils/pgwire/src/pg_message.rs | 55 +--- src/utils/pgwire/src/pg_protocol.rs | 18 +- src/utils/pgwire/src/pg_server.rs | 14 +- src/utils/pgwire/src/types.rs | 82 +++++ 13 files changed, 419 insertions(+), 241 deletions(-) diff --git a/src/frontend/src/handler/create_table_as.rs b/src/frontend/src/handler/create_table_as.rs index d9ee2aaf8df1..fb53d539b34b 100644 --- a/src/frontend/src/handler/create_table_as.rs +++ b/src/frontend/src/handler/create_table_as.rs @@ -120,5 +120,5 @@ pub async fn handle_create_as( returning: vec![], }; - handle_query(handler_args, insert, false).await + handle_query(handler_args, insert, vec![]).await } diff --git a/src/frontend/src/handler/mod.rs b/src/frontend/src/handler/mod.rs index 717818cb2a8b..7644abe6e008 100644 --- a/src/frontend/src/handler/mod.rs +++ b/src/frontend/src/handler/mod.rs @@ -21,7 +21,7 @@ use futures::{Stream, StreamExt}; use pgwire::pg_response::StatementType::{ABORT, BEGIN, COMMIT, ROLLBACK, START_TRANSACTION}; use pgwire::pg_response::{PgResponse, RowSetResult}; use pgwire::pg_server::BoxedError; -use pgwire::types::Row; +use pgwire::types::{Format, Row}; use risingwave_common::error::{ErrorCode, Result}; use risingwave_sqlparser::ast::*; @@ -151,7 +151,7 @@ pub async fn handle( session: Arc, stmt: Statement, sql: &str, - format: bool, + formats: Vec, ) -> Result { session.clear_cancel_query_flag(); let handler_args = HandlerArgs::new(session, &stmt, sql)?; @@ -307,7 +307,7 @@ pub async fn handle( Statement::Query(_) | Statement::Insert { .. } | Statement::Delete { .. } - | Statement::Update { .. } => query::handle_query(handler_args, stmt, format).await, + | Statement::Update { .. } => query::handle_query(handler_args, stmt, formats).await, Statement::CreateView { materialized, name, diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index 832681ea6bf6..452e3998e7ee 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -19,6 +19,7 @@ use futures::StreamExt; use itertools::Itertools; use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::{PgResponse, StatementType}; +use pgwire::types::Format; use postgres_types::FromSql; use risingwave_common::catalog::Schema; use risingwave_common::error::{ErrorCode, Result, RwError}; @@ -93,7 +94,7 @@ pub fn gen_batch_query_plan( pub async fn handle_query( handler_args: HandlerArgs, stmt: Statement, - format: bool, + formats: Vec, ) -> Result { let stmt_type = to_statement_type(&stmt)?; let session = handler_args.session.clone(); @@ -133,6 +134,9 @@ pub async fn handle_query( .map(|f| f.data_type()) .collect_vec(); + // Used in counting row count. + let first_field_format = formats.first().copied().unwrap_or(Format::Text); + let mut row_stream = { let query_epoch = session.config().get_query_epoch(); let query_snapshot = if let Some(query_epoch) = query_epoch { @@ -149,7 +153,7 @@ pub async fn handle_query( QueryMode::Local => PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new( local_execute(session.clone(), query, query_snapshot).await?, column_types, - format, + formats, session.clone(), )), // Local mode do not support cancel tasks. @@ -157,7 +161,7 @@ pub async fn handle_query( PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new( distribute_execute(session.clone(), query, query_snapshot).await?, column_types, - format, + formats, session.clone(), )) } @@ -179,7 +183,7 @@ pub async fn handle_query( let affected_rows_str = first_row_set[0].values()[0] .as_ref() .expect("compute node should return affected rows in output"); - if format { + if let Format::Binary = first_field_format { Some( i64::from_sql(&postgres_types::Type::INT8, affected_rows_str) .unwrap() diff --git a/src/frontend/src/handler/util.rs b/src/frontend/src/handler/util.rs index fc7b3ba7d905..5fc3657dd18b 100644 --- a/src/frontend/src/handler/util.rs +++ b/src/frontend/src/handler/util.rs @@ -22,11 +22,11 @@ use itertools::Itertools; use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::RowSetResult; use pgwire::pg_server::BoxedError; -use pgwire::types::Row; +use pgwire::types::{Format, FormatIterator, Row}; use pin_project_lite::pin_project; use risingwave_common::array::DataChunk; use risingwave_common::catalog::{ColumnDesc, Field}; -use risingwave_common::error::Result as RwResult; +use risingwave_common::error::{ErrorCode, Result as RwResult}; use risingwave_common::row::Row as _; use risingwave_common::types::{DataType, ScalarRefImpl}; use risingwave_expr::vector_op::timestamptz::timestamptz_to_string; @@ -47,7 +47,7 @@ pin_project! { #[pin] chunk_stream: VS, column_types: Vec, - format: bool, + formats: Vec, session_data: StaticSessionData, } } @@ -64,7 +64,7 @@ where pub fn new( chunk_stream: VS, column_types: Vec, - format: bool, + formats: Vec, session: Arc, ) -> Self { let session_data = StaticSessionData { @@ -73,7 +73,7 @@ where Self { chunk_stream, column_types, - format, + formats, session_data, } } @@ -92,7 +92,7 @@ where Poll::Ready(chunk) => match chunk { Some(chunk_result) => match chunk_result { Ok(chunk) => Poll::Ready(Some( - to_pg_rows(this.column_types, chunk, *this.format, this.session_data) + to_pg_rows(this.column_types, chunk, this.formats, this.session_data) .map_err(|err| err.into()), )), Err(err) => Poll::Ready(Some(Err(err))), @@ -107,19 +107,20 @@ where fn pg_value_format( data_type: &DataType, d: ScalarRefImpl<'_>, - format: bool, + format: Format, session_data: &StaticSessionData, ) -> RwResult { // format == false means TEXT format // format == true means BINARY format - if !format { - if *data_type == DataType::Timestamptz { - Ok(timestamptz_to_string_with_session_data(d, session_data)) - } else { - Ok(d.text_format(data_type).into()) + match format { + Format::Text => { + if *data_type == DataType::Timestamptz { + Ok(timestamptz_to_string_with_session_data(d, session_data)) + } else { + Ok(d.text_format(data_type).into()) + } } - } else { - d.binary_format(data_type) + Format::Binary => d.binary_format(data_type), } } @@ -140,16 +141,29 @@ fn timestamptz_to_string_with_session_data( fn to_pg_rows( column_types: &[DataType], chunk: DataChunk, - format: bool, + formats: &[Format], session_data: &StaticSessionData, ) -> RwResult> { + // Invariant check + if !formats.is_empty() && formats.len() != 1 && formats.len() != column_types.len() { + return Err(ErrorCode::InternalError(format!( + "format codes length {} is not 0, 1 or equal to column length {}", + formats.len(), + column_types.len() + )) + .into()); + } + chunk .rows() .map(|r| { + let format_iter = FormatIterator::new(formats, chunk.dimension()) + .map_err(ErrorCode::InternalError)?; let row = r .iter() .zip_eq(column_types) - .map(|(data, t)| match data { + .zip_eq(format_iter) + .map(|((data, t), format)| match data { Some(data) => Some(pg_value_format(t, data, format, session_data)).transpose(), None => Ok(None), }) @@ -190,6 +204,8 @@ pub fn to_pg_field(f: &Field) -> PgFieldDescriptor { #[cfg(test)] mod tests { + use bytes::BytesMut; + use postgres_types::{ToSql, Type}; use risingwave_common::array::*; use super::*; @@ -222,7 +238,7 @@ mod tests { DataType::Varchar, ], chunk, - false, + &[], &static_session, ); let expected: Vec>> = vec![ @@ -250,6 +266,50 @@ mod tests { assert_eq!(vec, expected); } + #[test] + fn test_to_pg_rows_mix_format() { + let chunk = DataChunk::from_pretty( + "i I f T + 1 6 6.01 aaa + ", + ); + let static_session = StaticSessionData { + timezone: "UTC".into(), + }; + let rows = to_pg_rows( + &[ + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Varchar, + ], + chunk, + &[Format::Binary, Format::Binary, Format::Binary, Format::Text], + &static_session, + ); + let mut raw_params = vec![BytesMut::new(); 3]; + 1_i32.to_sql(&Type::ANY, &mut raw_params[0]).unwrap(); + 6_i64.to_sql(&Type::ANY, &mut raw_params[1]).unwrap(); + 6.01_f32.to_sql(&Type::ANY, &mut raw_params[2]).unwrap(); + let raw_params = raw_params + .into_iter() + .map(|b| b.freeze()) + .collect::>(); + let expected: Vec>> = vec![vec![ + Some(raw_params[0].clone()), + Some(raw_params[1].clone()), + Some(raw_params[2].clone()), + Some("aaa".into()), + ]]; + let vec = rows + .unwrap() + .into_iter() + .map(|r| r.values().iter().cloned().collect_vec()) + .collect_vec(); + + assert_eq!(vec, expected); + } + #[test] fn test_value_format() { use {DataType as T, ScalarRefImpl as S}; @@ -258,29 +318,43 @@ mod tests { }; let f = |t, d, f| pg_value_format(t, d, f, &static_session).unwrap(); - assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), false), "1"); - assert_eq!(&f(&T::Float32, S::Float32(f32::NAN.into()), false), "NaN"); - assert_eq!(&f(&T::Float64, S::Float64(f64::NAN.into()), false), "NaN"); + assert_eq!(&f(&T::Float32, S::Float32(1_f32.into()), Format::Text), "1"); + assert_eq!( + &f(&T::Float32, S::Float32(f32::NAN.into()), Format::Text), + "NaN" + ); + assert_eq!( + &f(&T::Float64, S::Float64(f64::NAN.into()), Format::Text), + "NaN" + ); assert_eq!( - &f(&T::Float32, S::Float32(f32::INFINITY.into()), false), + &f(&T::Float32, S::Float32(f32::INFINITY.into()), Format::Text), "Infinity" ); assert_eq!( - &f(&T::Float32, S::Float32(f32::NEG_INFINITY.into()), false), + &f( + &T::Float32, + S::Float32(f32::NEG_INFINITY.into()), + Format::Text + ), "-Infinity" ); assert_eq!( - &f(&T::Float64, S::Float64(f64::INFINITY.into()), false), + &f(&T::Float64, S::Float64(f64::INFINITY.into()), Format::Text), "Infinity" ); assert_eq!( - &f(&T::Float64, S::Float64(f64::NEG_INFINITY.into()), false), + &f( + &T::Float64, + S::Float64(f64::NEG_INFINITY.into()), + Format::Text + ), "-Infinity" ); - assert_eq!(&f(&T::Boolean, S::Bool(true), false), "t"); - assert_eq!(&f(&T::Boolean, S::Bool(false), false), "f"); + assert_eq!(&f(&T::Boolean, S::Bool(true), Format::Text), "t"); + assert_eq!(&f(&T::Boolean, S::Bool(false), Format::Text), "f"); assert_eq!( - &f(&T::Timestamptz, S::Int64(-1), false), + &f(&T::Timestamptz, S::Int64(-1), Format::Text), "1969-12-31 23:59:59.999999+00:00" ); } diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index 748848542db1..17d74c9ef0e3 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -23,6 +23,7 @@ use parking_lot::{RwLock, RwLockReadGuard}; use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::PgResponse; use pgwire::pg_server::{BoxedError, Session, SessionId, SessionManager, UserAuthenticator}; +use pgwire::types::Format; use rand::RngCore; use risingwave_common::array::DataChunk; use risingwave_common::catalog::DEFAULT_SCHEMA_NAME; @@ -652,11 +653,7 @@ impl Session for SessionImpl { async fn run_statement( self: Arc, sql: &str, - - // format: indicate the query PgResponse format (Only meaningful for SELECT queries). - // false: TEXT - // true: BINARY - format: bool, + formats: Vec, ) -> std::result::Result, BoxedError> { // Parse sql. let mut stmts = Parser::parse_sql(sql) @@ -674,7 +671,7 @@ impl Session for SessionImpl { } let stmt = stmts.swap_remove(0); let rsp = { - let mut handle_fut = Box::pin(handle(self, stmt, sql, format)); + let mut handle_fut = Box::pin(handle(self, stmt, sql, formats)); if cfg!(debug_assertions) { // Report the SQL in the log periodically if the query is slow. const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60); @@ -701,11 +698,11 @@ impl Session for SessionImpl { async fn run_one_query( self: Arc, stmt: Statement, - format: bool, + format: Format, ) -> std::result::Result, BoxedError> { let sql_str = stmt.to_string(); let rsp = { - let mut handle_fut = Box::pin(handle(self, stmt, &sql_str, format)); + let mut handle_fut = Box::pin(handle(self, stmt, &sql_str, vec![format])); if cfg!(debug_assertions) { // Report the SQL in the log periodically if the query is slow. const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60); diff --git a/src/frontend/src/test_utils.rs b/src/frontend/src/test_utils.rs index 2b5acd929be4..131a7b04a9ae 100644 --- a/src/frontend/src/test_utils.rs +++ b/src/frontend/src/test_utils.rs @@ -91,7 +91,7 @@ impl LocalFrontend { sql: impl Into, ) -> std::result::Result> { let sql = sql.into(); - self.session_ref().run_statement(sql.as_str(), false).await + self.session_ref().run_statement(sql.as_str(), vec![]).await } pub async fn run_user_sql( @@ -103,7 +103,7 @@ impl LocalFrontend { ) -> std::result::Result> { let sql = sql.into(); self.session_user_ref(database, user_name, user_id) - .run_statement(sql.as_str(), false) + .run_statement(sql.as_str(), vec![]) .await } diff --git a/src/tests/sqlsmith/tests/frontend/mod.rs b/src/tests/sqlsmith/tests/frontend/mod.rs index 7997ae22fe2a..2f37b8247b9e 100644 --- a/src/tests/sqlsmith/tests/frontend/mod.rs +++ b/src/tests/sqlsmith/tests/frontend/mod.rs @@ -46,7 +46,7 @@ pub struct SqlsmithEnv { /// Skip status is required, so that we know if a SQL statement writing to the database was skipped. /// Then, we can infer the correct state of the database. async fn handle(session: Arc, stmt: Statement, sql: &str) -> Result { - let result = handler::handle(session.clone(), stmt, sql, false) + let result = handler::handle(session.clone(), stmt, sql, vec![]) .await .map(|_| ()) .map_err(|e| format!("Error Reason:\n{}", e).into()); diff --git a/src/utils/pgwire/src/lib.rs b/src/utils/pgwire/src/lib.rs index a26a7d20d3ee..dad979455d45 100644 --- a/src/utils/pgwire/src/lib.rs +++ b/src/utils/pgwire/src/lib.rs @@ -16,6 +16,7 @@ #![feature(lint_reasons, once_cell)] #![feature(trait_alias)] #![feature(result_option_inspect)] +#![feature(iterator_try_collect)] #![expect(clippy::doc_markdown, reason = "FIXME: later")] pub mod error; diff --git a/src/utils/pgwire/src/pg_extended.rs b/src/utils/pgwire/src/pg_extended.rs index 929ad7c4c4cf..e30d7ef5fa72 100644 --- a/src/utils/pgwire/src/pg_extended.rs +++ b/src/utils/pgwire/src/pg_extended.rs @@ -33,7 +33,7 @@ use crate::pg_message::{BeCommandCompleteMessage, BeMessage}; use crate::pg_protocol::{cstr_to_str, Conn}; use crate::pg_response::{PgResponse, RowSetResult}; use crate::pg_server::{Session, SessionManager}; -use crate::types::Row; +use crate::types::{Format, FormatIterator, Row}; #[derive(Default)] pub struct PgStatement { @@ -78,28 +78,33 @@ impl PgStatement { &self, portal_name: String, params: &[Bytes], - result_format: bool, - param_format: bool, + result_formats: Vec, + param_formats: Vec, ) -> PsqlResult> where VS: Stream + Unpin + Send, { - let instance_query_string = self.prepared_statement.instance(params, param_format)?; + let instance_query_string = self.prepared_statement.instance(params, ¶m_formats)?; - let row_description: Vec = if result_format { + let format_iter = FormatIterator::new(&result_formats, self.row_description.len()) + .map_err(|err| PsqlError::Internal(anyhow!(err)))?; + let row_description: Vec = { let mut row_description = self.row_description.clone(); row_description .iter_mut() - .for_each(|desc| desc.set_to_binary()); + .zip_eq(format_iter) + .for_each(|(desc, format)| { + if let Format::Binary = format { + desc.set_to_binary(); + } + }); row_description - } else { - self.row_description.clone() }; Ok(PgPortal { name: portal_name, query_string: instance_query_string, - result_format, + result_formats, is_query: self.is_query, row_description, result: None, @@ -120,7 +125,7 @@ where { name: String, query_string: String, - result_format: bool, + result_formats: Vec, is_query: bool, row_description: Vec, result: Option>, @@ -156,7 +161,7 @@ where result } else { let result = session - .run_statement(self.query_string.as_str(), self.result_format) + .run_statement(self.query_string.as_str(), self.result_formats.clone()) .await .map_err(|err| PsqlError::ExecuteError(err))?; self.result = Some(result); @@ -364,150 +369,123 @@ impl PreparedStatement { }) } - /// parse_params is to parse raw_params:&[Bytes] into params:[String]. - /// The param produced by this function will be used in the PreparedStatement. - /// - /// type_description is a list of type oids. - /// raw_params is a list of raw params. - /// param_format is format code : false for text, true for binary. - /// - /// # Example - /// - /// ```ignore - /// let raw_params = vec!["A".into(), "B".into(), "C".into()]; - /// let type_description = vec![DataType::Varchar; 3]; - /// let params = parse_params(&type_description, &raw_params,false); - /// assert_eq!(params, vec!["'A'", "'B'", "'C'"]) - /// - /// let raw_params = vec!["1".into(), "2".into(), "3.1".into()]; - /// let type_description = vec![DataType::INT,DataType::INT,DataType::FLOAT4]; - /// let params = parse_params(&type_description, &raw_params,false); - /// assert_eq!(params, vec!["1::INT", "2::INT", "3.1::FLOAT4"]) - /// ``` fn parse_params( type_description: &[DataType], raw_params: &[Bytes], - param_format: bool, + param_formats: &[Format], ) -> PsqlResult> { + // Invariant check if type_description.len() != raw_params.len() { return Err(PsqlError::Internal(anyhow!( "The number of params doesn't match the number of types" ))); } - assert_eq!(type_description.len(), raw_params.len()); + if raw_params.is_empty() { + return Ok(vec![]); + } let mut params = Vec::with_capacity(raw_params.len()); - // BINARY FORMAT PARAMS let place_hodler = Type::ANY; - for (type_oid, raw_param) in zip_eq(type_description.iter(), raw_params.iter()) { + let format_iter = FormatIterator::new(param_formats, raw_params.len()) + .map_err(|err| PsqlError::Internal(anyhow!(err)))?; + + for ((type_oid, raw_param), param_format) in + zip_eq(type_description.iter(), raw_params.iter()).zip_eq(format_iter) + { let str = match type_oid { DataType::Varchar | DataType::Bytea => { format!("'{}'", cstr_to_str(raw_param).unwrap().replace('\'', "''")) } - DataType::Boolean => { - if param_format { - bool::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() - } - } - DataType::Int64 => { - if param_format { - i64::from_sql(&place_hodler, raw_param).unwrap().to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() - } - } - DataType::Int16 => { - if param_format { - i16::from_sql(&place_hodler, raw_param).unwrap().to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() - } - } - DataType::Int32 => { - if param_format { - i32::from_sql(&place_hodler, raw_param).unwrap().to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() - } - } + DataType::Boolean => match param_format { + Format::Binary => bool::from_sql(&place_hodler, raw_param) + .unwrap() + .to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), + }, + DataType::Int64 => match param_format { + Format::Binary => i64::from_sql(&place_hodler, raw_param).unwrap().to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), + }, + DataType::Int16 => match param_format { + Format::Binary => i16::from_sql(&place_hodler, raw_param).unwrap().to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), + }, + DataType::Int32 => match param_format { + Format::Binary => i32::from_sql(&place_hodler, raw_param).unwrap().to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), + }, DataType::Float32 => { - let tmp = if param_format { - f32::from_sql(&place_hodler, raw_param).unwrap().to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() + let tmp = match param_format { + Format::Binary => { + f32::from_sql(&place_hodler, raw_param).unwrap().to_string() + } + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::FLOAT4", tmp) } DataType::Float64 => { - let tmp = if param_format { - f64::from_sql(&place_hodler, raw_param).unwrap().to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() + let tmp = match param_format { + Format::Binary => { + f64::from_sql(&place_hodler, raw_param).unwrap().to_string() + } + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::FLOAT8", tmp) } DataType::Date => { - let tmp = if param_format { - chrono::NaiveDate::from_sql(&place_hodler, raw_param) + let tmp = match param_format { + Format::Binary => chrono::NaiveDate::from_sql(&place_hodler, raw_param) .unwrap() - .to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() + .to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::DATE", tmp) } DataType::Time => { - let tmp = if param_format { - chrono::NaiveTime::from_sql(&place_hodler, raw_param) + let tmp = match param_format { + Format::Binary => chrono::NaiveTime::from_sql(&place_hodler, raw_param) .unwrap() - .to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() + .to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::TIME", tmp) } DataType::Timestamp => { - let tmp = if param_format { - chrono::NaiveDateTime::from_sql(&place_hodler, raw_param) + let tmp = match param_format { + Format::Binary => chrono::NaiveDateTime::from_sql(&place_hodler, raw_param) .unwrap() - .to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() + .to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::TIMESTAMP", tmp) } DataType::Decimal => { - let tmp = if param_format { - rust_decimal::Decimal::from_sql(&place_hodler, raw_param) + let tmp = match param_format { + Format::Binary => rust_decimal::Decimal::from_sql(&place_hodler, raw_param) .unwrap() - .to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() + .to_string(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::DECIMAL", tmp) } DataType::Timestamptz => { - let tmp = if param_format { - chrono::DateTime::::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string() - } else { - cstr_to_str(raw_param).unwrap().to_string() + let tmp = match param_format { + Format::Binary => { + chrono::DateTime::::from_sql(&place_hodler, raw_param) + .unwrap() + .to_string() + } + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::TIMESTAMPTZ", tmp) } DataType::Interval => { - let tmp = if param_format { - pg_interval::Interval::from_sql(&place_hodler, raw_param) + let tmp = match param_format { + Format::Binary => pg_interval::Interval::from_sql(&place_hodler, raw_param) .unwrap() - .to_postgres() - } else { - cstr_to_str(raw_param).unwrap().to_string() + .to_postgres(), + Format::Text => cstr_to_str(raw_param).unwrap().to_string(), }; format!("'{}'::INTERVAL", tmp) } @@ -602,8 +580,8 @@ impl PreparedStatement { Ok(self.replace_params(&default_params)) } - pub fn instance(&self, raw_params: &[Bytes], param_format: bool) -> PsqlResult { - let params = Self::parse_params(&self.param_types, raw_params, param_format)?; + pub fn instance(&self, raw_params: &[Bytes], param_formats: &[Format]) -> PsqlResult { + let params = Self::parse_params(&self.param_types, raw_params, param_formats)?; Ok(self.replace_params(¶ms)) } } @@ -620,6 +598,7 @@ mod tests { use tokio_postgres::types::{ToSql, Type}; use crate::pg_extended::PreparedStatement; + use crate::types::Format; #[test] fn test_prepared_statement_without_param() { @@ -627,7 +606,7 @@ mod tests { let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); let default_sql = prepared_statement.instance_default().unwrap(); assert!("SELECT * FROM test_table" == default_sql); - let sql = prepared_statement.instance(&[], false).unwrap(); + let sql = prepared_statement.instance(&[], &[]).unwrap(); assert!("SELECT * FROM test_table" == sql); } @@ -639,7 +618,7 @@ mod tests { .unwrap(); let default_sql = prepared_statement.instance_default().unwrap(); assert!("SELECT * FROM test_table WHERE id = 0::INT" == default_sql); - let sql = prepared_statement.instance(&["1".into()], false).unwrap(); + let sql = prepared_statement.instance(&["1".into()], &[]).unwrap(); assert!("SELECT * FROM test_table WHERE id = 1" == sql); let raw_statement = "INSERT INTO test (index,data) VALUES ($1,$2)".to_string(); @@ -651,7 +630,7 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("INSERT INTO test (index,data) VALUES (0::INT,'0')" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("INSERT INTO test (index,data) VALUES (1,'DATA')" == sql); @@ -664,7 +643,7 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("UPDATE COFFEES SET SALES = 0::INT WHERE COF_NAME LIKE '0'" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("UPDATE COFFEES SET SALES = 1 WHERE COF_NAME LIKE 'DATA'" == sql); @@ -681,7 +660,7 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("SELECT * FROM test_table WHERE id = 0::INT AND name = '0'" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into(), "NAME".into()], false) + .instance(&["1".into(), "DATA".into(), "NAME".into()], &[]) .unwrap(); assert!("SELECT * FROM test_table WHERE id = 1 AND name = 'NAME'" == sql); } @@ -692,7 +671,7 @@ mod tests { let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); let default_sql = prepared_statement.instance_default().unwrap(); assert!("SELECT * FROM test_table WHERE id = 0::INT" == default_sql); - let sql = prepared_statement.instance(&["1".into()], false).unwrap(); + let sql = prepared_statement.instance(&["1".into()], &[]).unwrap(); assert!("SELECT * FROM test_table WHERE id = 1" == sql); let raw_statement = @@ -701,7 +680,7 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("INSERT INTO test (index,data) VALUES (0::INT,'0')" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("INSERT INTO test (index,data) VALUES (1,'DATA')" == sql); @@ -711,7 +690,7 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("UPDATE COFFEES SET SALES = 0::INT WHERE COF_NAME LIKE '0'" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("UPDATE COFFEES SET SALES = 1 WHERE COF_NAME LIKE 'DATA'" == sql); } @@ -726,7 +705,7 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("SELECT * FROM test_table WHERE id = 0::INT AND name = '0'" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("SELECT * FROM test_table WHERE id = 1 AND name = 'DATA'" == sql); @@ -737,7 +716,7 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("INSERT INTO test (index,data) VALUES (0::INT,'0')" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("INSERT INTO test (index,data) VALUES (1,'DATA')" == sql); @@ -749,14 +728,14 @@ mod tests { let default_sql = prepared_statement.instance_default().unwrap(); assert!("UPDATE COFFEES SET SALES = 0::INT WHERE COF_NAME LIKE '0'" == default_sql); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("UPDATE COFFEES SET SALES = 1 WHERE COF_NAME LIKE 'DATA'" == sql); let raw_statement = "SELECT $1,$2;".to_string(); let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); let sql = prepared_statement - .instance(&["test$2".into(), "test$1".into()], false) + .instance(&["test$2".into(), "test$1".into()], &[]) .unwrap(); assert!("SELECT 'test$2','test$1';" == sql); @@ -765,7 +744,7 @@ mod tests { PreparedStatement::parse_statement(raw_statement, vec![DataType::INT32.to_oid()]) .unwrap(); let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], false) + .instance(&["1".into(), "DATA".into()], &[]) .unwrap(); assert!("SELECT 1,1,'DATA','DATA';" == sql); } @@ -774,20 +753,17 @@ mod tests { fn test_parse_params_text() { let raw_params = vec!["A".into(), "B".into(), "C".into()]; let type_description = vec![DataType::Varchar; 3]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, false).unwrap(); + let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); assert_eq!(params, vec!["'A'", "'B'", "'C'"]); let raw_params = vec!["false".into(), "true".into()]; let type_description = vec![DataType::Boolean; 2]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, false).unwrap(); + let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); assert_eq!(params, vec!["false", "true"]); let raw_params = vec!["1".into(), "2".into(), "3".into()]; let type_description = vec![DataType::Int16, DataType::Int32, DataType::Int64]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, false).unwrap(); + let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); assert_eq!(params, vec!["1", "2", "3"]); let raw_params = vec![ @@ -799,8 +775,7 @@ mod tests { .into(), ]; let type_description = vec![DataType::Float32, DataType::Float64, DataType::Decimal]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, false).unwrap(); + let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); assert_eq!( params, vec!["'1.0'::FLOAT4", "'2.0'::FLOAT8", "'3'::DECIMAL"] @@ -821,8 +796,7 @@ mod tests { .into(), ]; let type_description = vec![DataType::Date, DataType::Time, DataType::Timestamp]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, false).unwrap(); + let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); assert_eq!( params, vec![ @@ -840,7 +814,9 @@ mod tests { // Test VACHAR type. let raw_params = vec!["A".into(), "B".into(), "C".into()]; let type_description = vec![DataType::Varchar; 3]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, true).unwrap(); + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) + .unwrap(); assert_eq!(params, vec!["'A'", "'B'", "'C'"]); // Test BOOLEAN type. @@ -852,7 +828,9 @@ mod tests { .map(|b| b.freeze()) .collect::>(); let type_description = vec![DataType::Boolean; 2]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, true).unwrap(); + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) + .unwrap(); assert_eq!(params, vec!["false", "true"]); // Test SMALLINT, INT, BIGINT type. @@ -865,7 +843,9 @@ mod tests { .map(|b| b.freeze()) .collect::>(); let type_description = vec![DataType::Int16, DataType::Int32, DataType::Int64]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, true).unwrap(); + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) + .unwrap(); assert_eq!(params, vec!["1", "2", "3"]); // Test FLOAT4, FLOAT8, DECIMAL type. @@ -881,7 +861,9 @@ mod tests { .map(|b| b.freeze()) .collect::>(); let type_description = vec![DataType::Float32, DataType::Float64, DataType::Decimal]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, true).unwrap(); + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) + .unwrap(); assert_eq!(params, vec!["'1'::FLOAT4", "'2'::FLOAT8", "'3'::DECIMAL"]); let mut raw_params = vec![BytesMut::new(); 3]; @@ -897,7 +879,9 @@ mod tests { .map(|b| b.freeze()) .collect::>(); let type_description = vec![DataType::Float32, DataType::Float64, DataType::Float64]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, true).unwrap(); + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) + .unwrap(); assert_eq!( params, vec!["'NaN'::FLOAT4", "'inf'::FLOAT8", "'-inf'::FLOAT8"] @@ -922,7 +906,9 @@ mod tests { .map(|b| b.freeze()) .collect::>(); let type_description = vec![DataType::Date, DataType::Time, DataType::Timestamp]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, true).unwrap(); + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) + .unwrap(); assert_eq!( params, vec![ @@ -944,7 +930,9 @@ mod tests { .map(|b| b.freeze()) .collect::>(); let type_description = vec![DataType::Timestamptz, DataType::Interval]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, true).unwrap(); + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) + .unwrap(); assert_eq!( params, vec![ @@ -953,4 +941,67 @@ mod tests { ] ); } + + #[test] + fn test_parse_params_mix_format() { + let place_hodler = Type::ANY; + + // Test VACHAR type. + let raw_params = vec!["A".into(), "B".into(), "C".into()]; + let type_description = vec![DataType::Varchar; 3]; + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Text; 3]) + .unwrap(); + assert_eq!(params, vec!["'A'", "'B'", "'C'"]); + + // Test BOOLEAN type. + let mut raw_params = vec![BytesMut::new(); 2]; + false.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); + true.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); + let raw_params = raw_params + .into_iter() + .map(|b| b.freeze()) + .collect::>(); + let type_description = vec![DataType::Boolean; 2]; + let params = + PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary; 2]) + .unwrap(); + assert_eq!(params, vec!["false", "true"]); + + // Test SMALLINT, INT, BIGINT type. + let mut raw_params = vec![BytesMut::new(); 2]; + 1_i16.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); + 2_i32.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); + let mut raw_params = raw_params + .into_iter() + .map(|b| b.freeze()) + .collect::>(); + raw_params.push("3".into()); + let type_description = vec![DataType::Int16, DataType::Int32, DataType::Int64]; + let params = PreparedStatement::parse_params( + &type_description, + &raw_params, + &[Format::Binary, Format::Binary, Format::Text], + ) + .unwrap(); + assert_eq!(params, vec!["1", "2", "3"]); + + // Test FLOAT4, FLOAT8, DECIMAL type. + let mut raw_params = vec![BytesMut::new(); 2]; + 1.0_f32.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); + 2.0_f64.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); + let mut raw_params = raw_params + .into_iter() + .map(|b| b.freeze()) + .collect::>(); + raw_params.push("TEST".into()); + let type_description = vec![DataType::Float32, DataType::Float64, DataType::VARCHAR]; + let params = PreparedStatement::parse_params( + &type_description, + &raw_params, + &[Format::Binary, Format::Binary, Format::Text], + ) + .unwrap(); + assert_eq!(params, vec!["'1'::FLOAT4", "'2'::FLOAT8", "'TEST'"]); + } } diff --git a/src/utils/pgwire/src/pg_message.rs b/src/utils/pgwire/src/pg_message.rs index 68e351d29804..634d2f861b20 100644 --- a/src/utils/pgwire/src/pg_message.rs +++ b/src/utils/pgwire/src/pg_message.rs @@ -82,15 +82,8 @@ pub struct FeQueryMessage { #[derive(Debug)] pub struct FeBindMessage { - // param_format_code: - // false: text - // true: binary - pub param_format_code: bool, - - // result_format_code: - // false: text - // true: binary - pub result_format_code: bool, + pub param_format_codes: Vec, + pub result_format_codes: Vec, pub params: Vec, pub portal_name: Bytes, @@ -175,22 +168,10 @@ impl FeBindMessage { pub fn parse(mut buf: Bytes) -> Result { let portal_name = read_null_terminated(&mut buf)?; let statement_name = read_null_terminated(&mut buf)?; - // Read FormatCode + let len = buf.get_i16(); + let param_format_codes = (0..len).map(|_| buf.get_i16()).collect(); - let param_format_code = if len == 0 || len == 1 { - if len == 0 { - false - } else { - buf.get_i16() == 1 - } - } else { - let first_value = buf.get_i16(); - for _ in 1..len { - assert!(buf.get_i16() == first_value,"Only support uniform param format (TEXT or BINARY), can't support mix format now."); - } - first_value == 1 - }; // Read Params let len = buf.get_i16(); let params = (0..len) @@ -200,34 +181,12 @@ impl FeBindMessage { }) .collect(); - // Read ResultFormatCode - // result format code depend on following rule: - // - If the length is 0, format is false(text). - // - If the length is 1, format is decide by format_codes[0]. - // - If the length > 1, each column can have their own format and it depend on according - // format code. But RisingWave can't support return col with different format now, when - // length>1, we guarantee all format code is the same (0,0,0..) or (1,1,1,...). let len = buf.get_i16(); - let format_codes = (0..len).map(|_| buf.get_i16()).collect::>(); - let all_elements_are_equal = format_codes.iter().all(|&x| x == format_codes[0]); - - if !all_elements_are_equal { - return Err(Error::new( - ErrorKind::InvalidInput, - "Only support uniform result format (TEXT or BINARY), can't support mix format now.", - )); - } - - let result_format_code = if len == 0 { - // default format:text - false - } else { - format_codes[0] == 1 - }; + let result_format_codes = (0..len).map(|_| buf.get_i16()).collect(); Ok(FeMessage::Bind(FeBindMessage { - param_format_code, - result_format_code, + param_format_codes, + result_format_codes, params, portal_name, statement_name, diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 8acd0e4156ce..f249ade3691d 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -40,6 +40,7 @@ use crate::pg_message::{ }; use crate::pg_response::RowSetResult; use crate::pg_server::{Session, SessionManager, UserAuthenticator}; +use crate::types::Format; /// The state machine for each psql connection. /// Read pg messages from tcp stream and write results back. @@ -335,7 +336,7 @@ where // execute query let mut res = session - .run_one_query(stmt, false) + .run_one_query(stmt, Format::Text) .await .map_err(|err| PsqlError::QueryError(err))?; @@ -459,13 +460,24 @@ where .ok_or_else(PsqlError::no_statement)? }; + let result_formats = msg + .result_format_codes + .iter() + .map(|&format_code| Format::from_i16(format_code)) + .try_collect()?; + let param_formats = msg + .param_format_codes + .iter() + .map(|&format_code| Format::from_i16(format_code)) + .try_collect()?; + // 2. Instance the statement to get the portal. let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_string(); let portal = statement.instance( portal_name.clone(), &msg.params, - msg.result_format_code, - msg.param_format_code, + result_formats, + param_formats, )?; // 3. Insert the Portal. diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index e0f6a0ba3dce..8782a4236fcb 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -26,6 +26,7 @@ use tracing::debug; use crate::pg_field_descriptor::PgFieldDescriptor; use crate::pg_protocol::{PgProtocol, TlsConfig}; use crate::pg_response::{PgResponse, RowSetResult}; +use crate::types::Format; pub type BoxedError = Box; pub type SessionId = (i32, i32); @@ -46,10 +47,6 @@ where /// A psql connection. Each connection binds with a database. Switching database will need to /// recreate another connection. -/// -/// format: -/// false: TEXT -/// true: BINARY #[async_trait::async_trait] pub trait Session: Send + Sync where @@ -58,7 +55,7 @@ where async fn run_statement( self: Arc, sql: &str, - format: bool, + formats: Vec, ) -> Result, BoxedError>; /// The str sql can not use the unparse from AST: There is some problem when dealing with create @@ -66,7 +63,7 @@ where async fn run_one_query( self: Arc, sql: Statement, - format: bool, + format: Format, ) -> Result, BoxedError>; async fn infer_return_type( @@ -179,6 +176,7 @@ mod tests { use crate::pg_server::{ pg_serve, BoxedError, Session, SessionId, SessionManager, UserAuthenticator, }; + use crate::types; use crate::types::Row; struct MockSessionManager {} @@ -208,7 +206,7 @@ mod tests { async fn run_statement( self: Arc, sql: &str, - _format: bool, + _format: Vec, ) -> Result>, Box> { // split a statement and trim \' around the input param to construct result. @@ -246,7 +244,7 @@ mod tests { async fn run_one_query( self: Arc, _sql: Statement, - _format: bool, + _format: types::Format, ) -> Result>, BoxedError> { let res: Vec> = vec![Some(Bytes::new())]; Ok(PgResponse::new_for_stream( diff --git a/src/utils/pgwire/src/types.rs b/src/utils/pgwire/src/types.rs index 0e3e4b390891..6b6b91db0294 100644 --- a/src/utils/pgwire/src/types.rs +++ b/src/utils/pgwire/src/types.rs @@ -13,9 +13,13 @@ // limitations under the License. use std::ops::Index; +use std::slice::Iter; +use anyhow::anyhow; use bytes::Bytes; +use crate::error::{PsqlError, PsqlResult}; + /// A row of data returned from the database by a query. #[derive(Debug, Clone)] // NOTE: Since we only support simple query protocol, the values are represented as strings. @@ -50,3 +54,81 @@ impl Index for Row { &self.0[index] } } + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Format { + Binary, + Text, +} + +impl Format { + pub fn from_i16(format_code: i16) -> PsqlResult { + match format_code { + 0 => Ok(Format::Text), + 1 => Ok(Format::Binary), + _ => Err(PsqlError::Internal(anyhow!( + "Unknown format code: {}", + format_code + ))), + } + } +} + +/// FormatIterator used to generate formats of actual length given the provided format. +/// According Postgres Document: +/// - If the length of provided format is 0, all format will be default format(TEXT). +/// - If the length of provided format is 1, all format will be the same as this only format. +/// - If the length of provided format > 1, provided format should be the actual format. +#[derive(Debug, Clone)] +pub struct FormatIterator<'a, 'b> +where + 'a: 'b, +{ + _formats: &'a [Format], + format_iter: Iter<'b, Format>, + actual_len: usize, + default_format: Format, +} + +impl<'a, 'b> FormatIterator<'a, 'b> { + pub fn new(provided_formats: &'a [Format], actual_len: usize) -> Result { + if !provided_formats.is_empty() + && provided_formats.len() != 1 + && provided_formats.len() != actual_len + { + return Err(format!( + "format codes length {} is not 0, 1 or equal to actual length {}", + provided_formats.len(), + actual_len + )); + } + + let default_format = provided_formats.get(0).copied().unwrap_or(Format::Text); + + Ok(Self { + _formats: provided_formats, + default_format, + format_iter: provided_formats.iter(), + actual_len, + }) + } +} + +impl Iterator for FormatIterator<'_, '_> { + type Item = Format; + + fn next(&mut self) -> Option { + if self.actual_len == 0 { + return None; + } + + self.actual_len -= 1; + + Some( + self.format_iter + .next() + .copied() + .unwrap_or(self.default_format), + ) + } +}