From 827806166bf4b3718a1f7f61d4c4d0db222a61b3 Mon Sep 17 00:00:00 2001 From: ZENOTME <43447882+ZENOTME@users.noreply.github.com> Date: Mon, 3 Apr 2023 16:30:10 +0800 Subject: [PATCH] refactor(frontend): refactor extended query mode (#8919) --- src/frontend/src/binder/bind_param.rs | 8 +- src/frontend/src/binder/delete.rs | 2 +- src/frontend/src/binder/insert.rs | 2 +- src/frontend/src/binder/mod.rs | 7 + src/frontend/src/binder/statement.rs | 23 +- src/frontend/src/binder/update.rs | 2 +- src/frontend/src/handler/extended_handle.rs | 118 ++- src/frontend/src/handler/query.rs | 91 +- src/frontend/src/session.rs | 357 +++---- src/frontend/src/test_utils.rs | 5 +- src/tests/e2e_extended_mode/src/test.rs | 16 +- src/utils/pgwire/src/error.rs | 9 +- src/utils/pgwire/src/pg_extended.rs | 971 +------------------- src/utils/pgwire/src/pg_protocol.rs | 318 ++++--- src/utils/pgwire/src/pg_server.rs | 288 ++---- src/utils/pgwire/src/types.rs | 8 +- 16 files changed, 716 insertions(+), 1509 deletions(-) diff --git a/src/frontend/src/binder/bind_param.rs b/src/frontend/src/binder/bind_param.rs index e6ab32dd34e1..0106a6e1d553 100644 --- a/src/frontend/src/binder/bind_param.rs +++ b/src/frontend/src/binder/bind_param.rs @@ -13,8 +13,8 @@ // limitations under the License. use bytes::Bytes; -use pgwire::types::Format; -use risingwave_common::error::{Result, RwError}; +use pgwire::types::{Format, FormatIterator}; +use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::types::ScalarImpl; use super::statement::RewriteExprsRecursive; @@ -85,8 +85,10 @@ impl BoundStatement { param_formats: Vec, ) -> Result { let mut rewriter = ParamRewriter { + param_formats: FormatIterator::new(¶m_formats, params.len()) + .map_err(ErrorCode::BindError)? + .collect(), params, - param_formats, error: None, }; diff --git a/src/frontend/src/binder/delete.rs b/src/frontend/src/binder/delete.rs index 62a268e94bab..028624404903 100644 --- a/src/frontend/src/binder/delete.rs +++ b/src/frontend/src/binder/delete.rs @@ -22,7 +22,7 @@ use crate::catalog::TableId; use crate::expr::ExprImpl; use crate::user::UserId; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BoundDelete { /// Id of the table to perform deleting. pub table_id: TableId, diff --git a/src/frontend/src/binder/insert.rs b/src/frontend/src/binder/insert.rs index 300d1bd905b2..d9a7ca39ccc4 100644 --- a/src/frontend/src/binder/insert.rs +++ b/src/frontend/src/binder/insert.rs @@ -28,7 +28,7 @@ use crate::catalog::TableId; use crate::expr::{ExprImpl, InputRef}; use crate::user::UserId; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BoundInsert { /// Id of the table to perform inserting. pub table_id: TableId, diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index 06beb720ad57..991961c3ae05 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -222,6 +222,13 @@ impl Binder { Self::new_inner(session, true, vec![]) } + pub fn new_for_stream_with_param_types( + session: &SessionImpl, + param_types: Vec, + ) -> Binder { + Self::new_inner(session, true, param_types) + } + /// Bind a [`Statement`]. pub fn bind(&mut self, stmt: Statement) -> Result { self.bind_statement(stmt) diff --git a/src/frontend/src/binder/statement.rs b/src/frontend/src/binder/statement.rs index 17603dc530ec..1a94a6ce30d2 100644 --- a/src/frontend/src/binder/statement.rs +++ b/src/frontend/src/binder/statement.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use risingwave_common::catalog::Field; use risingwave_common::error::{ErrorCode, Result}; use risingwave_sqlparser::ast::Statement; @@ -20,7 +21,7 @@ use super::update::BoundUpdate; use crate::binder::{Binder, BoundInsert, BoundQuery}; use crate::expr::ExprRewriter; -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum BoundStatement { Insert(Box), Delete(Box), @@ -28,6 +29,26 @@ pub enum BoundStatement { Query(Box), } +impl BoundStatement { + pub fn output_fields(&self) -> Vec { + match self { + BoundStatement::Insert(i) => i.returning_schema.as_ref().map_or( + vec![Field::unnamed(risingwave_common::types::DataType::Int64)], + |s| s.fields().into(), + ), + BoundStatement::Delete(d) => d.returning_schema.as_ref().map_or( + vec![Field::unnamed(risingwave_common::types::DataType::Int64)], + |s| s.fields().into(), + ), + BoundStatement::Update(u) => u.returning_schema.as_ref().map_or( + vec![Field::unnamed(risingwave_common::types::DataType::Int64)], + |s| s.fields().into(), + ), + BoundStatement::Query(q) => q.schema().fields().into(), + } + } +} + impl Binder { pub(super) fn bind_statement(&mut self, stmt: Statement) -> Result { match stmt { diff --git a/src/frontend/src/binder/update.rs b/src/frontend/src/binder/update.rs index 7a65c617db3b..75bec2a8f6e4 100644 --- a/src/frontend/src/binder/update.rs +++ b/src/frontend/src/binder/update.rs @@ -27,7 +27,7 @@ use crate::catalog::TableId; use crate::expr::{Expr as _, ExprImpl}; use crate::user::UserId; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BoundUpdate { /// Id of the table to perform updating. pub table_id: TableId, diff --git a/src/frontend/src/handler/extended_handle.rs b/src/frontend/src/handler/extended_handle.rs index 359c8e4ebe57..5cf1c7d3948e 100644 --- a/src/frontend/src/handler/extended_handle.rs +++ b/src/frontend/src/handler/extended_handle.rs @@ -18,19 +18,33 @@ use bytes::Bytes; use pgwire::types::Format; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::DataType; -use risingwave_sqlparser::ast::Statement; +use risingwave_sqlparser::ast::{Query, Statement}; -use super::{query, HandlerArgs, RwPgResponse}; +use super::{handle, query, HandlerArgs, RwPgResponse}; use crate::binder::BoundStatement; use crate::session::SessionImpl; -pub struct PrepareStatement { +#[derive(Clone)] +pub enum PrepareStatement { + Prepared(PreparedResult), + PureStatement(Statement), +} + +#[derive(Clone)] +pub struct PreparedResult { pub statement: Statement, pub bound_statement: BoundStatement, pub param_types: Vec, } -pub struct Portal { +#[derive(Clone)] +pub enum Portal { + Portal(PortalResult), + PureStatement(Statement), +} + +#[derive(Clone)] +pub struct PortalResult { pub statement: Statement, pub bound_statement: BoundStatement, pub result_formats: Vec, @@ -44,16 +58,38 @@ pub fn handle_parse( session.clear_cancel_query_flag(); let str_sql = stmt.to_string(); let handler_args = HandlerArgs::new(session, &stmt, &str_sql)?; - match stmt { + match &stmt { Statement::Query(_) | Statement::Insert { .. } | Statement::Delete { .. } | Statement::Update { .. } => query::handle_parse(handler_args, stmt, specific_param_types), - _ => Err(ErrorCode::NotSupported( - format!("Can't support {} in extended query mode now", str_sql,), - "".to_string(), - ) - .into()), + Statement::CreateView { + query, + .. + } => { + if have_parameter_in_query(query) { + return Err(ErrorCode::NotImplemented( + "CREATE VIEW with parameters".to_string(), + None.into(), + ) + .into()); + } + Ok(PrepareStatement::PureStatement(stmt)) + } + Statement::CreateTable { + query, + .. + } => { + if let Some(query) = query && have_parameter_in_query(query) { + Err(ErrorCode::NotImplemented( + "CREATE TABLE AS SELECT with parameters".to_string(), + None.into(), + ).into()) + } else { + Ok(PrepareStatement::PureStatement(stmt)) + } + } + _ => Ok(PrepareStatement::PureStatement(stmt)), } } @@ -63,32 +99,46 @@ pub fn handle_bind( param_formats: Vec, result_formats: Vec, ) -> Result { - let PrepareStatement { - statement, - bound_statement, - .. - } = prepare_statement; - let bound_statement = bound_statement.bind_parameter(params, param_formats)?; - Ok(Portal { - statement, - bound_statement, - result_formats, - }) + match prepare_statement { + PrepareStatement::Prepared(prepared_result) => { + let PreparedResult { + statement, + bound_statement, + .. + } = prepared_result; + let bound_statement = bound_statement.bind_parameter(params, param_formats)?; + Ok(Portal::Portal(PortalResult { + statement, + bound_statement, + result_formats, + })) + } + PrepareStatement::PureStatement(stmt) => Ok(Portal::PureStatement(stmt)), + } } pub async fn handle_execute(session: Arc, portal: Portal) -> Result { - session.clear_cancel_query_flag(); - let str_sql = portal.statement.to_string(); - let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?; - match &portal.statement { - Statement::Query(_) - | Statement::Insert { .. } - | Statement::Delete { .. } - | Statement::Update { .. } => query::handle_execute(handler_args, portal).await, - _ => Err(ErrorCode::NotSupported( - format!("Can't support {} in extended query mode now", str_sql,), - "".to_string(), - ) - .into()), + match portal { + Portal::Portal(portal) => { + session.clear_cancel_query_flag(); + let str_sql = portal.statement.to_string(); + let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?; + match &portal.statement { + Statement::Query(_) + | Statement::Insert { .. } + | Statement::Delete { .. } + | Statement::Update { .. } => query::handle_execute(handler_args, portal).await, + _ => unreachable!(), + } + } + Portal::PureStatement(stmt) => { + let sql = stmt.to_string(); + handle(session, stmt, &sql, vec![]).await + } } } + +/// A quick way to check if a query contains parameters. +fn have_parameter_in_query(query: &Query) -> bool { + query.to_string().contains("$1") +} diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index bdb3b5c1bace..41c36994351d 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -27,9 +27,9 @@ use risingwave_common::session_config::QueryMode; use risingwave_common::types::DataType; use risingwave_sqlparser::ast::{SetExpr, Statement}; -use super::extended_handle::{Portal, PrepareStatement}; +use super::extended_handle::{PortalResult, PrepareStatement, PreparedResult}; use super::{PgResponseStream, RwPgResponse}; -use crate::binder::Binder; +use crate::binder::{Binder, BoundStatement}; use crate::catalog::TableId; use crate::handler::flush::do_flush; use crate::handler::privilege::resolve_privileges; @@ -368,6 +368,8 @@ pub async fn local_execute( Ok(execution.stream_rows()) } +// TODO: Following code have redundant code with `handle_query`, we may need to refactor them in +// future. pub fn handle_parse( handler_args: HandlerArgs, statement: Statement, @@ -382,15 +384,58 @@ pub fn handle_parse( let param_types = binder.export_param_types()?; - Ok(PrepareStatement { + Ok(PrepareStatement::Prepared(PreparedResult { statement, bound_statement, param_types, - }) + })) } -pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result { - let Portal { +pub fn gen_batch_query_plan_for_bound( + session: &SessionImpl, + context: OptimizerContextRef, + stmt: Statement, + bound: BoundStatement, +) -> Result<(PlanRef, QueryMode, Schema)> { + let must_dist = must_run_in_distributed_mode(&stmt)?; + + let mut planner = Planner::new(context); + + let mut logical = planner.plan(bound)?; + let schema = logical.schema(); + let batch_plan = logical.gen_batch_plan()?; + + let must_local = must_run_in_local_mode(batch_plan.clone()); + + let query_mode = match (must_dist, must_local) { + (true, true) => { + return Err(ErrorCode::InternalError( + "the query is forced to both local and distributed mode by optimizer".to_owned(), + ) + .into()) + } + (true, false) => QueryMode::Distributed, + (false, true) => QueryMode::Local, + (false, false) => match session.config().get_query_mode() { + QueryMode::Auto => determine_query_mode(batch_plan.clone()), + QueryMode::Local => QueryMode::Local, + QueryMode::Distributed => QueryMode::Distributed, + }, + }; + + let physical = match query_mode { + QueryMode::Auto => unreachable!(), + QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?, + QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?, + }; + Ok((physical, query_mode, schema)) +} + +pub async fn handle_execute( + handler_args: HandlerArgs, + portal: PortalResult, +) -> Result { + let PortalResult { statement, bound_statement, result_formats, @@ -407,38 +452,8 @@ pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result let (plan_fragmenter, query_mode, output_schema) = { let context = OptimizerContext::from_handler_args(handler_args); - let must_dist = must_run_in_distributed_mode(&statement)?; - - let mut planner = Planner::new(context.into()); - - let mut logical = planner.plan(bound_statement)?; - let schema = logical.schema(); - let batch_plan = logical.gen_batch_plan()?; - - let must_local = must_run_in_local_mode(batch_plan.clone()); - - let query_mode = match (must_dist, must_local) { - (true, true) => { - return Err(ErrorCode::InternalError( - "the query is forced to both local and distributed mode by optimizer" - .to_owned(), - ) - .into()) - } - (true, false) => QueryMode::Distributed, - (false, true) => QueryMode::Local, - (false, false) => match session.config().get_query_mode() { - QueryMode::Auto => determine_query_mode(batch_plan.clone()), - QueryMode::Local => QueryMode::Local, - QueryMode::Distributed => QueryMode::Distributed, - }, - }; - - let physical = match query_mode { - QueryMode::Auto => unreachable!(), - QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?, - QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?, - }; + let (physical, query_mode, schema) = + gen_batch_query_plan_for_bound(&session, context.into(), statement, bound_statement)?; let context = physical.plan_base().ctx.clone(); tracing::trace!( diff --git a/src/frontend/src/session.rs b/src/frontend/src/session.rs index e00d8f9b0b69..58389391f6ff 100644 --- a/src/frontend/src/session.rs +++ b/src/frontend/src/session.rs @@ -18,6 +18,7 @@ use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Duration; +use bytes::Bytes; use parking_lot::{RwLock, RwLockReadGuard}; use pgwire::pg_field_descriptor::PgFieldDescriptor; use pgwire::pg_response::PgResponse; @@ -56,19 +57,20 @@ use tokio::sync::watch; use tokio::task::JoinHandle; use tracing::info; -use crate::binder::Binder; +use crate::binder::{Binder, BoundStatement}; use crate::catalog::catalog_service::{CatalogReader, CatalogWriter, CatalogWriterImpl}; use crate::catalog::root_catalog::Catalog; use crate::catalog::{check_schema_writable, DatabaseId, SchemaId}; +use crate::handler::extended_handle::{ + handle_bind, handle_execute, handle_parse, Portal, PrepareStatement, +}; +use crate::handler::handle; use crate::handler::privilege::ObjectCheckItem; use crate::handler::util::to_pg_field; -use crate::handler::{handle, HandlerArgs}; use crate::health_service::HealthServiceImpl; use crate::meta_client::{FrontendMetaClient, FrontendMetaClientImpl}; use crate::monitor::FrontendMetrics; use crate::observer::FrontendObserverNode; -use crate::optimizer::OptimizerContext; -use crate::planner::Planner; use crate::scheduler::streaming_manager::{StreamingJobTracker, StreamingJobTrackerRef}; use crate::scheduler::worker_node_manager::{WorkerNodeManager, WorkerNodeManagerRef}; use crate::scheduler::SchedulerError::QueryCancelError; @@ -569,6 +571,51 @@ impl SessionImpl { pub fn cancel_current_creating_job(&self) { self.env.creating_streaming_job_tracker.abort_jobs(self.id); } + + /// This function only used for test now. + /// Maybe we can remove it in the future. + pub async fn run_statement( + self: Arc, + sql: &str, + formats: Vec, + ) -> std::result::Result, BoxedError> { + // Parse sql. + let mut stmts = Parser::parse_sql(sql) + .inspect_err(|e| tracing::error!("failed to parse sql:\n{}:\n{}", sql, e))?; + if stmts.is_empty() { + return Ok(PgResponse::empty_result( + pgwire::pg_response::StatementType::EMPTY, + )); + } + if stmts.len() > 1 { + return Ok(PgResponse::empty_result_with_notice( + pgwire::pg_response::StatementType::EMPTY, + "cannot insert multiple commands into statement".to_string(), + )); + } + let stmt = stmts.swap_remove(0); + let rsp = { + let mut handle_fut = Box::pin(handle(self, stmt, sql, formats)); + if cfg!(debug_assertions) { + // Report the SQL in the log periodically if the query is slow. + const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60); + loop { + match tokio::time::timeout(SLOW_QUERY_LOG_PERIOD, &mut handle_fut).await { + Ok(result) => break result, + Err(_) => tracing::warn!( + target: "risingwave_frontend_slow_query_log", + sql, + "slow query has been running for another {SLOW_QUERY_LOG_PERIOD:?}" + ), + } + } + } else { + handle_fut.await + } + } + .inspect_err(|e| tracing::error!("failed to handle sql:\n{}:\n{}", sql, e))?; + Ok(rsp) + } } pub struct SessionManagerImpl { @@ -578,7 +625,7 @@ pub struct SessionManagerImpl { number: AtomicI32, } -impl SessionManager for SessionManagerImpl { +impl SessionManager for SessionManagerImpl { type Session = SessionImpl; fn connect( @@ -717,29 +764,17 @@ impl SessionManagerImpl { } #[async_trait::async_trait] -impl Session for SessionImpl { - async fn run_statement( +impl Session for SessionImpl { + /// A copy of run_statement but exclude the parser part so each run must be at most one + /// statement. The str sql use the to_string of AST. Consider Reuse later. + async fn run_one_query( self: Arc, - sql: &str, - formats: Vec, + stmt: Statement, + format: Format, ) -> std::result::Result, BoxedError> { - // Parse sql. - let mut stmts = Parser::parse_sql(sql) - .inspect_err(|e| tracing::error!("failed to parse sql:\n{}:\n{}", sql, e))?; - if stmts.is_empty() { - return Ok(PgResponse::empty_result( - pgwire::pg_response::StatementType::EMPTY, - )); - } - if stmts.len() > 1 { - return Ok(PgResponse::empty_result_with_notice( - pgwire::pg_response::StatementType::EMPTY, - "cannot insert multiple commands into statement".to_string(), - )); - } - let stmt = stmts.swap_remove(0); + let sql_str = stmt.to_string(); let rsp = { - let mut handle_fut = Box::pin(handle(self, stmt, sql, formats)); + let mut handle_fut = Box::pin(handle(self, stmt, &sql_str, vec![format])); if cfg!(debug_assertions) { // Report the SQL in the log periodically if the query is slow. const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60); @@ -747,8 +782,7 @@ impl Session for SessionImpl { match tokio::time::timeout(SLOW_QUERY_LOG_PERIOD, &mut handle_fut).await { Ok(result) => break result, Err(_) => tracing::warn!( - target: "risingwave_frontend_slow_query_log", - sql, + sql_str, "slow query has been running for another {SLOW_QUERY_LOG_PERIOD:?}" ), } @@ -757,20 +791,47 @@ impl Session for SessionImpl { handle_fut.await } } - .inspect_err(|e| tracing::error!("failed to handle sql:\n{}:\n{}", sql, e))?; + .inspect_err(|e| tracing::error!("failed to handle sql:\n{}:\n{}", sql_str, e))?; Ok(rsp) } - /// A copy of run_statement but exclude the parser part so each run must be at most one - /// statement. The str sql use the to_string of AST. Consider Reuse later. - async fn run_one_query( + fn user_authenticator(&self) -> &UserAuthenticator { + &self.user_authenticator + } + + fn id(&self) -> SessionId { + self.id + } + + fn parse( self: Arc, - stmt: Statement, - format: Format, + statement: Statement, + params_types: Vec, + ) -> std::result::Result { + Ok(handle_parse(self, statement, params_types)?) + } + + fn bind( + self: Arc, + prepare_statement: PrepareStatement, + params: Vec, + param_formats: Vec, + result_formats: Vec, + ) -> std::result::Result { + Ok(handle_bind( + prepare_statement, + params, + param_formats, + result_formats, + )?) + } + + async fn execute( + self: Arc, + portal: Portal, ) -> std::result::Result, BoxedError> { - let sql_str = stmt.to_string(); let rsp = { - let mut handle_fut = Box::pin(handle(self, stmt, &sql_str, vec![format])); + let mut handle_fut = Box::pin(handle_execute(self, portal)); if cfg!(debug_assertions) { // Report the SQL in the log periodically if the query is slow. const SLOW_QUERY_LOG_PERIOD: Duration = Duration::from_secs(60); @@ -778,7 +839,6 @@ impl Session for SessionImpl { match tokio::time::timeout(SLOW_QUERY_LOG_PERIOD, &mut handle_fut).await { Ok(result) => break result, Err(_) => tracing::warn!( - sql_str, "slow query has been running for another {SLOW_QUERY_LOG_PERIOD:?}" ), } @@ -787,152 +847,125 @@ impl Session for SessionImpl { handle_fut.await } } - .inspect_err(|e| tracing::error!("failed to handle sql:\n{}:\n{}", sql_str, e))?; + .inspect_err(|e| tracing::error!("failed to handle execute:\n{}", e))?; Ok(rsp) } - async fn infer_return_type( + fn describe_statement( self: Arc, - sql: &str, + prepare_statement: PrepareStatement, + ) -> std::result::Result<(Vec, Vec), BoxedError> { + Ok(match prepare_statement { + PrepareStatement::Prepared(prepare_statement) => ( + prepare_statement.param_types, + infer( + Some(prepare_statement.bound_statement), + prepare_statement.statement, + )?, + ), + PrepareStatement::PureStatement(statement) => (vec![], infer(None, statement)?), + }) + } + + fn describe_portral( + self: Arc, + portal: Portal, ) -> std::result::Result, BoxedError> { - // Parse sql. - let mut stmts = Parser::parse_sql(sql) - .inspect_err(|e| tracing::error!("failed to parse sql:\n{}:\n{}", sql, e))?; - if stmts.is_empty() { - return Ok(vec![]); + match portal { + Portal::Portal(portal) => Ok(infer(Some(portal.bound_statement), portal.statement)?), + Portal::PureStatement(statement) => Ok(infer(None, statement)?), } - if stmts.len() > 1 { - return Err(Box::new(Error::new( - ErrorKind::InvalidInput, - "cannot insert multiple commands into statement", - ))); - } - let stmt = stmts.swap_remove(0); - // This part refers from src/frontend/handler/ so the Vec is same as - // result of run_statement(). - let rsp = match stmt { - Statement::Query(_) => infer(self, stmt, sql) - .inspect_err(|e| tracing::error!("failed to handle sql:\n{}:\n{}", sql, e))?, - Statement::ShowObjects(show_object) => match show_object { - ShowObject::Columns { table: _ } => { - vec![ - PgFieldDescriptor::new( - "Name".to_owned(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), - PgFieldDescriptor::new( - "Type".to_owned(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), - ] - } - _ => { - vec![PgFieldDescriptor::new( - "Name".to_owned(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - )] - } - }, - Statement::ShowCreateObject { .. } => { - vec![ - PgFieldDescriptor::new( - "Name".to_owned(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), + } +} + +/// Returns row description of the statement +fn infer(bound: Option, stmt: Statement) -> Result> { + match stmt { + Statement::Query(_) + | Statement::Insert { .. } + | Statement::Delete { .. } + | Statement::Update { .. } => Ok(bound + .unwrap() + .output_fields() + .iter() + .map(to_pg_field) + .collect()), + Statement::ShowObjects(show_object) => match show_object { + ShowObject::Columns { table: _ } => Ok(vec![ + PgFieldDescriptor::new( + "Name".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "Type".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + ]), + _ => Ok(vec![PgFieldDescriptor::new( + "Name".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + )]), + }, + Statement::ShowCreateObject { .. } => Ok(vec![ + PgFieldDescriptor::new( + "Name".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "Create Sql".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + ]), + Statement::ShowVariable { variable } => { + let name = &variable[0].real_value().to_lowercase(); + if name.eq_ignore_ascii_case("ALL") { + Ok(vec![ PgFieldDescriptor::new( - "Create Sql".to_owned(), + "Name".to_string(), DataType::Varchar.to_oid(), DataType::Varchar.type_len(), ), - ] - } - Statement::ShowVariable { variable } => { - let name = &variable[0].real_value().to_lowercase(); - if name.eq_ignore_ascii_case("ALL") { - vec![ - PgFieldDescriptor::new( - "Name".to_string(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), - PgFieldDescriptor::new( - "Setting".to_string(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), - PgFieldDescriptor::new( - "Description".to_string(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - ), - ] - } else { - vec![PgFieldDescriptor::new( - name.to_ascii_lowercase(), - DataType::Varchar.to_oid(), - DataType::Varchar.type_len(), - )] - } - } - Statement::Describe { name: _ } => { - vec![ PgFieldDescriptor::new( - "Name".to_owned(), + "Setting".to_string(), DataType::Varchar.to_oid(), DataType::Varchar.type_len(), ), PgFieldDescriptor::new( - "Type".to_owned(), + "Description".to_string(), DataType::Varchar.to_oid(), DataType::Varchar.type_len(), ), - ] - } - Statement::Explain { .. } => { - vec![PgFieldDescriptor::new( - "QUERY PLAN".to_owned(), + ]) + } else { + Ok(vec![PgFieldDescriptor::new( + name.to_ascii_lowercase(), DataType::Varchar.to_oid(), DataType::Varchar.type_len(), - )] + )]) } - _ => { - panic!("infer_return_type only support query statement"); - } - }; - Ok(rsp) - } - - fn user_authenticator(&self) -> &UserAuthenticator { - &self.user_authenticator - } - - fn id(&self) -> SessionId { - self.id + } + Statement::Describe { name: _ } => Ok(vec![ + PgFieldDescriptor::new( + "Name".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + PgFieldDescriptor::new( + "Type".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + ), + ]), + Statement::Explain { .. } => Ok(vec![PgFieldDescriptor::new( + "QUERY PLAN".to_owned(), + DataType::Varchar.to_oid(), + DataType::Varchar.type_len(), + )]), + _ => Ok(vec![]), } } - -/// Returns row description of the statement -fn infer(session: Arc, stmt: Statement, sql: &str) -> Result> { - let context = OptimizerContext::from_handler_args(HandlerArgs::new(session, &stmt, sql)?); - let session = context.session_ctx().clone(); - - let bound = { - let mut binder = Binder::new(&session); - binder.bind(stmt)? - }; - - let root = Planner::new(context.into()).plan(bound)?; - - let pg_descs = root - .schema() - .fields() - .iter() - .map(to_pg_field) - .collect::>(); - - Ok(pg_descs) -} diff --git a/src/frontend/src/test_utils.rs b/src/frontend/src/test_utils.rs index f0693eb315e9..0aa52f8d3eba 100644 --- a/src/frontend/src/test_utils.rs +++ b/src/frontend/src/test_utils.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use futures_async_stream::for_await; use parking_lot::RwLock; use pgwire::pg_response::StatementType; -use pgwire::pg_server::{BoxedError, Session, SessionId, SessionManager, UserAuthenticator}; +use pgwire::pg_server::{BoxedError, SessionId, SessionManager, UserAuthenticator}; use pgwire::types::Row; use risingwave_common::catalog::{ FunctionId, IndexId, TableId, DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, DEFAULT_SUPER_USER, @@ -47,6 +47,7 @@ use tempfile::{Builder, NamedTempFile}; use crate::catalog::catalog_service::CatalogWriter; use crate::catalog::root_catalog::Catalog; use crate::catalog::{DatabaseId, SchemaId}; +use crate::handler::extended_handle::{Portal, PrepareStatement}; use crate::handler::RwPgResponse; use crate::meta_client::FrontendMetaClient; use crate::session::{AuthContext, FrontendEnv, SessionImpl}; @@ -61,7 +62,7 @@ pub struct LocalFrontend { env: FrontendEnv, } -impl SessionManager for LocalFrontend { +impl SessionManager for LocalFrontend { type Session = SessionImpl; fn connect( diff --git a/src/tests/e2e_extended_mode/src/test.rs b/src/tests/e2e_extended_mode/src/test.rs index abf701af763c..688ab6b126b5 100644 --- a/src/tests/e2e_extended_mode/src/test.rs +++ b/src/tests/e2e_extended_mode/src/test.rs @@ -115,16 +115,8 @@ impl TestSuite { test_eq!(data, 1.234234); } - // TODO(ZENOTME): After #8112, risingwave should support this case. (DOUBLE PRECISION TYPE) - // for row in client - // .query("select $1::DOUBLE PRECISION;", &[&234234.23490238483_f64]) - // .await? - // { - // let data: f64 = row.try_get(0)?; - // test_eq!(data, 234234.23490238483); - // } for row in client - .query("select $1::FLOAT8;", &[&234234.23490238483_f64]) + .query("select $1::DOUBLE PRECISION;", &[&234234.23490238483_f64]) .await? { let data: f64 = row.try_get(0)?; @@ -199,8 +191,6 @@ impl TestSuite { Ok(()) } - /// TODO(ZENOTME): After #8112, risingwave should support to change all `prepare_typed` to - /// `prepare`. We don't need to provide the type explicitly. async fn dql_dml_with_param(&self) -> anyhow::Result<()> { let (client, connection) = tokio_postgres::connect(&self.config, NoTls).await?; @@ -215,7 +205,7 @@ impl TestSuite { client.query("create table t(id int)", &[]).await?; let insert_statement = client - .prepare_typed("insert INTO t (id) VALUES ($1)", &[Type::INT4]) + .prepare_typed("insert INTO t (id) VALUES ($1)", &[]) .await?; for i in 0..20 { @@ -288,7 +278,7 @@ impl TestSuite { client.query("create table t(id int)", &[]).await?; let insert_statement = client - .prepare_typed("insert INTO t (id) VALUES ($1)", &[Type::INT4]) + .prepare_typed("insert INTO t (id) VALUES ($1)", &[]) .await?; for i in 0..10 { diff --git a/src/utils/pgwire/src/error.rs b/src/utils/pgwire/src/error.rs index 63f4745238fd..0e6912866384 100644 --- a/src/utils/pgwire/src/error.rs +++ b/src/utils/pgwire/src/error.rs @@ -14,7 +14,6 @@ use std::io::Error as IoError; -use anyhow::anyhow; use thiserror::Error; use crate::pg_server::BoxedError; @@ -42,8 +41,8 @@ pub enum PsqlError { IoError(#[from] IoError), #[error("{0}")] - /// Include error for describe, bind, parse, execute etc. - Internal(#[from] anyhow::Error), + /// Include error for describe, bind. + Internal(BoxedError), #[error("{0}")] SslError(String), @@ -51,10 +50,10 @@ pub enum PsqlError { impl PsqlError { pub fn no_statement() -> Self { - PsqlError::Internal(anyhow!("No statement found".to_string())) + PsqlError::Internal("No statement found".into()) } pub fn no_portal() -> Self { - PsqlError::Internal(anyhow!("No portal found".to_string())) + PsqlError::Internal("No portal found".into()) } } diff --git a/src/utils/pgwire/src/pg_extended.rs b/src/utils/pgwire/src/pg_extended.rs index 4e2b8a509bd7..f4a08f20db17 100644 --- a/src/utils/pgwire/src/pg_extended.rs +++ b/src/utils/pgwire/src/pg_extended.rs @@ -12,177 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::ops::Range; -use std::str::FromStr; -use std::sync::{Arc, LazyLock}; use std::vec::IntoIter; -use anyhow::anyhow; -use bytes::Bytes; use futures::stream::FusedStream; use futures::{Stream, StreamExt, TryStreamExt}; -use itertools::Itertools; -use postgres_types::{FromSql, Type}; -use regex::Regex; -use risingwave_common::types::DataType; -use risingwave_common::util::iter_util::ZipEqFast; use tokio::io::{AsyncRead, AsyncWrite}; use crate::error::{PsqlError, PsqlResult}; -use crate::pg_field_descriptor::PgFieldDescriptor; use crate::pg_message::{BeCommandCompleteMessage, BeMessage}; -use crate::pg_protocol::{cstr_to_str, Conn}; +use crate::pg_protocol::Conn; use crate::pg_response::{PgResponse, RowSetResult}; -use crate::pg_server::{Session, SessionManager}; -use crate::types::{Format, FormatIterator, Row}; +use crate::types::Row; -#[derive(Default)] -pub struct PgStatement { - name: String, - prepared_statement: PreparedStatement, - row_description: Vec, - is_query: bool, -} - -impl PgStatement { - pub fn new( - name: String, - prepared_statement: PreparedStatement, - row_description: Vec, - is_query: bool, - ) -> Self { - PgStatement { - name, - prepared_statement, - row_description, - is_query, - } - } - - pub fn name(&self) -> String { - self.name.clone() - } - - pub fn param_oid_desc(&self) -> Vec { - self.prepared_statement - .param_type_description() - .into_iter() - .map(|v| v.to_oid()) - .collect_vec() - } - - pub fn row_desc(&self) -> Vec { - self.row_description.clone() - } - - pub fn instance( - &self, - portal_name: String, - params: &[Bytes], - result_formats: Vec, - param_formats: Vec, - ) -> PsqlResult> - where - VS: Stream + Unpin + Send, - { - let instance_query_string = self.prepared_statement.instance(params, ¶m_formats)?; - - let format_iter = FormatIterator::new(&result_formats, self.row_description.len()) - .map_err(|err| PsqlError::Internal(anyhow!(err)))?; - let row_description: Vec = { - let mut row_description = self.row_description.clone(); - row_description - .iter_mut() - .zip_eq_fast(format_iter) - .for_each(|(desc, format)| { - if let Format::Binary = format { - desc.set_to_binary(); - } - }); - row_description - }; - - Ok(PgPortal { - name: portal_name, - query_string: instance_query_string, - result_formats, - is_query: self.is_query, - row_description, - result: None, - row_cache: vec![].into_iter(), - }) - } - - /// We define the statement start with ("select","values","show","with","describe") is query - /// statement. Because these statement will return a result set. - pub fn is_query(&self) -> bool { - self.is_query - } -} - -pub struct PgPortal +pub struct ResultCache where VS: Stream + Unpin + Send, { - name: String, - query_string: String, - result_formats: Vec, - is_query: bool, - row_description: Vec, - result: Option>, + result: PgResponse, row_cache: IntoIter, } -impl PgPortal +impl ResultCache where VS: Stream + Unpin + Send, { - pub fn name(&self) -> String { - self.name.clone() - } - - pub fn query_string(&self) -> String { - self.query_string.clone() - } - - pub fn row_desc(&self) -> Vec { - self.row_description.clone() + pub fn new(result: PgResponse) -> Self { + ResultCache { + result, + row_cache: vec![].into_iter(), + } } - /// When execute a query sql, execute will re-use the result if result will not be consumed - /// completely. Detail can refer:https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY:~:text=Once%20a%20portal,ErrorResponse%2C%20or%20PortalSuspended. - pub async fn execute, S: AsyncWrite + AsyncRead + Unpin>( + /// Return indicate whether the result is consumed completely. + pub async fn consume( &mut self, - session: Arc, row_limit: usize, msg_stream: &mut Conn, - ) -> PsqlResult<()> { - // Check if there is a result cache - let result = if let Some(result) = &mut self.result { - result - } else { - let result = session - .run_statement(self.query_string.as_str(), self.result_formats.clone()) - .await - .map_err(|err| PsqlError::ExecuteError(err))?; - self.result = Some(result); - self.result.as_mut().unwrap() - }; - - // Indicate all data from stream have been completely consumed. - let mut query_end = false; - let mut query_row_count = 0; - - if let Some(notice) = result.get_notice() { + ) -> PsqlResult { + if let Some(notice) = self.result.get_notice() { msg_stream.write_no_flush(&BeMessage::NoticeResponse(¬ice))?; } - if result.is_empty() { + if self.result.is_empty() { // Run the callback before sending the response. - result.run_callback().await?; + self.result.run_callback().await?; msg_stream.write_no_flush(&BeMessage::EmptyQueryResponse)?; - } else if result.is_query() { + return Ok(true); + } + + let mut query_end = false; + if self.result.is_query() { + let mut query_row_count = 0; + // fetch row data // if row_limit is 0, fetch all rows // if row_limit > 0, fetch row_limit rows @@ -196,7 +78,8 @@ where } } } else { - self.row_cache = if let Some(rows) = result + self.row_cache = if let Some(rows) = self + .result .values_stream() .try_next() .await @@ -209,18 +92,19 @@ where }; } } + // Check if the result is consumed completely. // If not, cache the result. - if self.row_cache.len() == 0 && result.values_stream().peekable().is_terminated() { + if self.row_cache.len() == 0 && self.result.values_stream().peekable().is_terminated() { query_end = true; } if query_end { // Run the callback before sending the `CommandComplete` message. - result.run_callback().await?; + self.result.run_callback().await?; msg_stream.write_no_flush(&BeMessage::CommandComplete( BeCommandCompleteMessage { - stmt_type: result.get_stmt_type(), + stmt_type: self.result.get_stmt_type(), rows_cnt: query_row_count as i32, }, ))?; @@ -229,804 +113,19 @@ where } } else { // Run the callback before sending the `CommandComplete` message. - result.run_callback().await?; + self.result.run_callback().await?; msg_stream.write_no_flush(&BeMessage::CommandComplete(BeCommandCompleteMessage { - stmt_type: result.get_stmt_type(), - rows_cnt: result + stmt_type: self.result.get_stmt_type(), + rows_cnt: self + .result .get_effected_rows_cnt() .expect("row count should be set"), }))?; - } - - // If the result is consumed completely or is not a query result, clear the cache. - if query_end || !self.result.as_ref().unwrap().is_query() { - self.result.take(); - } - - Ok(()) - } - - /// We define the statement start with ("select","values","show","with","describe") is query - /// statement. Because these statement will return a result set. - pub fn is_query(&self) -> bool { - self.is_query - } -} - -#[derive(Default)] -pub struct PreparedStatement { - raw_statement: String, - - /// Generic param information used for simplify replace_param(). - /// Range is the start and end index of the param in raw_statement. - /// - /// e.g. - /// raw_statement : "select $1,$2" - /// param_tokens : {{1,(7..9)},{2,(10..12)}} - param_tokens: Vec<(usize, Range)>, - - param_types: Vec, -} - -static PARAMETER_PATTERN: LazyLock = - LazyLock::new(|| Regex::new(r"\$[0-9][0-9]*::[a-zA-Z]+[0-9]*|\$[0-9][0-8]*").unwrap()); - -impl PreparedStatement { - /// parse_statement is to parse the type information from raw_statement and - /// provided_param_types. - /// - /// raw_statement is the sql statement may with generic param. (e.g. "select * from table where - /// a = $1") provided_param_types is the type information provided by user. - /// - /// Why we need parse: - /// The point is user may not provided type information in provided_param_types explicitly. - /// - They may provide in the raw_statement implicitly (e.g. "select * from table where a = - /// $1::INT") . - /// - Or they don't provide. In default, We will treat these unknow type as 'VARCHAR'. - /// So we need to integrate these param information to generate a complete type - /// information(PreparedStatement::param_types). - pub fn parse_statement( - raw_statement: String, - provided_param_oid: Vec, - ) -> PsqlResult { - let provided_param_types = provided_param_oid - .iter() - .map(|x| DataType::from_oid(*x).map_err(|e| PsqlError::ParseError(Box::new(e)))) - .collect::>>()?; - - let generic_params: Vec<_> = PARAMETER_PATTERN - .find_iter(raw_statement.as_str()) - .collect(); - - if generic_params.is_empty() { - return Ok(PreparedStatement { - raw_statement, - param_types: provided_param_types, - param_tokens: vec![], - }); - } - let mut param_tokens = Vec::with_capacity(generic_params.len()); - let mut param_records: Vec> = vec![None; 1]; - - // Parse the implicit type information. - // e.g. - // generic_params = {"$1::VARCHAR","$2::INT4","$3"} - // param_record will be {Some(Type::VARCHAR),Some(Type::INT4),None} - // None means the type information isn't provided implicitly. Such as '$3' above. - for param_match in generic_params { - let range = param_match.range(); - let mut param = param_match.as_str().split("::"); - let param_idx = param - .next() - .unwrap() - .trim_start_matches('$') - .parse::() - .unwrap(); - let param_type = if let Some(str) = param.next() { - Some(DataType::from_str(str).map_err(|_| { - PsqlError::ParseError(format!("Invalid type name {}", str).into()) - })?) - } else { - None - }; - if param_idx > param_records.len() { - param_records.resize(param_idx, None); - } - param_records[param_idx - 1] = param_type; - param_tokens.push((param_idx, range)); - } - - // Integrate the param_records and provided_param_types. - if provided_param_types.len() > param_records.len() { - param_records.resize(provided_param_types.len(), None); - } - for (idx, param_record) in param_records.iter_mut().enumerate() { - if let Some(param_record) = param_record { - // Check consistency of param type. - if idx < provided_param_types.len() && provided_param_types[idx] != *param_record { - return Err(PsqlError::ParseError( - format!("Type mismatch for parameter ${}", idx).into(), - )); - } - continue; - } - if idx < provided_param_types.len() { - *param_record = Some(provided_param_types[idx].clone()); - } else { - // If the type information isn't provided implicitly or explicitly, we just assign - // it as VARCHAR. - *param_record = Some(DataType::Varchar); - } + query_end = true; } - let param_types = param_records.into_iter().map(|x| x.unwrap()).collect(); - - Ok(PreparedStatement { - raw_statement, - param_tokens, - param_types, - }) - } - - fn parse_params( - type_description: &[DataType], - raw_params: &[Bytes], - param_formats: &[Format], - ) -> PsqlResult> { - // Invariant check - if type_description.len() != raw_params.len() { - return Err(PsqlError::Internal(anyhow!( - "The number of params doesn't match the number of types" - ))); - } - if raw_params.is_empty() { - return Ok(vec![]); - } - - let mut params = Vec::with_capacity(raw_params.len()); - - let place_hodler = Type::ANY; - let format_iter = FormatIterator::new(param_formats, raw_params.len()) - .map_err(|err| PsqlError::Internal(anyhow!(err)))?; - - for ((type_oid, raw_param), param_format) in type_description - .iter() - .zip_eq_fast(raw_params.iter()) - .zip_eq_fast(format_iter) - { - let str = match type_oid { - DataType::Varchar | DataType::Bytea => { - format!("'{}'", cstr_to_str(raw_param).unwrap().replace('\'', "''")) - } - DataType::Boolean => match param_format { - Format::Binary => bool::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string(), - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }, - DataType::Int64 => { - let tmp = match param_format { - Format::Binary => { - i64::from_sql(&place_hodler, raw_param).unwrap().to_string() - } - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("{}::INT8", tmp) - } - DataType::Int16 => { - let tmp = match param_format { - Format::Binary => { - i16::from_sql(&place_hodler, raw_param).unwrap().to_string() - } - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("{}::INT2", tmp) - } - DataType::Int32 => { - let tmp = match param_format { - Format::Binary => { - i32::from_sql(&place_hodler, raw_param).unwrap().to_string() - } - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("{}::INT4", tmp) - } - DataType::Float32 => { - let tmp = match param_format { - Format::Binary => { - f32::from_sql(&place_hodler, raw_param).unwrap().to_string() - } - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::FLOAT4", tmp) - } - DataType::Float64 => { - let tmp = match param_format { - Format::Binary => { - f64::from_sql(&place_hodler, raw_param).unwrap().to_string() - } - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::FLOAT8", tmp) - } - DataType::Date => { - let tmp = match param_format { - Format::Binary => chrono::NaiveDate::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string(), - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::DATE", tmp) - } - DataType::Time => { - let tmp = match param_format { - Format::Binary => chrono::NaiveTime::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string(), - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::TIME", tmp) - } - DataType::Timestamp => { - let tmp = match param_format { - Format::Binary => chrono::NaiveDateTime::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string(), - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::TIMESTAMP", tmp) - } - DataType::Decimal => { - let tmp = match param_format { - Format::Binary => rust_decimal::Decimal::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string(), - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::DECIMAL", tmp) - } - DataType::Timestamptz => { - let tmp = match param_format { - Format::Binary => { - chrono::DateTime::::from_sql(&place_hodler, raw_param) - .unwrap() - .to_string() - } - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::TIMESTAMPTZ", tmp) - } - DataType::Interval => { - let tmp = match param_format { - Format::Binary => pg_interval::Interval::from_sql(&place_hodler, raw_param) - .unwrap() - .to_postgres(), - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::INTERVAL", tmp) - } - DataType::Jsonb => { - let tmp = match param_format { - Format::Binary => { - use risingwave_common::types::to_text::ToText as _; - use risingwave_common::types::Scalar as _; - risingwave_common::array::JsonbVal::value_deserialize(raw_param) - .unwrap() - .as_scalar_ref() - .to_text_with_type(&DataType::Jsonb) - } - Format::Text => cstr_to_str(raw_param).unwrap().to_string(), - }; - format!("'{}'::JSONB", tmp) - } - DataType::Serial | DataType::Struct(_) | DataType::List { .. } => { - return Err(PsqlError::Internal(anyhow!( - "Unsupported param type {:?}", - type_oid - ))) - } - }; - params.push(str) - } - - Ok(params) - } - - /// `default_params` creates default params from type oids for - /// [`PreparedStatement::instance_default`]. - fn default_params(type_description: &[DataType]) -> PsqlResult> { - let mut params: _ = Vec::new(); - for oid in type_description.iter() { - match oid { - DataType::Boolean => params.push("false".to_string()), - DataType::Int64 => params.push("0::BIGINT".to_string()), - DataType::Int16 => params.push("0::SMALLINT".to_string()), - DataType::Int32 => params.push("0::INT".to_string()), - DataType::Float32 => params.push("0::FLOAT4".to_string()), - DataType::Float64 => params.push("0::FLOAT8".to_string()), - DataType::Bytea => params.push("'\\x0'".to_string()), - DataType::Varchar => params.push("'0'".to_string()), - DataType::Date => params.push("'2021-01-01'::DATE".to_string()), - DataType::Time => params.push("'00:00:00'::TIME".to_string()), - DataType::Timestamp => params.push("'2021-01-01 00:00:00'::TIMESTAMP".to_string()), - DataType::Decimal => params.push("'0'::DECIMAL".to_string()), - DataType::Timestamptz => { - params.push("'2022-10-01 12:00:00+01:00'::timestamptz".to_string()) - } - DataType::Interval => params.push("'2 months ago'::interval".to_string()), - DataType::Jsonb => params.push("'null'::JSONB".to_string()), - DataType::Serial | DataType::Struct(_) | DataType::List { .. } => { - return Err(PsqlError::Internal(anyhow!( - "Unsupported param type {:?}", - oid - ))) - } - }; - } - Ok(params) - } - - // replace_params replaces the generic params in the raw statement with the given params. - // Our replace algorithm: - // param_tokens is a vec of (param_index, param_range) in the raw statement. - // We sort the param_tokens by param_range.start to get a vec of range sorted from left to - // right. Our purpose is to split the raw statement into several parts: - // [normal part1] [generic param1] [normal part2] [generic param2] [generic param3] - // Then we create the result statement: - // For normal part, we just copy it from the raw statement. - // For generic param, we replace it with the given param. - fn replace_params(&self, params: &[String]) -> String { - let tmp = &self.raw_statement; - - let ranges: Vec<_> = self - .param_tokens - .iter() - .sorted_by(|a, b| a.1.start.cmp(&b.1.start)) - .collect(); - - let mut start_offset = 0; - let mut res = String::new(); - for (idx, range) in ranges { - let param = ¶ms[*idx - 1]; - res.push_str(&tmp[start_offset..range.start]); - res.push_str(param); - start_offset = range.end; - } - res.push_str(&tmp[start_offset..]); - - res - } - - pub fn param_type_description(&self) -> Vec { - self.param_types.clone() - } - - /// `instance_default` used in parse phase. - /// At parse phase, user still do not provide params but we need to infer the sql result.(The - /// session can't support infer the sql with generic param now). Hence to get a sql without - /// generic param, we used `default_params()` to generate default params according `param_types` - /// and replace the generic param with them. - pub fn instance_default(&self) -> PsqlResult { - let default_params = Self::default_params(&self.param_types)?; - Ok(self.replace_params(&default_params)) - } - - pub fn instance(&self, raw_params: &[Bytes], param_formats: &[Format]) -> PsqlResult { - let params = Self::parse_params(&self.param_types, raw_params, param_formats)?; - Ok(self.replace_params(¶ms)) - } -} - -#[cfg(test)] -mod tests { - use chrono::{DateTime, Utc}; - use pg_interval::Interval; - // Note this useful idiom: importing names from outer (for mod tests) scope. - use postgres_types::private::BytesMut; - use risingwave_common::types::{DataType, Date, Time, Timestamp}; - use tokio_postgres::types::{ToSql, Type}; - - use crate::pg_extended::PreparedStatement; - use crate::types::Format; - - #[test] - fn test_prepared_statement_without_param() { - let raw_statement = "SELECT * FROM test_table".to_string(); - let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("SELECT * FROM test_table" == default_sql); - let sql = prepared_statement.instance(&[], &[]).unwrap(); - assert!("SELECT * FROM test_table" == sql); - } - - #[test] - fn test_prepared_statement_with_explicit_param() { - let raw_statement = "SELECT * FROM test_table WHERE id = $1".to_string(); - let prepared_statement = - PreparedStatement::parse_statement(raw_statement, vec![DataType::Int32.to_oid()]) - .unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("SELECT * FROM test_table WHERE id = 0::INT" == default_sql); - let sql = prepared_statement.instance(&["1".into()], &[]).unwrap(); - assert!("SELECT * FROM test_table WHERE id = 1::INT4" == sql); - - let raw_statement = "INSERT INTO test (index,data) VALUES ($1,$2)".to_string(); - let prepared_statement = PreparedStatement::parse_statement( - raw_statement, - vec![DataType::Int32.to_oid(), DataType::Varchar.to_oid()], - ) - .unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("INSERT INTO test (index,data) VALUES (0::INT,'0')" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("INSERT INTO test (index,data) VALUES (1::INT4,'DATA')" == sql); - - let raw_statement = "UPDATE COFFEES SET SALES = $1 WHERE COF_NAME LIKE $2".to_string(); - let prepared_statement = PreparedStatement::parse_statement( - raw_statement, - vec![DataType::Int32.to_oid(), DataType::Varchar.to_oid()], - ) - .unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("UPDATE COFFEES SET SALES = 0::INT WHERE COF_NAME LIKE '0'" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("UPDATE COFFEES SET SALES = 1::INT4 WHERE COF_NAME LIKE 'DATA'" == sql); - - let raw_statement = "SELECT * FROM test_table WHERE id = $1 AND name = $3".to_string(); - let prepared_statement = PreparedStatement::parse_statement( - raw_statement, - vec![ - DataType::Int32.to_oid(), - DataType::Varchar.to_oid(), - DataType::Varchar.to_oid(), - ], - ) - .unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("SELECT * FROM test_table WHERE id = 0::INT AND name = '0'" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into(), "NAME".into()], &[]) - .unwrap(); - assert!("SELECT * FROM test_table WHERE id = 1::INT4 AND name = 'NAME'" == sql); - } - - #[test] - fn test_prepared_statement_with_implicit_param() { - let raw_statement = "SELECT * FROM test_table WHERE id = $1::INT".to_string(); - let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("SELECT * FROM test_table WHERE id = 0::INT" == default_sql); - let sql = prepared_statement.instance(&["1".into()], &[]).unwrap(); - assert!("SELECT * FROM test_table WHERE id = 1::INT4" == sql); - - let raw_statement = - "INSERT INTO test (index,data) VALUES ($1::INT4,$2::VARCHAR)".to_string(); - let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("INSERT INTO test (index,data) VALUES (0::INT,'0')" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("INSERT INTO test (index,data) VALUES (1::INT4,'DATA')" == sql); - - let raw_statement = - "UPDATE COFFEES SET SALES = $1::INT WHERE COF_NAME LIKE $2::VARCHAR".to_string(); - let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("UPDATE COFFEES SET SALES = 0::INT WHERE COF_NAME LIKE '0'" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("UPDATE COFFEES SET SALES = 1::INT4 WHERE COF_NAME LIKE 'DATA'" == sql); - } - - #[test] - fn test_prepared_statement_with_mix_param() { - let raw_statement = - "SELECT * FROM test_table WHERE id = $1 AND name = $2::VARCHAR".to_string(); - let prepared_statement = - PreparedStatement::parse_statement(raw_statement, vec![DataType::Int32.to_oid()]) - .unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("SELECT * FROM test_table WHERE id = 0::INT AND name = '0'" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("SELECT * FROM test_table WHERE id = 1::INT4 AND name = 'DATA'" == sql); - - let raw_statement = "INSERT INTO test (index,data) VALUES ($1,$2)".to_string(); - let prepared_statement = - PreparedStatement::parse_statement(raw_statement, vec![DataType::Int32.to_oid()]) - .unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("INSERT INTO test (index,data) VALUES (0::INT,'0')" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("INSERT INTO test (index,data) VALUES (1::INT4,'DATA')" == sql); - - let raw_statement = - "UPDATE COFFEES SET SALES = $1 WHERE COF_NAME LIKE $2::VARCHAR".to_string(); - let prepared_statement = - PreparedStatement::parse_statement(raw_statement, vec![DataType::Int32.to_oid()]) - .unwrap(); - let default_sql = prepared_statement.instance_default().unwrap(); - assert!("UPDATE COFFEES SET SALES = 0::INT WHERE COF_NAME LIKE '0'" == default_sql); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("UPDATE COFFEES SET SALES = 1::INT4 WHERE COF_NAME LIKE 'DATA'" == sql); - - let raw_statement = "SELECT $1,$2;".to_string(); - let prepared_statement = PreparedStatement::parse_statement(raw_statement, vec![]).unwrap(); - let sql = prepared_statement - .instance(&["test$2".into(), "test$1".into()], &[]) - .unwrap(); - assert!("SELECT 'test$2','test$1';" == sql); - - let raw_statement = "SELECT $1,$1::INT,$2::VARCHAR,$2;".to_string(); - let prepared_statement = - PreparedStatement::parse_statement(raw_statement, vec![DataType::Int32.to_oid()]) - .unwrap(); - let sql = prepared_statement - .instance(&["1".into(), "DATA".into()], &[]) - .unwrap(); - assert!("SELECT 1::INT4,1::INT4,'DATA','DATA';" == sql); - } - #[test] - - fn test_parse_params_text() { - let raw_params = vec!["A".into(), "B".into(), "C".into()]; - let type_description = vec![DataType::Varchar; 3]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); - assert_eq!(params, vec!["'A'", "'B'", "'C'"]); - - let raw_params = vec!["false".into(), "true".into()]; - let type_description = vec![DataType::Boolean; 2]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); - assert_eq!(params, vec!["false", "true"]); - - let raw_params = vec!["1".into(), "2".into(), "3".into()]; - let type_description = vec![DataType::Int16, DataType::Int32, DataType::Int64]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); - assert_eq!(params, vec!["1::INT2", "2::INT4", "3::INT8"]); - - let raw_params = vec![ - "1.0".into(), - "2.0".into(), - rust_decimal::Decimal::from_f32_retain(3.0_f32) - .unwrap() - .to_string() - .into(), - ]; - let type_description = vec![DataType::Float32, DataType::Float64, DataType::Decimal]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); - assert_eq!( - params, - vec!["'1.0'::FLOAT4", "'2.0'::FLOAT8", "'3'::DECIMAL"] - ); - - let raw_params = vec![ - Date::from_ymd_uncheck(2021, 1, 1).0.to_string().into(), - Time::from_hms_uncheck(12, 0, 0).0.to_string().into(), - Timestamp::from_timestamp_uncheck(1610000000, 0) - .0 - .to_string() - .into(), - ]; - let type_description = vec![DataType::Date, DataType::Time, DataType::Timestamp]; - let params = PreparedStatement::parse_params(&type_description, &raw_params, &[]).unwrap(); - assert_eq!( - params, - vec![ - "'2021-01-01'::DATE", - "'12:00:00'::TIME", - "'2021-01-07 06:13:20'::TIMESTAMP" - ] - ); - } - - #[test] - fn test_parse_params_binary() { - let place_hodler = Type::ANY; - - // Test VACHAR type. - let raw_params = vec!["A".into(), "B".into(), "C".into()]; - let type_description = vec![DataType::Varchar; 3]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) - .unwrap(); - assert_eq!(params, vec!["'A'", "'B'", "'C'"]); - - // Test BOOLEAN type. - let mut raw_params = vec![BytesMut::new(); 2]; - false.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); - true.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); - let raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - let type_description = vec![DataType::Boolean; 2]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) - .unwrap(); - assert_eq!(params, vec!["false", "true"]); - - // Test SMALLINT, INT, BIGINT type. - let mut raw_params = vec![BytesMut::new(); 3]; - 1_i16.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); - 2_i32.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); - 3_i64.to_sql(&place_hodler, &mut raw_params[2]).unwrap(); - let raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - let type_description = vec![DataType::Int16, DataType::Int32, DataType::Int64]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) - .unwrap(); - assert_eq!(params, vec!["1::INT2", "2::INT4", "3::INT8"]); - - // Test FLOAT4, FLOAT8, DECIMAL type. - let mut raw_params = vec![BytesMut::new(); 3]; - 1.0_f32.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); - 2.0_f64.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); - rust_decimal::Decimal::from_f32_retain(3.0_f32) - .unwrap() - .to_sql(&place_hodler, &mut raw_params[2]) - .unwrap(); - let raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - let type_description = vec![DataType::Float32, DataType::Float64, DataType::Decimal]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) - .unwrap(); - assert_eq!(params, vec!["'1'::FLOAT4", "'2'::FLOAT8", "'3'::DECIMAL"]); - - let mut raw_params = vec![BytesMut::new(); 3]; - f32::NAN.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); - f64::INFINITY - .to_sql(&place_hodler, &mut raw_params[1]) - .unwrap(); - f64::NEG_INFINITY - .to_sql(&place_hodler, &mut raw_params[2]) - .unwrap(); - let raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - let type_description = vec![DataType::Float32, DataType::Float64, DataType::Float64]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) - .unwrap(); - assert_eq!( - params, - vec!["'NaN'::FLOAT4", "'inf'::FLOAT8", "'-inf'::FLOAT8"] - ); - - // Test DATE, TIME, TIMESTAMP type. - let mut raw_params = vec![BytesMut::new(); 3]; - Date::from_ymd_uncheck(2021, 1, 1) - .0 - .to_sql(&place_hodler, &mut raw_params[0]) - .unwrap(); - Time::from_hms_uncheck(12, 0, 0) - .0 - .to_sql(&place_hodler, &mut raw_params[1]) - .unwrap(); - Timestamp::from_timestamp_uncheck(1610000000, 0) - .0 - .to_sql(&place_hodler, &mut raw_params[2]) - .unwrap(); - let raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - let type_description = vec![DataType::Date, DataType::Time, DataType::Timestamp]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) - .unwrap(); - assert_eq!( - params, - vec![ - "'2021-01-01'::DATE", - "'12:00:00'::TIME", - "'2021-01-07 06:13:20'::TIMESTAMP" - ] - ); - - // Test TIMESTAMPTZ, INTERVAL type. - let mut raw_params = vec![BytesMut::new(); 2]; - DateTime::::from_utc(Timestamp::from_timestamp_uncheck(1200, 0).0, Utc) - .to_sql(&place_hodler, &mut raw_params[0]) - .unwrap(); - let interval = Interval::new(1, 1, 24000000); - ToSql::to_sql(&interval, &place_hodler, &mut raw_params[1]).unwrap(); - let raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - let type_description = vec![DataType::Timestamptz, DataType::Interval]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary]) - .unwrap(); - assert_eq!( - params, - vec![ - "'1970-01-01 00:20:00 UTC'::TIMESTAMPTZ", - "'1 mons 1 days 00:00:24'::INTERVAL" - ] - ); - } - - #[test] - fn test_parse_params_mix_format() { - let place_hodler = Type::ANY; - - // Test VACHAR type. - let raw_params = vec!["A".into(), "B".into(), "C".into()]; - let type_description = vec![DataType::Varchar; 3]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Text; 3]) - .unwrap(); - assert_eq!(params, vec!["'A'", "'B'", "'C'"]); - - // Test BOOLEAN type. - let mut raw_params = vec![BytesMut::new(); 2]; - false.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); - true.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); - let raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - let type_description = vec![DataType::Boolean; 2]; - let params = - PreparedStatement::parse_params(&type_description, &raw_params, &[Format::Binary; 2]) - .unwrap(); - assert_eq!(params, vec!["false", "true"]); - - // Test SMALLINT, INT, BIGINT type. - let mut raw_params = vec![BytesMut::new(); 2]; - 1_i16.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); - 2_i32.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); - let mut raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - raw_params.push("3".into()); - let type_description = vec![DataType::Int16, DataType::Int32, DataType::Int64]; - let params = PreparedStatement::parse_params( - &type_description, - &raw_params, - &[Format::Binary, Format::Binary, Format::Text], - ) - .unwrap(); - assert_eq!(params, vec!["1::INT2", "2::INT4", "3::INT8"]); - - // Test FLOAT4, FLOAT8, DECIMAL type. - let mut raw_params = vec![BytesMut::new(); 2]; - 1.0_f32.to_sql(&place_hodler, &mut raw_params[0]).unwrap(); - 2.0_f64.to_sql(&place_hodler, &mut raw_params[1]).unwrap(); - let mut raw_params = raw_params - .into_iter() - .map(|b| b.freeze()) - .collect::>(); - raw_params.push("TEST".into()); - let type_description = vec![DataType::Float32, DataType::Float64, DataType::Varchar]; - let params = PreparedStatement::parse_params( - &type_description, - &raw_params, - &[Format::Binary, Format::Binary, Format::Text], - ) - .unwrap(); - assert_eq!(params, vec!["'1'::FLOAT4", "'2'::FLOAT8", "'TEST'"]); + Ok(query_end) } } diff --git a/src/utils/pgwire/src/pg_protocol.rs b/src/utils/pgwire/src/pg_protocol.rs index 392c2bd95624..4558f72d3432 100644 --- a/src/utils/pgwire/src/pg_protocol.rs +++ b/src/utils/pgwire/src/pg_protocol.rs @@ -16,14 +16,16 @@ use std::collections::HashMap; use std::io::{self, Error as IoError, ErrorKind}; use std::path::PathBuf; use std::pin::Pin; +use std::str; use std::str::Utf8Error; use std::sync::Arc; -use std::{str, vec}; use bytes::{Bytes, BytesMut}; use futures::stream::StreamExt; use futures::Stream; +use itertools::Itertools; use openssl::ssl::{SslAcceptor, SslContext, SslContextRef, SslMethod}; +use risingwave_common::types::DataType; use risingwave_sqlparser::parser::Parser; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_openssl::SslStream; @@ -31,22 +33,23 @@ use tracing::log::trace; use tracing::{error, warn}; use crate::error::{PsqlError, PsqlResult}; -use crate::pg_extended::{PgPortal, PgStatement, PreparedStatement}; -use crate::pg_field_descriptor::PgFieldDescriptor; +use crate::pg_extended::ResultCache; use crate::pg_message::{ BeCommandCompleteMessage, BeMessage, BeParameterStatusMessage, FeBindMessage, FeCancelMessage, FeCloseMessage, FeDescribeMessage, FeExecuteMessage, FeMessage, FeParseMessage, FePasswordMessage, FeStartupMessage, }; -use crate::pg_response::{RowSetResult, StatementType}; +use crate::pg_response::RowSetResult; use crate::pg_server::{Session, SessionManager, UserAuthenticator}; use crate::types::Format; /// The state machine for each psql connection. /// Read pg messages from tcp stream and write results back. -pub struct PgProtocol +pub struct PgProtocol where - SM: SessionManager, + PS: Send + Clone + 'static, + PO: Send + Clone + 'static, + SM: SessionManager, VS: Stream + Unpin + Send, { /// Used for write/read pg messages. @@ -59,10 +62,14 @@ where session_mgr: Arc, session: Option>, - unnamed_statement: Option, - unnamed_portal: Option>, - named_statements: HashMap, - named_portals: HashMap>, + result_cache: HashMap>, + unnamed_prepare_statement: Option, + prepare_statement_store: HashMap, + unnamed_portal: Option, + portal_store: HashMap, + // Used to store the dependency of portal and prepare statement. + // When we close a prepare statement, we need to close all the portals that depend on it. + statement_portal_dependency: HashMap>, // Used for ssl connection. // If None, not expected to build ssl connection (panic). @@ -94,9 +101,11 @@ impl TlsConfig { } } -impl Drop for PgProtocol +impl Drop for PgProtocol where - SM: SessionManager, + PS: Send + Clone + 'static, + PO: Send + Clone + 'static, + SM: SessionManager, VS: Stream + Unpin + Send, { fn drop(&mut self) { @@ -125,10 +134,12 @@ pub fn cstr_to_str(b: &Bytes) -> Result<&str, Utf8Error> { std::str::from_utf8(without_null) } -impl PgProtocol +impl PgProtocol where + PS: Send + Clone + 'static, + PO: Send + Clone + 'static, S: AsyncWrite + AsyncRead + Unpin, - SM: SessionManager, + SM: SessionManager, VS: Stream + Unpin + Send, { pub fn new(stream: S, session_mgr: Arc, tls_config: Option) -> Self { @@ -141,13 +152,15 @@ where state: PgProtocolState::Startup, session_mgr, session: None, - unnamed_statement: None, - unnamed_portal: None, - named_statements: Default::default(), - named_portals: Default::default(), tls_context: tls_config .as_ref() .and_then(|e| build_ssl_ctx_from_config(e).ok()), + result_cache: Default::default(), + unnamed_prepare_statement: Default::default(), + prepare_statement_store: Default::default(), + unnamed_portal: Default::default(), + portal_store: Default::default(), + statement_portal_dependency: Default::default(), } } @@ -218,7 +231,7 @@ where FeMessage::Query(query_msg) => self.process_query_msg(query_msg.get_sql()).await?, FeMessage::CancelQuery(m) => self.process_cancel_msg(m)?, FeMessage::Terminate => self.process_terminate(), - FeMessage::Parse(m) => self.process_parse_msg(m).await?, + FeMessage::Parse(m) => self.process_parse_msg(m)?, FeMessage::Bind(m) => self.process_bind_msg(m)?, FeMessage::Execute(m) => self.process_execute_msg(m).await?, FeMessage::Describe(m) => self.process_describe_msg(m)?, @@ -392,7 +405,7 @@ where self.is_terminate = true; } - async fn process_parse_msg(&mut self, msg: FeParseMessage) -> PsqlResult<()> { + fn process_parse_msg(&mut self, msg: FeParseMessage) -> PsqlResult<()> { let sql = cstr_to_str(&msg.sql_bytes).unwrap(); let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_string(); tracing::trace!( @@ -401,7 +414,11 @@ where statement_name ); - let is_query_sql = { + if self.prepare_statement_store.contains_key(&statement_name) { + return Err(PsqlError::ParseError("Duplicated statement name".into())); + } + + let stmt = { let stmts = Parser::parse_sql(sql) .inspect_err(|e| tracing::error!("failed to parse sql:\n{}:\n{}", sql, e)) .map_err(|err| PsqlError::ParseError(err.into()))?; @@ -412,37 +429,40 @@ where )); } + // TODO: This behavior is not compatible with Postgres. if stmts.is_empty() { - false - } else { - StatementType::infer_from_statement(&stmts[0]) - .map_or(false, |stmt_type| stmt_type.is_query()) + return Err(PsqlError::ParseError( + "Empty statement is parsed in extended query mode".into(), + )); } + + stmts.into_iter().next().unwrap() }; - let prepared_statement = PreparedStatement::parse_statement(sql.to_string(), msg.type_ids)?; + let param_types = msg + .type_ids + .iter() + .map(|&id| DataType::from_oid(id)) + .try_collect() + .map_err(|err| PsqlError::ParseError(err.into()))?; - // Create the row description. - let fields: Vec = if is_query_sql { - let sql = prepared_statement.instance_default()?; + let session = self.session.clone().unwrap(); + let prepare_statement = session + .parse(stmt, param_types) + .map_err(PsqlError::ParseError)?; - let session = self.session.clone().unwrap(); - session - .infer_return_type(&sql) - .await - .map_err(PsqlError::ParseError)? + if statement_name.is_empty() { + self.unnamed_prepare_statement.replace(prepare_statement); } else { - vec![] - }; + self.prepare_statement_store + .insert(statement_name.clone(), prepare_statement); + } - let statement = PgStatement::new(statement_name, prepared_statement, fields, is_query_sql); + self.statement_portal_dependency + .entry(statement_name) + .or_insert_with(Vec::new) + .clear(); - let name = statement.name(); - if name.is_empty() { - self.unnamed_statement.replace(statement); - } else { - self.named_statements.insert(name, statement); - } self.stream.write_no_flush(&BeMessage::ParseComplete)?; Ok(()) } @@ -450,21 +470,18 @@ where fn process_bind_msg(&mut self, msg: FeBindMessage) -> PsqlResult<()> { let statement_name = cstr_to_str(&msg.statement_name).unwrap().to_string(); let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_string(); - // 1. Get statement. + trace!( target: "pgwire_query_log", "(extended query)bind: statement name: {}, portal name: {}", &statement_name,&portal_name ); - let statement = if statement_name.is_empty() { - self.unnamed_statement - .as_ref() - .ok_or_else(PsqlError::no_statement)? - } else { - self.named_statements - .get(&statement_name) - .ok_or_else(PsqlError::no_statement)? - }; + + if self.portal_store.contains_key(&portal_name) { + return Err(PsqlError::Internal("Duplicated portal name".into())); + } + + let prepare_statement = self.get_statement(&statement_name)?; let result_formats = msg .result_format_codes @@ -477,109 +494,118 @@ where .map(|&format_code| Format::from_i16(format_code)) .try_collect()?; - // 2. Instance the statement to get the portal. - let portal = statement.instance( - portal_name.clone(), - &msg.params, - result_formats, - param_formats, - )?; + let portal = self + .session + .clone() + .unwrap() + .bind(prepare_statement, msg.params, param_formats, result_formats) + .map_err(PsqlError::Internal)?; - // 3. Insert the Portal. if portal_name.is_empty() { + self.result_cache.remove(&portal_name); self.unnamed_portal.replace(portal); } else { - self.named_portals.insert(portal_name, portal); + assert!( + self.result_cache.get(&portal_name).is_none(), + "Named portal never can be overridden." + ); + self.portal_store.insert(portal_name.clone(), portal); } + + self.statement_portal_dependency + .get_mut(&statement_name) + .unwrap() + .push(portal_name); + self.stream.write_no_flush(&BeMessage::BindComplete)?; Ok(()) } async fn process_execute_msg(&mut self, msg: FeExecuteMessage) -> PsqlResult<()> { - // 1. Get portal. let portal_name = cstr_to_str(&msg.portal_name).unwrap().to_string(); - let portal = if msg.portal_name.is_empty() { - self.unnamed_portal - .as_mut() - .ok_or_else(PsqlError::no_portal)? + let row_max = msg.max_rows as usize; + tracing::trace!(target: "pgwire_query_log", "(extended query)execute portal name: {}",portal_name); + + if let Some(mut result_cache) = self.result_cache.remove(&portal_name) { + assert!(self.portal_store.contains_key(&portal_name)); + + let is_cosume_completed = result_cache.consume::(row_max, &mut self.stream).await?; + + if !is_cosume_completed { + self.result_cache.insert(portal_name, result_cache); + } } else { - // NOTE Error handle need modify later. - self.named_portals - .get_mut(&portal_name) - .ok_or_else(PsqlError::no_portal)? - }; + let portal = self.get_portal(&portal_name)?; - tracing::trace!(target: "pgwire_query_log", "(extended query)execute query: {}, portal name: {}", portal.query_string(),portal_name); + let pg_response = self + .session + .clone() + .unwrap() + .execute(portal) + .await + .map_err(PsqlError::ExecuteError)?; - // 2. Execute instance statement using portal. - let session = self.session.clone().unwrap(); - portal - .execute::(session, msg.max_rows.try_into().unwrap(), &mut self.stream) - .await?; + let mut result_cache = ResultCache::new(pg_response); + let is_consume_completed = result_cache.consume::(row_max, &mut self.stream).await?; + if !is_consume_completed { + self.result_cache.insert(portal_name, result_cache); + } + } - // NOTE there is no ReadyForQuery message. Ok(()) } fn process_describe_msg(&mut self, msg: FeDescribeMessage) -> PsqlResult<()> { + let name = cstr_to_str(&msg.name).unwrap().to_string(); // b'S' => Statement // b'P' => Portal tracing::trace!( target: "pgwire_query_log", "(extended query)describe name: {}", - cstr_to_str(&msg.name).unwrap() + name, ); assert!(msg.kind == b'S' || msg.kind == b'P'); if msg.kind == b'S' { - let name = cstr_to_str(&msg.name).unwrap().to_string(); - let statement = if name.is_empty() { - self.unnamed_statement - .as_ref() - .ok_or_else(PsqlError::no_statement)? - } else { - // NOTE Error handle need modify later. - self.named_statements - .get(&name) - .ok_or_else(PsqlError::no_statement)? - }; + let prepare_statement = self.get_statement(&name)?; + + let (param_types, row_descriptions) = self + .session + .clone() + .unwrap() + .describe_statement(prepare_statement) + .map_err(PsqlError::Internal)?; - // 1. Send parameter description. self.stream .write_no_flush(&BeMessage::ParameterDescription( - &statement.param_oid_desc(), + ¶m_types.iter().map(|t| t.to_oid()).collect_vec(), ))?; - // 2. Send row description. - if statement.is_query() { - self.stream - .write_no_flush(&BeMessage::RowDescription(&statement.row_desc()))?; - } else { + if row_descriptions.is_empty() { // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B, // return NoData message if the statement is not a query. self.stream.write_no_flush(&BeMessage::NoData)?; - } - } else if msg.kind == b'P' { - let name = cstr_to_str(&msg.name).unwrap().to_string(); - let portal = if name.is_empty() { - self.unnamed_portal - .as_ref() - .ok_or_else(PsqlError::no_portal)? } else { - // NOTE Error handle need modify later. - self.named_portals - .get(&name) - .ok_or_else(PsqlError::no_portal)? - }; - - // 3. Send row description. - if portal.is_query() { self.stream - .write_no_flush(&BeMessage::RowDescription(&portal.row_desc()))?; - } else { + .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?; + } + } else if msg.kind == b'P' { + let portal = self.get_portal(&name)?; + + let row_descriptions = self + .session + .clone() + .unwrap() + .describe_portral(portal) + .map_err(PsqlError::Internal)?; + + if row_descriptions.is_empty() { // According https://www.postgresql.org/docs/current/protocol-flow.html#:~:text=The%20response%20is%20a%20RowDescri[…]0a%20query%20that%20will%20return%20rows%3B, // return NoData message if the statement is not a query. self.stream.write_no_flush(&BeMessage::NoData)?; + } else { + self.stream + .write_no_flush(&BeMessage::RowDescription(&row_descriptions))?; } } Ok(()) @@ -589,13 +615,71 @@ where let name = cstr_to_str(&msg.name).unwrap().to_string(); assert!(msg.kind == b'S' || msg.kind == b'P'); if msg.kind == b'S' { - self.named_statements.remove_entry(&name); + if name.is_empty() { + self.unnamed_prepare_statement = None; + } else { + self.prepare_statement_store.remove(&name); + } + for portal_name in self + .statement_portal_dependency + .remove(&name) + .unwrap_or(vec![]) + { + self.remove_portal(&portal_name); + } } else if msg.kind == b'P' { - self.named_portals.remove_entry(&name); + self.remove_portal(&name); } self.stream.write_no_flush(&BeMessage::CloseComplete)?; Ok(()) } + + fn remove_portal(&mut self, portal_name: &str) { + if portal_name.is_empty() { + self.unnamed_portal = None; + } else { + self.portal_store.remove(portal_name); + } + self.result_cache.remove(portal_name); + } + + fn get_portal(&self, portal_name: &str) -> PsqlResult { + if portal_name.is_empty() { + Ok(self + .unnamed_portal + .as_ref() + .ok_or_else(|| PsqlError::Internal("unnamed portal not found".into()))? + .clone()) + } else { + Ok(self + .portal_store + .get(portal_name) + .ok_or_else(|| { + PsqlError::Internal(format!("Portal {} not found", portal_name).into()) + })? + .clone()) + } + } + + fn get_statement(&self, statement_name: &str) -> PsqlResult { + if statement_name.is_empty() { + Ok(self + .unnamed_prepare_statement + .as_ref() + .ok_or_else(|| PsqlError::Internal("unnamed prepare statement not found".into()))? + .clone()) + } else { + Ok(self + .prepare_statement_store + .get(statement_name) + .ok_or_else(|| { + PsqlError::Internal( + format!("Prepare statement {} not found", statement_name).into(), + ) + })? + .clone()) + } + } } /// Wraps a byte stream and read/write pg messages. diff --git a/src/utils/pgwire/src/pg_server.rs b/src/utils/pgwire/src/pg_server.rs index ec28ffde5e17..eb9482b001d6 100644 --- a/src/utils/pgwire/src/pg_server.rs +++ b/src/utils/pgwire/src/pg_server.rs @@ -17,7 +17,9 @@ use std::io; use std::result::Result; use std::sync::Arc; +use bytes::Bytes; use futures::Stream; +use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpListener; @@ -32,11 +34,13 @@ pub type BoxedError = Box; pub type SessionId = (i32, i32); /// The interface for a database system behind pgwire protocol. /// We can mock it for testing purpose. -pub trait SessionManager: Send + Sync + 'static +pub trait SessionManager: Send + Sync + 'static where VS: Stream + Unpin + Send, + PS: Send + Clone + 'static, + PO: Send + Clone + 'static, { - type Session: Session; + type Session: Session; fn connect(&self, database: &str, user_name: &str) -> Result, BoxedError>; @@ -50,16 +54,12 @@ where /// A psql connection. Each connection binds with a database. Switching database will need to /// recreate another connection. #[async_trait::async_trait] -pub trait Session: Send + Sync +pub trait Session: Send + Sync where VS: Stream + Unpin + Send, + PS: Send + Clone + 'static, + PO: Send + Clone + 'static, { - async fn run_statement( - self: Arc, - sql: &str, - formats: Vec, - ) -> Result, BoxedError>; - /// The str sql can not use the unparse from AST: There is some problem when dealing with create /// view, see https://github.com/risingwavelabs/risingwave/issues/6801. async fn run_one_query( @@ -68,10 +68,29 @@ where format: Format, ) -> Result, BoxedError>; - async fn infer_return_type( + fn parse( + self: Arc, + sql: Statement, + params_types: Vec, + ) -> Result; + + fn bind( self: Arc, - sql: &str, - ) -> Result, BoxedError>; + prepare_statement: PS, + params: Vec, + param_formats: Vec, + result_formats: Vec, + ) -> Result; + + async fn execute(self: Arc, portal: PO) -> Result, BoxedError>; + + fn describe_statement( + self: Arc, + prepare_statement: PS, + ) -> Result<(Vec, Vec), BoxedError>; + + fn describe_portral(self: Arc, portal: PO) -> Result, BoxedError>; + fn user_authenticator(&self) -> &UserAuthenticator; fn id(&self) -> SessionId; @@ -103,13 +122,15 @@ impl UserAuthenticator { } /// Binds a Tcp listener at `addr`. Spawn a coroutine to serve every new connection. -pub async fn pg_serve( +pub async fn pg_serve( addr: &str, - session_mgr: Arc>, + session_mgr: Arc>, ssl_config: Option, ) -> io::Result<()> where VS: Stream + Unpin + Send + 'static, + PS: Send + Clone + 'static, + PO: Send + Clone + 'static, { let listener = TcpListener::bind(addr).await.unwrap(); // accept connections and process them, spawning a new thread for each one @@ -138,15 +159,17 @@ where } #[tracing::instrument(level = "debug", skip_all)] -pub fn handle_connection( +pub fn handle_connection( stream: S, session_mgr: Arc, tls_config: Option, ) -> impl Future> where S: AsyncWrite + AsyncRead + Unpin, - SM: SessionManager, + SM: SessionManager, VS: Stream + Unpin + Send + 'static, + PS: Send + Clone + 'static, + PO: Send + Clone + 'static, { let mut pg_proto = PgProtocol::new(stream, session_mgr, tls_config); async { @@ -169,8 +192,8 @@ mod tests { use bytes::Bytes; use futures::stream::BoxStream; use futures::StreamExt; + use risingwave_common::types::DataType; use risingwave_sqlparser::ast::Statement; - use tokio_postgres::types::*; use tokio_postgres::NoTls; use crate::pg_field_descriptor::PgFieldDescriptor; @@ -182,8 +205,9 @@ mod tests { use crate::types::Row; struct MockSessionManager {} + struct MockSession {} - impl SessionManager> for MockSessionManager { + impl SessionManager, String, String> for MockSessionManager { type Session = MockSession; fn connect( @@ -205,58 +229,52 @@ mod tests { fn end_session(&self, _session: &Self::Session) {} } - struct MockSession {} - #[async_trait::async_trait] - impl Session> for MockSession { - async fn run_statement( + impl Session, String, String> for MockSession { + async fn run_one_query( self: Arc, - sql: &str, - _format: Vec, - ) -> Result>, Box> - { - // split a statement and trim \' around the input param to construct result. - // Ex: - // SELECT 'a','b' -> result: a , b - let res: Vec> = sql - .split(&[' ', ',', ';']) - .skip(1) - .map(|x| { - Some( - x.trim_start_matches('\'') - .trim_end_matches('\'') - .to_string() - .into(), - ) - }) - .collect(); - let len = res.len(); - + _sql: Statement, + _format: types::Format, + ) -> Result>, BoxedError> { Ok(PgResponse::new_for_stream( StatementType::SELECT, None, - futures::stream::iter(vec![Ok(vec![Row::new(res)])]).boxed(), + futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])]).boxed(), vec![ // 1043 is the oid of varchar type. // -1 is the type len of varchar type. PgFieldDescriptor::new("".to_string(), 1043, -1); - len + 1 ], )) } - /// The test below will issue "BEGIN", "ROLLBACK" as simple query, but the results do not - /// matter, so just return a fake one. - async fn run_one_query( + fn parse( self: Arc, _sql: Statement, - _format: types::Format, + _params_types: Vec, + ) -> Result { + Ok(String::new()) + } + + fn bind( + self: Arc, + _prepare_statement: String, + _params: Vec, + _param_formats: Vec, + _result_formats: Vec, + ) -> Result { + Ok(String::new()) + } + + async fn execute( + self: Arc, + _portal: String, ) -> Result>, BoxedError> { - let res: Vec> = vec![Some(Bytes::new())]; Ok(PgResponse::new_for_stream( StatementType::SELECT, None, - futures::stream::iter(vec![Ok(vec![Row::new(res)])]).boxed(), + futures::stream::iter(vec![Ok(vec![Row::new(vec![Some(Bytes::new())])])]).boxed(), vec![ // 1043 is the oid of varchar type. // -1 is the type len of varchar type. @@ -266,25 +284,25 @@ mod tests { )) } - fn user_authenticator(&self) -> &UserAuthenticator { - &UserAuthenticator::None + fn describe_statement( + self: Arc, + _statement: String, + ) -> Result<(Vec, Vec), BoxedError> { + Ok(( + vec![], + vec![PgFieldDescriptor::new("".to_string(), 1043, -1)], + )) } - async fn infer_return_type( + fn describe_portral( self: Arc, - sql: &str, - ) -> Result, super::BoxedError> { - let count = sql.split(&[' ', ',', ';']).skip(1).count(); - Ok(vec![ - // 1043 is the oid of varchar type. - // -1 is the type len of varchar type. - PgFieldDescriptor::new( - "".to_string(), - 1043, - -1 - ); - count - ]) + _portal: String, + ) -> Result, BoxedError> { + Ok(vec![PgFieldDescriptor::new("".to_string(), 1043, -1)]) + } + + fn user_authenticator(&self) -> &UserAuthenticator { + &UserAuthenticator::None } fn id(&self) -> SessionId { @@ -292,21 +310,15 @@ mod tests { } } - // test_psql_extended_mode_explicit_simple - // constrain: - // - Only support simple SELECT statement. - // - Must provide all type description of the generic types. - // - Input description(params description) should include all the generic params description we - // need. #[tokio::test] - async fn test_psql_extended_mode_explicit_simple() { + async fn test_query() { let session_mgr = Arc::new(MockSessionManager {}); tokio::spawn(async move { pg_serve("127.0.0.1:10000", session_mgr, None).await }); // wait for server to start tokio::time::sleep(std::time::Duration::from_millis(10)).await; // Connect to the database. - let (mut client, connection) = tokio_postgres::connect("host=localhost port=10000", NoTls) + let (client, connection) = tokio_postgres::connect("host=localhost port=10000", NoTls) .await .unwrap(); @@ -318,121 +330,17 @@ mod tests { } }); - // explicit parameter (test pre_statement) - { - let statement = client - .prepare_typed("SELECT $1;", &[Type::VARCHAR]) - .await - .unwrap(); - - let rows = client.query(&statement, &[&"AA"]).await.unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "AA"); + let rows = client + .simple_query("SELECT ''") + .await + .expect("Error executing query"); + // Row + CommandComplete + assert_eq!(rows.len(), 2); - let rows = client.query(&statement, &[&"BB"]).await.unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "BB"); - } - // explicit parameter (test portal) - { - let transaction = client.transaction().await.unwrap(); - let statement = transaction - .prepare_typed("SELECT $1;", &[Type::VARCHAR]) - .await - .unwrap(); - let portal1 = transaction.bind(&statement, &[&"AA"]).await.unwrap(); - let portal2 = transaction.bind(&statement, &[&"BB"]).await.unwrap(); - let rows = transaction.query_portal(&portal1, 0).await.unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "AA"); - let rows = transaction.query_portal(&portal2, 0).await.unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "BB"); - transaction.rollback().await.unwrap(); - } - // mix parameter - { - let statement = client - .prepare_typed("SELECT $1,$2;", &[Type::VARCHAR, Type::VARCHAR]) - .await - .unwrap(); - let rows = client.query(&statement, &[&"AA", &"BB"]).await.unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "AA"); - let value: &str = rows[0].get(1); - assert_eq!(value, "BB"); - - let statement = client - .prepare_typed("SELECT $1,$1;", &[Type::VARCHAR]) - .await - .unwrap(); - let rows = client.query(&statement, &[&"AA"]).await.unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "AA"); - let value: &str = rows[0].get(1); - assert_eq!(value, "AA"); - - let statement = client - .prepare_typed( - "SELECT $2,$3,$1,$3,$2;", - &[Type::VARCHAR, Type::VARCHAR, Type::VARCHAR], - ) - .await - .unwrap(); - let rows = client - .query(&statement, &[&"AA", &"BB", &"CC"]) - .await - .unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "BB"); - let value: &str = rows[0].get(1); - assert_eq!(value, "CC"); - let value: &str = rows[0].get(2); - assert_eq!(value, "AA"); - let value: &str = rows[0].get(3); - assert_eq!(value, "CC"); - let value: &str = rows[0].get(4); - assert_eq!(value, "BB"); - - let statement = client - .prepare_typed( - "SELECT $3,$1;", - &[Type::VARCHAR, Type::VARCHAR, Type::VARCHAR], - ) - .await - .unwrap(); - let rows = client - .query(&statement, &[&"AA", &"BB", &"CC"]) - .await - .unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "CC"); - let value: &str = rows[0].get(1); - assert_eq!(value, "AA"); - - let statement = client - .prepare_typed( - "SELECT $2,$1;", - &[Type::VARCHAR, Type::VARCHAR, Type::VARCHAR], - ) - .await - .unwrap(); - let rows = client - .query(&statement, &[&"AA", &"BB", &"CC"]) - .await - .unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "BB"); - let value: &str = rows[0].get(1); - assert_eq!(value, "AA"); - } - // no params - { - let rows = client.query("SELECT 'AA','BB';", &[]).await.unwrap(); - let value: &str = rows[0].get(0); - assert_eq!(value, "AA"); - let value: &str = rows[0].get(1); - assert_eq!(value, "BB"); - } + let rows = client + .query("SELECT ''", &[]) + .await + .expect("Error executing query"); + assert_eq!(rows.len(), 1); } } diff --git a/src/utils/pgwire/src/types.rs b/src/utils/pgwire/src/types.rs index e05405a4cb02..e558dd0f9687 100644 --- a/src/utils/pgwire/src/types.rs +++ b/src/utils/pgwire/src/types.rs @@ -16,7 +16,6 @@ use std::iter::TrustedLen; use std::ops::Index; use std::slice::Iter; -use anyhow::anyhow; use bytes::Bytes; use crate::error::{PsqlError, PsqlResult}; @@ -67,10 +66,9 @@ impl Format { match format_code { 0 => Ok(Format::Text), 1 => Ok(Format::Binary), - _ => Err(PsqlError::Internal(anyhow!( - "Unknown format code: {}", - format_code - ))), + _ => Err(PsqlError::Internal( + format!("Unknown format code: {}", format_code).into(), + )), } } }