Skip to content

Commit

Permalink
Update web identity assume role constructor to take options
Browse files Browse the repository at this point in the history
  • Loading branch information
jasdel committed Jan 4, 2022
1 parent 0a9bede commit 15e70a6
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 108 deletions.
40 changes: 34 additions & 6 deletions aws/credentials/stscreds/web_identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const (
// compare test values.
var now = time.Now

// TokenFetcher shuold return WebIdentity token bytes or an error
// TokenFetcher should return WebIdentity token bytes or an error
type TokenFetcher interface {
FetchToken(credentials.Context) ([]byte, error)
}
Expand All @@ -50,6 +50,8 @@ func (f FetchTokenPath) FetchToken(ctx credentials.Context) ([]byte, error) {
// an OIDC token.
type WebIdentityRoleProvider struct {
credentials.Expiry

// The policy ARNs to use with the web identity assumed role.
PolicyArns []*sts.PolicyDescriptorType

// Duration the STS credentials will be valid for. Truncated to seconds.
Expand All @@ -74,6 +76,9 @@ type WebIdentityRoleProvider struct {

// NewWebIdentityCredentials will return a new set of credentials with a given
// configuration, role arn, and token file path.
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options, and wrap with credentials.NewCredentials helper.
func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName, path string) *credentials.Credentials {
svc := sts.New(c)
p := NewWebIdentityRoleProvider(svc, roleARN, roleSessionName, path)
Expand All @@ -82,19 +87,42 @@ func NewWebIdentityCredentials(c client.ConfigProvider, roleARN, roleSessionName

// NewWebIdentityRoleProvider will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProvider(svc stsiface.STSAPI, roleARN, roleSessionName, path string) *WebIdentityRoleProvider {
return NewWebIdentityRoleProviderWithToken(svc, roleARN, roleSessionName, FetchTokenPath(path))
return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, FetchTokenPath(path))
}

// NewWebIdentityRoleProviderWithToken will return a new WebIdentityRoleProvider with the
// provided stsiface.STSAPI and a TokenFetcher
//
// Deprecated: Use NewWebIdentityRoleProviderWithOptions for flexible
// functional options.
func NewWebIdentityRoleProviderWithToken(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher) *WebIdentityRoleProvider {
return &WebIdentityRoleProvider{
return NewWebIdentityRoleProviderWithOptions(svc, roleARN, roleSessionName, tokenFetcher)
}

// NewWebIdentityRoleProviderWithOptions will return an initialize
// WebIdentityRoleProvider with the provided stsiface.STSAPI, role ARN, and a
// TokenFetcher. Additional options can be provided as functional options.
//
// TokenFetcher is the implementation that will retrieve the JWT token from to
// assume the role with. Use the provided FetchTokenPath implementation to
// retrieve the JWT token using a file system path.
func NewWebIdentityRoleProviderWithOptions(svc stsiface.STSAPI, roleARN, roleSessionName string, tokenFetcher TokenFetcher, optFns ...func(*WebIdentityRoleProvider)) *WebIdentityRoleProvider {
p := WebIdentityRoleProvider{
client: svc,
tokenFetcher: tokenFetcher,
roleARN: roleARN,
roleSessionName: roleSessionName,
}

for _, fn := range optFns {
fn(&p)
}

return &p
}

// Retrieve will attempt to assume a role from a token which is located at
Expand All @@ -104,9 +132,9 @@ func (p *WebIdentityRoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(aws.BackgroundContext())
}

// RetrieveWithContext will attempt to assume a role from a token which is located at
// 'WebIdentityTokenFilePath' specified destination and if that is empty an
// error will be returned.
// RetrieveWithContext will attempt to assume a role from a token which is
// located at 'WebIdentityTokenFilePath' specified destination and if that is
// empty an error will be returned.
func (p *WebIdentityRoleProvider) RetrieveWithContext(ctx credentials.Context) (credentials.Value, error) {
b, err := p.tokenFetcher.FetchToken(ctx)
if err != nil {
Expand Down
244 changes: 146 additions & 98 deletions aws/credentials/stscreds/web_identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -1,152 +1,107 @@
//go:build go1.7
// +build go1.7

package stscreds_test
package stscreds

import (
"fmt"
"net/http"
"reflect"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/corehandlers"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/awstesting/unit"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
)

func TestWebIdentityProviderRetrieve(t *testing.T) {
var reqCount int
cases := map[string]struct {
onSendReq func(*testing.T, *request.Request)
roleARN string
tokenFilepath string
tokenPath string
sessionName string
newClient func(t *testing.T) stsiface.STSAPI
duration time.Duration
expectedError string
expectedCredValue credentials.Value
}{
"session name case": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := "foo", *input.RoleSessionName; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if input.DurationSeconds != nil {
t.Errorf("expect no duration, got %v", *input.DurationSeconds)
}
roleARN: "arn01234567890123456789",
tokenPath: "testdata/token.jwt",
sessionName: "foo",
newClient: func(t *testing.T) stsiface.STSAPI {
return mockAssumeRoleWithWebIdentityClient{
t: t,
doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := "foo", *input.RoleSessionName; e != a {
t.Errorf("expected %v, but received %v", e, a)
}
if input.DurationSeconds != nil {
t.Errorf("expect no duration, got %v", *input.DurationSeconds)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: stscreds.WebIdentityProviderName,
ProviderName: WebIdentityProviderName,
},
},
"with duration": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
duration: 15 * time.Minute,
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a {
t.Errorf("expect %v duration, got %v", e, a)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: stscreds.WebIdentityProviderName,
},
},
"invalid token retry": {
roleARN: "arn01234567890123456789",
tokenFilepath: "testdata/token.jwt",
sessionName: "foo",
onSendReq: func(t *testing.T, r *request.Request) {
input := r.Params.(*sts.AssumeRoleWithWebIdentityInput)
if e, a := "foo", *input.RoleSessionName; !reflect.DeepEqual(e, a) {
t.Errorf("expected %v, but received %v", e, a)
}

if reqCount == 0 {
r.HTTPResponse.StatusCode = 400
r.Error = awserr.New(sts.ErrCodeInvalidIdentityTokenException,
"some error message", nil)
return
}
roleARN: "arn01234567890123456789",
tokenPath: "testdata/token.jwt",
sessionName: "foo",
duration: 15 * time.Minute,
newClient: func(t *testing.T) stsiface.STSAPI {
return mockAssumeRoleWithWebIdentityClient{
t: t,
doRequest: func(t *testing.T, input *sts.AssumeRoleWithWebIdentityInput) (
*sts.AssumeRoleWithWebIdentityOutput, error,
) {
if e, a := int64((15*time.Minute)/time.Second), *input.DurationSeconds; e != a {
t.Errorf("expect %v duration, got %v", e, a)
}

data := r.Data.(*sts.AssumeRoleWithWebIdentityOutput)
*data = sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
return &sts.AssumeRoleWithWebIdentityOutput{
Credentials: &sts.Credentials{
Expiration: aws.Time(time.Now()),
AccessKeyId: aws.String("access-key-id"),
SecretAccessKey: aws.String("secret-access-key"),
SessionToken: aws.String("session-token"),
},
}, nil
},
}
},
expectedCredValue: credentials.Value{
AccessKeyID: "access-key-id",
SecretAccessKey: "secret-access-key",
SessionToken: "session-token",
ProviderName: stscreds.WebIdentityProviderName,
ProviderName: WebIdentityProviderName,
},
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {
reqCount = 0

svc := sts.New(unit.Session, &aws.Config{
Logger: t,
})
svc.Handlers.Send.Swap(corehandlers.SendHandler.Name, request.NamedHandler{
Name: "custom send stub handler",
Fn: func(r *request.Request) {
r.HTTPResponse = &http.Response{
StatusCode: 200, Header: http.Header{},
}
c.onSendReq(t, r)
reqCount++
},
})
svc.Handlers.UnmarshalMeta.Clear()
svc.Handlers.Unmarshal.Clear()
svc.Handlers.UnmarshalError.Clear()

p := stscreds.NewWebIdentityRoleProvider(svc, c.roleARN, c.sessionName, c.tokenFilepath)
p := NewWebIdentityRoleProvider(c.newClient(t), c.roleARN, c.sessionName, c.tokenPath)
p.Duration = c.duration

credValue, err := p.Retrieve()
Expand All @@ -169,3 +124,96 @@ func TestWebIdentityProviderRetrieve(t *testing.T) {
})
}
}

type mockAssumeRoleWithWebIdentityClient struct {
stsiface.STSAPI

t *testing.T
doRequest func(*testing.T, *sts.AssumeRoleWithWebIdentityInput) (*sts.AssumeRoleWithWebIdentityOutput, error)
}

func (c mockAssumeRoleWithWebIdentityClient) AssumeRoleWithWebIdentityRequest(input *sts.AssumeRoleWithWebIdentityInput) (
*request.Request, *sts.AssumeRoleWithWebIdentityOutput,
) {
output, err := c.doRequest(c.t, input)

req := &request.Request{
HTTPRequest: &http.Request{},
Retryer: client.DefaultRetryer{},
}
req.Handlers.Send.PushBack(func(r *request.Request) {
r.HTTPResponse = &http.Response{}
r.Data = output
r.Error = err

var found bool
for _, retryCode := range req.RetryErrorCodes {
if retryCode == sts.ErrCodeInvalidIdentityTokenException {
found = true
break
}
}
if !found {
c.t.Errorf("expect ErrCodeInvalidIdentityTokenException error code to be retry-able")
}
})

return req, output
}

func TestNewWebIdentityRoleProviderWithOptions(t *testing.T) {
const roleARN = "a-role-arn"
const roleSessionName = "a-session-name"

cases := map[string]struct {
options []func(*WebIdentityRoleProvider)
expect WebIdentityRoleProvider
}{
"no options": {
expect: WebIdentityRoleProvider{
client: stubClient{},
tokenFetcher: stubTokenFetcher{},
roleARN: roleARN,
roleSessionName: roleSessionName,
},
},
"with options": {
options: []func(*WebIdentityRoleProvider){
func(o *WebIdentityRoleProvider) {
o.Duration = 10 * time.Minute
o.ExpiryWindow = time.Minute
},
},
expect: WebIdentityRoleProvider{
client: stubClient{},
tokenFetcher: stubTokenFetcher{},
roleARN: roleARN,
roleSessionName: roleSessionName,
Duration: 10 * time.Minute,
ExpiryWindow: time.Minute,
},
},
}

for name, c := range cases {
t.Run(name, func(t *testing.T) {

p := NewWebIdentityRoleProviderWithOptions(
stubClient{}, roleARN, roleSessionName,
stubTokenFetcher{}, c.options...)

if !reflect.DeepEqual(c.expect, *p) {
t.Errorf("expect:\n%v\nactual:\n%v", c.expect, *p)
}
})
}
}

type stubClient struct {
stsiface.STSAPI
}
type stubTokenFetcher struct{}

func (stubTokenFetcher) FetchToken(credentials.Context) ([]byte, error) {
return nil, fmt.Errorf("stubTokenFetcher should not be called")
}
Loading

0 comments on commit 15e70a6

Please sign in to comment.