diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index 515be967f8085..312ff9f7cad81 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -151,6 +151,9 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { finalFunc := aggfuncs.Build(s.ctx, finalDesc, 0) finalPr := finalFunc.AllocPartialResult() resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, 1) + if p.funcName == ast.AggFuncApproxCountDistinct { + resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeString)}, 1) + } // update partial result. for row := iter.Begin(); row != iter.End(); row = iter.Next() { @@ -159,6 +162,9 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { p.messUpChunk(srcChk) partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) dt := resultChk.GetRow(0).GetDatum(0, p.dataType) + if p.funcName == ast.AggFuncApproxCountDistinct { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString)) + } result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) c.Assert(err, IsNil) c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[0])) @@ -178,17 +184,26 @@ func (s *testSuite) testMergePartialResult(c *C, p aggTest) { resultChk.Reset() partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) dt = resultChk.GetRow(0).GetDatum(0, p.dataType) + if p.funcName == ast.AggFuncApproxCountDistinct { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString)) + } result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) c.Assert(err, IsNil) c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[1])) err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr) c.Assert(err, IsNil) + if p.funcName == ast.AggFuncApproxCountDistinct { + resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) + } resultChk.Reset() err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) c.Assert(err, IsNil) dt = resultChk.GetRow(0).GetDatum(0, p.dataType) + if p.funcName == ast.AggFuncApproxCountDistinct { + dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeLonglong)) + } result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2]) c.Assert(err, IsNil) c.Assert(result, Equals, 0, Commentf("%v != %v", dt.String(), p.results[2])) diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index a3270a7da4d3a..26126a9208f41 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -39,6 +39,12 @@ var ( _ AggFunc = (*countOriginalWithDistinct4String)(nil) _ AggFunc = (*countOriginalWithDistinct)(nil) + // All the AggFunc implementations for "APPROX_COUNT_DISTINCT" are listed here. + _ AggFunc = (*approxCountDistinctOriginal)(nil) + _ AggFunc = (*approxCountDistinctPartial1)(nil) + _ AggFunc = (*approxCountDistinctPartial2)(nil) + _ AggFunc = (*approxCountDistinctFinal)(nil) + // All the AggFunc implementations for "FIRSTROW" are listed here. _ AggFunc = (*firstRow4Decimal)(nil) _ AggFunc = (*firstRow4Int)(nil) diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index 1eea035c0553b..799dc9d9e4d34 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -55,6 +55,8 @@ func Build(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal return buildVarPop(aggFuncDesc, ordinal) case ast.AggFuncJsonObjectAgg: return buildJSONObjectAgg(aggFuncDesc, ordinal) + case ast.AggFuncApproxCountDistinct: + return buildApproxCountDistinct(aggFuncDesc, ordinal) } return nil } @@ -89,6 +91,39 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag } } +func buildApproxCountDistinct(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { + base := baseApproxCountDistinct{baseAggFunc{ + args: aggFuncDesc.Args, + ordinal: ordinal, + }} + + // In partition table, union need to compute partial result into partial result. + // We can detect and handle this case by checking whether return type is string. + + switch aggFuncDesc.RetTp.Tp { + case mysql.TypeLonglong: + switch aggFuncDesc.Mode { + case aggregation.CompleteMode: + return &approxCountDistinctOriginal{base} + case aggregation.Partial1Mode: + return &approxCountDistinctPartial1{approxCountDistinctOriginal{base}} + case aggregation.Partial2Mode: + return &approxCountDistinctPartial2{approxCountDistinctPartial1{approxCountDistinctOriginal{base}}} + case aggregation.FinalMode: + return &approxCountDistinctFinal{approxCountDistinctPartial2{approxCountDistinctPartial1{approxCountDistinctOriginal{base}}}} + } + case mysql.TypeString: + switch aggFuncDesc.Mode { + case aggregation.CompleteMode, aggregation.Partial1Mode: + return &approxCountDistinctPartial1{approxCountDistinctOriginal{base}} + case aggregation.Partial2Mode, aggregation.FinalMode: + return &approxCountDistinctPartial2{approxCountDistinctPartial1{approxCountDistinctOriginal{base}}} + } + } + + return nil +} + // buildCount builds the AggFunc implementation for function "COUNT". func buildCount(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { // If mode is DedupMode, we return nil for not implemented. diff --git a/executor/aggfuncs/func_count_distinct.go b/executor/aggfuncs/func_count_distinct.go index 6dc180f05b16a..b7b883d4c05a0 100644 --- a/executor/aggfuncs/func_count_distinct.go +++ b/executor/aggfuncs/func_count_distinct.go @@ -15,8 +15,10 @@ package aggfuncs import ( "encoding/binary" + "math" "unsafe" + "github.com/dgryski/go-farm" "github.com/pingcap/errors" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/sessionctx" @@ -302,7 +304,7 @@ func (e *countOriginalWithDistinct) UpdatePartialResult(sctx sessionctx.Context, encodedBytes = encodedBytes[:0] for i := 0; i < len(e.args) && !hasNull; i++ { - encodedBytes, isNull, err = e.evalAndEncode(sctx, e.args[i], row, buf, encodedBytes) + encodedBytes, isNull, err = evalAndEncode(sctx, e.args[i], row, buf, encodedBytes) if err != nil { return } @@ -322,7 +324,7 @@ func (e *countOriginalWithDistinct) UpdatePartialResult(sctx sessionctx.Context, } // evalAndEncode eval one row with an expression and encode value to bytes. -func (e *countOriginalWithDistinct) evalAndEncode( +func evalAndEncode( sctx sessionctx.Context, arg expression.Expression, row chunk.Row, buf, encodedBytes []byte, ) (_ []byte, isNull bool, err error) { @@ -433,3 +435,388 @@ func appendJSON(encodedBytes, _ []byte, val json.BinaryJSON) []byte { encodedBytes = append(encodedBytes, val.Value...) return encodedBytes } + +func intHash64(x uint64) uint64 { + x ^= x >> 33 + x *= 0xff51afd7ed558ccd + x ^= x >> 33 + x *= 0xc4ceb9fe1a85ec53 + x ^= x >> 33 + return x +} + +type baseApproxCountDistinct struct { + baseAggFunc +} + +const ( + // The maximum degree of buffer size before the values are discarded + uniquesHashMaxSizeDegree uint8 = 17 + // The maximum number of elements before the values are discarded + uniquesHashMaxSize = uint32(1) << (uniquesHashMaxSizeDegree - 1) + // Initial buffer size degree + uniquesHashSetInitialSizeDegree uint8 = 4 + // The number of least significant bits used for thinning. The remaining high-order bits are used to determine the position in the hash table. + uniquesHashBitsForSkip = 32 - uniquesHashMaxSizeDegree +) + +type approxCountDistinctHashValue uint32 + +// partialResult4ApproxCountDistinct use `BJKST` algorithm to compute approximate result of count distinct. +// According to an experimental survey http://www.vldb.org/pvldb/vol11/p499-harmouch.pdf, the error guarantee of BJKST +// was even better than the theoretical lower bounds. +// For the calculation state, it uses a sample of element hash values with a size up to uniquesHashMaxSize. Compared +// with the widely known HyperLogLog algorithm, this algorithm is less effective in terms of accuracy and +// memory consumption (even up to proportionality), but it is adaptive. This means that with fairly high accuracy, it +// consumes less memory during simultaneous computation of cardinality for a large number of data sets whose cardinality +// has power law distribution (i.e. in cases when most of the data sets are small). +// This algorithm is also very accurate for data sets with small cardinality and very efficient on CPU. If number of +// distinct element is more than 2^32, relative error may be high. +type partialResult4ApproxCountDistinct struct { + size uint32 /// Number of elements. + sizeDegree uint8 /// The size of the table as a power of 2. + skipDegree uint8 /// Skip elements not divisible by 2 ^ skipDegree. + hasZero bool /// The hash table contains an element with a hash value of 0. + buf []approxCountDistinctHashValue +} + +// NewPartialResult4ApproxCountDistinct builds a partial result for agg function ApproxCountDistinct. +func NewPartialResult4ApproxCountDistinct() *partialResult4ApproxCountDistinct { + p := &partialResult4ApproxCountDistinct{} + p.reset() + return p +} + +func (p *partialResult4ApproxCountDistinct) InsertHash64(x uint64) { + // no need to rehash, just cast into uint32 + p.insertHash(approxCountDistinctHashValue(x)) +} + +func (p *partialResult4ApproxCountDistinct) alloc(newSizeDegree uint8) { + p.size = 0 + p.skipDegree = 0 + p.hasZero = false + p.buf = make([]approxCountDistinctHashValue, uint32(1)< b { + return a + } + + return b +} + +func (p *partialResult4ApproxCountDistinct) bufSize() uint32 { + return uint32(1) << p.sizeDegree +} + +func (p *partialResult4ApproxCountDistinct) mask() uint32 { + return p.bufSize() - 1 +} + +func (p *partialResult4ApproxCountDistinct) place(x approxCountDistinctHashValue) uint32 { + return uint32(x>>uniquesHashBitsForSkip) & p.mask() +} + +// Increase the size of the buffer 2 times or up to new size degree. +func (p *partialResult4ApproxCountDistinct) resize(newSizeDegree uint8) { + oldSize := p.bufSize() + oldBuf := p.buf + + if 0 == newSizeDegree { + newSizeDegree = p.sizeDegree + 1 + } + + p.buf = make([]approxCountDistinctHashValue, uint32(1)< p.skipDegree { + p.skipDegree = rhsSkipDegree + p.rehash() + } + + rb, rhsSize, err := codec.DecodeUvarint(rb) + + if err != nil { + return err + } + + if rhsSize > uint64(uniquesHashMaxSize) { + return errors.New("Cannot read partialResult4ApproxCountDistinct: too large size degree") + } + + if p.bufSize() < uint32(rhsSize) { + newSizeDegree := max(uniquesHashSetInitialSizeDegree, uint8(math.Log2(float64(rhsSize-1)))+2) + p.resize(newSizeDegree) + } + + for i := uint32(0); i < uint32(rhsSize); i++ { + x := *(*approxCountDistinctHashValue)(unsafe.Pointer(&rb[0])) + rb = rb[4:] + p.insertHash(x) + } + + return err +} + +// Correct system errors due to collisions during hashing in uint32. +func (p *partialResult4ApproxCountDistinct) fixedSize() uint64 { + if 0 == p.skipDegree { + return uint64(p.size) + } + + res := uint64(p.size) * (uint64(1) << p.skipDegree) + + // Pseudo-random remainder. + res += intHash64(uint64(p.size)) & ((uint64(1) << p.skipDegree) - 1) + + // When different elements randomly scattered across 2^32 buckets, filled buckets with average of `res` obtained. + p32 := uint64(1) << 32 + fixedRes := math.Round(float64(p32) * (math.Log(float64(p32)) - math.Log(float64(p32-res)))) + return uint64(fixedRes) +} + +func (p *partialResult4ApproxCountDistinct) insertHash(hashValue approxCountDistinctHashValue) { + if !p.good(hashValue) { + return + } + + p.insertImpl(hashValue) + p.shrinkIfNeed() +} + +// The value is divided by 2 ^ skip_degree +func (p *partialResult4ApproxCountDistinct) good(hash approxCountDistinctHashValue) bool { + return hash == ((hash >> p.skipDegree) << p.skipDegree) +} + +// Insert a value +func (p *partialResult4ApproxCountDistinct) insertImpl(x approxCountDistinctHashValue) { + if x == 0 { + if !p.hasZero { + p.size += 1 + } + p.hasZero = true + return + } + + placeValue := p.place(x) + for p.buf[placeValue] != 0 && p.buf[placeValue] != x { + placeValue++ + placeValue &= p.mask() + } + + if p.buf[placeValue] == x { + return + } + + p.buf[placeValue] = x + p.size++ +} + +// If the hash table is full enough, then do resize. +// If there are too many items, then throw half the pieces until they are small enough. +func (p *partialResult4ApproxCountDistinct) shrinkIfNeed() { + if p.size > p.maxFill() { + if p.size > uniquesHashMaxSize { + for p.size > uniquesHashMaxSize { + p.skipDegree++ + p.rehash() + } + } else { + p.resize(0) + } + } +} + +func (p *partialResult4ApproxCountDistinct) maxFill() uint32 { + return uint32(1) << (p.sizeDegree - 1) +} + +// Delete all values whose hashes do not divide by 2 ^ skip_degree +func (p *partialResult4ApproxCountDistinct) rehash() { + for i := uint32(0); i < p.bufSize(); i++ { + if p.buf[i] != 0 && !p.good(p.buf[i]) { + p.buf[i] = 0 + p.size-- + } + } + + for i := uint32(0); i < p.bufSize(); i++ { + if p.buf[i] != 0 && i != p.place(p.buf[i]) { + x := p.buf[i] + p.buf[i] = 0 + p.reinsertImpl(x) + } + } +} + +// Insert a value into the new buffer that was in the old buffer. +// Used when increasing the size of the buffer, as well as when reading from a file. +func (p *partialResult4ApproxCountDistinct) reinsertImpl(x approxCountDistinctHashValue) { + placeValue := p.place(x) + for p.buf[placeValue] != 0 { + placeValue++ + placeValue &= p.mask() + } + + p.buf[placeValue] = x +} + +func (p *partialResult4ApproxCountDistinct) merge(tar *partialResult4ApproxCountDistinct) { + if tar.skipDegree > p.skipDegree { + p.skipDegree = tar.skipDegree + p.rehash() + } + + if !p.hasZero && tar.hasZero { + p.hasZero = true + p.size++ + p.shrinkIfNeed() + } + + for i := uint32(0); i < tar.bufSize(); i++ { + if tar.buf[i] != 0 && p.good(tar.buf[i]) { + p.insertImpl(tar.buf[i]) + p.shrinkIfNeed() + } + } +} + +func (p *partialResult4ApproxCountDistinct) Serialize() []byte { + var buf [4]byte + res := make([]byte, 0, 1+binary.MaxVarintLen64+p.size*4) + + res = append(res, p.skipDegree) + res = codec.EncodeUvarint(res, uint64(p.size)) + + if p.hasZero { + binary.LittleEndian.PutUint32(buf[:], 0) + res = append(res, buf[:]...) + } + + for i := uint32(0); i < p.bufSize(); i++ { + if p.buf[i] != 0 { + binary.LittleEndian.PutUint32(buf[:], uint32(p.buf[i])) + res = append(res, buf[:]...) + } + } + return res +} + +func (e *baseApproxCountDistinct) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4ApproxCountDistinct)(pr) + chk.AppendInt64(e.ordinal, int64(p.fixedSize())) + return nil +} + +func (e *baseApproxCountDistinct) AllocPartialResult() PartialResult { + return (PartialResult)(NewPartialResult4ApproxCountDistinct()) +} + +func (e *baseApproxCountDistinct) ResetPartialResult(pr PartialResult) { + p := (*partialResult4ApproxCountDistinct)(pr) + p.reset() +} + +func (e *baseApproxCountDistinct) MergePartialResult(sctx sessionctx.Context, src PartialResult, dst PartialResult) error { + p1, p2 := (*partialResult4ApproxCountDistinct)(src), (*partialResult4ApproxCountDistinct)(dst) + p2.merge(p1) + return nil +} + +type approxCountDistinctOriginal struct { + baseApproxCountDistinct +} + +func (e *approxCountDistinctOriginal) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (err error) { + p := (*partialResult4ApproxCountDistinct)(pr) + encodedBytes := make([]byte, 0) + // Decimal struct is the biggest type we will use. + buf := make([]byte, types.MyDecimalStructSize) + + for _, row := range rowsInGroup { + var hasNull, isNull bool + encodedBytes = encodedBytes[:0] + + for i := 0; i < len(e.args) && !hasNull; i++ { + encodedBytes, isNull, err = evalAndEncode(sctx, e.args[i], row, buf, encodedBytes) + if err != nil { + return + } + if isNull { + hasNull = true + break + } + } + if hasNull { + continue + } + + x := farm.Hash64(encodedBytes) + p.InsertHash64(x) + } + + return nil +} + +type approxCountDistinctPartial1 struct { + approxCountDistinctOriginal +} + +func (e *approxCountDistinctPartial1) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + p := (*partialResult4ApproxCountDistinct)(pr) + chk.AppendBytes(e.ordinal, p.Serialize()) + return nil +} + +type approxCountDistinctPartial2 struct { + approxCountDistinctPartial1 +} + +func (e *approxCountDistinctPartial2) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error { + p := (*partialResult4ApproxCountDistinct)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalString(sctx, row) + if err != nil { + return err + } + + if isNull { + continue + } + + err = p.readAndMerge(hack.Slice(input)) + if err != nil { + return err + } + } + return nil +} + +type approxCountDistinctFinal struct { + approxCountDistinctPartial2 +} + +func (e *approxCountDistinctFinal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { + return e.baseApproxCountDistinct.AppendFinalResult2Chunk(sctx, pr, chk) +} diff --git a/executor/aggfuncs/func_count_test.go b/executor/aggfuncs/func_count_test.go index f647b98467025..2036b0145d5ee 100644 --- a/executor/aggfuncs/func_count_test.go +++ b/executor/aggfuncs/func_count_test.go @@ -14,16 +14,33 @@ package aggfuncs_test import ( + "encoding/binary" "testing" + "github.com/dgryski/go-farm" . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/executor/aggfuncs" ) +func genApproxDistinctMergePartialResult(begin, end uint64) string { + o := aggfuncs.NewPartialResult4ApproxCountDistinct() + encodedBytes := make([]byte, 8) + for i := begin; i < end; i++ { + binary.LittleEndian.PutUint64(encodedBytes, i) + x := farm.Hash64(encodedBytes) + o.InsertHash64(x) + } + return string(o.Serialize()) +} + func (s *testSuite) TestMergePartialResult4Count(c *C) { tester := buildAggTester(ast.AggFuncCount, mysql.TypeLonglong, 5, 5, 3, 8) s.testMergePartialResult(c, tester) + + tester = buildAggTester(ast.AggFuncApproxCountDistinct, mysql.TypeLonglong, 5, genApproxDistinctMergePartialResult(0, 5), genApproxDistinctMergePartialResult(2, 5), 5) + s.testMergePartialResult(c, tester) } func (s *testSuite) TestCount(c *C) { @@ -53,6 +70,35 @@ func (s *testSuite) TestCount(c *C) { for _, test := range tests2 { s.testMultiArgsAggFunc(c, test) } + + tests3 := []aggTest{ + buildAggTester(ast.AggFuncCount, mysql.TypeLonglong, 5, 0, 5), + buildAggTester(ast.AggFuncCount, mysql.TypeFloat, 5, 0, 5), + buildAggTester(ast.AggFuncCount, mysql.TypeDouble, 5, 0, 5), + buildAggTester(ast.AggFuncCount, mysql.TypeNewDecimal, 5, 0, 5), + buildAggTester(ast.AggFuncCount, mysql.TypeString, 5, 0, 5), + buildAggTester(ast.AggFuncCount, mysql.TypeDate, 5, 0, 5), + buildAggTester(ast.AggFuncCount, mysql.TypeDuration, 5, 0, 5), + buildAggTester(ast.AggFuncCount, mysql.TypeJSON, 5, 0, 5), + } + for _, test := range tests3 { + s.testAggFunc(c, test) + } + + tests4 := []multiArgsAggTest{ + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeLonglong, mysql.TypeLonglong}, mysql.TypeLonglong, 5, 0, 5), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeFloat, mysql.TypeFloat}, mysql.TypeLonglong, 5, 0, 5), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeDouble, mysql.TypeDouble}, mysql.TypeLonglong, 5, 0, 5), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeNewDecimal, mysql.TypeNewDecimal}, mysql.TypeLonglong, 5, 0, 5), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeString, mysql.TypeString}, mysql.TypeLonglong, 5, 0, 5), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeDate, mysql.TypeDate}, mysql.TypeLonglong, 5, 0, 5), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeDuration, mysql.TypeDuration}, mysql.TypeLonglong, 5, 0, 5), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeJSON, mysql.TypeJSON}, mysql.TypeLonglong, 5, 0, 5), + } + + for _, test := range tests4 { + s.testMultiArgsAggFunc(c, test) + } } func BenchmarkCount(b *testing.B) { @@ -87,4 +133,18 @@ func BenchmarkCount(b *testing.B) { for _, test := range tests2 { s.benchmarkMultiArgsAggFunc(b, test) } + + tests3 := []multiArgsAggTest{ + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeLonglong, mysql.TypeLonglong}, mysql.TypeLonglong, rowNum, 0, rowNum), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeFloat, mysql.TypeFloat}, mysql.TypeLonglong, rowNum, 0, rowNum), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeDouble, mysql.TypeDouble}, mysql.TypeLonglong, rowNum, 0, rowNum), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeNewDecimal, mysql.TypeNewDecimal}, mysql.TypeLonglong, rowNum, 0, rowNum), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeString, mysql.TypeString}, mysql.TypeLonglong, rowNum, 0, rowNum), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeDate, mysql.TypeDate}, mysql.TypeLonglong, rowNum, 0, rowNum), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeDuration, mysql.TypeDuration}, mysql.TypeLonglong, rowNum, 0, rowNum), + buildMultiArgsAggTester(ast.AggFuncApproxCountDistinct, []byte{mysql.TypeJSON, mysql.TypeJSON}, mysql.TypeLonglong, rowNum, 0, rowNum), + } + for _, test := range tests3 { + s.benchmarkMultiArgsAggFunc(b, test) + } } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index fee57a7239037..aae9b29620d1e 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -58,6 +58,7 @@ func (s *testSuiteAgg) TestAggregation(c *C) { tk.MustQuery("select bit_and(c) from t where NULL").Check(testkit.Rows("18446744073709551615")) tk.MustQuery("select bit_or(c) from t where NULL").Check(testkit.Rows("0")) tk.MustQuery("select bit_xor(c) from t where NULL").Check(testkit.Rows("0")) + tk.MustQuery("select approx_count_distinct(c) from t where NULL").Check(testkit.Rows("0")) result := tk.MustQuery("select count(*) from t") result.Check(testkit.Rows("7")) result = tk.MustQuery("select count(*) from t group by d order by c") @@ -82,12 +83,16 @@ func (s *testSuiteAgg) TestAggregation(c *C) { result.Check(testkit.Rows("")) result = tk.MustQuery("select count(distinct c) from t group by d order by c") result.Check(testkit.Rows("1", "2", "2")) + result = tk.MustQuery("select approx_count_distinct(c) from t group by d order by c") + result.Check(testkit.Rows("1", "2", "2")) result = tk.MustQuery("select sum(c) as a from t group by d order by a") result.Check(testkit.Rows("2", "4", "5")) result = tk.MustQuery("select sum(c) as a, sum(c+1), sum(c), sum(c+1) from t group by d order by a") result.Check(testkit.Rows("2 4 2 4", "4 6 4 6", "5 7 5 7")) result = tk.MustQuery("select count(distinct c,d) from t") result.Check(testkit.Rows("5")) + result = tk.MustQuery("select approx_count_distinct(c,d) from t") + result.Check(testkit.Rows("5")) err := tk.ExecToErr("select count(c,d) from t") c.Assert(err, NotNil) result = tk.MustQuery("select d*2 as ee, sum(c) from t group by ee order by ee") @@ -100,12 +105,16 @@ func (s *testSuiteAgg) TestAggregation(c *C) { result.Check(testkit.Rows("1", "3", "4")) result = tk.MustQuery("select avg(c) as a from t group by d order by a") result.Check(testkit.Rows("1.0000", "2.0000", "2.5000")) + result = tk.MustQuery("select c, approx_count_distinct(d) as a from t group by c order by a, c") + result.Check(testkit.Rows(" 1", "3 1", "4 1", "1 3")) result = tk.MustQuery("select d, d + 1 from t group by d order by d") result.Check(testkit.Rows("1 2", "2 3", "3 4")) result = tk.MustQuery("select count(*) from t") result.Check(testkit.Rows("7")) result = tk.MustQuery("select count(distinct d) from t") result.Check(testkit.Rows("3")) + result = tk.MustQuery("select approx_count_distinct(d) from t") + result.Check(testkit.Rows("3")) result = tk.MustQuery("select count(*) as a from t group by d having sum(c) > 3 order by a") result.Check(testkit.Rows("2", "2")) result = tk.MustQuery("select max(c) from t group by d having sum(c) > 3 order by avg(c) desc") @@ -228,6 +237,8 @@ func (s *testSuiteAgg) TestAggregation(c *C) { result.Check(testkit.Rows("5", "2", "8")) result = tk.MustQuery("select count(distinct b) from (select * from t1) t group by a order by a") result.Check(testkit.Rows("2", "1", "2")) + result = tk.MustQuery("select approx_count_distinct(b) from (select * from t1) t group by a order by a") + result.Check(testkit.Rows("2", "1", "2")) result = tk.MustQuery("select max(distinct b) from (select * from t1) t group by a order by a") result.Check(testkit.Rows("4", "2", "5")) result = tk.MustQuery("select min(distinct b) from (select * from t1) t group by a order by a") @@ -278,11 +289,15 @@ func (s *testSuiteAgg) TestAggregation(c *C) { tk.MustExec("insert into t values(1, 2, 3), (1, 2, 4)") result = tk.MustQuery("select count(distinct c), count(distinct a,b) from t") result.Check(testkit.Rows("2 1")) + result = tk.MustQuery("select approx_count_distinct( c), approx_count_distinct( a,b) from t") + result.Check(testkit.Rows("2 1")) tk.MustExec("drop table if exists t") tk.MustExec("create table t (a float)") tk.MustExec("insert into t values(966.36), (363.97), (569.99), (453.33), (376.45), (321.93), (12.12), (45.77), (9.66), (612.17)") result = tk.MustQuery("select distinct count(distinct a) from t") result.Check(testkit.Rows("10")) + result = tk.MustQuery("select distinct approx_count_distinct( a) from t") + result.Check(testkit.Rows("10")) tk.MustExec("create table idx_agg (a int, b int, index (b))") tk.MustExec("insert idx_agg values (1, 1), (1, 2), (2, 2)") @@ -324,11 +339,11 @@ func (s *testSuiteAgg) TestAggregation(c *C) { result.Check(testkit.Rows(" 0 0 18446744073709551615 0 ")) tk.MustExec("truncate table t") tk.MustExec("create table s(id int)") - result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t left join s on t.id = s.id") - result.Check(testkit.Rows(" 0 0 18446744073709551615 0 ")) + result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95), approx_count_distinct(95) from t left join s on t.id = s.id") + result.Check(testkit.Rows(" 0 0 18446744073709551615 0 0")) tk.MustExec(`insert into t values (1, '{"i": 1, "n": "n1"}')`) - result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95) from t left join s on t.id = s.id") - result.Check(testkit.Rows("1 1 95 95.0000 95 95 95 95 95 95")) + result = tk.MustQuery("select t.id, count(95), sum(95), avg(95), bit_or(95), bit_and(95), bit_or(95), max(95), min(95), group_concat(95), approx_count_distinct(95) from t left join s on t.id = s.id") + result.Check(testkit.Rows("1 1 95 95.0000 95 95 95 95 95 95 1")) tk.MustExec("set @@tidb_hash_join_concurrency=5") // test agg bit col @@ -353,6 +368,7 @@ func (s *testSuiteAgg) TestAggregation(c *C) { tk.MustExec("set @@session.tidb_opt_distinct_agg_push_down = 1") tk.MustQuery("select count(distinct a) from t;").Check(testkit.Rows("2")) tk.MustExec("set @@session.tidb_opt_distinct_agg_push_down = 0") + tk.MustQuery("select approx_count_distinct( a) from t;").Check(testkit.Rows("2")) tk.MustExec("drop table t") tk.MustExec("create table t(a decimal(10, 4))") @@ -431,6 +447,7 @@ func (s *testSuiteAgg) TestAggPrune(c *C) { tk.MustExec("create table t(id int primary key, b float, c float, d float)") tk.MustExec("insert into t values(1, 1, 3, NULL), (2, 1, NULL, 6), (3, NULL, 1, 2), (4, NULL, NULL, 1), (5, NULL, 2, NULL), (6, 3, NULL, NULL), (7, NULL, NULL, NULL), (8, 1, 2 ,3)") tk.MustQuery("select count(distinct b, c, d) from t group by id").Check(testkit.Rows("0", "0", "0", "0", "0", "0", "0", "1")) + tk.MustQuery("select approx_count_distinct( b, c, d) from t group by id order by id").Check(testkit.Rows("0", "0", "0", "0", "0", "0", "0", "1")) tk.MustExec("drop table if exists t") tk.MustExec("create table t(a int primary key, b varchar(10))") @@ -928,6 +945,7 @@ func (s *testSuiteAgg) TestIssue10099(c *C) { tk.MustExec("create table t(a char(10), b char(10))") tk.MustExec("insert into t values('1', '222'), ('12', '22')") tk.MustQuery("select count(distinct a, b) from t").Check(testkit.Rows("2")) + tk.MustQuery("select approx_count_distinct( a, b) from t").Check(testkit.Rows("2")) } func (s *testSuiteAgg) TestIssue10098(c *C) { diff --git a/expression/aggregation/agg_to_pb.go b/expression/aggregation/agg_to_pb.go index f82982447190d..98bb1feded596 100644 --- a/expression/aggregation/agg_to_pb.go +++ b/expression/aggregation/agg_to_pb.go @@ -37,6 +37,8 @@ func AggFuncToPBExpr(sc *stmtctx.StatementContext, client kv.Client, aggFunc *Ag switch aggFunc.Name { case ast.AggFuncCount: tp = tipb.ExprType_Count + case ast.AggFuncApproxCountDistinct: + tp = tipb.ExprType_ApproxCountDistinct case ast.AggFuncFirstRow: tp = tipb.ExprType_First case ast.AggFuncGroupConcat: @@ -81,6 +83,8 @@ func PBExprToAggFuncDesc(ctx sessionctx.Context, aggFunc *tipb.Expr, fieldTps [] switch aggFunc.Tp { case tipb.ExprType_Count: name = ast.AggFuncCount + case tipb.ExprType_ApproxCountDistinct: + name = ast.AggFuncApproxCountDistinct case tipb.ExprType_First: name = ast.AggFuncFirstRow case tipb.ExprType_GroupConcat: diff --git a/expression/aggregation/aggregation.go b/expression/aggregation/aggregation.go index 12cca5873f5bb..b76233af80e73 100644 --- a/expression/aggregation/aggregation.go +++ b/expression/aggregation/aggregation.go @@ -208,7 +208,7 @@ func CheckAggPushDown(aggFunc *AggFuncDesc, storeType kv.StoreType) bool { // CheckAggPushFlash checks whether an agg function can be pushed to flash storage. func CheckAggPushFlash(aggFunc *AggFuncDesc) bool { switch aggFunc.Name { - case ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncMin, ast.AggFuncMax, ast.AggFuncAvg, ast.AggFuncFirstRow: + case ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncMin, ast.AggFuncMax, ast.AggFuncAvg, ast.AggFuncFirstRow, ast.AggFuncApproxCountDistinct: return true } return false diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index 17762632454c4..6604f4ca144a2 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -86,6 +86,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) error { switch a.Name { case ast.AggFuncCount: a.typeInfer4Count(ctx) + case ast.AggFuncApproxCountDistinct: + a.typeInfer4ApproxCountDistinct(ctx) case ast.AggFuncSum: a.typeInfer4Sum(ctx) case ast.AggFuncAvg: @@ -124,6 +126,10 @@ func (a *baseFuncDesc) typeInfer4Count(ctx sessionctx.Context) { types.SetBinChsClnFlag(a.RetTp) } +func (a *baseFuncDesc) typeInfer4ApproxCountDistinct(ctx sessionctx.Context) { + a.typeInfer4Count(ctx) +} + // typeInfer4Sum should returns a "decimal", otherwise it returns a "double". // Because child returns integer or decimal type. func (a *baseFuncDesc) typeInfer4Sum(ctx sessionctx.Context) { @@ -260,16 +266,21 @@ func (a *baseFuncDesc) typeInfer4VarPop(ctx sessionctx.Context) { // | t | a | int(11) | // +-------+---------+---------+ // -// Query: `select a, avg(a), sum(a), count(a), bit_xor(a), bit_or(a), bit_and(a), max(a), min(a), group_concat(a) from t;` -// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+ -// | a | avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a) | max(a) | min(a) | group_concat(a) | -// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+ -// | NULL | NULL | NULL | 0 | 0 | 0 | 18446744073709551615 | NULL | NULL | NULL | -// +------+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+ +// Query: `select avg(a), sum(a), count(a), bit_xor(a), bit_or(a), bit_and(a), max(a), min(a), group_concat(a), approx_count_distinct(a) from test.t;` +//+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+ +//| avg(a) | sum(a) | count(a) | bit_xor(a) | bit_or(a) | bit_and(a) | max(a) | min(a) | group_concat(a) | approx_count_distinct(a) | +//+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+ +//| NULL | NULL | 0 | 0 | 0 | 18446744073709551615 | NULL | NULL | NULL | 0 | +//+--------+--------+----------+------------+-----------+----------------------+--------+--------+-----------------+--------------------------+ + func (a *baseFuncDesc) GetDefaultValue() (v types.Datum) { switch a.Name { case ast.AggFuncCount, ast.AggFuncBitOr, ast.AggFuncBitXor: v = types.NewIntDatum(0) + case ast.AggFuncApproxCountDistinct: + if a.RetTp.Tp != mysql.TypeString { + v = types.NewIntDatum(0) + } case ast.AggFuncFirstRow, ast.AggFuncAvg, ast.AggFuncSum, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncGroupConcat: v = types.Datum{} @@ -282,12 +293,13 @@ func (a *baseFuncDesc) GetDefaultValue() (v types.Datum) { // We do not need to wrap cast upon these functions, // since the EvalXXX method called by the arg is determined by the corresponding arg type. var noNeedCastAggFuncs = map[string]struct{}{ - ast.AggFuncCount: {}, - ast.AggFuncMax: {}, - ast.AggFuncMin: {}, - ast.AggFuncFirstRow: {}, - ast.WindowFuncNtile: {}, - ast.AggFuncJsonObjectAgg: {}, + ast.AggFuncCount: {}, + ast.AggFuncApproxCountDistinct: {}, + ast.AggFuncMax: {}, + ast.AggFuncMin: {}, + ast.AggFuncFirstRow: {}, + ast.WindowFuncNtile: {}, + ast.AggFuncJsonObjectAgg: {}, } // WrapCastForAggArgs wraps the args of an aggregate function with a cast function. diff --git a/expression/aggregation/descriptor.go b/expression/aggregation/descriptor.go index c3aab50bd66cc..97471cdcb8774 100644 --- a/expression/aggregation/descriptor.go +++ b/expression/aggregation/descriptor.go @@ -123,6 +123,13 @@ func (a *AggFuncDesc) Split(ordinal []int) (partialAggDesc, finalAggDesc *AggFun RetType: a.RetTp, }) finalAggDesc.Args = args + case ast.AggFuncApproxCountDistinct: + args := make([]expression.Expression, 0, 1) + args = append(args, &expression.Column{ + Index: ordinal[0], + RetType: types.NewFieldType(mysql.TypeString), + }) + finalAggDesc.Args = args default: args := make([]expression.Expression, 0, 1) args = append(args, &expression.Column{ diff --git a/expression/expression.go b/expression/expression.go index a854a5cee3cf5..b4e055b8e0ec0 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -1014,6 +1014,12 @@ func IsPushDownEnabled(name string, storeType kv.StoreType) bool { mask := storeTypeMask(storeType) return !(value&mask == mask) } + + if storeType != kv.TiFlash && name == ast.AggFuncApproxCountDistinct { + // Can not push down approx_count_distinct to other store except tiflash by now. + return false + } + return true } diff --git a/go.mod b/go.mod index 8892648bb068f..e938b83cc7dae 100644 --- a/go.mod +++ b/go.mod @@ -30,11 +30,11 @@ require ( github.com/pingcap/goleveldb v0.0.0-20191226122134-f82aafb29989 github.com/pingcap/kvproto v0.0.0-20200518112156-d4aeb467de29 github.com/pingcap/log v0.0.0-20200511115504-543df19646ad - github.com/pingcap/parser v0.0.0-20200616060540-a87fd8a746fc + github.com/pingcap/parser v0.0.0-20200618113039-789c193fe4b7 github.com/pingcap/pd/v4 v4.0.0-rc.2.0.20200520083007-2c251bd8f181 github.com/pingcap/sysutil v0.0.0-20200408114249-ed3bd6f7fdb1 github.com/pingcap/tidb-tools v4.0.0-rc.1.0.20200514040632-f76b3e428e19+incompatible - github.com/pingcap/tipb v0.0.0-20200417094153-7316d94df1ee + github.com/pingcap/tipb v0.0.0-20200522051215-f31a15d98fce github.com/prometheus/client_golang v1.5.1 github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.9.1 diff --git a/go.sum b/go.sum index 9a930b695fcc4..8b4ee5c8fc39e 100644 --- a/go.sum +++ b/go.sum @@ -406,8 +406,8 @@ github.com/pingcap/log v0.0.0-20200511115504-543df19646ad/go.mod h1:4rbK1p9ILyIf github.com/pingcap/parser v0.0.0-20200424075042-8222d8b724a4/go.mod h1:9v0Edh8IbgjGYW2ArJr19E+bvL8zKahsFp+ixWeId+4= github.com/pingcap/parser v0.0.0-20200507022230-f3bf29096657/go.mod h1:9v0Edh8IbgjGYW2ArJr19E+bvL8zKahsFp+ixWeId+4= github.com/pingcap/parser v0.0.0-20200603032439-c4ecb4508d2f/go.mod h1:9v0Edh8IbgjGYW2ArJr19E+bvL8zKahsFp+ixWeId+4= -github.com/pingcap/parser v0.0.0-20200616060540-a87fd8a746fc h1:Fddt2tbar2sT3PBw31S2akfLpzUuuPlEC9KrYDNwqAE= -github.com/pingcap/parser v0.0.0-20200616060540-a87fd8a746fc/go.mod h1:vQdbJqobJAgFyiRNNtXahpMoGWwPEuWciVEK5A20NS0= +github.com/pingcap/parser v0.0.0-20200618113039-789c193fe4b7 h1:8b1EZ/BcRb9bDTjgGd/QobO6U+9nSyv81U0qGCKmj6U= +github.com/pingcap/parser v0.0.0-20200618113039-789c193fe4b7/go.mod h1:vQdbJqobJAgFyiRNNtXahpMoGWwPEuWciVEK5A20NS0= github.com/pingcap/pd/v4 v4.0.0-rc.1.0.20200422143320-428acd53eba2/go.mod h1:s+utZtXDznOiL24VK0qGmtoHjjXNsscJx3m1n8cC56s= github.com/pingcap/pd/v4 v4.0.0-rc.2.0.20200520083007-2c251bd8f181 h1:FM+PzdoR3fmWAJx3ug+p5aOgs5aZYwFkoDL7Potdsz0= github.com/pingcap/pd/v4 v4.0.0-rc.2.0.20200520083007-2c251bd8f181/go.mod h1:q4HTx/bA8aKBa4S7L+SQKHvjRPXCRV0tA0yRw0qkZSA= @@ -422,8 +422,9 @@ github.com/pingcap/tidb-tools v4.0.0-rc.1.0.20200421113014-507d2bb3a15e+incompat github.com/pingcap/tidb-tools v4.0.0-rc.1.0.20200514040632-f76b3e428e19+incompatible h1:/JKsYjsa5Ug8v5CN4zIbJGIqsvgBUkGwaP/rEScVvWM= github.com/pingcap/tidb-tools v4.0.0-rc.1.0.20200514040632-f76b3e428e19+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= github.com/pingcap/tipb v0.0.0-20190428032612-535e1abaa330/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= -github.com/pingcap/tipb v0.0.0-20200417094153-7316d94df1ee h1:XJQ6/LGzOSc/jo33AD8t7jtc4GohxcyODsYnb+kZXJM= github.com/pingcap/tipb v0.0.0-20200417094153-7316d94df1ee/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= +github.com/pingcap/tipb v0.0.0-20200522051215-f31a15d98fce h1:LDyY6Xh/Z/SHVQ10erWtoOwIxHSTtlpPQ9cvS+BfRMY= +github.com/pingcap/tipb v0.0.0-20200522051215-f31a15d98fce/go.mod h1:RtkHW8WbcNxj8lsbzjaILci01CtYnYbIkQhjyZWrWVI= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/kv/checker.go b/kv/checker.go index 038f3d4e043c1..abcb748d07cab 100644 --- a/kv/checker.go +++ b/kv/checker.go @@ -44,7 +44,7 @@ func (d RequestTypeSupportedChecker) supportExpr(exprType tipb.ExprType) bool { return true // aggregate functions. case tipb.ExprType_Count, tipb.ExprType_First, tipb.ExprType_Max, tipb.ExprType_Min, tipb.ExprType_Sum, tipb.ExprType_Avg, - tipb.ExprType_Agg_BitXor, tipb.ExprType_Agg_BitAnd, tipb.ExprType_Agg_BitOr: + tipb.ExprType_Agg_BitXor, tipb.ExprType_Agg_BitAnd, tipb.ExprType_Agg_BitOr, tipb.ExprType_ApproxCountDistinct: return true case ReqSubTypeDesc: return true diff --git a/planner/cascades/testdata/integration_suite_in.json b/planner/cascades/testdata/integration_suite_in.json index 547c2cee8e486..c95409facdbb9 100644 --- a/planner/cascades/testdata/integration_suite_in.json +++ b/planner/cascades/testdata/integration_suite_in.json @@ -44,6 +44,7 @@ "select count(case when a > 0 and a <= 1000 then b end) from t", "select count(case when a <= 0 or a > 1000 then null else b end) from t", "select count(distinct case when a > 0 and a <= 1000 then b end) from t", + "select approx_count_distinct(case when a > 0 and a <= 1000 then b end) from t", "select count(b), sum(b), avg(b), b, max(b), min(b), bit_and(b), bit_or(b), bit_xor(b) from t group by a having sum(b) >= 0 and count(b) >= 0 order by b" ] }, diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index 3a9a926f85c8a..f8953319f1053 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -435,6 +435,17 @@ "4" ] }, + { + "SQL": "select approx_count_distinct(case when a > 0 and a <= 1000 then b end) from t", + "Plan": [ + "HashAgg_10 1.00 root funcs:approx_count_distinct(test.t.b)->Column#3", + "└─TableReader_11 250.00 root data:TableRangeScan_12", + " └─TableRangeScan_12 250.00 cop[tikv] table:t range:(0,1000], keep order:false, stats:pseudo" + ], + "Result": [ + "4" + ] + }, { "SQL": "select count(b), sum(b), avg(b), b, max(b), min(b), bit_and(b), bit_or(b), bit_xor(b) from t group by a having sum(b) >= 0 and count(b) >= 0 order by b", "Plan": [ diff --git a/planner/cascades/testdata/transformation_rules_suite_in.json b/planner/cascades/testdata/transformation_rules_suite_in.json index f68fbd25699ba..4edf3f5eeb14f 100644 --- a/planner/cascades/testdata/transformation_rules_suite_in.json +++ b/planner/cascades/testdata/transformation_rules_suite_in.json @@ -9,6 +9,7 @@ "select b, @i:=@i+1 as ii from (select a, b, @i:=0 from t as t1) as t2 where @i < 10 and a > 10", "select a, max(b) from t group by a having a > 1", "select a, avg(b) from t group by a having a > 1 and max(b) > 10", + "select a, approx_count_distinct(b) from t group by a having a > 1 and max(b) > 10", "select t1.a, t1.b, t2.b from t t1, t t2 where t1.a = t2.a and t2.b = t1.b and t1.a > 10 and t2.b > 10 and t1.a > t2.b", "select t1.a, t1.b from t t1, t t2 where t1.a = t2.a and t1.a = 10 and t2.a = 5", "select a, f from t where f > 1", @@ -180,6 +181,7 @@ "select sum(case when a > 10 then 0 else c end) from t", "select sum(case when a > 10 then 2 else 1 end) from t", "select count(DISTINCT case when a > 10 then null else c end) from t", + "select approx_count_distinct(case when a > 10 then null else c end) from t", "select sum(DISTINCT case when a > 10 then c else 0 end) from t", "select case when c > 10 then c end from t", "select count(case when a > 10 then c end), c from t", diff --git a/planner/cascades/testdata/transformation_rules_suite_out.json b/planner/cascades/testdata/transformation_rules_suite_out.json index ba8bbcae8e221..c78557711b199 100644 --- a/planner/cascades/testdata/transformation_rules_suite_out.json +++ b/planner/cascades/testdata/transformation_rules_suite_out.json @@ -109,6 +109,23 @@ " TableScan_12 table:t, pk col:test.t.a, cond:[gt(test.t.a, 1)]" ] }, + { + "SQL": "select a, approx_count_distinct(b) from t group by a having a > 1 and max(b) > 10", + "Result": [ + "Group#0 Schema:[test.t.a,Column#16]", + " Projection_5 input:[Group#1], test.t.a, Column#13", + "Group#1 Schema:[test.t.a,Column#13,Column#14]", + " Projection_3 input:[Group#2], test.t.a, Column#13, Column#14", + "Group#2 Schema:[Column#13,Column#14,test.t.a]", + " Selection_8 input:[Group#3], gt(Column#14, 10)", + "Group#3 Schema:[Column#13,Column#14,test.t.a]", + " Aggregation_2 input:[Group#4], group by:test.t.a, funcs:approx_count_distinct(test.t.b), max(test.t.b), firstrow(test.t.a)", + "Group#4 Schema:[test.t.a,test.t.b]", + " TiKVSingleGather_10 input:[Group#5], table:t", + "Group#5 Schema:[test.t.a,test.t.b]", + " TableScan_12 table:t, pk col:test.t.a, cond:[gt(test.t.a, 1)]" + ] + }, { "SQL": "select t1.a, t1.b, t2.b from t t1, t t2 where t1.a = t2.a and t2.b = t1.b and t1.a > 10 and t2.b > 10 and t1.a > t2.b", "Result": [ @@ -2178,6 +2195,19 @@ " DataSource_1 table:t" ] }, + { + "SQL": "select approx_count_distinct(case when a > 10 then null else c end) from t", + "Result": [ + "Group#0 Schema:[Column#13]", + " Projection_3 input:[Group#1], Column#13", + "Group#1 Schema:[Column#13]", + " Aggregation_5 input:[Group#2], funcs:approx_count_distinct(test.t.c)", + "Group#2 Schema:[test.t.a,test.t.c]", + " Selection_4 input:[Group#3], not(gt(test.t.a, 10))", + "Group#3 Schema:[test.t.a,test.t.c]", + " DataSource_1 table:t" + ] + }, { "SQL": "select sum(DISTINCT case when a > 10 then c else 0 end) from t", "Result": [ diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index d513156a140e3..26494242cf09e 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -344,6 +344,42 @@ func (s *testIntegrationSerialSuite) TestSelPushDownTiFlash(c *C) { } } +func (s *testIntegrationSerialSuite) TestAggPushDownEngine(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b varchar(20))") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Se) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + c.Assert(exists, IsTrue) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tiflash'") + + tk.MustQuery("desc select approx_count_distinct(a) from t").Check(testkit.Rows( + "StreamAgg_16 1.00 root funcs:approx_count_distinct(Column#5)->Column#3", + "└─TableReader_17 1.00 root data:StreamAgg_8", + " └─StreamAgg_8 1.00 cop[tiflash] funcs:approx_count_distinct(test.t.a)->Column#5", + " └─TableFullScan_15 10000.00 cop[tiflash] table:t keep order:false, stats:pseudo")) + + tk.MustExec("set @@session.tidb_isolation_read_engines = 'tikv'") + + tk.MustQuery("desc select approx_count_distinct(a) from t").Check(testkit.Rows( + "HashAgg_5 1.00 root funcs:approx_count_distinct(test.t.a)->Column#3", + "└─TableReader_11 10000.00 root data:TableFullScan_10", + " └─TableFullScan_10 10000.00 cop[tikv] table:t keep order:false, stats:pseudo")) +} + func (s *testIntegrationSerialSuite) TestIssue15110(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") @@ -792,6 +828,26 @@ func (s *testIntegrationSuite) TestIssue15546(c *C) { tk.MustQuery("select * from pt, vt where pt.a = vt.a").Check(testkit.Rows("1 1 1 1")) } +func (s *testIntegrationSuite) TestApproxCountDistinctInPartitionTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int(11), b int) partition by range (a) (partition p0 values less than (3), partition p1 values less than maxvalue);") + tk.MustExec("insert into t values(1, 1), (2, 1), (3, 1), (4, 2), (4, 2)") + tk.MustExec(fmt.Sprintf("set session tidb_opt_agg_push_down=1")) + tk.MustQuery("explain select approx_count_distinct(a), b from t group by b order by b desc").Check(testkit.Rows("Sort_11 16000.00 root test.t.b:desc", + "└─HashAgg_16 16000.00 root group by:test.t.b, funcs:approx_count_distinct(Column#5)->Column#4, funcs:firstrow(Column#6)->test.t.b", + " └─PartitionUnion_17 16000.00 root ", + " ├─HashAgg_18 8000.00 root group by:test.t.b, funcs:approx_count_distinct(test.t.a)->Column#5, funcs:firstrow(test.t.b)->Column#6, funcs:firstrow(test.t.b)->test.t.b", + " │ └─TableReader_22 10000.00 root data:TableFullScan_21", + " │ └─TableFullScan_21 10000.00 cop[tikv] table:t, partition:p0 keep order:false, stats:pseudo", + " └─HashAgg_25 8000.00 root group by:test.t.b, funcs:approx_count_distinct(test.t.a)->Column#5, funcs:firstrow(test.t.b)->Column#6, funcs:firstrow(test.t.b)->test.t.b", + " └─TableReader_29 10000.00 root data:TableFullScan_28", + " └─TableFullScan_28 10000.00 cop[tikv] table:t, partition:p1 keep order:false, stats:pseudo")) + tk.MustQuery("select approx_count_distinct(a), b from t group by b order by b desc").Check(testkit.Rows("1 2", "3 1")) +} + func (s *testIntegrationSuite) TestHintWithRequiredProperty(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index a3c76f43a2ef0..e7d371f321c42 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -59,7 +59,7 @@ func (a *aggregationPushDownSolver) isDecomposableWithUnion(fun *aggregation.Agg return false case ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncFirstRow: return true - case ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncAvg: + case ast.AggFuncSum, ast.AggFuncCount, ast.AggFuncAvg, ast.AggFuncApproxCountDistinct: return true default: return false diff --git a/planner/core/rule_join_elimination.go b/planner/core/rule_join_elimination.go index f4fccf6d6c1c8..71b457844cb9b 100644 --- a/planner/core/rule_join_elimination.go +++ b/planner/core/rule_join_elimination.go @@ -164,7 +164,8 @@ func GetDupAgnosticAggCols( if !aggDesc.HasDistinct && aggDesc.Name != ast.AggFuncFirstRow && aggDesc.Name != ast.AggFuncMax && - aggDesc.Name != ast.AggFuncMin { + aggDesc.Name != ast.AggFuncMin && + aggDesc.Name != ast.AggFuncApproxCountDistinct { // If not all aggregate functions are duplicate agnostic, // we should clean the aggCols, so `return true, newAggCols[:0]`. return true, newAggCols[:0] diff --git a/planner/core/task.go b/planner/core/task.go index 6f5ec293c326b..b38bc98de3d78 100644 --- a/planner/core/task.go +++ b/planner/core/task.go @@ -1098,6 +1098,17 @@ func BuildFinalModeAggregation( args = append(args, partial.Schema.Columns[partialCursor]) partialCursor++ } + if finalAggFunc.Name == ast.AggFuncApproxCountDistinct { + ft := types.NewFieldType(mysql.TypeString) + ft.Charset, ft.Collate = charset.CharsetBin, charset.CollationBin + ft.Flag |= mysql.NotNullFlag + partial.Schema.Append(&expression.Column{ + UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), + RetType: ft, + }) + args = append(args, partial.Schema.Columns[partialCursor]) + partialCursor++ + } if aggregation.NeedValue(finalAggFunc.Name) { partial.Schema.Append(&expression.Column{ UniqueID: sctx.GetSessionVars().AllocPlanColumnID(), @@ -1115,6 +1126,11 @@ func BuildFinalModeAggregation( sumAgg.Name = ast.AggFuncSum sumAgg.RetTp = partial.Schema.Columns[partialCursor-1].GetType() partial.AggFuncs = append(partial.AggFuncs, &cntAgg, &sumAgg) + } else if aggFunc.Name == ast.AggFuncApproxCountDistinct { + approxCountDistinctAgg := *aggFunc + approxCountDistinctAgg.Name = ast.AggFuncApproxCountDistinct + approxCountDistinctAgg.RetTp = partial.Schema.Columns[partialCursor-1].GetType() + partial.AggFuncs = append(partial.AggFuncs, &approxCountDistinctAgg) } else { partial.AggFuncs = append(partial.AggFuncs, aggFunc) } @@ -1236,17 +1252,26 @@ func RemoveUnnecessaryFirstRow( continue } } - if aggregation.NeedCount(aggFunc.Name) { - partialCursor++ - } - if aggregation.NeedValue(aggFunc.Name) { - partialCursor++ - } + partialCursor += computePartialCursorOffset(aggFunc.Name) newAggFuncs = append(newAggFuncs, aggFunc) } return newAggFuncs } +func computePartialCursorOffset(name string) int { + offset := 0 + if aggregation.NeedCount(name) { + offset++ + } + if aggregation.NeedValue(name) { + offset++ + } + if name == ast.AggFuncApproxCountDistinct { + offset++ + } + return offset +} + func (p *PhysicalStreamAgg) attach2Task(tasks ...task) task { t := tasks[0].copy() inputRows := t.count() diff --git a/planner/core/testdata/plan_suite_unexported_in.json b/planner/core/testdata/plan_suite_unexported_in.json index 8b2fe43e88f01..b71d9f9960da7 100644 --- a/planner/core/testdata/plan_suite_unexported_in.json +++ b/planner/core/testdata/plan_suite_unexported_in.json @@ -24,7 +24,9 @@ "select t1.a, count(t2.b) from t t1, t t2 where t1.a = t2.a group by t1.a", "select avg(a) from (select * from t t1 union all select * from t t2) t", "select count(distinct a) from (select * from t t1 union all select * from t t2) t", - "select count(distinct b) from (select * from t t1 union all select * from t t2) t" + "select count(distinct b) from (select * from t t1 union all select * from t t2) t", + "select approx_count_distinct(a) from (select * from t t1 union all select * from t t2) t", + "select approx_count_distinct(b) from (select * from t t1 union all select * from t t2) t" ] }, { @@ -371,7 +373,8 @@ "select tt.a, sum(tt.b) from (select a, b from t) tt group by tt.a", "select count(1) from (select count(1), a as b from t group by a) tt group by b", "select a, count(b) from t group by a", - "select a, count(distinct a, b) from t group by a" + "select a, count(distinct a, b) from t group by a", + "select a, approx_count_distinct(a, b) from t group by a" ] }, { @@ -512,6 +515,7 @@ "select max(t1.a), min(test.t1.b) from t t1 left join t t2 on t1.b = t2.b", "select sum(distinct t1.a) from t t1 left join t t2 on t1.a = t2.a and t1.b = t2.b", "select count(distinct t1.a, t1.b) from t t1 left join t t2 on t1.b = t2.b", + "select approx_count_distinct(t1.a, t1.b) from t t1 left join t t2 on t1.b = t2.b", // Test left outer join "select t1.b from t t1 left outer join t t2 on t1.a = t2.a", // Test right outer join diff --git a/planner/core/testdata/plan_suite_unexported_out.json b/planner/core/testdata/plan_suite_unexported_out.json index 548280e498344..388ab97d628d7 100644 --- a/planner/core/testdata/plan_suite_unexported_out.json +++ b/planner/core/testdata/plan_suite_unexported_out.json @@ -24,7 +24,9 @@ "Join{DataScan(t1)->DataScan(t2)}(test.t.a,test.t.a)->Projection->Projection", "UnionAll{DataScan(t1)->Projection->Aggr(count(test.t.a),sum(test.t.a))->DataScan(t2)->Projection->Aggr(count(test.t.a),sum(test.t.a))}->Aggr(avg(Column#38, Column#39))->Projection", "UnionAll{DataScan(t1)->Projection->Projection->Projection->DataScan(t2)->Projection->Projection->Projection}->Aggr(count(distinct Column#25))->Projection", - "UnionAll{DataScan(t1)->Projection->Aggr(firstrow(test.t.b),firstrow(test.t.b))->DataScan(t2)->Projection->Aggr(firstrow(test.t.b),firstrow(test.t.b))}->Aggr(count(distinct Column#26))->Projection" + "UnionAll{DataScan(t1)->Projection->Aggr(firstrow(test.t.b),firstrow(test.t.b))->DataScan(t2)->Projection->Aggr(firstrow(test.t.b),firstrow(test.t.b))}->Aggr(count(distinct Column#26))->Projection", + "UnionAll{DataScan(t1)->Projection->Aggr(approx_count_distinct(test.t.a))->DataScan(t2)->Projection->Aggr(approx_count_distinct(test.t.a))}->Aggr(approx_count_distinct(Column#38))->Projection", + "UnionAll{DataScan(t1)->Projection->Aggr(approx_count_distinct(test.t.b))->DataScan(t2)->Projection->Aggr(approx_count_distinct(test.t.b))}->Aggr(approx_count_distinct(Column#38))->Projection" ] }, { @@ -561,7 +563,8 @@ "DataScan(t)->Projection", "DataScan(t)->Projection", "DataScan(t)->Projection", - "DataScan(t)->Projection" + "DataScan(t)->Projection", + "DataScan(t)->Aggr(approx_count_distinct(test.t.a, test.t.b),firstrow(test.t.a))->Projection" ] }, { @@ -902,6 +905,7 @@ "DataScan(t1)->Aggr(max(test.t.a),min(test.t.b))->Projection", "DataScan(t1)->Aggr(sum(distinct test.t.a))->Projection", "DataScan(t1)->Aggr(count(distinct test.t.a, test.t.b))->Projection", + "DataScan(t1)->Aggr(approx_count_distinct(test.t.a, test.t.b))->Projection", "DataScan(t1)->Projection", "DataScan(t2)->Projection", "Join{Join{DataScan(t1)->DataScan(t2)}(test.t.a,test.t.a)->DataScan(t3)->TopN([test.t.b true],0,1)}(test.t.b,test.t.b)->TopN([test.t.b true],0,1)->Aggr(max(test.t.b))->Projection",