Skip to content

Commit

Permalink
feat(frontend): support extended query protocol handle (risingwavelab…
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME authored Mar 16, 2023
1 parent 018fe9e commit 25a9127
Show file tree
Hide file tree
Showing 3 changed files with 337 additions and 9 deletions.
94 changes: 94 additions & 0 deletions src/frontend/src/handler/extended_handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::Statement;

use super::{query, HandlerArgs, RwPgResponse};
use crate::binder::BoundStatement;
use crate::session::SessionImpl;

pub struct PrepareStatement {
pub statement: Statement,
pub bound_statement: BoundStatement,
pub param_types: Vec<DataType>,
}

pub struct Portal {
pub statement: Statement,
pub bound_statement: BoundStatement,
pub result_formats: Vec<Format>,
}

pub fn handle_parse(
session: Arc<SessionImpl>,
stmt: Statement,
specific_param_types: Vec<DataType>,
) -> Result<PrepareStatement> {
session.clear_cancel_query_flag();
let str_sql = stmt.to_string();
let handler_args = HandlerArgs::new(session, &stmt, &str_sql)?;
match stmt {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_parse(handler_args, stmt, specific_param_types),
_ => Err(ErrorCode::NotSupported(
format!("Can't support {} in extended query mode now", str_sql,),
"".to_string(),
)
.into()),
}
}

pub fn handle_bind(
prepare_statement: PrepareStatement,
params: Vec<Bytes>,
param_formats: Vec<Format>,
result_formats: Vec<Format>,
) -> Result<Portal> {
let PrepareStatement {
statement,
bound_statement,
..
} = prepare_statement;
let bound_statement = bound_statement.bind_parameter(params, param_formats)?;
Ok(Portal {
statement,
bound_statement,
result_formats,
})
}

pub async fn handle_execute(session: Arc<SessionImpl>, portal: Portal) -> Result<RwPgResponse> {
session.clear_cancel_query_flag();
let str_sql = portal.statement.to_string();
let handler_args = HandlerArgs::new(session, &portal.statement, &str_sql)?;
match &portal.statement {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_execute(handler_args, portal).await,
_ => Err(ErrorCode::NotSupported(
format!("Can't support {} in extended query mode now", str_sql,),
"".to_string(),
)
.into()),
}
}
1 change: 1 addition & 0 deletions src/frontend/src/handler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub mod drop_table;
pub mod drop_user;
mod drop_view;
pub mod explain;
pub mod extended_handle;
mod flush;
pub mod handle_privilege;
pub mod privilege;
Expand Down
251 changes: 242 additions & 9 deletions src/frontend/src/handler/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ use postgres_types::FromSql;
use risingwave_common::catalog::Schema;
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::session_config::QueryMode;
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{SetExpr, Statement};

use super::extended_handle::{Portal, PrepareStatement};
use super::{PgResponseStream, RwPgResponse};
use crate::binder::{Binder, BoundSetExpr, BoundStatement};
use crate::handler::flush::do_flush;
Expand Down Expand Up @@ -74,6 +76,20 @@ fn must_run_in_distributed_mode(stmt: &Statement) -> Result<bool> {
) | is_insert_using_select(stmt))
}

fn must_run_in_local_mode(bound: &BoundStatement) -> bool {
let mut must_local = false;

if let BoundStatement::Query(query) = &bound {
if let BoundSetExpr::Select(select) = &query.body
&& let Some(relation) = &select.from
&& relation.contains_sys_table() {
must_local = true;
}
}

must_local
}

pub fn gen_batch_query_plan(
session: &SessionImpl,
context: OptimizerContextRef,
Expand All @@ -89,16 +105,9 @@ pub fn gen_batch_query_plan(
let check_items = resolve_privileges(&bound);
session.check_privileges(&check_items)?;

let mut planner = Planner::new(context);
let must_local = must_run_in_local_mode(&bound);

let mut must_local = false;
if let BoundStatement::Query(query) = &bound {
if let BoundSetExpr::Select(select) = &query.body
&& let Some(relation) = &select.from
&& relation.contains_sys_table() {
must_local = true;
}
}
let mut planner = Planner::new(context);

let mut logical = planner.plan(bound)?;
let schema = logical.schema();
Expand Down Expand Up @@ -339,3 +348,227 @@ pub async fn local_execute(

Ok(execution.stream_rows())
}

pub fn handle_parse(
handler_args: HandlerArgs,
statement: Statement,
specific_param_types: Vec<DataType>,
) -> Result<PrepareStatement> {
let session = handler_args.session;
let mut binder = Binder::new_with_param_types(&session, specific_param_types);
let bound_statement = binder.bind(statement.clone())?;

let check_items = resolve_privileges(&bound_statement);
session.check_privileges(&check_items)?;

let param_types = binder.export_param_types()?;

Ok(PrepareStatement {
statement,
bound_statement,
param_types,
})
}

pub async fn handle_execute(handler_args: HandlerArgs, portal: Portal) -> Result<RwPgResponse> {
let Portal {
statement,
bound_statement,
result_formats,
} = portal;

let stmt_type = StatementType::infer_from_statement(&statement)
.map_err(|err| RwError::from(ErrorCode::InvalidInputSyntax(err)))?;
let session = handler_args.session.clone();
let query_start_time = Instant::now();
let only_checkpoint_visible = handler_args.session.config().only_checkpoint_visible();
let mut notice = String::new();

// Subblock to make sure PlanRef (an Rc) is dropped before `await` below.
let (plan_fragmenter, query_mode, output_schema) = {
let context = OptimizerContext::from_handler_args(handler_args);

let must_dist = must_run_in_distributed_mode(&statement)?;
let must_local = must_run_in_local_mode(&bound_statement);

let mut planner = Planner::new(context.into());

let mut logical = planner.plan(bound_statement)?;
let schema = logical.schema();
let batch_plan = logical.gen_batch_plan()?;

let query_mode = match (must_dist, must_local) {
(true, true) => {
return Err(ErrorCode::InternalError(
"the query is forced to both local and distributed mode by optimizer"
.to_owned(),
)
.into())
}
(true, false) => QueryMode::Distributed,
(false, true) => QueryMode::Local,
(false, false) => match session.config().get_query_mode() {
QueryMode::Auto => determine_query_mode(batch_plan.clone()),
QueryMode::Local => QueryMode::Local,
QueryMode::Distributed => QueryMode::Distributed,
},
};

let physical = match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => logical.gen_batch_local_plan(batch_plan)?,
QueryMode::Distributed => logical.gen_batch_distributed_plan(batch_plan)?,
};

let context = physical.plan_base().ctx.clone();
tracing::trace!(
"Generated query plan: {:?}, query_mode:{:?}",
physical.explain_to_string()?,
query_mode
);
let plan_fragmenter = BatchPlanFragmenter::new(
session.env().worker_node_manager_ref(),
session.env().catalog_reader().clone(),
session.config().get_batch_parallelism(),
physical,
)?;
context.append_notice(&mut notice);
(plan_fragmenter, query_mode, schema)
};
let query = plan_fragmenter.generate_complete_query().await?;
tracing::trace!("Generated query after plan fragmenter: {:?}", &query);

let pg_descs = output_schema
.fields()
.iter()
.map(to_pg_field)
.collect::<Vec<PgFieldDescriptor>>();
let column_types = output_schema
.fields()
.iter()
.map(|f| f.data_type())
.collect_vec();

// Used in counting row count.
let first_field_format = result_formats.first().copied().unwrap_or(Format::Text);

let mut row_stream = {
let query_epoch = session.config().get_query_epoch();
let query_snapshot = if let Some(query_epoch) = query_epoch {
PinnedHummockSnapshot::Other(query_epoch)
} else {
// Acquire hummock snapshot for execution.
// TODO: if there's no table scan, we don't need to acquire snapshot.
let hummock_snapshot_manager = session.env().hummock_snapshot_manager();
let query_id = query.query_id().clone();
let pinned_snapshot = hummock_snapshot_manager.acquire(&query_id).await?;
PinnedHummockSnapshot::FrontendPinned(pinned_snapshot, only_checkpoint_visible)
};
match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new(
local_execute(session.clone(), query, query_snapshot).await?,
column_types,
result_formats,
session.clone(),
)),
// Local mode do not support cancel tasks.
QueryMode::Distributed => {
PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new(
distribute_execute(session.clone(), query, query_snapshot).await?,
column_types,
result_formats,
session.clone(),
))
}
}
};

let rows_count: Option<i32> = match stmt_type {
StatementType::SELECT
| StatementType::INSERT_RETURNING
| StatementType::DELETE_RETURNING
| StatementType::UPDATE_RETURNING => None,

StatementType::INSERT | StatementType::DELETE | StatementType::UPDATE => {
let first_row_set = row_stream.next().await;
let first_row_set = match first_row_set {
None => {
return Err(RwError::from(ErrorCode::InternalError(
"no affected rows in output".to_string(),
)))
}
Some(row) => {
row.map_err(|err| RwError::from(ErrorCode::InternalError(format!("{}", err))))?
}
};
let affected_rows_str = first_row_set[0].values()[0]
.as_ref()
.expect("compute node should return affected rows in output");
if let Format::Binary = first_field_format {
Some(
i64::from_sql(&postgres_types::Type::INT8, affected_rows_str)
.unwrap()
.try_into()
.expect("affected rows count large than i32"),
)
} else {
Some(
String::from_utf8(affected_rows_str.to_vec())
.unwrap()
.parse()
.unwrap_or_default(),
)
}
}
_ => unreachable!(),
};

// We need to do some post work after the query is finished and before the `Complete` response
// it sent. This is achieved by the `callback` in `PgResponse`.
let callback = async move {
// Implicitly flush the writes.
if session.config().get_implicit_flush() && stmt_type.is_dml() {
do_flush(&session).await?;
}

// update some metrics
match query_mode {
QueryMode::Auto => unreachable!(),
QueryMode::Local => {
session
.env()
.frontend_metrics
.latency_local_execution
.observe(query_start_time.elapsed().as_secs_f64());

session
.env()
.frontend_metrics
.query_counter_local_execution
.inc();
}
QueryMode::Distributed => {
session
.env()
.query_manager()
.query_metrics
.query_latency
.observe(query_start_time.elapsed().as_secs_f64());

session
.env()
.query_manager()
.query_metrics
.completed_query_counter
.inc();
}
}

Ok(())
};

Ok(PgResponse::new_for_stream_extra(
stmt_type, rows_count, row_stream, pg_descs, notice, callback,
))
}

0 comments on commit 25a9127

Please sign in to comment.