diff --git a/adaptive_ratelimit_test.go b/adaptive_ratelimit_test.go deleted file mode 100644 index b507e78..0000000 --- a/adaptive_ratelimit_test.go +++ /dev/null @@ -1,17 +0,0 @@ -package ratelimit_test - -// func TestAdaptiveRateLimit(t *testing.T) { -// limiter := ratelimit.NewUnlimited(context.Background()) -// start := time.Now() - -// for i := 0; i < 132; i++ { -// limiter.Take() -// // got 429 / hit ratelimit after 100 -// if i == 100 { -// // Retry-After and new limiter (calibrate using different statergies) -// // new expected ratelimit 30req every 5 sec -// limiter.SleepandReset(time.Duration(5)*time.Second, 30, time.Duration(5)*time.Second) -// } -// } -// require.Equal(t, time.Since(start).Round(time.Second), time.Duration(10)*time.Second) -// } diff --git a/ratelimit.go b/ratelimit.go index f8974ba..fbb29fa 100644 --- a/ratelimit.go +++ b/ratelimit.go @@ -12,7 +12,7 @@ var minusOne = ^uint32(0) // Limiter allows a burst of request during the defined duration type Limiter struct { - maxCount uint32 + maxCount atomic.Uint32 count atomic.Uint32 ticker *time.Ticker tokens chan struct{} @@ -26,7 +26,7 @@ func (limiter *Limiter) run(ctx context.Context) { for { if limiter.count.Load() == 0 { <-limiter.ticker.C - limiter.count.Store(limiter.maxCount) + limiter.count.Store(limiter.maxCount.Load()) } select { case <-ctx.Done(): @@ -39,7 +39,7 @@ func (limiter *Limiter) run(ctx context.Context) { case limiter.tokens <- struct{}{}: limiter.count.Add(minusOne) case <-limiter.ticker.C: - limiter.count.Store(limiter.maxCount) + limiter.count.Store(limiter.maxCount.Load()) } } } @@ -56,29 +56,18 @@ func (limiter *Limiter) CanTake() bool { // GetLimit returns current rate limit per given duration func (limiter *Limiter) GetLimit() uint { - return uint(limiter.maxCount) + return uint(limiter.maxCount.Load()) } -// TODO: SleepandReset should be able to handle multiple calls without resetting multiple times -// Which is not possible in this implementation -// // SleepandReset stops timer removes all tokens and resets with new limit (used for Adaptive Ratelimiting) -// func (ratelimiter *Limiter) SleepandReset(sleepTime time.Duration, newLimit uint, duration time.Duration) { -// // stop existing Limiter using internalContext -// ratelimiter.cancelFunc() -// // drain any token -// close(ratelimiter.tokens) -// <-ratelimiter.tokens -// // sleep -// time.Sleep(sleepTime) -// //reset and start -// ratelimiter.maxCount = newLimit -// ratelimiter.count = newLimit -// ratelimiter.ticker = time.NewTicker(duration) -// ratelimiter.tokens = make(chan struct{}) -// ctx, cancel := context.WithCancel(context.TODO()) -// ratelimiter.cancelFunc = cancel -// go ratelimiter.run(ctx) -// } +// GetLimit returns current rate limit per given duration +func (limiter *Limiter) SetLimit(max uint) { + limiter.maxCount.Store(uint32(max)) +} + +// GetLimit returns current rate limit per given duration +func (limiter *Limiter) SetDuration(d time.Duration) { + limiter.ticker.Reset(d) +} // Stop the rate limiter canceling the internal context func (limiter *Limiter) Stop() { @@ -91,8 +80,10 @@ func (limiter *Limiter) Stop() { func New(ctx context.Context, max uint, duration time.Duration) *Limiter { internalctx, cancel := context.WithCancel(context.TODO()) + var maxCount atomic.Uint32 + maxCount.Store(uint32(max)) limiter := &Limiter{ - maxCount: uint32(max), + maxCount: maxCount, ticker: time.NewTicker(duration), tokens: make(chan struct{}), ctx: ctx, @@ -108,8 +99,11 @@ func New(ctx context.Context, max uint, duration time.Duration) *Limiter { func NewUnlimited(ctx context.Context) *Limiter { internalctx, cancel := context.WithCancel(context.TODO()) + var maxCount atomic.Uint32 + maxCount.Store(math.MaxUint32) + limiter := &Limiter{ - maxCount: math.MaxUint32, + maxCount: maxCount, ticker: time.NewTicker(time.Millisecond), tokens: make(chan struct{}), ctx: ctx,