Skip to content

Commit

Permalink
support assume role (#130)
Browse files Browse the repository at this point in the history
* support assume role

* fix config
  • Loading branch information
kaplanelad authored Aug 13, 2020
1 parent 5c2f18d commit 0f47880
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 16 deletions.
16 changes: 11 additions & 5 deletions collector/aws/session.go → collector/aws/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package aws
import (
awsClient "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
log "github.com/sirupsen/logrus"
)

// CreateNewSession return new AWS session
func CreateNewSession(accessKey, secretKey, sessionToken, region string) *session.Session {
// CreateAuthConfiguration return aws auth configuration
func CreateAuthConfiguration(accessKey, secretKey, sessionToken, role, region string) (*session.Session, *awsClient.Config) {
var credentialsAWS *credentials.Credentials

// Use separate call for AWS credentials defined in config.yaml
Expand All @@ -18,11 +19,16 @@ func CreateNewSession(accessKey, secretKey, sessionToken, region string) *sessio
credentialsAWS = credentials.NewStaticCredentials(accessKey, secretKey, sessionToken)
}

sess := session.Must(session.NewSession(&awsClient.Config{
config := &awsClient.Config{
Region: &region,
Credentials: credentialsAWS,
}))
}

return sess
sess := session.Must(session.NewSession(config))

if role != "" {
log.WithField("role", role).Info("assume role provided")
config.Credentials = stscreds.NewCredentials(sess, role, func(p *stscreds.AssumeRoleProvider) {})
}
return sess, config
}
3 changes: 2 additions & 1 deletion collector/aws/common/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"finala/collector/aws/pricing"
"finala/collector/config"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
)
Expand All @@ -25,7 +26,7 @@ type AWSManager interface {
GetCloudWatchClient() *cloudwatch.CloudwatchManager
GetPricingClient() *pricing.PricingManager
GetRegion() string
GetSession() *session.Session
GetSession() (*session.Session, *aws.Config)
GetAccountIdentity() *sts.GetCallerIdentityOutput
SetGlobal(resourceName collector.ResourceIdentifier)
IsGlobalSet(resourceName collector.ResourceIdentifier) bool
Expand Down
15 changes: 9 additions & 6 deletions collector/aws/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"finala/collector/config"
"fmt"

awsClient "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
awsCloudwatch "github.com/aws/aws-sdk-go/service/cloudwatch"
awsPricing "github.com/aws/aws-sdk-go/service/pricing"
Expand All @@ -20,7 +21,7 @@ type DetectorDescriptor interface {
GetCloudWatchClient() *cloudwatch.CloudwatchManager
GetPricingClient() *pricing.PricingManager
GetRegion() string
GetSession() *session.Session
GetSession() (*session.Session, *awsClient.Config)
GetAccountIdentity() *sts.GetCallerIdentityOutput
}

Expand All @@ -35,6 +36,7 @@ type DetectorManager struct {
cloudWatchClient *cloudwatch.CloudwatchManager
pricing *pricing.PricingManager
session *session.Session
awsConfig *awsClient.Config
accountIdentity *sts.GetCallerIdentityOutput
region string
global map[string]struct{}
Expand All @@ -43,11 +45,11 @@ type DetectorManager struct {
// NewDetectorManager create new instance of detector manager
func NewDetectorManager(collector collector.CollectorDescriber, account config.AWSAccount, stsManager *STSManager, global map[string]struct{}, region string) *DetectorManager {

priceSession := CreateNewSession(account.AccessKey, account.SecretKey, account.SessionToken, defaultRegionPrice)
priceSession, _ := CreateAuthConfiguration(account.AccessKey, account.SecretKey, account.SessionToken, account.Role, defaultRegionPrice)
pricingManager := pricing.NewPricingManager(awsPricing.New(priceSession), defaultRegionPrice)

regionSession := CreateNewSession(account.AccessKey, account.SecretKey, account.SessionToken, region)
cloudWatchCLient := cloudwatch.NewCloudWatchManager(awsCloudwatch.New(regionSession))
regionSession, regionConfig := CreateAuthConfiguration(account.AccessKey, account.SecretKey, account.SessionToken, account.Role, region)
cloudWatchCLient := cloudwatch.NewCloudWatchManager(awsCloudwatch.New(regionSession, regionConfig))

callerIdentityOutput, _ := stsManager.client.GetCallerIdentity(&sts.GetCallerIdentityInput{})
return &DetectorManager{
Expand All @@ -56,6 +58,7 @@ func NewDetectorManager(collector collector.CollectorDescriber, account config.A
pricing: pricingManager,
region: region,
session: regionSession,
awsConfig: regionConfig,
accountIdentity: callerIdentityOutput,
global: global,
}
Expand Down Expand Up @@ -87,8 +90,8 @@ func (dm *DetectorManager) GetRegion() string {
}

// GetSession return the aws session
func (dm *DetectorManager) GetSession() *session.Session {
return dm.session
func (dm *DetectorManager) GetSession() (*session.Session, *awsClient.Config) {
return dm.session, dm.awsConfig
}

// GetAccountIdentity return the caller identity
Expand Down
4 changes: 2 additions & 2 deletions collector/aws/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ func (app *Analyze) All() {

for _, account := range app.awsAccounts {

globalsession := CreateNewSession(account.AccessKey, account.SecretKey, account.SessionToken, "")
stsManager := NewSTSManager(sts.New(globalsession))
globalsession, globalConfig := CreateAuthConfiguration(account.AccessKey, account.SecretKey, account.SessionToken, account.Role, "")
stsManager := NewSTSManager(sts.New(globalsession, globalConfig))

for _, region := range account.Regions {
resourcesDetection := NewDetectorManager(app.cl, account, stsManager, app.global, region)
Expand Down
5 changes: 3 additions & 2 deletions collector/aws/testutils/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"finala/collector/aws/pricing"
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
)
Expand Down Expand Up @@ -57,8 +58,8 @@ func (dm *MockAWSManager) GetRegion() string {
return dm.region
}

func (dm *MockAWSManager) GetSession() *session.Session {
return dm.session
func (dm *MockAWSManager) GetSession() (*session.Session, *aws.Config) {
return dm.session, &aws.Config{}
}

func (dm *MockAWSManager) GetAccountIdentity() *sts.GetCallerIdentityOutput {
Expand Down
1 change: 1 addition & 0 deletions collector/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type AWSAccount struct {
Name string `yaml:"name"`
AccessKey string `yaml:"access_key"`
SecretKey string `yaml:"secret_key"`
Role string `yaml:"role"`
SessionToken string `yaml:"session_token"`
Regions []string `yaml:"regions"`
}
Expand Down
1 change: 1 addition & 0 deletions configuration/collector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ providers:
- name: <account_name>
# access_key: <access_key>
# secret_key: <secret_key>
# role:
regions:
- us-east-1
- us-west-2
Expand Down

0 comments on commit 0f47880

Please sign in to comment.