Skip to content

Commit

Permalink
feat(frontend): support bind paramater (risingwavelabs#8543)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME authored Mar 15, 2023
1 parent f92d7f6 commit 1a11c3f
Show file tree
Hide file tree
Showing 16 changed files with 432 additions and 1 deletion.
195 changes: 195 additions & 0 deletions src/frontend/src/binder/bind_param.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::error::{Result, RwError};
use risingwave_common::types::ScalarImpl;

use super::statement::RewriteExprsRecursive;
use super::BoundStatement;
use crate::expr::{Expr, ExprImpl, ExprRewriter, Literal};

/// Rewrites parameter expressions to literals.
pub(crate) struct ParamRewriter {
pub(crate) params: Vec<Bytes>,
pub(crate) param_formats: Vec<Format>,
pub(crate) error: Option<RwError>,
}

impl ExprRewriter for ParamRewriter {
fn rewrite_expr(&mut self, expr: ExprImpl) -> ExprImpl {
if self.error.is_some() {
return expr;
}
match expr {
ExprImpl::InputRef(inner) => self.rewrite_input_ref(*inner),
ExprImpl::Literal(inner) => self.rewrite_literal(*inner),
ExprImpl::FunctionCall(inner) => self.rewrite_function_call(*inner),
ExprImpl::AggCall(inner) => self.rewrite_agg_call(*inner),
ExprImpl::Subquery(inner) => self.rewrite_subquery(*inner),
ExprImpl::CorrelatedInputRef(inner) => self.rewrite_correlated_input_ref(*inner),
ExprImpl::TableFunction(inner) => self.rewrite_table_function(*inner),
ExprImpl::WindowFunction(inner) => self.rewrite_window_function(*inner),
ExprImpl::UserDefinedFunction(inner) => self.rewrite_user_defined_function(*inner),
ExprImpl::Parameter(inner) => self.rewrite_parameter(*inner),
}
}

fn rewrite_parameter(&mut self, parameter: crate::expr::Parameter) -> ExprImpl {
let data_type = parameter.return_type();

// original parameter.index is 1-based.
let parameter_index = (parameter.index - 1) as usize;

let format = self.param_formats[parameter_index];
let scalar = {
let res = match format {
Format::Text => {
let value = self.params[parameter_index].clone();
ScalarImpl::from_text(&value, &data_type)
}
Format::Binary => {
let value = self.params[parameter_index].clone();
ScalarImpl::from_binary(&value, &data_type)
}
};

match res {
Ok(datum) => datum,
Err(e) => {
self.error = Some(e);
return parameter.into();
}
}
};
Literal::new(Some(scalar), data_type).into()
}
}

impl BoundStatement {
pub fn bind_parameter(
mut self,
params: Vec<Bytes>,
param_formats: Vec<Format>,
) -> Result<BoundStatement> {
let mut rewriter = ParamRewriter {
params,
param_formats,
error: None,
};

self.rewrite_exprs_recursive(&mut rewriter);

if let Some(err) = rewriter.error {
return Err(err);
}

Ok(self)
}
}

#[cfg(test)]
mod test {
use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::types::DataType;
use risingwave_sqlparser::test_utils::parse_sql_statements;

use crate::binder::test_utils::{mock_binder, mock_binder_with_param_types};
use crate::binder::BoundStatement;

fn create_expect_bound(sql: &str) -> BoundStatement {
let mut binder = mock_binder();
let stmt = parse_sql_statements(sql).unwrap().remove(0);
binder.bind(stmt).unwrap()
}

fn create_actual_bound(
sql: &str,
param_types: Vec<DataType>,
params: Vec<Bytes>,
param_formats: Vec<Format>,
) -> BoundStatement {
let mut binder = mock_binder_with_param_types(param_types);
let stmt = parse_sql_statements(sql).unwrap().remove(0);
let bound = binder.bind(stmt).unwrap();
bound.bind_parameter(params, param_formats).unwrap()
}

fn expect_actual_eq(expect: BoundStatement, actual: BoundStatement) {
// Use debug format to compare. May modify in future.
assert!(format!("{:?}", expect) == format!("{:?}", actual));
}

#[tokio::test]
async fn basic_select() {
expect_actual_eq(
create_expect_bound("select 1::int4"),
create_actual_bound(
"select $1::int4",
vec![],
vec!["1".into()],
vec![Format::Text],
),
);
}

#[tokio::test]
async fn basic_value() {
expect_actual_eq(
create_expect_bound("values(1::int4)"),
create_actual_bound(
"values($1::int4)",
vec![],
vec!["1".into()],
vec![Format::Text],
),
);
}

#[tokio::test]
async fn default_type() {
expect_actual_eq(
create_expect_bound("select '1'"),
create_actual_bound("select $1", vec![], vec!["1".into()], vec![Format::Text]),
);
}

#[tokio::test]
async fn cast_after_specific() {
expect_actual_eq(
create_expect_bound("select 1::varchar"),
create_actual_bound(
"select $1::varchar",
vec![DataType::Int32],
vec!["1".into()],
vec![Format::Text],
),
);
}

#[tokio::test]
async fn infer_case() {
expect_actual_eq(
create_expect_bound("select 1,1::INT4"),
create_actual_bound(
"select $1,$1::INT4",
vec![],
vec!["1".into()],
vec![Format::Text],
),
);
}
}
14 changes: 14 additions & 0 deletions src/frontend/src/binder/delete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use risingwave_common::catalog::{Schema, TableVersionId};
use risingwave_common::error::Result;
use risingwave_sqlparser::ast::{Expr, ObjectName, SelectItem};

use super::statement::RewriteExprsRecursive;
use super::{Binder, BoundBaseTable};
use crate::catalog::TableId;
use crate::expr::ExprImpl;
Expand Down Expand Up @@ -48,6 +49,19 @@ pub struct BoundDelete {
pub returning_schema: Option<Schema>,
}

impl RewriteExprsRecursive for BoundDelete {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
self.selection =
std::mem::take(&mut self.selection).map(|expr| rewriter.rewrite_expr(expr));

let new_returning_list = std::mem::take(&mut self.returning_list)
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect::<Vec<_>>();
self.returning_list = new_returning_list;
}
}

impl Binder {
pub(super) fn bind_delete(
&mut self,
Expand Down
19 changes: 19 additions & 0 deletions src/frontend/src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem, SetExpr};

use super::statement::RewriteExprsRecursive;
use super::{BoundQuery, BoundSetExpr};
use crate::binder::Binder;
use crate::catalog::TableId;
Expand Down Expand Up @@ -66,6 +67,24 @@ pub struct BoundInsert {
pub returning_schema: Option<Schema>,
}

impl RewriteExprsRecursive for BoundInsert {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
self.source.rewrite_exprs_recursive(rewriter);

let new_cast_exprs = std::mem::take(&mut self.cast_exprs)
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect::<Vec<_>>();
self.cast_exprs = new_cast_exprs;

let new_returning_list = std::mem::take(&mut self.returning_list)
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect::<Vec<_>>();
self.returning_list = new_returning_list;
}
}

impl Binder {
pub(super) fn bind_insert(
&mut self,
Expand Down
12 changes: 12 additions & 0 deletions src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use risingwave_common::util::iter_util::ZipEqDebug;
use risingwave_sqlparser::ast::Statement;

mod bind_context;
mod bind_param;
mod create;
mod delete;
mod expr;
Expand Down Expand Up @@ -209,6 +210,10 @@ impl Binder {
Self::new_inner(session, false, vec![])
}

pub fn new_with_param_types(session: &SessionImpl, param_types: Vec<DataType>) -> Binder {
Self::new_inner(session, false, param_types)
}

pub fn new_for_stream(session: &SessionImpl) -> Binder {
Self::new_inner(session, true, vec![])
}
Expand Down Expand Up @@ -295,13 +300,20 @@ impl Binder {

#[cfg(test)]
pub mod test_utils {
use risingwave_common::types::DataType;

use super::Binder;
use crate::session::SessionImpl;

#[cfg(test)]
pub fn mock_binder() -> Binder {
Binder::new(&SessionImpl::mock())
}

#[cfg(test)]
pub fn mock_binder_with_param_types(param_types: Vec<DataType>) -> Binder {
Binder::new_with_param_types(&SessionImpl::mock(), param_types)
}
}

/// The column name stored in [`BindContext`] for a column without an alias.
Expand Down
15 changes: 14 additions & 1 deletion src/frontend/src/binder/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With};

use super::statement::RewriteExprsRecursive;
use crate::binder::{Binder, BoundSetExpr};
use crate::expr::{CorrelatedId, Depth, ExprImpl};
use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};

/// A validated sql query, including order and union.
/// An example of its relationship with `BoundSetExpr` and `BoundSelect` can be found here: <https://bit.ly/3GQwgPz>
Expand Down Expand Up @@ -96,6 +97,18 @@ impl BoundQuery {
}
}

impl RewriteExprsRecursive for BoundQuery {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl ExprRewriter) {
let new_extra_order_exprs = std::mem::take(&mut self.extra_order_exprs)
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect::<Vec<_>>();
self.extra_order_exprs = new_extra_order_exprs;

self.body.rewrite_exprs_recursive(rewriter);
}
}

impl Binder {
/// Bind a [`Query`].
///
Expand Down
10 changes: 10 additions & 0 deletions src/frontend/src/binder/relation/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use risingwave_sqlparser::ast::{
};

use crate::binder::bind_context::BindContext;
use crate::binder::statement::RewriteExprsRecursive;
use crate::binder::{Binder, Relation, COLUMN_GROUP_PREFIX};
use crate::expr::ExprImpl;

Expand All @@ -30,6 +31,15 @@ pub struct BoundJoin {
pub cond: ExprImpl,
}

impl RewriteExprsRecursive for BoundJoin {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
self.left.rewrite_exprs_recursive(rewriter);
self.right.rewrite_exprs_recursive(rewriter);
let dummy = ExprImpl::literal_bool(true);
self.cond = rewriter.rewrite_expr(std::mem::replace(&mut self.cond, dummy));
}
}

impl Binder {
pub(crate) fn bind_vec_table_with_joins(
&mut self,
Expand Down
21 changes: 21 additions & 0 deletions src/frontend/src/binder/relation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use risingwave_sqlparser::ast::{

use self::watermark::is_watermark_func;
use super::bind_context::ColumnBinding;
use super::statement::RewriteExprsRecursive;
use crate::binder::{Binder, BoundSetExpr};
use crate::catalog::system_catalog::pg_catalog::{
PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME,
Expand Down Expand Up @@ -64,6 +65,26 @@ pub enum Relation {
Share(Box<BoundShare>),
}

impl RewriteExprsRecursive for Relation {
fn rewrite_exprs_recursive(&mut self, rewriter: &mut impl crate::expr::ExprRewriter) {
match self {
Relation::Subquery(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Join(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::WindowTableFunction(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Watermark(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Share(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::TableFunction(inner) => {
let new_args = std::mem::take(&mut inner.args)
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect();
inner.args = new_args;
}
_ => {}
}
}
}

impl Relation {
pub fn contains_sys_table(&self) -> bool {
match self {
Expand Down
Loading

0 comments on commit 1a11c3f

Please sign in to comment.