Skip to content

Commit

Permalink
Ensure proper order for obtaining credentials, assuming roles, using …
Browse files Browse the repository at this point in the history
…profiles (#5)

* Adjust logic to validate creds before assumerole

Previously, a bug existed that prevented session-derived creds
from being used when assuming a role. This is because session-
derived creds would not be gathered until the very last moment.
Since all the assumerole logic was passed before this last moment,
assumerole could not work with session-derived creds.

Now, GetCredentials has a new contract - it provides and validates
credentials. Before, GetCredentials would sometimes return
unvalidated creds and sometimes validated creds. This meant that
more error handling logic needed to be included in GetSession and
GetSessionOptions. As part of validating creds, GetCredentials now
gets session-derived creds, if necessary, prior to assuming a role.

* Add error instance for better checking
  • Loading branch information
YakDriver authored and aeschright committed Oct 3, 2019
1 parent 516649a commit f3803c6
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 79 deletions.
106 changes: 88 additions & 18 deletions awsauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ import (
"github.com/hashicorp/go-multierror"
)

const (
// errMsgNoValidCredentialSources error getting credentials
errMsgNoValidCredentialSources = `No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`
)

var (
// ErrNoValidCredentialSources indicates that no credentials source could be found
ErrNoValidCredentialSources = errNoValidCredentialSources()
)

func errNoValidCredentialSources() error { return errors.New(errMsgNoValidCredentialSources) }

// GetAccountIDAndPartition gets the account ID and associated partition.
func GetAccountIDAndPartition(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) {
var accountID, partition string
var err, errors error
Expand Down Expand Up @@ -51,6 +66,8 @@ func GetAccountIDAndPartition(iamconn *iam.IAM, stsconn *sts.STS, authProviderNa
return accountID, partition, errors
}

// GetAccountIDAndPartitionFromEC2Metadata gets the account ID and associated
// partition from EC2 metadata.
func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
log.Println("[DEBUG] Trying to get account information via EC2 Metadata")

Expand All @@ -75,6 +92,8 @@ func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
return parseAccountIDAndPartitionFromARN(info.InstanceProfileArn)
}

// GetAccountIDAndPartitionFromIAMGetUser gets the account ID and associated
// partition from IAM.
func GetAccountIDAndPartitionFromIAMGetUser(iamconn *iam.IAM) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via iam:GetUser")

Expand Down Expand Up @@ -102,6 +121,8 @@ func GetAccountIDAndPartitionFromIAMGetUser(iamconn *iam.IAM) (string, string, e
return parseAccountIDAndPartitionFromARN(aws.StringValue(output.User.Arn))
}

// GetAccountIDAndPartitionFromIAMListRoles gets the account ID and associated
// partition from listing IAM roles.
func GetAccountIDAndPartitionFromIAMListRoles(iamconn *iam.IAM) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via iam:ListRoles")

Expand All @@ -123,6 +144,8 @@ func GetAccountIDAndPartitionFromIAMListRoles(iamconn *iam.IAM) (string, string,
return parseAccountIDAndPartitionFromARN(aws.StringValue(output.Roles[0].Arn))
}

// GetAccountIDAndPartitionFromSTSGetCallerIdentity gets the account ID and associated
// partition from STS caller identity.
func GetAccountIDAndPartitionFromSTSGetCallerIdentity(stsconn *sts.STS) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via sts:GetCallerIdentity")

Expand All @@ -148,9 +171,54 @@ func parseAccountIDAndPartitionFromARN(inputARN string) (string, string, error)
return arn.AccountID, arn.Partition, nil
}

// This function is responsible for reading credentials from the
// environment in the case that they're not explicitly specified
// in the Terraform configuration.
// GetCredentialsFromSession returns credentials derived from a session. A
// session uses the AWS SDK Go chain of providers so may use a provider (e.g.,
// ProcessProvider) that is not part of the Terraform provider chain.
func GetCredentialsFromSession(c *Config) (*awsCredentials.Credentials, error) {
log.Printf("[INFO] Attempting to use session-derived credentials")

var sess *session.Session
var err error
if c.Profile == "" {
sess, err = session.NewSession()
if err != nil {
return nil, ErrNoValidCredentialSources
}
} else {
options := &session.Options{
Config: aws.Config{
HTTPClient: cleanhttp.DefaultClient(),
MaxRetries: aws.Int(0),
Region: aws.String(c.Region),
},
}
options.Profile = c.Profile
options.SharedConfigState = session.SharedConfigEnable

sess, err = session.NewSessionWithOptions(*options)
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
return nil, ErrNoValidCredentialSources
}
return nil, fmt.Errorf("Error creating AWS session: %s", err)
}
}

creds := sess.Config.Credentials
cp, err := sess.Config.Credentials.Get()
if err != nil {
return nil, ErrNoValidCredentialSources
}

log.Printf("[INFO] Successfully derived credentials from session")
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
return creds, nil
}

// GetCredentials gets credentials from the environment, shared credentials,
// or the session (which may include a credential process). GetCredentials also
// validates the credentials and the ability to assume a role or will return an
// error if unsuccessful.
func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
// build a chain provider, lazy-evaluated by aws-sdk
providers := []awsCredentials.Provider{
Expand Down Expand Up @@ -225,30 +293,32 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
}
}

// Validate the credentials before returning them
creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
creds, err = GetCredentialsFromSession(c)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}
} else {
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
}

// This is the "normal" flow (i.e. not assuming a role)
if c.AssumeRoleARN == "" {
return awsCredentials.NewChainCredentials(providers), nil
return creds, nil
}

// Otherwise we need to construct an STS client with the main credentials, and verify
// that we can assume the defined role.
log.Printf("[INFO] Attempting to AssumeRole %s (SessionName: %q, ExternalId: %q, Policy: %q)",
c.AssumeRoleARN, c.AssumeRoleSessionName, c.AssumeRoleExternalID, c.AssumeRolePolicy)

creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
}

return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}

log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
Expand Down
25 changes: 4 additions & 21 deletions awsauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os"
"testing"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
Expand Down Expand Up @@ -423,20 +422,12 @@ func TestAWSGetCredentials_shouldErrorWhenBlank(t *testing.T) {
defer resetEnv()

cfg := Config{}
c, err := GetCredentials(&cfg)
_, err := GetCredentials(&cfg)

if err != nil {
if err != ErrNoValidCredentialSources {
t.Fatalf("Unexpected error: %s", err)
}

_, err = c.Get()
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatal("Expected NoCredentialProviders error")
}
} else {
t.Fatal("Expected AWS error")
}
if err == nil {
t.Fatal("Expected an error given empty env, keys, and IAM in AWS Config")
}
Expand Down Expand Up @@ -586,22 +577,14 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t)
defer ts()

creds, err := GetCredentials(&Config{})
if err != nil {
_, err := GetCredentials(&Config{})
if err != ErrNoValidCredentialSources {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err == nil {
t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint")
}

if v.ProviderName != "" {
t.Fatalf("Expected provider name to be empty, %q given", v.ProviderName)
}
}

func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ require (
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd // indirect
golang.org/x/text v0.3.0 // indirect
)

go 1.13
46 changes: 6 additions & 40 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package awsbase

import (
"crypto/tls"
"errors"
"fmt"
"log"
"net/http"
Expand All @@ -28,45 +27,14 @@ func GetSessionOptions(c *Config) (*session.Options, error) {
},
}

// get and validate credentials
creds, err := GetCredentials(c)
if err != nil {
return nil, err
}

// Call Get to check for credential provider. If nothing found, we'll get an
// error, and we can present it nicely to the user
cp, err := creds.Get()
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
// If a profile wasn't specified, the session may still be able to resolve credentials from shared config.
if c.Profile == "" {
sess, err := session.NewSession()
if err != nil {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
}
_, err = sess.Config.Credentials.Get()
if err != nil {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
}
log.Printf("[INFO] Using session-derived AWS Auth")
options.Config.Credentials = sess.Config.Credentials
} else {
log.Printf("[INFO] AWS Auth using Profile: %q", c.Profile)
options.Profile = c.Profile
options.SharedConfigState = session.SharedConfigEnable
}
} else {
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}
} else {
// add the validated credentials to the session options
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
options.Config.Credentials = creds
}
// add the validated credentials to the session options
options.Config.Credentials = creds

if c.Insecure {
transport := options.Config.HTTPClient.Transport.(*http.Transport)
Expand All @@ -83,7 +51,7 @@ func GetSessionOptions(c *Config) (*session.Options, error) {
return options, nil
}

// GetSession attempts to return valid AWS Go SDK session
// GetSession attempts to return valid AWS Go SDK session.
func GetSession(c *Config) (*session.Session, error) {
options, err := GetSessionOptions(c)

Expand All @@ -94,9 +62,7 @@ func GetSession(c *Config) (*session.Session, error) {
sess, err := session.NewSessionWithOptions(*options)
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
return nil, ErrNoValidCredentialSources
}
return nil, fmt.Errorf("Error creating AWS session: %s", err)
}
Expand Down Expand Up @@ -138,7 +104,7 @@ func GetSession(c *Config) (*session.Session, error) {
if !c.SkipCredsValidation {
stsClient := sts.New(sess.Copy(&aws.Config{Endpoint: aws.String(c.StsEndpoint)}))
if _, _, err := GetAccountIDAndPartitionFromSTSGetCallerIdentity(stsClient); err != nil {
return nil, fmt.Errorf("error validating provider credentials: %s", err)
return nil, fmt.Errorf("error using credentials to get account ID: %s", err)
}
}

Expand Down

0 comments on commit f3803c6

Please sign in to comment.