Skip to content

Commit

Permalink
refactor(frontend): refine extended handle (risingwavelabs#8992)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZENOTME authored Apr 4, 2023
1 parent 381e3b8 commit 29c4185
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 32 deletions.
43 changes: 33 additions & 10 deletions src/frontend/src/handler/extended_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ use bytes::Bytes;
use pgwire::types::Format;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{Query, Statement};
use risingwave_sqlparser::ast::{CreateSink, Query, Statement};

use super::{handle, query, HandlerArgs, RwPgResponse};
use crate::binder::BoundStatement;
use crate::session::SessionImpl;

/// Except for Query,Insert,Delete,Update statement, we store other statement as `PureStatement`.
/// We separate them because `PureStatement` don't have query and parameters (except
/// create-table-as, create-view-as, create-sink-as), so we don't need to do extra work(infer and
/// bind parameter) for them.
/// For create-table-as, create-view-as, create-sink-as with query parameters, we can't
/// support them. If we find that there are parameter in their query, we return a error otherwise we
/// store them as `PureStatement`.
#[derive(Clone)]
pub enum PrepareStatement {
Prepared(PreparedResult),
Expand Down Expand Up @@ -52,17 +59,17 @@ pub struct PortalResult {

pub fn handle_parse(
session: Arc<SessionImpl>,
stmt: Statement,
statement: Statement,
specific_param_types: Vec<DataType>,
) -> Result<PrepareStatement> {
session.clear_cancel_query_flag();
let str_sql = stmt.to_string();
let handler_args = HandlerArgs::new(session, &stmt, &str_sql)?;
match &stmt {
let str_sql = statement.to_string();
let handler_args = HandlerArgs::new(session, &statement, &str_sql)?;
match &statement {
Statement::Query(_)
| Statement::Insert { .. }
| Statement::Delete { .. }
| Statement::Update { .. } => query::handle_parse(handler_args, stmt, specific_param_types),
| Statement::Update { .. } => query::handle_parse(handler_args, statement, specific_param_types),
Statement::CreateView {
query,
..
Expand All @@ -74,7 +81,7 @@ pub fn handle_parse(
)
.into());
}
Ok(PrepareStatement::PureStatement(stmt))
Ok(PrepareStatement::PureStatement(statement))
}
Statement::CreateTable {
query,
Expand All @@ -86,10 +93,20 @@ pub fn handle_parse(
None.into(),
).into())
} else {
Ok(PrepareStatement::PureStatement(stmt))
Ok(PrepareStatement::PureStatement(statement))
}
}
_ => Ok(PrepareStatement::PureStatement(stmt)),
Statement::CreateSink { stmt } => {
if let CreateSink::AsQuery(query) = &stmt.sink_from && have_parameter_in_query(query) {
Err(ErrorCode::NotImplemented(
"CREATE SINK AS SELECT with parameters".to_string(),
None.into(),
).into())
} else {
Ok(PrepareStatement::PureStatement(statement))
}
}
_ => Ok(PrepareStatement::PureStatement(statement)),
}
}

Expand All @@ -113,7 +130,13 @@ pub fn handle_bind(
result_formats,
}))
}
PrepareStatement::PureStatement(stmt) => Ok(Portal::PureStatement(stmt)),
PrepareStatement::PureStatement(stmt) => {
assert!(
params.is_empty() && param_formats.is_empty(),
"params and param_formats should be empty for pure statement"
);
Ok(Portal::PureStatement(stmt))
}
}
}

Expand Down
97 changes: 75 additions & 22 deletions src/tests/e2e_extended_mode/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use pg_interval::Interval;
use rust_decimal::prelude::FromPrimitive;
use rust_decimal::Decimal;
use tokio_postgres::types::Type;
use tokio_postgres::NoTls;
use tokio_postgres::{Client, NoTls};

use crate::opts::Opts;

Expand All @@ -32,8 +32,10 @@ macro_rules! test_eq {
(left_val, right_val) => {
if !(*left_val == *right_val) {
return Err(anyhow!(
"assertion failed: `(left == right)` \
"{}:{} assertion failed: `(left == right)` \
(left: `{:?}`, right: `{:?}`)",
file!(),
line!(),
left_val,
right_val
));
Expand Down Expand Up @@ -67,11 +69,12 @@ impl TestSuite {
self.binary_param_and_result().await?;
self.dql_dml_with_param().await?;
self.max_row().await?;
self.multiple_on_going_portal().await?;
self.create_with_parameter().await?;
Ok(())
}

pub async fn binary_param_and_result(&self) -> anyhow::Result<()> {
// Connect to the database.
async fn create_client(&self) -> anyhow::Result<Client> {
let (client, connection) = tokio_postgres::connect(&self.config, NoTls).await?;

// The connection object performs the actual communication with the database,
Expand All @@ -82,6 +85,12 @@ impl TestSuite {
}
});

Ok(client)
}

pub async fn binary_param_and_result(&self) -> anyhow::Result<()> {
let client = self.create_client().await?;

for row in client.query("select $1::SMALLINT;", &[&1024_i16]).await? {
let data: i16 = row.try_get(0)?;
test_eq!(data, 1024);
Expand Down Expand Up @@ -192,15 +201,7 @@ impl TestSuite {
}

async fn dql_dml_with_param(&self) -> anyhow::Result<()> {
let (client, connection) = tokio_postgres::connect(&self.config, NoTls).await?;

// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let client = self.create_client().await?;

client.query("create table t(id int)", &[]).await?;

Expand Down Expand Up @@ -265,15 +266,7 @@ impl TestSuite {
}

async fn max_row(&self) -> anyhow::Result<()> {
let (mut client, connection) = tokio_postgres::connect(&self.config, NoTls).await?;

// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
let mut client = self.create_client().await?;

client.query("create table t(id int)", &[]).await?;

Expand Down Expand Up @@ -321,4 +314,64 @@ impl TestSuite {

Ok(())
}

async fn multiple_on_going_portal(&self) -> anyhow::Result<()> {
let mut client = self.create_client().await?;

let transaction = client.transaction().await?;
let statement = transaction
.prepare_typed("SELECT generate_series(1,5,1)", &[])
.await?;
let portal_1 = transaction.bind(&statement, &[]).await?;
let portal_2 = transaction.bind(&statement, &[]).await?;

let rows = transaction.query_portal(&portal_1, 1).await?;
test_eq!(rows.len(), 1);
test_eq!(rows.get(0).unwrap().get::<usize, i32>(0), 1);

let rows = transaction.query_portal(&portal_2, 1).await?;
test_eq!(rows.len(), 1);
test_eq!(rows.get(0).unwrap().get::<usize, i32>(0), 1);

let rows = transaction.query_portal(&portal_2, 3).await?;
test_eq!(rows.len(), 3);
test_eq!(rows.get(0).unwrap().get::<usize, i32>(0), 2);
test_eq!(rows.get(1).unwrap().get::<usize, i32>(0), 3);
test_eq!(rows.get(2).unwrap().get::<usize, i32>(0), 4);

let rows = transaction.query_portal(&portal_1, 1).await?;
test_eq!(rows.len(), 1);
test_eq!(rows.get(0).unwrap().get::<usize, i32>(0), 2);

Ok(())
}

// Can't support these sql
async fn create_with_parameter(&self) -> anyhow::Result<()> {
let client = self.create_client().await?;

test_eq!(
client
.query("create table t as select $1", &[])
.await
.is_err(),
true
);
test_eq!(
client
.query("create view v as select $1", &[])
.await
.is_err(),
true
);
test_eq!(
client
.query("create materialized view v as select $1", &[])
.await
.is_err(),
true
);

Ok(())
}
}

0 comments on commit 29c4185

Please sign in to comment.