Skip to content

Commit

Permalink
feat: throttling with retryAfter (#422)
Browse files Browse the repository at this point in the history
  • Loading branch information
fracasula authored Apr 11, 2024
1 parent c992d13 commit c2904bf
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 47 deletions.
6 changes: 3 additions & 3 deletions throttling/lua/gcra.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ if remaining < 0 then
current_time_micro,
0, -- allowed
0, -- remaining
tostring(retry_after),
tostring(reset_after),
tonumber(retry_after),
tonumber(reset_after),
}
end

Expand All @@ -62,4 +62,4 @@ if reset_after > 0 then
end

local retry_after = -1
return { current_time_micro, cost, remaining, tostring(retry_after), tostring(reset_after) }
return { current_time_micro, cost, remaining, tonumber(retry_after), tonumber(reset_after) }
11 changes: 9 additions & 2 deletions throttling/lua/sortedset.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ local used_tokens = redis.call('ZCARD', key)

-- If the number of requests is greater than the max requests we hit the limit
if (used_tokens + cost) > tonumber(rate) then
return { current_time_micro, "" }
local next_to_expire = redis.call('ZRANGE', key, 0, 0, 'WITHSCORES')[2]
local retry_after = next_to_expire + period - current_time_micro

return { current_time_micro, "", retry_after }
end

-- seed needed to generate random members in case of collision
Expand Down Expand Up @@ -63,4 +66,8 @@ end
redis.call('EXPIRE', key, period)

members = members:sub(1, -2) -- remove the last comma
return { current_time_micro, members }
return {
current_time_micro,
members,
0 -- no retry_after
}
10 changes: 5 additions & 5 deletions throttling/memory_gcra.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ type gcra struct {
}

func (g *gcra) limit(ctx context.Context, key string, cost, burst, rate, period int64) (
bool, error,
bool, time.Duration, error,
) {
rl, err := g.getLimiter(key, burst, rate, period)
if err != nil {
return false, err
return false, 0, err
}

limited, _, err := rl.RateLimitCtx(ctx, "key", int(cost))
limited, res, err := rl.RateLimitCtx(ctx, "key", int(cost))
if err != nil {
return false, fmt.Errorf("could not rate limit: %w", err)
return false, 0, fmt.Errorf("could not rate limit: %w", err)
}

return !limited, nil
return !limited, res.RetryAfter, nil
}

func (g *gcra) getLimiter(key string, burst, rate, period int64) (*throttled.GCRARateLimiterCtx, error) {
Expand Down
4 changes: 2 additions & 2 deletions throttling/memory_gcra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ func TestMemoryGCRA(t *testing.T) {
rate := int64(1)
period := int64(1)

allowed, err := l.limit(context.Background(), "key", burst+rate, burst, rate, period)
allowed, _, err := l.limit(context.Background(), "key", burst+rate, burst, rate, period)
require.NoError(t, err)
require.True(t, allowed, "it should be able to fill the bucket (burst)")

// next request should be allowed after 5 seconds
start := time.Now()

require.Eventually(t, func() bool {
allowed, err := l.limit(context.Background(), "key", burst, burst, rate, period)
allowed, _, err := l.limit(context.Background(), "key", burst, burst, rate, period)
if err != nil {
t.Logf("Memory GCRA error: %v", err)
return false
Expand Down
88 changes: 56 additions & 32 deletions throttling/throttling.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,128 +78,152 @@ func New(options ...Option) (*Limiter, error) {
// Allow returns true if the limit is not exceeded, false otherwise.
func (l *Limiter) Allow(ctx context.Context, cost, rate, window int64, key string) (
bool, func(context.Context) error, error,
) {
allowed, _, tr, err := l.allow(ctx, cost, rate, window, key)
return allowed, tr, err
}

// AllowAfter returns true if the limit is not exceeded, false otherwise.
// Additionally, it returns the time.Duration until the next allowed request.
func (l *Limiter) AllowAfter(ctx context.Context, cost, rate, window int64, key string) (
bool, time.Duration, func(context.Context) error, error,
) {
return l.allow(ctx, cost, rate, window, key)
}

func (l *Limiter) allow(ctx context.Context, cost, rate, window int64, key string) (
bool, time.Duration, func(context.Context) error, error,
) {
if cost < 1 {
return false, nil, fmt.Errorf("cost must be greater than 0")
return false, 0, nil, fmt.Errorf("cost must be greater than 0")
}
if rate < 1 {
return false, nil, fmt.Errorf("rate must be greater than 0")
return false, 0, nil, fmt.Errorf("rate must be greater than 0")
}
if window < 1 {
return false, nil, fmt.Errorf("window must be greater than 0")
return false, 0, nil, fmt.Errorf("window must be greater than 0")
}
if key == "" {
return false, nil, fmt.Errorf("key must not be empty")
return false, 0, nil, fmt.Errorf("key must not be empty")
}

if l.redisSpeaker != nil {
if l.useGCRA {
defer l.getTimer(key, "redis-gcra", rate, window)()
_, allowed, tr, err := l.redisGCRA(ctx, cost, rate, window, key)
return allowed, tr, err
_, allowed, retryAfter, tr, err := l.redisGCRA(ctx, cost, rate, window, key)
return allowed, retryAfter, tr, err
}

defer l.getTimer(key, "redis-sorted-set", rate, window)()
_, allowed, tr, err := l.redisSortedSet(ctx, cost, rate, window, key)
return allowed, tr, err
_, allowed, retryAfter, tr, err := l.redisSortedSet(ctx, cost, rate, window, key)
return allowed, retryAfter, tr, err
}

defer l.getTimer(key, "gcra", rate, window)()
return l.gcraLimit(ctx, cost, rate, window, key)
allowed, retryAfter, tr, err := l.gcraLimit(ctx, cost, rate, window, key)
return allowed, retryAfter, tr, err
}

func (l *Limiter) redisSortedSet(ctx context.Context, cost, rate, window int64, key string) (
time.Duration, bool, func(context.Context) error, error,
time.Duration, bool, time.Duration, func(context.Context) error, error,
) {
res, err := sortedSetScript.Run(ctx, l.redisSpeaker, []string{key}, cost, rate, window).Result()
if err != nil {
return 0, false, nil, fmt.Errorf("could not run SortedSet Redis script: %v", err)
return 0, false, 0, nil, fmt.Errorf("could not run SortedSet Redis script: %v", err)
}

result, ok := res.([]interface{})
if !ok {
return 0, false, nil, fmt.Errorf("unexpected result from SortedSet Redis script of type %T: %v", res, res)
return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of type %T: %v", res, res)
}
if len(result) != 2 {
return 0, false, nil, fmt.Errorf("unexpected result from SortedSet Redis script of length %d: %+v", len(result), result)
if len(result) != 3 {
return 0, false, 0, nil, fmt.Errorf("unexpected result from SortedSet Redis script of length %d: %+v", len(result), result)
}

t, ok := result[0].(int64)
if !ok {
return 0, false, nil, fmt.Errorf("unexpected result[0] from SortedSet Redis script of type %T: %v", result[0], result[0])
return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from SortedSet Redis script of type %T: %v", result[0], result[0])
}
redisTime := time.Duration(t) * time.Microsecond

members, ok := result[1].(string)
if !ok {
return redisTime, false, nil, fmt.Errorf("unexpected result[1] from SortedSet Redis script of type %T: %v", result[1], result[1])
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from SortedSet Redis script of type %T: %v", result[1], result[1])
}
if members == "" { // limit exceeded
return redisTime, false, nil, nil
retryAfter, ok := result[2].(int64)
if !ok {
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[2] from SortedSet Redis script of type %T: %v", result[2], result[2])
}
return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil
}

r := &sortedSetRedisReturn{
key: key,
members: strings.Split(members, ","),
remover: l.redisSpeaker,
}
return redisTime, true, r.Return, nil
return redisTime, true, 0, r.Return, nil
}

func (l *Limiter) redisGCRA(ctx context.Context, cost, rate, window int64, key string) (
time.Duration, bool, func(context.Context) error, error,
time.Duration, bool, time.Duration, func(context.Context) error, error,
) {
burst := rate
if l.gcraBurst > 0 {
burst = l.gcraBurst
}
res, err := gcraRedisScript.Run(ctx, l.redisSpeaker, []string{key}, burst, rate, window, cost).Result()
if err != nil {
return 0, false, nil, fmt.Errorf("could not run GCRA Redis script: %v", err)
return 0, false, 0, nil, fmt.Errorf("could not run GCRA Redis script: %v", err)
}

result, ok := res.([]interface{})
result, ok := res.([]any)
if !ok {
return 0, false, nil, fmt.Errorf("unexpected result from GCRA Redis script of type %T: %v", res, res)
return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis script of type %T: %v", res, res)
}
if len(result) != 5 {
return 0, false, nil, fmt.Errorf("unexpected result from GCRA Redis scrip of length %d: %+v", len(result), result)
return 0, false, 0, nil, fmt.Errorf("unexpected result from GCRA Redis scrip of length %d: %+v", len(result), result)
}

t, ok := result[0].(int64)
if !ok {
return 0, false, nil, fmt.Errorf("unexpected result[0] from GCRA Redis script of type %T: %v", result[0], result[0])
return 0, false, 0, nil, fmt.Errorf("unexpected result[0] from GCRA Redis script of type %T: %v", result[0], result[0])
}
redisTime := time.Duration(t) * time.Microsecond

allowed, ok := result[1].(int64)
if !ok {
return redisTime, false, nil, fmt.Errorf("unexpected result[1] from GCRA Redis script of type %T: %v", result[1], result[1])
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[1] from GCRA Redis script of type %T: %v", result[1], result[1])
}
if allowed < 1 { // limit exceeded
return redisTime, false, nil, nil
retryAfter, ok := result[3].(int64)
if !ok {
return redisTime, false, 0, nil, fmt.Errorf("unexpected result[3] from GCRA Redis script of type %T: %v", result[3], result[3])
}
return redisTime, false, time.Duration(retryAfter) * time.Microsecond, nil, nil
}

r := &unsupportedReturn{}
return redisTime, true, r.Return, nil
return redisTime, true, 0, r.Return, nil
}

func (l *Limiter) gcraLimit(ctx context.Context, cost, rate, window int64, key string) (
bool, func(context.Context) error, error,
bool, time.Duration, func(context.Context) error, error,
) {
burst := rate
if l.gcraBurst > 0 {
burst = l.gcraBurst
}
allowed, err := l.gcra.limit(ctx, key, cost, burst, rate, window)
allowed, retryAfter, err := l.gcra.limit(ctx, key, cost, burst, rate, window)
if err != nil {
return false, nil, fmt.Errorf("could not limit: %w", err)
return false, 0, nil, fmt.Errorf("could not limit: %w", err)
}
if !allowed {
return false, nil, nil // limit exceeded
return false, retryAfter, nil, nil // limit exceeded
}
r := &unsupportedReturn{}
return true, r.Return, nil
return true, 0, r.Return, nil
}

func (l *Limiter) getTimer(key, algo string, rate, window int64) func() {
Expand Down
Loading

0 comments on commit c2904bf

Please sign in to comment.