From 3e65e9b5c77aded6ebd7c1db65aa482bf5915b94 Mon Sep 17 00:00:00 2001 From: Weizhen Wang Date: Wed, 4 Jan 2023 10:34:20 +0800 Subject: [PATCH] util: gorotinue pool (#39872) close pingcap/tidb#38039 --- resourcemanager/pooltask/BUILD.bazel | 8 + resourcemanager/pooltask/task.go | 132 +++++++ util/gpool/BUILD.bazel | 11 + util/gpool/gpool.go | 69 ++++ util/gpool/spinlock.go | 47 +++ util/gpool/spmc/BUILD.bazel | 43 +++ util/gpool/spmc/main_test.go | 27 ++ util/gpool/spmc/option.go | 138 +++++++ util/gpool/spmc/spmcpool.go | 420 +++++++++++++++++++++ util/gpool/spmc/spmcpool_benchmark_test.go | 111 ++++++ util/gpool/spmc/spmcpool_test.go | 283 ++++++++++++++ util/gpool/spmc/worker.go | 74 ++++ util/gpool/spmc/worker_loop_queue.go | 192 ++++++++++ util/gpool/spmc/worker_loop_queue_test.go | 184 +++++++++ 14 files changed, 1739 insertions(+) create mode 100644 resourcemanager/pooltask/BUILD.bazel create mode 100644 resourcemanager/pooltask/task.go create mode 100644 util/gpool/BUILD.bazel create mode 100644 util/gpool/gpool.go create mode 100644 util/gpool/spinlock.go create mode 100644 util/gpool/spmc/BUILD.bazel create mode 100644 util/gpool/spmc/main_test.go create mode 100644 util/gpool/spmc/option.go create mode 100644 util/gpool/spmc/spmcpool.go create mode 100644 util/gpool/spmc/spmcpool_benchmark_test.go create mode 100644 util/gpool/spmc/spmcpool_test.go create mode 100644 util/gpool/spmc/worker.go create mode 100644 util/gpool/spmc/worker_loop_queue.go create mode 100644 util/gpool/spmc/worker_loop_queue_test.go diff --git a/resourcemanager/pooltask/BUILD.bazel b/resourcemanager/pooltask/BUILD.bazel new file mode 100644 index 0000000000000..c9e37436562ee --- /dev/null +++ b/resourcemanager/pooltask/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "pooltask", + srcs = ["task.go"], + importpath = "github.com/pingcap/tidb/resourcemanager/pooltask", + visibility = ["//visibility:public"], +) diff --git a/resourcemanager/pooltask/task.go b/resourcemanager/pooltask/task.go new file mode 100644 index 0000000000000..ef9b046c8ccba --- /dev/null +++ b/resourcemanager/pooltask/task.go @@ -0,0 +1,132 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pooltask + +import ( + "sync" +) + +// Context is a interface that can be used to create a context. +type Context[T any] interface { + GetContext() T +} + +// NilContext is to create a nil as context +type NilContext struct{} + +// GetContext is to get a nil as context +func (NilContext) GetContext() any { + return nil +} + +// TaskBox is a box which contains all info about pooltask. +type TaskBox[T any, U any, C any, CT any, TF Context[CT]] struct { + constArgs C + contextFunc TF + wg *sync.WaitGroup + task chan Task[T] + resultCh chan U + taskID uint64 +} + +// NewTaskBox is to create a task box for pool. +func NewTaskBox[T any, U any, C any, CT any, TF Context[CT]](constArgs C, contextFunc TF, wg *sync.WaitGroup, taskCh chan Task[T], resultCh chan U, taskID uint64) TaskBox[T, U, C, CT, TF] { + return TaskBox[T, U, C, CT, TF]{ + constArgs: constArgs, + contextFunc: contextFunc, + wg: wg, + task: taskCh, + resultCh: resultCh, + taskID: taskID, + } +} + +// TaskID is to get the task id. +func (t TaskBox[T, U, C, CT, TF]) TaskID() uint64 { + return t.taskID +} + +// ConstArgs is to get the const args. +func (t *TaskBox[T, U, C, CT, TF]) ConstArgs() C { + return t.constArgs +} + +// GetTaskCh is to get the task channel. +func (t *TaskBox[T, U, C, CT, TF]) GetTaskCh() chan Task[T] { + return t.task +} + +// GetResultCh is to get result channel +func (t *TaskBox[T, U, C, CT, TF]) GetResultCh() chan U { + return t.resultCh +} + +// GetContextFunc is to get context func. +func (t *TaskBox[T, U, C, CT, TF]) GetContextFunc() TF { + return t.contextFunc +} + +// Done is to set the pooltask status to complete. +func (t *TaskBox[T, U, C, CT, TF]) Done() { + t.wg.Done() +} + +// Clone is to copy the box +func (t *TaskBox[T, U, C, CT, TF]) Clone() *TaskBox[T, U, C, CT, TF] { + newBox := NewTaskBox[T, U, C, CT, TF](t.constArgs, t.contextFunc, t.wg, t.task, t.resultCh, t.taskID) + return &newBox +} + +// GPool is a goroutine pool. +type GPool[T any, U any, C any, CT any, TF Context[CT]] interface { + Tune(size int) +} + +// TaskController is a controller that can control or watch the pool. +type TaskController[T any, U any, C any, CT any, TF Context[CT]] struct { + pool GPool[T, U, C, CT, TF] + close chan struct{} + wg *sync.WaitGroup + taskID uint64 + resultCh chan U +} + +// NewTaskController create a controller to deal with pooltask's status. +func NewTaskController[T any, U any, C any, CT any, TF Context[CT]](p GPool[T, U, C, CT, TF], taskID uint64, closeCh chan struct{}, wg *sync.WaitGroup, resultCh chan U) TaskController[T, U, C, CT, TF] { + return TaskController[T, U, C, CT, TF]{ + pool: p, + taskID: taskID, + close: closeCh, + wg: wg, + resultCh: resultCh, + } +} + +// Wait is to wait the pool task to stop. +func (t *TaskController[T, U, C, CT, TF]) Wait() { + <-t.close + t.wg.Wait() + close(t.resultCh) +} + +// TaskID is to get the task id. +func (t *TaskController[T, U, C, CT, TF]) TaskID() uint64 { + return t.taskID +} + +// Task is a task that can be executed. +type Task[T any] struct { + Task T +} diff --git a/util/gpool/BUILD.bazel b/util/gpool/BUILD.bazel new file mode 100644 index 0000000000000..04a3dc25e7cd0 --- /dev/null +++ b/util/gpool/BUILD.bazel @@ -0,0 +1,11 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "gpool", + srcs = [ + "gpool.go", + "spinlock.go", + ], + importpath = "github.com/pingcap/tidb/util/gpool", + visibility = ["//visibility:public"], +) diff --git a/util/gpool/gpool.go b/util/gpool/gpool.go new file mode 100644 index 0000000000000..7611d29542a31 --- /dev/null +++ b/util/gpool/gpool.go @@ -0,0 +1,69 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gpool + +import ( + "errors" + "sync/atomic" + "time" +) + +const ( + // DefaultCleanIntervalTime is the interval time to clean up goroutines. + DefaultCleanIntervalTime = 5 * time.Second + + // OPENED represents that the pool is opened. + OPENED = iota + + // CLOSED represents that the pool is closed. + CLOSED +) + +var ( + // ErrPoolClosed will be returned when submitting task to a closed pool. + ErrPoolClosed = errors.New("this pool has been closed") + + // ErrPoolOverload will be returned when the pool is full and no workers available. + ErrPoolOverload = errors.New("too many goroutines blocked on submit or Nonblocking is set") + + // ErrProducerClosed will be returned when the producer is closed. + ErrProducerClosed = errors.New("this producer has been closed") +) + +// BasePool is base class of pool +type BasePool struct { + name string + generator atomic.Uint64 +} + +// NewBasePool is to create a new BasePool. +func NewBasePool() BasePool { + return BasePool{} +} + +// SetName is to set name. +func (p *BasePool) SetName(name string) { + p.name = name +} + +// Name is to get name. +func (p *BasePool) Name() string { + return p.name +} + +// NewTaskID is to get a new task ID. +func (p *BasePool) NewTaskID() uint64 { + return p.generator.Add(1) +} diff --git a/util/gpool/spinlock.go b/util/gpool/spinlock.go new file mode 100644 index 0000000000000..acf7d15192416 --- /dev/null +++ b/util/gpool/spinlock.go @@ -0,0 +1,47 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gpool + +import ( + "runtime" + "sync" + "sync/atomic" +) + +type spinLock uint32 + +const maxBackoff = 16 + +func (sl *spinLock) Lock() { + backoff := 1 + for !atomic.CompareAndSwapUint32((*uint32)(sl), 0, 1) { + // Leverage the exponential backoff algorithm, see https://en.wikipedia.org/wiki/Exponential_backoff. + for i := 0; i < backoff; i++ { + runtime.Gosched() + } + if backoff < maxBackoff { + backoff <<= 1 + } + } +} + +func (sl *spinLock) Unlock() { + atomic.StoreUint32((*uint32)(sl), 0) +} + +// NewSpinLock instantiates a spin-lock. +func NewSpinLock() sync.Locker { + return new(spinLock) +} diff --git a/util/gpool/spmc/BUILD.bazel b/util/gpool/spmc/BUILD.bazel new file mode 100644 index 0000000000000..db48d9771cb17 --- /dev/null +++ b/util/gpool/spmc/BUILD.bazel @@ -0,0 +1,43 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "spmc", + srcs = [ + "option.go", + "spmcpool.go", + "worker.go", + "worker_loop_queue.go", + ], + importpath = "github.com/pingcap/tidb/util/gpool/spmc", + visibility = ["//visibility:public"], + deps = [ + "//resourcemanager/pooltask", + "//util/gpool", + "//util/logutil", + "@com_github_pingcap_errors//:errors", + "@com_github_pingcap_log//:log", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_zap//:zap", + ], +) + +go_test( + name = "spmc_test", + srcs = [ + "main_test.go", + "spmcpool_benchmark_test.go", + "spmcpool_test.go", + "worker_loop_queue_test.go", + ], + embed = [":spmc"], + race = "on", + deps = [ + "//resourcemanager/pooltask", + "//testkit/testsetup", + "//util", + "//util/gpool", + "@com_github_stretchr_testify//require", + "@org_uber_go_atomic//:atomic", + "@org_uber_go_goleak//:goleak", + ], +) diff --git a/util/gpool/spmc/main_test.go b/util/gpool/spmc/main_test.go new file mode 100644 index 0000000000000..381e5302598d5 --- /dev/null +++ b/util/gpool/spmc/main_test.go @@ -0,0 +1,27 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "testing" + + "github.com/pingcap/tidb/testkit/testsetup" + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + testsetup.SetupForCommonTest() + goleak.VerifyTestMain(m) +} diff --git a/util/gpool/spmc/option.go b/util/gpool/spmc/option.go new file mode 100644 index 0000000000000..e317ce157b93d --- /dev/null +++ b/util/gpool/spmc/option.go @@ -0,0 +1,138 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "time" +) + +// Option represents the optional function. +type Option func(opts *Options) + +func loadOptions(options ...Option) *Options { + opts := DefaultOption() + for _, option := range options { + option(opts) + } + return opts +} + +// Options contains all options which will be applied when instantiating an pool. +type Options struct { + // PanicHandler is used to handle panics from each worker goroutine. + // if nil, panics will be thrown out again from worker goroutines. + PanicHandler func(interface{}) + + // ExpiryDuration is a period for the scavenger goroutine to clean up those expired workers, + // the scavenger scans all workers every `ExpiryDuration` and clean up those workers that haven't been + // used for more than `ExpiryDuration`. + ExpiryDuration time.Duration + + // LimitDuration is a period in the limit mode. + LimitDuration time.Duration + + // Max number of goroutine blocking on pool.Submit. + // 0 (default value) means no such limit. + MaxBlockingTasks int + + // When Nonblocking is true, Pool.AddProduce will never be blocked. + // ErrPoolOverload will be returned when Pool.Submit cannot be done at once. + // When Nonblocking is true, MaxBlockingTasks is inoperative. + Nonblocking bool +} + +// DefaultOption is the default option. +func DefaultOption() *Options { + return &Options{ + LimitDuration: 200 * time.Millisecond, + Nonblocking: true, + } +} + +// WithExpiryDuration sets up the interval time of cleaning up goroutines. +func WithExpiryDuration(expiryDuration time.Duration) Option { + return func(opts *Options) { + opts.ExpiryDuration = expiryDuration + } +} + +// WithMaxBlockingTasks sets up the maximum number of goroutines that are blocked when it reaches the capacity of pool. +func WithMaxBlockingTasks(maxBlockingTasks int) Option { + return func(opts *Options) { + opts.MaxBlockingTasks = maxBlockingTasks + } +} + +// WithNonblocking indicates that pool will return nil when there is no available workers. +func WithNonblocking(nonblocking bool) Option { + return func(opts *Options) { + opts.Nonblocking = nonblocking + } +} + +// WithPanicHandler sets up panic handler. +func WithPanicHandler(panicHandler func(interface{})) Option { + return func(opts *Options) { + opts.PanicHandler = panicHandler + } +} + +// TaskOption represents the optional function. +type TaskOption func(opts *TaskOptions) + +func loadTaskOptions(options ...TaskOption) *TaskOptions { + opts := new(TaskOptions) + for _, option := range options { + option(opts) + } + if opts.Concurrency == 0 { + opts.Concurrency = 1 + } + if opts.ResultChanLen == 0 { + opts.ResultChanLen = uint64(opts.Concurrency) + } + if opts.ResultChanLen == 0 { + opts.ResultChanLen = uint64(opts.Concurrency) + } + return opts +} + +// TaskOptions contains all options +type TaskOptions struct { + Concurrency int + ResultChanLen uint64 + TaskChanLen uint64 +} + +// WithResultChanLen is to set the length of result channel. +func WithResultChanLen(resultChanLen uint64) TaskOption { + return func(opts *TaskOptions) { + opts.ResultChanLen = resultChanLen + } +} + +// WithTaskChanLen is to set the length of task channel. +func WithTaskChanLen(taskChanLen uint64) TaskOption { + return func(opts *TaskOptions) { + opts.TaskChanLen = taskChanLen + } +} + +// WithConcurrency is to set the concurrency of task. +func WithConcurrency(c int) TaskOption { + return func(opts *TaskOptions) { + opts.Concurrency = c + } +} diff --git a/util/gpool/spmc/spmcpool.go b/util/gpool/spmc/spmcpool.go new file mode 100644 index 0000000000000..b69c7a05e0eca --- /dev/null +++ b/util/gpool/spmc/spmcpool.go @@ -0,0 +1,420 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/pingcap/log" + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/pingcap/tidb/util/gpool" + "github.com/pingcap/tidb/util/logutil" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// Pool is a single producer, multiple consumer goroutine pool. +// T is the type of the task. We can treat it as input. +// U is the type of the result. We can treat it as output. +// C is the type of the const parameter. if Our task look like y = ax + b, C acts like b as const parameter. +// CT is the type of the context. It needs to be read/written parallel. +// TF is the type of the context getter. It is used to get a context. +// if we don't need to use CT/TF, we can define CT as any and TF as NilContext. +type Pool[T any, U any, C any, CT any, TF pooltask.Context[CT]] struct { + gpool.BasePool + workerCache sync.Pool + workers *loopQueue[T, U, C, CT, TF] + lock sync.Locker + cond *sync.Cond + taskCh chan *pooltask.TaskBox[T, U, C, CT, TF] + options *Options + stopCh chan struct{} + consumerFunc func(T, C, CT) U + capacity atomic.Int32 + running atomic.Int32 + state atomic.Int32 + waiting atomic.Int32 // waiting is the number of goroutines that are waiting for the pool to be available. + heartbeatDone atomic.Bool + + waitingTask atomicutil.Uint32 // waitingTask is the number of tasks that are waiting for the pool to be available. +} + +// NewSPMCPool create a single producer, multiple consumer goroutine pool. +func NewSPMCPool[T any, U any, C any, CT any, TF pooltask.Context[CT]](name string, size int32, options ...Option) (*Pool[T, U, C, CT, TF], error) { + opts := loadOptions(options...) + if expiry := opts.ExpiryDuration; expiry <= 0 { + opts.ExpiryDuration = gpool.DefaultCleanIntervalTime + } + result := &Pool[T, U, C, CT, TF]{ + BasePool: gpool.NewBasePool(), + taskCh: make(chan *pooltask.TaskBox[T, U, C, CT, TF], 128), + stopCh: make(chan struct{}), + lock: gpool.NewSpinLock(), + options: opts, + } + result.SetName(name) + result.state.Store(int32(gpool.OPENED)) + result.workerCache.New = func() interface{} { + return &goWorker[T, U, C, CT, TF]{ + pool: result, + } + } + result.capacity.Add(size) + result.workers = newWorkerLoopQueue[T, U, C, CT, TF](int(size)) + result.cond = sync.NewCond(result.lock) + // Start a goroutine to clean up expired workers periodically. + go result.purgePeriodically() + return result, nil +} + +// purgePeriodically clears expired workers periodically which runs in an individual goroutine, as a scavenger. +func (p *Pool[T, U, C, CT, TF]) purgePeriodically() { + heartbeat := time.NewTicker(p.options.ExpiryDuration) + defer func() { + heartbeat.Stop() + p.heartbeatDone.Store(true) + }() + for { + select { + case <-heartbeat.C: + case <-p.stopCh: + return + } + + if p.IsClosed() { + break + } + + p.lock.Lock() + expiredWorkers := p.workers.retrieveExpiry(p.options.ExpiryDuration) + p.lock.Unlock() + + // Notify obsolete workers to stop. + // This notification must be outside the p.lock, since w.task + // may be blocking and may consume a lot of time if many workers + // are located on non-local CPUs. + for i := range expiredWorkers { + expiredWorkers[i].taskBoxCh <- nil + expiredWorkers[i] = nil + } + + // There might be a situation where all workers have been cleaned up(no worker is running), + // or another case where the pool capacity has been Tuned up, + // while some invokers still get stuck in "p.cond.Wait()", + // then it ought to wake all those invokers. + if p.Running() == 0 || (p.Waiting() > 0 && p.Free() > 0) || p.waitingTask.Load() > 0 { + p.cond.Broadcast() + } + } +} + +// Tune changes the capacity of this pool, note that it is noneffective to the infinite or pre-allocation pool. +func (p *Pool[T, U, C, CT, TF]) Tune(size int) { + capacity := p.Cap() + if capacity == -1 || size <= 0 || size == capacity { + return + } + p.capacity.Store(int32(size)) + if size > capacity { + // boost + if size-capacity == 1 { + p.cond.Signal() + return + } + p.cond.Broadcast() + } +} + +// Running returns the number of workers currently running. +func (p *Pool[T, U, C, CT, TF]) Running() int { + return int(p.running.Load()) +} + +// Free returns the number of available goroutines to work, -1 indicates this pool is unlimited. +func (p *Pool[T, U, C, CT, TF]) Free() int { + c := p.Cap() + if c < 0 { + return -1 + } + return c - p.Running() +} + +// Waiting returns the number of tasks which are waiting be executed. +func (p *Pool[T, U, C, CT, TF]) Waiting() int { + return int(p.waiting.Load()) +} + +// IsClosed indicates whether the pool is closed. +func (p *Pool[T, U, C, CT, TF]) IsClosed() bool { + return p.state.Load() == gpool.CLOSED +} + +// Cap returns the capacity of this pool. +func (p *Pool[T, U, C, CT, TF]) Cap() int { + return int(p.capacity.Load()) +} + +func (p *Pool[T, U, C, CT, TF]) addRunning(delta int) { + p.running.Add(int32(delta)) +} + +func (p *Pool[T, U, C, CT, TF]) addWaiting(delta int) { + p.waiting.Add(int32(delta)) +} + +func (p *Pool[T, U, C, CT, TF]) addWaitingTask() { + p.waitingTask.Inc() +} + +func (p *Pool[T, U, C, CT, TF]) subWaitingTask() { + p.waitingTask.Dec() +} + +// release closes this pool and releases the worker queue. +func (p *Pool[T, U, C, CT, TF]) release() { + if !p.state.CompareAndSwap(gpool.OPENED, gpool.CLOSED) { + return + } + p.lock.Lock() + p.workers.reset() + p.lock.Unlock() + // There might be some callers waiting in retrieveWorker(), so we need to wake them up to prevent + // those callers blocking infinitely. + p.cond.Broadcast() + close(p.taskCh) +} + +func isClose(exitCh chan struct{}) bool { + select { + case <-exitCh: + return true + default: + } + return false +} + +// ReleaseAndWait is like Release, it waits all workers to exit. +func (p *Pool[T, U, C, CT, TF]) ReleaseAndWait() { + if p.IsClosed() || isClose(p.stopCh) { + return + } + + close(p.stopCh) + p.release() + for { + // Wait for all workers to exit and all task to be completed. + if p.Running() == 0 && p.heartbeatDone.Load() && p.waitingTask.Load() == 0 { + return + } + } +} + +// SetConsumerFunc is to set ConsumerFunc which is to process the task. +func (p *Pool[T, U, C, CT, TF]) SetConsumerFunc(consumerFunc func(T, C, CT) U) { + p.consumerFunc = consumerFunc +} + +// AddProduceBySlice is to add Produce by a slice. +// Producer need to return ErrProducerClosed when to exit. +func (p *Pool[T, U, C, CT, TF]) AddProduceBySlice(producer func() ([]T, error), constArg C, contextFn TF, options ...TaskOption) (<-chan U, pooltask.TaskController[T, U, C, CT, TF]) { + opt := loadTaskOptions(options...) + taskID := p.NewTaskID() + var wg sync.WaitGroup + result := make(chan U, opt.ResultChanLen) + closeCh := make(chan struct{}) + inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) + tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) + for i := 0; i < opt.Concurrency; i++ { + err := p.run() + if err == gpool.ErrPoolClosed { + break + } + taskBox := pooltask.NewTaskBox[T, U, C, CT, TF](constArg, contextFn, &wg, inputCh, result, taskID) + p.addWaitingTask() + p.taskCh <- &taskBox + } + go func() { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack")) + } + close(closeCh) + close(inputCh) + }() + for { + tasks, err := producer() + if err != nil { + if errors.Is(err, gpool.ErrProducerClosed) { + return + } + log.Error("producer error", zap.Error(err)) + return + } + for _, task := range tasks { + wg.Add(1) + task := pooltask.Task[T]{ + Task: task, + } + inputCh <- task + } + } + }() + return result, tc +} + +// AddProducer is to add producer. +// Producer need to return ErrProducerClosed when to exit. +func (p *Pool[T, U, C, CT, TF]) AddProducer(producer func() (T, error), constArg C, contextFn TF, options ...TaskOption) (<-chan U, pooltask.TaskController[T, U, C, CT, TF]) { + opt := loadTaskOptions(options...) + taskID := p.NewTaskID() + var wg sync.WaitGroup + result := make(chan U, opt.ResultChanLen) + closeCh := make(chan struct{}) + inputCh := make(chan pooltask.Task[T], opt.TaskChanLen) + tc := pooltask.NewTaskController[T, U, C, CT, TF](p, taskID, closeCh, &wg, result) + for i := 0; i < opt.Concurrency; i++ { + err := p.run() + if err == gpool.ErrPoolClosed { + break + } + p.addWaitingTask() + taskBox := pooltask.NewTaskBox[T, U, C, CT, TF](constArg, contextFn, &wg, inputCh, result, taskID) + p.taskCh <- &taskBox + } + go func() { + defer func() { + if r := recover(); r != nil { + logutil.BgLogger().Error("producer panic", zap.Any("recover", r), zap.Stack("stack")) + } + close(closeCh) + close(inputCh) + }() + for { + task, err := producer() + if err != nil { + if errors.Is(err, gpool.ErrProducerClosed) { + return + } + log.Error("producer error", zap.Error(err)) + return + } + wg.Add(1) + t := pooltask.Task[T]{ + Task: task, + } + inputCh <- t + } + }() + return result, tc +} + +func (p *Pool[T, U, C, CT, TF]) run() error { + if p.IsClosed() { + return gpool.ErrPoolClosed + } + var w *goWorker[T, U, C, CT, TF] + if w = p.retrieveWorker(); w == nil { + return gpool.ErrPoolOverload + } + return nil +} + +// retrieveWorker returns an available worker to run the tasks. +func (p *Pool[T, U, C, CT, TF]) retrieveWorker() (w *goWorker[T, U, C, CT, TF]) { + spawnWorker := func() { + w = p.workerCache.Get().(*goWorker[T, U, C, CT, TF]) + w.taskBoxCh = p.taskCh + w.run() + } + + p.lock.Lock() + + w = p.workers.detach() + if w != nil { // first try to fetch the worker from the queue + p.lock.Unlock() + } else if capacity := p.Cap(); capacity == -1 || capacity > p.Running() { + // if the worker queue is empty and we don't run out of the pool capacity, + // then just spawn a new worker goroutine. + p.lock.Unlock() + spawnWorker() + } else { // otherwise, we'll have to keep them blocked and wait for at least one worker to be put back into pool. + if p.options.Nonblocking { + p.lock.Unlock() + return + } + retry: + if p.options.MaxBlockingTasks != 0 && p.Waiting() >= p.options.MaxBlockingTasks { + p.lock.Unlock() + return + } + p.addWaiting(1) + p.cond.Wait() // block and wait for an available worker + p.addWaiting(-1) + + if p.IsClosed() { + p.lock.Unlock() + return + } + + var nw int + if nw = p.Running(); nw == 0 { // awakened by the scavenger + p.lock.Unlock() + spawnWorker() + return + } + if w = p.workers.detach(); w == nil { + if nw < p.Cap() { + p.lock.Unlock() + spawnWorker() + return + } + goto retry + } + p.lock.Unlock() + } + return +} + +// revertWorker puts a worker back into free pool, recycling the goroutines. +func (p *Pool[T, U, C, CT, TF]) revertWorker(worker *goWorker[T, U, C, CT, TF]) bool { + if capacity := p.Cap(); capacity > 0 && p.Running() > capacity || p.IsClosed() { + p.cond.Broadcast() + return false + } + worker.recycleTime.Store(time.Now()) + p.lock.Lock() + + if p.IsClosed() { + p.lock.Unlock() + return false + } + + err := p.workers.insert(worker) + if err != nil { + p.lock.Unlock() + if err == errQueueIsFull && p.waitingTask.Load() > 0 { + return true + } + return false + } + + // Notify the invoker stuck in 'retrieveWorker()' of there is an available worker in the worker queue. + p.cond.Signal() + p.lock.Unlock() + return true +} diff --git a/util/gpool/spmc/spmcpool_benchmark_test.go b/util/gpool/spmc/spmcpool_benchmark_test.go new file mode 100644 index 0000000000000..db3a4f0824e78 --- /dev/null +++ b/util/gpool/spmc/spmcpool_benchmark_test.go @@ -0,0 +1,111 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "testing" + "time" + + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/gpool" +) + +const ( + RunTimes = 10000 + DefaultExpiredTime = 10 * time.Second +) + +func BenchmarkGPool(b *testing.B) { + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("test", 10) + if err != nil { + b.Fatal(err) + } + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + sema := make(chan struct{}, 10) + var wg util.WaitGroupWrapper + wg.Run(func() { + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(sema) + }) + producerFunc := func() (struct{}, error) { + _, ok := <-sema + if ok { + return struct{}{}, nil + } + return struct{}{}, gpool.ErrProducerClosed + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(6), WithResultChanLen(10)) + exitCh := make(chan struct{}) + wg.Run(func() { + for { + select { + case <-resultCh: + case <-exitCh: + return + } + } + }) + ctl.Wait() + close(exitCh) + wg.Wait() + } +} + +func BenchmarkGoCommon(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + var wg util.WaitGroupWrapper + var wgp util.WaitGroupWrapper + sema := make(chan struct{}, 10) + result := make(chan struct{}, 10) + wg.Run(func() { + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(sema) + }) + + for n := 0; n < 6; n++ { + wg.Run(func() { + item, ok := <-sema + if !ok { + return + } + result <- item + }) + } + exitCh := make(chan struct{}) + wgp.Run(func() { + for { + select { + case <-result: + case <-exitCh: + return + } + } + }) + wg.Wait() + close(exitCh) + wgp.Wait() + } +} diff --git a/util/gpool/spmc/spmcpool_test.go b/util/gpool/spmc/spmcpool_test.go new file mode 100644 index 0000000000000..984f501789c47 --- /dev/null +++ b/util/gpool/spmc/spmcpool_test.go @@ -0,0 +1,283 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/pingcap/tidb/util" + "github.com/pingcap/tidb/util/gpool" + "github.com/stretchr/testify/require" +) + +func TestPool(t *testing.T) { + type ConstArgs struct { + a int + } + myArgs := ConstArgs{a: 10} + // init the pool + // input type, output type, constArgs type + pool, err := NewSPMCPool[int, int, ConstArgs, any, pooltask.NilContext]("TestPool", 10) + require.NoError(t, err) + pool.SetConsumerFunc(func(task int, constArgs ConstArgs, ctx any) int { + return task + constArgs.a + }) + taskCh := make(chan int, 10) + for i := 1; i < 11; i++ { + taskCh <- i + } + pfunc := func() (int, error) { + select { + case task := <-taskCh: + return task, nil + default: + return 0, gpool.ErrProducerClosed + } + } + // add new task + resultCh, control := pool.AddProducer(pfunc, myArgs, pooltask.NilContext{}, WithConcurrency(4)) + + var count atomic.Uint32 + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for result := range resultCh { + count.Add(1) + require.Greater(t, result, 10) + } + }() + // Waiting task finishing + control.Wait() + wg.Wait() + require.Equal(t, uint32(10), count.Load()) + // close pool + pool.ReleaseAndWait() +} + +func TestPoolWithEnoughCapacity(t *testing.T) { + const ( + RunTimes = 1000 + poolsize = 30 + concurrency = 6 + ) + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestPoolWithEnoughCapa", poolsize, WithExpiryDuration(DefaultExpiredTime)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + var twg util.WaitGroupWrapper + for i := 0; i < 3; i++ { + twg.Run(func() { + sema := make(chan struct{}, 10) + var wg util.WaitGroupWrapper + exitCh := make(chan struct{}) + wg.Run(func() { + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }) + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(concurrency)) + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + }) + } + twg.Wait() +} + +func TestPoolWithoutEnoughCapacity(t *testing.T) { + const ( + RunTimes = 5 + concurrency = 2 + poolsize = 2 + ) + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestPoolWithoutEnoughCapa", poolsize, + WithExpiryDuration(DefaultExpiredTime)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + var twg sync.WaitGroup + for i := 0; i < 10; i++ { + func() { + sema := make(chan struct{}, 10) + var wg util.WaitGroupWrapper + exitCh := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }() + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(concurrency)) + + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + }() + } + twg.Wait() +} + +func TestPoolWithoutEnoughCapacityParallel(t *testing.T) { + const ( + RunTimes = 5 + concurrency = 2 + poolsize = 2 + ) + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestPoolWithoutEnoughCapa", poolsize, + WithExpiryDuration(DefaultExpiredTime), WithNonblocking(true)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + var twg sync.WaitGroup + for i := 0; i < 10; i++ { + twg.Add(1) + go func() { + defer twg.Done() + sema := make(chan struct{}, 10) + var wg sync.WaitGroup + exitCh := make(chan struct{}) + wg.Add(1) + go func() { + wg.Done() + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }() + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(concurrency)) + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + }() + } + twg.Wait() +} + +func TestBenchPool(t *testing.T) { + p, err := NewSPMCPool[struct{}, struct{}, int, any, pooltask.NilContext]("TestBenchPool", 10, WithExpiryDuration(DefaultExpiredTime)) + require.NoError(t, err) + defer p.ReleaseAndWait() + p.SetConsumerFunc(func(a struct{}, b int, c any) struct{} { + return struct{}{} + }) + + for i := 0; i < 1000; i++ { + sema := make(chan struct{}, 10) + var wg sync.WaitGroup + exitCh := make(chan struct{}) + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < RunTimes; j++ { + sema <- struct{}{} + } + close(exitCh) + }() + producerFunc := func() (struct{}, error) { + for { + select { + case <-sema: + return struct{}{}, nil + default: + select { + case <-exitCh: + return struct{}{}, gpool.ErrProducerClosed + default: + } + } + } + } + resultCh, ctl := p.AddProducer(producerFunc, RunTimes, pooltask.NilContext{}, WithConcurrency(6)) + wg.Add(1) + go func() { + defer wg.Done() + for range resultCh { + } + }() + ctl.Wait() + wg.Wait() + } + p.ReleaseAndWait() +} diff --git a/util/gpool/spmc/worker.go b/util/gpool/spmc/worker.go new file mode 100644 index 0000000000000..32ff56a790dbd --- /dev/null +++ b/util/gpool/spmc/worker.go @@ -0,0 +1,74 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "github.com/pingcap/log" + "github.com/pingcap/tidb/resourcemanager/pooltask" + atomicutil "go.uber.org/atomic" + "go.uber.org/zap" +) + +// goWorker is the actual executor who runs the tasks, +// it starts a goroutine that accepts tasks and +// performs function calls. +type goWorker[T any, U any, C any, CT any, TF pooltask.Context[CT]] struct { + // pool who owns this worker. + pool *Pool[T, U, C, CT, TF] + + // taskBoxCh is a job should be done. + taskBoxCh chan *pooltask.TaskBox[T, U, C, CT, TF] + + // recycleTime will be updated when putting a worker back into queue. + recycleTime atomicutil.Time +} + +// run starts a goroutine to repeat the process +// that performs the function calls. +func (w *goWorker[T, U, C, CT, TF]) run() { + w.pool.addRunning(1) + go func() { + defer func() { + w.pool.addRunning(-1) + w.pool.workerCache.Put(w) + if p := recover(); p != nil { + if ph := w.pool.options.PanicHandler; ph != nil { + ph(p) + } else { + log.Error("worker exits from a panic", zap.Any("recover", p), zap.Stack("stack")) + } + } + // Call Signal() here in case there are goroutines waiting for available workers. + w.pool.cond.Signal() + }() + + for f := range w.taskBoxCh { + if f == nil { + return + } + w.pool.subWaitingTask() + ctx := f.GetContextFunc().GetContext() + if f.GetResultCh() != nil { + for t := range f.GetTaskCh() { + f.GetResultCh() <- w.pool.consumerFunc(t.Task, f.ConstArgs(), ctx) + f.Done() + } + } + if ok := w.pool.revertWorker(w); !ok { + return + } + } + }() +} diff --git a/util/gpool/spmc/worker_loop_queue.go b/util/gpool/spmc/worker_loop_queue.go new file mode 100644 index 0000000000000..59c7b97fd425a --- /dev/null +++ b/util/gpool/spmc/worker_loop_queue.go @@ -0,0 +1,192 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/resourcemanager/pooltask" +) + +var ( + // errQueueIsFull will be returned when the worker queue is full. + errQueueIsFull = errors.New("the queue is full") + + // errQueueIsReleased will be returned when trying to insert item to a released worker queue. + errQueueIsReleased = errors.New("the queue is released could not accept item anymore") +) + +type loopQueue[T any, U any, C any, CT any, TF pooltask.Context[CT]] struct { + items []*goWorker[T, U, C, CT, TF] + expiry []*goWorker[T, U, C, CT, TF] + head int + tail int + size int + isFull bool +} + +func newWorkerLoopQueue[T any, U any, C any, CT any, TF pooltask.Context[CT]](size int) *loopQueue[T, U, C, CT, TF] { + return &loopQueue[T, U, C, CT, TF]{ + items: make([]*goWorker[T, U, C, CT, TF], size), + size: size, + } +} + +func (wq *loopQueue[T, U, C, CT, TF]) len() int { + if wq.size == 0 { + return 0 + } + + if wq.head == wq.tail { + if wq.isFull { + return wq.size + } + return 0 + } + + if wq.tail > wq.head { + return wq.tail - wq.head + } + + return wq.size - wq.head + wq.tail +} + +func (wq *loopQueue[T, U, C, CT, TF]) isEmpty() bool { + return wq.head == wq.tail && !wq.isFull +} + +func (wq *loopQueue[T, U, C, CT, TF]) insert(worker *goWorker[T, U, C, CT, TF]) error { + if wq.size == 0 { + return errQueueIsReleased + } + + if wq.isFull { + return errQueueIsFull + } + wq.items[wq.tail] = worker + wq.tail++ + + if wq.tail == wq.size { + wq.tail = 0 + } + if wq.tail == wq.head { + wq.isFull = true + } + + return nil +} + +func (wq *loopQueue[T, U, C, CT, TF]) detach() *goWorker[T, U, C, CT, TF] { + if wq.isEmpty() { + return nil + } + + w := wq.items[wq.head] + wq.items[wq.head] = nil + wq.head++ + if wq.head == wq.size { + wq.head = 0 + } + wq.isFull = false + + return w +} + +func (wq *loopQueue[T, U, C, CT, TF]) retrieveExpiry(duration time.Duration) []*goWorker[T, U, C, CT, TF] { + expiryTime := time.Now().Add(-duration) + index := wq.binarySearch(expiryTime) + if index == -1 { + return nil + } + wq.expiry = wq.expiry[:0] + + if wq.head <= index { + wq.expiry = append(wq.expiry, wq.items[wq.head:index+1]...) + for i := wq.head; i < index+1; i++ { + wq.items[i] = nil + } + } else { + wq.expiry = append(wq.expiry, wq.items[0:index+1]...) + wq.expiry = append(wq.expiry, wq.items[wq.head:]...) + for i := 0; i < index+1; i++ { + wq.items[i] = nil + } + for i := wq.head; i < wq.size; i++ { + wq.items[i] = nil + } + } + head := (index + 1) % wq.size + wq.head = head + if len(wq.expiry) > 0 { + wq.isFull = false + } + + return wq.expiry +} + +// binarySearch is to find the first worker which is idle for more than duration. +func (wq *loopQueue[T, U, C, CT, TF]) binarySearch(expiryTime time.Time) int { + var mid, nlen, basel, tmid int + nlen = len(wq.items) + + // if no need to remove work, return -1 + if wq.isEmpty() || expiryTime.Before(wq.items[wq.head].recycleTime.Load()) { + return -1 + } + + // example + // size = 8, head = 7, tail = 4 + // [ 2, 3, 4, 5, nil, nil, nil, 1] true position + // 0 1 2 3 4 5 6 7 + // tail head + // + // 1 2 3 4 nil nil nil 0 mapped position + // r l + + // base algorithm is a copy from worker_stack + // map head and tail to effective left and right + r := (wq.tail - 1 - wq.head + nlen) % nlen + basel = wq.head + l := 0 + for l <= r { + mid = l + ((r - l) >> 1) + // calculate true mid position from mapped mid position + tmid = (mid + basel + nlen) % nlen + if expiryTime.Before(wq.items[tmid].recycleTime.Load()) { + r = mid - 1 + } else { + l = mid + 1 + } + } + // return true position from mapped position + return (r + basel + nlen) % nlen +} + +func (wq *loopQueue[T, U, C, CT, TF]) reset() { + if wq.isEmpty() { + return + } + +Releasing: + if w := wq.detach(); w != nil { + w.taskBoxCh <- nil + goto Releasing + } + wq.items = wq.items[:0] + wq.size = 0 + wq.head = 0 + wq.tail = 0 +} diff --git a/util/gpool/spmc/worker_loop_queue_test.go b/util/gpool/spmc/worker_loop_queue_test.go new file mode 100644 index 0000000000000..da9bdc8dbc36c --- /dev/null +++ b/util/gpool/spmc/worker_loop_queue_test.go @@ -0,0 +1,184 @@ +// Copyright 2022 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package spmc + +import ( + "testing" + "time" + + "github.com/pingcap/tidb/resourcemanager/pooltask" + "github.com/stretchr/testify/require" + atomicutil "go.uber.org/atomic" +) + +func TestNewLoopQueue(t *testing.T) { + size := 100 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + require.EqualValues(t, 0, q.len(), "Len error") + require.Equal(t, true, q.isEmpty(), "IsEmpty error") + require.Nil(t, q.detach(), "Dequeue error") +} + +func TestLoopQueue(t *testing.T) { + size := 10 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + + for i := 0; i < 5; i++ { + err := q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + if err != nil { + break + } + } + require.EqualValues(t, 5, q.len(), "Len error") + _ = q.detach() + require.EqualValues(t, 4, q.len(), "Len error") + + time.Sleep(time.Second) + + for i := 0; i < 6; i++ { + err := q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + if err != nil { + break + } + } + require.EqualValues(t, 10, q.len(), "Len error") + + err := q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + require.Error(t, err, "Enqueue, error") + + q.retrieveExpiry(time.Second) + require.EqualValuesf(t, 6, q.len(), "Len error: %d", q.len()) +} + +func TestRotatedArraySearch(t *testing.T) { + size := 10 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + + expiry1 := time.Now() + + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + + require.EqualValues(t, 0, q.binarySearch(time.Now()), "index should be 0") + require.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") + + expiry2 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + require.EqualValues(t, -1, q.binarySearch(expiry1), "index should be -1") + require.EqualValues(t, 0, q.binarySearch(expiry2), "index should be 0") + require.EqualValues(t, 1, q.binarySearch(time.Now()), "index should be 1") + + for i := 0; i < 5; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + + expiry3 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(expiry3)}) + + var err error + for err != errQueueIsFull { + err = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + + require.EqualValues(t, 7, q.binarySearch(expiry3), "index should be 7") + + // rotate + for i := 0; i < 6; i++ { + _ = q.detach() + } + + expiry4 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(expiry4)}) + + for i := 0; i < 4; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + // head = 6, tail = 5, insert direction -> + // [expiry4, time, time, time, time, nil/tail, time/head, time, time, time] + require.EqualValues(t, 0, q.binarySearch(expiry4), "index should be 0") + + for i := 0; i < 3; i++ { + _ = q.detach() + } + expiry5 := time.Now() + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(expiry5)}) + + // head = 6, tail = 5, insert direction -> + // [expiry4, time, time, time, time, expiry5, nil/tail, nil, nil, time/head] + require.EqualValues(t, 5, q.binarySearch(expiry5), "index should be 5") + + for i := 0; i < 3; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + // head = 9, tail = 9, insert direction -> + // [expiry4, time, time, time, time, expiry5, time, time, time, time/head/tail] + require.EqualValues(t, -1, q.binarySearch(expiry2), "index should be -1") + + require.EqualValues(t, 9, q.binarySearch(q.items[9].recycleTime.Load()), "index should be 9") + require.EqualValues(t, 8, q.binarySearch(time.Now()), "index should be 8") +} + +func TestRetrieveExpiry(t *testing.T) { + size := 10 + q := newWorkerLoopQueue[struct{}, struct{}, int, any, pooltask.NilContext](size) + expirew := make([]*goWorker[struct{}, struct{}, int, any, pooltask.NilContext], 0) + u, _ := time.ParseDuration("1s") + + // test [ time+1s, time+1s, time+1s, time+1s, time+1s, time, time, time, time, time] + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + expirew = append(expirew, q.items[:size/2]...) + time.Sleep(u) + + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + workers := q.retrieveExpiry(u) + + require.EqualValues(t, expirew, workers, "expired workers aren't right") + + // test [ time, time, time, time, time, time+1s, time+1s, time+1s, time+1s, time+1s] + time.Sleep(u) + + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + expirew = expirew[:0] + expirew = append(expirew, q.items[size/2:]...) + + workers2 := q.retrieveExpiry(u) + + require.EqualValues(t, expirew, workers2, "expired workers aren't right") + + // test [ time+1s, time+1s, time+1s, nil, nil, time+1s, time+1s, time+1s, time+1s, time+1s] + for i := 0; i < size/2; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + for i := 0; i < size/2; i++ { + _ = q.detach() + } + for i := 0; i < 3; i++ { + _ = q.insert(&goWorker[struct{}, struct{}, int, any, pooltask.NilContext]{recycleTime: *atomicutil.NewTime(time.Now())}) + } + time.Sleep(u) + + expirew = expirew[:0] + expirew = append(expirew, q.items[0:3]...) + expirew = append(expirew, q.items[size/2:]...) + + workers3 := q.retrieveExpiry(u) + + require.EqualValues(t, expirew, workers3, "expired workers aren't right") +}