Skip to content

Commit

Permalink
executor: support GROUP_CONCAT(ORDER BY) (#16591) (#17184)
Browse files Browse the repository at this point in the history
  • Loading branch information
SunRunAway authored May 14, 2020
1 parent b8219e6 commit 4a350fe
Show file tree
Hide file tree
Showing 38 changed files with 1,028 additions and 99 deletions.
4 changes: 2 additions & 2 deletions cmd/explaintest/r/explain.result
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ set session tidb_hashagg_partial_concurrency = 1;
set session tidb_hashagg_final_concurrency = 1;
explain select group_concat(a) from t group by id;
id count task operator info
StreamAgg_8 8000.00 root group by:col_1, funcs:group_concat(col_0, ",")
StreamAgg_8 8000.00 root group by:col_1, funcs:group_concat(col_0 separator ",")
└─Projection_18 10000.00 root cast(test.t.a), test.t.id
└─TableReader_15 10000.00 root data:TableScan_14
└─TableScan_14 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:true, stats:pseudo
explain select group_concat(a, b) from t group by id;
id count task operator info
StreamAgg_8 8000.00 root group by:col_2, funcs:group_concat(col_0, col_1, ",")
StreamAgg_8 8000.00 root group by:col_2, funcs:group_concat(col_0, col_1 separator ",")
└─Projection_18 10000.00 root cast(test.t.a), cast(test.t.b), test.t.id
└─TableReader_15 10000.00 root data:TableScan_14
└─TableScan_14 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:true, stats:pseudo
Expand Down
145 changes: 135 additions & 10 deletions executor/aggfuncs/aggfunc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/types/json"
Expand Down Expand Up @@ -68,6 +69,7 @@ type aggTest struct {
dataGen func(i int) types.Datum
funcName string
results []types.Datum
orderBy bool
}

type multiArgsAggTest struct {
Expand All @@ -77,6 +79,7 @@ type multiArgsAggTest struct {
dataGens []func(i int) types.Datum
funcName string
results []types.Datum
orderBy bool
}

func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
Expand All @@ -93,6 +96,11 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
}
desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
c.Assert(err, IsNil)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
partialDesc, finalDesc := desc.Split([]int{0, 1})

// build partial func for partial phase.
Expand All @@ -112,7 +120,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
dt := resultChk.GetRow(0).GetDatum(0, p.dataType)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0]))

err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
c.Assert(err, IsNil)
Expand All @@ -128,7 +136,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1]))
err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
c.Assert(err, IsNil)

Expand All @@ -139,7 +147,7 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[2]))
}

func buildAggTester(funcName string, tp byte, numRows int, results ...interface{}) aggTest {
Expand All @@ -159,6 +167,77 @@ func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows i
return pt
}

func (s *testSuite) testMultiArgsMergePartialResult(c *C, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
iter := chunk.NewIterator4Chunk(srcChk)

args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}

desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
c.Assert(err, IsNil)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
partialDesc, finalDesc := desc.Split([]int{0, 1})

// build partial func for partial phase.
partialFunc := aggfuncs.Build(s.ctx, partialDesc, 0)
partialResult := partialFunc.AllocPartialResult()

// build final func for final phase.
finalFunc := aggfuncs.Build(s.ctx, finalDesc, 0)
finalPr := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.retType}, 1)

// update partial result.
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt := resultChk.GetRow(0).GetDatum(0, p.retType)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)

err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
c.Assert(err, IsNil)
partialFunc.ResetPartialResult(partialResult)

iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
}
resultChk.Reset()
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
c.Assert(err, IsNil)

resultChk.Reset()
err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
c.Assert(err, IsNil)

dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
}

// for multiple args in aggfuncs such as json_objectagg(c1, c2)
func buildMultiArgsAggTester(funcName string, tps []byte, rt byte, numRows int, results ...interface{}) multiArgsAggTest {
fts := make([]*types.FieldType, len(tps))
Expand Down Expand Up @@ -234,6 +313,11 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
}
desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
c.Assert(err, IsNil)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(s.ctx, desc, 0)
finalPr := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
Expand All @@ -246,7 +330,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1]))

// test the empty input
resultChk.Reset()
Expand All @@ -255,11 +339,16 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0]))

// test the agg func with distinct
desc, err = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true)
c.Assert(err, IsNil)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(s.ctx, desc, 0)
finalPr = finalFunc.AllocPartialResult()

Expand All @@ -275,7 +364,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1]))

// test the empty input
resultChk.Reset()
Expand All @@ -284,7 +373,7 @@ func (s *testSuite) testAggFunc(c *C, p aggTest) {
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0]))
}

func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
Expand All @@ -301,9 +390,17 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)})
}

desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
c.Assert(err, IsNil)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(s.ctx, desc, 0)
finalPr := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
Expand All @@ -316,7 +413,7 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1]))

// test the empty input
resultChk.Reset()
Expand All @@ -325,11 +422,16 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0]))

// test the agg func with distinct
desc, err = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true)
c.Assert(err, IsNil)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(s.ctx, desc, 0)
finalPr = finalFunc.AllocPartialResult()

Expand All @@ -345,7 +447,7 @@ func (s *testSuite) testMultiArgsAggFunc(c *C, p multiArgsAggTest) {
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
c.Assert(err, IsNil)
c.Assert(result, Equals, 0)
c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1]))

// test the empty input
resultChk.Reset()
Expand Down Expand Up @@ -373,6 +475,11 @@ func (s *testSuite) benchmarkAggFunc(b *testing.B, p aggTest) {
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(s.ctx, desc, 0)
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
iter := chunk.NewIterator4Chunk(srcChk)
Expand All @@ -388,6 +495,11 @@ func (s *testSuite) benchmarkAggFunc(b *testing.B, p aggTest) {
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(s.ctx, desc, 0)
resultChk.Reset()
b.Run(fmt.Sprintf("%v(distinct)/%v", p.funcName, p.dataType), func(b *testing.B) {
Expand All @@ -409,11 +521,19 @@ func (s *testSuite) benchmarkMultiArgsAggFunc(b *testing.B, p multiArgsAggTest)
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)})
}

desc, err := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(s.ctx, desc, 0)
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
iter := chunk.NewIterator4Chunk(srcChk)
Expand All @@ -429,6 +549,11 @@ func (s *testSuite) benchmarkMultiArgsAggFunc(b *testing.B, p multiArgsAggTest)
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(s.ctx, desc, 0)
resultChk.Reset()
b.Run(fmt.Sprintf("%v(distinct)/%v", p.funcName, p.dataTypes), func(b *testing.B) {
Expand Down
24 changes: 18 additions & 6 deletions executor/aggfuncs/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,6 @@ func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDe
case aggregation.DedupMode:
return nil
default:
base := baseAggFunc{
args: aggFuncDesc.Args[:len(aggFuncDesc.Args)-1],
ordinal: ordinal,
}
// The last arg is promised to be a not-null string constant, so the error can be ignored.
c, _ := aggFuncDesc.Args[len(aggFuncDesc.Args)-1].(*expression.Constant)
sep, _, err := c.EvalString(nil, chunk.Row{})
Expand All @@ -335,10 +331,26 @@ func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDe
panic(fmt.Sprintf("Error happened when buildGroupConcat: %s", err.Error()))
}
var truncated int32
base := baseGroupConcat4String{
baseAggFunc: baseAggFunc{
args: aggFuncDesc.Args[:len(aggFuncDesc.Args)-1],
ordinal: ordinal,
},
byItems: aggFuncDesc.OrderByItems,
sep: sep,
maxLen: maxLen,
truncated: &truncated,
}
if aggFuncDesc.HasDistinct {
return &groupConcatDistinct{baseGroupConcat4String{baseAggFunc: base, sep: sep, maxLen: maxLen, truncated: &truncated}}
if len(aggFuncDesc.OrderByItems) > 0 {
return &groupConcatDistinctOrder{base}
}
return &groupConcatDistinct{base}
}
if len(aggFuncDesc.OrderByItems) > 0 {
return &groupConcatOrder{base}
}
return &groupConcat{baseGroupConcat4String{baseAggFunc: base, sep: sep, maxLen: maxLen, truncated: &truncated}}
return &groupConcat{base}
}
}

Expand Down
Loading

0 comments on commit 4a350fe

Please sign in to comment.