Skip to content

Commit

Permalink
Add support to assume an AWS role and renew expired credentials
Browse files Browse the repository at this point in the history
Co-authored-by: Christoph Burmeister <christoph.burmeister@idealo.de>
  • Loading branch information
steveteuber and chburmeister committed Nov 18, 2022
1 parent 416fa22 commit 612ab14
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 16 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ elasticsearch_exporter --help
```

| Argument | Introduced in Version | Description | Default |
| -------- | --------------------- | ----------- | ----------- |
| ----------------------- | --------------------- | ----------- | ----------- |
| es.uri | 1.0.2 | Address (host and port) of the Elasticsearch node we should connect to. This could be a local node (`localhost:9200`, for instance), or the address of a remote Elasticsearch server. When basic auth is needed, specify as: `<proto>://<user>:<password>@<host>:<port>`. E.G., `http://admin:pass@localhost:9200`. Special characters in the user credentials need to be URL-encoded. | <http://localhost:9200> |
| es.all | 1.0.2 | If true, query stats for all nodes in the cluster, rather than just the node we connect to. | false |
| es.cluster_settings | 1.1.0rc1 | If true, query stats for cluster settings. | false |
Expand All @@ -69,7 +69,8 @@ elasticsearch_exporter --help
| es.ssl-skip-verify | 1.0.4rc1 | Skip SSL verification when connecting to Elasticsearch. | false |
| web.listen-address | 1.0.2 | Address to listen on for web interface and telemetry. | :9114 |
| web.telemetry-path | 1.0.2 | Path under which to expose metrics. | /metrics |
| aws.region | 1.5.0 | Region for AWS elasticsearch | |
| aws.region | 1.5.0 | AWS region to send STS requests to. | |
| aws.role-arn | 1.6.0 | AWS role ARN of an IAM role to assume. | |
| version | 1.0.2 | Show version info on stdout and exit. | |

Commandline parameters start with a single `-` for versions less than `1.1.0rc1`.
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ go 1.19
require (
github.com/aws/aws-sdk-go-v2 v1.16.16
github.com/aws/aws-sdk-go-v2/config v1.17.8
github.com/aws/aws-sdk-go-v2/credentials v1.12.21
github.com/aws/aws-sdk-go-v2/service/sts v1.16.19
github.com/blang/semver/v4 v4.0.0
github.com/go-kit/log v0.2.1
github.com/imdario/mergo v0.3.13
Expand All @@ -17,15 +19,13 @@ require (
require (
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.12.21 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.17 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.3.24 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.17 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.11.23 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.6 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.16.19 // indirect
github.com/aws/smithy-go v1.13.3 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect
Expand Down
7 changes: 5 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,10 @@ func main() {
"Sets the log output. Valid outputs are stdout and stderr").
Default("stdout").String()
awsRegion = kingpin.Flag("aws.region",
"Region for AWS elasticsearch").
"AWS region to send STS requests to.").
Default("").String()
awsRoleArn = kingpin.Flag("aws.role-arn",
"AWS role ARN of an IAM role to assume.").
Default("").String()
)

Expand Down Expand Up @@ -171,7 +174,7 @@ func main() {
}

if *awsRegion != "" {
httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, logger)
httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, *awsRoleArn, logger)
if err != nil {
_ = level.Error(logger).Log("msg", "failed to create AWS transport", "err", err)
os.Exit(1)
Expand Down
29 changes: 19 additions & 10 deletions pkg/roundtripper/roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
)
Expand All @@ -36,24 +38,24 @@ const (

type AWSSigningTransport struct {
t http.RoundTripper
creds aws.Credentials
creds aws.CredentialsProvider
region string
log log.Logger
}

func NewAWSSigningTransport(transport http.RoundTripper, region string, log log.Logger) (*AWSSigningTransport, error) {
func NewAWSSigningTransport(transport http.RoundTripper, region string, roleArn string, log log.Logger) (*AWSSigningTransport, error) {
cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region))
if err != nil {
_ = level.Error(log).Log("msg", "fail to load aws default config", "err", err)
_ = level.Error(log).Log("msg", "failed to load aws default config", "err", err)
return nil, err
}

creds, err := cfg.Credentials.Retrieve(context.Background())
if err != nil {
_ = level.Error(log).Log("msg", "fail to retrive aws credentials", "err", err)
return nil, err
if roleArn != "" {
cfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), roleArn)
}

creds := aws.NewCredentialsCache(cfg.Credentials)

return &AWSSigningTransport{
t: transport,
region: region,
Expand All @@ -66,13 +68,20 @@ func (a *AWSSigningTransport) RoundTrip(req *http.Request) (*http.Response, erro
signer := v4.NewSigner()
payloadHash, newReader, err := hashPayload(req.Body)
if err != nil {
_ = level.Error(a.log).Log("msg", "fail to hash request body", "err", err)
_ = level.Error(a.log).Log("msg", "failed to hash request body", "err", err)
return nil, err
}
req.Body = newReader
err = signer.SignHTTP(context.Background(), a.creds, req, payloadHash, service, a.region, time.Now())

creds, err := a.creds.Retrieve(context.Background())
if err != nil {
_ = level.Error(a.log).Log("msg", "failed to retrieve aws credentials", "err", err)
return nil, err
}

err = signer.SignHTTP(context.Background(), creds, req, payloadHash, service, a.region, time.Now())
if err != nil {
_ = level.Error(a.log).Log("msg", "fail to sign request body", "err", err)
_ = level.Error(a.log).Log("msg", "failed to sign request body", "err", err)
return nil, err
}
return a.t.RoundTrip(req)
Expand Down

0 comments on commit 612ab14

Please sign in to comment.