diff --git a/ddl/db_integration_test.go b/ddl/db_integration_test.go index a8edcc14730d8..3cc94050027e8 100644 --- a/ddl/db_integration_test.go +++ b/ddl/db_integration_test.go @@ -1929,11 +1929,7 @@ func (s *testIntegrationSuite3) TestInsertIntoGeneratedColumnWithDefaultExpr(c * tk.MustExec("create table t5 (a int default 10, b int as (a+1))") tk.MustGetErrCode("insert into t5 values (20, default(a))", mysql.ErrBadGeneratedColumn) - tk.MustExec("drop table t1") - tk.MustExec("drop table t2") - tk.MustExec("drop table t3") - tk.MustExec("drop table t4") - tk.MustExec("drop table t5") + tk.MustExec("drop table t1, t2, t3, t4, t5") } func (s *testIntegrationSuite3) TestSqlFunctionsInGeneratedColumns(c *C) { diff --git a/executor/write_test.go b/executor/write_test.go index 838106a34ba26..befb96ae317e3 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -796,6 +796,41 @@ func (s *testSuite4) TestInsertSetWithDefault(c *C) { tk.MustExec("drop table t1, t2") } +func (s *testSuite4) TestInsertOnDupUpdateDefault(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + // Assign `DEFAULT` in `INSERT ... ON DUPLICATE KEY UPDATE ...` statement + tk.MustExec("drop table if exists t1, t2;") + tk.MustExec("create table t1 (a int unique, b int default 20, c int default 30);") + tk.MustExec("insert into t1 values (1,default,default);") + tk.MustExec("insert into t1 values (1,default,default) on duplicate key update b=default;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30")) + tk.MustExec("insert into t1 values (1,default,default) on duplicate key update c=default, b=default;") + tk.MustQuery("select * from t1;").Check(testkit.Rows("1 20 30")) + tk.MustExec("insert into t1 values (1,default,default) on duplicate key update c=default, a=2") + tk.MustQuery("select * from t1;").Check(testkit.Rows("2 20 30")) + tk.MustExec("insert into t1 values (2,default,default) on duplicate key update c=default(b)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("2 20 20")) + tk.MustExec("insert into t1 values (2,default,default) on duplicate key update a=default(b)+default(c)") + tk.MustQuery("select * from t1;").Check(testkit.Rows("50 20 20")) + // With generated columns + tk.MustExec("create table t2 (a int unique, b int generated always as (-a) virtual, c int generated always as (-a) stored);") + tk.MustExec("insert into t2 values (1,default,default);") + tk.MustExec("insert into t2 values (1,default,default) on duplicate key update a=2, b=default;") + tk.MustQuery("select * from t2").Check(testkit.Rows("2 -2 -2")) + tk.MustExec("insert into t2 values (2,default,default) on duplicate key update a=3, c=default;") + tk.MustQuery("select * from t2").Check(testkit.Rows("3 -3 -3")) + tk.MustExec("insert into t2 values (3,default,default) on duplicate key update c=default, b=default, a=4;") + tk.MustQuery("select * from t2").Check(testkit.Rows("4 -4 -4")) + tk.MustExec("insert into t2 values (10,default,default) on duplicate key update b=default, a=20, c=default;") + tk.MustQuery("select * from t2").Check(testkit.Rows("4 -4 -4", "10 -10 -10")) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update b=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(b), b=default(b);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(a), c=default(c);", mysql.ErrBadGeneratedColumn) + tk.MustGetErrCode("insert into t2 values (4,default,default) on duplicate key update a=default(a), c=default(a);", mysql.ErrBadGeneratedColumn) + tk.MustExec("drop table t1, t2") +} + func (s *testSuite4) TestReplace(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 980365d8029ad..924363cf1e2d1 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -3020,7 +3020,7 @@ func (b *PlanBuilder) buildUpdateLists( columnFullName := fmt.Sprintf("%s.%s.%s", name.DBName.L, name.TblName.L, name.ColName.L) // We save a flag for the column in map `modifyColumns` // This flag indicated if assign keyword `DEFAULT` to the column - if _, ok := extractDefaultExpr(assign.Expr); ok { + if extractDefaultExpr(assign.Expr) != nil { modifyColumns[columnFullName] = true } else { modifyColumns[columnFullName] = false @@ -3073,7 +3073,7 @@ func (b *PlanBuilder) buildUpdateLists( var np LogicalPlan if i < len(list) { // If assign `DEFAULT` to column, fill the `defaultExpr.Name` before rewrite expression - if expr, ok := extractDefaultExpr(assign.Expr); ok { + if expr := extractDefaultExpr(assign.Expr); expr != nil { expr.Name = assign.Column } newExpr, np, err = b.rewrite(ctx, assign.Expr, p, nil, false) @@ -3115,15 +3115,14 @@ func (b *PlanBuilder) buildUpdateLists( return newList, p, allAssignmentsAreConstant, nil } -// extractDefaultExpr extract a `DefaultExpr` without any parameter from a `ExprNode`, -// return the `DefaultExpr` and whether it's extracted successfully. -// Note: the SQL function `DEFAULT(a)` is not the same with keyword `DEFAULT`, -// SQL function `DEFAULT(a)` will return `false`. -func extractDefaultExpr(node ast.ExprNode) (*ast.DefaultExpr, bool) { +// extractDefaultExpr extract a `DefaultExpr` from `ExprNode`, +// If it is a `DEFAULT` function like `DEFAULT(a)`, return nil. +// Only if it is `DEFAULT` keyword, it will return the `DefaultExpr`. +func extractDefaultExpr(node ast.ExprNode) *ast.DefaultExpr { if expr, ok := node.(*ast.DefaultExpr); ok && expr.Name == nil { - return expr, true + return expr } - return nil, false + return nil } func (b *PlanBuilder) buildDelete(ctx context.Context, delete *ast.DeleteStmt) (Plan, error) { diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index 6a2bd1ddfc8d8..5155cd9f3d380 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -1829,27 +1829,13 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( mockTablePlan.SetSchema(insertPlan.Schema4OnDuplicate) mockTablePlan.names = insertPlan.names4OnDuplicate - columnByName := make(map[string]*table.Column, len(insertPlan.Table.Cols())) - for _, col := range insertPlan.Table.Cols() { - columnByName[col.Name.L] = col - } - onDupColSet, dupCols, dupColNames, err := insertPlan.validateOnDup(insert.OnDuplicate, columnByName, tableInfo) + + onDupColSet, err := insertPlan.resolveOnDuplicate(insert.OnDuplicate, tableInfo, func(node ast.ExprNode) (expression.Expression, error) { + return b.rewriteInsertOnDuplicateUpdate(ctx, node, mockTablePlan, insertPlan) + }) if err != nil { return nil, err } - for i, assign := range insert.OnDuplicate { - // Construct the function which calculates the assign value of the column. - expr, err1 := b.rewriteInsertOnDuplicateUpdate(ctx, assign.Expr, mockTablePlan, insertPlan) - if err1 != nil { - return nil, err1 - } - - insertPlan.OnDuplicate = append(insertPlan.OnDuplicate, &expression.Assignment{ - Col: dupCols[i], - ColName: dupColNames[i].ColName, - Expr: expr, - }) - } // Calculate generated columns. mockTablePlan.schema = insertPlan.tableSchema @@ -1863,29 +1849,50 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( return insertPlan, err } -func (p *Insert) validateOnDup(onDup []*ast.Assignment, colMap map[string]*table.Column, tblInfo *model.TableInfo) (map[string]struct{}, []*expression.Column, types.NameSlice, error) { +func (p *Insert) resolveOnDuplicate(onDup []*ast.Assignment, tblInfo *model.TableInfo, yield func(ast.ExprNode) (expression.Expression, error)) (map[string]struct{}, error) { onDupColSet := make(map[string]struct{}, len(onDup)) - dupCols := make([]*expression.Column, 0, len(onDup)) - dupColNames := make(types.NameSlice, 0, len(onDup)) + colMap := make(map[string]*table.Column, len(p.Table.Cols())) + for _, col := range p.Table.Cols() { + colMap[col.Name.L] = col + } for _, assign := range onDup { // Check whether the column to be updated exists in the source table. idx, err := expression.FindFieldName(p.tableColNames, assign.Column) if err != nil { - return nil, nil, nil, err + return nil, err } else if idx < 0 { - return nil, nil, nil, ErrUnknownColumn.GenWithStackByArgs(assign.Column.OrigColName(), "field list") + return nil, ErrUnknownColumn.GenWithStackByArgs(assign.Column.OrigColName(), "field list") } // Check whether the column to be updated is the generated column. column := colMap[assign.Column.Name.L] + defaultExpr := extractDefaultExpr(assign.Expr) + if defaultExpr != nil { + defaultExpr.Name = assign.Column + } + // Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT. + // see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html if column.IsGenerated() { - return nil, nil, nil, ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tblInfo.Name.O) + if defaultExpr != nil { + continue + } + return nil, ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tblInfo.Name.O) } + onDupColSet[column.Name.L] = struct{}{} - dupCols = append(dupCols, p.tableSchema.Columns[idx]) - dupColNames = append(dupColNames, p.tableColNames[idx]) + + expr, err := yield(assign.Expr) + if err != nil { + return nil, err + } + + p.OnDuplicate = append(p.OnDuplicate, &expression.Assignment{ + Col: p.tableSchema.Columns[idx], + ColName: p.tableColNames[idx].ColName, + Expr: expr, + }) } - return onDupColSet, dupCols, dupColNames, nil + return onDupColSet, nil } func (b *PlanBuilder) getAffectCols(insertStmt *ast.InsertStmt, insertPlan *Insert) (affectedValuesCols []*table.Column, err error) { @@ -1941,14 +1948,14 @@ func (b *PlanBuilder) buildSetValuesOfInsert(ctx context.Context, insert *ast.In insertPlan.AllAssignmentsAreConstant = true for i, assign := range insert.Setlist { - defaultExpr, isDefaultExpr := extractDefaultExpr(assign.Expr) - if isDefaultExpr { + defaultExpr := extractDefaultExpr(assign.Expr) + if defaultExpr != nil { defaultExpr.Name = assign.Column } // Note: For INSERT, REPLACE, and UPDATE, if a generated column is inserted into, replaced, or updated explicitly, the only permitted value is DEFAULT. // see https://dev.mysql.com/doc/refman/8.0/en/create-table-generated-columns.html if _, ok := generatedColumns[assign.Column.Name.L]; ok { - if isDefaultExpr { + if defaultExpr != nil { continue } return ErrBadGeneratedColumn.GenWithStackByArgs(assign.Column.Name.O, tableInfo.Name.O)