diff --git a/pkg/accesscontrol/access_store.go b/pkg/accesscontrol/access_store.go index 6614b9fd..e219bc83 100644 --- a/pkg/accesscontrol/access_store.go +++ b/pkg/accesscontrol/access_store.go @@ -9,6 +9,7 @@ import ( "time" v1 "github.com/rancher/wrangler/v3/pkg/generated/controllers/rbac/v1" + "golang.org/x/sync/singleflight" rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/util/cache" "k8s.io/apiserver/pkg/authentication/user" @@ -39,10 +40,11 @@ type accessStoreCache interface { } type AccessStore struct { - usersPolicyRules policyRules - groupsPolicyRules policyRules - roles roleRevisions - cache accessStoreCache + usersPolicyRules policyRules + groupsPolicyRules policyRules + roles roleRevisions + cache accessStoreCache + concurrentAccessFor *singleflight.Group } type roleKey struct { @@ -52,9 +54,10 @@ type roleKey struct { func NewAccessStore(ctx context.Context, cacheResults bool, rbac v1.Interface) *AccessStore { as := &AccessStore{ - usersPolicyRules: newPolicyRuleIndex(true, rbac), - groupsPolicyRules: newPolicyRuleIndex(false, rbac), - roles: newRoleRevision(ctx, rbac), + usersPolicyRules: newPolicyRuleIndex(true, rbac), + groupsPolicyRules: newPolicyRuleIndex(false, rbac), + roles: newRoleRevision(ctx, rbac), + concurrentAccessFor: new(singleflight.Group), } if cacheResults { as.cache = cache.NewLRUExpireCache(50) @@ -69,16 +72,19 @@ func (l *AccessStore) AccessFor(user user.Info) *AccessSet { cacheKey := l.CacheKey(user) - if val, ok := l.cache.Get(cacheKey); ok { - as, _ := val.(*AccessSet) - return as - } + res, _, _ := l.concurrentAccessFor.Do(cacheKey, func() (interface{}, error) { + if val, ok := l.cache.Get(cacheKey); ok { + as, _ := val.(*AccessSet) + return as, nil + } - result := l.newAccessSet(user) - result.ID = cacheKey - l.cache.Add(cacheKey, result, 24*time.Hour) + result := l.newAccessSet(user) + result.ID = cacheKey + l.cache.Add(cacheKey, result, 24*time.Hour) - return result + return result, nil + }) + return res.(*AccessSet) } func (l *AccessStore) newAccessSet(user user.Info) *AccessSet { diff --git a/pkg/accesscontrol/access_store_test.go b/pkg/accesscontrol/access_store_test.go index 543723cd..5d81ec37 100644 --- a/pkg/accesscontrol/access_store_test.go +++ b/pkg/accesscontrol/access_store_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "golang.org/x/sync/singleflight" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -197,6 +198,7 @@ func TestAccessStore_AccessFor(t *testing.T) { } asCache := cache.NewLRUExpireCache(10) store := &AccessStore{ + concurrentAccessFor: new(singleflight.Group), usersPolicyRules: &policyRulesMock{ getRBFunc: func(s string) []*rbacv1.RoleBinding { return []*rbacv1.RoleBinding{ @@ -301,10 +303,10 @@ func (c *spyCache) observeAdd(k interface{}) { } func TestAccessStore_AccessFor_concurrent(t *testing.T) { - t.Skipf("TODO - Add a fix for this test") testUser := &user.DefaultInfo{Name: "test-user"} asCache := &spyCache{accessStoreCache: cache.NewLRUExpireCache(100)} store := &AccessStore{ + concurrentAccessFor: new(singleflight.Group), roles: roleRevisionsMock(func(ns, name string) string { return fmt.Sprintf("%s%srev", ns, name) }),