diff --git a/.changelog/e8196bb044a34f06804817bb03c87bff.json b/.changelog/e8196bb044a34f06804817bb03c87bff.json new file mode 100644 index 00000000000..6c67cbfdc39 --- /dev/null +++ b/.changelog/e8196bb044a34f06804817bb03c87bff.json @@ -0,0 +1,8 @@ +{ + "id": "e8196bb0-44a3-4f06-8048-17bb03c87bff", + "type": "feature", + "description": "Add load option for CredentialCache. Adds a new member to the LoadOptions struct, CredentialsCacheOptions. This member allows specifying a function that will be used to configure the CredentialsCache. The CredentialsCacheOptions will only be used if the configuration loader will wrap the underlying credential provider in the CredentialsCache.", + "modules": [ + "config" + ] +} \ No newline at end of file diff --git a/config/codegen/main.go b/config/codegen/main.go index 6f52eb29d62..80257ff0f39 100644 --- a/config/codegen/main.go +++ b/config/codegen/main.go @@ -26,6 +26,8 @@ var implAsserts = map[string][]string{ "regionProvider": {envConfigType, sharedConfigType, loadOptionsType, ec2IMDSRegionType}, "credentialsProviderProvider": {loadOptionsType}, "defaultRegionProvider": {loadOptionsType}, + "credentialsCacheOptionsProvider": {loadOptionsType}, + "processCredentialOptions": {loadOptionsType}, "ec2RoleCredentialOptionsProvider": {loadOptionsType}, "endpointCredentialOptionsProvider": {loadOptionsType}, "assumeRoleCredentialOptionsProvider": {loadOptionsType}, diff --git a/config/example_test.go b/config/example_test.go index 267bd6944f0..c78be204cd5 100644 --- a/config/example_test.go +++ b/config/example_test.go @@ -6,6 +6,7 @@ import ( "log" "net/http" "path/filepath" + "time" "github.com/aws/aws-sdk-go-v2/aws" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" @@ -16,6 +17,18 @@ import ( smithyhttp "github.com/aws/smithy-go/transport/http" ) +func ExampleWithCredentialsCacheOptions() { + cfg, err := config.LoadDefaultConfig(context.TODO(), + config.WithCredentialsCacheOptions(func(o *aws.CredentialsCacheOptions) { + o.ExpiryWindow = 10 * time.Minute + }), + ) + if err != nil { + log.Fatal(err) + } + _ = cfg +} + func ExampleWithSharedConfigProfile() { cfg, err := config.LoadDefaultConfig(context.TODO(), // Specify the shared configuration profile to load. diff --git a/config/load_options.go b/config/load_options.go index a02c6b08b87..a2756aeb376 100644 --- a/config/load_options.go +++ b/config/load_options.go @@ -101,6 +101,10 @@ type LoadOptions struct { // from the EC2 Metadata service UseEC2IMDSRegion *UseEC2IMDSRegion + // CredentialsCacheOptions is a function for setting the + // aws.CredentialsCacheOptions + CredentialsCacheOptions func(*aws.CredentialsCacheOptions) + // ProcessCredentialOptions is a function for setting // the processcreds.Options ProcessCredentialOptions func(*processcreds.Options) @@ -365,6 +369,29 @@ func WithCredentialsProvider(v aws.CredentialsProvider) LoadOptionsFunc { } } +// getCredentialsCacheOptionsProvider returns the wrapped function to set aws.CredentialsCacheOptions +func (o LoadOptions) getCredentialsCacheOptions(ctx context.Context) (func(*aws.CredentialsCacheOptions), bool, error) { + if o.CredentialsCacheOptions == nil { + return nil, false, nil + } + + return o.CredentialsCacheOptions, true, nil +} + +// WithCredentialsCacheOptions is a helper function to construct functional +// options that sets a function to modify the aws.CredentialsCacheOptions the +// aws.CredentialsCache will be configured with, if the CredentialsCache is used +// by the configuration loader. +// +// If multiple WithCredentialsCacheOptions calls are made, the last call +// overrides the previous call values. +func WithCredentialsCacheOptions(v func(*aws.CredentialsCacheOptions)) LoadOptionsFunc { + return func(o *LoadOptions) error { + o.CredentialsCacheOptions = v + return nil + } +} + // getProcessCredentialOptions returns the wrapped function to set processcreds.Options func (o LoadOptions) getProcessCredentialOptions(ctx context.Context) (func(*processcreds.Options), bool, error) { if o.ProcessCredentialOptions == nil { diff --git a/config/provider.go b/config/provider.go index 557db2c264a..ac4c1403668 100644 --- a/config/provider.go +++ b/config/provider.go @@ -162,6 +162,28 @@ func getCredentialsProvider(ctx context.Context, configs configs) (p aws.Credent return } +// credentialsCacheOptionsProvider is an interface for retrieving a function for setting +// the aws.CredentialsCacheOptions. +type credentialsCacheOptionsProvider interface { + getCredentialsCacheOptions(ctx context.Context) (func(*aws.CredentialsCacheOptions), bool, error) +} + +// getCredentialsCacheOptionsProvider is an interface for retrieving a function for setting +// the aws.CredentialsCacheOptions. +func getCredentialsCacheOptionsProvider(ctx context.Context, configs configs) ( + f func(*aws.CredentialsCacheOptions), found bool, err error, +) { + for _, config := range configs { + if p, ok := config.(credentialsCacheOptionsProvider); ok { + f, found, err = p.getCredentialsCacheOptions(ctx) + if err != nil || found { + break + } + } + } + return +} + // processCredentialOptions is an interface for retrieving a function for setting // the processcreds.Options. type processCredentialOptions interface { diff --git a/config/provider_assert_test.go b/config/provider_assert_test.go index 337614e0871..b0a8fa1cbb4 100644 --- a/config/provider_assert_test.go +++ b/config/provider_assert_test.go @@ -17,6 +17,11 @@ var ( _ clientLogModeProvider = &LoadOptions{} ) +// credentialsCacheOptionsProvider implementor assertions +var ( + _ credentialsCacheOptionsProvider = &LoadOptions{} +) + // credentialsProviderProvider implementor assertions var ( _ credentialsProviderProvider = &LoadOptions{} @@ -68,6 +73,11 @@ var ( _ loggerProvider = &LoadOptions{} ) +// processCredentialOptions implementor assertions +var ( + _ processCredentialOptions = &LoadOptions{} +) + // regionProvider implementor assertions var ( _ regionProvider = &EnvConfig{} diff --git a/config/resolve_credentials.go b/config/resolve_credentials.go index 7dfe7404745..6bac0bb4dd8 100644 --- a/config/resolve_credentials.go +++ b/config/resolve_credentials.go @@ -59,8 +59,8 @@ func resolveCredentials(ctx context.Context, cfg *aws.Config, configs configs) e // // Config providers used: // * credentialsProviderProvider -func resolveCredentialProvider(ctx context.Context, cfg *aws.Config, cfgs configs) (bool, error) { - credProvider, found, err := getCredentialsProvider(ctx, cfgs) +func resolveCredentialProvider(ctx context.Context, cfg *aws.Config, configs configs) (bool, error) { + credProvider, found, err := getCredentialsProvider(ctx, configs) if err != nil { return false, err } @@ -68,7 +68,10 @@ func resolveCredentialProvider(ctx context.Context, cfg *aws.Config, cfgs config return false, nil } - cfg.Credentials = wrapWithCredentialsCache(credProvider) + cfg.Credentials, err = wrapWithCredentialsCache(ctx, configs, credProvider) + if err != nil { + return false, err + } return true, nil } @@ -105,7 +108,10 @@ func resolveCredentialChain(ctx context.Context, cfg *aws.Config, configs config } // Wrap the resolved provider in a cache so the SDK will cache credentials. - cfg.Credentials = wrapWithCredentialsCache(cfg.Credentials) + cfg.Credentials, err = wrapWithCredentialsCache(ctx, configs, cfg.Credentials) + if err != nil { + return err + } return nil } @@ -248,9 +254,12 @@ func resolveHTTPCredProvider(ctx context.Context, cfg *aws.Config, url, authToke provider := endpointcreds.New(url, optFns...) - cfg.Credentials = wrapWithCredentialsCache(provider, func(options *aws.CredentialsCacheOptions) { + cfg.Credentials, err = wrapWithCredentialsCache(ctx, configs, provider, func(options *aws.CredentialsCacheOptions) { options.ExpiryWindow = 5 * time.Minute }) + if err != nil { + return err + } return nil } @@ -296,9 +305,12 @@ func resolveEC2RoleCredentials(ctx context.Context, cfg *aws.Config, configs con provider := ec2rolecreds.New(optFns...) - cfg.Credentials = wrapWithCredentialsCache(provider, func(options *aws.CredentialsCacheOptions) { + cfg.Credentials, err = wrapWithCredentialsCache(ctx, configs, provider, func(options *aws.CredentialsCacheOptions) { options.ExpiryWindow = 5 * time.Minute }) + if err != nil { + return err + } return nil } @@ -430,12 +442,31 @@ func credsFromAssumeRole(ctx context.Context, cfg *aws.Config, sharedCfg *Shared return nil } -// wrapWithCredentialsCache will wrap provider with an aws.CredentialsCache with the provided options if the provider is not already a aws.CredentialsCache. -func wrapWithCredentialsCache(provider aws.CredentialsProvider, optFns ...func(options *aws.CredentialsCacheOptions)) aws.CredentialsProvider { +// wrapWithCredentialsCache will wrap provider with an aws.CredentialsCache +// with the provided options if the provider is not already a +// aws.CredentialsCache. +func wrapWithCredentialsCache( + ctx context.Context, + cfgs configs, + provider aws.CredentialsProvider, + optFns ...func(options *aws.CredentialsCacheOptions), +) (aws.CredentialsProvider, error) { _, ok := provider.(*aws.CredentialsCache) if ok { - return provider + return provider, nil + } + + credCacheOptions, found, err := getCredentialsCacheOptionsProvider(ctx, cfgs) + if err != nil { + return nil, err + } + + // force allocation of a new slice if the additional options are + // needed, to prevent overwriting the passed in slice of options. + optFns = optFns[:len(optFns):len(optFns)] + if found { + optFns = append(optFns, credCacheOptions) } - return aws.NewCredentialsCache(provider, optFns...) + return aws.NewCredentialsCache(provider, optFns...), nil } diff --git a/config/resolve_credentials_test.go b/config/resolve_credentials_test.go index d6663a0b77a..268d3a6ae61 100644 --- a/config/resolve_credentials_test.go +++ b/config/resolve_credentials_test.go @@ -452,3 +452,22 @@ func TestSharedConfigCredentialSource(t *testing.T) { }) } } + +func TestResolveCredentialsCacheOptions(t *testing.T) { + var cfg aws.Config + var optionsFnCalled bool + + err := resolveCredentials(context.Background(), &cfg, configs{LoadOptions{ + CredentialsCacheOptions: func(o *aws.CredentialsCacheOptions) { + optionsFnCalled = true + o.ExpiryWindow = time.Minute * 5 + }, + }}) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !optionsFnCalled { + t.Errorf("expect options to be called") + } +}