Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure proper order for obtaining credentials, assuming roles, using profiles #5

Merged
merged 2 commits into from
Oct 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions awsauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ import (
"github.com/hashicorp/go-multierror"
)

const (
// errMsgNoValidCredentialSources error getting credentials
errMsgNoValidCredentialSources = `No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`
)

var (
// ErrNoValidCredentialSources indicates that no credentials source could be found
ErrNoValidCredentialSources = errNoValidCredentialSources()
)

func errNoValidCredentialSources() error { return errors.New(errMsgNoValidCredentialSources) }

// GetAccountIDAndPartition gets the account ID and associated partition.
func GetAccountIDAndPartition(iamconn *iam.IAM, stsconn *sts.STS, authProviderName string) (string, string, error) {
var accountID, partition string
var err, errors error
Expand Down Expand Up @@ -51,6 +66,8 @@ func GetAccountIDAndPartition(iamconn *iam.IAM, stsconn *sts.STS, authProviderNa
return accountID, partition, errors
}

// GetAccountIDAndPartitionFromEC2Metadata gets the account ID and associated
// partition from EC2 metadata.
func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
log.Println("[DEBUG] Trying to get account information via EC2 Metadata")

Expand All @@ -75,6 +92,8 @@ func GetAccountIDAndPartitionFromEC2Metadata() (string, string, error) {
return parseAccountIDAndPartitionFromARN(info.InstanceProfileArn)
}

// GetAccountIDAndPartitionFromIAMGetUser gets the account ID and associated
// partition from IAM.
func GetAccountIDAndPartitionFromIAMGetUser(iamconn *iam.IAM) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via iam:GetUser")

Expand Down Expand Up @@ -102,6 +121,8 @@ func GetAccountIDAndPartitionFromIAMGetUser(iamconn *iam.IAM) (string, string, e
return parseAccountIDAndPartitionFromARN(aws.StringValue(output.User.Arn))
}

// GetAccountIDAndPartitionFromIAMListRoles gets the account ID and associated
// partition from listing IAM roles.
func GetAccountIDAndPartitionFromIAMListRoles(iamconn *iam.IAM) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via iam:ListRoles")

Expand All @@ -123,6 +144,8 @@ func GetAccountIDAndPartitionFromIAMListRoles(iamconn *iam.IAM) (string, string,
return parseAccountIDAndPartitionFromARN(aws.StringValue(output.Roles[0].Arn))
}

// GetAccountIDAndPartitionFromSTSGetCallerIdentity gets the account ID and associated
// partition from STS caller identity.
func GetAccountIDAndPartitionFromSTSGetCallerIdentity(stsconn *sts.STS) (string, string, error) {
log.Println("[DEBUG] Trying to get account information via sts:GetCallerIdentity")

Expand All @@ -148,9 +171,54 @@ func parseAccountIDAndPartitionFromARN(inputARN string) (string, string, error)
return arn.AccountID, arn.Partition, nil
}

// This function is responsible for reading credentials from the
// environment in the case that they're not explicitly specified
// in the Terraform configuration.
// GetCredentialsFromSession returns credentials derived from a session. A
// session uses the AWS SDK Go chain of providers so may use a provider (e.g.,
// ProcessProvider) that is not part of the Terraform provider chain.
func GetCredentialsFromSession(c *Config) (*awsCredentials.Credentials, error) {
log.Printf("[INFO] Attempting to use session-derived credentials")

var sess *session.Session
var err error
if c.Profile == "" {
sess, err = session.NewSession()
if err != nil {
return nil, ErrNoValidCredentialSources
}
} else {
options := &session.Options{
Config: aws.Config{
HTTPClient: cleanhttp.DefaultClient(),
MaxRetries: aws.Int(0),
Region: aws.String(c.Region),
},
}
options.Profile = c.Profile
options.SharedConfigState = session.SharedConfigEnable

sess, err = session.NewSessionWithOptions(*options)
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
return nil, ErrNoValidCredentialSources
}
return nil, fmt.Errorf("Error creating AWS session: %s", err)
}
}

creds := sess.Config.Credentials
cp, err := sess.Config.Credentials.Get()
if err != nil {
return nil, ErrNoValidCredentialSources
}

log.Printf("[INFO] Successfully derived credentials from session")
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
return creds, nil
}

// GetCredentials gets credentials from the environment, shared credentials,
// or the session (which may include a credential process). GetCredentials also
// validates the credentials and the ability to assume a role or will return an
// error if unsuccessful.
func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
// build a chain provider, lazy-evaluated by aws-sdk
providers := []awsCredentials.Provider{
Expand Down Expand Up @@ -225,30 +293,32 @@ func GetCredentials(c *Config) (*awsCredentials.Credentials, error) {
}
}

// Validate the credentials before returning them
creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
creds, err = GetCredentialsFromSession(c)
if err != nil {
return nil, err
}
} else {
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}
} else {
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
}

// This is the "normal" flow (i.e. not assuming a role)
if c.AssumeRoleARN == "" {
return awsCredentials.NewChainCredentials(providers), nil
return creds, nil
}

// Otherwise we need to construct an STS client with the main credentials, and verify
// that we can assume the defined role.
log.Printf("[INFO] Attempting to AssumeRole %s (SessionName: %q, ExternalId: %q, Policy: %q)",
c.AssumeRoleARN, c.AssumeRoleSessionName, c.AssumeRoleExternalID, c.AssumeRolePolicy)

creds := awsCredentials.NewChainCredentials(providers)
cp, err := creds.Get()
if err != nil {
if awsErr, ok := err.(awserr.Error); ok && awsErr.Code() == "NoCredentialProviders" {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
}

return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}

log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)

awsConfig := &aws.Config{
Credentials: creds,
Region: aws.String(c.Region),
Expand Down
25 changes: 4 additions & 21 deletions awsauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os"
"testing"

"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/sts"
Expand Down Expand Up @@ -423,20 +422,12 @@ func TestAWSGetCredentials_shouldErrorWhenBlank(t *testing.T) {
defer resetEnv()

cfg := Config{}
c, err := GetCredentials(&cfg)
_, err := GetCredentials(&cfg)

if err != nil {
if err != ErrNoValidCredentialSources {
t.Fatalf("Unexpected error: %s", err)
}

_, err = c.Get()
if awsErr, ok := err.(awserr.Error); ok {
if awsErr.Code() != "NoCredentialProviders" {
t.Fatal("Expected NoCredentialProviders error")
}
} else {
t.Fatal("Expected AWS error")
}
if err == nil {
t.Fatal("Expected an error given empty env, keys, and IAM in AWS Config")
}
Expand Down Expand Up @@ -586,22 +577,14 @@ func TestAWSGetCredentials_shouldErrorWithInvalidEndpoint(t *testing.T) {
ts := invalidAwsEnv(t)
defer ts()

creds, err := GetCredentials(&Config{})
if err != nil {
_, err := GetCredentials(&Config{})
if err != ErrNoValidCredentialSources {
t.Fatalf("Error gettings creds: %s", err)
}
if creds == nil {
t.Fatal("Expected a static creds provider to be returned")
}

v, err := creds.Get()
if err == nil {
t.Fatal("Expected error returned when getting creds w/ invalid EC2 endpoint")
}

if v.ProviderName != "" {
t.Fatalf("Expected provider name to be empty, %q given", v.ProviderName)
}
}

func TestAWSGetCredentials_shouldIgnoreInvalidEndpoint(t *testing.T) {
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ require (
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd // indirect
golang.org/x/text v0.3.0 // indirect
)

go 1.13
46 changes: 6 additions & 40 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package awsbase

import (
"crypto/tls"
"errors"
"fmt"
"log"
"net/http"
Expand All @@ -28,45 +27,14 @@ func GetSessionOptions(c *Config) (*session.Options, error) {
},
}

// get and validate credentials
creds, err := GetCredentials(c)
if err != nil {
return nil, err
}

// Call Get to check for credential provider. If nothing found, we'll get an
// error, and we can present it nicely to the user
cp, err := creds.Get()
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
// If a profile wasn't specified, the session may still be able to resolve credentials from shared config.
if c.Profile == "" {
sess, err := session.NewSession()
if err != nil {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
}
_, err = sess.Config.Credentials.Get()
if err != nil {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
}
log.Printf("[INFO] Using session-derived AWS Auth")
options.Config.Credentials = sess.Config.Credentials
} else {
log.Printf("[INFO] AWS Auth using Profile: %q", c.Profile)
options.Profile = c.Profile
options.SharedConfigState = session.SharedConfigEnable
}
} else {
return nil, fmt.Errorf("Error loading credentials for AWS Provider: %s", err)
}
} else {
// add the validated credentials to the session options
log.Printf("[INFO] AWS Auth provider used: %q", cp.ProviderName)
options.Config.Credentials = creds
}
// add the validated credentials to the session options
options.Config.Credentials = creds

if c.Insecure {
transport := options.Config.HTTPClient.Transport.(*http.Transport)
Expand All @@ -83,7 +51,7 @@ func GetSessionOptions(c *Config) (*session.Options, error) {
return options, nil
}

// GetSession attempts to return valid AWS Go SDK session
// GetSession attempts to return valid AWS Go SDK session.
func GetSession(c *Config) (*session.Session, error) {
options, err := GetSessionOptions(c)

Expand All @@ -94,9 +62,7 @@ func GetSession(c *Config) (*session.Session, error) {
sess, err := session.NewSessionWithOptions(*options)
if err != nil {
if IsAWSErr(err, "NoCredentialProviders", "") {
return nil, errors.New(`No valid credential sources found for AWS Provider.
Please see https://terraform.io/docs/providers/aws/index.html for more information on
providing credentials for the AWS Provider`)
return nil, ErrNoValidCredentialSources
}
return nil, fmt.Errorf("Error creating AWS session: %s", err)
}
Expand Down Expand Up @@ -138,7 +104,7 @@ func GetSession(c *Config) (*session.Session, error) {
if !c.SkipCredsValidation {
stsClient := sts.New(sess.Copy(&aws.Config{Endpoint: aws.String(c.StsEndpoint)}))
if _, _, err := GetAccountIDAndPartitionFromSTSGetCallerIdentity(stsClient); err != nil {
return nil, fmt.Errorf("error validating provider credentials: %s", err)
return nil, fmt.Errorf("error using credentials to get account ID: %s", err)
}
}

Expand Down