From 60801010475969b411d192574c9a47b5acc81bb9 Mon Sep 17 00:00:00 2001 From: vijtrip2 <62766875+vijtrip2@users.noreply.github.com> Date: Tue, 28 Sep 2021 12:47:28 -0700 Subject: [PATCH] Only use sts GetCallerIdentity to find AWS AccountID --- pkg/config/config.go | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/pkg/config/config.go b/pkg/config/config.go index 56f1c39..16a9370 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -15,6 +15,7 @@ package config import ( "errors" + "fmt" "net/url" "github.com/aws/aws-sdk-go/aws/session" @@ -30,7 +31,6 @@ const ( flagEnableLeaderElection = "enable-leader-election" flagMetricAddr = "metrics-addr" flagEnableDevLogging = "enable-development-logging" - flagAWSAccountID = "aws-account-id" flagAWSRegion = "aws-region" flagAWSEndpointURL = "aws-endpoint-url" flagLogLevel = "log-level" @@ -38,6 +38,7 @@ const ( flagWatchNamespace = "watch-namespace" flagEnableWebhookServer = "enable-webhook-server" flagWebhookServerAddr = "webhook-server-addr" + envVarAWSRegion = "AWS_REGION" ) // Config contains configuration otpions for ACK service controllers @@ -84,14 +85,9 @@ func (cfg *Config) BindFlags() { "Configures the logger to use a Zap development config (encoder=consoleEncoder,logLevel=Debug,stackTraceLevel=Warn, no sampling), "+ "otherwise a Zap production config will be used (encoder=jsonEncoder,logLevel=Info,stackTraceLevel=Error), sampling).", ) - flag.StringVar( - &cfg.AccountID, flagAWSAccountID, - envutil.WithDefault("AWS_ACCOUNT_ID", ""), - "The AWS Account ID in which the service controller will create resources", - ) flag.StringVar( &cfg.Region, flagAWSRegion, - envutil.WithDefault("AWS_REGION", ""), + envutil.WithDefault(envVarAWSRegion, ""), "The AWS Region in which the service controller will create its resources", ) flag.StringVar( @@ -137,27 +133,29 @@ func (cfg *Config) SetupLogger() { ctrlrt.SetLogger(zap.New(zap.UseFlagOptions(&zapOptions))) } -// PopulateAccountIdIfMissing uses sts GetCallerIdentity API to find -// AWS AccountId when Config.AccountId is empty -func (cfg *Config) PopulateAccountIdIfMissing() error { - if cfg.AccountID == "" { - // use sts to find AWS AccountId - session := session.Must(session.NewSession()) - client := sts.New(session) - res, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{}) - if err == nil { - cfg.AccountID = *res.Account - } - return err +// SetAWSAccountID uses sts GetCallerIdentity API to find AWS AccountId and set +// in Config +func (cfg *Config) SetAWSAccountID() error { + // use sts to find AWS AccountId + session, err := session.NewSession() + if err != nil { + return fmt.Errorf("unable to create session: %v", err) } + client := sts.New(session) + res, err := client.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err != nil { + return fmt.Errorf("unable to get caller identity: %v", err) + } + cfg.AccountID = *res.Account return nil } // Validate ensures the options are valid func (cfg *Config) Validate() error { - if cfg.AccountID == "" { - return errors.New("unable to start service controller as account ID is missing. Please pass --aws-account-id flag or set AWS_ACCOUNT_ID environment variable") + if err := cfg.SetAWSAccountID(); err != nil { + return errors.New("unable to determine account ID. Please make sure AWS credentials are setup in controller pod") } + if cfg.Region == "" { return errors.New("unable to start service controller as AWS region is missing. Please pass --aws-region flag or set AWS_REGION environment variable") }