From e90e78714eea4305fa89b6af9d3d809d5b5db64a Mon Sep 17 00:00:00 2001 From: Paul Wells Date: Mon, 15 Apr 2024 17:38:54 -0700 Subject: [PATCH] Reuse time.Timer in ack and rtx timer utilities --- ack_timer.go | 76 +++++++++++++++++++++-------------------------- ack_timer_test.go | 18 ++++++----- rtx_timer.go | 5 +++- 3 files changed, 48 insertions(+), 51 deletions(-) diff --git a/ack_timer.go b/ack_timer.go index 3d9b43e0..879e86df 100644 --- a/ack_timer.go +++ b/ack_timer.go @@ -4,6 +4,7 @@ package sctp import ( + "math" "sync" "time" ) @@ -17,23 +18,37 @@ type ackTimerObserver interface { onAckTimeout() } +type ackTimerState int + +const ( + ackTimerStopped ackTimerState = iota + ackTimerStarted + ackTimerClosed +) + // ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type ackTimer struct { observer ackTimerObserver - interval time.Duration - stopFunc stopAckTimerLoop - closed bool mutex sync.RWMutex + state ackTimerState + timer *time.Timer } -type stopAckTimerLoop func() - // newAckTimer creates a new acknowledgement timer used to enable delayed ack. func newAckTimer(observer ackTimerObserver) *ackTimer { - return &ackTimer{ - observer: observer, - interval: ackInterval, + t := &ackTimer{observer: observer} + t.timer = time.AfterFunc(math.MaxInt64, t.timeout) + t.timer.Stop() + return t +} + +func (t *ackTimer) timeout() { + t.mutex.Lock() + if t.state == ackTimerStarted { + t.state = ackTimerStopped + defer t.observer.onAckTimeout() } + t.mutex.Unlock() } // start starts the timer. @@ -41,34 +56,13 @@ func (t *ackTimer) start() bool { t.mutex.Lock() defer t.mutex.Unlock() - // this timer is already closed - if t.closed { + // this timer is already closed or already running + if t.state != ackTimerStopped { return false } - // this is a noop if the timer is already running - if t.stopFunc != nil { - return false - } - - cancelCh := make(chan struct{}) - - go func() { - timer := time.NewTimer(t.interval) - - select { - case <-timer.C: - t.stop() - t.observer.onAckTimeout() - case <-cancelCh: - timer.Stop() - } - }() - - t.stopFunc = func() { - close(cancelCh) - } - + t.state = ackTimerStarted + t.timer.Reset(ackInterval) return true } @@ -78,9 +72,9 @@ func (t *ackTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == ackTimerStarted { + t.timer.Stop() + t.state = ackTimerStopped } } @@ -90,12 +84,10 @@ func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() - if t.stopFunc != nil { - t.stopFunc() - t.stopFunc = nil + if t.state == ackTimerStarted { + t.timer.Stop() } - - t.closed = true + t.state = ackTimerClosed } // isRunning tests if the timer is running. @@ -104,5 +96,5 @@ func (t *ackTimer) isRunning() bool { t.mutex.RLock() defer t.mutex.RUnlock() - return (t.stopFunc != nil) + return t.state == ackTimerStarted } diff --git a/ack_timer_test.go b/ack_timer_test.go index 145b6cde..575b3ecc 100644 --- a/ack_timer_test.go +++ b/ack_timer_test.go @@ -70,14 +70,16 @@ func TestAckTimer(t *testing.T) { }, }) - // should start ok - ok := rt.start() - assert.True(t, ok, "start() should succeed") - assert.True(t, rt.isRunning(), "should be running") + for i := 0; i < 2; i++ { + // should start ok + ok := rt.start() + assert.True(t, ok, "start() should succeed") + assert.True(t, rt.isRunning(), "should be running") - // stop immedidately - rt.stop() - assert.False(t, rt.isRunning(), "should not be running") + // stop immedidately + rt.stop() + assert.False(t, rt.isRunning(), "should not be running") + } // Sleep more than 200msec of interval to test if it never times out time.Sleep(ackInterval + 50*time.Millisecond) @@ -86,7 +88,7 @@ func TestAckTimer(t *testing.T) { "should not be timed out (actual: %d)", atomic.LoadUint32(&nCbs)) // can start again - ok = rt.start() + ok := rt.start() assert.True(t, ok, "start() should succeed again") assert.True(t, rt.isRunning(), "should be running") diff --git a/rtx_timer.go b/rtx_timer.go index 42848abc..354825b5 100644 --- a/rtx_timer.go +++ b/rtx_timer.go @@ -175,9 +175,12 @@ func (t *rtxTimer) start(rto float64) bool { go func() { canceling := false + timer := time.NewTimer(math.MaxInt64) + timer.Stop() + for !canceling { timeout := calculateNextTimeout(rto, nRtos, t.rtoMax) - timer := time.NewTimer(time.Duration(timeout) * time.Millisecond) + timer.Reset(time.Duration(timeout) * time.Millisecond) select { case <-timer.C: