From 4a350fee1f7a57826c2227dc4f32a9d05dd1d2b8 Mon Sep 17 00:00:00 2001 From: Feng Liyuan Date: Thu, 14 May 2020 12:40:58 +0800 Subject: [PATCH] executor: support GROUP_CONCAT(ORDER BY) (#16591) (#17184) --- cmd/explaintest/r/explain.result | 4 +- executor/aggfuncs/aggfunc_test.go | 145 +++++++- executor/aggfuncs/builder.go | 24 +- executor/aggfuncs/func_group_concat.go | 330 ++++++++++++++++++- executor/aggfuncs/func_group_concat_test.go | 19 +- executor/aggregate_test.go | 94 ++++++ executor/builder.go | 2 +- executor/executor_required_rows_test.go | 13 +- executor/sort.go | 3 +- expression/aggregation/agg_to_pb.go | 3 + expression/aggregation/descriptor.go | 15 + expression/aggregation/explain.go | 21 +- go.mod | 2 +- go.sum | 4 +- planner/cascades/enforcer_rules.go | 5 +- planner/core/exhaust_physical_plans.go | 3 +- planner/core/explain.go | 3 +- planner/core/find_best_task.go | 3 +- planner/core/logical_plan_builder.go | 94 ++++-- planner/core/logical_plan_test.go | 5 +- planner/core/logical_plans.go | 5 +- planner/core/physical_plan_test.go | 43 +++ planner/core/physical_plans.go | 5 +- planner/core/plan.go | 5 +- planner/core/property_cols_prune.go | 3 +- planner/core/resolve_indices.go | 6 + planner/core/rule_aggregation_push_down.go | 28 +- planner/core/rule_column_pruning.go | 35 +- planner/core/rule_inject_extra_projection.go | 22 +- planner/core/rule_max_min_eliminate.go | 3 +- planner/core/rule_topn_push_down.go | 5 +- planner/core/task.go | 3 +- planner/core/testdata/plan_suite_in.json | 9 + planner/core/testdata/plan_suite_out.json | 59 ++++ planner/util/byitem.go | 45 +++ sessionctx/variable/varsutil.go | 13 +- types/datum.go | 44 +++ types/datum_test.go | 2 + 38 files changed, 1028 insertions(+), 99 deletions(-) create mode 100644 planner/util/byitem.go diff --git a/cmd/explaintest/r/explain.result b/cmd/explaintest/r/explain.result index a8b430562c55f..800f68fa8bf7f 100644 --- a/cmd/explaintest/r/explain.result +++ b/cmd/explaintest/r/explain.result @@ -28,13 +28,13 @@ set session tidb_hashagg_partial_concurrency = 1; set session tidb_hashagg_final_concurrency = 1; explain select group_concat(a) from t group by id; id count task operator info -StreamAgg_8 8000.00 root group by:col_1, funcs:group_concat(col_0, ",") +StreamAgg_8 8000.00 root group by:col_1, funcs:group_concat(col_0 separator ",") └─Projection_18 10000.00 root cast(test.t.a), test.t.id └─TableReader_15 10000.00 root data:TableScan_14 └─TableScan_14 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:true, stats:pseudo explain select group_concat(a, b) from t group by id; id count task operator info -StreamAgg_8 8000.00 root group by:col_2, funcs:group_concat(col_0, col_1, ",") +StreamAgg_8 8000.00 root group by:col_2, funcs:group_concat(col_0, col_1 separator ",") └─Projection_18 10000.00 root cast(test.t.a), cast(test.t.b), test.t.id └─TableReader_15 10000.00 root data:TableScan_14 └─TableScan_14 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:true, stats:pseudo diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index 19cdd55d83dbc..72a6f9ce2132c 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/executor/aggfuncs" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" @@ -68,6 +69,7 @@ type aggTest struct { dataGen func(i int) types.Datum funcName string results []types.Datum + orderBy bool } type multiArgsAggTest struct { @@ -77,6 +79,7 @@ type multiArgsAggTest struct { dataGens []func(i int) types.Datum funcName string results []types.Datum + orderBy bool } func (s *testSuite) testMergePartialResult(c *C, p aggTest) { @@ -93,6 +96,11 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { } desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) c.Assert(err, IsNil) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } partialDesc, finalDesc := desc.Split([]int{0, 1}) // build partial func for partial phase. @@ -112,7 +120,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { dt := resultChk.GetRow(0).GetDatum(0, p.dataType) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0])) err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr) c.Assert(err, IsNil) @@ -128,7 +136,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { dt = resultChk.GetRow(0).GetDatum(0, p.dataType) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1])) err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr) c.Assert(err, IsNil) @@ -139,7 +147,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { dt = resultChk.GetRow(0).GetDatum(0, p.dataType) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[2])) } func buildAggTester(funcName string, tp byte, numRows int, results ...interface{}) aggTest { @@ -159,6 +167,77 @@ func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows i return pt } +func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) { + srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows) + for i := 0; i < p.numRows; i++ { + for j := 0; j < len(p.dataGens); j++ { + fdt := p.dataGens[j](i) + srcChk.AppendDatum(j, &fdt) + } + } + iter := chunk.NewIterator4Chunk(srcChk) + + args := make([]expression.Expression, len(p.dataTypes)) + for k := 0; k < len(p.dataTypes); k++ { + args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k} + } + + desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + c.Assert(err, IsNil) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } + partialDesc, finalDesc := desc.Split([]int{0, 1}) + + // build partial func for partial phase. + partialFunc := aggfuncs.Build(s.ctx, partialDesc, 0) + partialResult := partialFunc.AllocPartialResult() + + // build final func for final phase. + finalFunc := aggfuncs.Build(s.ctx, finalDesc, 0) + finalPr := finalFunc.AllocPartialResult() + resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.retType}, 1) + + // update partial result. + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) + } + partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) + dt := resultChk.GetRow(0).GetDatum(0, p.retType) + result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) + c.Assert(err, IsNil) + c.Assert(result, Equals, 0) + + err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr) + c.Assert(err, IsNil) + partialFunc.ResetPartialResult(partialResult) + + iter.Begin() + iter.Next() + for row := iter.Next(); row != iter.End(); row = iter.Next() { + partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) + } + resultChk.Reset() + partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) + dt = resultChk.GetRow(0).GetDatum(0, p.retType) + result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) + c.Assert(err, IsNil) + c.Assert(result, Equals, 0) + err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr) + c.Assert(err, IsNil) + + resultChk.Reset() + err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) + c.Assert(err, IsNil) + + dt = resultChk.GetRow(0).GetDatum(0, p.retType) + result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2]) + c.Assert(err, IsNil) + c.Assert(result, Equals, 0) +} + // for multiple args in aggfuncs such as json_objectagg(c1, c2) func buildMultiArgsAggTester(funcName string, tps []byte, rt byte, numRows int, results ...interface{}) multiArgsAggTest { fts := make([]*types.FieldType, len(tps)) @@ -234,6 +313,11 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { } desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) c.Assert(err, IsNil) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc := aggfuncs.Build(s.ctx, desc, 0) finalPr := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) @@ -246,7 +330,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1])) // test the empty input resultChk.Reset() @@ -255,11 +339,16 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0])) // test the agg func with distinct desc, err = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true) c.Assert(err, IsNil) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc = aggfuncs.Build(s.ctx, desc, 0) finalPr = finalFunc.AllocPartialResult() @@ -275,7 +364,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1])) // test the empty input resultChk.Reset() @@ -284,7 +373,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) { dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0])) } func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { @@ -301,9 +390,17 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { for k := 0; k < len(p.dataTypes); k++ { args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k} } + if p.funcName == ast.AggFuncGroupConcat { + args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)}) + } desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) c.Assert(err, IsNil) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc := aggfuncs.Build(s.ctx, desc, 0) finalPr := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) @@ -316,7 +413,7 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1])) // test the empty input resultChk.Reset() @@ -325,11 +422,16 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0])) // test the agg func with distinct desc, err = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true) c.Assert(err, IsNil) + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc = aggfuncs.Build(s.ctx, desc, 0) finalPr = finalFunc.AllocPartialResult() @@ -345,7 +447,7 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) { dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp) result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) c.Assert(err, IsNil) - c.Assert(result, Equals, 0) + c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1])) // test the empty input resultChk.Reset() @@ -373,6 +475,11 @@ func (s *testSuite) benchmarkAggFunc(b *testing.B, p aggTest) { if err != nil { b.Fatal(err) } + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc := aggfuncs.Build(s.ctx, desc, 0) resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) iter := chunk.NewIterator4Chunk(srcChk) @@ -388,6 +495,11 @@ func (s *testSuite) benchmarkAggFunc(b *testing.B, p aggTest) { if err != nil { b.Fatal(err) } + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc = aggfuncs.Build(s.ctx, desc, 0) resultChk.Reset() b.Run(fmt.Sprintf("%v(distinct)/%v", p.funcName, p.dataType), func(b *testing.B) { @@ -409,11 +521,19 @@ func (s *testSuite) benchmarkMultiArgsAggFunc(b *testing.B, p multiArgsAggTest) for k := 0; k < len(p.dataTypes); k++ { args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k} } + if p.funcName == ast.AggFuncGroupConcat { + args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)}) + } desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) if err != nil { b.Fatal(err) } + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc := aggfuncs.Build(s.ctx, desc, 0) resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1) iter := chunk.NewIterator4Chunk(srcChk) @@ -429,6 +549,11 @@ func (s *testSuite) benchmarkMultiArgsAggFunc(b *testing.B, p multiArgsAggTest) if err != nil { b.Fatal(err) } + if p.orderBy { + desc.OrderByItems = []*util.ByItems{ + {Expr: args[0], Desc: true}, + } + } finalFunc = aggfuncs.Build(s.ctx, desc, 0) resultChk.Reset() b.Run(fmt.Sprintf("%v(distinct)/%v", p.funcName, p.dataTypes), func(b *testing.B) { diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index c3676c77f408e..710d1aa3b1f0e 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -313,10 +313,6 @@ func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDe case aggregation.DedupMode: return nil default: - base := baseAggFunc{ - args: aggFuncDesc.Args[:len(aggFuncDesc.Args)-1], - ordinal: ordinal, - } // The last arg is promised to be a not-null string constant, so the error can be ignored. c, _ := aggFuncDesc.Args[len(aggFuncDesc.Args)-1].(*expression.Constant) sep, _, err := c.EvalString(nil, chunk.Row{}) @@ -335,10 +331,26 @@ func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDe panic(fmt.Sprintf("Error happened when buildGroupConcat: %s", err.Error())) } var truncated int32 + base := baseGroupConcat4String{ + baseAggFunc: baseAggFunc{ + args: aggFuncDesc.Args[:len(aggFuncDesc.Args)-1], + ordinal: ordinal, + }, + byItems: aggFuncDesc.OrderByItems, + sep: sep, + maxLen: maxLen, + truncated: &truncated, + } if aggFuncDesc.HasDistinct { - return &groupConcatDistinct{baseGroupConcat4String{baseAggFunc: base, sep: sep, maxLen: maxLen, truncated: &truncated}} + if len(aggFuncDesc.OrderByItems) > 0 { + return &groupConcatDistinctOrder{base} + } + return &groupConcatDistinct{base} + } + if len(aggFuncDesc.OrderByItems) > 0 { + return &groupConcatOrder{base} } - return &groupConcat{baseGroupConcat4String{baseAggFunc: base, sep: sep, maxLen: maxLen, truncated: &truncated}} + return &groupConcat{base} } } diff --git a/executor/aggfuncs/func_group_concat.go b/executor/aggfuncs/func_group_concat.go index 636b2bb69006d..932d25b901514 100644 --- a/executor/aggfuncs/func_group_concat.go +++ b/executor/aggfuncs/func_group_concat.go @@ -15,11 +15,16 @@ package aggfuncs import ( "bytes" + "container/heap" + "sort" "sync/atomic" - "github.com/cznic/mathutil" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/hack" @@ -28,6 +33,7 @@ import ( type baseGroupConcat4String struct { baseAggFunc + byItems []*util.ByItems sep string maxLen uint64 @@ -47,19 +53,20 @@ func (e *baseGroupConcat4String) AppendFinalResult2Chunk(sctx sessionctx.Context return nil } +func (e *baseGroupConcat4String) handleTruncateError(sctx sessionctx.Context) (err error) { + if atomic.CompareAndSwapInt32(e.truncated, 0, 1) { + if !sctx.GetSessionVars().StmtCtx.TruncateAsWarning { + return expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String()) + } + sctx.GetSessionVars().StmtCtx.AppendWarning(expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String())) + } + return nil +} + func (e *baseGroupConcat4String) truncatePartialResultIfNeed(sctx sessionctx.Context, buffer *bytes.Buffer) (err error) { if e.maxLen > 0 && uint64(buffer.Len()) > e.maxLen { - i := mathutil.MaxInt - if uint64(i) > e.maxLen { - i = int(e.maxLen) - } - buffer.Truncate(i) - if atomic.CompareAndSwapInt32(e.truncated, 0, 1) { - if !sctx.GetSessionVars().StmtCtx.TruncateAsWarning { - return expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String()) - } - sctx.GetSessionVars().StmtCtx.AppendWarning(expression.ErrCutValueGroupConcat.GenWithStackByArgs(e.args[0].String())) - } + buffer.Truncate(int(e.maxLen)) + return e.handleTruncateError(sctx) } return nil } @@ -214,3 +221,302 @@ func (e *groupConcatDistinct) SetTruncated(t *int32) { func (e *groupConcatDistinct) GetTruncated() *int32 { return e.truncated } + +type sortRow struct { + buffer *bytes.Buffer + byItems []types.Datum +} + +type topNRows struct { + rows []sortRow + desc []bool + sctx sessionctx.Context + err error + + currSize uint64 + limitSize uint64 + sepSize uint64 +} + +func (h topNRows) Len() int { + return len(h.rows) +} + +func (h topNRows) Less(i, j int) bool { + n := len(h.rows[i].byItems) + for k := 0; k < n; k++ { + ret, err := h.rows[i].byItems[k].CompareDatum(h.sctx.GetSessionVars().StmtCtx, &h.rows[j].byItems[k]) + if err != nil { + h.err = err + return false + } + if h.desc[k] { + ret = -ret + } + if ret > 0 { + return true + } + if ret < 0 { + return false + } + } + return false +} + +func (h topNRows) Swap(i, j int) { + h.rows[i], h.rows[j] = h.rows[j], h.rows[i] +} + +func (h *topNRows) Push(x interface{}) { + h.rows = append(h.rows, x.(sortRow)) +} + +func (h *topNRows) Pop() interface{} { + n := len(h.rows) + x := h.rows[n-1] + h.rows = h.rows[:n-1] + return x +} + +func (h *topNRows) tryToAdd(row sortRow) (truncated bool) { + h.currSize += uint64(row.buffer.Len()) + if len(h.rows) > 0 { + h.currSize += h.sepSize + } + heap.Push(h, row) + if h.currSize <= h.limitSize { + return false + } + + for h.currSize > h.limitSize { + debt := h.currSize - h.limitSize + if uint64(h.rows[0].buffer.Len()) > debt { + h.currSize -= debt + h.rows[0].buffer.Truncate(h.rows[0].buffer.Len() - int(debt)) + } else { + h.currSize -= uint64(h.rows[0].buffer.Len()) + h.sepSize + heap.Pop(h) + } + } + return true +} + +func (h *topNRows) reset() { + h.rows = h.rows[:0] + h.err = nil + h.currSize = 0 +} + +func (h *topNRows) concat(sep string, truncated bool) string { + buffer := new(bytes.Buffer) + sort.Sort(sort.Reverse(h)) + for i, row := range h.rows { + if i != 0 { + buffer.WriteString(sep) + } + buffer.Write(row.buffer.Bytes()) + } + if truncated && uint64(buffer.Len()) < h.limitSize { + // append the last separator, because the last separator may be truncated in tryToAdd. + buffer.WriteString(sep) + buffer.Truncate(int(h.limitSize)) + } + return buffer.String() +} + +type partialResult4GroupConcatOrder struct { + topN *topNRows +} + +type groupConcatOrder struct { + baseGroupConcat4String +} + +func (e *groupConcatOrder) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4GroupConcatOrder)(pr) + if p.topN.Len() == 0 { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendString(e.ordinal, p.topN.concat(e.sep, *e.truncated == 1)) + return nil +} + +func (e *groupConcatOrder) AllocPartialResult() PartialResult { + desc := make([]bool, len(e.byItems)) + for i, byItem := range e.byItems { + desc[i] = byItem.Desc + } + p := &partialResult4GroupConcatOrder{ + topN: &topNRows{ + desc: desc, + currSize: 0, + limitSize: e.maxLen, + sepSize: uint64(len(e.sep)), + }, + } + return PartialResult(p) +} + +func (e *groupConcatOrder) ResetPartialResult(pr PartialResult) { + p := (*partialResult4GroupConcatOrder)(pr) + p.topN.reset() +} + +func (e *groupConcatOrder) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4GroupConcatOrder)(pr) + p.topN.sctx = sctx + v, isNull := "", false + for _, row := range rowsInGroup { + buffer := new(bytes.Buffer) + for _, arg := range e.args { + v, isNull, err = arg.EvalString(sctx, row) + if err != nil { + return err + } + if isNull { + break + } + buffer.WriteString(v) + } + if isNull { + continue + } + sortRow := sortRow{ + buffer: buffer, + byItems: make([]types.Datum, 0, len(e.byItems)), + } + for _, byItem := range e.byItems { + d, err := byItem.Expr.Eval(row) + if err != nil { + return err + } + sortRow.byItems = append(sortRow.byItems, d) + } + truncated := p.topN.tryToAdd(sortRow) + if p.topN.err != nil { + return p.topN.err + } + if truncated { + if err := e.handleTruncateError(sctx); err != nil { + return err + } + } + } + return nil +} + +func (e *groupConcatOrder) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) error { + // If order by exists, the parallel hash aggregation is forbidden in executorBuilder.buildHashAgg. + // So MergePartialResult will not be called. + return terror.ClassOptimizer.New(mysql.ErrInternal, mysql.MySQLErrName[mysql.ErrInternal]).GenWithStack("groupConcatOrder.MergePartialResult should not be called") +} + +// SetTruncated will be called in `executorBuilder#buildHashAgg` with duck-type. +func (e *groupConcatOrder) SetTruncated(t *int32) { + e.truncated = t +} + +// GetTruncated will be called in `executorBuilder#buildHashAgg` with duck-type. +func (e *groupConcatOrder) GetTruncated() *int32 { + return e.truncated +} + +type partialResult4GroupConcatOrderDistinct struct { + topN *topNRows + valSet set.StringSet + encodeBytesBuffer []byte +} + +type groupConcatDistinctOrder struct { + baseGroupConcat4String +} + +func (e *groupConcatDistinctOrder) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4GroupConcatOrderDistinct)(pr) + if p.topN.Len() == 0 { + chk.AppendNull(e.ordinal) + return nil + } + chk.AppendString(e.ordinal, p.topN.concat(e.sep, *e.truncated == 1)) + return nil +} + +func (e *groupConcatDistinctOrder) AllocPartialResult() PartialResult { + desc := make([]bool, len(e.byItems)) + for i, byItem := range e.byItems { + desc[i] = byItem.Desc + } + p := &partialResult4GroupConcatOrderDistinct{ + topN: &topNRows{ + desc: desc, + currSize: 0, + limitSize: e.maxLen, + sepSize: uint64(len(e.sep)), + }, + valSet: set.NewStringSet(), + } + return PartialResult(p) +} + +func (e *groupConcatDistinctOrder) ResetPartialResult(pr PartialResult) { + p := (*partialResult4GroupConcatOrderDistinct)(pr) + p.topN.reset() + p.valSet = set.NewStringSet() +} + +func (e *groupConcatDistinctOrder) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4GroupConcatOrderDistinct)(pr) + p.topN.sctx = sctx + v, isNull := "", false + for _, row := range rowsInGroup { + buffer := new(bytes.Buffer) + p.encodeBytesBuffer = p.encodeBytesBuffer[:0] + for _, arg := range e.args { + v, isNull, err = arg.EvalString(sctx, row) + if err != nil { + return err + } + if isNull { + break + } + p.encodeBytesBuffer = codec.EncodeBytes(p.encodeBytesBuffer, hack.Slice(v)) + buffer.WriteString(v) + } + if isNull { + continue + } + joinedVal := string(p.encodeBytesBuffer) + if p.valSet.Exist(joinedVal) { + continue + } + p.valSet.Insert(joinedVal) + sortRow := sortRow{ + buffer: buffer, + byItems: make([]types.Datum, 0, len(e.byItems)), + } + for _, byItem := range e.byItems { + d, err := byItem.Expr.Eval(row) + if err != nil { + return err + } + sortRow.byItems = append(sortRow.byItems, d) + } + truncated := p.topN.tryToAdd(sortRow) + if p.topN.err != nil { + return p.topN.err + } + if truncated { + if err := e.handleTruncateError(sctx); err != nil { + return err + } + } + } + return nil +} + +func (e *groupConcatDistinctOrder) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) error { + // If order by exists, the parallel hash aggregation is forbidden in executorBuilder.buildHashAgg. + // So MergePartialResult will not be called. + return terror.ClassOptimizer.New(mysql.ErrInternal, mysql.MySQLErrName[mysql.ErrInternal]).GenWithStack("groupConcatDistinctOrder.MergePartialResult should not be called") +} diff --git a/executor/aggfuncs/func_group_concat_test.go b/executor/aggfuncs/func_group_concat_test.go index 576ed8ffd2ad9..7e68e93cfed3b 100644 --- a/executor/aggfuncs/func_group_concat_test.go +++ b/executor/aggfuncs/func_group_concat_test.go @@ -14,9 +14,13 @@ package aggfuncs_test import ( + "fmt" + . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/types" ) func (s *testSuite) TestMergePartialResult4GroupConcat(c *C) { @@ -25,6 +29,19 @@ func (s *testSuite) TestMergePartialResult4GroupConcat(c *C) { } func (s *testSuite) TestGroupConcat(c *C) { - test := buildAggTester(ast.AggFuncGroupConcat, mysql.TypeString, 5, nil, "0 1 2 3 4", "0 1 2 3 4 2 3 4") + test := buildAggTester(ast.AggFuncGroupConcat, mysql.TypeString, 5, nil, "0 1 2 3 4") s.testAggFunc(c, test) + + test2 := buildMultiArgsAggTester(ast.AggFuncGroupConcat, []byte{mysql.TypeString, mysql.TypeString}, mysql.TypeString, 5, nil, "44 33 22 11 00") + test2.orderBy = true + s.testMultiArgsAggFunc(c, test2) + + defer variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.GroupConcatMaxLen, types.NewStringDatum("1024")) + // minimum GroupConcatMaxLen is 4 + for i := 4; i <= 7; i++ { + variable.SetSessionSystemVar(s.ctx.GetSessionVars(), variable.GroupConcatMaxLen, types.NewStringDatum(fmt.Sprint(i))) + test2 = buildMultiArgsAggTester(ast.AggFuncGroupConcat, []byte{mysql.TypeString, mysql.TypeString}, mysql.TypeString, 5, nil, "44 33 22 11 00"[:i]) + test2.orderBy = true + s.testMultiArgsAggFunc(c, test2) + } } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 1dac01d17f0f8..37f9463a5f704 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -14,6 +14,8 @@ package executor_test import ( + "fmt" + . "github.com/pingcap/check" "github.com/pingcap/errors" "github.com/pingcap/parser/terror" @@ -378,9 +380,11 @@ func (s *testSuite1) TestAggPrune(c *C) { } func (s *testSuite1) TestGroupConcatAggr(c *C) { + var err error // issue #5411 tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") + tk.MustExec("drop table if exists test;") tk.MustExec("create table test(id int, name int)") tk.MustExec("insert into test values(1, 10);") tk.MustExec("insert into test values(1, 20);") @@ -406,6 +410,96 @@ func (s *testSuite1) TestGroupConcatAggr(c *C) { result = tk.MustQuery("select id, group_concat(name SEPARATOR '123') from test group by id order by id") result.Check(testkit.Rows("1 101232012330", "2 20", "3 200123500")) + tk.MustQuery("select group_concat(id ORDER BY name) from (select * from test order by id, name limit 2,2) t").Check(testkit.Rows("2,1")) + tk.MustQuery("select group_concat(id ORDER BY name desc) from (select * from test order by id, name limit 2,2) t").Check(testkit.Rows("1,2")) + tk.MustQuery("select group_concat(name ORDER BY id) from (select * from test order by id, name limit 2,2) t").Check(testkit.Rows("30,20")) + tk.MustQuery("select group_concat(name ORDER BY id desc) from (select * from test order by id, name limit 2,2) t").Check(testkit.Rows("20,30")) + + result = tk.MustQuery("select group_concat(name ORDER BY name desc SEPARATOR '++') from test;") + result.Check(testkit.Rows("500++200++30++20++20++10")) + + result = tk.MustQuery("select group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from test;") + result.Check(testkit.Rows("3--3--1--1--2--1")) + + result = tk.MustQuery("select group_concat(name ORDER BY name desc SEPARATOR '++'), group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from test;") + result.Check(testkit.Rows("500++200++30++20++20++10 3--3--1--1--2--1")) + + result = tk.MustQuery("select group_concat(distinct name order by name desc) from test;") + result.Check(testkit.Rows("500,200,30,20,10")) + + expected := "3--3--1--1--2--1" + for maxLen := 4; maxLen < len(expected); maxLen++ { + tk.MustExec(fmt.Sprintf("set session group_concat_max_len=%v", maxLen)) + result = tk.MustQuery("select group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from test;") + result.Check(testkit.Rows(expected[:maxLen])) + c.Assert(tk.Se.GetSessionVars().StmtCtx.GetWarnings(), HasLen, 1) + } + expected = "1--2--1--1--3--3" + for maxLen := 4; maxLen < len(expected); maxLen++ { + tk.MustExec(fmt.Sprintf("set session group_concat_max_len=%v", maxLen)) + result = tk.MustQuery("select group_concat(id ORDER BY name asc, id desc SEPARATOR '--') from test;") + result.Check(testkit.Rows(expected[:maxLen])) + c.Assert(tk.Se.GetSessionVars().StmtCtx.GetWarnings(), HasLen, 1) + } + expected = "500,200,30,20,10" + for maxLen := 4; maxLen < len(expected); maxLen++ { + tk.MustExec(fmt.Sprintf("set session group_concat_max_len=%v", maxLen)) + result = tk.MustQuery("select group_concat(distinct name order by name desc) from test;") + result.Check(testkit.Rows(expected[:maxLen])) + c.Assert(tk.Se.GetSessionVars().StmtCtx.GetWarnings(), HasLen, 1) + } + + tk.MustExec(fmt.Sprintf("set session group_concat_max_len=%v", 1024)) + + // test varchar table + tk.MustExec("drop table if exists test2;") + tk.MustExec("create table test2(id varchar(20), name varchar(20));") + tk.MustExec("insert into test2 select * from test;") + + tk.MustQuery("select group_concat(id ORDER BY name) from (select * from test2 order by id, name limit 2,2) t").Check(testkit.Rows("2,1")) + tk.MustQuery("select group_concat(id ORDER BY name desc) from (select * from test2 order by id, name limit 2,2) t").Check(testkit.Rows("1,2")) + tk.MustQuery("select group_concat(name ORDER BY id) from (select * from test2 order by id, name limit 2,2) t").Check(testkit.Rows("30,20")) + tk.MustQuery("select group_concat(name ORDER BY id desc) from (select * from test2 order by id, name limit 2,2) t").Check(testkit.Rows("20,30")) + + result = tk.MustQuery("select group_concat(name ORDER BY name desc SEPARATOR '++'), group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from test2;") + result.Check(testkit.Rows("500++30++200++20++20++10 3--1--3--1--2--1")) + + // test Position Expr + tk.MustQuery("select 1, 2, 3, 4, 5 , group_concat(name, id ORDER BY 1 desc, id SEPARATOR '++') from test;").Check(testkit.Rows("1 2 3 4 5 5003++2003++301++201++202++101")) + tk.MustQuery("select 1, 2, 3, 4, 5 , group_concat(name, id ORDER BY 2 desc, name SEPARATOR '++') from test;").Check(testkit.Rows("1 2 3 4 5 2003++5003++202++101++201++301")) + err = tk.ExecToErr("select 1, 2, 3, 4, 5 , group_concat(name, id ORDER BY 3 desc, name SEPARATOR '++') from test;") + c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '3' in 'order clause'") + + // test Param Marker + tk.MustExec(`prepare s1 from "select 1, 2, 3, 4, 5 , group_concat(name, id ORDER BY floor(id/?) desc, name SEPARATOR '++') from test";`) + tk.MustExec("set @a=2;") + tk.MustQuery("execute s1 using @a;").Check(testkit.Rows("1 2 3 4 5 202++2003++5003++101++201++301")) + + tk.MustExec(`prepare s1 from "select 1, 2, 3, 4, 5 , group_concat(name, id ORDER BY ? desc, name SEPARATOR '++') from test";`) + tk.MustExec("set @a=2;") + tk.MustQuery("execute s1 using @a;").Check(testkit.Rows("1 2 3 4 5 2003++5003++202++101++201++301")) + tk.MustExec("set @a=3;") + err = tk.ExecToErr("execute s1 using @a;") + c.Assert(err.Error(), Equals, "[planner:1054]Unknown column '?' in 'order clause'") + tk.MustExec("set @a=3.0;") + tk.MustQuery("execute s1 using @a;").Check(testkit.Rows("1 2 3 4 5 101++202++201++301++2003++5003")) + + // test partition table + tk.MustExec("drop table if exists ptest;") + tk.MustExec("CREATE TABLE ptest (id int,name int) PARTITION BY RANGE ( id ) " + + "(PARTITION `p0` VALUES LESS THAN (2), PARTITION `p1` VALUES LESS THAN (11))") + tk.MustExec("insert into ptest select * from test;") + + for j := 0; j <= 1; j++ { + tk.MustExec(fmt.Sprintf("set session tidb_opt_agg_push_down = %v", j)) + + result = tk.MustQuery("select /*+ agg_to_cop */ group_concat(name ORDER BY name desc SEPARATOR '++'), group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from ptest;") + result.Check(testkit.Rows("500++200++30++20++20++10 3--3--1--1--2--1")) + + result = tk.MustQuery("select /*+ agg_to_cop */ group_concat(distinct name order by name desc) from ptest;") + result.Check(testkit.Rows("500,200,30,20,10")) + } + // issue #9920 tk.MustQuery("select group_concat(123, null)").Check(testkit.Rows("")) } diff --git a/executor/builder.go b/executor/builder.go index 2472d6449dbe0..cc0cc1aed4e39 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -1073,7 +1073,7 @@ func (b *executorBuilder) buildHashAgg(v *plannercore.PhysicalHashAgg) Executor e.defaultVal = chunk.NewChunkWithCapacity(retTypes(e), 1) } for _, aggDesc := range v.AggFuncs { - if aggDesc.HasDistinct { + if aggDesc.HasDistinct || len(aggDesc.OrderByItems) > 0 { e.isUnparallelExec = true } } diff --git a/executor/executor_required_rows_test.go b/executor/executor_required_rows_test.go index 804ccec90031d..ee30b99c2a488 100644 --- a/executor/executor_required_rows_test.go +++ b/executor/executor_required_rows_test.go @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -253,10 +254,10 @@ func (s *testExecSuite) TestSortRequiredRows(c *C) { sctx := defaultCtx() ctx := context.Background() ds := newRequiredRowsDataSource(sctx, testCase.totalRows, testCase.expectedRowsDS) - byItems := make([]*plannercore.ByItems, 0, len(testCase.groupBy)) + byItems := make([]*util.ByItems, 0, len(testCase.groupBy)) for _, groupBy := range testCase.groupBy { col := ds.Schema().Columns[groupBy] - byItems = append(byItems, &plannercore.ByItems{Expr: col}) + byItems = append(byItems, &util.ByItems{Expr: col}) } exec := buildSortExec(sctx, byItems, ds) c.Assert(exec.Open(ctx), IsNil) @@ -271,7 +272,7 @@ func (s *testExecSuite) TestSortRequiredRows(c *C) { } } -func buildSortExec(sctx sessionctx.Context, byItems []*plannercore.ByItems, src Executor) Executor { +func buildSortExec(sctx sessionctx.Context, byItems []*util.ByItems, src Executor) Executor { sortExec := SortExec{ baseExecutor: newBaseExecutor(sctx, src.Schema(), nil, src), ByItems: byItems, @@ -360,10 +361,10 @@ func (s *testExecSuite) TestTopNRequiredRows(c *C) { sctx := defaultCtx() ctx := context.Background() ds := newRequiredRowsDataSource(sctx, testCase.totalRows, testCase.expectedRowsDS) - byItems := make([]*plannercore.ByItems, 0, len(testCase.groupBy)) + byItems := make([]*util.ByItems, 0, len(testCase.groupBy)) for _, groupBy := range testCase.groupBy { col := ds.Schema().Columns[groupBy] - byItems = append(byItems, &plannercore.ByItems{Expr: col}) + byItems = append(byItems, &util.ByItems{Expr: col}) } exec := buildTopNExec(sctx, testCase.topNOffset, testCase.topNCount, byItems, ds) c.Assert(exec.Open(ctx), IsNil) @@ -378,7 +379,7 @@ func (s *testExecSuite) TestTopNRequiredRows(c *C) { } } -func buildTopNExec(ctx sessionctx.Context, offset, count int, byItems []*plannercore.ByItems, src Executor) Executor { +func buildTopNExec(ctx sessionctx.Context, offset, count int, byItems []*util.ByItems, src Executor) Executor { sortExec := SortExec{ baseExecutor: newBaseExecutor(ctx, src.Schema(), nil, src), ByItems: byItems, diff --git a/executor/sort.go b/executor/sort.go index fefac4cda4c1a..3867de92b7e2e 100644 --- a/executor/sort.go +++ b/executor/sort.go @@ -23,6 +23,7 @@ import ( "github.com/opentracing/opentracing-go" "github.com/pingcap/tidb/expression" plannercore "github.com/pingcap/tidb/planner/core" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/memory" @@ -35,7 +36,7 @@ var rowChunksLabel fmt.Stringer = stringutil.StringerStr("rowChunks") type SortExec struct { baseExecutor - ByItems []*plannercore.ByItems + ByItems []*util.ByItems Idx int fetched bool schema *expression.Schema diff --git a/expression/aggregation/agg_to_pb.go b/expression/aggregation/agg_to_pb.go index 59d09db237701..a6a003a01a86f 100644 --- a/expression/aggregation/agg_to_pb.go +++ b/expression/aggregation/agg_to_pb.go @@ -26,6 +26,9 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag if aggFunc.HasDistinct { return nil } + if len(aggFunc.OrderByItems) > 0 { + return nil + } pc := expression.NewPBConverter(client, sc) var tp tipb.ExprType switch aggFunc.Name { diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index 66f5f3346c805..24a7e7f5fc8d8 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" @@ -34,6 +35,8 @@ type AggFuncDesc struct { Mode AggFunctionMode // HasDistinct represents whether the aggregation function contains distinct attribute. HasDistinct bool + // OrderByItems represents the order by clause used in GROUP_CONCAT + OrderByItems []*util.ByItems } // NewAggFuncDesc creates an aggregation function signature descriptor. @@ -50,6 +53,14 @@ func (a *AggFuncDesc) Equal(ctx sessionctx.Context, other *AggFuncDesc) bool { if a.HasDistinct != other.HasDistinct { return false } + if len(a.OrderByItems) != len(other.OrderByItems) { + return false + } + for i := range a.OrderByItems { + if !a.OrderByItems[i].Equal(ctx, other.OrderByItems[i]) { + return false + } + } return a.baseFuncDesc.equal(ctx, &other.baseFuncDesc) } @@ -57,6 +68,10 @@ func (a *AggFuncDesc) Equal(ctx sessionctx.Context, other *AggFuncDesc) bool { func (a *AggFuncDesc) Clone() *AggFuncDesc { clone := *a clone.baseFuncDesc = *a.baseFuncDesc.clone() + clone.OrderByItems = make([]*util.ByItems, len(a.OrderByItems)) + for i, byItem := range a.OrderByItems { + clone.OrderByItems[i] = byItem.Clone() + } return &clone } diff --git a/expression/aggregation/explain.go b/expression/aggregation/explain.go index 0a6a01a4ed8b7..b001a21c23d1e 100644 --- a/expression/aggregation/explain.go +++ b/expression/aggregation/explain.go @@ -16,6 +16,8 @@ package aggregation import ( "bytes" "fmt" + + "github.com/pingcap/parser/ast" ) // ExplainAggFunc generates explain information for a aggregation function. @@ -26,10 +28,25 @@ func ExplainAggFunc(agg *AggFuncDesc) string { buffer.WriteString("distinct ") } for i, arg := range agg.Args { - buffer.WriteString(arg.ExplainInfo()) - if i+1 < len(agg.Args) { + if agg.Name == ast.AggFuncGroupConcat && i == len(agg.Args)-1 { + if len(agg.OrderByItems) > 0 { + buffer.WriteString(" order by ") + for i, item := range agg.OrderByItems { + order := "asc" + if item.Desc { + order = "desc" + } + fmt.Fprintf(&buffer, "%s %s", item.Expr.ExplainInfo(), order) + if i+1 < len(agg.OrderByItems) { + buffer.WriteString(", ") + } + } + } + buffer.WriteString(" separator ") + } else if i != 0 { buffer.WriteString(", ") } + buffer.WriteString(arg.ExplainInfo()) } buffer.WriteString(")") return buffer.String() diff --git a/go.mod b/go.mod index f4d1169553772..ef73fefd09710 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/pingcap/goleveldb v0.0.0-20171020122428-b9ff6c35079e github.com/pingcap/kvproto v0.0.0-20200317043902-2838e21ca222 github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd - github.com/pingcap/parser v3.1.0-beta.2.0.20200425032215-994651e9b6df+incompatible + github.com/pingcap/parser v3.1.2-0.20200507065358-a5eade012146+incompatible github.com/pingcap/pd/v3 v3.1.0-beta.2.0.20200312100832-1206736bd050 github.com/pingcap/tidb-tools v4.0.0-beta.1.0.20200317092225-ed6b2a87af54+incompatible github.com/pingcap/tipb v0.0.0-20200426072603-ce17d2d03251 diff --git a/go.sum b/go.sum index b70465a031138..99599962f30f5 100644 --- a/go.sum +++ b/go.sum @@ -248,8 +248,8 @@ github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9 h1:AJD9pZYm72vMgPcQDww github.com/pingcap/log v0.0.0-20191012051959-b742a5d432e9/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd h1:CV3VsP3Z02MVtdpTMfEgRJ4T9NGgGTxdHpJerent7rM= github.com/pingcap/log v0.0.0-20200117041106-d28c14d3b1cd/go.mod h1:4rbK1p9ILyIfb6hU7OG2CiWSqMXnp3JMbiaVJ6mvoY8= -github.com/pingcap/parser v3.1.0-beta.2.0.20200425032215-994651e9b6df+incompatible h1:9kFDvyd1YTGin3xHetc68xAChSOgQD2vThXxBR9ccH0= -github.com/pingcap/parser v3.1.0-beta.2.0.20200425032215-994651e9b6df+incompatible/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= +github.com/pingcap/parser v3.1.2-0.20200507065358-a5eade012146+incompatible h1:yIsvQ1+8J1ZEPkOMRIThrSn3UPyD2iiHXBi0Z3uOmXI= +github.com/pingcap/parser v3.1.2-0.20200507065358-a5eade012146+incompatible/go.mod h1:1FNvfp9+J0wvc4kl8eGNh7Rqrxveg15jJoWo/a0uHwA= github.com/pingcap/pd/v3 v3.1.0-beta.2.0.20200312100832-1206736bd050 h1:mxPdR0pxnUcRfRGX2JnaLyAd9SZWeR42SzvMp4Zv3YI= github.com/pingcap/pd/v3 v3.1.0-beta.2.0.20200312100832-1206736bd050/go.mod h1:0HfF1LfWLMuGpui0PKhGvkXxfjv1JslMRY6B+cae3dg= github.com/pingcap/tidb-tools v4.0.0-beta.1.0.20200317092225-ed6b2a87af54+incompatible h1:tYADqdmWwgDOwf/qEN0trJAy6H3c3Tt/QZx1z4qVrRQ= diff --git a/planner/cascades/enforcer_rules.go b/planner/cascades/enforcer_rules.go index e3d41b0a88719..c706f9942234d 100644 --- a/planner/cascades/enforcer_rules.go +++ b/planner/cascades/enforcer_rules.go @@ -18,6 +18,7 @@ import ( "github.com/pingcap/tidb/planner/implementation" "github.com/pingcap/tidb/planner/memo" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" ) // Enforcer defines the interface for enforcer rules. @@ -56,10 +57,10 @@ func (e *OrderEnforcer) NewProperty(prop *property.PhysicalProperty) (newProp *p // OnEnforce adds sort operator to satisfy required order property. func (e *OrderEnforcer) OnEnforce(reqProp *property.PhysicalProperty, child memo.Implementation) (impl memo.Implementation) { sort := &plannercore.PhysicalSort{ - ByItems: make([]*plannercore.ByItems, 0, len(reqProp.Items)), + ByItems: make([]*util.ByItems, 0, len(reqProp.Items)), } for _, item := range reqProp.Items { - item := &plannercore.ByItems{ + item := &util.ByItems{ Expr: item.Col, Desc: item.Desc, } diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index e70077f7169f3..237a8118ef89d 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -26,6 +26,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -1155,7 +1156,7 @@ func (lt *LogicalTopN) getPhysLimits() []PhysicalPlan { } // Check if this prop's columns can match by items totally. -func matchItems(p *property.PhysicalProperty, items []*ByItems) bool { +func matchItems(p *property.PhysicalProperty, items []*util.ByItems) bool { if len(items) < len(p.Items) { return false } diff --git a/planner/core/explain.go b/planner/core/explain.go index 6214c228ebe38..e06f20a77f79f 100644 --- a/planner/core/explain.go +++ b/planner/core/explain.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/statistics" ) @@ -459,7 +460,7 @@ func (p *PhysicalWindow) formatFrameBound(buffer *bytes.Buffer, bound *FrameBoun } } -func explainNormalizedByItems(buffer *bytes.Buffer, byItems []*ByItems) *bytes.Buffer { +func explainNormalizedByItems(buffer *bytes.Buffer, byItems []*util.ByItems) *bytes.Buffer { for i, item := range byItems { order := "asc" if item.Desc { diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index ec5423ea0ecf6..c0a6e3137d460 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" @@ -44,7 +45,7 @@ var invalidTask = &rootTask{cst: math.MaxFloat64} // getPropByOrderByItems will check if this sort property can be pushed or not. In order to simplify the problem, we only // consider the case that all expression are columns. -func getPropByOrderByItems(items []*ByItems) (*property.PhysicalProperty, bool) { +func getPropByOrderByItems(items []*util.ByItems) (*property.PhysicalProperty, bool) { propItems := make([]property.Item, 0, len(items)) for _, item := range items { col, ok := item.Expr.(*expression.Column) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 56e96bb19b50c..6b4af1e40b157 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -20,6 +20,7 @@ import ( "math/bits" "reflect" "sort" + "strconv" "strings" "unicode" @@ -38,6 +39,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" @@ -96,6 +98,54 @@ func (la *LogicalAggregation) collectGroupByColumns() { } } +// aggOrderByResolver is currently resolving expressions of order by clause +// in aggregate function GROUP_CONCAT. +type aggOrderByResolver struct { + ctx sessionctx.Context + err error + args []ast.ExprNode + exprDepth int // exprDepth is the depth of current expression in expression tree. +} + +func (a *aggOrderByResolver) Enter(inNode ast.Node) (ast.Node, bool) { + a.exprDepth++ + switch n := inNode.(type) { + case *driver.ParamMarkerExpr: + if a.exprDepth == 1 { + _, isNull, isExpectedType := getUintFromNode(a.ctx, n) + // For constant uint expression in top level, it should be treated as position expression. + if !isNull && isExpectedType { + return expression.ConstructPositionExpr(n), true + } + } + } + return inNode, false +} + +func (a *aggOrderByResolver) Leave(inNode ast.Node) (ast.Node, bool) { + switch v := inNode.(type) { + case *ast.PositionExpr: + pos, isNull, err := expression.PosFromPositionExpr(a.ctx, v) + if err != nil { + a.err = err + } + if err != nil || isNull { + return inNode, false + } + if pos < 1 || pos > len(a.args) { + errPos := strconv.Itoa(pos) + if v.P != nil { + errPos = "?" + } + a.err = ErrUnknownColumn.FastGenByArgs(errPos, "order clause") + return inNode, false + } + ret := a.args[pos-1] + return ret, true + } + return inNode, true +} + func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int, error) { b.optFlag = b.optFlag | flagBuildKeyInfo b.optFlag = b.optFlag | flagPushDownAgg @@ -130,6 +180,27 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu if err != nil { return nil, nil, err } + if aggFunc.Order != nil { + trueArgs := aggFunc.Args[:len(aggFunc.Args)-1] // the last argument is SEPARATOR, remote it. + resolver := &aggOrderByResolver{ + ctx: b.ctx, + args: trueArgs, + } + for _, byItem := range aggFunc.Order.Items { + resolver.exprDepth = 0 + resolver.err = nil + retExpr, _ := byItem.Expr.Accept(resolver) + if resolver.err != nil { + return nil, nil, errors.Trace(resolver.err) + } + newByItem, np, err := b.rewrite(ctx, retExpr.(ast.ExprNode), p, nil, true) + if err != nil { + return nil, nil, err + } + p = np + newFunc.OrderByItems = append(newFunc.OrderByItems, &util.ByItems{Expr: newByItem, Desc: byItem.Desc}) + } + } combined := false for j, oldFunc := range plan4Agg.AggFuncs { if oldFunc.Equal(b.ctx, newFunc) { @@ -1035,25 +1106,6 @@ func (b *PlanBuilder) buildUnionAll(ctx context.Context, subPlan []LogicalPlan) return u } -// ByItems wraps a "by" item. -type ByItems struct { - Expr expression.Expression - Desc bool -} - -// String implements fmt.Stringer interface. -func (by *ByItems) String() string { - if by.Desc { - return fmt.Sprintf("%s true", by.Expr) - } - return by.Expr.String() -} - -// Clone makes a copy of ByItems. -func (by *ByItems) Clone() *ByItems { - return &ByItems{Expr: by.Expr.Clone(), Desc: by.Desc} -} - // itemTransformer transforms ParamMarkerExpr to PositionExpr in the context of ByItem type itemTransformer struct { } @@ -1078,7 +1130,7 @@ func (b *PlanBuilder) buildSort(ctx context.Context, p LogicalPlan, byItems []*a b.curClause = orderByClause } sort := LogicalSort{}.Init(b.ctx) - exprs := make([]*ByItems, 0, len(byItems)) + exprs := make([]*util.ByItems, 0, len(byItems)) transformer := &itemTransformer{} for _, item := range byItems { newExpr, _ := item.Expr.Accept(transformer) @@ -1089,7 +1141,7 @@ func (b *PlanBuilder) buildSort(ctx context.Context, p LogicalPlan, byItems []*a } p = np - exprs = append(exprs, &ByItems{Expr: it, Desc: item.Desc}) + exprs = append(exprs, &util.ByItems{Expr: it, Desc: item.Desc}) } sort.ByItems = exprs sort.SetChildren(p) diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index 85374da8f8ac4..ad40169c0202c 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -30,6 +30,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/testleak" ) @@ -2593,7 +2594,7 @@ func (s *testPlanSuite) optimize(ctx context.Context, sql string) (PhysicalPlan, return p.(PhysicalPlan), stmt, err } -func byItemsToProperty(byItems []*ByItems) *property.PhysicalProperty { +func byItemsToProperty(byItems []*util.ByItems) *property.PhysicalProperty { pp := &property.PhysicalProperty{} for _, item := range byItems { pp.Items = append(pp.Items, property.Item{Col: item.Expr.(*expression.Column), Desc: item.Desc}) @@ -2680,7 +2681,7 @@ func (s *testPlanSuite) TestSkylinePruning(c *C) { _, err = lp.recursiveDeriveStats() c.Assert(err, IsNil) var ds *DataSource - var byItems []*ByItems + var byItems []*util.ByItems for ds == nil { switch v := lp.(type) { case *DataSource: diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 705f14b60a78d..970c423f9e729 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/types" @@ -663,7 +664,7 @@ type LogicalUnionAll struct { type LogicalSort struct { baseLogicalPlan - ByItems []*ByItems + ByItems []*util.ByItems } func (ls *LogicalSort) extractCorrelatedCols() []*expression.CorrelatedColumn { @@ -678,7 +679,7 @@ func (ls *LogicalSort) extractCorrelatedCols() []*expression.CorrelatedColumn { type LogicalTopN struct { baseLogicalPlan - ByItems []*ByItems + ByItems []*util.ByItems Offset uint64 Count uint64 } diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index b35449f2ec6ad..988e4774a11f9 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -15,6 +15,7 @@ package core_test import ( "context" + "fmt" . "github.com/pingcap/check" "github.com/pingcap/parser" @@ -616,7 +617,49 @@ func (s *testPlanSuite) TestUnmatchedTableInHint(c *C) { } } } +func (s *testPlanSuite) TestGroupConcatOrderby(c *C) { + var ( + input []string + output []struct { + SQL string + Plan []string + Result []string + } + ) + s.testData.GetTestCases(c, &input, &output) + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + tk := testkit.NewTestKit(c, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists test;") + tk.MustExec("create table test(id int, name int)") + tk.MustExec("insert into test values(1, 10);") + tk.MustExec("insert into test values(1, 20);") + tk.MustExec("insert into test values(1, 30);") + tk.MustExec("insert into test values(2, 20);") + tk.MustExec("insert into test values(3, 200);") + tk.MustExec("insert into test values(3, 500);") + + tk.MustExec("drop table if exists ptest;") + tk.MustExec("CREATE TABLE ptest (id int,name int) PARTITION BY RANGE ( id ) " + + "(PARTITION `p0` VALUES LESS THAN (2), PARTITION `p1` VALUES LESS THAN (11))") + tk.MustExec("insert into ptest select * from test;") + tk.MustExec(fmt.Sprintf("set session tidb_opt_agg_push_down = %v", 1)) + for i, ts := range input { + s.testData.OnRecord(func() { + output[i].SQL = ts + output[i].Plan = s.testData.ConvertRowsToStrings(tk.MustQuery("explain " + ts).Rows()) + output[i].Result = s.testData.ConvertRowsToStrings(tk.MustQuery(ts).Sort().Rows()) + }) + tk.MustQuery("explain " + ts).Check(testkit.Rows(output[i].Plan...)) + tk.MustQuery(ts).Check(testkit.Rows(output[i].Result...)) + } +} func (s *testPlanSuite) TestJoinHints(c *C) { defer testleak.AfterTest(c)() store, dom, err := newStoreWithBootstrap() diff --git a/planner/core/physical_plans.go b/planner/core/physical_plans.go index 1ab40ab606fea..e047679906a35 100644 --- a/planner/core/physical_plans.go +++ b/planner/core/physical_plans.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/table" @@ -201,7 +202,7 @@ type PhysicalProjection struct { type PhysicalTopN struct { basePhysicalPlan - ByItems []*ByItems + ByItems []*util.ByItems Offset uint64 Count uint64 } @@ -363,7 +364,7 @@ type PhysicalStreamAgg struct { type PhysicalSort struct { basePhysicalPlan - ByItems []*ByItems + ByItems []*util.ByItems } // NominalSort asks sort properties for its child. It is a fake operator that will not diff --git a/planner/core/plan.go b/planner/core/plan.go index 901b67b9bbc7a..0403a8ee1fdf9 100644 --- a/planner/core/plan.go +++ b/planner/core/plan.go @@ -22,6 +22,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/stringutil" "github.com/pingcap/tipb/go-tipb" @@ -56,9 +57,9 @@ func enforceProperty(p *property.PhysicalProperty, tsk task, ctx sessionctx.Cont } tsk = finishCopTask(ctx, tsk) sortReqProp := &property.PhysicalProperty{TaskTp: property.RootTaskType, Items: p.Items, ExpectedCnt: math.MaxFloat64} - sort := PhysicalSort{ByItems: make([]*ByItems, 0, len(p.Items))}.Init(ctx, tsk.plan().statsInfo(), sortReqProp) + sort := PhysicalSort{ByItems: make([]*util.ByItems, 0, len(p.Items))}.Init(ctx, tsk.plan().statsInfo(), sortReqProp) for _, col := range p.Items { - sort.ByItems = append(sort.ByItems, &ByItems{col.Col, col.Desc}) + sort.ByItems = append(sort.ByItems, &util.ByItems{Expr: col.Col, Desc: col.Desc}) } return sort.attach2Task(tsk) } diff --git a/planner/core/property_cols_prune.go b/planner/core/property_cols_prune.go index 1ec2933485ec8..4c22ac39fab5f 100644 --- a/planner/core/property_cols_prune.go +++ b/planner/core/property_cols_prune.go @@ -15,6 +15,7 @@ package core import ( "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/planner/util" ) func (ds *DataSource) preparePossibleProperties() [][]*expression.Column { @@ -57,7 +58,7 @@ func (p *LogicalTopN) preparePossibleProperties() [][]*expression.Column { return [][]*expression.Column{propCols} } -func getPossiblePropertyFromByItems(items []*ByItems) []*expression.Column { +func getPossiblePropertyFromByItems(items []*util.ByItems) []*expression.Column { cols := make([]*expression.Column, 0, len(items)) for _, item := range items { if col, ok := item.Expr.(*expression.Column); ok { diff --git a/planner/core/resolve_indices.go b/planner/core/resolve_indices.go index 419b4b631845e..ee691e1f533f8 100644 --- a/planner/core/resolve_indices.go +++ b/planner/core/resolve_indices.go @@ -292,6 +292,12 @@ func (p *basePhysicalAgg) ResolveIndices() (err error) { return err } } + for _, byItem := range aggFun.OrderByItems { + byItem.Expr, err = byItem.Expr.ResolveIndices(p.children[0].Schema()) + if err != nil { + return err + } + } } for i, item := range p.GroupByItems { p.GroupByItems[i], err = item.ResolveIndices(p.children[0].Schema()) diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index caa6a8dd4fcfe..4f10b83d278e4 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -35,7 +35,10 @@ type aggregationPushDownSolver struct { // It's easy to see that max, min, first row is decomposable, no matter whether it's distinct, but sum(distinct) and // count(distinct) is not. // Currently we don't support avg and concat. -func (a *aggregationPushDownSolver) isDecomposable(fun *aggregation.AggFuncDesc) bool { +func (a *aggregationPushDownSolver) isDecomposableWithJoin(fun *aggregation.AggFuncDesc) bool { + if len(fun.OrderByItems) > 0 { + return false + } switch fun.Name { case ast.AggFuncAvg, ast.AggFuncGroupConcat: // TODO: Support avg push down. @@ -49,6 +52,22 @@ func (a *aggregationPushDownSolver) isDecomposable(fun *aggregation.AggFuncDesc) } } +func (a *aggregationPushDownSolver) isDecomposableWithUnion(fun *aggregation.AggFuncDesc) bool { + if len(fun.OrderByItems) > 0 { + return false + } + switch fun.Name { + case ast.AggFuncGroupConcat, ast.AggFuncVarPop: + return false + case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow: + return true + case ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncAvg: + return true + default: + return false + } +} + // getAggFuncChildIdx gets which children it belongs to, 0 stands for left, 1 stands for right, -1 stands for both. func (a *aggregationPushDownSolver) getAggFuncChildIdx(aggFunc *aggregation.AggFuncDesc, schema *expression.Schema) int { fromLeft, fromRight := false, false @@ -76,7 +95,7 @@ func (a *aggregationPushDownSolver) collectAggFuncs(agg *LogicalAggregation, joi valid = true leftChild := join.children[0] for _, aggFunc := range agg.AggFuncs { - if !a.isDecomposable(aggFunc) { + if !a.isDecomposableWithJoin(aggFunc) { return false, nil, nil } index := a.getAggFuncChildIdx(aggFunc, leftChild.Schema()) @@ -380,6 +399,11 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e projChild := proj.children[0] agg.SetChildren(projChild) } else if union, ok1 := child.(*LogicalUnionAll); ok1 { + for _, aggFunc := range agg.AggFuncs { + if !a.isDecomposableWithUnion(aggFunc) { + return p, nil + } + } var gbyCols []*expression.Column gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil) pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols, agg.aggHints) diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 595cd6d04e30e..c3e5338be1e52 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/types" ) @@ -120,6 +121,10 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) var selfUsedCols []*expression.Column for _, aggrFunc := range la.AggFuncs { selfUsedCols = expression.ExtractColumnsFromExpressions(selfUsedCols, aggrFunc.Args, nil) + + var cols []*expression.Column + aggrFunc.OrderByItems, cols = pruneByItems(aggrFunc.OrderByItems) + selfUsedCols = append(selfUsedCols, cols...) } if len(la.AggFuncs) == 0 { // If all the aggregate functions are pruned, we should add an aggregate function to keep the correctness. @@ -154,22 +159,32 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) return child.PruneColumns(selfUsedCols) } -// PruneColumns implements LogicalPlan interface. -func (ls *LogicalSort) PruneColumns(parentUsedCols []*expression.Column) error { - child := ls.children[0] - for i := len(ls.ByItems) - 1; i >= 0; i-- { - cols := expression.ExtractColumns(ls.ByItems[i].Expr) +func pruneByItems(old []*util.ByItems) (new []*util.ByItems, parentUsedCols []*expression.Column) { + new = make([]*util.ByItems, 0, len(old)) + for _, byItem := range old { + cols := expression.ExtractColumns(byItem.Expr) if len(cols) == 0 { - if !expression.IsRuntimeConstExpr(ls.ByItems[i].Expr) { - continue + if !expression.IsRuntimeConstExpr(byItem.Expr) { + new = append(new, byItem) } - ls.ByItems = append(ls.ByItems[:i], ls.ByItems[i+1:]...) - } else if ls.ByItems[i].Expr.GetType().Tp == mysql.TypeNull { - ls.ByItems = append(ls.ByItems[:i], ls.ByItems[i+1:]...) + } else if byItem.Expr.GetType().Tp == mysql.TypeNull { + // do nothing, should be filtered } else { parentUsedCols = append(parentUsedCols, cols...) + new = append(new, byItem) } } + return +} + +// PruneColumns implements LogicalPlan interface. +// If any expression can view as a constant in execution stage, such as correlated column, constant, +// we do prune them. Note that we can't prune the expressions contain non-deterministic functions, such as rand(). +func (ls *LogicalSort) PruneColumns(parentUsedCols []*expression.Column) error { + child := ls.children[0] + var cols []*expression.Column + ls.ByItems, cols = pruneByItems(ls.ByItems) + parentUsedCols = append(parentUsedCols, cols...) return child.PruneColumns(parentUsedCols) } diff --git a/planner/core/rule_inject_extra_projection.go b/planner/core/rule_inject_extra_projection.go index a2bcbcf28e3eb..f4533c412a69d 100644 --- a/planner/core/rule_inject_extra_projection.go +++ b/planner/core/rule_inject_extra_projection.go @@ -19,6 +19,7 @@ import ( "github.com/pingcap/parser/model" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" ) @@ -79,6 +80,10 @@ func injectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDes _, isScalarFunc := arg.(*expression.ScalarFunction) hasScalarFunc = hasScalarFunc || isScalarFunc } + for _, byItem := range aggFuncs[i].OrderByItems { + _, isScalarFunc := byItem.Expr.(*expression.ScalarFunction) + hasScalarFunc = hasScalarFunc || isScalarFunc + } } for i := 0; !hasScalarFunc && i < len(groupByItems); i++ { _, isScalarFunc := groupByItems[i].(*expression.ScalarFunction) @@ -108,6 +113,21 @@ func injectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDes f.Args[i] = newArg cursor++ } + for _, byItem := range f.OrderByItems { + if _, isCnst := byItem.Expr.(*expression.Constant); isCnst { + continue + } + projExprs = append(projExprs, byItem.Expr) + newArg := &expression.Column{ + UniqueID: aggPlan.context().GetSessionVars().AllocPlanColumnID(), + RetType: byItem.Expr.GetType(), + ColName: model.NewCIStr(fmt.Sprintf("col_%d", len(projSchemaCols))), + Index: cursor, + } + projSchemaCols = append(projSchemaCols, newArg) + byItem.Expr = newArg + cursor++ + } } for i, item := range groupByItems { @@ -146,7 +166,7 @@ func injectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDes // PhysicalTopN, some extra columns will be added into the schema of the // Projection, thus we need to add another Projection upon them to prune the // redundant columns. -func injectProjBelowSort(p PhysicalPlan, orderByItems []*ByItems) PhysicalPlan { +func injectProjBelowSort(p PhysicalPlan, orderByItems []*util.ByItems) PhysicalPlan { hasScalarFunc, numOrderByItems := false, len(orderByItems) for i := 0; !hasScalarFunc && i < numOrderByItems; i++ { _, isScalarFunc := orderByItems[i].Expr.(*expression.ScalarFunction) diff --git a/planner/core/rule_max_min_eliminate.go b/planner/core/rule_max_min_eliminate.go index 8ca1774d3e0ff..b67980c4bb712 100644 --- a/planner/core/rule_max_min_eliminate.go +++ b/planner/core/rule_max_min_eliminate.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/ranger" ) @@ -180,7 +181,7 @@ func (a *maxMinEliminator) eliminateSingleMaxMin(agg *LogicalAggregation) *Logic desc := f.Name == ast.AggFuncMax // Compose Sort operator. sort := LogicalSort{}.Init(ctx) - sort.ByItems = append(sort.ByItems, &ByItems{f.Args[0], desc}) + sort.ByItems = append(sort.ByItems, &util.ByItems{Expr: f.Args[0], Desc: desc}) sort.SetChildren(child) child = sort } diff --git a/planner/core/rule_topn_push_down.go b/planner/core/rule_topn_push_down.go index 1db0820ab77cc..06cc238326d54 100644 --- a/planner/core/rule_topn_push_down.go +++ b/planner/core/rule_topn_push_down.go @@ -18,6 +18,7 @@ import ( "github.com/cznic/mathutil" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/planner/util" ) // pushDownTopNOptimizer pushes down the topN or limit. In the future we will remove the limit from `requiredProperty` in CBO phase. @@ -95,7 +96,7 @@ func (p *LogicalUnionAll) pushDownTopN(topN *LogicalTopN) LogicalPlan { if topN != nil { newTopN = LogicalTopN{Count: topN.Count + topN.Offset}.Init(p.ctx) for _, by := range topN.ByItems { - newTopN.ByItems = append(newTopN.ByItems, &ByItems{by.Expr, by.Desc}) + newTopN.ByItems = append(newTopN.ByItems, &util.ByItems{Expr: by.Expr, Desc: by.Desc}) } } p.children[i] = child.pushDownTopN(newTopN) @@ -141,7 +142,7 @@ func (p *LogicalJoin) pushDownTopNToChild(topN *LogicalTopN, idx int) LogicalPla newTopN := LogicalTopN{ Count: topN.Count + topN.Offset, - ByItems: make([]*ByItems, len(topN.ByItems)), + ByItems: make([]*util.ByItems, len(topN.ByItems)), }.Init(topN.ctx) for i := range topN.ByItems { newTopN.ByItems[i] = topN.ByItems[i].Clone() diff --git a/planner/core/task.go b/planner/core/task.go index 0525109988977..60ea670ef1edd 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -25,6 +25,7 @@ import ( "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/plancodec" @@ -634,7 +635,7 @@ func (p *NominalSort) attach2Task(tasks ...task) task { } func (p *PhysicalTopN) getPushedDownTopN(childPlan PhysicalPlan) *PhysicalTopN { - newByItems := make([]*ByItems, 0, len(p.ByItems)) + newByItems := make([]*util.ByItems, 0, len(p.ByItems)) for _, expr := range p.ByItems { newByItems = append(newByItems, expr.Clone()) } diff --git a/planner/core/testdata/plan_suite_in.json b/planner/core/testdata/plan_suite_in.json index af45fa3e1c7c3..e4ed71ff2a998 100644 --- a/planner/core/testdata/plan_suite_in.json +++ b/planner/core/testdata/plan_suite_in.json @@ -406,6 +406,15 @@ "select /*+ TIDB_INLJ(t2) */ t1.b, t2.a from t2 t1, t2 t2 where t1.b=t2.b and t2.c=-1;" ] }, + { + "name": "TestGroupConcatOrderby", + "cases": [ + "select /*+ agg_to_cop */ group_concat(name ORDER BY name desc SEPARATOR '++'), group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from test;", + "select /*+ agg_to_cop */ group_concat(name ORDER BY name desc SEPARATOR '++'), group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from ptest;", + "select /*+ agg_to_cop */ group_concat(distinct name order by name desc) from test;", + "select /*+ agg_to_cop */ group_concat(distinct name order by name desc) from ptest;" + ] + }, { "name": "TestIndexJoinUnionScan", "cases": [ diff --git a/planner/core/testdata/plan_suite_out.json b/planner/core/testdata/plan_suite_out.json index a3803adc2ff09..af425d219f65f 100644 --- a/planner/core/testdata/plan_suite_out.json +++ b/planner/core/testdata/plan_suite_out.json @@ -990,6 +990,65 @@ } ] }, + { + "Name": "TestGroupConcatOrderby", + "Cases": [ + { + "SQL": "select /*+ agg_to_cop */ group_concat(name ORDER BY name desc SEPARATOR '++'), group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from test;", + "Plan": [ + "HashAgg_5 1.00 root funcs:group_concat(col_0 order by col_1 desc separator \"++\"), group_concat(col_2 order by col_3 desc, col_4 asc separator \"--\")", + "└─Projection_18 10000.00 root cast(test.test.name), test.test.name, cast(test.test.id), test.test.name, test.test.id", + " └─TableReader_11 10000.00 root data:TableScan_10", + " └─TableScan_10 10000.00 cop[tikv] table:test, range:[-inf,+inf], keep order:false, stats:pseudo" + ], + "Result": [ + "500++200++30++20++20++10 3--3--1--1--2--1" + ] + }, + { + "SQL": "select /*+ agg_to_cop */ group_concat(name ORDER BY name desc SEPARATOR '++'), group_concat(id ORDER BY name desc, id asc SEPARATOR '--') from ptest;", + "Plan": [ + "HashAgg_10 1.00 root funcs:group_concat(col_0 order by col_1 desc separator \"++\"), group_concat(col_2 order by col_3 desc, col_4 asc separator \"--\")", + "└─Projection_23 20000.00 root cast(test.ptest.name), test.ptest.name, cast(test.ptest.id), test.ptest.name, test.ptest.id", + " └─Union_13 20000.00 root ", + " ├─TableReader_15 10000.00 root data:TableScan_14", + " │ └─TableScan_14 10000.00 cop[tikv] table:ptest, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─TableReader_17 10000.00 root data:TableScan_16", + " └─TableScan_16 10000.00 cop[tikv] table:ptest, partition:p1, range:[-inf,+inf], keep order:false, stats:pseudo" + ], + "Result": [ + "500++200++30++20++20++10 3--3--1--1--2--1" + ] + }, + { + "SQL": "select /*+ agg_to_cop */ group_concat(distinct name order by name desc) from test;", + "Plan": [ + "StreamAgg_8 1.00 root funcs:group_concat(distinct col_0 order by col_1 desc separator \",\")", + "└─Projection_18 10000.00 root cast(test.test.name), test.test.name", + " └─TableReader_15 10000.00 root data:TableScan_14", + " └─TableScan_14 10000.00 cop[tikv] table:test, range:[-inf,+inf], keep order:false, stats:pseudo" + ], + "Result": [ + "500,200,30,20,10" + ] + }, + { + "SQL": "select /*+ agg_to_cop */ group_concat(distinct name order by name desc) from ptest;", + "Plan": [ + "StreamAgg_12 1.00 root funcs:group_concat(distinct col_0 order by col_1 desc separator \",\")", + "└─Projection_23 20000.00 root cast(test.ptest.name), test.ptest.name", + " └─Union_18 20000.00 root ", + " ├─TableReader_20 10000.00 root data:TableScan_19", + " │ └─TableScan_19 10000.00 cop[tikv] table:ptest, partition:p0, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─TableReader_22 10000.00 root data:TableScan_21", + " └─TableScan_21 10000.00 cop[tikv] table:ptest, partition:p1, range:[-inf,+inf], keep order:false, stats:pseudo" + ], + "Result": [ + "500,200,30,20,10" + ] + } + ] + }, { "Name": "TestIndexJoinUnionScan", "Cases": [ diff --git a/planner/util/byitem.go b/planner/util/byitem.go new file mode 100644 index 0000000000000..550bb93572cbe --- /dev/null +++ b/planner/util/byitem.go @@ -0,0 +1,45 @@ +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "fmt" + + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" +) + +// ByItems wraps a "by" item. +type ByItems struct { + Expr expression.Expression + Desc bool +} + +// String implements fmt.Stringer interface. +func (by *ByItems) String() string { + if by.Desc { + return fmt.Sprintf("%s true", by.Expr) + } + return by.Expr.String() +} + +// Clone makes a copy of ByItems. +func (by *ByItems) Clone() *ByItems { + return &ByItems{Expr: by.Expr.Clone(), Desc: by.Desc} +} + +// Equal checks whether two ByItems are equal. +func (by *ByItems) Equal(ctx sessionctx.Context, other *ByItems) bool { + return by.Expr.Equal(ctx, other.Expr) && by.Desc == other.Desc +} diff --git a/sessionctx/variable/varsutil.go b/sessionctx/variable/varsutil.go index 9348abc17d63e..7b0798c134fc7 100644 --- a/sessionctx/variable/varsutil.go +++ b/sessionctx/variable/varsutil.go @@ -22,6 +22,7 @@ import ( "sync/atomic" "time" + "github.com/cznic/mathutil" "github.com/pingcap/errors" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" @@ -314,9 +315,15 @@ func ValidateSetSystemVar(vars *SessionVars, name string, value string) (string, } return value, ErrWrongValueForVar.GenWithStackByArgs(name, value) case GroupConcatMaxLen: - // The reasonable range of 'group_concat_max_len' is 4~18446744073709551615(64-bit platforms) - // See https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_group_concat_max_len for details - return checkUInt64SystemVar(name, value, 4, math.MaxUint64, vars) + // https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_group_concat_max_len + // Minimum Value 4 + // Maximum Value (64-bit platforms) 18446744073709551615 + // Maximum Value (32-bit platforms) 4294967295 + maxLen := uint64(math.MaxUint64) + if mathutil.IntBits == 32 { + maxLen = uint64(math.MaxUint32) + } + return checkUInt64SystemVar(name, value, 4, maxLen, vars) case InteractiveTimeout: return checkUInt64SystemVar(name, value, 1, secondsPerYear, vars) case InnodbCommitConcurrency: diff --git a/types/datum.go b/types/datum.go index 4e17f14319af4..4fa5b4a608775 100644 --- a/types/datum.go +++ b/types/datum.go @@ -342,6 +342,50 @@ func (d *Datum) SetAutoID(id int64, flag uint) { } } +// String returns a human-readable description of Datum. It is intended only for debugging. +func (d Datum) String() string { + var t string + switch d.k { + case KindNull: + t = "KindNull" + case KindInt64: + t = "KindInt64" + case KindUint64: + t = "KindUint64" + case KindFloat32: + t = "KindFloat32" + case KindFloat64: + t = "KindFloat64" + case KindString: + t = "KindString" + case KindBytes: + t = "KindBytes" + case KindMysqlDecimal: + t = "KindMysqlDecimal" + case KindMysqlDuration: + t = "KindMysqlDuration" + case KindMysqlEnum: + t = "KindMysqlEnum" + case KindBinaryLiteral: + t = "KindBinaryLiteral" + case KindMysqlBit: + t = "KindMysqlBit" + case KindMysqlSet: + t = "KindMysqlSet" + case KindMysqlJSON: + t = "KindMysqlJSON" + case KindMysqlTime: + t = "KindMysqlTime" + default: + t = "Unknown" + } + v := d.GetValue() + if b, ok := v.([]byte); ok && d.k == KindBytes { + v = string(b) + } + return fmt.Sprintf("%v %v", t, v) +} + // GetValue gets the value of the datum of any kind. func (d *Datum) GetValue() interface{} { switch d.k { diff --git a/types/datum_test.go b/types/datum_test.go index d525082990ad7..4e2d464ad0eda 100644 --- a/types/datum_test.go +++ b/types/datum_test.go @@ -14,6 +14,7 @@ package types import ( + "fmt" "reflect" "testing" "time" @@ -47,6 +48,7 @@ func (ts *testDatumSuite) TestDatum(c *C) { d.SetCollation(d.Collation()) c.Assert(d.Collation(), NotNil) c.Assert(d.Length(), Equals, int(d.length)) + c.Assert(fmt.Sprint(d), Equals, d.String()) } }