From 4b1074c09d251c1ca5a2b57882d74ae4c8c841ca Mon Sep 17 00:00:00 2001 From: Yiding Cui Date: Sun, 29 Sep 2019 20:35:32 +0800 Subject: [PATCH] planner: update's select should not change the output columns (#12476) (#12483) --- planner/core/logical_plan_builder.go | 11 ++++++++++- planner/core/logical_plan_test.go | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 40d255baeeaf4..4d94bc43227b0 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -2674,12 +2674,21 @@ func (b *PlanBuilder) buildUpdate(ctx context.Context, update *ast.UpdateStmt) ( b.visitInfo = appendVisitInfo(b.visitInfo, mysql.SelectPriv, dbName, t.Name.L, "", nil) } + oldSchemaLen := p.Schema().Len() if sel.Where != nil { - p, err = b.buildSelection(ctx, p, sel.Where, nil) + p, err = b.buildSelection(ctx, p, update.Where, nil) if err != nil { return nil, err } } + // TODO: expression rewriter should not change the output columns. We should cut the columns here. + if p.Schema().Len() != oldSchemaLen { + proj := LogicalProjection{Exprs: expression.Column2Exprs(p.Schema().Columns[:oldSchemaLen])}.Init(b.ctx) + proj.SetSchema(expression.NewSchema(make([]*expression.Column, oldSchemaLen)...)) + copy(proj.schema.Columns, p.Schema().Columns[:oldSchemaLen]) + proj.SetChildren(p) + p = proj + } if sel.OrderBy != nil { p, err = b.buildSort(ctx, p, sel.OrderBy.Items, nil, nil) if err != nil { diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 88446a70c998e..a5f06c68d1204 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -941,6 +941,10 @@ func (s *testPlanSuite) TestPlanBuilder(c *C) { // binlog columns, because the schema and data are not consistent. plan: "LeftHashJoin{LeftHashJoin{TableReader(Table(t))->IndexLookUp(Index(t.c_d_e)[[666,666]], Table(t))}(test.t.a,test.t.b)->IndexReader(Index(t.c_d_e)[[42,42]])}(test.t.b,test.t.a)->Sel([or(6_aux_0, 10_aux_0)])->Projection->Delete", }, + { + sql: "update t set a = 2 where b in (select c from t)", + plan: "LeftHashJoin{TableReader(Table(t))->IndexReader(Index(t.c_d_e)[[NULL,+inf]])->StreamAgg}(test.t.b,test.t.c)->Projection->Update", + }, } ctx := context.Background()