Skip to content

Commit

Permalink
cache AWS Config's CredentialsProvider to reduce STS calls
Browse files Browse the repository at this point in the history
Signed-off-by: Erhan Cagirici <erhan@upbound.io>
  • Loading branch information
erhancagirici committed Mar 22, 2024
1 parent cacdcfc commit 607c80f
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 1 deletion.
5 changes: 4 additions & 1 deletion internal/clients/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ func SelectTerraformSetup(config *SetupConfig) terraform.SetupFn { // nolint:goc
} else if awsCfg == nil {
return terraform.Setup{}, errors.Wrap(err, "obtained aws config cannot be nil")
}
creds, err := awsCfg.Credentials.Retrieve(ctx)

// only IRSA auth credentials are cached, other auth methods will skip
// cache and call downstream Credentials.Retrieve() of given awsCfg
creds, err := GlobalAWSCredentialsProviderCache.RetrieveCredentials(ctx, pc, awsCfg)
if err != nil {
return terraform.Setup{}, errors.Wrap(err, "failed to retrieve aws credentials from aws config")
}
Expand Down
200 changes: 200 additions & 0 deletions internal/clients/creds_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// SPDX-FileCopyrightText: 2024 The Crossplane Authors <https://crossplane.io>
//
// SPDX-License-Identifier: Apache-2.0

package clients

import (
"context"
"crypto/sha256"
"fmt"
"github.com/pkg/errors"
"io"
"os"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
"strings"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/crossplane/crossplane-runtime/pkg/logging"
"github.com/upbound/provider-aws/apis/v1beta1"
)

// GlobalAWSCredentialsProviderCache is a global AWS CredentialsProvider cache to be used by all controllers.
var GlobalAWSCredentialsProviderCache = NewAWSCredentialsProviderCache()

// AWSCredentialsProviderCacheOption lets you configure *GlobalAWSCredentialsProviderCache.
type AWSCredentialsProviderCacheOption func(cache *AWSCredentialsProviderCache)

// WithCacheMaxSize lets you override the default MaxSize for AWS CredentialsProvider cache.
func WithCacheMaxSize(n int) AWSCredentialsProviderCacheOption {
return func(c *AWSCredentialsProviderCache) {
c.maxSize = n
}
}

// WithCacheStore lets you bootstrap AWS CredentialsProvider Cache with your own cache.
func WithCacheStore(cache map[string]*awsCredentialsProviderCacheEntry) AWSCredentialsProviderCacheOption {
return func(c *AWSCredentialsProviderCache) {
c.cache = cache
}
}

// WithCacheLogger lets you configure the logger for the cache.
func WithCacheLogger(l logging.Logger) AWSCredentialsProviderCacheOption {
return func(c *AWSCredentialsProviderCache) {
c.logger = l
}
}

// NewAWSCredentialsProviderCache returns a new empty *AWSCredentialsProviderCache with the default GetAWSConfig method.
func NewAWSCredentialsProviderCache(opts ...AWSCredentialsProviderCacheOption) *AWSCredentialsProviderCache {
// zl := zap.New(zap.UseDevMode(false))
logr := logging.NewLogrLogger(zap.New(zap.UseDevMode(false)).WithName("provider-aws-credentials-cache"))
c := &AWSCredentialsProviderCache{
cache: map[string]*awsCredentialsProviderCacheEntry{},
maxSize: 100,
mu: &sync.RWMutex{},
logger: logr,
}
for _, f := range opts {
f(c)
}
return c
}

// AWSCredentialsProviderCache holds aws.CredentialsProvider objects in memory so that
// we don't need to make API calls to AWS in every reconciliation of every
// resource. It has a maximum size that when it's reached, the entry that has
// the oldest access time will be removed from the cache, i.e. FIFO on last access
// time.
// Note that there is no need to invalidate the values in the cache because they
// never change, so we don't need concurrency-safety to prevent access to an
// invalidated entry.
type AWSCredentialsProviderCache struct {
// cache holds the AWS Config with a unique cache key per provider configuration.
// key content includes the ProviderConfig's UUID and ResourceVersion and
// additional fields depending on the auth method
cache map[string]*awsCredentialsProviderCacheEntry

// maxSize is the maximum number of elements this cache can ever have.
maxSize int

// mu is used to make sure the cache map is concurrency-safe.
mu *sync.RWMutex

// logger is the logger for cache operations
logger logging.Logger
}

type awsCredentialsProviderCacheEntry struct {
*aws.Config
credProvider aws.CredentialsProvider
AccessedAt time.Time
}

func (c *AWSCredentialsProviderCache) RetrieveCredentials(ctx context.Context, pc *v1beta1.ProviderConfig, awsCfg *aws.Config) (aws.Credentials, error) {
// cache key calculation tries to capture any parameter that could cause changes
// in the resulting AWS credentials, to ensure unique keys.
//
// Parameters that are directly available in the provider config, will generate
// unique cache keys through UUID and ResourceVersion of the ProviderConfig's
// k8s object, as they change when the provider config is modified.
//
// any other external parameter that have an effect on the resulting credentials
// and does not appear in the ProviderConfig directly (i.e. the same provider config
// content produces a different config), should be included in the cache key
cacheKeyParams := []string{
string(pc.UID),
pc.ResourceVersion,
awsCfg.Region,
string(pc.Spec.Credentials.Source),
}

// Only IRSA authentication method credentials are cached currently
switch s := pc.Spec.Credentials.Source; s { //nolint:exhaustive
case authKeyIRSA:
tokenHash, err := hashTokenFile(os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE"))
if err != nil {
return aws.Credentials{}, errors.Wrap(err, "cannot calculate cache key for credentials")
}
cacheKeyParams = append(cacheKeyParams, authKeyIRSA, tokenHash, os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE"), os.Getenv("AWS_ROLE_ARN"))
default:
c.logger.Debug("skipping cache", "pc", pc.GroupVersionKind().String(), "authSource", s)
// skip cache for other/unimplemented credential types
return awsCfg.Credentials.Retrieve(ctx)
}

cacheKey := strings.Join(cacheKeyParams, ":")
c.logger.Debug("checking cache entry", "cacheKey", cacheKey, "pc", pc.GroupVersionKind().String())
c.mu.RLock()
cacheEntry, ok := c.cache[cacheKey]
c.mu.RUnlock()

// TODO: consider implementing a TTL even though the cached entry is valid
// cache hit
if ok {
c.logger.Debug("cache hit", "cacheKey", cacheKey, "pc", pc.GroupVersionKind().String())
// since this is a hot-path in the execution, do not always update
// the last access times, it is fine to evict the LRU entry on a less
// granular precision.
if time.Since(cacheEntry.AccessedAt) > 10*time.Minute {
c.mu.Lock()
cacheEntry.AccessedAt = time.Now()
c.mu.Unlock()
}
return cacheEntry.credProvider.Retrieve(ctx)
}

// cache miss
c.logger.Debug("cache miss", "cacheKey", cacheKey, "pc", pc.GroupVersionKind().String())
c.mu.Lock()
defer c.mu.Unlock()
c.makeRoom()
c.cache[cacheKey] = &awsCredentialsProviderCacheEntry{
credProvider: awsCfg.Credentials,
AccessedAt: time.Now(),
}
return awsCfg.Credentials.Retrieve(ctx)
}

// makeRoom ensures that there is at most maxSize-1 elements in the cache map
// so that a new entry can be added. It deletes the object that was last accessed
// before all others.
func (c *AWSCredentialsProviderCache) makeRoom() {
if 1+len(c.cache) <= c.maxSize {
return
}
var dustiest string
for key, val := range c.cache {
if dustiest == "" {
dustiest = key
}
if val.AccessedAt.Before(c.cache[dustiest].AccessedAt) {
dustiest = key
}
}
delete(c.cache, dustiest)
}

// hashTokenFile calculates the sha256 checksum of the token file content at
// the supplied file path
func hashTokenFile(filename string) (string, error) {
if filename == "" {
return "", errors.New("token file name cannot be empty")
}
file, err := os.Open(filename)

Check failure on line 187 in internal/clients/creds_cache.go

View workflow job for this annotation

GitHub Actions / lint

G304: Potential file inclusion via variable (gosec)
if err != nil {
return "", err
}
defer file.Close()

Check failure on line 191 in internal/clients/creds_cache.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `file.Close` is not checked (errcheck)

hash := sha256.New()
if _, err = io.Copy(hash, file); err != nil {
return "", err
}

checksum := hash.Sum(nil)
return fmt.Sprintf("%x", checksum), nil
}

0 comments on commit 607c80f

Please sign in to comment.