diff --git a/pkg/executor/sortexec/parallel_sort_spill_test.go b/pkg/executor/sortexec/parallel_sort_spill_test.go index 26d2fc4cf9b81..44378dcf0613b 100644 --- a/pkg/executor/sortexec/parallel_sort_spill_test.go +++ b/pkg/executor/sortexec/parallel_sort_spill_test.go @@ -29,7 +29,7 @@ import ( var hardLimit1 = int64(100000) var hardLimit2 = hardLimit1 * 10 -func oneSpillCase(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) { +func oneSpillCase(t *testing.T, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) { if exe == nil { exe = buildSortExec(sortCase, dataSource) } @@ -60,7 +60,7 @@ func inMemoryThenSpill(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, require.True(t, checkCorrectness(schema, exe, dataSource, resultChunks)) } -func failpointNoMemoryDataTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) { +func failpointNoMemoryDataTest(t *testing.T, exe *sortexec.SortExec, sortCase *testutil.SortCase, dataSource *testutil.MockDataSource) { if exe == nil { exe = buildSortExec(sortCase, dataSource) } @@ -68,7 +68,7 @@ func failpointNoMemoryDataTest(t *testing.T, ctx *mock.Context, exe *sortexec.So executeInFailpoint(t, exe, 0, nil) } -func failpointDataInMemoryThenSpillTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) { +func failpointDataInMemoryThenSpillTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, dataSource *testutil.MockDataSource) { if exe == nil { exe = buildSortExec(sortCase, dataSource) } @@ -91,15 +91,13 @@ func TestParallelSortSpillDisk(t *testing.T) { ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit1) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = true schema := expression.NewSchema(sortCase.Columns()...) dataSource := buildDataSource(sortCase, schema) exe := buildSortExec(sortCase, dataSource) for i := 0; i < 10; i++ { - oneSpillCase(t, ctx, nil, sortCase, schema, dataSource) - oneSpillCase(t, ctx, exe, sortCase, schema, dataSource) + oneSpillCase(t, nil, sortCase, schema, dataSource) + oneSpillCase(t, exe, sortCase, schema, dataSource) } ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2) @@ -129,21 +127,19 @@ func TestParallelSortSpillDiskFailpoint(t *testing.T) { ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit1) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = true schema := expression.NewSchema(sortCase.Columns()...) dataSource := buildDataSource(sortCase, schema) exe := buildSortExec(sortCase, dataSource) for i := 0; i < 20; i++ { - failpointNoMemoryDataTest(t, ctx, nil, sortCase, schema, dataSource) - failpointNoMemoryDataTest(t, ctx, exe, sortCase, schema, dataSource) + failpointNoMemoryDataTest(t, nil, sortCase, dataSource) + failpointNoMemoryDataTest(t, exe, sortCase, dataSource) } ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) for i := 0; i < 20; i++ { - failpointDataInMemoryThenSpillTest(t, ctx, nil, sortCase, schema, dataSource) - failpointDataInMemoryThenSpillTest(t, ctx, exe, sortCase, schema, dataSource) + failpointDataInMemoryThenSpillTest(t, ctx, nil, sortCase, dataSource) + failpointDataInMemoryThenSpillTest(t, ctx, exe, sortCase, dataSource) } } diff --git a/pkg/executor/sortexec/parallel_sort_test.go b/pkg/executor/sortexec/parallel_sort_test.go index 096cd41026899..36d6858c28a45 100644 --- a/pkg/executor/sortexec/parallel_sort_test.go +++ b/pkg/executor/sortexec/parallel_sort_test.go @@ -36,8 +36,6 @@ func executeInFailpoint(t *testing.T, exe *sortexec.SortExec, hardLimit int64, t tmpCtx := context.Background() err := exe.Open(tmpCtx) require.NoError(t, err) - exe.IsUnparallel = false - exe.InitInParallelModeForTest() goRoutineWaiter := sync.WaitGroup{} goRoutineWaiter.Add(1) @@ -85,8 +83,6 @@ func parallelSortTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, s ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = true if exe == nil { exe = buildSortExec(sortCase, dataSource) @@ -105,8 +101,6 @@ func failpointTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sort ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = true if exe == nil { exe = buildSortExec(sortCase, dataSource) } diff --git a/pkg/executor/sortexec/parallel_sort_worker.go b/pkg/executor/sortexec/parallel_sort_worker.go index ecf43fd3b8500..082b935060ecc 100644 --- a/pkg/executor/sortexec/parallel_sort_worker.go +++ b/pkg/executor/sortexec/parallel_sort_worker.go @@ -79,6 +79,14 @@ func newParallelSortWorker( } } +func (p *parallelSortWorker) reset() { + p.batchRows = nil + p.localSortedRows = nil + p.sortedRowsIter = nil + p.merger = nil + p.memTracker.ReplaceBytesUsed(0) +} + func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor int32) { injectParallelSortRandomFail(triggerFactor) failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) { diff --git a/pkg/executor/sortexec/sort.go b/pkg/executor/sortexec/sort.go index 122e7ccb86a0e..3ba107d3ff042 100644 --- a/pkg/executor/sortexec/sort.go +++ b/pkg/executor/sortexec/sort.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "time" + "github.com/pingcap/errors" "github.com/pingcap/failpoint" "github.com/pingcap/tidb/pkg/executor/internal/exec" "github.com/pingcap/tidb/pkg/expression" @@ -55,6 +56,7 @@ type SortExec struct { memTracker *memory.Tracker diskTracker *disk.Tracker + // TODO delete this variable in the future and remove the unparallel sort IsUnparallel bool finishCh chan struct{} @@ -124,11 +126,9 @@ func (e *SortExec) Close() error { // will use `e.Parallel.workers` and `e.Parallel.merger`. channel.Clear(e.Parallel.resultChannel) for i := range e.Parallel.workers { - e.Parallel.workers[i].batchRows = nil - e.Parallel.workers[i].localSortedRows = nil - e.Parallel.workers[i].sortedRowsIter = nil - e.Parallel.workers[i].merger = nil - e.Parallel.workers[i].memTracker.ReplaceBytesUsed(0) + if e.Parallel.workers[i] != nil { + e.Parallel.workers[i].reset() + } } e.Parallel.merger = nil if e.Parallel.spillAction != nil { @@ -160,7 +160,7 @@ func (e *SortExec) Open(ctx context.Context) error { e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker) } - e.IsUnparallel = true + e.IsUnparallel = false if e.IsUnparallel { e.Unparallel.Idx = 0 e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0] @@ -185,24 +185,10 @@ func (e *SortExec) Open(ctx context.Context) error { return exec.Open(ctx, e.Children(0)) } -// InitInParallelModeForTest is a function for test -// After system variable is added, we can delete this function -func (e *SortExec) InitInParallelModeForTest() { - e.Parallel.workers = make([]*parallelSortWorker, e.Ctx().GetSessionVars().ExecutorConcurrency) - e.Parallel.chunkChannel = make(chan *chunkWithMemoryUsage, e.Ctx().GetSessionVars().ExecutorConcurrency) - e.Parallel.fetcherAndWorkerSyncer = &sync.WaitGroup{} - e.Parallel.sortedRowsIters = make([]*chunk.Iterator4Slice, len(e.Parallel.workers)) - e.Parallel.resultChannel = make(chan rowWithError, e.MaxChunkSize()) - e.Parallel.closeSync = make(chan struct{}) - e.Parallel.merger = newMultiWayMerger(&memorySource{sortedRowsIters: e.Parallel.sortedRowsIters}, e.lessRow) - e.Parallel.spillHelper = newParallelSortSpillHelper(e, exec.RetTypes(e), e.finishCh, e.lessRow, e.Parallel.resultChannel) - e.Parallel.spillAction = newParallelSortSpillDiskAction(e.Parallel.spillHelper) - for i := range e.Parallel.sortedRowsIters { - e.Parallel.sortedRowsIters[i] = chunk.NewIterator4Slice(nil) - } - if e.enableTmpStorageOnOOM { - e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.Parallel.spillAction) - } +// InitUnparallelModeForTest is for unit test +func (e *SortExec) InitUnparallelModeForTest() { + e.Unparallel.Idx = 0 + e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0] } // Next implements the Executor Next interface. @@ -272,9 +258,13 @@ func (e *SortExec) InitInParallelModeForTest() { */ func (e *SortExec) Next(ctx context.Context, req *chunk.Chunk) error { if e.fetched.CompareAndSwap(false, true) { - e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) + err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) + if err != nil { + return err + } + e.buildKeyColumns() - err := e.fetchChunks(ctx) + err = e.fetchChunks(ctx) if err != nil { return err } @@ -710,6 +700,14 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) { e.Parallel.resultChannel <- rowWithError{err: err} } + failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) { + if val.(bool) { + if e.Ctx().GetSessionVars().ConnectionID == 123456 { + e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded) + } + } + }) + // We must place it after the spill as workers will process its received // chunks after channel is closed and this will cause data race. close(e.Parallel.chunkChannel) @@ -753,12 +751,16 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) { } } -func (e *SortExec) initCompareFuncs(ctx expression.EvalContext) { +func (e *SortExec) initCompareFuncs(ctx expression.EvalContext) error { e.keyCmpFuncs = make([]chunk.CompareFunc, len(e.ByItems)) for i := range e.ByItems { keyType := e.ByItems[i].Expr.GetType(ctx) e.keyCmpFuncs[i] = chunk.GetCompareFunc(keyType) + if e.keyCmpFuncs[i] == nil { + return errors.Errorf("Sort executor not supports type %s", types.TypeStr(keyType.GetType())) + } } + return nil } func (e *SortExec) buildKeyColumns() { diff --git a/pkg/executor/sortexec/sort_spill_test.go b/pkg/executor/sortexec/sort_spill_test.go index bff19e1a2ff9c..b72a777a444d7 100644 --- a/pkg/executor/sortexec/sort_spill_test.go +++ b/pkg/executor/sortexec/sort_spill_test.go @@ -177,9 +177,9 @@ func executeSortExecutor(t *testing.T, exe *sortexec.SortExec, isParallelSort bo tmpCtx := context.Background() err := exe.Open(tmpCtx) require.NoError(t, err) - if isParallelSort { - exe.IsUnparallel = false - exe.InitInParallelModeForTest() + if !isParallelSort { + exe.IsUnparallel = true + exe.InitUnparallelModeForTest() } resultChunks := make([]*chunk.Chunk, 0) @@ -199,9 +199,9 @@ func executeSortExecutorAndManullyTriggerSpill(t *testing.T, exe *sortexec.SortE tmpCtx := context.Background() err := exe.Open(tmpCtx) require.NoError(t, err) - if isParallelSort { - exe.IsUnparallel = false - exe.InitInParallelModeForTest() + if !isParallelSort { + exe.IsUnparallel = true + exe.InitUnparallelModeForTest() } resultChunks := make([]*chunk.Chunk, 0) @@ -239,8 +239,6 @@ func onePartitionAndAllDataInMemoryCase(t *testing.T, ctx *mock.Context, sortCas ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 1048576) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = false schema := expression.NewSchema(sortCase.Columns()...) dataSource := buildDataSource(sortCase, schema) exe := buildSortExec(sortCase, dataSource) @@ -262,8 +260,6 @@ func onePartitionAndAllDataInDiskCase(t *testing.T, ctx *mock.Context, sortCase ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 50000) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = false schema := expression.NewSchema(sortCase.Columns()...) dataSource := buildDataSource(sortCase, schema) exe := buildSortExec(sortCase, dataSource) @@ -292,8 +288,6 @@ func multiPartitionCase(t *testing.T, ctx *mock.Context, sortCase *testutil.Sort ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 10000) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = false schema := expression.NewSchema(sortCase.Columns()...) dataSource := buildDataSource(sortCase, schema) exe := buildSortExec(sortCase, dataSource) @@ -333,8 +327,6 @@ func inMemoryThenSpillCase(t *testing.T, ctx *mock.Context, sortCase *testutil.S ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit) ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1) ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker) - // TODO use variable to choose parallel mode after system variable is added - // ctx.GetSessionVars().EnableParallelSort = false schema := expression.NewSchema(sortCase.Columns()...) dataSource := buildDataSource(sortCase, schema) exe := buildSortExec(sortCase, dataSource) diff --git a/pkg/executor/sortexec/topn.go b/pkg/executor/sortexec/topn.go index e6b6104d936a3..f9c7ed779c51c 100644 --- a/pkg/executor/sortexec/topn.go +++ b/pkg/executor/sortexec/topn.go @@ -261,7 +261,11 @@ func (e *TopNExec) fetchChunks(ctx context.Context) error { } func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error { - e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) + err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx()) + if err != nil { + return err + } + e.buildKeyColumns() e.chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, int(e.Limit.Offset), e.greaterRow, e.RetFieldTypes()) for uint64(e.chkHeap.rowChunks.Len()) < e.chkHeap.totalLimit { diff --git a/pkg/util/chunk/compare.go b/pkg/util/chunk/compare.go index b2f6f6bd0d0d4..86ab092f66f84 100644 --- a/pkg/util/chunk/compare.go +++ b/pkg/util/chunk/compare.go @@ -53,6 +53,8 @@ func GetCompareFunc(tp *types.FieldType) CompareFunc { return cmpBit case mysql.TypeJSON: return cmpJSON + case mysql.TypeNull: + return cmpNullConst } return nil } @@ -169,6 +171,10 @@ func cmpJSON(l Row, lCol int, r Row, rCol int) int { return types.CompareBinaryJSON(lJ, rJ) } +func cmpNullConst(_ Row, _ int, _ Row, _ int) int { + return 0 +} + // Compare compares the value with ad. // We assume that the collation information of the column is the same with the datum. func Compare(row Row, colIdx int, ad *types.Datum) int {