Skip to content

Commit

Permalink
Merge pull request #105 from hashicorp/max-attempts-envvar
Browse files Browse the repository at this point in the history
Use `AWS_MAX_ATTEMPTS` environment variable to configure AWS SDK max attempts
  • Loading branch information
gdavison authored Feb 8, 2022
2 parents 5bdc32b + 7511c15 commit 654d4ef
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 2 deletions.
6 changes: 6 additions & 0 deletions aws_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log"
"net"
"os"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -41,6 +42,11 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) {

var retryer aws.Retryer
retryer = retry.NewStandard()
if maxAttempts := os.Getenv("AWS_MAX_ATTEMPTS"); maxAttempts != "" {
if i, err := strconv.Atoi(maxAttempts); err == nil {
retryer = retry.AddWithMaxAttempts(retryer, i)
}
}
if c.MaxRetries != 0 {
retryer = retry.AddWithMaxAttempts(retryer, c.MaxRetries)
}
Expand Down
128 changes: 128 additions & 0 deletions aws_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,134 @@ func fullValueTypeName(v reflect.Value) string {
return fmt.Sprintf("%s.%s", requestType.PkgPath(), requestType.Name())
}

func TestMaxAttempts(t *testing.T) {
testCases := map[string]struct {
Config *Config
EnvironmentVariables map[string]string
SharedConfigurationFile string
ExpectedMaxAttempts int
}{
"no configuration": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectedMaxAttempts: retry.DefaultMaxAttempts,
},

"config": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
MaxRetries: 5,
},
ExpectedMaxAttempts: 5,
},

"AWS_MAX_ATTEMPTS": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
EnvironmentVariables: map[string]string{
"AWS_MAX_ATTEMPTS": "5",
},
ExpectedMaxAttempts: 5,
},

// "shared configuration file": {
// Config: &Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// SharedConfigurationFile: `
// [default]
// max_attempts = 5
// `,
// ExpectedMaxAttempts: 5,
// },

"config overrides AWS_MAX_ATTEMPTS": {
Config: &Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
MaxRetries: 10,
},
EnvironmentVariables: map[string]string{
"AWS_MAX_ATTEMPTS": "5",
},
ExpectedMaxAttempts: 10,
},

// "AWS_MAX_ATTEMPTS overrides shared configuration": {
// Config: &Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// EnvironmentVariables: map[string]string{
// "AWS_MAX_ATTEMPTS": "5",
// },
// SharedConfigurationFile: `
// [default]
// max_attempts = 10
// `,
// ExpectedMaxAttempts: 5,
// },
}

for testName, testCase := range testCases {
testCase := testCase

t.Run(testName, func(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)

for k, v := range testCase.EnvironmentVariables {
os.Setenv(k, v)
}

if testCase.SharedConfigurationFile != "" {
file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file")

if err != nil {
t.Fatalf("unexpected error creating temporary shared configuration file: %s", err)
}

defer os.Remove(file.Name())

err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600)

if err != nil {
t.Fatalf("unexpected error writing shared configuration file: %s", err)
}

testCase.Config.SharedConfigFiles = []string{file.Name()}
}

testCase.Config.SkipCredsValidation = true

awsConfig, err := GetAwsConfig(context.Background(), testCase.Config)
if err != nil {
t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err)
}

retryer := awsConfig.Retryer()
if retryer == nil {
t.Fatal("no retryer set")
}
if a, e := retryer.MaxAttempts(), testCase.ExpectedMaxAttempts; a != e {
t.Errorf(`expected MaxAttempts "%d", got: "%d"`, e, a)
}
})
}
}

func TestServiceEndpointTypes(t *testing.T) {
testCases := map[string]struct {
Config *Config
Expand Down
5 changes: 3 additions & 2 deletions v2/awsv1shim/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,9 @@ func GetSession(awsC *awsv2.Config, c *awsbase.Config) (*session.Session, error)
return nil, fmt.Errorf("Error creating AWS session: %w", err)
}

if c.MaxRetries > 0 {
sess = sess.Copy(&aws.Config{MaxRetries: aws.Int(c.MaxRetries)})
// Set retries after resolving credentials to prevent retries during resolution
if retryer := awsC.Retryer(); retryer != nil {
sess = sess.Copy(&aws.Config{MaxRetries: aws.Int(retryer.MaxAttempts())})
}

SetSessionUserAgent(sess, c.APNInfo, c.UserAgent)
Expand Down
129 changes: 129 additions & 0 deletions v2/awsv1shim/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go-v2/aws/retry"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/client"
Expand Down Expand Up @@ -1205,6 +1206,134 @@ func awsSdkGoUserAgent() string {
return fmt.Sprintf("%s/%s (%s; %s; %s)", aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH)
}

func TestMaxAttempts(t *testing.T) {
testCases := map[string]struct {
Config *awsbase.Config
EnvironmentVariables map[string]string
SharedConfigurationFile string
ExpectedMaxAttempts int
}{
"no configuration": {
Config: &awsbase.Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
ExpectedMaxAttempts: retry.DefaultMaxAttempts,
},

"config": {
Config: &awsbase.Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
MaxRetries: 5,
},
ExpectedMaxAttempts: 5,
},

"AWS_MAX_ATTEMPTS": {
Config: &awsbase.Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
},
EnvironmentVariables: map[string]string{
"AWS_MAX_ATTEMPTS": "5",
},
ExpectedMaxAttempts: 5,
},

// "shared configuration file": {
// Config: &awsbase.Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// SharedConfigurationFile: `
// [default]
// max_attempts = 5
// `,
// ExpectedMaxAttempts: 5,
// },

"config overrides AWS_MAX_ATTEMPTS": {
Config: &awsbase.Config{
AccessKey: servicemocks.MockStaticAccessKey,
Region: "us-east-1",
SecretKey: servicemocks.MockStaticSecretKey,
MaxRetries: 10,
},
EnvironmentVariables: map[string]string{
"AWS_MAX_ATTEMPTS": "5",
},
ExpectedMaxAttempts: 10,
},

// "AWS_MAX_ATTEMPTS overrides shared configuration": {
// Config: &awsbase.Config{
// AccessKey: servicemocks.MockStaticAccessKey,
// Region: "us-east-1",
// SecretKey: servicemocks.MockStaticSecretKey,
// },
// EnvironmentVariables: map[string]string{
// "AWS_MAX_ATTEMPTS": "5",
// },
// SharedConfigurationFile: `
// [default]
// max_attempts = 10
// `,
// ExpectedMaxAttempts: 5,
// },
}

for testName, testCase := range testCases {
testCase := testCase

t.Run(testName, func(t *testing.T) {
oldEnv := servicemocks.InitSessionTestEnv()
defer servicemocks.PopEnv(oldEnv)

for k, v := range testCase.EnvironmentVariables {
os.Setenv(k, v)
}

if testCase.SharedConfigurationFile != "" {
file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file")

if err != nil {
t.Fatalf("unexpected error creating temporary shared configuration file: %s", err)
}

defer os.Remove(file.Name())

err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600)

if err != nil {
t.Fatalf("unexpected error writing shared configuration file: %s", err)
}

testCase.Config.SharedConfigFiles = []string{file.Name()}
}

testCase.Config.SkipCredsValidation = true

awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config)
if err != nil {
t.Fatalf("GetAwsConfig() returned error: %s", err)
}
actualSession, err := GetSession(&awsConfig, testCase.Config)
if err != nil {
t.Fatalf("error in GetSession() '%[1]T': %[1]s", err)
}

if a, e := *actualSession.Config.MaxRetries, testCase.ExpectedMaxAttempts; a != e {
t.Errorf(`expected MaxAttempts "%d", got: "%d"`, e, a)
}
})
}
}

func TestServiceEndpointTypes(t *testing.T) {
testCases := map[string]struct {
Config *awsbase.Config
Expand Down

0 comments on commit 654d4ef

Please sign in to comment.