Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend): support update column to default value #8987

Merged
merged 3 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions e2e_test/batch/basic/dml.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ select v1, v2 from t order by v2;
45 810
35 1919

statement ok
update t set v1 = DEFAULT where v2 = 810;

query RI
select v1, v2 from t order by v2;
----
114 10
514 20
NULL 810
35 1919

# Delete

statement ok
Expand Down
8 changes: 8 additions & 0 deletions src/frontend/planner_test/tests/testdata/update.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
└─BatchUpdate { table: t, exprs: [$1::Int32, $1, $2] }
└─BatchExchange { order: [], dist: Single }
└─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id], distribution: UpstreamHashShard(t._row_id) }
- sql: |
create table t (v1 int, v2 real);
update t set v1 = DEFAULT;
batch_plan: |
BatchExchange { order: [], dist: Single }
└─BatchUpdate { table: t, exprs: [null:Int32, $1, $2] }
└─BatchExchange { order: [], dist: Single }
└─BatchScan { table: t, columns: [t.v1, t.v2, t._row_id], distribution: UpstreamHashShard(t._row_id) }
- sql: |
create table t (v1 int, v2 int);
update t set v1 = v2 + 1 where v2 > 0;
Expand Down
20 changes: 14 additions & 6 deletions src/frontend/src/binder/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use itertools::Itertools;
use risingwave_common::catalog::{Schema, TableVersionId};
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_sqlparser::ast::{Assignment, Expr, ObjectName, SelectItem};
use risingwave_sqlparser::ast::{Assignment, AssignmentValue, Expr, ObjectName, SelectItem};

use super::statement::RewriteExprsRecursive;
use super::{Binder, Relation};
Expand Down Expand Up @@ -114,17 +114,19 @@ impl Binder {
}

// (col1, col2) = (subquery)
(_ids, Expr::Subquery(_)) => {
(_ids, AssignmentValue::Expr(Expr::Subquery(_))) => {
return Err(ErrorCode::NotImplemented(
"subquery on the right side of multi-assignment".to_owned(),
None.into(),
)
.into())
}
// (col1, col2) = (expr1, expr2)
(ids, Expr::Row(values)) if ids.len() == values.len() => {
id.into_iter().zip_eq_fast(values.into_iter()).collect()
}
// TODO: support `DEFAULT` in multiple assignments
(ids, AssignmentValue::Expr(Expr::Row(values))) if ids.len() == values.len() => id
.into_iter()
.zip_eq_fast(values.into_iter().map(AssignmentValue::Expr))
.collect(),
// (col1, col2) = <other expr>
_ => {
return Err(ErrorCode::BindError(
Expand All @@ -148,7 +150,13 @@ impl Binder {
}
}

let value_expr = self.bind_expr(value)?.cast_assign(id_expr.return_type())?;
let value_expr = match value {
AssignmentValue::Expr(expr) => {
self.bind_expr(expr)?.cast_assign(id_expr.return_type())?
}
// TODO: specify default expression after we support non-`NULL` default values.
AssignmentValue::Default => ExprImpl::literal_null(id_expr.return_type()),
};

match assignment_exprs.entry(id_expr) {
Entry::Occupied(_) => {
Expand Down
22 changes: 20 additions & 2 deletions src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1814,12 +1814,30 @@ impl fmt::Display for GrantObjects {
}
}

/// SQL assignment `foo = expr` as used in SQLUpdate
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum AssignmentValue {
/// An expression, e.g. `foo = 1`
Expr(Expr),
/// The `DEFAULT` keyword, e.g. `foo = DEFAULT`
Default,
}

impl fmt::Display for AssignmentValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AssignmentValue::Expr(expr) => write!(f, "{}", expr),
AssignmentValue::Default => f.write_str("DEFAULT"),
}
}
}

/// SQL assignment `foo = { expr | DEFAULT }` as used in SQLUpdate
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Assignment {
pub id: Vec<Ident>,
pub value: Expr,
pub value: AssignmentValue,
}

impl fmt::Display for Assignment {
Expand Down
8 changes: 7 additions & 1 deletion src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4035,7 +4035,13 @@ impl Parser {
pub fn parse_assignment(&mut self) -> Result<Assignment, ParserError> {
let id = self.parse_identifiers_non_keywords()?;
self.expect_token(&Token::Eq)?;
let value = self.parse_expr()?;

let value = if self.parse_keyword(Keyword::DEFAULT) {
AssignmentValue::Default
} else {
AssignmentValue::Expr(self.parse_expr()?)
};

Ok(Assignment { id, value })
}

Expand Down
14 changes: 9 additions & 5 deletions src/sqlparser/tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn parse_insert_values() {

#[test]
fn parse_update() {
let sql = "UPDATE t SET a = 1, b = 2, c = 3 WHERE d";
let sql = "UPDATE t SET a = 1, b = 2, c = 3, d = DEFAULT WHERE e";
match verified_stmt(sql) {
Statement::Update {
table_name,
Expand All @@ -106,19 +106,23 @@ fn parse_update() {
vec![
Assignment {
id: vec!["a".into()],
value: Expr::Value(number("1")),
value: AssignmentValue::Expr(Expr::Value(number("1"))),
},
Assignment {
id: vec!["b".into()],
value: Expr::Value(number("2")),
value: AssignmentValue::Expr(Expr::Value(number("2"))),
},
Assignment {
id: vec!["c".into()],
value: Expr::Value(number("3")),
value: AssignmentValue::Expr(Expr::Value(number("3"))),
},
Assignment {
id: vec!["d".into()],
value: AssignmentValue::Default,
}
]
);
assert_eq!(selection.unwrap(), Expr::Identifier("d".into()));
assert_eq!(selection.unwrap(), Expr::Identifier("e".into()));
}
_ => unreachable!(),
}
Expand Down