diff --git a/executor/executor.go b/executor/executor.go index 706348f80ecb5..e03f6fe7e5697 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -28,6 +28,7 @@ import ( "github.com/cznic/mathutil" "github.com/opentracing/opentracing-go" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/model" @@ -1439,6 +1440,12 @@ type UnionExec struct { results []*chunk.Chunk wg sync.WaitGroup initialized bool + mu struct { + *sync.Mutex + maxOpenedChildID int + } + + childInFlightForTest int32 } // unionWorkerResult stores the result for a union worker. @@ -1458,12 +1465,11 @@ func (e *UnionExec) waitAllFinished() { // Open implements the Executor Open interface. func (e *UnionExec) Open(ctx context.Context) error { - if err := e.baseExecutor.Open(ctx); err != nil { - return err - } e.stopFetchData.Store(false) e.initialized = false e.finished = make(chan struct{}) + e.mu.Mutex = &sync.Mutex{} + e.mu.maxOpenedChildID = -1 return nil } @@ -1509,6 +1515,19 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.wg.Done() }() for childID := range e.childIDChan { + e.mu.Lock() + if childID > e.mu.maxOpenedChildID { + e.mu.maxOpenedChildID = childID + } + e.mu.Unlock() + if err := e.children[childID].Open(ctx); err != nil { + result.err = err + e.stopFetchData.Store(true) + e.resultPool <- result + } + failpoint.Inject("issue21441", func() { + atomic.AddInt32(&e.childInFlightForTest, 1) + }) for { if e.stopFetchData.Load().(bool) { return @@ -1523,12 +1542,20 @@ func (e *UnionExec) resultPuller(ctx context.Context, workerID int) { e.resourcePools[workerID] <- result.chk break } + failpoint.Inject("issue21441", func() { + if int(atomic.LoadInt32(&e.childInFlightForTest)) > e.concurrency { + panic("the count of child in flight is larger than e.concurrency unexpectedly") + } + }) e.resultPool <- result if result.err != nil { e.stopFetchData.Store(true) return } } + failpoint.Inject("issue21441", func() { + atomic.AddInt32(&e.childInFlightForTest, -1) + }) } } @@ -1567,7 +1594,15 @@ func (e *UnionExec) Close() error { for range e.childIDChan { } } - return e.baseExecutor.Close() + // We do not need to acquire the e.mu.Lock since all the resultPuller can be + // promised to exit when reaching here (e.childIDChan been closed). + var firstErr error + for i := 0; i <= e.mu.maxOpenedChildID; i++ { + if err := e.children[i].Close(); err != nil && firstErr == nil { + firstErr = err + } + } + return firstErr } // ResetContextOfStmt resets the StmtContext and session variables. diff --git a/executor/executor_test.go b/executor/executor_test.go index 2affecc1728c7..7f5055a1334ba 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -7067,6 +7067,36 @@ func (s *testSuite) TestOOMActionPriority(c *C) { c.Assert(action.GetPriority(), Equals, int64(memory.DefLogPriority)) } +func (s *testSerialSuite) TestIssue21441(c *C) { + failpoint.Enable("github.com/pingcap/tidb/executor/issue21441", `return`) + defer failpoint.Disable("github.com/pingcap/tidb/executor/issue21441") + + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int)") + tk.MustExec(`insert into t values(1),(2),(3)`) + tk.Se.GetSessionVars().InitChunkSize = 1 + tk.Se.GetSessionVars().MaxChunkSize = 1 + sql := ` +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t union all +select a from t` + tk.MustQuery(sql).Sort().Check(testkit.Rows( + "1", "1", "1", "1", "1", "1", "1", "1", + "2", "2", "2", "2", "2", "2", "2", "2", + "3", "3", "3", "3", "3", "3", "3", "3", + )) + + tk.MustQuery("select a from (" + sql + ") t order by a limit 4").Check(testkit.Rows("1", "1", "1", "1")) + tk.MustQuery("select a from (" + sql + ") t order by a limit 7, 4").Check(testkit.Rows("1", "2", "2", "2")) +} + func (s *testSuite) Test17780(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index 05ddf49c033c9..c99a771ec64c5 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -502,9 +502,13 @@ func (s *tikvSnapshot) SetOption(opt kv.Option, val interface{}) { case kv.SnapshotTS: s.setSnapshotTS(val.(uint64)) case kv.ReplicaRead: + s.mu.Lock() s.replicaRead = val.(kv.ReplicaReadType) + s.mu.Unlock() case kv.TaskID: + s.mu.Lock() s.taskID = val.(uint64) + s.mu.Unlock() case kv.CollectRuntimeStats: s.mu.Lock() s.mu.stats = val.(*SnapshotRuntimeStats)