From 7511c15b168db4813b63a467e346f4808f1f496a Mon Sep 17 00:00:00 2001 From: Graham Davison Date: Tue, 8 Feb 2022 12:28:20 -0800 Subject: [PATCH] Uses `AWS_MAX_ATTEMPTS` environment variable to configure AWS SDK max attempts --- aws_config.go | 6 ++ aws_config_test.go | 128 ++++++++++++++++++++++++++++++++++ v2/awsv1shim/session.go | 5 +- v2/awsv1shim/session_test.go | 129 +++++++++++++++++++++++++++++++++++ 4 files changed, 266 insertions(+), 2 deletions(-) diff --git a/aws_config.go b/aws_config.go index bd3c53b1..4aea2606 100644 --- a/aws_config.go +++ b/aws_config.go @@ -7,6 +7,7 @@ import ( "log" "net" "os" + "strconv" "strings" "time" @@ -41,6 +42,11 @@ func GetAwsConfig(ctx context.Context, c *Config) (aws.Config, error) { var retryer aws.Retryer retryer = retry.NewStandard() + if maxAttempts := os.Getenv("AWS_MAX_ATTEMPTS"); maxAttempts != "" { + if i, err := strconv.Atoi(maxAttempts); err == nil { + retryer = retry.AddWithMaxAttempts(retryer, i) + } + } if c.MaxRetries != 0 { retryer = retry.AddWithMaxAttempts(retryer, c.MaxRetries) } diff --git a/aws_config_test.go b/aws_config_test.go index 707af08f..f0df937e 100644 --- a/aws_config_test.go +++ b/aws_config_test.go @@ -1185,6 +1185,134 @@ func fullValueTypeName(v reflect.Value) string { return fmt.Sprintf("%s.%s", requestType.PkgPath(), requestType.Name()) } +func TestMaxAttempts(t *testing.T) { + testCases := map[string]struct { + Config *Config + EnvironmentVariables map[string]string + SharedConfigurationFile string + ExpectedMaxAttempts int + }{ + "no configuration": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedMaxAttempts: retry.DefaultMaxAttempts, + }, + + "config": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + MaxRetries: 5, + }, + ExpectedMaxAttempts: 5, + }, + + "AWS_MAX_ATTEMPTS": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + EnvironmentVariables: map[string]string{ + "AWS_MAX_ATTEMPTS": "5", + }, + ExpectedMaxAttempts: 5, + }, + + // "shared configuration file": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // SharedConfigurationFile: ` + // [default] + // max_attempts = 5 + // `, + // ExpectedMaxAttempts: 5, + // }, + + "config overrides AWS_MAX_ATTEMPTS": { + Config: &Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + MaxRetries: 10, + }, + EnvironmentVariables: map[string]string{ + "AWS_MAX_ATTEMPTS": "5", + }, + ExpectedMaxAttempts: 10, + }, + + // "AWS_MAX_ATTEMPTS overrides shared configuration": { + // Config: &Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_MAX_ATTEMPTS": "5", + // }, + // SharedConfigurationFile: ` + // [default] + // max_attempts = 10 + // `, + // ExpectedMaxAttempts: 5, + // }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } + + if testCase.SharedConfigurationFile != "" { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("error in GetAwsConfig() '%[1]T': %[1]s", err) + } + + retryer := awsConfig.Retryer() + if retryer == nil { + t.Fatal("no retryer set") + } + if a, e := retryer.MaxAttempts(), testCase.ExpectedMaxAttempts; a != e { + t.Errorf(`expected MaxAttempts "%d", got: "%d"`, e, a) + } + }) + } +} + func TestServiceEndpointTypes(t *testing.T) { testCases := map[string]struct { Config *Config diff --git a/v2/awsv1shim/session.go b/v2/awsv1shim/session.go index 01ecb5f7..abc04b96 100644 --- a/v2/awsv1shim/session.go +++ b/v2/awsv1shim/session.go @@ -81,8 +81,9 @@ func GetSession(awsC *awsv2.Config, c *awsbase.Config) (*session.Session, error) return nil, fmt.Errorf("Error creating AWS session: %w", err) } - if c.MaxRetries > 0 { - sess = sess.Copy(&aws.Config{MaxRetries: aws.Int(c.MaxRetries)}) + // Set retries after resolving credentials to prevent retries during resolution + if retryer := awsC.Retryer(); retryer != nil { + sess = sess.Copy(&aws.Config{MaxRetries: aws.Int(retryer.MaxAttempts())}) } SetSessionUserAgent(sess, c.APNInfo, c.UserAgent) diff --git a/v2/awsv1shim/session_test.go b/v2/awsv1shim/session_test.go index e82c1f1d..ccc2c58b 100644 --- a/v2/awsv1shim/session_test.go +++ b/v2/awsv1shim/session_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws/retry" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/client" @@ -1205,6 +1206,134 @@ func awsSdkGoUserAgent() string { return fmt.Sprintf("%s/%s (%s; %s; %s)", aws.SDKName, aws.SDKVersion, runtime.Version(), runtime.GOOS, runtime.GOARCH) } +func TestMaxAttempts(t *testing.T) { + testCases := map[string]struct { + Config *awsbase.Config + EnvironmentVariables map[string]string + SharedConfigurationFile string + ExpectedMaxAttempts int + }{ + "no configuration": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + ExpectedMaxAttempts: retry.DefaultMaxAttempts, + }, + + "config": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + MaxRetries: 5, + }, + ExpectedMaxAttempts: 5, + }, + + "AWS_MAX_ATTEMPTS": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + }, + EnvironmentVariables: map[string]string{ + "AWS_MAX_ATTEMPTS": "5", + }, + ExpectedMaxAttempts: 5, + }, + + // "shared configuration file": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // SharedConfigurationFile: ` + // [default] + // max_attempts = 5 + // `, + // ExpectedMaxAttempts: 5, + // }, + + "config overrides AWS_MAX_ATTEMPTS": { + Config: &awsbase.Config{ + AccessKey: servicemocks.MockStaticAccessKey, + Region: "us-east-1", + SecretKey: servicemocks.MockStaticSecretKey, + MaxRetries: 10, + }, + EnvironmentVariables: map[string]string{ + "AWS_MAX_ATTEMPTS": "5", + }, + ExpectedMaxAttempts: 10, + }, + + // "AWS_MAX_ATTEMPTS overrides shared configuration": { + // Config: &awsbase.Config{ + // AccessKey: servicemocks.MockStaticAccessKey, + // Region: "us-east-1", + // SecretKey: servicemocks.MockStaticSecretKey, + // }, + // EnvironmentVariables: map[string]string{ + // "AWS_MAX_ATTEMPTS": "5", + // }, + // SharedConfigurationFile: ` + // [default] + // max_attempts = 10 + // `, + // ExpectedMaxAttempts: 5, + // }, + } + + for testName, testCase := range testCases { + testCase := testCase + + t.Run(testName, func(t *testing.T) { + oldEnv := servicemocks.InitSessionTestEnv() + defer servicemocks.PopEnv(oldEnv) + + for k, v := range testCase.EnvironmentVariables { + os.Setenv(k, v) + } + + if testCase.SharedConfigurationFile != "" { + file, err := ioutil.TempFile("", "aws-sdk-go-base-shared-configuration-file") + + if err != nil { + t.Fatalf("unexpected error creating temporary shared configuration file: %s", err) + } + + defer os.Remove(file.Name()) + + err = ioutil.WriteFile(file.Name(), []byte(testCase.SharedConfigurationFile), 0600) + + if err != nil { + t.Fatalf("unexpected error writing shared configuration file: %s", err) + } + + testCase.Config.SharedConfigFiles = []string{file.Name()} + } + + testCase.Config.SkipCredsValidation = true + + awsConfig, err := awsbase.GetAwsConfig(context.Background(), testCase.Config) + if err != nil { + t.Fatalf("GetAwsConfig() returned error: %s", err) + } + actualSession, err := GetSession(&awsConfig, testCase.Config) + if err != nil { + t.Fatalf("error in GetSession() '%[1]T': %[1]s", err) + } + + if a, e := *actualSession.Config.MaxRetries, testCase.ExpectedMaxAttempts; a != e { + t.Errorf(`expected MaxAttempts "%d", got: "%d"`, e, a) + } + }) + } +} + func TestServiceEndpointTypes(t *testing.T) { testCases := map[string]struct { Config *awsbase.Config