From 46dcea22a3efdba1d896db3347da418e89377dce Mon Sep 17 00:00:00 2001 From: ailinkid <314806019@qq.com> Date: Tue, 23 Nov 2021 17:32:00 +0800 Subject: [PATCH 1/3] fix the updatable table name resolution in build update list Signed-off-by: ailinkid <314806019@qq.com> --- planner/core/integration_test.go | 22 +++++++++ planner/core/logical_plan_builder.go | 68 +++++++++++++++++++--------- 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 2a4fc6316e6a3..faf0393ae1c7d 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -235,6 +235,28 @@ func (s *testIntegrationSuite) TestIssue24571(c *C) { tk.MustExec("update (select 1 as a) as t, test.t set test.t.a=1;") } +func (s *testIntegrationSuite) TestBuildUpdateListResolver(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + + // For issue https://github.com/pingcap/tidb/issues/24567 + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(a int)") + tk.MustExec("create table t1(b int)") + tk.MustGetErrCode("update (select 1 as a) as t set a=1", mysql.ErrNonUpdatableTable) + tk.MustGetErrCode("update (select 1 as a) as t, t1 set a=1", mysql.ErrNonUpdatableTable) + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + + // For issue https://github.com/pingcap/tidb/issues/30031 + tk.MustExec("create table t(a int default -1, c int as (a+10) stored)") + tk.MustExec("insert into t(a) values(1)") + tk.MustExec("update test.t, (select 1 as b) as t set test.t.a=default") + tk.MustQuery("select * from t").Check(testkit.Rows("-1 9")) + tk.MustExec("drop table if exists t") +} + func (s *testIntegrationSuite) TestIssue22828(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 02451a4c4cb65..5aa508a2ed3b7 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4732,12 +4732,15 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( p = proj // update subquery table should be forbidden - var notUpdatableTbl []string - notUpdatableTbl = extractTableSourceAsNames(update.TableRefs.TableRefs, notUpdatableTbl, true) + // var notUpdatableTbl []string + // notUpdatableTbl = extractTableSourceAsNames(update.TableRefs.TableRefs, notUpdatableTbl, true) var updateTableList []*ast.TableName updateTableList = extractTableList(update.TableRefs.TableRefs, updateTableList, true) - orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, updateTableList, update.List, p, notUpdatableTbl) + + utlr := &updatableTableListResolver{} + update.Accept(utlr) + orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, utlr.updatableTableList, update.List, p) if err != nil { return nil, err } @@ -4820,8 +4823,7 @@ func isCTE(tl *ast.TableName) bool { return tl.TableInfo == nil } -func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan, - notUpdatableTbl []string) (newList []*expression.Assignment, po LogicalPlan, allAssignmentsAreConstant bool, e error) { +func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) (newList []*expression.Assignment, po LogicalPlan, allAssignmentsAreConstant bool, e error) { b.curClause = fieldList // modifyColumns indicates which columns are in set list, // and if it is set to `DEFAULT` @@ -4844,21 +4846,22 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab columnsIdx[assign.Column] = idx } name := p.OutputNames()[idx] + foundListItem := false for _, tl := range tableList { if (tl.Schema.L == "" || tl.Schema.L == name.DBName.L) && (tl.Name.L == name.TblName.L) { if isCTE(tl) || tl.TableInfo.IsView() || tl.TableInfo.IsSequence() { return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE") } - // may be a subquery - if tl.Schema.L == "" { - for _, nTbl := range notUpdatableTbl { - if nTbl == name.TblName.L { - return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE") - } - } - } + foundListItem = true } } + if !foundListItem { + // For case like: + // 1: update (select * from t1) t1 set b = 1111111 ----- (no updatable table here) + // 2: update (select 1 as a) as t, t1 set a=1 ----- (updatable t1 don't have column a) + // --- subQuery is not counted as updatable table. + return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE") + } columnFullName := fmt.Sprintf("%s.%s.%s", name.DBName.L, name.TblName.L, name.ColName.L) // We save a flag for the column in map `modifyColumns` // This flag indicated if assign keyword `DEFAULT` to the column @@ -4873,15 +4876,7 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab // And, fill virtualAssignments here; that's for generated columns. virtualAssignments := make([]*ast.Assignment, 0) for _, tn := range tableList { - // Only generate virtual to updatable table, skip not updatable table(i.e. table in update's subQuery) - updatable := true - for _, nTbl := range notUpdatableTbl { - if tn.Name.L == nTbl { - updatable = false - break - } - } - if !updatable || isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() { + if isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() { continue } @@ -5984,6 +5979,35 @@ func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectLi } } +type updatableTableListResolver struct { + updatableTableList []*ast.TableName +} + +func (u *updatableTableListResolver) Enter(inNode ast.Node) (ast.Node, bool) { + switch v := inNode.(type) { + case *ast.UpdateStmt, *ast.TableRefsClause, *ast.Join, *ast.TableSource, *ast.TableName: + return v, false + } + return inNode, true +} + +func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) { + switch v := inNode.(type) { + case *ast.TableSource: + if s, ok := v.Source.(*ast.TableName); ok { + if v.AsName.L != "" { + newTableName := *s + newTableName.Name = v.AsName + newTableName.Schema = model.NewCIStr("") + u.updatableTableList = append(u.updatableTableList, &newTableName) + } else { + u.updatableTableList = append(u.updatableTableList, s) + } + } + } + return inNode, true +} + // extractTableList extracts all the TableNames from node. // If asName is true, extract AsName prior to OrigName. // Privilege check should use OrigName, while expression may use AsName. From b656fbf738082712a31d30006bc114625814dd9f Mon Sep 17 00:00:00 2001 From: ailinkid <314806019@qq.com> Date: Tue, 23 Nov 2021 17:47:00 +0800 Subject: [PATCH 2/3] remove redundant code Signed-off-by: ailinkid <314806019@qq.com> --- planner/core/logical_plan_builder.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 5aa508a2ed3b7..ce84ccb5fe53c 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4731,13 +4731,6 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( proj.SetChildren(p) p = proj - // update subquery table should be forbidden - // var notUpdatableTbl []string - // notUpdatableTbl = extractTableSourceAsNames(update.TableRefs.TableRefs, notUpdatableTbl, true) - - var updateTableList []*ast.TableName - updateTableList = extractTableList(update.TableRefs.TableRefs, updateTableList, true) - utlr := &updatableTableListResolver{} update.Accept(utlr) orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, utlr.updatableTableList, update.List, p) From 3e21cb59e03df7c53a62f49a7c145799a1a95505 Mon Sep 17 00:00:00 2001 From: ailinkid <314806019@qq.com> Date: Tue, 23 Nov 2021 18:01:15 +0800 Subject: [PATCH 3/3] remove useless func Signed-off-by: ailinkid <314806019@qq.com> --- planner/core/logical_plan_builder.go | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index ce84ccb5fe53c..bbd6da8372101 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -6131,28 +6131,6 @@ func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, in } } -// extractTableSourceAsNames extracts TableSource.AsNames from node. -// if onlySelectStmt is set to be true, only extracts AsNames when TableSource.Source.(type) == *ast.SelectStmt -func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelectStmt bool) []string { - switch x := node.(type) { - case *ast.Join: - input = extractTableSourceAsNames(x.Left, input, onlySelectStmt) - input = extractTableSourceAsNames(x.Right, input, onlySelectStmt) - case *ast.TableSource: - if _, ok := x.Source.(*ast.SelectStmt); !ok && onlySelectStmt { - break - } - if s, ok := x.Source.(*ast.TableName); ok { - if x.AsName.L == "" { - input = append(input, s.Name.L) - break - } - } - input = append(input, x.AsName.L) - } - return input -} - func appendDynamicVisitInfo(vi []visitInfo, priv string, withGrant bool, err error) []visitInfo { return append(vi, visitInfo{ privilege: mysql.ExtendedPriv,