Skip to content

Commit

Permalink
Merge pull request #237 from coveooss/DT-4520/aws-sdk-v2
Browse files Browse the repository at this point in the history
Update up AWS SDK v2 (with golang update to v1.17)
  • Loading branch information
dotboris authored Feb 3, 2022
2 parents ca644d4 + 38017cd commit c2f69df
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 129 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- name: Setup Go
uses: actions/setup-go@v1
with:
go-version: 1.13
go-version: 1.17
id: go

- name: Checkout
Expand All @@ -25,5 +25,4 @@ jobs:
AWS_REGION: us-east-1
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
run: |
go test ./...
run: go test -v ./...
191 changes: 114 additions & 77 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"archive/zip"
"bytes"
"context"
"crypto/md5"
"encoding/json"
"errors"
Expand All @@ -20,18 +21,17 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
awsSession "github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go-v2/aws"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/iam"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/blang/semver"
"github.com/coveooss/gotemplate/v3/collections"
"github.com/fatih/color"
"github.com/hashicorp/go-getter"
"github.com/inconshreveable/go-update"
"golang.org/x/crypto/ssh/terminal"
yaml "gopkg.in/yaml.v2"
)

Expand Down Expand Up @@ -89,12 +89,11 @@ type TGFConfigBuild struct {

var (
cachedAWSConfigExistCheck *bool
cachedSession *session.Session
)

func resetCache() {
cachedAWSConfigExistCheck = nil
cachedSession = nil
cachedAwsConfig = nil
}

func (cb TGFConfigBuild) hash() string {
Expand Down Expand Up @@ -161,85 +160,123 @@ func (config TGFConfig) String() string {
return string(bytes)
}

func (config *TGFConfig) getAwsSession(duration int64) (*session.Session, error) {
if cachedSession != nil {
return cachedSession, nil
}
askedForMfa := false
options := awsSession.Options{
Profile: config.tgf.AwsProfile,
SharedConfigState: awsSession.SharedConfigEnable,
AssumeRoleTokenProvider: func() (string, error) {
askedForMfa = true
fmt.Fprintf(os.Stderr, "Assume Role MFA token code: ")
v, err := terminal.ReadPassword(int(os.Stdin.Fd()))
fmt.Fprintln(os.Stderr)
return string(v), err
},
}
if duration > 0 {
options.AssumeRoleDuration = time.Duration(duration) * time.Second
var cachedAwsConfig *aws.Config

func (tgfConfig *TGFConfig) getAwsConfig(assumeRoleDuration time.Duration) (*aws.Config, error) {
if cachedAwsConfig != nil {
log.Debug("Using cached AWS config")
return cachedAwsConfig, nil
}

session, err := awsSession.NewSessionWithOptions(options)
log.Debugf("Creating new AWS config (assumeRoleDuration=%s)", assumeRoleDuration)
_config, err := awsConfig.LoadDefaultConfig(
context.TODO(),
awsConfig.WithSharedConfigProfile(tgfConfig.tgf.AwsProfile),
awsConfig.WithAssumeRoleCredentialOptions(func(o *stscreds.AssumeRoleOptions) {
o.TokenProvider = stscreds.StdinTokenProvider
if assumeRoleDuration > 0 {
o.Duration = assumeRoleDuration
}
}),
)

if err == nil {
// We must get the current credentials before verifying the expiration
_, err = session.Config.Credentials.Get()
if err != nil {
return nil, err
}
config := &_config

log.Debug("Fetching credentials for current AWS config")
creds, err := config.Credentials.Retrieve(context.TODO())
if err != nil {
return session, err
return nil, err
}

expiration, _ := session.Config.Credentials.ExpiresAt()
if duration := time.Until(expiration).Round(time.Minute); duration > 0 && duration < 55*time.Minute {
// The duration is less that 1 hour, we try to extend the session
expiresIn := time.Until(creds.Expires)
if creds.CanExpire && expiresIn < (1*time.Hour) {
newDuration := guessAwsMaxAssumeRoleDuration(*config)

// We try to find the maximum role session duration allowed (but not complain if not successful)
maxDuration := int64(3600)
roleRegex := regexp.MustCompile(".*:assumed-role/(.*)/.*")
if identity, err := sts.New(session).GetCallerIdentity(&sts.GetCallerIdentityInput{}); err == nil {
if matches := roleRegex.FindStringSubmatch(*identity.Arn); len(matches) > 0 {
if role, err := iam.New(session).GetRole(&iam.GetRoleInput{RoleName: &matches[1]}); err == nil {
maxDuration = *role.Role.MaxSessionDuration
}
}
}
var profile string
if profile = config.tgf.AwsProfile; profile == "" {
if profile = os.Getenv("AWS_PROFILE"); profile == "" {
profile = "default"
}
}
if askedForMfa {
log.Warningf("Your AWS configuration is set to expire your session in %v. This timeout could not be automatically extended due to the session's MFA",
duration)
} else {
session, err = config.getAwsSession(maxDuration)
log.Warningf("Your AWS configuration is set to expire your session in %v (automatically extended to %v)",
duration,
time.Duration(maxDuration)*time.Second)
log.Warningf(
"Credentials for current AWS session are set to expire in less than one hour (%s). Will extend to %s.",
expiresIn,
newDuration,
)

log.Warningf(
color.WhiteString("You should consider defining %s in your AWS config profile %s"),
color.HiBlueString("duration_seconds = %d", newDuration/time.Second),
color.HiBlueString(getPrettyAwsProfileName(*tgfConfig)),
)

shortConfig := config
config, err = tgfConfig.getAwsConfig(newDuration)
if err != nil {
log.Warning("Failed to extend current AWS session, will use the current short duration.", err)
config = shortConfig
}
}

log.Debug("Caching newly created AWS config for future calls")
cachedAwsConfig = config

log.Warningf(color.WhiteString("You should consider defining %s in your AWS config profile %s"),
color.HiBlueString("duration_seconds = %d", maxDuration), color.HiBlueString(profile))
return config, nil
}

func guessAwsMaxAssumeRoleDuration(awsConfig aws.Config) time.Duration {
fallback := 1 * time.Hour
log.Debugf("Trying to figure out the max duration of an AWS assume role operation (fallback=%s)", fallback)

roleRegex := regexp.MustCompile(".*:assumed-role/(.*)/.*")

identity, err := sts.NewFromConfig(awsConfig).GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{})
if err != nil {
log.Debug("Failed, using fallback:", err)
return fallback
}
if err == nil {
cachedSession = session

matches := roleRegex.FindStringSubmatch(*identity.Arn)
if len(matches) == 0 {
log.Debug("Failed, using fallback: Current role is not an assumed role")
return fallback
}
return session, err

role, err := iam.NewFromConfig(awsConfig).GetRole(
context.TODO(),
&iam.GetRoleInput{
RoleName: &matches[1],
},
)
if err != nil {
log.Debug("Failed, using fallback:", err)
return fallback
}

maxDuration := time.Duration(*role.Role.MaxSessionDuration) * time.Second
log.Debugf("Max duration for current role (%s) is %s", *role.Role.Arn, maxDuration)
return maxDuration
}

func getPrettyAwsProfileName(tgfConfig TGFConfig) string {
if profile := tgfConfig.tgf.AwsProfile; profile != "" {
return profile
}

if profile := os.Getenv("AWS_PROFILE"); profile != "" {
return profile
}

return "default"
}

// InitAWS tries to open an AWS session and init AWS environment variable on success
func (config *TGFConfig) InitAWS() error {
if config.tgf.AwsProfile == "" && os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_PROFILE") != "" {
log.Warning("You set both AWS_ACCESS_KEY_ID and AWS_PROFILE, AWS_PROFILE will be ignored")
}
session, err := config.getAwsSession(0)
awsConfig, err := config.getAwsConfig(0)
if err != nil {
return err
}
creds, err := session.Config.Credentials.Get()
creds, err := awsConfig.Credentials.Retrieve(context.TODO())
if err != nil {
return err
}
Expand All @@ -249,7 +286,7 @@ func (config *TGFConfig) InitAWS() error {
"AWS_ACCESS_KEY_ID": creds.AccessKeyID,
"AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey,
"AWS_SESSION_TOKEN": creds.SessionToken,
"AWS_REGION": *session.Config.Region,
"AWS_REGION": awsConfig.Region,
} {
os.Setenv(key, value)
if !config.tgf.ConfigDump {
Expand Down Expand Up @@ -385,15 +422,15 @@ func (config *TGFConfig) validate() (errors []error) {

if config.RecommendedTGFVersion != "" && version != locallyBuilt {
if valid, err := CheckVersionRange(version, config.RecommendedTGFVersion); err != nil {
errors = append(errors, fmt.Errorf("Unable to check recommended tgf version %s vs %s: %v", version, config.RecommendedTGFVersion, err))
errors = append(errors, fmt.Errorf("unable to check recommended tgf version %s vs %s: %v", version, config.RecommendedTGFVersion, err))
} else if !valid {
errors = append(errors, ConfigWarning(fmt.Sprintf("TGF v%s does not meet the recommended version range %s", version, config.RecommendedTGFVersion)))
}
}

if config.RequiredVersionRange != "" && config.ImageVersion != nil && *config.ImageVersion != "" && reVersion.MatchString(*config.ImageVersion) {
if valid, err := CheckVersionRange(*config.ImageVersion, config.RequiredVersionRange); err != nil {
errors = append(errors, fmt.Errorf("Unable to check recommended image version %s vs %s: %v", *config.ImageVersion, config.RequiredVersionRange, err))
errors = append(errors, fmt.Errorf("unable to check recommended image version %s vs %s: %v", *config.ImageVersion, config.RequiredVersionRange, err))
return
} else if !valid {
errors = append(errors, VersionMistmatchError(fmt.Sprintf("Image %s does not meet the required version range %s", config.GetImageName(), config.RequiredVersionRange)))
Expand All @@ -403,7 +440,7 @@ func (config *TGFConfig) validate() (errors []error) {

if config.RecommendedImageVersion != "" && config.ImageVersion != nil && *config.ImageVersion != "" && reVersion.MatchString(*config.ImageVersion) {
if valid, err := CheckVersionRange(*config.ImageVersion, config.RecommendedImageVersion); err != nil {
errors = append(errors, fmt.Errorf("Unable to check recommended image version %s vs %s: %v", *config.ImageVersion, config.RecommendedImageVersion, err))
errors = append(errors, fmt.Errorf("unable to check recommended image version %s vs %s: %v", *config.ImageVersion, config.RecommendedImageVersion, err))
} else if !valid {
errors = append(errors, ConfigWarning(fmt.Sprintf("Image %s does not meet the recommended version range %s", config.GetImageName(), config.RecommendedImageVersion)))
}
Expand Down Expand Up @@ -487,17 +524,17 @@ func (config *TGFConfig) ParseAliases() {

func (config *TGFConfig) readSSMParameterStore(ssmParameterFolder string) map[string]string {
values := make(map[string]string)
session, err := config.getAwsSession(0)
log.Debugf("Reading configuration from SSM %s in %s", ssmParameterFolder, *session.Config.Region)
awsConfig, err := config.getAwsConfig(0)
log.Debugf("Reading configuration from SSM %s in %s", ssmParameterFolder, awsConfig.Region)
if err != nil {
log.Warningf("Caught an error while creating an AWS session: %v", err)
return values
}
svc := ssm.New(session)
response, err := svc.GetParametersByPath(&ssm.GetParametersByPathInput{
svc := ssm.NewFromConfig(*awsConfig)
response, err := svc.GetParametersByPath(context.TODO(), &ssm.GetParametersByPathInput{
Path: aws.String(ssmParameterFolder),
Recursive: aws.Bool(true),
WithDecryption: aws.Bool(true),
Recursive: true,
WithDecryption: true,
})
if err != nil {
log.Warningf("Caught an error while reading from `%s` in SSM: %v", ssmParameterFolder, err)
Expand Down Expand Up @@ -538,7 +575,7 @@ func (config *TGFConfig) findRemoteConfigFiles(location, files string) []string
if err == nil {
_, err = os.Stat(destConfigPath)
if os.IsNotExist(err) {
err = errors.New("Config file was not found at the source")
err = errors.New("config file was not found at the source")
}
}

Expand Down
28 changes: 16 additions & 12 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io/ioutil"
"math/rand"
Expand All @@ -16,9 +17,10 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ssm"
"github.com/aws/aws-sdk-go-v2/service/ssm/types"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -445,11 +447,11 @@ func writeSSMConfig(parameterFolder, parameterKey, parameterValue string) {
putParameterInput := &ssm.PutParameterInput{
Name: aws.String(fullParameterKey),
Value: aws.String(parameterValue),
Overwrite: aws.Bool(true),
Type: aws.String(ssm.ParameterTypeString),
Overwrite: true,
Type: types.ParameterTypeString,
}

if _, err := client.PutParameter(putParameterInput); err != nil {
if _, err := client.PutParameter(context.TODO(), putParameterInput); err != nil {
panic(err)
}
}
Expand All @@ -462,16 +464,18 @@ func deleteSSMConfig(parameterFolder, parameterKey string) {
Name: aws.String(fullParameterKey),
}

if _, err := client.DeleteParameter(deleteParameterInput); err != nil {
if _, err := client.DeleteParameter(context.TODO(), deleteParameterInput); err != nil {
panic(err)
}
}

func getSSMClient() *ssm.SSM {
awsSession := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
}))
return ssm.New(awsSession, &aws.Config{Region: aws.String("us-east-1")})
func getSSMClient() *ssm.Client {
awsConfig, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion("us-east-1"))
if err != nil {
panic(err)
}

return ssm.NewFromConfig(awsConfig)
}

func randInt() int {
Expand Down
Loading

0 comments on commit c2f69df

Please sign in to comment.