diff --git a/go.mod b/go.mod index eec68df2..0ed76f29 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/ThreeDotsLabs/watermill-sql/v2 v2.0.0 github.com/alecthomas/chroma v0.10.0 github.com/dustin/go-humanize v1.0.1 + github.com/failsafe-go/failsafe-go v0.6.1 github.com/getzep/sprig/v3 v3.0.0-20230930153539-1d7fce7d845e github.com/hashicorp/go-retryablehttp v0.7.4 github.com/invopop/jsonschema v0.12.0 @@ -132,7 +133,7 @@ require ( golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/net v0.21.0 // indirect - golang.org/x/sync v0.4.0 // indirect + golang.org/x/sync v0.6.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.14.0 // indirect diff --git a/go.sum b/go.sum index 913fdbd8..34d531df 100644 --- a/go.sum +++ b/go.sum @@ -93,6 +93,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/failsafe-go/failsafe-go v0.6.1 h1:BQhD3FnmEVJ54Dke6nJqp7tsMjXnhEh55Yp0vMLzRi8= +github.com/failsafe-go/failsafe-go v0.6.1/go.mod h1:3QEdMHQN8p1XMbrOSZHeacu6XaEByX5u+h5lg/UOWnY= github.com/fatih/color v1.15.0 h1:kOqh6YHBtK8aywxGerMG2Eq3H6Qgoqeo13Bk2Mv/nBs= github.com/fatih/color v1.15.0/go.mod h1:0h5ZqXfHYED7Bhv2ZJamyIOUej9KtShiJESRwBDUSsw= github.com/felixge/httpsnoop v1.0.2/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= @@ -571,8 +573,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.4.0 h1:zxkM55ReGkDlKSM+Fu41A+zmbZuaPVbGMzvvdUPznYQ= -golang.org/x/sync v0.4.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/models/errors.go b/pkg/models/errors.go index b6583201..123d3b13 100644 --- a/pkg/models/errors.go +++ b/pkg/models/errors.go @@ -44,3 +44,24 @@ func (e *BadRequestError) Unwrap() error { func NewBadRequestError(message string) error { return &BadRequestError{Message: message} } + +var ErrLockAcquisitionFailed = errors.New("failed to acquire advisory lock") + +type AdvisoryLockError struct { + Err error +} + +func (e AdvisoryLockError) Error() string { + if e.Err != nil { + return fmt.Sprintf("failed to acquire advisory lock: %v", e.Err) + } + return ErrLockAcquisitionFailed.Error() +} + +func (e AdvisoryLockError) Unwrap() error { + return ErrLockAcquisitionFailed +} + +func NewAdvisoryLockError(err error) error { + return &AdvisoryLockError{Err: err} +} diff --git a/pkg/store/postgres/memorystore.go b/pkg/store/postgres/memorystore.go index e0657a02..8ebf2ea5 100644 --- a/pkg/store/postgres/memorystore.go +++ b/pkg/store/postgres/memorystore.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/getzep/zep/pkg/store" "github.com/google/uuid" @@ -360,15 +361,33 @@ func (pms *PostgresMemoryStore) PurgeDeleted(ctx context.Context) error { return nil } -// acquireAdvisoryLock acquires a PostgreSQL advisory lock for the given key. -// The lock needs to be released manually by calling releaseAdvisoryLock. -// Accepts a bun.IDB, which can be either a *bun.DB or *bun.Tx. -// Returns the lock ID. -func acquireAdvisoryLock(ctx context.Context, db bun.IDB, key string) (uint64, error) { +func generateLockID(key string) uint64 { hasher := sha256.New() hasher.Write([]byte(key)) hash := hasher.Sum(nil) - lockID := binary.BigEndian.Uint64(hash[:8]) + return binary.BigEndian.Uint64(hash[:8]) +} + +// tryAcquireAdvisoryLock attempts to acquire a PostgreSQL advisory lock using pg_try_advisory_lock. +// This function will fail if it's unable to immediately acquire a lock. +// Accepts a bun.IDB, which can be either a *bun.DB or *bun.Tx. +// Returns the lock ID and a boolean indicating if the lock was successfully acquired. +func tryAcquireAdvisoryLock(ctx context.Context, db bun.IDB, key string) (uint64, error) { + lockID := generateLockID(key) + + var acquired bool + if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock(?)", lockID).Scan(&acquired); err != nil { + return 0, fmt.Errorf("tryAcquireAdvisoryLock: %w", err) + } + if !acquired { + return 0, models.NewAdvisoryLockError(fmt.Errorf("failed to acquire advisory lock for %s", key)) + } + return lockID, nil +} + +// acquireAdvisoryLock acquires a PostgreSQL advisory lock for the given key. +func acquireAdvisoryLock(ctx context.Context, db bun.IDB, key string) (uint64, error) { + lockID := generateLockID(key) if _, err := db.ExecContext(ctx, "SELECT pg_advisory_lock(?)", lockID); err != nil { return 0, store.NewStorageError("failed to acquire advisory lock", err) diff --git a/pkg/store/postgres/userstore.go b/pkg/store/postgres/userstore.go index b907eaa1..dd7254a5 100644 --- a/pkg/store/postgres/userstore.go +++ b/pkg/store/postgres/userstore.go @@ -6,7 +6,10 @@ import ( "errors" "fmt" "sync" + "time" + "github.com/failsafe-go/failsafe-go" + "github.com/failsafe-go/failsafe-go/retrypolicy" "github.com/getzep/zep/pkg/models" "github.com/uptrace/bun" "github.com/uptrace/bun/driver/pgdriver" @@ -94,10 +97,24 @@ func (dao *UserStoreDAO) Update( // Acquire a lock for this UserID. This is to prevent concurrent updates // to the session metadata. - lockID, err := acquireAdvisoryLock(ctx, dao.db, user.UserID) + lockRetryPolicy := retrypolicy.Builder[any](). + HandleErrors(models.ErrLockAcquisitionFailed). + WithDelay(200 * time.Millisecond). + WithMaxRetries(3). + Build() + + lockIDVal, err := failsafe.Get(func() (any, error) { + return tryAcquireAdvisoryLock(ctx, dao.db, user.UserID) + }, lockRetryPolicy) if err != nil { return nil, fmt.Errorf("failed to acquire advisory lock: %w", err) } + + lockID, ok := lockIDVal.(uint64) + if !ok { + return nil, fmt.Errorf("failed to acquire advisory lock: %w", models.ErrLockAcquisitionFailed) + } + defer func(ctx context.Context, db bun.IDB, lockID uint64) { err := releaseAdvisoryLock(ctx, db, lockID) if err != nil {