From 92b5621bf8338d68f64c79cd094969d289428723 Mon Sep 17 00:00:00 2001 From: rittneje Date: Mon, 3 Jan 2022 19:43:31 -0500 Subject: [PATCH] add options for customizing credential providers (#4174) Fixes #4160. It's a little different than what I originally proposed, since (1) this avoid an import cycle, and (2) this is a little more flexible. Adds a NewWebIdentityRoleProviderWithOptions constructor to be similar to the other credential providers defined by the SDK. --- CHANGELOG_PENDING.md | 2 + .../stscreds/web_identity_provider.go | 40 ++- .../stscreds/web_identity_provider_test.go | 244 +++++++++++------- aws/session/credentials.go | 33 ++- aws/session/credentials_test.go | 51 ++++ aws/session/session.go | 5 + 6 files changed, 261 insertions(+), 114 deletions(-) diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index 8a1927a39ca..d229672aebe 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -2,4 +2,6 @@ ### SDK Enhancements +* `aws/session`: Add options for customizing the construction of credential providers. Currently only supported for `stscreds.WebIdentityRoleProvider`. + ### SDK Bugs diff --git a/aws/credentials/stscreds/web_identity_provider.go b/aws/credentials/stscreds/web_identity_provider.go index cefe2a76d4d..19ad619aa3d 100644 --- a/aws/credentials/stscreds/web_identity_provider.go +++ b/aws/credentials/stscreds/web_identity_provider.go @@ -28,7 +28,7 @@ const ( // compare test values. var now = time.Now -// TokenFetcher shuold return WebIdentity token bytes or an error +// TokenFetcher should return WebIdentity token bytes or an error type TokenFetcher interface { FetchToken(credentials.Context) ([]byte, error) } @@ -50,6 +50,8 @@ func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) { // an OIDC token. type WebIdentityRoleProvider struct { credentials.Expiry + + // The policy ARNs to use with the web identity assumed role. PolicyArns []*sts.PolicyDescriptorType // Duration the STS credentials will be valid for. Truncated to seconds. @@ -74,6 +76,9 @@ type WebIdentityRoleProvider struct { // NewWebIdentityCredentials will return a new set of credentials with a given // configuration, role arn, and token file path. +// +// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible +// functional options, and wrap with credentials.NewCredentials helper. func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials { svc := sts.New(c) p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path) @@ -82,19 +87,42 @@ func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName // NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the // provided stsiface.STSAPI +// +// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible +// functional options. func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider { - return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path)) + return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, FetchTokenPath(path)) } // NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the // provided stsiface.STSAPI and a TokenFetcher +// +// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible +// functional options. func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider { - return &WebIdentityRoleProvider{ + return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, tokenFetcher) +} + +// NewWebIdentityRoleProviderWithOptions will return an initialize +// WebIdentityRoleProvider with the provided stsiface.STSAPI, role ARN, and a +// TokenFetcher. Additional options can be provided as functional options. +// +// TokenFetcher is the implementation that will retrieve the JWT token from to +// assume the role with. Use the provided FetchTokenPath implementation to +// retrieve the JWT token using a file system path. +func NewWebIdentityRoleProviderWithOptions(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher, optFns ...func(*WebIdentityRoleProvider)) *WebIdentityRoleProvider { + p := WebIdentityRoleProvider{ client: svc, tokenFetcher: tokenFetcher, roleARN: roleARN, roleSessionName: roleSessionName, } + + for _, fn := range optFns { + fn(&p) + } + + return &p } // Retrieve will attempt to assume a role from a token which is located at @@ -104,9 +132,9 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) { return p.RetrieveWithContext(aws.BackgroundContext()) } -// RetrieveWithContext will attempt to assume a role from a token which is located at -// 'WebIdentityTokenFilePath' specified destination and if that is empty an -// error will be returned. +// RetrieveWithContext will attempt to assume a role from a token which is +// located at 'WebIdentityTokenFilePath' specified destination and if that is +// empty an error will be returned. func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) { b, err := p.tokenFetcher.FetchToken(ctx) if err != nil { diff --git a/aws/credentials/stscreds/web_identity_provider_test.go b/aws/credentials/stscreds/web_identity_provider_test.go index 8f78e93ed0a..7ea53dc3133 100644 --- a/aws/credentials/stscreds/web_identity_provider_test.go +++ b/aws/credentials/stscreds/web_identity_provider_test.go @@ -1,9 +1,10 @@ //go:build go1.7 // +build go1.7 -package stscreds_test +package stscreds import ( + "fmt" "net/http" "reflect" "strings" @@ -11,46 +12,48 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/corehandlers" + "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/awstesting/unit" "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go/service/sts/stsiface" ) func TestWebIdentityProviderRetrieve(t *testing.T) { - var reqCount int cases := map[string]struct { - onSendReq func(*testing.T, *request.Request) roleARN string - tokenFilepath string + tokenPath string sessionName string + newClient func(t *testing.T) stsiface.STSAPI duration time.Duration expectedError string expectedCredValue credentials.Value }{ "session name case": { - roleARN: "arn01234567890123456789", - tokenFilepath: "testdata/token.jwt", - sessionName: "foo", - onSendReq: func(t *testing.T, r *request.Request) { - input := r.Params.(*sts.AssumeRoleWithWebIdentityInput) - if e, a := "foo", *input.RoleSessionName; e != a { - t.Errorf("expected %v, but received %v", e, a) - } - if input.DurationSeconds != nil { - t.Errorf("expect no duration, got %v", *input.DurationSeconds) - } + roleARN: "arn01234567890123456789", + tokenPath: "testdata/token.jwt", + sessionName: "foo", + newClient: func(t *testing.T) stsiface.STSAPI { + return mockAssumeRoleWithWebIdentityClient{ + t: t, + doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) ( + *sts.AssumeRoleWithWebIdentityOutput, error, + ) { + if e, a := "foo", *input.RoleSessionName; e != a { + t.Errorf("expected %v, but received %v", e, a) + } + if input.DurationSeconds != nil { + t.Errorf("expect no duration, got %v", *input.DurationSeconds) + } - data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput) - *data = sts.AssumeRoleWithWebIdentityOutput{ - Credentials: &sts.Credentials{ - Expiration: aws.Time(time.Now()), - AccessKeyId: aws.String("access-key-id"), - SecretAccessKey: aws.String("secret-access-key"), - SessionToken: aws.String("session-token"), + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &sts.Credentials{ + Expiration: aws.Time(time.Now()), + AccessKeyId: aws.String("access-key-id"), + SecretAccessKey: aws.String("secret-access-key"), + SessionToken: aws.String("session-token"), + }, + }, nil }, } }, @@ -58,61 +61,32 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { AccessKeyID: "access-key-id", SecretAccessKey: "secret-access-key", SessionToken: "session-token", - ProviderName: stscreds.WebIdentityProviderName, + ProviderName: WebIdentityProviderName, }, }, "with duration": { - roleARN: "arn01234567890123456789", - tokenFilepath: "testdata/token.jwt", - sessionName: "foo", - duration: 15 * time.Minute, - onSendReq: func(t *testing.T, r *request.Request) { - input := r.Params.(*sts.AssumeRoleWithWebIdentityInput) - if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a { - t.Errorf("expect %v duration, got %v", e, a) - } - - data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput) - *data = sts.AssumeRoleWithWebIdentityOutput{ - Credentials: &sts.Credentials{ - Expiration: aws.Time(time.Now()), - AccessKeyId: aws.String("access-key-id"), - SecretAccessKey: aws.String("secret-access-key"), - SessionToken: aws.String("session-token"), - }, - } - }, - expectedCredValue: credentials.Value{ - AccessKeyID: "access-key-id", - SecretAccessKey: "secret-access-key", - SessionToken: "session-token", - ProviderName: stscreds.WebIdentityProviderName, - }, - }, - "invalid token retry": { - roleARN: "arn01234567890123456789", - tokenFilepath: "testdata/token.jwt", - sessionName: "foo", - onSendReq: func(t *testing.T, r *request.Request) { - input := r.Params.(*sts.AssumeRoleWithWebIdentityInput) - if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) { - t.Errorf("expected %v, but received %v", e, a) - } - - if reqCount == 0 { - r.HTTPResponse.StatusCode = 400 - r.Error = awserr.New(sts.ErrCodeInvalidIdentityTokenException, - "some error message", nil) - return - } + roleARN: "arn01234567890123456789", + tokenPath: "testdata/token.jwt", + sessionName: "foo", + duration: 15 * time.Minute, + newClient: func(t *testing.T) stsiface.STSAPI { + return mockAssumeRoleWithWebIdentityClient{ + t: t, + doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) ( + *sts.AssumeRoleWithWebIdentityOutput, error, + ) { + if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a { + t.Errorf("expect %v duration, got %v", e, a) + } - data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput) - *data = sts.AssumeRoleWithWebIdentityOutput{ - Credentials: &sts.Credentials{ - Expiration: aws.Time(time.Now()), - AccessKeyId: aws.String("access-key-id"), - SecretAccessKey: aws.String("secret-access-key"), - SessionToken: aws.String("session-token"), + return &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &sts.Credentials{ + Expiration: aws.Time(time.Now()), + AccessKeyId: aws.String("access-key-id"), + SecretAccessKey: aws.String("secret-access-key"), + SessionToken: aws.String("session-token"), + }, + }, nil }, } }, @@ -120,33 +94,14 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { AccessKeyID: "access-key-id", SecretAccessKey: "secret-access-key", SessionToken: "session-token", - ProviderName: stscreds.WebIdentityProviderName, + ProviderName: WebIdentityProviderName, }, }, } for name, c := range cases { t.Run(name, func(t *testing.T) { - reqCount = 0 - - svc := sts.New(unit.Session, &aws.Config{ - Logger: t, - }) - svc.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{ - Name: "custom send stub handler", - Fn: func(r *request.Request) { - r.HTTPResponse = &http.Response{ - StatusCode: 200, Header: http.Header{}, - } - c.onSendReq(t, r) - reqCount++ - }, - }) - svc.Handlers.UnmarshalMeta.Clear() - svc.Handlers.Unmarshal.Clear() - svc.Handlers.UnmarshalError.Clear() - - p := stscreds.NewWebIdentityRoleProvider(svc, c.roleARN, c.sessionName, c.tokenFilepath) + p := NewWebIdentityRoleProvider(c.newClient(t), c.roleARN, c.sessionName, c.tokenPath) p.Duration = c.duration credValue, err := p.Retrieve() @@ -169,3 +124,96 @@ func TestWebIdentityProviderRetrieve(t *testing.T) { }) } } + +type mockAssumeRoleWithWebIdentityClient struct { + stsiface.STSAPI + + t *testing.T + doRequest func(*testing.T, *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error) +} + +func (c mockAssumeRoleWithWebIdentityClient) AssumeRoleWithWebIdentityRequest(input *sts.AssumeRoleWithWebIdentityInput) ( + *request.Request, *sts.AssumeRoleWithWebIdentityOutput, +) { + output, err := c.doRequest(c.t, input) + + req := &request.Request{ + HTTPRequest: &http.Request{}, + Retryer: client.DefaultRetryer{}, + } + req.Handlers.Send.PushBack(func(r *request.Request) { + r.HTTPResponse = &http.Response{} + r.Data = output + r.Error = err + + var found bool + for _, retryCode := range req.RetryErrorCodes { + if retryCode == sts.ErrCodeInvalidIdentityTokenException { + found = true + break + } + } + if !found { + c.t.Errorf("expect ErrCodeInvalidIdentityTokenException error code to be retry-able") + } + }) + + return req, output +} + +func TestNewWebIdentityRoleProviderWithOptions(t *testing.T) { + const roleARN = "a-role-arn" + const roleSessionName = "a-session-name" + + cases := map[string]struct { + options []func(*WebIdentityRoleProvider) + expect WebIdentityRoleProvider + }{ + "no options": { + expect: WebIdentityRoleProvider{ + client: stubClient{}, + tokenFetcher: stubTokenFetcher{}, + roleARN: roleARN, + roleSessionName: roleSessionName, + }, + }, + "with options": { + options: []func(*WebIdentityRoleProvider){ + func(o *WebIdentityRoleProvider) { + o.Duration = 10 * time.Minute + o.ExpiryWindow = time.Minute + }, + }, + expect: WebIdentityRoleProvider{ + client: stubClient{}, + tokenFetcher: stubTokenFetcher{}, + roleARN: roleARN, + roleSessionName: roleSessionName, + Duration: 10 * time.Minute, + ExpiryWindow: time.Minute, + }, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + + p := NewWebIdentityRoleProviderWithOptions( + stubClient{}, roleARN, roleSessionName, + stubTokenFetcher{}, c.options...) + + if !reflect.DeepEqual(c.expect, *p) { + t.Errorf("expect:\n%v\nactual:\n%v", c.expect, *p) + } + }) + } +} + +type stubClient struct { + stsiface.STSAPI +} +type stubTokenFetcher struct{} + +func (stubTokenFetcher) FetchToken(credentials.Context) ([]byte, error) { + return nil, fmt.Errorf("stubTokenFetcher should not be called") +} diff --git a/aws/session/credentials.go b/aws/session/credentials.go index 3efdac29ff4..1d3f4c3adc3 100644 --- a/aws/session/credentials.go +++ b/aws/session/credentials.go @@ -14,8 +14,17 @@ import ( "github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/internal/shareddefaults" + "github.com/aws/aws-sdk-go/service/sts" ) +// CredentialsProviderOptions specifies additional options for configuring +// credentials providers. +type CredentialsProviderOptions struct { + // WebIdentityRoleProviderOptions configures a WebIdentityRoleProvider, + // such as setting its ExpiryWindow. + WebIdentityRoleProviderOptions func(*stscreds.WebIdentityRoleProvider) +} + func resolveCredentials(cfg *aws.Config, envCfg envConfig, sharedCfg sharedConfig, handlers request.Handlers, @@ -40,6 +49,7 @@ func resolveCredentials(cfg *aws.Config, envCfg.WebIdentityTokenFilePath, envCfg.RoleARN, envCfg.RoleSessionName, + sessOpts.CredentialsProviderOptions, ) default: @@ -59,6 +69,7 @@ var WebIdentityEmptyTokenFilePathErr = awserr.New(stscreds.ErrCodeWebIdentity, " func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers, filepath string, roleARN, sessionName string, + credOptions *CredentialsProviderOptions, ) (*credentials.Credentials, error) { if len(filepath) == 0 { @@ -69,17 +80,18 @@ func assumeWebIdentity(cfg *aws.Config, handlers request.Handlers, return nil, WebIdentityEmptyRoleARNErr } - creds := stscreds.NewWebIdentityCredentials( - &Session{ - Config: cfg, - Handlers: handlers.Copy(), - }, - roleARN, - sessionName, - filepath, - ) + svc := sts.New(&Session{ + Config: cfg, + Handlers: handlers.Copy(), + }) - return creds, nil + var optFns []func(*stscreds.WebIdentityRoleProvider) + if credOptions != nil && credOptions.WebIdentityRoleProviderOptions != nil { + optFns = append(optFns, credOptions.WebIdentityRoleProviderOptions) + } + + p := stscreds.NewWebIdentityRoleProviderWithOptions(svc, roleARN, sessionName, stscreds.FetchTokenPath(filepath), optFns...) + return credentials.NewCredentials(p), nil } func resolveCredsFromProfile(cfg *aws.Config, @@ -114,6 +126,7 @@ func resolveCredsFromProfile(cfg *aws.Config, sharedCfg.WebIdentityTokenFile, sharedCfg.RoleARN, sharedCfg.RoleSessionName, + sessOpts.CredentialsProviderOptions, ) case sharedCfg.hasSSOConfiguration(): diff --git a/aws/session/credentials_test.go b/aws/session/credentials_test.go index 914ee8a986e..f1a8f332f36 100644 --- a/aws/session/credentials_test.go +++ b/aws/session/credentials_test.go @@ -19,6 +19,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/defaults" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" @@ -814,6 +815,56 @@ func TestSessionAssumeRole_WithMFA_ExtendedDuration(t *testing.T) { } } +func TestSessionAssumeRoleWithWebIdentity_Options(t *testing.T) { + restoreEnvFn := initSessionTestEnv() + defer restoreEnvFn() + + os.Setenv("AWS_REGION", "us-east-1") + os.Setenv("AWS_ROLE_ARN", "web_identity_role_arn") + os.Setenv("AWS_WEB_IDENTITY_TOKEN_FILE", "./testdata/wit.txt") + + endpointResolver, teardown := setupCredentialsEndpoints(t) + defer teardown() + + optionsCalled := false + + sess, err := NewSessionWithOptions(Options{ + Config: aws.Config{ + EndpointResolver: endpointResolver, + }, + CredentialsProviderOptions: &CredentialsProviderOptions{ + WebIdentityRoleProviderOptions: func(*stscreds.WebIdentityRoleProvider) { + optionsCalled = true + }, + }, + }) + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if !optionsCalled { + t.Errorf("expect options func to be called") + } + + creds, err := sess.Config.Credentials.Get() + if err != nil { + t.Fatalf("expect no error, got %v", err) + } + + if e, a := "WEB_IDENTITY_AKID", creds.AccessKeyID; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "WEB_IDENTITY_SECRET", creds.SecretAccessKey; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := "WEB_IDENTITY_SESSION_TOKEN", creds.SessionToken; e != a { + t.Errorf("expect %v, got %v", e, a) + } + if e, a := stscreds.WebIdentityProviderName, creds.ProviderName; e != a { + t.Errorf("expect %v,got %v", e, a) + } +} + func ssoTestSetup() (func(), error) { dir, err := ioutil.TempDir("", "sso-test") if err != nil { diff --git a/aws/session/session.go b/aws/session/session.go index ebace4bb79d..4293dbe10bd 100644 --- a/aws/session/session.go +++ b/aws/session/session.go @@ -304,6 +304,11 @@ type Options struct { // // AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6 EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState + + // Specifies options for creating credential providers. + // These are only used if the aws.Config does not already + // include credentials. + CredentialsProviderOptions *CredentialsProviderOptions } // NewSessionWithOptions returns a new Session created from SDK defaults, config files,