diff --git a/internal/clients/aws.go b/internal/clients/aws.go index 13fa90b9ab..3b7032319b 100644 --- a/internal/clients/aws.go +++ b/internal/clients/aws.go @@ -50,7 +50,10 @@ func SelectTerraformSetup(config *SetupConfig) terraform.SetupFn { // nolint:goc } else if awsCfg == nil { return terraform.Setup{}, errors.Wrap(err, "obtained aws config cannot be nil") } - creds, err := awsCfg.Credentials.Retrieve(ctx) + + // only IRSA auth credentials are cached, other auth methods will skip + // cache and call downstream Credentials.Retrieve() of given awsCfg + creds, err := GlobalAWSCredentialsProviderCache.RetrieveCredentials(ctx, pc, awsCfg) if err != nil { return terraform.Setup{}, errors.Wrap(err, "failed to retrieve aws credentials from aws config") } diff --git a/internal/clients/creds_cache.go b/internal/clients/creds_cache.go new file mode 100644 index 0000000000..7eacd702c4 --- /dev/null +++ b/internal/clients/creds_cache.go @@ -0,0 +1,200 @@ +// SPDX-FileCopyrightText: 2024 The Crossplane Authors +// +// SPDX-License-Identifier: Apache-2.0 + +package clients + +import ( + "context" + "crypto/sha256" + "fmt" + "github.com/pkg/errors" + "io" + "os" + "sigs.k8s.io/controller-runtime/pkg/log/zap" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/crossplane/crossplane-runtime/pkg/logging" + "github.com/upbound/provider-aws/apis/v1beta1" +) + +// GlobalAWSCredentialsProviderCache is a global AWS CredentialsProvider cache to be used by all controllers. +var GlobalAWSCredentialsProviderCache = NewAWSCredentialsProviderCache() + +// AWSCredentialsProviderCacheOption lets you configure *GlobalAWSCredentialsProviderCache. +type AWSCredentialsProviderCacheOption func(cache *AWSCredentialsProviderCache) + +// WithCacheMaxSize lets you override the default MaxSize for AWS CredentialsProvider cache. +func WithCacheMaxSize(n int) AWSCredentialsProviderCacheOption { + return func(c *AWSCredentialsProviderCache) { + c.maxSize = n + } +} + +// WithCacheStore lets you bootstrap AWS CredentialsProvider Cache with your own cache. +func WithCacheStore(cache map[string]*awsCredentialsProviderCacheEntry) AWSCredentialsProviderCacheOption { + return func(c *AWSCredentialsProviderCache) { + c.cache = cache + } +} + +// WithCacheLogger lets you configure the logger for the cache. +func WithCacheLogger(l logging.Logger) AWSCredentialsProviderCacheOption { + return func(c *AWSCredentialsProviderCache) { + c.logger = l + } +} + +// NewAWSCredentialsProviderCache returns a new empty *AWSCredentialsProviderCache with the default GetAWSConfig method. +func NewAWSCredentialsProviderCache(opts ...AWSCredentialsProviderCacheOption) *AWSCredentialsProviderCache { + // zl := zap.New(zap.UseDevMode(false)) + logr := logging.NewLogrLogger(zap.New(zap.UseDevMode(false)).WithName("provider-aws-credentials-cache")) + c := &AWSCredentialsProviderCache{ + cache: map[string]*awsCredentialsProviderCacheEntry{}, + maxSize: 100, + mu: &sync.RWMutex{}, + logger: logr, + } + for _, f := range opts { + f(c) + } + return c +} + +// AWSCredentialsProviderCache holds aws.CredentialsProvider objects in memory so that +// we don't need to make API calls to AWS in every reconciliation of every +// resource. It has a maximum size that when it's reached, the entry that has +// the oldest access time will be removed from the cache, i.e. FIFO on last access +// time. +// Note that there is no need to invalidate the values in the cache because they +// never change, so we don't need concurrency-safety to prevent access to an +// invalidated entry. +type AWSCredentialsProviderCache struct { + // cache holds the AWS Config with a unique cache key per provider configuration. + // key content includes the ProviderConfig's UUID and ResourceVersion and + // additional fields depending on the auth method + cache map[string]*awsCredentialsProviderCacheEntry + + // maxSize is the maximum number of elements this cache can ever have. + maxSize int + + // mu is used to make sure the cache map is concurrency-safe. + mu *sync.RWMutex + + // logger is the logger for cache operations + logger logging.Logger +} + +type awsCredentialsProviderCacheEntry struct { + *aws.Config + credProvider aws.CredentialsProvider + AccessedAt time.Time +} + +func (c *AWSCredentialsProviderCache) RetrieveCredentials(ctx context.Context, pc *v1beta1.ProviderConfig, awsCfg *aws.Config) (aws.Credentials, error) { + // cache key calculation tries to capture any parameter that could cause changes + // in the resulting AWS credentials, to ensure unique keys. + // + // Parameters that are directly available in the provider config, will generate + // unique cache keys through UUID and ResourceVersion of the ProviderConfig's + // k8s object, as they change when the provider config is modified. + // + // any other external parameter that have an effect on the resulting credentials + // and does not appear in the ProviderConfig directly (i.e. the same provider config + // content produces a different config), should be included in the cache key + cacheKeyParams := []string{ + string(pc.UID), + pc.ResourceVersion, + awsCfg.Region, + string(pc.Spec.Credentials.Source), + } + + // Only IRSA authentication method credentials are cached currently + switch s := pc.Spec.Credentials.Source; s { //nolint:exhaustive + case authKeyIRSA: + tokenHash, err := hashTokenFile(os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE")) + if err != nil { + return aws.Credentials{}, errors.Wrap(err, "cannot calculate cache key for credentials") + } + cacheKeyParams = append(cacheKeyParams, authKeyIRSA, tokenHash, os.Getenv("AWS_WEB_IDENTITY_TOKEN_FILE"), os.Getenv("AWS_ROLE_ARN")) + default: + c.logger.Debug("skipping cache", "pc", pc.GroupVersionKind().String(), "authSource", s) + // skip cache for other/unimplemented credential types + return awsCfg.Credentials.Retrieve(ctx) + } + + cacheKey := strings.Join(cacheKeyParams, ":") + c.logger.Debug("checking cache entry", "cacheKey", cacheKey, "pc", pc.GroupVersionKind().String()) + c.mu.RLock() + cacheEntry, ok := c.cache[cacheKey] + c.mu.RUnlock() + + // TODO: consider implementing a TTL even though the cached entry is valid + // cache hit + if ok { + c.logger.Debug("cache hit", "cacheKey", cacheKey, "pc", pc.GroupVersionKind().String()) + // since this is a hot-path in the execution, do not always update + // the last access times, it is fine to evict the LRU entry on a less + // granular precision. + if time.Since(cacheEntry.AccessedAt) > 10*time.Minute { + c.mu.Lock() + cacheEntry.AccessedAt = time.Now() + c.mu.Unlock() + } + return cacheEntry.credProvider.Retrieve(ctx) + } + + // cache miss + c.logger.Debug("cache miss", "cacheKey", cacheKey, "pc", pc.GroupVersionKind().String()) + c.mu.Lock() + defer c.mu.Unlock() + c.makeRoom() + c.cache[cacheKey] = &awsCredentialsProviderCacheEntry{ + credProvider: awsCfg.Credentials, + AccessedAt: time.Now(), + } + return awsCfg.Credentials.Retrieve(ctx) +} + +// makeRoom ensures that there is at most maxSize-1 elements in the cache map +// so that a new entry can be added. It deletes the object that was last accessed +// before all others. +func (c *AWSCredentialsProviderCache) makeRoom() { + if 1+len(c.cache) <= c.maxSize { + return + } + var dustiest string + for key, val := range c.cache { + if dustiest == "" { + dustiest = key + } + if val.AccessedAt.Before(c.cache[dustiest].AccessedAt) { + dustiest = key + } + } + delete(c.cache, dustiest) +} + +// hashTokenFile calculates the sha256 checksum of the token file content at +// the supplied file path +func hashTokenFile(filename string) (string, error) { + if filename == "" { + return "", errors.New("token file name cannot be empty") + } + file, err := os.Open(filename) + if err != nil { + return "", err + } + defer file.Close() + + hash := sha256.New() + if _, err = io.Copy(hash, file); err != nil { + return "", err + } + + checksum := hash.Sum(nil) + return fmt.Sprintf("%x", checksum), nil +}