Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: enable parallel sort #53537

Merged
merged 13 commits into from
Jun 6, 2024
22 changes: 9 additions & 13 deletions pkg/executor/sortexec/parallel_sort_spill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
var hardLimit1 = int64(100000)
var hardLimit2 = hardLimit1 * 10

func oneSpillCase(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
func oneSpillCase(t *testing.T, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
Expand Down Expand Up @@ -60,15 +60,15 @@ func inMemoryThenSpill(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec,
require.True(t, checkCorrectness(schema, exe, dataSource, resultChunks))
}

func failpointNoMemoryDataTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
func failpointNoMemoryDataTest(t *testing.T, exe *sortexec.SortExec, sortCase *testutil.SortCase, dataSource *testutil.MockDataSource) {
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
dataSource.PrepareChunks()
executeInFailpoint(t, exe, 0, nil)
}

func failpointDataInMemoryThenSpillTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
func failpointDataInMemoryThenSpillTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, dataSource *testutil.MockDataSource) {
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
Expand All @@ -91,15 +91,13 @@ func TestParallelSortSpillDisk(t *testing.T) {
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true

schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
for i := 0; i < 10; i++ {
oneSpillCase(t, ctx, nil, sortCase, schema, dataSource)
oneSpillCase(t, ctx, exe, sortCase, schema, dataSource)
oneSpillCase(t, nil, sortCase, schema, dataSource)
oneSpillCase(t, exe, sortCase, schema, dataSource)
}

ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2)
Expand Down Expand Up @@ -129,21 +127,19 @@ func TestParallelSortSpillDiskFailpoint(t *testing.T) {
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true

schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
for i := 0; i < 20; i++ {
failpointNoMemoryDataTest(t, ctx, nil, sortCase, schema, dataSource)
failpointNoMemoryDataTest(t, ctx, exe, sortCase, schema, dataSource)
failpointNoMemoryDataTest(t, nil, sortCase, dataSource)
failpointNoMemoryDataTest(t, exe, sortCase, dataSource)
}

ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
for i := 0; i < 20; i++ {
failpointDataInMemoryThenSpillTest(t, ctx, nil, sortCase, schema, dataSource)
failpointDataInMemoryThenSpillTest(t, ctx, exe, sortCase, schema, dataSource)
failpointDataInMemoryThenSpillTest(t, ctx, nil, sortCase, dataSource)
failpointDataInMemoryThenSpillTest(t, ctx, exe, sortCase, dataSource)
}
}
6 changes: 0 additions & 6 deletions pkg/executor/sortexec/parallel_sort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ func executeInFailpoint(t *testing.T, exe *sortexec.SortExec, hardLimit int64, t
tmpCtx := context.Background()
err := exe.Open(tmpCtx)
require.NoError(t, err)
exe.IsUnparallel = false
exe.InitInParallelModeForTest()

goRoutineWaiter := sync.WaitGroup{}
goRoutineWaiter.Add(1)
Expand Down Expand Up @@ -85,8 +83,6 @@ func parallelSortTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, s
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true

if exe == nil {
exe = buildSortExec(sortCase, dataSource)
Expand All @@ -105,8 +101,6 @@ func failpointTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sort
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/executor/sortexec/parallel_sort_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ func newParallelSortWorker(
}
}

func (p *parallelSortWorker) reset() {
p.batchRows = nil
p.localSortedRows = nil
p.sortedRowsIter = nil
p.merger = nil
p.memTracker.ReplaceBytesUsed(0)
}

func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor int32) {
injectParallelSortRandomFail(triggerFactor)
failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) {
Expand Down
56 changes: 29 additions & 27 deletions pkg/executor/sortexec/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"sync/atomic"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/expression"
Expand Down Expand Up @@ -55,6 +56,7 @@ type SortExec struct {
memTracker *memory.Tracker
diskTracker *disk.Tracker

// TODO delete this variable in the future and remove the unparallel sort
IsUnparallel bool

finishCh chan struct{}
Expand Down Expand Up @@ -124,11 +126,9 @@ func (e *SortExec) Close() error {
// will use `e.Parallel.workers` and `e.Parallel.merger`.
channel.Clear(e.Parallel.resultChannel)
for i := range e.Parallel.workers {
e.Parallel.workers[i].batchRows = nil
e.Parallel.workers[i].localSortedRows = nil
e.Parallel.workers[i].sortedRowsIter = nil
e.Parallel.workers[i].merger = nil
e.Parallel.workers[i].memTracker.ReplaceBytesUsed(0)
if e.Parallel.workers[i] != nil {
e.Parallel.workers[i].reset()
}
}
e.Parallel.merger = nil
if e.Parallel.spillAction != nil {
Expand Down Expand Up @@ -160,7 +160,7 @@ func (e *SortExec) Open(ctx context.Context) error {
e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker)
}

e.IsUnparallel = true
e.IsUnparallel = false
windtalker marked this conversation as resolved.
Show resolved Hide resolved
if e.IsUnparallel {
e.Unparallel.Idx = 0
e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0]
Expand All @@ -185,24 +185,10 @@ func (e *SortExec) Open(ctx context.Context) error {
return exec.Open(ctx, e.Children(0))
}

// InitInParallelModeForTest is a function for test
// After system variable is added, we can delete this function
func (e *SortExec) InitInParallelModeForTest() {
e.Parallel.workers = make([]*parallelSortWorker, e.Ctx().GetSessionVars().ExecutorConcurrency)
e.Parallel.chunkChannel = make(chan *chunkWithMemoryUsage, e.Ctx().GetSessionVars().ExecutorConcurrency)
e.Parallel.fetcherAndWorkerSyncer = &sync.WaitGroup{}
e.Parallel.sortedRowsIters = make([]*chunk.Iterator4Slice, len(e.Parallel.workers))
e.Parallel.resultChannel = make(chan rowWithError, e.MaxChunkSize())
e.Parallel.closeSync = make(chan struct{})
e.Parallel.merger = newMultiWayMerger(&memorySource{sortedRowsIters: e.Parallel.sortedRowsIters}, e.lessRow)
e.Parallel.spillHelper = newParallelSortSpillHelper(e, exec.RetTypes(e), e.finishCh, e.lessRow, e.Parallel.resultChannel)
e.Parallel.spillAction = newParallelSortSpillDiskAction(e.Parallel.spillHelper)
for i := range e.Parallel.sortedRowsIters {
e.Parallel.sortedRowsIters[i] = chunk.NewIterator4Slice(nil)
}
if e.enableTmpStorageOnOOM {
e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.Parallel.spillAction)
}
// InitUnparallelModeForTest is for unit test
func (e *SortExec) InitUnparallelModeForTest() {
e.Unparallel.Idx = 0
e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0]
}

// Next implements the Executor Next interface.
Expand Down Expand Up @@ -272,9 +258,13 @@ func (e *SortExec) InitInParallelModeForTest() {
*/
func (e *SortExec) Next(ctx context.Context, req *chunk.Chunk) error {
if e.fetched.CompareAndSwap(false, true) {
e.initCompareFuncs()
err := e.initCompareFuncs()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should remember this error in SortExec, so if initCompareFuncs fails, every call of Next will return error, otherwise, the second call of Next will go to L266 directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should remember this error in SortExec, so if initCompareFuncs fails, every call of Next will return error, otherwise, the second call of Next will go to L266 directly?

Next is called by only one go routine, and once it returns error Next will not be called anymore. I think it's needless to remember it in SortExec.

if err != nil {
return err
}

e.buildKeyColumns()
err := e.fetchChunks(ctx)
err = e.fetchChunks(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -710,6 +700,14 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) {
e.Parallel.resultChannel <- rowWithError{err: err}
}

failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) {
if val.(bool) {
if e.Ctx().GetSessionVars().ConnectionID == 123456 {
e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded)
}
}
})

// We must place it after the spill as workers will process its received
// chunks after channel is closed and this will cause data race.
close(e.Parallel.chunkChannel)
Expand Down Expand Up @@ -753,12 +751,16 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) {
}
}

func (e *SortExec) initCompareFuncs() {
func (e *SortExec) initCompareFuncs() error {
e.keyCmpFuncs = make([]chunk.CompareFunc, len(e.ByItems))
for i := range e.ByItems {
keyType := e.ByItems[i].Expr.GetType()
e.keyCmpFuncs[i] = chunk.GetCompareFunc(keyType)
if e.keyCmpFuncs[i] == nil {
return errors.Errorf("Sort executor not supports type %s", types.TypeStr(keyType.GetType()))
}
}
return nil
}

func (e *SortExec) buildKeyColumns() {
Expand Down
20 changes: 6 additions & 14 deletions pkg/executor/sortexec/sort_spill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ func executeSortExecutor(t *testing.T, exe *sortexec.SortExec, isParallelSort bo
tmpCtx := context.Background()
err := exe.Open(tmpCtx)
require.NoError(t, err)
if isParallelSort {
exe.IsUnparallel = false
exe.InitInParallelModeForTest()
if !isParallelSort {
exe.IsUnparallel = true
exe.InitUnparallelModeForTest()
}

resultChunks := make([]*chunk.Chunk, 0)
Expand All @@ -196,9 +196,9 @@ func executeSortExecutorAndManullyTriggerSpill(t *testing.T, exe *sortexec.SortE
tmpCtx := context.Background()
err := exe.Open(tmpCtx)
require.NoError(t, err)
if isParallelSort {
exe.IsUnparallel = false
exe.InitInParallelModeForTest()
if !isParallelSort {
exe.IsUnparallel = true
exe.InitUnparallelModeForTest()
}

resultChunks := make([]*chunk.Chunk, 0)
Expand Down Expand Up @@ -236,8 +236,6 @@ func onePartitionAndAllDataInMemoryCase(t *testing.T, ctx *mock.Context, sortCas
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 1048576)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand All @@ -259,8 +257,6 @@ func onePartitionAndAllDataInDiskCase(t *testing.T, ctx *mock.Context, sortCase
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 50000)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand Down Expand Up @@ -289,8 +285,6 @@ func multiPartitionCase(t *testing.T, ctx *mock.Context, sortCase *testutil.Sort
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 10000)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand Down Expand Up @@ -330,8 +324,6 @@ func inMemoryThenSpillCase(t *testing.T, ctx *mock.Context, sortCase *testutil.S
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand Down
6 changes: 5 additions & 1 deletion pkg/executor/sortexec/topn.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,11 @@ func (e *TopNExec) fetchChunks(ctx context.Context) error {
}

func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error {
e.initCompareFuncs()
err := e.initCompareFuncs()
if err != nil {
return err
}

e.buildKeyColumns()
e.chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, int(e.Limit.Offset), e.greaterRow, e.RetFieldTypes())
for uint64(e.chkHeap.rowChunks.Len()) < e.chkHeap.totalLimit {
Expand Down
6 changes: 6 additions & 0 deletions pkg/util/chunk/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func GetCompareFunc(tp *types.FieldType) CompareFunc {
return cmpBit
case mysql.TypeJSON:
return cmpJSON
case mysql.TypeNull:
return cmpNullConst
}
return nil
}
Expand Down Expand Up @@ -169,6 +171,10 @@ func cmpJSON(l Row, lCol int, r Row, rCol int) int {
return types.CompareBinaryJSON(lJ, rJ)
}

func cmpNullConst(_ Row, _ int, _ Row, _ int) int {
return 0
}

// Compare compares the value with ad.
// We assume that the collation information of the column is the same with the datum.
func Compare(row Row, colIdx int, ad *types.Datum) int {
Expand Down