Skip to content

Commit

Permalink
update: fix the updatable table name resolution in build update list (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
AilinKid authored Nov 25, 2021
1 parent 8dc59e6 commit bee016b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 48 deletions.
22 changes: 22 additions & 0 deletions planner/core/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,28 @@ func (s *testIntegrationSuite) TestIssue24571(c *C) {
tk.MustExec("update (select 1 as a) as t, test.t set test.t.a=1;")
}

func (s *testIntegrationSuite) TestBuildUpdateListResolver(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

// For issue https://github.com/pingcap/tidb/issues/24567
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists t1")
tk.MustExec("create table t(a int)")
tk.MustExec("create table t1(b int)")
tk.MustGetErrCode("update (select 1 as a) as t set a=1", mysql.ErrNonUpdatableTable)
tk.MustGetErrCode("update (select 1 as a) as t, t1 set a=1", mysql.ErrNonUpdatableTable)
tk.MustExec("drop table if exists t")
tk.MustExec("drop table if exists t1")

// For issue https://github.com/pingcap/tidb/issues/30031
tk.MustExec("create table t(a int default -1, c int as (a+10) stored)")
tk.MustExec("insert into t(a) values(1)")
tk.MustExec("update test.t, (select 1 as b) as t set test.t.a=default")
tk.MustQuery("select * from t").Check(testkit.Rows("-1 9"))
tk.MustExec("drop table if exists t")
}

func (s *testIntegrationSuite) TestIssue22828(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")
Expand Down
91 changes: 43 additions & 48 deletions planner/core/logical_plan_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4731,13 +4731,9 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) (
proj.SetChildren(p)
p = proj

// update subquery table should be forbidden
var notUpdatableTbl []string
notUpdatableTbl = extractTableSourceAsNames(update.TableRefs.TableRefs, notUpdatableTbl, true)

var updateTableList []*ast.TableName
updateTableList = extractTableList(update.TableRefs.TableRefs, updateTableList, true)
orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, updateTableList, update.List, p, notUpdatableTbl)
utlr := &updatableTableListResolver{}
update.Accept(utlr)
orderedList, np, allAssignmentsAreConstant, err := b.buildUpdateLists(ctx, utlr.updatableTableList, update.List, p)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -4820,8 +4816,7 @@ func isCTE(tl *ast.TableName) bool {
return tl.TableInfo == nil
}

func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan,
notUpdatableTbl []string) (newList []*expression.Assignment, po LogicalPlan, allAssignmentsAreConstant bool, e error) {
func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.TableName, list []*ast.Assignment, p LogicalPlan) (newList []*expression.Assignment, po LogicalPlan, allAssignmentsAreConstant bool, e error) {
b.curClause = fieldList
// modifyColumns indicates which columns are in set list,
// and if it is set to `DEFAULT`
Expand All @@ -4844,21 +4839,22 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab
columnsIdx[assign.Column] = idx
}
name := p.OutputNames()[idx]
foundListItem := false
for _, tl := range tableList {
if (tl.Schema.L == "" || tl.Schema.L == name.DBName.L) && (tl.Name.L == name.TblName.L) {
if isCTE(tl) || tl.TableInfo.IsView() || tl.TableInfo.IsSequence() {
return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE")
}
// may be a subquery
if tl.Schema.L == "" {
for _, nTbl := range notUpdatableTbl {
if nTbl == name.TblName.L {
return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE")
}
}
}
foundListItem = true
}
}
if !foundListItem {
// For case like:
// 1: update (select * from t1) t1 set b = 1111111 ----- (no updatable table here)
// 2: update (select 1 as a) as t, t1 set a=1 ----- (updatable t1 don't have column a)
// --- subQuery is not counted as updatable table.
return nil, nil, false, ErrNonUpdatableTable.GenWithStackByArgs(name.TblName.O, "UPDATE")
}
columnFullName := fmt.Sprintf("%s.%s.%s", name.DBName.L, name.TblName.L, name.ColName.L)
// We save a flag for the column in map `modifyColumns`
// This flag indicated if assign keyword `DEFAULT` to the column
Expand All @@ -4873,15 +4869,7 @@ func (b *PlanBuilder) buildUpdateLists(ctx context.Context, tableList []*ast.Tab
// And, fill virtualAssignments here; that's for generated columns.
virtualAssignments := make([]*ast.Assignment, 0)
for _, tn := range tableList {
// Only generate virtual to updatable table, skip not updatable table(i.e. table in update's subQuery)
updatable := true
for _, nTbl := range notUpdatableTbl {
if tn.Name.L == nTbl {
updatable = false
break
}
}
if !updatable || isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() {
if isCTE(tn) || tn.TableInfo.IsView() || tn.TableInfo.IsSequence() {
continue
}

Expand Down Expand Up @@ -5984,6 +5972,35 @@ func unfoldSelectList(list *ast.SetOprSelectList, unfoldList *ast.SetOprSelectLi
}
}

type updatableTableListResolver struct {
updatableTableList []*ast.TableName
}

func (u *updatableTableListResolver) Enter(inNode ast.Node) (ast.Node, bool) {
switch v := inNode.(type) {
case *ast.UpdateStmt, *ast.TableRefsClause, *ast.Join, *ast.TableSource, *ast.TableName:
return v, false
}
return inNode, true
}

func (u *updatableTableListResolver) Leave(inNode ast.Node) (ast.Node, bool) {
switch v := inNode.(type) {
case *ast.TableSource:
if s, ok := v.Source.(*ast.TableName); ok {
if v.AsName.L != "" {
newTableName := *s
newTableName.Name = v.AsName
newTableName.Schema = model.NewCIStr("")
u.updatableTableList = append(u.updatableTableList, &newTableName)
} else {
u.updatableTableList = append(u.updatableTableList, s)
}
}
}
return inNode, true
}

// extractTableList extracts all the TableNames from node.
// If asName is true, extract AsName prior to OrigName.
// Privilege check should use OrigName, while expression may use AsName.
Expand Down Expand Up @@ -6114,28 +6131,6 @@ func collectTableName(node ast.ResultSetNode, updatableName *map[string]bool, in
}
}

// extractTableSourceAsNames extracts TableSource.AsNames from node.
// if onlySelectStmt is set to be true, only extracts AsNames when TableSource.Source.(type) == *ast.SelectStmt
func extractTableSourceAsNames(node ast.ResultSetNode, input []string, onlySelectStmt bool) []string {
switch x := node.(type) {
case *ast.Join:
input = extractTableSourceAsNames(x.Left, input, onlySelectStmt)
input = extractTableSourceAsNames(x.Right, input, onlySelectStmt)
case *ast.TableSource:
if _, ok := x.Source.(*ast.SelectStmt); !ok && onlySelectStmt {
break
}
if s, ok := x.Source.(*ast.TableName); ok {
if x.AsName.L == "" {
input = append(input, s.Name.L)
break
}
}
input = append(input, x.AsName.L)
}
return input
}

func appendDynamicVisitInfo(vi []visitInfo, priv string, withGrant bool, err error) []visitInfo {
return append(vi, visitInfo{
privilege: mysql.ExtendedPriv,
Expand Down

0 comments on commit bee016b

Please sign in to comment.