From 834b74f0f3eec42055d1af6ecfe34d448f71d97b Mon Sep 17 00:00:00 2001 From: garethgeorge Date: Thu, 11 Apr 2024 01:46:50 -0700 Subject: [PATCH] fix: bugs in refactored task queue and improved coverage --- internal/queue/timepriorityqueue.go | 2 +- internal/queue/timepriorityqueue_test.go | 31 ++++++++++++++++++++ internal/queue/timequeue.go | 36 +++++++++++------------- 3 files changed, 49 insertions(+), 20 deletions(-) diff --git a/internal/queue/timepriorityqueue.go b/internal/queue/timepriorityqueue.go index 4c924197..8f235733 100644 --- a/internal/queue/timepriorityqueue.go +++ b/internal/queue/timepriorityqueue.go @@ -59,7 +59,7 @@ func (t *TimePriorityQueue[T]) Enqueue(at time.Time, priority int, v T) { func (t *TimePriorityQueue[T]) Dequeue(ctx context.Context) T { t.mu.Lock() for { - for t.tqueue.Len() > 0 { + for t.tqueue.heap.Len() > 0 { thead := t.tqueue.Peek() // peek at the head of the time queue if thead.at.Before(time.Now()) { tqe := heap.Pop(&t.tqueue.heap).(timeQueueEntry[priorityEntry[T]]) diff --git a/internal/queue/timepriorityqueue_test.go b/internal/queue/timepriorityqueue_test.go index 42752ddc..375d744d 100644 --- a/internal/queue/timepriorityqueue_test.go +++ b/internal/queue/timepriorityqueue_test.go @@ -2,6 +2,7 @@ package queue import ( "context" + "math/rand" "testing" "time" ) @@ -53,3 +54,33 @@ func TestTPQMixedReadinessStates(t *testing.T) { } } } + +func TestTPQStress(t *testing.T) { + tpq := NewTimePriorityQueue[int]() + start := time.Now() + + totalEnqueued := 0 + totalEnqueuedSum := 0 + + go func() { + ctx, _ := context.WithDeadline(context.Background(), start.Add(1*time.Second)) + for ctx.Err() == nil { + v := rand.Intn(100) + tpq.Enqueue(time.Now().Add(time.Duration(rand.Intn(1000)-500)*time.Millisecond), rand.Intn(5), v) + totalEnqueuedSum += v + totalEnqueued++ + } + }() + + ctx, _ := context.WithDeadline(context.Background(), start.Add(3*time.Second)) + totalDequeued := 0 + sum := 0 + for ctx.Err() == nil || totalDequeued < totalEnqueued { + sum += tpq.Dequeue(ctx) + totalDequeued++ + } + + if sum != totalEnqueuedSum { + t.Errorf("expected sum to be %d, got %d", totalEnqueuedSum, sum) + } +} diff --git a/internal/queue/timequeue.go b/internal/queue/timequeue.go index 45653b43..988b0e2d 100644 --- a/internal/queue/timequeue.go +++ b/internal/queue/timequeue.go @@ -4,6 +4,7 @@ import ( "container/heap" "context" "sync" + "sync/atomic" "time" ) @@ -13,7 +14,7 @@ type TimeQueue[T any] struct { dequeueMu sync.Mutex mu sync.Mutex - notify chan struct{} + notify atomic.Pointer[chan struct{}] } func NewTimeQueue[T any]() *TimeQueue[T] { @@ -25,10 +26,13 @@ func NewTimeQueue[T any]() *TimeQueue[T] { func (t *TimeQueue[T]) Enqueue(at time.Time, v T) { t.mu.Lock() heap.Push(&t.heap, timeQueueEntry[T]{at, v}) - if t.notify != nil { - t.notify <- struct{}{} - } t.mu.Unlock() + if n := t.notify.Load(); n != nil { + select { + case *n <- struct{}{}: + default: + } + } } func (t *TimeQueue[T]) Len() int { @@ -63,30 +67,24 @@ func (t *TimeQueue[T]) Dequeue(ctx context.Context) T { t.dequeueMu.Lock() defer t.dequeueMu.Unlock() - t.mu.Lock() - t.notify = make(chan struct{}, 1) - defer func() { - t.mu.Lock() - close(t.notify) - t.notify = nil - t.mu.Unlock() - }() - t.mu.Unlock() + notify := make(chan struct{}, 1) + t.notify.Store(¬ify) + defer t.notify.Store(nil) for { t.mu.Lock() - var wait time.Duration - if t.heap.Len() == 0 { - wait = 3 * time.Minute - } else { + if t.heap.Len() > 0 { val := t.heap.Peek() wait = time.Until(val.at) if wait <= 0 { - t.mu.Unlock() + defer t.mu.Unlock() return heap.Pop(&t.heap).(timeQueueEntry[T]).v } } + if wait == 0 || wait > 3*time.Minute { + wait = 3 * time.Minute + } t.mu.Unlock() timer := time.NewTimer(wait) @@ -101,7 +99,7 @@ func (t *TimeQueue[T]) Dequeue(ctx context.Context) T { } t.mu.Unlock() return val.v - case <-t.notify: // new task was added, loop again to ensure we have the earliest task. + case <-notify: // new task was added, loop again to ensure we have the earliest task. if !timer.Stop() { <-timer.C }