From b67e00fb6d044ff2f9ffa87f07b7708038505e76 Mon Sep 17 00:00:00 2001 From: broccoliSpicy <93440049+broccoliSpicy@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:21:33 +0800 Subject: [PATCH] fix(binder): Incorrect cast when specifying columns (#8770) Co-authored-by: xxchan --- e2e_test/batch/basic/dml.slt.part | 24 ++ .../planner_test/tests/testdata/insert.yaml | 6 +- src/frontend/src/binder/insert.rs | 239 ++++++++++-------- src/frontend/src/binder/query.rs | 13 + src/frontend/src/binder/set_expr.rs | 2 +- src/frontend/src/binder/values.rs | 42 ++- src/sqlparser/src/ast/query.rs | 17 ++ 7 files changed, 205 insertions(+), 138 deletions(-) diff --git a/e2e_test/batch/basic/dml.slt.part b/e2e_test/batch/basic/dml.slt.part index 4e7ec10b36122..ec60471d0990c 100644 --- a/e2e_test/batch/basic/dml.slt.part +++ b/e2e_test/batch/basic/dml.slt.part @@ -1,6 +1,30 @@ statement ok SET RW_IMPLICIT_FLUSH TO true; +statement ok +create table t1 (v1 real, v2 int, v3 varchar); + + +# Insert + +statement ok +insert into t1 (v2, v1, v3) values (1, 2, 'a'), (3, 4, 'b'); + +query RI rowsort +select * from t1; +---- +2 1 a +4 3 b + +statement ok +insert into t1 (v2, v1) values (1, 2), (3, 4); + +statement ok +insert into t1 values (1, 2), (3, 4); + +statement ok +drop table t1; + statement ok create table t (v1 real, v2 int); diff --git a/src/frontend/planner_test/tests/testdata/insert.yaml b/src/frontend/planner_test/tests/testdata/insert.yaml index 423f8e0e97f7e..9cacf3bf43ca2 100644 --- a/src/frontend/planner_test/tests/testdata/insert.yaml +++ b/src/frontend/planner_test/tests/testdata/insert.yaml @@ -107,13 +107,13 @@ - name: To many target columns sql: | create table t (v1 int, v2 int); - insert into t (v1, v2, v2) values (5, 6); - binder_error: 'Bind error: INSERT has more target columns than values' + insert into t (v1, v2) values (5); + binder_error: 'Bind error: INSERT has more target columns than expressions' - name: Not enough target columns sql: | create table t (v1 int, v2 int); insert into t (v1) values (5, 6); - binder_error: 'Bind error: INSERT has less target columns than values' + binder_error: 'Bind error: INSERT has more expressions than target columns' - name: insert literal null sql: | create table t(v1 int); diff --git a/src/frontend/src/binder/insert.rs b/src/frontend/src/binder/insert.rs index d9a7ca39ccc4f..0c38d61e13ff4 100644 --- a/src/frontend/src/binder/insert.rs +++ b/src/frontend/src/binder/insert.rs @@ -12,17 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use itertools::Itertools; -use risingwave_common::catalog::{Schema, TableVersionId}; +use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId}; use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::types::DataType; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem, SetExpr}; +use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem}; use super::statement::RewriteExprsRecursive; -use super::{BoundQuery, BoundSetExpr}; +use super::BoundQuery; use crate::binder::Binder; use crate::catalog::TableId; use crate::expr::{ExprImpl, InputRef}; @@ -89,7 +89,7 @@ impl Binder { pub(super) fn bind_insert( &mut self, name: ObjectName, - columns: Vec, + cols_to_insert_by_user: Vec, source: Query, returning_items: Vec, ) -> Result { @@ -100,16 +100,11 @@ impl Binder { let table_id = table_catalog.id; let owner = table_catalog.owner; let table_version_id = table_catalog.version_id().expect("table must be versioned"); - let columns_to_insert = table_catalog.columns_to_insert().cloned().collect_vec(); - - let expected_types: Vec = columns_to_insert - .iter() - .map(|c| c.data_type().clone()) - .collect(); + let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec(); let generated_column_names: HashSet<_> = table_catalog.generated_column_names().collect(); - for query_col in &columns { - let query_col_name = query_col.real_value(); + for col in &cols_to_insert_by_user { + let query_col_name = col.real_value(); if generated_column_names.contains(query_col_name.as_str()) { return Err(RwError::from(ErrorCode::BindError(format!( "cannot insert a non-DEFAULT value into column \"{0}\". Column \"{0}\" is a generated column.", @@ -135,21 +130,38 @@ impl Binder { } }; - // When the column types of `source` query do not match `expected_types`, casting is - // needed. + let (returning_list, fields) = self.bind_returning_list(returning_items)?; + let is_returning = !returning_list.is_empty(); + + let col_indices_to_insert = get_col_indices_to_insert( + &cols_to_insert_in_table, + &cols_to_insert_by_user, + &table_name, + )?; + let expected_types: Vec = col_indices_to_insert + .iter() + .map(|idx| cols_to_insert_in_table[*idx].data_type().clone()) + .collect(); + + // When the column types of `source` query do not match `expected_types`, + // casting is needed. // // In PG, when the `source` is a `VALUES` without order / limit / offset, special treatment // is given and it is NOT equivalent to assignment cast over potential implicit cast inside. // For example, the following is valid: + // // ``` // create table t (v1 time); // insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05'); // ``` + // // But the followings are not: + // // ``` // values (timestamp '2020-01-01 01:02:03'), (time '03:04:05'); // insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05') limit 1; // ``` + // // Because `timestamp` can cast to `time` in assignment context, but no casting between them // is allowed implicitly. // @@ -157,35 +169,14 @@ impl Binder { // internal implicit cast. // In other cases, the `source` query is handled on its own and assignment cast is done // afterwards. - let (source, cast_exprs, nulls_inserted) = match source { - Query { - with: None, - body: SetExpr::Values(values), - order_by: order, - limit: None, - offset: None, - fetch: None, - } if order.is_empty() => { - let (values, nulls_inserted) = - self.bind_values(values, Some(expected_types.clone()))?; - let body = BoundSetExpr::Values(values.into()); - ( - BoundQuery { - body, - order: vec![], - limit: None, - offset: None, - with_ties: false, - extra_order_exprs: vec![], - }, - vec![], - nulls_inserted, - ) - } - query => { - let bound = self.bind_query(query)?; - let actual_types = bound.data_types(); - let cast_exprs = match expected_types == actual_types { + let bound_query; + let cast_exprs; + + match source.as_simple_values() { + None => { + bound_query = self.bind_query(source)?; + let actual_types = bound_query.data_types(); + cast_exprs = match expected_types == actual_types { true => vec![], false => Self::cast_on_insert( &expected_types, @@ -196,71 +187,45 @@ impl Binder { .collect(), )?, }; - (bound, cast_exprs, false) } - }; - - let mut target_table_col_indices: Vec = vec![]; - 'outer: for query_column in &columns { - let column_name = query_column.real_value(); - for (col_idx, table_column) in columns_to_insert.iter().enumerate() { - if column_name == table_column.name() { - target_table_col_indices.push(col_idx); - continue 'outer; + Some(values) => { + assert!(!values.0.is_empty()); + let num_value_cols = values.0[0].len(); + let has_user_specified_columns = !cols_to_insert_by_user.is_empty(); + let num_target_cols = if has_user_specified_columns { + cols_to_insert_by_user.len() + } else { + cols_to_insert_in_table.len() + }; + let err_msg = match num_target_cols.cmp(&num_value_cols) { + std::cmp::Ordering::Equal => None, + std::cmp::Ordering::Greater => { + if has_user_specified_columns { + // e.g. insert into t (v1, v2) values (7) + Some("INSERT has more target columns than expressions") + } else { + // e.g. create table t (a int, b real) + // insert into t values (7) + // this kind of usage is fine, null values will be provided + // implicitly. + None + } + } + std::cmp::Ordering::Less => { + // e.g. create table t (a int, b real) + // insert into t (v1) values (7, 13) + // or insert into t values (7, 13, 17) + Some("INSERT has more expressions than target columns") + } + }; + if let Some(msg) = err_msg { + return Err(RwError::from(ErrorCode::BindError(msg.to_string()))); } - } - // Invalid column name found - return Err(RwError::from(ErrorCode::BindError(format!( - "Column {} not found in table {}", - column_name, table_name - )))); - } - // create table t1 (v1 int, v2 int); insert into t1 (v2) values (5); - // We added the null values above. Above is equivalent to - // insert into t1 values (NULL, 5); - let target_table_col_indices = if !target_table_col_indices.is_empty() && nulls_inserted { - let provided_insert_cols: HashSet = - target_table_col_indices.iter().cloned().collect(); - - let mut result: Vec = target_table_col_indices.clone(); - for i in 0..columns_to_insert.len() { - if !provided_insert_cols.contains(&i) { - result.push(i); - } + let values = self.bind_values(values.clone(), Some(expected_types))?; + bound_query = BoundQuery::with_values(values); + cast_exprs = vec![]; } - result - } else { - target_table_col_indices - }; - - let (returning_list, fields) = self.bind_returning_list(returning_items)?; - let is_returning = !returning_list.is_empty(); - // validate that query has a value for each target column, if target columns are used - // create table t1 (v1 int, v2 int); - // insert into t1 (v1, v2, v2) values (5, 6); // ...more target columns than values - // insert into t1 (v1) values (5, 6); // ...less target columns than values - let err_msg = match target_table_col_indices.len().cmp(&expected_types.len()) { - std::cmp::Ordering::Equal => None, - std::cmp::Ordering::Greater => Some("INSERT has more target columns than values"), - std::cmp::Ordering::Less => Some("INSERT has less target columns than values"), - }; - - if let Some(msg) = err_msg && !target_table_col_indices.is_empty() { - return Err(RwError::from(ErrorCode::BindError( - msg.to_string(), - ))); - } - - // Check if column was used multiple times in query e.g. - // insert into t1 (v1, v1) values (1, 5); - let mut uniq_cols = target_table_col_indices.clone(); - uniq_cols.sort_unstable(); - uniq_cols.dedup(); - if target_table_col_indices.len() != uniq_cols.len() { - return Err(RwError::from(ErrorCode::BindError( - "Column specified more than once".to_string(), - ))); } let insert = BoundInsert { @@ -269,8 +234,8 @@ impl Binder { table_name, owner, row_id_index, - column_indices: target_table_col_indices, - source, + column_indices: col_indices_to_insert, + source: bound_query, cast_exprs, returning_list, returning_schema: if is_returning { @@ -302,3 +267,63 @@ impl Binder { Err(ErrorCode::BindError(msg.into()).into()) } } + +/// Returned indices have the same length as `cols_to_insert_in_table`. +/// The first elements have the same order as `cols_to_insert_by_user`. +/// The rest are what's not specified by the user. +/// +/// Also checks there are no duplicate nor unknown columns provided by the user. +fn get_col_indices_to_insert( + cols_to_insert_in_table: &[ColumnCatalog], + cols_to_insert_by_user: &[Ident], + table_name: &str, +) -> Result> { + if cols_to_insert_by_user.is_empty() { + return Ok((0..cols_to_insert_in_table.len()).collect()); + } + + let mut col_indices_to_insert: Vec = Vec::new(); + + let mut col_name_to_idx: HashMap = HashMap::new(); + for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() { + col_name_to_idx.insert(col.name().to_string(), col_idx); + } + + for col_name in cols_to_insert_by_user { + let col_name = &col_name.real_value(); + match col_name_to_idx.get_mut(col_name) { + Some(value_ref) => { + if *value_ref == usize::MAX { + return Err(RwError::from(ErrorCode::BindError( + "Column specified more than once".to_string(), + ))); + } + col_indices_to_insert.push(*value_ref); + *value_ref = usize::MAX; // mark this column name, for duplicate + // detection + } + None => { + // Invalid column name found + return Err(RwError::from(ErrorCode::BindError(format!( + "Column {} not found in table {}", + col_name, table_name + )))); + } + } + } + + // columns that are in the target table but not in the provided target columns + if col_indices_to_insert.len() != cols_to_insert_in_table.len() { + for col in cols_to_insert_in_table { + if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) { + if *col_to_insert_idx != usize::MAX { + col_indices_to_insert.push(*col_to_insert_idx); + } + } else { + unreachable!(); + } + } + } + + Ok(col_indices_to_insert) +} diff --git a/src/frontend/src/binder/query.rs b/src/frontend/src/binder/query.rs index 3cac14efd1cdf..9381ecdd74229 100644 --- a/src/frontend/src/binder/query.rs +++ b/src/frontend/src/binder/query.rs @@ -22,6 +22,7 @@ use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With}; use super::statement::RewriteExprsRecursive; +use super::BoundValues; use crate::binder::{Binder, BoundSetExpr}; use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter}; @@ -95,6 +96,18 @@ impl BoundQuery { self.body .collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id) } + + /// Simple `VALUES` without other clauses. + pub fn with_values(values: BoundValues) -> Self { + BoundQuery { + body: BoundSetExpr::Values(values.into()), + order: vec![], + limit: None, + offset: None, + with_ties: false, + extra_order_exprs: vec![], + } + } } impl RewriteExprsRecursive for BoundQuery { diff --git a/src/frontend/src/binder/set_expr.rs b/src/frontend/src/binder/set_expr.rs index dd1d646ee3daa..5696e90a0b7dd 100644 --- a/src/frontend/src/binder/set_expr.rs +++ b/src/frontend/src/binder/set_expr.rs @@ -114,7 +114,7 @@ impl Binder { pub(super) fn bind_set_expr(&mut self, set_expr: SetExpr) -> Result { match set_expr { SetExpr::Select(s) => Ok(BoundSetExpr::Select(Box::new(self.bind_select(*s)?))), - SetExpr::Values(v) => Ok(BoundSetExpr::Values(Box::new(self.bind_values(v, None)?.0))), + SetExpr::Values(v) => Ok(BoundSetExpr::Values(Box::new(self.bind_values(v, None)?))), SetExpr::Query(q) => Ok(BoundSetExpr::Query(Box::new(self.bind_query(*q)?))), SetExpr::SetOperation { op, diff --git a/src/frontend/src/binder/values.rs b/src/frontend/src/binder/values.rs index 44ac158270d87..1cd164dae9979 100644 --- a/src/frontend/src/binder/values.rs +++ b/src/frontend/src/binder/values.rs @@ -84,12 +84,12 @@ fn values_column_name(values_id: usize, col_id: usize) -> String { impl Binder { /// Bind [`Values`] with given `expected_types`. If no types are expected, a compatible type for /// all rows will be used. - /// Returns true if null values were inserted + /// If values are shorter than expected, `NULL`s will be filled. pub(super) fn bind_values( &mut self, values: Values, expected_types: Option>, - ) -> Result<(BoundValues, bool)> { + ) -> Result { assert!(!values.0.is_empty()); self.context.clause = Some(Clause::Values); @@ -102,32 +102,21 @@ impl Binder { // Adding Null values in case user did not specify all columns. E.g. // create table t1 (v1 int, v2 int); insert into t1 (v2) values (5); - let vec_len = bound[0].len(); - let nulls_to_insert = if let Some(expected_types) = &expected_types && expected_types.len() > vec_len { - let nulls_to_insert = expected_types.len() - vec_len; + let mut num_columns = bound[0].len(); + if bound.iter().any(|row| row.len() != num_columns) { + return Err( + ErrorCode::BindError("VALUES lists must all be the same length".into()).into(), + ); + } + if let Some(expected_types) = &expected_types && expected_types.len() > num_columns { + let nulls_to_insert = expected_types.len() - num_columns; for row in &mut bound { - if vec_len != row.len() { - return Err(ErrorCode::BindError( - "VALUES lists must all be the same length".into(), - ) - .into()); - } for i in 0..nulls_to_insert { - let t = expected_types[vec_len + i].clone(); + let t = expected_types[num_columns + i].clone(); row.push(ExprImpl::literal_null(t)); } } - nulls_to_insert - } else { - 0 - }; - - // only check for this condition again if we did not insert any nulls - let num_columns = bound[0].len(); - if nulls_to_insert == 0 && bound.iter().any(|row| row.len() != num_columns) { - return Err( - ErrorCode::BindError("VALUES lists must all be the same length".into()).into(), - ); + num_columns = expected_types.len(); } // Calculate column types. @@ -173,13 +162,12 @@ impl Binder { ) .into()); } - Ok((bound_values, nulls_to_insert > 0)) + Ok(bound_values) } } #[cfg(test)] mod tests { - use risingwave_common::util::iter_util::zip_eq_fast; use risingwave_sqlparser::ast::{Expr, Value}; @@ -207,8 +195,8 @@ mod tests { .collect(), ); - assert_eq!(res.0.schema, schema); - for vec in res.0.rows { + assert_eq!(res.schema, schema); + for vec in res.rows { for (expr, ty) in zip_eq_fast(vec, schema.data_types()) { assert_eq!(expr.return_type(), ty); } diff --git a/src/sqlparser/src/ast/query.rs b/src/sqlparser/src/ast/query.rs index 31b7074b6e67d..c6224434cff6e 100644 --- a/src/sqlparser/src/ast/query.rs +++ b/src/sqlparser/src/ast/query.rs @@ -43,6 +43,23 @@ pub struct Query { pub fetch: Option, } +impl Query { + /// Simple `VALUES` without other clauses. + pub fn as_simple_values(&self) -> Option<&Values> { + match &self { + Query { + with: None, + body: SetExpr::Values(values), + order_by, + limit: None, + offset: None, + fetch: None, + } if order_by.is_empty() => Some(values), + _ => None, + } + } +} + impl fmt::Display for Query { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(ref with) = self.with {