Skip to content

Commit

Permalink
planner: Support assign DEFAULT in ON DUPLICATE KEY UPDATE statem…
Browse files Browse the repository at this point in the history
…ent (pingcap#13168)
  • Loading branch information
Deardrops authored and XiaTianliang committed Dec 21, 2019
1 parent a63ca96 commit 2c333de
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 44 deletions.
6 changes: 1 addition & 5 deletions ddl/db_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1929,11 +1929,7 @@ func (s *testIntegrationSuite3) TestInsertIntoGeneratedColumnWithDefaultExpr(c *
tk.MustExec("create table t5 (a int default 10, b int as (a+1))")
tk.MustGetErrCode("insert into t5 values (20, default(a))", mysql.ErrBadGeneratedColumn)

tk.MustExec("drop table t1")
tk.MustExec("drop table t2")
tk.MustExec("drop table t3")
tk.MustExec("drop table t4")
tk.MustExec("drop table t5")
tk.MustExec("drop table t1, t2, t3, t4, t5")
}

func (s *testIntegrationSuite3) TestSqlFunctionsInGeneratedColumns(c *C) {
Expand Down
35 changes: 35 additions & 0 deletions executor/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,41 @@ func (s *testSuite4) TestInsertSetWithDefault(c *C) {
tk.MustExec("drop table t1, t2")
}

func (s *testSuite4) TestInsertOnDupUpdateDefault(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
// Assign `DEFAULT` in `INSERT ... ON DUPLICATE KEY UPDATE ...` statement
tk.MustExec("drop table if exists t1, t2;")
tk.MustExec("create table t1 (a int unique, b int default 20, c int default 30);")
tk.MustExec("insert into t1 values (1,default,default);")
tk.MustExec("insert into t1 values (1,default,default) on duplicate key update b=default;")
tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30"))
tk.MustExec("insert into t1 values (1,default,default) on duplicate key update c=default, b=default;")
tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30"))
tk.MustExec("insert into t1 values (1,default,default) on duplicate key update c=default, a=2")
tk.MustQuery("select * from t1;").Check(testkit.Rows("2 20 30"))
tk.MustExec("insert into t1 values (2,default,default) on duplicate key update c=default(b)")
tk.MustQuery("select * from t1;").Check(testkit.Rows("2 20 20"))
tk.MustExec("insert into t1 values (2,default,default) on duplicate key update a=default(b)+default(c)")
tk.MustQuery("select * from t1;").Check(testkit.Rows("50 20 20"))
// With generated columns
tk.MustExec("create table t2 (a int unique, b int generated always as (-a) virtual, c int generated always as (-a) stored);")
tk.MustExec("insert into t2 values (1,default,default);")
tk.MustExec("insert into t2 values (1,default,default) on duplicate key update a=2, b=default;")
tk.MustQuery("select * from t2").Check(testkit.Rows("2 -2 -2"))
tk.MustExec("insert into t2 values (2,default,default) on duplicate key update a=3, c=default;")
tk.MustQuery("select * from t2").Check(testkit.Rows("3 -3 -3"))
tk.MustExec("insert into t2 values (3,default,default) on duplicate key update c=default, b=default, a=4;")
tk.MustQuery("select * from t2").Check(testkit.Rows("4 -4 -4"))
tk.MustExec("insert into t2 values (10,default,default) on duplicate key update b=default, a=20, c=default;")
tk.MustQuery("select * from t2").Check(testkit.Rows("4 -4 -4", "10 -10 -10"))
tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update b=default(a);", mysql.ErrBadGeneratedColumn)
tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(b), b=default(b);", mysql.ErrBadGeneratedColumn)
tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(a), c=default(c);", mysql.ErrBadGeneratedColumn)
tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(a), c=default(a);", mysql.ErrBadGeneratedColumn)
tk.MustExec("drop table t1, t2")
}

func (s *testSuite4) TestReplace(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
17 changes: 8 additions & 9 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3020,7 +3020,7 @@ func (b *PlanBuilder) buildUpdateLists(
columnFullName := fmt.Sprintf("%s.%s.%s", name.DBName.L, name.TblName.L, name.ColName.L)
// We save a flag for the column in map `modifyColumns`
// This flag indicated if assign keyword `DEFAULT` to the column
if _, ok := extractDefaultExpr(assign.Expr); ok {
if extractDefaultExpr(assign.Expr) != nil {
modifyColumns[columnFullName] = true
} else {
modifyColumns[columnFullName] = false
Expand Down Expand Up @@ -3073,7 +3073,7 @@ func (b *PlanBuilder) buildUpdateLists(
var np LogicalPlan
if i < len(list) {
// If assign `DEFAULT` to column, fill the `defaultExpr.Name` before rewrite expression
if expr, ok := extractDefaultExpr(assign.Expr); ok {
if expr := extractDefaultExpr(assign.Expr); expr != nil {
expr.Name = assign.Column
}
newExpr, np, err = b.rewrite(ctx, assign.Expr, p, nil, false)
Expand Down Expand Up @@ -3115,15 +3115,14 @@ func (b *PlanBuilder) buildUpdateLists(
return newList, p, allAssignmentsAreConstant, nil
}

// extractDefaultExpr extract a `DefaultExpr` without any parameter from a `ExprNode`,
// return the `DefaultExpr` and whether it's extracted successfully.
// Note: the SQL function `DEFAULT(a)` is not the same with keyword `DEFAULT`,
// SQL function `DEFAULT(a)` will return `false`.
func extractDefaultExpr(node ast.ExprNode) (*ast.DefaultExpr, bool) {
// extractDefaultExpr extract a `DefaultExpr` from `ExprNode`,
// If it is a `DEFAULT` function like `DEFAULT(a)`, return nil.
// Only if it is `DEFAULT` keyword, it will return the `DefaultExpr`.
func extractDefaultExpr(node ast.ExprNode) *ast.DefaultExpr {
if expr, ok := node.(*ast.DefaultExpr); ok && expr.Name == nil {
return expr, true
return expr
}
return nil, false
return nil
}

func (b *PlanBuilder) buildDelete(ctx context.Context, delete *ast.DeleteStmt) (Plan, error) {
Expand Down
67 changes: 37 additions & 30 deletions planner/core/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1829,27 +1829,13 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) (

mockTablePlan.SetSchema(insertPlan.Schema4OnDuplicate)
mockTablePlan.names = insertPlan.names4OnDuplicate
columnByName := make(map[string]*table.Column, len(insertPlan.Table.Cols()))
for _, col := range insertPlan.Table.Cols() {
columnByName[col.Name.L] = col
}
onDupColSet, dupCols, dupColNames, err := insertPlan.validateOnDup(insert.OnDuplicate, columnByName, tableInfo)

onDupColSet, err := insertPlan.resolveOnDuplicate(insert.OnDuplicate, tableInfo, func(node ast.ExprNode) (expression.Expression, error) {
return b.rewriteInsertOnDuplicateUpdate(ctx, node, mockTablePlan, insertPlan)
})
if err != nil {
return nil, err
}
for i, assign := range insert.OnDuplicate {
// Construct the function which calculates the assign value of the column.
expr, err1 := b.rewriteInsertOnDuplicateUpdate(ctx, assign.Expr, mockTablePlan, insertPlan)
if err1 != nil {
return nil, err1
}

insertPlan.OnDuplicate = append(insertPlan.OnDuplicate, &expression.Assignment{
Col: dupCols[i],
ColName: dupColNames[i].ColName,
Expr: expr,
})
}

// Calculate generated columns.
mockTablePlan.schema = insertPlan.tableSchema
Expand All @@ -1863,29 +1849,50 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) (
return insertPlan, err
}

func (p *Insert) validateOnDup(onDup []*ast.Assignment, colMap map[string]*table.Column, tblInfo *model.TableInfo) (map[string]struct{}, []*expression.Column, types.NameSlice, error) {
func (p *Insert) resolveOnDuplicate(onDup []*ast.Assignment, tblInfo *model.TableInfo, yield func(ast.ExprNode) (expression.Expression, error)) (map[string]struct{}, error) {
onDupColSet := make(map[string]struct{}, len(onDup))
dupCols := make([]*expression.Column, 0, len(onDup))
dupColNames := make(types.NameSlice, 0, len(onDup))
colMap := make(map[string]*table.Column, len(p.Table.Cols()))
for _, col := range p.Table.Cols() {
colMap[col.Name.L] = col
}
for _, assign := range onDup {
// Check whether the column to be updated exists in the source table.
idx, err := expression.FindFieldName(p.tableColNames, assign.Column)
if err != nil {
return nil, nil, nil, err
return nil, err
} else if idx < 0 {
return nil, nil, nil, ErrUnknownColumn.GenWithStackByArgs(assign.Column.OrigColName(), "field list")
return nil, ErrUnknownColumn.GenWithStackByArgs(assign.Column.OrigColName(), "field list")
}

// Check whether the column to be updated is the generated column.
column := colMap[assign.Column.Name.L]
defaultExpr := extractDefaultExpr(assign.Expr)
if defaultExpr != nil {
defaultExpr.Name = assign.Column
}
// Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT.
// see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html
if column.IsGenerated() {
return nil, nil, nil, ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tblInfo.Name.O)
if defaultExpr != nil {
continue
}
return nil, ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tblInfo.Name.O)
}

onDupColSet[column.Name.L] = struct{}{}
dupCols = append(dupCols, p.tableSchema.Columns[idx])
dupColNames = append(dupColNames, p.tableColNames[idx])

expr, err := yield(assign.Expr)
if err != nil {
return nil, err
}

p.OnDuplicate = append(p.OnDuplicate, &expression.Assignment{
Col: p.tableSchema.Columns[idx],
ColName: p.tableColNames[idx].ColName,
Expr: expr,
})
}
return onDupColSet, dupCols, dupColNames, nil
return onDupColSet, nil
}

func (b *PlanBuilder) getAffectCols(insertStmt *ast.InsertStmt, insertPlan *Insert) (affectedValuesCols []*table.Column, err error) {
Expand Down Expand Up @@ -1941,14 +1948,14 @@ func (b *PlanBuilder) buildSetValuesOfInsert(ctx context.Context, insert *ast.In

insertPlan.AllAssignmentsAreConstant = true
for i, assign := range insert.Setlist {
defaultExpr, isDefaultExpr := extractDefaultExpr(assign.Expr)
if isDefaultExpr {
defaultExpr := extractDefaultExpr(assign.Expr)
if defaultExpr != nil {
defaultExpr.Name = assign.Column
}
// Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT.
// see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html
if _, ok := generatedColumns[assign.Column.Name.L]; ok {
if isDefaultExpr {
if defaultExpr != nil {
continue
}
return ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tableInfo.Name.O)
Expand Down

0 comments on commit 2c333de

Please sign in to comment.