diff --git a/clock/clock.go b/clock/clock.go index a3efed96..dd181ce8 100644 --- a/clock/clock.go +++ b/clock/clock.go @@ -53,6 +53,16 @@ type WithTicker interface { NewTicker(time.Duration) Ticker } +// WithDelayedExecution allows for injecting fake or real clocks into +// code that needs to make use of AfterFunc functionality. +type WithDelayedExecution interface { + Clock + // AfterFunc executes f in its own goroutine after waiting + // for d duration and returns a Timer whose channel can be + // closed by calling Stop() on the Timer. + AfterFunc(d time.Duration, f func()) Timer +} + // Ticker defines the Ticker interface. type Ticker interface { C() <-chan time.Time @@ -88,6 +98,13 @@ func (RealClock) NewTimer(d time.Duration) Timer { } } +// AfterFunc is the same as time.AfterFunc(d, f). +func (RealClock) AfterFunc(d time.Duration, f func()) Timer { + return &realTimer{ + timer: time.AfterFunc(d, f), + } +} + // Tick is the same as time.Tick(d) // This method does not allow to free/GC the backing ticker. Use // NewTicker instead. diff --git a/clock/testing/fake_clock.go b/clock/testing/fake_clock.go index 27d0dd58..fb493c4b 100644 --- a/clock/testing/fake_clock.go +++ b/clock/testing/fake_clock.go @@ -49,6 +49,7 @@ type fakeClockWaiter struct { skipIfBlocked bool destChan chan time.Time fired bool + afterFunc func() } // NewFakePassiveClock returns a new FakePassiveClock. @@ -116,6 +117,25 @@ func (f *FakeClock) NewTimer(d time.Duration) clock.Timer { return timer } +// AfterFunc is the Fake version of time.AfterFunc(d, cb). +func (f *FakeClock) AfterFunc(d time.Duration, cb func()) clock.Timer { + f.lock.Lock() + defer f.lock.Unlock() + stopTime := f.time.Add(d) + ch := make(chan time.Time, 1) // Don't block! + + timer := &fakeTimer{ + fakeClock: f, + waiter: fakeClockWaiter{ + targetTime: stopTime, + destChan: ch, + afterFunc: cb, + }, + } + f.waiters = append(f.waiters, &timer.waiter) + return timer +} + // Tick constructs a fake ticker, akin to time.Tick func (f *FakeClock) Tick(d time.Duration) <-chan time.Time { if d <= 0 { @@ -175,7 +195,6 @@ func (f *FakeClock) setTimeLocked(t time.Time) { for i := range f.waiters { w := f.waiters[i] if !w.targetTime.After(t) { - if w.skipIfBlocked { select { case w.destChan <- t: @@ -187,6 +206,10 @@ func (f *FakeClock) setTimeLocked(t time.Time) { w.fired = true } + if w.afterFunc != nil { + w.afterFunc() + } + if w.stepInterval > 0 { for !w.targetTime.After(t) { w.targetTime = w.targetTime.Add(w.stepInterval) @@ -201,7 +224,7 @@ func (f *FakeClock) setTimeLocked(t time.Time) { f.waiters = newWaiters } -// HasWaiters returns true if After has been called on f but not yet satisfied (so you can +// HasWaiters returns true if After or AfterFunc has been called on f but not yet satisfied (so you can // write race-free tests). func (f *FakeClock) HasWaiters() bool { f.lock.RLock() @@ -245,6 +268,12 @@ func (*IntervalClock) NewTimer(d time.Duration) clock.Timer { panic("IntervalClock doesn't implement NewTimer") } +// AfterFunc is unimplemented, will panic. +// TODO: make interval clock use FakeClock so this can be implemented. +func (*IntervalClock) AfterFunc(d time.Duration, f func()) clock.Timer { + panic("IntervalClock doesn't implement AfterFunc") +} + // Tick is unimplemented, will panic. // TODO: make interval clock use FakeClock so this can be implemented. func (*IntervalClock) Tick(d time.Duration) <-chan time.Time { diff --git a/clock/testing/fake_clock_test.go b/clock/testing/fake_clock_test.go index 5def1f63..db7ddbb5 100644 --- a/clock/testing/fake_clock_test.go +++ b/clock/testing/fake_clock_test.go @@ -137,6 +137,66 @@ func TestFakeAfter(t *testing.T) { } } +func TestFakeAfterFunc(t *testing.T) { + tc := NewFakeClock(time.Now()) + if tc.HasWaiters() { + t.Errorf("unexpected waiter?") + } + expectOneSecTimerFire := false + oneSecTimerFire := 0 + tc.AfterFunc(time.Second, func() { + if !expectOneSecTimerFire { + t.Errorf("oneSecTimer func fired") + } else { + oneSecTimerFire++ + } + }) + if !tc.HasWaiters() { + t.Errorf("unexpected lack of waiter?") + } + + expectOneOhOneSecTimerFire := false + oneOhOneSecTimerFire := 0 + tc.AfterFunc(time.Second+time.Millisecond, func() { + if !expectOneOhOneSecTimerFire { + t.Errorf("oneOhOneSecTimer func fired") + } else { + oneOhOneSecTimerFire++ + } + }) + + expectTwoSecTimerFire := false + twoSecTimerFire := 0 + twoSecTimer := tc.AfterFunc(2*time.Second, func() { + if !expectTwoSecTimerFire { + t.Errorf("twoSecTimer func fired") + } else { + twoSecTimerFire++ + } + }) + + tc.Step(999 * time.Millisecond) + + expectOneSecTimerFire = true + tc.Step(time.Millisecond) + if oneSecTimerFire != 1 { + t.Errorf("expected oneSecTimerFire=1, got %d", oneSecTimerFire) + } + expectOneSecTimerFire = false + + expectOneOhOneSecTimerFire = true + tc.Step(time.Millisecond) + if oneOhOneSecTimerFire != 1 { + // should not double-trigger! + t.Errorf("expected oneOhOneSecTimerFire=1, got %d", oneOhOneSecTimerFire) + } + expectOneOhOneSecTimerFire = false + + // ensure a canceled timer doesn't fire + twoSecTimer.Stop() + tc.Step(time.Second) +} + func TestFakeTick(t *testing.T) { tc := NewFakeClock(time.Now()) if tc.HasWaiters() {