diff --git a/session/session.go b/session/session.go index 1c307d16c9fc9..f6c5adf5d7ca0 100644 --- a/session/session.go +++ b/session/session.go @@ -3647,8 +3647,10 @@ func createSessionWithOpt(store kv.Storage, opt *Opt) (*session, error) { // attachStatsCollector attaches the stats collector in the dom for the session func attachStatsCollector(s *session, dom *domain.Domain) *session { if dom.StatsHandle() != nil && dom.StatsUpdating() { - s.statsCollector = dom.StatsHandle().NewSessionStatsCollector() - if GetIndexUsageSyncLease() > 0 { + if s.statsCollector == nil { + s.statsCollector = dom.StatsHandle().NewSessionStatsCollector() + } + if s.idxUsageCollector == nil && GetIndexUsageSyncLease() > 0 { s.idxUsageCollector = dom.StatsHandle().NewSessionIndexUsageCollector() } } @@ -3658,9 +3660,14 @@ func attachStatsCollector(s *session, dom *domain.Domain) *session { // detachStatsCollector removes the stats collector in the session func detachStatsCollector(s *session) *session { - s.statsCollector = nil - s.idxUsageCollector = nil - + if s.statsCollector != nil { + s.statsCollector.Delete() + s.statsCollector = nil + } + if s.idxUsageCollector != nil { + s.idxUsageCollector.Delete() + s.idxUsageCollector = nil + } return s } diff --git a/ttl/ttlworker/task_manager.go b/ttl/ttlworker/task_manager.go index 956e14abf3c13..20f70d7fc49ac 100644 --- a/ttl/ttlworker/task_manager.go +++ b/ttl/ttlworker/task_manager.go @@ -306,6 +306,7 @@ loop: err = idleWorker.Schedule(task.ttlScanTask) if err != nil { logger.Warn("fail to schedule task", zap.Error(err)) + task.cancel() continue } @@ -457,6 +458,8 @@ func (m *taskManager) checkFinishedTask(se session.Session, now time.Time) { stillRunningTasks = append(stillRunningTasks, task) continue } + // we should cancel task to release inner context and avoid memory leak + task.cancel() err := m.reportTaskFinished(se, now, task) if err != nil { logutil.Logger(m.ctx).Error("fail to report finished task", zap.Error(err)) @@ -579,6 +582,11 @@ type runningScanTask struct { result *ttlScanTaskExecResult } +// Context returns context for the task and is only used by test now +func (t *runningScanTask) Context() context.Context { + return t.ctx +} + func (t *runningScanTask) finished() bool { return t.result != nil && t.statistics.TotalRows.Load() == t.statistics.ErrorRows.Load()+t.statistics.SuccessRows.Load() } diff --git a/ttl/ttlworker/task_manager_integration_test.go b/ttl/ttlworker/task_manager_integration_test.go index 209f7c7febc46..e66305806e079 100644 --- a/ttl/ttlworker/task_manager_integration_test.go +++ b/ttl/ttlworker/task_manager_integration_test.go @@ -131,6 +131,7 @@ func TestParallelSchedule(t *testing.T) { require.NoError(t, isc.Update(sessionFactory())) now := time.Now() scheduleWg := sync.WaitGroup{} + finishTasks := make([]func(), 0, 4) for i := 0; i < 4; i++ { workers := []ttlworker.Worker{} for j := 0; j < 4; j++ { @@ -139,7 +140,8 @@ func TestParallelSchedule(t *testing.T) { workers = append(workers, scanWorker) } - m := ttlworker.NewTaskManager(context.Background(), nil, isc, fmt.Sprintf("task-manager-%d", i), store) + managerID := fmt.Sprintf("task-manager-%d", i) + m := ttlworker.NewTaskManager(context.Background(), nil, isc, managerID, store) m.SetScanWorkers4Test(workers) scheduleWg.Add(1) go func() { @@ -147,6 +149,15 @@ func TestParallelSchedule(t *testing.T) { m.RescheduleTasks(se, now) scheduleWg.Done() }() + finishTasks = append(finishTasks, func() { + se := sessionFactory() + for _, task := range m.GetRunningTasks() { + require.Nil(t, task.Context().Err(), fmt.Sprintf("%s %d", managerID, task.ScanID)) + task.SetResult(nil) + m.CheckFinishedTask(se, time.Now()) + require.NotNil(t, task.Context().Err(), fmt.Sprintf("%s %d", managerID, task.ScanID)) + } + }) } scheduleWg.Wait() // all tasks should have been scheduled @@ -154,6 +165,9 @@ func TestParallelSchedule(t *testing.T) { for i := 0; i < 4; i++ { sql := fmt.Sprintf("select count(1) from mysql.tidb_ttl_task where status = 'running' AND owner_id = 'task-manager-%d'", i) tk.MustQuery(sql).Check(testkit.Rows("4")) + finishTasks[i]() + sql = fmt.Sprintf("select count(1) from mysql.tidb_ttl_task where status = 'finished' AND owner_id = 'task-manager-%d'", i) + tk.MustQuery(sql).Check(testkit.Rows("4")) } }