Skip to content

Commit

Permalink
Only use sts GetCallerIdentity to find AWS AccountID
Browse files Browse the repository at this point in the history
  • Loading branch information
vijtrip2 committed Sep 30, 2021
1 parent 20deb45 commit 6080101
Showing 1 changed file with 19 additions and 21 deletions.
40 changes: 19 additions & 21 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package config

import (
"errors"
"fmt"
"net/url"

"github.com/aws/aws-sdk-go/aws/session"
Expand All @@ -30,14 +31,14 @@ const (
flagEnableLeaderElection = "enable-leader-election"
flagMetricAddr = "metrics-addr"
flagEnableDevLogging = "enable-development-logging"
flagAWSAccountID = "aws-account-id"
flagAWSRegion = "aws-region"
flagAWSEndpointURL = "aws-endpoint-url"
flagLogLevel = "log-level"
flagResourceTags = "resource-tags"
flagWatchNamespace = "watch-namespace"
flagEnableWebhookServer = "enable-webhook-server"
flagWebhookServerAddr = "webhook-server-addr"
envVarAWSRegion = "AWS_REGION"
)

// Config contains configuration otpions for ACK service controllers
Expand Down Expand Up @@ -84,14 +85,9 @@ func (cfg *Config) BindFlags() {
"Configures the logger to use a Zap development config (encoder=consoleEncoder,logLevel=Debug,stackTraceLevel=Warn, no sampling), "+
"otherwise a Zap production config will be used (encoder=jsonEncoder,logLevel=Info,stackTraceLevel=Error), sampling).",
)
flag.StringVar(
&cfg.AccountID, flagAWSAccountID,
envutil.WithDefault("AWS_ACCOUNT_ID", ""),
"The AWS Account ID in which the service controller will create resources",
)
flag.StringVar(
&cfg.Region, flagAWSRegion,
envutil.WithDefault("AWS_REGION", ""),
envutil.WithDefault(envVarAWSRegion, ""),
"The AWS Region in which the service controller will create its resources",
)
flag.StringVar(
Expand Down Expand Up @@ -137,27 +133,29 @@ func (cfg *Config) SetupLogger() {
ctrlrt.SetLogger(zap.New(zap.UseFlagOptions(&zapOptions)))
}

// PopulateAccountIdIfMissing uses sts GetCallerIdentity API to find
// AWS AccountId when Config.AccountId is empty
func (cfg *Config) PopulateAccountIdIfMissing() error {
if cfg.AccountID == "" {
// use sts to find AWS AccountId
session := session.Must(session.NewSession())
client := sts.New(session)
res, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err == nil {
cfg.AccountID = *res.Account
}
return err
// SetAWSAccountID uses sts GetCallerIdentity API to find AWS AccountId and set
// in Config
func (cfg *Config) SetAWSAccountID() error {
// use sts to find AWS AccountId
session, err := session.NewSession()
if err != nil {
return fmt.Errorf("unable to create session: %v", err)
}
client := sts.New(session)
res, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
if err != nil {
return fmt.Errorf("unable to get caller identity: %v", err)
}
cfg.AccountID = *res.Account
return nil
}

// Validate ensures the options are valid
func (cfg *Config) Validate() error {
if cfg.AccountID == "" {
return errors.New("unable to start service controller as account ID is missing. Please pass --aws-account-id flag or set AWS_ACCOUNT_ID environment variable")
if err := cfg.SetAWSAccountID(); err != nil {
return errors.New("unable to determine account ID. Please make sure AWS credentials are setup in controller pod")
}

if cfg.Region == "" {
return errors.New("unable to start service controller as AWS region is missing. Please pass --aws-region flag or set AWS_REGION environment variable")
}
Expand Down

0 comments on commit 6080101

Please sign in to comment.