Skip to content

Commit

Permalink
executor: make memory tracker for aggregate more accurate. (#22463)
Browse files Browse the repository at this point in the history
  • Loading branch information
wshwsh12 authored Feb 18, 2021
1 parent c9af430 commit fb84db1
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 10 deletions.
64 changes: 54 additions & 10 deletions executor/aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,33 @@ type baseHashAggWorker struct {
aggFuncs []aggfuncs.AggFunc
maxChunkSize int
stats *AggWorkerStat

memTracker *memory.Tracker
BInMap int // indicate there are 2^BInMap buckets in Golang Map.
}

func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc, maxChunkSize int) baseHashAggWorker {
return baseHashAggWorker{
const (
// ref https://github.com/golang/go/blob/go1.15.6/src/reflect/type.go#L2162.
// defBucketMemoryUsage = bucketSize*(1+unsafe.Sizeof(string) + unsafe.Sizeof(slice))+2*ptrSize
// The bucket size may be changed by golang implement in the future.
defBucketMemoryUsage = 8*(1+16+24) + 16
// Maximum average load of a bucket that triggers growth is 6.5.
// Represent as loadFactorNum/loadFactDen, to allow integer math.
loadFactorNum = 13
loadFactorDen = 2
)

func newBaseHashAggWorker(ctx sessionctx.Context, finishCh <-chan struct{}, aggFuncs []aggfuncs.AggFunc,
maxChunkSize int, memTrack *memory.Tracker) baseHashAggWorker {
baseWorker := baseHashAggWorker{
ctx: ctx,
finishCh: finishCh,
aggFuncs: aggFuncs,
maxChunkSize: maxChunkSize,
memTracker: memTrack,
BInMap: 0,
}
return baseWorker
}

// HashAggPartialWorker indicates the partial workers of parallel hash agg execution,
Expand All @@ -76,8 +94,7 @@ type HashAggPartialWorker struct {
groupKey [][]byte
// chk stores the input data from child,
// and is reused by childExec and partial worker.
chk *chunk.Chunk
memTracker *memory.Tracker
chk *chunk.Chunk
}

// HashAggFinalWorker indicates the final workers of parallel hash agg execution,
Expand Down Expand Up @@ -296,7 +313,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
// Init partial workers.
for i := 0; i < partialConcurrency; i++ {
w := HashAggPartialWorker{
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.PartialAggFuncs, e.maxChunkSize),
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.PartialAggFuncs, e.maxChunkSize, e.memTracker),
inputCh: e.partialInputChs[i],
outputChs: e.partialOutputChs,
giveBackCh: e.inputCh,
Expand All @@ -305,8 +322,9 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
groupByItems: e.GroupByItems,
chk: newFirstChunk(e.children[0]),
groupKey: make([][]byte, 0, 8),
memTracker: e.memTracker,
}
// There is a bucket in the empty partialResultsMap.
e.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap))
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.PartialStats = append(e.stats.PartialStats, w.stats)
Expand All @@ -324,7 +342,7 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
// Init final workers.
for i := 0; i < finalConcurrency; i++ {
w := HashAggFinalWorker{
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.FinalAggFuncs, e.maxChunkSize),
baseHashAggWorker: newBaseHashAggWorker(e.ctx, e.finishCh, e.FinalAggFuncs, e.maxChunkSize, e.memTracker),
partialResultMap: make(aggPartialResultMapper),
groupSet: set.NewStringSet(),
inputCh: e.partialOutputChs[i],
Expand All @@ -334,6 +352,8 @@ func (e *HashAggExec) initForParallelExec(ctx sessionctx.Context) {
mutableRow: chunk.MutRowFromTypes(retTypes(e)),
groupKeys: make([][]byte, 0, 8),
}
// There is a bucket in the empty partialResultsMap.
e.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap))
if e.stats != nil {
w.stats = &AggWorkerStat{}
e.stats.FinalStats = append(e.stats.FinalStats, w.stats)
Expand Down Expand Up @@ -406,8 +426,19 @@ func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitG
}
}

func getGroupKeyMemUsage(groupKey [][]byte) int64 {
mem := int64(0)
for _, key := range groupKey {
mem += int64(cap(key))
}
mem += 12 * int64(cap(groupKey))
return mem
}

func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *stmtctx.StatementContext, chk *chunk.Chunk, finalConcurrency int) (err error) {
memSize := getGroupKeyMemUsage(w.groupKey)
w.groupKey, err = getGroupKey(w.ctx, chk, w.groupKey, w.groupByItems)
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKey) - memSize)
if err != nil {
return err
}
Expand All @@ -418,9 +449,11 @@ func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *s
for i := 0; i < numRows; i++ {
for j, af := range w.aggFuncs {
rows[0] = chk.GetRow(i)
if _, err := af.UpdatePartialResult(ctx, rows, partialResults[i][j]); err != nil {
memDelta, err := af.UpdatePartialResult(ctx, rows, partialResults[i][j])
if err != nil {
return err
}
w.memTracker.Consume(memDelta)
}
}
return nil
Expand Down Expand Up @@ -487,7 +520,7 @@ func getGroupKey(ctx sessionctx.Context, input *chunk.Chunk, groupKey [][]byte,
return groupKey, nil
}

func (w baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, groupKey [][]byte, mapper aggPartialResultMapper) [][]aggfuncs.PartialResult {
func (w *baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, groupKey [][]byte, mapper aggPartialResultMapper) [][]aggfuncs.PartialResult {
n := len(groupKey)
partialResults := make([][]aggfuncs.PartialResult, n)
for i := 0; i < n; i++ {
Expand All @@ -496,10 +529,17 @@ func (w baseHashAggWorker) getPartialResult(sc *stmtctx.StatementContext, groupK
continue
}
for _, af := range w.aggFuncs {
partialResult, _ := af.AllocPartialResult()
partialResult, memDelta := af.AllocPartialResult()
partialResults[i] = append(partialResults[i], partialResult)
w.memTracker.Consume(memDelta)
}
mapper[string(groupKey[i])] = partialResults[i]
w.memTracker.Consume(int64(len(groupKey[i])))
// Map will expand when count > bucketNum * loadFactor. The memory usage will doubled.
if len(mapper) > (1<<w.BInMap)*loadFactorNum/loadFactorDen {
w.memTracker.Consume(defBucketMemoryUsage * (1 << w.BInMap))
w.BInMap++
}
}
return partialResults
}
Expand Down Expand Up @@ -541,10 +581,12 @@ func (w *HashAggFinalWorker) consumeIntermData(sctx sessionctx.Context) (err err
for reachEnd := false; !reachEnd; {
intermDataBuffer, groupKeys, reachEnd = input.getPartialResultBatch(sc, intermDataBuffer[:0], w.aggFuncs, w.maxChunkSize)
groupKeysLen := len(groupKeys)
memSize := getGroupKeyMemUsage(w.groupKeys)
w.groupKeys = w.groupKeys[:0]
for i := 0; i < groupKeysLen; i++ {
w.groupKeys = append(w.groupKeys, []byte(groupKeys[i]))
}
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKeys) - memSize)
finalPartialResults := w.getPartialResult(sc, w.groupKeys, w.partialResultMap)
for i, groupKey := range groupKeys {
if !w.groupSet.Exist(groupKey) {
Expand Down Expand Up @@ -575,10 +617,12 @@ func (w *HashAggFinalWorker) getFinalResult(sctx sessionctx.Context) {
return
}
execStart := time.Now()
memSize := getGroupKeyMemUsage(w.groupKeys)
w.groupKeys = w.groupKeys[:0]
for groupKey := range w.groupSet {
w.groupKeys = append(w.groupKeys, []byte(groupKey))
}
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKeys) - memSize)
partialResults := w.getPartialResult(sctx.GetSessionVars().StmtCtx, w.groupKeys, w.partialResultMap)
for i := 0; i < len(w.groupSet); i++ {
for j, af := range w.aggFuncs {
Expand Down
57 changes: 57 additions & 0 deletions executor/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"math/rand"
"os"
"sort"
"strconv"
"strings"
"sync"
"testing"
Expand All @@ -29,6 +30,7 @@ import (
"github.com/pingcap/log"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/planner/core"
Expand Down Expand Up @@ -2016,3 +2018,58 @@ func BenchmarkReadLastLinesOfHugeLine(b *testing.B) {
}
}
}

func BenchmarkAggPartialResultMapperMemoryUsage(b *testing.B) {
b.ReportAllocs()
type testCase struct {
rowNum int
expectedB int
}
cases := []testCase{
{
rowNum: 0,
expectedB: 0,
},
{
rowNum: 100,
expectedB: 4,
},
{
rowNum: 10000,
expectedB: 11,
},
{
rowNum: 1000000,
expectedB: 18,
},
{
rowNum: 851968, // 6.5 * (1 << 17)
expectedB: 17,
},
{
rowNum: 851969, // 6.5 * (1 << 17) + 1
expectedB: 18,
},
{
rowNum: 425984, // 6.5 * (1 << 16)
expectedB: 16,
},
{
rowNum: 425985, // 6.5 * (1 << 16) + 1
expectedB: 17,
},
}

for _, c := range cases {
b.Run(fmt.Sprintf("MapRows %v", c.rowNum), func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
aggMap := make(aggPartialResultMapper)
tempSlice := make([]aggfuncs.PartialResult, 10)
for num := 0; num < c.rowNum; num++ {
aggMap[strconv.Itoa(num)] = tempSlice
}
}
})
}
}
134 changes: 134 additions & 0 deletions executor/executor_pkg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ package executor
import (
"context"
"crypto/tls"
"runtime"
"strconv"
"time"
"unsafe"

. "github.com/pingcap/check"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/config"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
plannerutil "github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/sessionctx/variable"
Expand Down Expand Up @@ -446,3 +450,133 @@ func (s *pkgTestSuite) TestSlowQueryRuntimeStats(c *C) {
stats.Merge(stats.Clone())
c.Assert(stats.String(), Equals, "initialize: 2ms, read_file: 2s, parse_log: {time:200ms, concurrency:15}, total_file: 4, read_file: 4, read_size: 2 GB")
}

// Test whether the actual buckets in Golang Map is same with the estimated number.
// The test relies the implement of Golang Map. ref https://github.com/golang/go/blob/go1.13/src/runtime/map.go#L114
func (s *pkgTestSuite) TestAggPartialResultMapperB(c *C) {
if runtime.Version() < `go1.13` {
c.Skip("Unsupported version")
}
type testCase struct {
rowNum int
expectedB int
expectedGrowing bool
}
cases := []testCase{
{
rowNum: 0,
expectedB: 0,
expectedGrowing: false,
},
{
rowNum: 100,
expectedB: 4,
expectedGrowing: false,
},
{
rowNum: 10000,
expectedB: 11,
expectedGrowing: false,
},
{
rowNum: 1000000,
expectedB: 18,
expectedGrowing: false,
},
{
rowNum: 851968, // 6.5 * (1 << 17)
expectedB: 17,
expectedGrowing: false,
},
{
rowNum: 851969, // 6.5 * (1 << 17) + 1
expectedB: 18,
expectedGrowing: true,
},
{
rowNum: 425984, // 6.5 * (1 << 16)
expectedB: 16,
expectedGrowing: false,
},
{
rowNum: 425985, // 6.5 * (1 << 16) + 1
expectedB: 17,
expectedGrowing: true,
},
}

for _, tc := range cases {
aggMap := make(aggPartialResultMapper)
tempSlice := make([]aggfuncs.PartialResult, 10)
for num := 0; num < tc.rowNum; num++ {
aggMap[strconv.Itoa(num)] = tempSlice
}

c.Assert(getB(aggMap), Equals, tc.expectedB)
c.Assert(getGrowing(aggMap), Equals, tc.expectedGrowing)
}
}

// A header for a Go map.
type hmap struct {
// Note: the format of the hmap is also encoded in cmd/compile/internal/gc/reflect.go.
// Make sure this stays in sync with the compiler's definition.
count int // # live cells == size of map. Must be first (used by len() builtin)
flags uint8
B uint8 // log_2 of # of buckets (can hold up to loadFactor * 2^B items)
noverflow uint16 // approximate number of overflow buckets; see incrnoverflow for details
hash0 uint32 // hash seed

buckets unsafe.Pointer // array of 2^B Buckets. may be nil if count==0.
oldbuckets unsafe.Pointer // previous bucket array of half the size, non-nil only when growing
nevacuate uintptr // progress counter for evacuation (buckets less than this have been evacuated)

extra *mapextra // optional fields
}

// mapextra holds fields that are not present on all maps.
type mapextra struct {
// If both key and elem do not contain pointers and are inline, then we mark bucket
// type as containing no pointers. This avoids scanning such maps.
// However, bmap.overflow is a pointer. In order to keep overflow buckets
// alive, we store pointers to all overflow buckets in hmap.extra.overflow and hmap.extra.oldoverflow.
// overflow and oldoverflow are only used if key and elem do not contain pointers.
// overflow contains overflow buckets for hmap.buckets.
// oldoverflow contains overflow buckets for hmap.oldbuckets.
// The indirection allows to store a pointer to the slice in hiter.
overflow *[]*bmap
oldoverflow *[]*bmap

// nextOverflow holds a pointer to a free overflow bucket.
nextOverflow *bmap
}

const (
bucketCntBits = 3
bucketCnt = 1 << bucketCntBits
)

// A bucket for a Go map.
type bmap struct {
// tophash generally contains the top byte of the hash value
// for each key in this bucket. If tophash[0] < minTopHash,
// tophash[0] is a bucket evacuation state instead.
tophash [bucketCnt]uint8
// Followed by bucketCnt keys and then bucketCnt elems.
// NOTE: packing all the keys together and then all the elems together makes the
// code a bit more complicated than alternating key/elem/key/elem/... but it allows
// us to eliminate padding which would be needed for, e.g., map[int64]int8.
// Followed by an overflow pointer.
}

func getB(m aggPartialResultMapper) int {
point := (**hmap)(unsafe.Pointer(&m))
value := *point
return int(value.B)
}

func getGrowing(m aggPartialResultMapper) bool {
point := (**hmap)(unsafe.Pointer(&m))
value := *point
return value.oldbuckets != nil
}

0 comments on commit fb84db1

Please sign in to comment.