Skip to content

Commit

Permalink
Merge pull request #4 from theplant/time
Browse files Browse the repository at this point in the history
Force the use of the driver's internal time to prevent time inconsistency among multiple nodes
  • Loading branch information
molon authored Aug 19, 2024
2 parents e04fab0 + 094ebe1 commit d3d06ef
Show file tree
Hide file tree
Showing 13 changed files with 691 additions and 169 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
gorm
__debug_*
29 changes: 18 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ func runExample(limiter *ratelimiter.RateLimiter, key string) {
Key: key,
DurationPerToken: durationPerToken,
Burst: burst,
Now: now.Add(delta),
Tokens: 1,
MaxFutureReserve: 0,
}
r, err := limiter.Reserve(ctx, reserveReq)
advancedNow := now.Add(delta)
r, err := limiter.Reserve(
// only for test, you should not use this in production !!
ratelimiter.WithNowFuncForTest(ctx, func() time.Time {
return advancedNow
}),
reserveReq,
)
if err != nil {
panic(err)
}
Expand All @@ -40,7 +46,7 @@ func runExample(limiter *ratelimiter.RateLimiter, key string) {
return true
}

fmt.Printf("%v: allowed: %t , you can retry after %v\n", delta, false, r.RetryAfterFrom(reserveReq.Now))
fmt.Printf("%v: allowed: %t , you can retry after %v\n", delta, false, r.RetryAfterFrom(advancedNow))
return false
}

Expand All @@ -64,13 +70,14 @@ func runExample(limiter *ratelimiter.RateLimiter, key string) {
}
}

func ExampleDriverRedis() {

func ExampleInitRedisDriver() {
d, err := ratelimiter.InitRedisDriver(context.Background(), redisCli)
if err != nil {
panic(err)
}
limiter := ratelimiter.New(d)
runExample(limiter, "ExampleDriverRedis")
runExample(limiter, "ExampleInitRedisDriver")
// Output:
// 0s: allowed: true
// 1m0s: allowed: true
Expand Down Expand Up @@ -130,17 +137,17 @@ goarch: arm64
pkg: github.com/theplant/ratelimiter
BenchmarkDriverRedis_Reserve
BenchmarkDriverRedis_Reserve/Key1_Duration10ms_Burst5
BenchmarkDriverRedis_Reserve/Key1_Duration10ms_Burst5-12 6559 181420 ns/op 542 B/op 13 allocs/op
BenchmarkDriverRedis_Reserve/Key1_Duration10ms_Burst5-12 4045 274446 ns/op 692 B/op 16 allocs/op
BenchmarkDriverRedis_Reserve/Key2_Duration20ms_Burst10
BenchmarkDriverRedis_Reserve/Key2_Duration20ms_Burst10-12 6609 183268 ns/op 539 B/op 13 allocs/op
BenchmarkDriverRedis_Reserve/Key2_Duration20ms_Burst10-12 4638 277619 ns/op 688 B/op 16 allocs/op
BenchmarkDriverRedis_Reserve/Key3_Duration50ms_Burst3
BenchmarkDriverRedis_Reserve/Key3_Duration50ms_Burst3-12 6212 174383 ns/op 536 B/op 13 allocs/op
BenchmarkDriverRedis_Reserve/Key3_Duration50ms_Burst3-12 4406 274867 ns/op 688 B/op 16 allocs/op
BenchmarkDriverGORM_Reserve
BenchmarkDriverGORM_Reserve/Key1_Duration10ms_Burst5
BenchmarkDriverGORM_Reserve/Key1_Duration10ms_Burst5-12 1484 694646 ns/op 12276 B/op 160 allocs/op
BenchmarkDriverGORM_Reserve/Key1_Duration10ms_Burst5-12 1108 909553 ns/op 10346 B/op 139 allocs/op
BenchmarkDriverGORM_Reserve/Key2_Duration20ms_Burst10
BenchmarkDriverGORM_Reserve/Key2_Duration20ms_Burst10-12 1724 689210 ns/op 12313 B/op 160 allocs/op
BenchmarkDriverGORM_Reserve/Key2_Duration20ms_Burst10-12 1152 1061490 ns/op 12973 B/op 166 allocs/op
BenchmarkDriverGORM_Reserve/Key3_Duration50ms_Burst3
BenchmarkDriverGORM_Reserve/Key3_Duration50ms_Burst3-12 1712 695863 ns/op 12297 B/op 160 allocs/op
BenchmarkDriverGORM_Reserve/Key3_Duration50ms_Burst3-12 1156 1064777 ns/op 12999 B/op 165 allocs/op
```
18 changes: 9 additions & 9 deletions benchmark_test.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
package ratelimiter_test
package ratelimiter

import (
"context"
"testing"
"time"

"github.com/theplant/ratelimiter"
)

func runBenchmarks(b *testing.B, limiter *ratelimiter.RateLimiter) {
func runBenchmarks(b *testing.B, limiter *RateLimiter) {
ctx := context.Background()

tests := []struct {
Expand All @@ -28,14 +26,16 @@ func runBenchmarks(b *testing.B, limiter *ratelimiter.RateLimiter) {

b.ResetTimer()
for i := 0; i < b.N; i++ {
reserveReq := &ratelimiter.ReserveRequest{
reserveReq := &ReserveRequest{
Key: tt.key,
DurationPerToken: tt.durationPerToken,
Burst: tt.burst,
Now: now.Add(time.Duration(i) * time.Second),
Tokens: 1,
MaxFutureReserve: 0,
}
ctx := WithNowFuncForTest(ctx, func() time.Time {
return now.Add(time.Duration(i) * tt.durationPerToken)
})
_, err := limiter.Reserve(ctx, reserveReq)
if err != nil {
b.Fatalf("failed to reserve: %v", err)
Expand All @@ -46,15 +46,15 @@ func runBenchmarks(b *testing.B, limiter *ratelimiter.RateLimiter) {
}

func BenchmarkDriverRedis_Reserve(b *testing.B) {
driver, err := ratelimiter.InitRedisDriver(context.Background(), redisCli)
driver, err := InitRedisDriver(context.Background(), redisCli)
if err != nil {
b.Fatalf("failed to initialize Redis driver: %v", err)
}
limiter := ratelimiter.New(driver)
limiter := New(driver)
runBenchmarks(b, limiter)
}

func BenchmarkDriverGORM_Reserve(b *testing.B) {
limiter := ratelimiter.New(ratelimiter.DriverGORM(db))
limiter := New(NewGormDriver(db))
runBenchmarks(b, limiter)
}
208 changes: 154 additions & 54 deletions driver_gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,86 +2,186 @@ package ratelimiter

import (
"context"
"fmt"
"strconv"
"strings"
"time"

"github.com/go-sql-driver/mysql"
"github.com/jackc/pgx/v5/pgconn"
"github.com/pkg/errors"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)

type KV struct {
Key string `json:"key" gorm:"primaryKey;not null;"`
Value string `json:"value" gorm:"not null;"`
}

func DriverGORM(db *gorm.DB) Driver {
return DriverFunc(func(ctx context.Context, req *ReserveRequest) (*Reservation, error) {
now := req.Now.UTC() // stripMono
if req.Key == "" || now.IsZero() || req.DurationPerToken <= 0 || req.Burst <= 0 || req.Tokens <= 0 || req.Tokens > req.Burst {
return nil, errors.Wrapf(ErrInvalidParameters, "%v", req)
}
type GormDriver struct {
db *gorm.DB
rawQuery string
}

select {
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "ratelimiter: context done")
default:
}
// NewGormDriver returns a Driver that uses Gorm as the storage.
// Sometimes you may need to auto migrate the KV table, you can use `InitGormDriver` instead.
func NewGormDriver(db *gorm.DB) *GormDriver {
d := &GormDriver{
db: db,
}

resetValue := now.Add(-time.Duration(req.Burst) * req.DurationPerToken)
var currentTimestampQuery string
switch db.Dialector.Name() {
case "mysql":
currentTimestampQuery = "CURRENT_TIMESTAMP(6)"
case "postgres":
currentTimestampQuery = "clock_timestamp()"
default:
// Fallback to a generic solution or handle other databases if needed
currentTimestampQuery = "CURRENT_TIMESTAMP"
}

var timeBase time.Time
var timeToAct time.Time
var ok bool
d.rawQuery = fmt.Sprintf(`
WITH kv_select AS (
SELECT * FROM kvs WHERE key = ? FOR UPDATE
)
SELECT kv.*, %s AS now
FROM (SELECT 1) AS dummy
LEFT JOIN kv_select AS kv ON kv.key = ?;
`, currentTimestampQuery)
return d
}

err := db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var kv KV
// InitGormDriver initializes a GormDriver with the provided Gorm DB.
// Sometimes you may not need to auto migrate the KV table, you can use `NewGormDriver` instead.
func InitGormDriver(ctx context.Context, db *gorm.DB) (*GormDriver, error) {
if err := db.AutoMigrate(&KV{}); err != nil {
return nil, errors.Wrap(err, "ratelimiter: failed to migrate kv")
}

if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).First(&kv, "key = ?", req.Key).Error; err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return errors.Wrap(err, "ratelimiter: failed to get kv")
}
return NewGormDriver(db), nil
}

timeBase = resetValue
kv = KV{Key: req.Key, Value: strconv.FormatInt(timeBase.UnixMicro(), 10)}
if err := tx.Create(&kv).Error; err != nil {
return errors.Wrap(err, "ratelimiter: failed to create kv")
}
} else {
unixMicroBase, err := strconv.ParseInt(kv.Value, 10, 64)
if err != nil {
return errors.Wrap(err, "ratelimiter: failed to parse base time")
}
timeBase = time.UnixMicro(unixMicroBase)

if timeBase.Before(resetValue) {
timeBase = resetValue
}
type kvWrapper struct {
KV
Now time.Time
}

type ctxKeyAfterQuery struct{}

func isDuplicateKeyError(err error) bool {
if err == nil {
return false
}

var pqErr *pgconn.PgError
if errors.As(err, &pqErr) {
return pqErr.Code == "23505"
}

var mysqlErr *mysql.MySQLError
if errors.As(err, &mysqlErr) {
return mysqlErr.Number == 1062
}

errMsg := err.Error()
return strings.Contains(errMsg, "SQLSTATE 23505")
}

func (d *GormDriver) Reserve(ctx context.Context, req *ReserveRequest) (*Reservation, error) {
return d.reserve(ctx, req, 0)
}

func (d *GormDriver) reserve(ctx context.Context, req *ReserveRequest, idx int) (*Reservation, error) {
if req.Key == "" || req.DurationPerToken <= 0 || req.Burst <= 0 || req.Tokens <= 0 || req.Tokens > req.Burst {
return nil, errors.Wrapf(ErrInvalidParameters, "%v", req)
}

select {
case <-ctx.Done():
return nil, errors.Wrap(ctx.Err(), "ratelimiter: context done")
default:
}

var now time.Time
if Test {
nowFunc, exists := NowFuncFromContextForTest(ctx)
if exists {
now = nowFunc().UTC() // stripMono
}
}

var timeBase time.Time
var timeToAct time.Time
var ok bool

err := d.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
var kv kvWrapper

if err := tx.Raw(d.rawQuery, req.Key, req.Key).Scan(&kv).Error; err != nil {
return errors.Wrap(err, "ratelimiter: failed to get kv")
}

if Test {
afterQuery, ok := ctx.Value(ctxKeyAfterQuery{}).(func(kv kvWrapper))
if ok {
afterQuery(kv)
}
}

tokensDuration := req.DurationPerToken * time.Duration(req.Tokens)
timeToAct = timeBase.Add(tokensDuration).UTC()
if now.IsZero() {
now = kv.Now // use db time
}
resetValue := now.Add(-time.Duration(req.Burst) * req.DurationPerToken)

if timeToAct.After(now.Add(req.MaxFutureReserve)) {
ok = false
return nil
if kv.Key == "" { // not found
timeBase = resetValue
if err := tx.Create(&KV{
Key: req.Key,
Value: strconv.FormatInt(timeBase.UnixMicro(), 10),
}).Error; err != nil {
return errors.Wrap(err, "ratelimiter: failed to create kv")
}
} else {
unixMicroBase, err := strconv.ParseInt(kv.Value, 10, 64)
if err != nil {
return errors.Wrap(err, "ratelimiter: failed to parse base time")
}
timeBase = time.UnixMicro(unixMicroBase)

kv.Value = strconv.FormatInt(timeToAct.UnixMicro(), 10)
if err := tx.Save(&kv).Error; err != nil {
return errors.Wrap(err, "ratelimiter: failed to save time to act")
if timeBase.Before(resetValue) {
timeBase = resetValue
}
ok = true
}

tokensDuration := req.DurationPerToken * time.Duration(req.Tokens)
timeToAct = timeBase.Add(tokensDuration).UTC()

if timeToAct.After(now.Add(req.MaxFutureReserve)) {
ok = false
return nil
})
if err != nil {
return nil, err
}

return &Reservation{
ReserveRequest: req,
OK: ok,
TimeToAct: timeToAct,
}, nil
if err := tx.Model(&KV{}).Where("key = ?", req.Key).Update(
"value", strconv.FormatInt(timeToAct.UnixMicro(), 10),
).Error; err != nil {
return errors.Wrap(err, "ratelimiter: failed to save time to act")
}
ok = true
return nil
})
if err != nil {
// retry once if duplicate key errorf
if idx == 0 && isDuplicateKeyError(err) {
return d.reserve(ctx, req, idx+1)
}
return nil, err
}

return &Reservation{
ReserveRequest: req,
OK: ok,
TimeToAct: timeToAct,
Now: now,
}, nil
}
Loading

0 comments on commit d3d06ef

Please sign in to comment.