Skip to content

Commit

Permalink
fix(binder): Incorrect cast when specifying columns (risingwavelabs#8770
Browse files Browse the repository at this point in the history
)

Co-authored-by: xxchan <xxchan22f@gmail.com>
  • Loading branch information
broccoliSpicy and xxchan authored Apr 5, 2023
1 parent ce25cf5 commit b67e00f
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 138 deletions.
24 changes: 24 additions & 0 deletions e2e_test/batch/basic/dml.slt.part
Original file line number Diff line number Diff line change
@@ -1,6 +1,30 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t1 (v1 real, v2 int, v3 varchar);


# Insert

statement ok
insert into t1 (v2, v1, v3) values (1, 2, 'a'), (3, 4, 'b');

query RI rowsort
select * from t1;
----
2 1 a
4 3 b

statement ok
insert into t1 (v2, v1) values (1, 2), (3, 4);

statement ok
insert into t1 values (1, 2), (3, 4);

statement ok
drop table t1;

statement ok
create table t (v1 real, v2 int);

Expand Down
6 changes: 3 additions & 3 deletions src/frontend/planner_test/tests/testdata/insert.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@
- name: To many target columns
sql: |
create table t (v1 int, v2 int);
insert into t (v1, v2, v2) values (5, 6);
binder_error: 'Bind error: INSERT has more target columns than values'
insert into t (v1, v2) values (5);
binder_error: 'Bind error: INSERT has more target columns than expressions'
- name: Not enough target columns
sql: |
create table t (v1 int, v2 int);
insert into t (v1) values (5, 6);
binder_error: 'Bind error: INSERT has less target columns than values'
binder_error: 'Bind error: INSERT has more expressions than target columns'
- name: insert literal null
sql: |
create table t(v1 int);
Expand Down
239 changes: 132 additions & 107 deletions src/frontend/src/binder/insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

use itertools::Itertools;
use risingwave_common::catalog::{Schema, TableVersionId};
use risingwave_common::catalog::{ColumnCatalog, Schema, TableVersionId};
use risingwave_common::error::{ErrorCode, Result, RwError};
use risingwave_common::types::DataType;
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem, SetExpr};
use risingwave_sqlparser::ast::{Ident, ObjectName, Query, SelectItem};

use super::statement::RewriteExprsRecursive;
use super::{BoundQuery, BoundSetExpr};
use super::BoundQuery;
use crate::binder::Binder;
use crate::catalog::TableId;
use crate::expr::{ExprImpl, InputRef};
Expand Down Expand Up @@ -89,7 +89,7 @@ impl Binder {
pub(super) fn bind_insert(
&mut self,
name: ObjectName,
columns: Vec<Ident>,
cols_to_insert_by_user: Vec<Ident>,
source: Query,
returning_items: Vec<SelectItem>,
) -> Result<BoundInsert> {
Expand All @@ -100,16 +100,11 @@ impl Binder {
let table_id = table_catalog.id;
let owner = table_catalog.owner;
let table_version_id = table_catalog.version_id().expect("table must be versioned");
let columns_to_insert = table_catalog.columns_to_insert().cloned().collect_vec();

let expected_types: Vec<DataType> = columns_to_insert
.iter()
.map(|c| c.data_type().clone())
.collect();
let cols_to_insert_in_table = table_catalog.columns_to_insert().cloned().collect_vec();

let generated_column_names: HashSet<_> = table_catalog.generated_column_names().collect();
for query_col in &columns {
let query_col_name = query_col.real_value();
for col in &cols_to_insert_by_user {
let query_col_name = col.real_value();
if generated_column_names.contains(query_col_name.as_str()) {
return Err(RwError::from(ErrorCode::BindError(format!(
"cannot insert a non-DEFAULT value into column \"{0}\". Column \"{0}\" is a generated column.",
Expand All @@ -135,57 +130,53 @@ impl Binder {
}
};

// When the column types of `source` query do not match `expected_types`, casting is
// needed.
let (returning_list, fields) = self.bind_returning_list(returning_items)?;
let is_returning = !returning_list.is_empty();

let col_indices_to_insert = get_col_indices_to_insert(
&cols_to_insert_in_table,
&cols_to_insert_by_user,
&table_name,
)?;
let expected_types: Vec<DataType> = col_indices_to_insert
.iter()
.map(|idx| cols_to_insert_in_table[*idx].data_type().clone())
.collect();

// When the column types of `source` query do not match `expected_types`,
// casting is needed.
//
// In PG, when the `source` is a `VALUES` without order / limit / offset, special treatment
// is given and it is NOT equivalent to assignment cast over potential implicit cast inside.
// For example, the following is valid:
//
// ```
// create table t (v1 time);
// insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
// ```
//
// But the followings are not:
//
// ```
// values (timestamp '2020-01-01 01:02:03'), (time '03:04:05');
// insert into t values (timestamp '2020-01-01 01:02:03'), (time '03:04:05') limit 1;
// ```
//
// Because `timestamp` can cast to `time` in assignment context, but no casting between them
// is allowed implicitly.
//
// In this case, assignment cast should be used directly in `VALUES`, suppressing its
// internal implicit cast.
// In other cases, the `source` query is handled on its own and assignment cast is done
// afterwards.
let (source, cast_exprs, nulls_inserted) = match source {
Query {
with: None,
body: SetExpr::Values(values),
order_by: order,
limit: None,
offset: None,
fetch: None,
} if order.is_empty() => {
let (values, nulls_inserted) =
self.bind_values(values, Some(expected_types.clone()))?;
let body = BoundSetExpr::Values(values.into());
(
BoundQuery {
body,
order: vec![],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: vec![],
},
vec![],
nulls_inserted,
)
}
query => {
let bound = self.bind_query(query)?;
let actual_types = bound.data_types();
let cast_exprs = match expected_types == actual_types {
let bound_query;
let cast_exprs;

match source.as_simple_values() {
None => {
bound_query = self.bind_query(source)?;
let actual_types = bound_query.data_types();
cast_exprs = match expected_types == actual_types {
true => vec![],
false => Self::cast_on_insert(
&expected_types,
Expand All @@ -196,71 +187,45 @@ impl Binder {
.collect(),
)?,
};
(bound, cast_exprs, false)
}
};

let mut target_table_col_indices: Vec<usize> = vec![];
'outer: for query_column in &columns {
let column_name = query_column.real_value();
for (col_idx, table_column) in columns_to_insert.iter().enumerate() {
if column_name == table_column.name() {
target_table_col_indices.push(col_idx);
continue 'outer;
Some(values) => {
assert!(!values.0.is_empty());
let num_value_cols = values.0[0].len();
let has_user_specified_columns = !cols_to_insert_by_user.is_empty();
let num_target_cols = if has_user_specified_columns {
cols_to_insert_by_user.len()
} else {
cols_to_insert_in_table.len()
};
let err_msg = match num_target_cols.cmp(&num_value_cols) {
std::cmp::Ordering::Equal => None,
std::cmp::Ordering::Greater => {
if has_user_specified_columns {
// e.g. insert into t (v1, v2) values (7)
Some("INSERT has more target columns than expressions")
} else {
// e.g. create table t (a int, b real)
// insert into t values (7)
// this kind of usage is fine, null values will be provided
// implicitly.
None
}
}
std::cmp::Ordering::Less => {
// e.g. create table t (a int, b real)
// insert into t (v1) values (7, 13)
// or insert into t values (7, 13, 17)
Some("INSERT has more expressions than target columns")
}
};
if let Some(msg) = err_msg {
return Err(RwError::from(ErrorCode::BindError(msg.to_string())));
}
}
// Invalid column name found
return Err(RwError::from(ErrorCode::BindError(format!(
"Column {} not found in table {}",
column_name, table_name
))));
}

// create table t1 (v1 int, v2 int); insert into t1 (v2) values (5);
// We added the null values above. Above is equivalent to
// insert into t1 values (NULL, 5);
let target_table_col_indices = if !target_table_col_indices.is_empty() && nulls_inserted {
let provided_insert_cols: HashSet<usize> =
target_table_col_indices.iter().cloned().collect();

let mut result: Vec<usize> = target_table_col_indices.clone();
for i in 0..columns_to_insert.len() {
if !provided_insert_cols.contains(&i) {
result.push(i);
}
let values = self.bind_values(values.clone(), Some(expected_types))?;
bound_query = BoundQuery::with_values(values);
cast_exprs = vec![];
}
result
} else {
target_table_col_indices
};

let (returning_list, fields) = self.bind_returning_list(returning_items)?;
let is_returning = !returning_list.is_empty();
// validate that query has a value for each target column, if target columns are used
// create table t1 (v1 int, v2 int);
// insert into t1 (v1, v2, v2) values (5, 6); // ...more target columns than values
// insert into t1 (v1) values (5, 6); // ...less target columns than values
let err_msg = match target_table_col_indices.len().cmp(&expected_types.len()) {
std::cmp::Ordering::Equal => None,
std::cmp::Ordering::Greater => Some("INSERT has more target columns than values"),
std::cmp::Ordering::Less => Some("INSERT has less target columns than values"),
};

if let Some(msg) = err_msg && !target_table_col_indices.is_empty() {
return Err(RwError::from(ErrorCode::BindError(
msg.to_string(),
)));
}

// Check if column was used multiple times in query e.g.
// insert into t1 (v1, v1) values (1, 5);
let mut uniq_cols = target_table_col_indices.clone();
uniq_cols.sort_unstable();
uniq_cols.dedup();
if target_table_col_indices.len() != uniq_cols.len() {
return Err(RwError::from(ErrorCode::BindError(
"Column specified more than once".to_string(),
)));
}

let insert = BoundInsert {
Expand All @@ -269,8 +234,8 @@ impl Binder {
table_name,
owner,
row_id_index,
column_indices: target_table_col_indices,
source,
column_indices: col_indices_to_insert,
source: bound_query,
cast_exprs,
returning_list,
returning_schema: if is_returning {
Expand Down Expand Up @@ -302,3 +267,63 @@ impl Binder {
Err(ErrorCode::BindError(msg.into()).into())
}
}

/// Returned indices have the same length as `cols_to_insert_in_table`.
/// The first elements have the same order as `cols_to_insert_by_user`.
/// The rest are what's not specified by the user.
///
/// Also checks there are no duplicate nor unknown columns provided by the user.
fn get_col_indices_to_insert(
cols_to_insert_in_table: &[ColumnCatalog],
cols_to_insert_by_user: &[Ident],
table_name: &str,
) -> Result<Vec<usize>> {
if cols_to_insert_by_user.is_empty() {
return Ok((0..cols_to_insert_in_table.len()).collect());
}

let mut col_indices_to_insert: Vec<usize> = Vec::new();

let mut col_name_to_idx: HashMap<String, usize> = HashMap::new();
for (col_idx, col) in cols_to_insert_in_table.iter().enumerate() {
col_name_to_idx.insert(col.name().to_string(), col_idx);
}

for col_name in cols_to_insert_by_user {
let col_name = &col_name.real_value();
match col_name_to_idx.get_mut(col_name) {
Some(value_ref) => {
if *value_ref == usize::MAX {
return Err(RwError::from(ErrorCode::BindError(
"Column specified more than once".to_string(),
)));
}
col_indices_to_insert.push(*value_ref);
*value_ref = usize::MAX; // mark this column name, for duplicate
// detection
}
None => {
// Invalid column name found
return Err(RwError::from(ErrorCode::BindError(format!(
"Column {} not found in table {}",
col_name, table_name
))));
}
}
}

// columns that are in the target table but not in the provided target columns
if col_indices_to_insert.len() != cols_to_insert_in_table.len() {
for col in cols_to_insert_in_table {
if let Some(col_to_insert_idx) = col_name_to_idx.get(col.name()) {
if *col_to_insert_idx != usize::MAX {
col_indices_to_insert.push(*col_to_insert_idx);
}
} else {
unreachable!();
}
}
}

Ok(col_indices_to_insert)
}
13 changes: 13 additions & 0 deletions src/frontend/src/binder/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_sqlparser::ast::{Cte, Expr, Fetch, OrderByExpr, Query, Value, With};

use super::statement::RewriteExprsRecursive;
use super::BoundValues;
use crate::binder::{Binder, BoundSetExpr};
use crate::expr::{CorrelatedId, Depth, ExprImpl, ExprRewriter};

Expand Down Expand Up @@ -95,6 +96,18 @@ impl BoundQuery {
self.body
.collect_correlated_indices_by_depth_and_assign_id(depth, correlated_id)
}

/// Simple `VALUES` without other clauses.
pub fn with_values(values: BoundValues) -> Self {
BoundQuery {
body: BoundSetExpr::Values(values.into()),
order: vec![],
limit: None,
offset: None,
with_ties: false,
extra_order_exprs: vec![],
}
}
}

impl RewriteExprsRecursive for BoundQuery {
Expand Down
Loading

0 comments on commit b67e00f

Please sign in to comment.