diff --git a/e2e_test/batch/basic/dml.slt.part b/e2e_test/batch/basic/dml.slt.part index 1b4ad5e459c2f..4e7ec10b36122 100644 --- a/e2e_test/batch/basic/dml.slt.part +++ b/e2e_test/batch/basic/dml.slt.part @@ -39,6 +39,17 @@ select v1, v2 from t order by v2; 45 810 35 1919 +statement ok +update t set v1 = DEFAULT where v2 = 810; + +query RI +select v1, v2 from t order by v2; +---- +114 10 +514 20 +NULL 810 +35 1919 + # Delete statement ok diff --git a/src/frontend/planner_test/tests/testdata/update.yaml b/src/frontend/planner_test/tests/testdata/update.yaml index d8ec77dcac352..6163724b31d90 100644 --- a/src/frontend/planner_test/tests/testdata/update.yaml +++ b/src/frontend/planner_test/tests/testdata/update.yaml @@ -27,6 +27,14 @@ └─BatchUpdate { table: t, exprs: [$1::Int32, $1, $2] } └─BatchExchange { order: [], dist: Single } └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id], distribution: UpstreamHashShard(t._row_id) } +- sql: | + create table t (v1 int, v2 real); + update t set v1 = DEFAULT; + batch_plan: | + BatchExchange { order: [], dist: Single } + └─BatchUpdate { table: t, exprs: [null:Int32, $1, $2] } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id], distribution: UpstreamHashShard(t._row_id) } - sql: | create table t (v1 int, v2 int); update t set v1 = v2 + 1 where v2 > 0; diff --git a/src/frontend/src/binder/update.rs b/src/frontend/src/binder/update.rs index 75bec2a8f6e48..f2879aa8922e5 100644 --- a/src/frontend/src/binder/update.rs +++ b/src/frontend/src/binder/update.rs @@ -19,7 +19,7 @@ use itertools::Itertools; use risingwave_common::catalog::{Schema, TableVersionId}; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::util::iter_util::ZipEqFast; -use risingwave_sqlparser::ast::{Assignment, Expr, ObjectName, SelectItem}; +use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem}; use super::statement::RewriteExprsRecursive; use super::{Binder, Relation}; @@ -114,7 +114,7 @@ impl Binder { } // (col1, col2) = (subquery) - (_ids, Expr::Subquery(_)) => { + (_ids, AssignmentValue::Expr(Expr::Subquery(_))) => { return Err(ErrorCode::NotImplemented( "subquery on the right side of multi-assignment".to_owned(), None.into(), @@ -122,9 +122,11 @@ impl Binder { .into()) } // (col1, col2) = (expr1, expr2) - (ids, Expr::Row(values)) if ids.len() == values.len() => { - id.into_iter().zip_eq_fast(values.into_iter()).collect() - } + // TODO: support `DEFAULT` in multiple assignments + (ids, AssignmentValue::Expr(Expr::Row(values))) if ids.len() == values.len() => id + .into_iter() + .zip_eq_fast(values.into_iter().map(AssignmentValue::Expr)) + .collect(), // (col1, col2) = _ => { return Err(ErrorCode::BindError( @@ -148,7 +150,13 @@ impl Binder { } } - let value_expr = self.bind_expr(value)?.cast_assign(id_expr.return_type())?; + let value_expr = match value { + AssignmentValue::Expr(expr) => { + self.bind_expr(expr)?.cast_assign(id_expr.return_type())? + } + // TODO: specify default expression after we support non-`NULL` default values. + AssignmentValue::Default => ExprImpl::literal_null(id_expr.return_type()), + }; match assignment_exprs.entry(id_expr) { Entry::Occupied(_) => { diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 79af8fd138e33..b1f5619e4a874 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -1814,12 +1814,30 @@ impl fmt::Display for GrantObjects { } } -/// SQL assignment `foo = expr` as used in SQLUpdate +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] +pub enum AssignmentValue { + /// An expression, e.g. `foo = 1` + Expr(Expr), + /// The `DEFAULT` keyword, e.g. `foo = DEFAULT` + Default, +} + +impl fmt::Display for AssignmentValue { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AssignmentValue::Expr(expr) => write!(f, "{}", expr), + AssignmentValue::Default => f.write_str("DEFAULT"), + } + } +} + +/// SQL assignment `foo = { expr | DEFAULT }` as used in SQLUpdate #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] pub struct Assignment { pub id: Vec, - pub value: Expr, + pub value: AssignmentValue, } impl fmt::Display for Assignment { diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index ff784093fb0a3..ea0799e463015 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -4035,7 +4035,13 @@ impl Parser { pub fn parse_assignment(&mut self) -> Result { let id = self.parse_identifiers_non_keywords()?; self.expect_token(&Token::Eq)?; - let value = self.parse_expr()?; + + let value = if self.parse_keyword(Keyword::DEFAULT) { + AssignmentValue::Default + } else { + AssignmentValue::Expr(self.parse_expr()?) + }; + Ok(Assignment { id, value }) } diff --git a/src/sqlparser/tests/sqlparser_common.rs b/src/sqlparser/tests/sqlparser_common.rs index 361940cc3d0b0..b6a430b614bfe 100644 --- a/src/sqlparser/tests/sqlparser_common.rs +++ b/src/sqlparser/tests/sqlparser_common.rs @@ -92,7 +92,7 @@ fn parse_insert_values() { #[test] fn parse_update() { - let sql = "UPDATE t SET a = 1, b = 2, c = 3 WHERE d"; + let sql = "UPDATE t SET a = 1, b = 2, c = 3, d = DEFAULT WHERE e"; match verified_stmt(sql) { Statement::Update { table_name, @@ -106,19 +106,23 @@ fn parse_update() { vec![ Assignment { id: vec!["a".into()], - value: Expr::Value(number("1")), + value: AssignmentValue::Expr(Expr::Value(number("1"))), }, Assignment { id: vec!["b".into()], - value: Expr::Value(number("2")), + value: AssignmentValue::Expr(Expr::Value(number("2"))), }, Assignment { id: vec!["c".into()], - value: Expr::Value(number("3")), + value: AssignmentValue::Expr(Expr::Value(number("3"))), }, + Assignment { + id: vec!["d".into()], + value: AssignmentValue::Default, + } ] ); - assert_eq!(selection.unwrap(), Expr::Identifier("d".into())); + assert_eq!(selection.unwrap(), Expr::Identifier("e".into())); } _ => unreachable!(), }