diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 6b78d16a8dcaf..68f6a368b3b83 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2263,14 +2263,15 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "") } + oldSchemaLen := p.Schema().Len() if sel.Where != nil { p, err = b.buildSelection(p, sel.Where, nil) if err != nil { return nil, errors.Trace(err) } } - if sel.OrderBy != nil { - p, err = b.buildSort(p, sel.OrderBy.Items, nil) + if update.Order != nil { + p, err = b.buildSort(p, update.Order.Items, nil) if err != nil { return nil, errors.Trace(err) } @@ -2281,6 +2282,14 @@ func (b *planBuilder) buildUpdate(update *ast.UpdateStmt) (Plan, error) { return nil, errors.Trace(err) } } + // TODO: expression rewriter should not change the output columns. We should cut the columns here. + if p.Schema().Len() != oldSchemaLen { + proj := LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldSchemaLen])}.init(b.ctx) + proj.SetSchema(expression.NewSchema(make([]*expression.Column, oldSchemaLen)...)) + copy(proj.schema.Columns, p.Schema().Columns[:oldSchemaLen]) + proj.SetChildren(p) + p = proj + } orderedList, np, err := b.buildUpdateLists(tableList, update.List, p) if err != nil { return nil, errors.Trace(err) diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 810a85db37d0c..b39bd330f5d6d 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -955,6 +955,10 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) { // binlog columns, because the schema and data are not consistent. plan: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[666,666]], Table(t))}(test.t.a,test.t.b)->IndexReader(Index(t.c_d_e)[[42,42]])}(test.t.b,test.t.a)->Sel([or(6_aux_0, 10_aux_0)])->Projection->Delete", }, + { + sql: "update t set a = 2 where b in (select c from t)", + plan: "LeftHashJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])->HashAgg}(Column#2,Column#25)->Projection->Update", + }, } for _, ca := range tests { comment := Commentf("for %s", ca.sql)