Skip to content

Commit

Permalink
*: fix wrong result when to concurrency merge global stats (#48852)
Browse files Browse the repository at this point in the history
close #48713
  • Loading branch information
hawkingrei authored Nov 24, 2023
1 parent 6ca9813 commit 26db590
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 54 deletions.
59 changes: 26 additions & 33 deletions pkg/statistics/handle/globalstats/merge_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ type topnStatsMergeWorker struct {
respCh chan<- *TopnStatsMergeResponse
// the stats in the wrapper should only be read during the worker
statsWrapper *StatsWrapper
// Different TopN structures may hold the same value, we have to merge them.
counter map[hack.MutableString]float64
// shardMutex is used to protect `statsWrapper.AllHg`
shardMutex []sync.Mutex
mu sync.Mutex
}

// NewTopnStatsMergeWorker returns topn merge worker
Expand All @@ -54,8 +57,9 @@ func NewTopnStatsMergeWorker(
wrapper *StatsWrapper,
killer *sqlkiller.SQLKiller) *topnStatsMergeWorker {
worker := &topnStatsMergeWorker{
taskCh: taskCh,
respCh: respCh,
taskCh: taskCh,
respCh: respCh,
counter: make(map[hack.MutableString]float64),
}
worker.statsWrapper = wrapper
worker.shardMutex = make([]sync.Mutex, len(wrapper.AllHg))
Expand All @@ -79,33 +83,24 @@ func NewTopnStatsMergeTask(start, end int) *TopnStatsMergeTask {

// TopnStatsMergeResponse indicates topn merge worker response
type TopnStatsMergeResponse struct {
Err error
TopN *statistics.TopN
PopedTopn []statistics.TopNMeta
Err error
}

// Run runs topn merge like statistics.MergePartTopN2GlobalTopN
func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool,
n uint32,
version int) {
func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool, version int) {
for task := range worker.taskCh {
start := task.start
end := task.end
checkTopNs := worker.statsWrapper.AllTopN[start:end]
allTopNs := worker.statsWrapper.AllTopN
allHists := worker.statsWrapper.AllHg
resp := &TopnStatsMergeResponse{}
if statistics.CheckEmptyTopNs(checkTopNs) {
worker.respCh <- resp
return
}

partNum := len(allTopNs)
// Different TopN structures may hold the same value, we have to merge them.
counter := make(map[hack.MutableString]float64)

// datumMap is used to store the mapping from the string type to datum type.
// The datum is used to find the value in the histogram.
datumMap := statistics.NewDatumMapCache()

for i, topN := range checkTopNs {
i = i + start
if err := worker.killer.HandleSignal(); err != nil {
Expand All @@ -118,12 +113,15 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool,
}
for _, val := range topN.TopN {
encodedVal := hack.String(val.Encoded)
_, exists := counter[encodedVal]
counter[encodedVal] += float64(val.Count)
worker.mu.Lock()
_, exists := worker.counter[encodedVal]
worker.counter[encodedVal] += float64(val.Count)
if exists {
worker.mu.Unlock()
// We have already calculated the encodedVal from the histogram, so just continue to next topN value.
continue
}
worker.mu.Unlock()
// We need to check whether the value corresponding to encodedVal is contained in other partition-level stats.
// 1. Check the topN first.
// 2. If the topN doesn't contain the value corresponding to encodedVal. We should check the histogram.
Expand All @@ -147,31 +145,26 @@ func (worker *topnStatsMergeWorker) Run(timeZone *time.Location, isIndex bool,
}
datum = d
}
worker.shardMutex[j].Lock()
// Get the row count which the value is equal to the encodedVal from histogram.
count, _ := allHists[j].EqualRowCount(nil, datum, isIndex)
if count != 0 {
counter[encodedVal] += count
// Remove the value corresponding to encodedVal from the histogram.
worker.shardMutex[j].Lock()
worker.statsWrapper.AllHg[j].BinarySearchRemoveVal(statistics.TopNMeta{Encoded: datum.GetBytes(), Count: uint64(count)})
worker.shardMutex[j].Unlock()
}
worker.shardMutex[j].Unlock()
if count != 0 {
worker.mu.Lock()
worker.counter[encodedVal] += count
worker.mu.Unlock()
}
}
}
}
numTop := len(counter)
if numTop == 0 {
worker.respCh <- resp
continue
}
sorted := make([]statistics.TopNMeta, 0, numTop)
for value, cnt := range counter {
data := hack.Slice(string(value))
sorted = append(sorted, statistics.TopNMeta{Encoded: data, Count: uint64(cnt)})
}
globalTopN, leftTopN := statistics.GetMergedTopNFromSortedSlice(sorted, n)
resp.TopN = globalTopN
resp.PopedTopn = leftTopN
worker.respCh <- resp
}
}

func (worker *topnStatsMergeWorker) Result() map[hack.MutableString]float64 {
return worker.counter
}
34 changes: 13 additions & 21 deletions pkg/statistics/handle/globalstats/topn.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@ import (
func mergeGlobalStatsTopN(gp *gp.Pool, sc sessionctx.Context, wrapper *StatsWrapper,
timeZone *time.Location, version int, n uint32, isIndex bool) (*statistics.TopN,
[]statistics.TopNMeta, []*statistics.Histogram, error) {
if statistics.CheckEmptyTopNs(wrapper.AllTopN) {
return nil, nil, wrapper.AllHg, nil
}
mergeConcurrency := sc.GetSessionVars().AnalyzePartitionMergeConcurrency
killer := &sc.GetSessionVars().SQLKiller

// use original method if concurrency equals 1 or for version1
if mergeConcurrency < 2 {
return MergePartTopN2GlobalTopN(timeZone, version, wrapper.AllTopN, n, wrapper.AllHg, isIndex, killer)
Expand Down Expand Up @@ -78,12 +82,12 @@ func MergeGlobalStatsTopNByConcurrency(
taskNum := len(tasks)
taskCh := make(chan *TopnStatsMergeTask, taskNum)
respCh := make(chan *TopnStatsMergeResponse, taskNum)
worker := NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killer)
for i := 0; i < mergeConcurrency; i++ {
worker := NewTopnStatsMergeWorker(taskCh, respCh, wrapper, killer)
wg.Add(1)
gp.Go(func() {
defer wg.Done()
worker.Run(timeZone, isIndex, n, version)
worker.Run(timeZone, isIndex, version)
})
}
for _, task := range tasks {
Expand All @@ -92,8 +96,6 @@ func MergeGlobalStatsTopNByConcurrency(
close(taskCh)
wg.Wait()
close(respCh)
resps := make([]*TopnStatsMergeResponse, 0)

// handle Error
hasErr := false
errMsg := make([]string, 0)
Expand All @@ -102,27 +104,21 @@ func MergeGlobalStatsTopNByConcurrency(
hasErr = true
errMsg = append(errMsg, resp.Err.Error())
}
resps = append(resps, resp)
}
if hasErr {
return nil, nil, nil, errors.New(strings.Join(errMsg, ","))
}

// fetch the response from each worker and merge them into global topn stats
sorted := make([]statistics.TopNMeta, 0, mergeConcurrency)
leftTopn := make([]statistics.TopNMeta, 0)
for _, resp := range resps {
if resp.TopN != nil {
sorted = append(sorted, resp.TopN.TopN...)
}
leftTopn = append(leftTopn, resp.PopedTopn...)
counter := worker.Result()
numTop := len(counter)
sorted := make([]statistics.TopNMeta, 0, numTop)
for value, cnt := range counter {
data := hack.Slice(string(value))
sorted = append(sorted, statistics.TopNMeta{Encoded: data, Count: uint64(cnt)})
}

globalTopN, popedTopn := statistics.GetMergedTopNFromSortedSlice(sorted, n)

result := append(leftTopn, popedTopn...)
statistics.SortTopnMeta(result)
return globalTopN, result, wrapper.AllHg, nil
return globalTopN, popedTopn, wrapper.AllHg, nil
}

// MergePartTopN2GlobalTopN is used to merge the partition-level topN to global-level topN.
Expand All @@ -149,10 +145,6 @@ func MergePartTopN2GlobalTopN(
isIndex bool,
killer *sqlkiller.SQLKiller,
) (*statistics.TopN, []statistics.TopNMeta, []*statistics.Histogram, error) {
if statistics.CheckEmptyTopNs(topNs) {
return nil, nil, hists, nil
}

partNum := len(topNs)
// Different TopN structures may hold the same value, we have to merge them.
counter := make(map[hack.MutableString]float64)
Expand Down

0 comments on commit 26db590

Please sign in to comment.