diff --git a/executor/builder.go b/executor/builder.go index 4ed4988dd7c92..285c308b87987 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -4790,6 +4790,9 @@ func buildKvRangesForIndexJoin(ctx sessionctx.Context, tableID, indexID int64, l } } if len(kvRanges) != 0 && memTracker != nil { + failpoint.Inject("testIssue49033", func() { + panic("testIssue49033") + }) memTracker.Consume(int64(2 * cap(kvRanges[0].StartKey) * len(kvRanges))) } if len(tmpDatumRanges) != 0 && memTracker != nil { diff --git a/executor/index_lookup_hash_join.go b/executor/index_lookup_hash_join.go index 2224dcb2a1583..7f4ee90582735 100644 --- a/executor/index_lookup_hash_join.go +++ b/executor/index_lookup_hash_join.go @@ -74,6 +74,10 @@ type IndexNestedLoopHashJoin struct { stats *indexLookUpJoinRuntimeStats prepared bool + // panicErr records the error generated by panic recover. This is introduced to + // return the actual error message instead of `context cancelled` to the client. + panicErr error + ctxWithCancel context.Context } type indexHashJoinOuterWorker struct { @@ -149,7 +153,7 @@ func (e *IndexNestedLoopHashJoin) startWorkers(ctx context.Context) { e.stats.concurrency = concurrency } workerCtx, cancelFunc := context.WithCancel(ctx) - e.cancelFunc = cancelFunc + e.ctxWithCancel, e.cancelFunc = workerCtx, cancelFunc innerCh := make(chan *indexHashJoinTask, concurrency) if e.keepOuterOrder { e.taskCh = make(chan *indexHashJoinTask, concurrency) @@ -162,7 +166,7 @@ func (e *IndexNestedLoopHashJoin) startWorkers(ctx context.Context) { e.joinChkResourceCh = make([]chan *chunk.Chunk, concurrency) e.workerWg.Add(1) ow := e.newOuterWorker(innerCh) - go util.WithRecovery(func() { ow.run(workerCtx) }, e.finishJoinWorkers) + go util.WithRecovery(func() { ow.run(e.ctxWithCancel) }, e.finishJoinWorkers) for i := 0; i < concurrency; i++ { if !e.keepOuterOrder { @@ -179,7 +183,7 @@ func (e *IndexNestedLoopHashJoin) startWorkers(ctx context.Context) { e.workerWg.Add(concurrency) for i := 0; i < concurrency; i++ { workerID := i - go util.WithRecovery(func() { e.newInnerWorker(innerCh, workerID).run(workerCtx, cancelFunc) }, e.finishJoinWorkers) + go util.WithRecovery(func() { e.newInnerWorker(innerCh, workerID).run(e.ctxWithCancel, cancelFunc) }, e.finishJoinWorkers) } go e.wait4JoinWorkers() } @@ -194,6 +198,7 @@ func (e *IndexNestedLoopHashJoin) finishJoinWorkers(r interface{}) { task := &indexHashJoinTask{err: err} e.taskCh <- task } + e.panicErr = err if e.cancelFunc != nil { e.cancelFunc() } @@ -219,59 +224,39 @@ func (e *IndexNestedLoopHashJoin) Next(ctx context.Context, req *chunk.Chunk) er } req.Reset() if e.keepOuterOrder { - return e.runInOrder(ctx, req) + return e.runInOrder(e.ctxWithCancel, req) } - // unordered run - var ( - result *indexHashJoinResult - ok bool - ) - select { - case result, ok = <-e.resultCh: - if !ok { - return nil - } - if result.err != nil { - return result.err - } - case <-ctx.Done(): - return ctx.Err() - } - req.SwapColumns(result.chk) - result.src <- result.chk - return nil + return e.runUnordered(e.ctxWithCancel, req) } func (e *IndexNestedLoopHashJoin) runInOrder(ctx context.Context, req *chunk.Chunk) error { - var ( - result *indexHashJoinResult - ok bool - ) for { if e.isDryUpTasks(ctx) { - return nil + return e.panicErr } if e.curTask.err != nil { return e.curTask.err } - select { - case result, ok = <-e.curTask.resultCh: - if !ok { - e.curTask = nil - continue - } - if result.err != nil { - return result.err - } - case <-ctx.Done(): - return ctx.Err() + result, err := e.getResultFromChannel(ctx, e.curTask.resultCh) + if err != nil { + return err } - req.SwapColumns(result.chk) - result.src <- result.chk - return nil + if result == nil { + e.curTask = nil + continue + } + return e.handleResult(req, result) } } +func (e *IndexNestedLoopHashJoin) runUnordered(ctx context.Context, req *chunk.Chunk) error { + result, err := e.getResultFromChannel(ctx, e.resultCh) + if err != nil { + return err + } + return e.handleResult(req, result) +} + // isDryUpTasks indicates whether all the tasks have been processed. func (e *IndexNestedLoopHashJoin) isDryUpTasks(ctx context.Context) bool { if e.curTask != nil { @@ -289,6 +274,38 @@ func (e *IndexNestedLoopHashJoin) isDryUpTasks(ctx context.Context) bool { return false } +func (e *IndexNestedLoopHashJoin) getResultFromChannel(ctx context.Context, resultCh <-chan *indexHashJoinResult) (*indexHashJoinResult, error) { + var ( + result *indexHashJoinResult + ok bool + ) + select { + case result, ok = <-resultCh: + if !ok { + return nil, nil + } + if result.err != nil { + return nil, result.err + } + case <-ctx.Done(): + err := e.panicErr + if err == nil { + err = ctx.Err() + } + return nil, err + } + return result, nil +} + +func (*IndexNestedLoopHashJoin) handleResult(req *chunk.Chunk, result *indexHashJoinResult) error { + if result == nil { + return nil + } + req.SwapColumns(result.chk) + result.src <- result.chk + return nil +} + // Close implements the IndexNestedLoopHashJoin Executor interface. func (e *IndexNestedLoopHashJoin) Close() error { if e.stats != nil { @@ -311,6 +328,7 @@ func (e *IndexNestedLoopHashJoin) Close() error { e.joinChkResourceCh = nil e.finished.Store(false) e.prepared = false + e.ctxWithCancel = nil return e.baseExecutor.Close() } diff --git a/executor/join_test.go b/executor/join_test.go index e6523d020b769..e2113e256bd07 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -15,10 +15,14 @@ package executor_test import ( + "context" "testing" + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/testkit/testdata" + "github.com/stretchr/testify/require" ) func TestNaturalJoin(t *testing.T) { @@ -119,3 +123,36 @@ func TestIssue48991(t *testing.T) { res := tk.MustQuery("SELECT `col_14` FROM `test`.`tbl_3` WHERE ((`tbl_3`.`col_15` < 'dV') AND `tbl_3`.`col_12` IN (SELECT `col_12` FROM `test`.`tbl_3` WHERE NOT (ISNULL(`tbl_3`.`col_12`)))) ORDER BY IF(ISNULL(`col_14`),0,1),`col_14`;") res.Check(testkit.Rows("1984-06-10 00:00:00", "1984-07-31 00:00:00", "2017-06-07 00:00:00")) } + +func TestIssue49033(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test;") + tk.MustExec("drop table if exists t, s;") + tk.MustExec("create table t(a int, index(a));") + tk.MustExec("create table s(a int, index(a));") + tk.MustExec("insert into t values(1), (2), (3), (4), (5), (6), (7), (8), (9), (10), (11), (12), (13), (14), (15), (16), (17), (18), (19), (20), (21), (22), (23), (24), (25), (26), (27), (28), (29), (30), (31), (32), (33), (34), (35), (36), (37), (38), (39), (40), (41), (42), (43), (44), (45), (46), (47), (48), (49), (50), (51), (52), (53), (54), (55), (56), (57), (58), (59), (60), (61), (62), (63), (64), (65), (66), (67), (68), (69), (70), (71), (72), (73), (74), (75), (76), (77), (78), (79), (80), (81), (82), (83), (84), (85), (86), (87), (88), (89), (90), (91), (92), (93), (94), (95), (96), (97), (98), (99), (100), (101), (102), (103), (104), (105), (106), (107), (108), (109), (110), (111), (112), (113), (114), (115), (116), (117), (118), (119), (120), (121), (122), (123), (124), (125), (126), (127), (128);") + tk.MustExec("insert into s values(1), (128);") + tk.MustExec("set @@tidb_max_chunk_size=32;") + tk.MustExec("set @@tidb_index_lookup_join_concurrency=1;") + tk.MustExec("set @@tidb_index_join_batch_size=32;") + tk.MustQuery("select /*+ INL_HASH_JOIN(s) */ * from t join s on t.a=s.a;") + tk.MustQuery("select /*+ INL_HASH_JOIN(s) */ * from t join s on t.a=s.a order by t.a;") + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/executor/testIssue49033", "return")) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/executor/testIssue49033")) + }() + + rs, err := tk.Exec("select /*+ INL_HASH_JOIN(s) */ * from t join s on t.a=s.a order by t.a;") + require.NoError(t, err) + _, err = session.GetRows4Test(context.Background(), nil, rs) + require.EqualError(t, err, "testIssue49033") + require.NoError(t, rs.Close()) + + rs, err = tk.Exec("select /*+ INL_HASH_JOIN(s) */ * from t join s on t.a=s.a") + require.NoError(t, err) + _, err = session.GetRows4Test(context.Background(), nil, rs) + require.EqualError(t, err, "testIssue49033") + require.NoError(t, rs.Close()) +}