diff --git a/README.md b/README.md index f9527ce3..5058045a 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ elasticsearch_exporter --help ``` | Argument | Introduced in Version | Description | Default | -| -------- | --------------------- | ----------- | ----------- | +| ----------------------- | --------------------- | ----------- | ----------- | | es.uri | 1.0.2 | Address (host and port) of the Elasticsearch node we should connect to. This could be a local node (`localhost:9200`, for instance), or the address of a remote Elasticsearch server. When basic auth is needed, specify as: `://:@:`. E.G., `http://admin:pass@localhost:9200`. Special characters in the user credentials need to be URL-encoded. | | | es.all | 1.0.2 | If true, query stats for all nodes in the cluster, rather than just the node we connect to. | false | | es.cluster_settings | 1.1.0rc1 | If true, query stats for cluster settings. | false | @@ -69,7 +69,8 @@ elasticsearch_exporter --help | es.ssl-skip-verify | 1.0.4rc1 | Skip SSL verification when connecting to Elasticsearch. | false | | web.listen-address | 1.0.2 | Address to listen on for web interface and telemetry. | :9114 | | web.telemetry-path | 1.0.2 | Path under which to expose metrics. | /metrics | -| aws.region | 1.5.0 | Region for AWS elasticsearch | | +| aws.region | 1.5.0 | AWS region to send STS requests to. | | +| aws.role-arn | 1.6.0 | AWS role ARN of an IAM role to assume. | | | version | 1.0.2 | Show version info on stdout and exit. | | Commandline parameters start with a single `-` for versions less than `1.1.0rc1`. diff --git a/go.mod b/go.mod index 1924e0ff..1f90b958 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.19 require ( github.com/aws/aws-sdk-go-v2 v1.16.16 github.com/aws/aws-sdk-go-v2/config v1.17.8 + github.com/aws/aws-sdk-go-v2/credentials v1.12.21 + github.com/aws/aws-sdk-go-v2/service/sts v1.16.19 github.com/blang/semver/v4 v4.0.0 github.com/go-kit/log v0.2.1 github.com/imdario/mergo v0.3.13 @@ -17,7 +19,6 @@ require ( require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.12.21 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.17 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.23 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.17 // indirect @@ -25,7 +26,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.17 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.11.23 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.6 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.16.19 // indirect github.com/aws/smithy-go v1.13.3 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect diff --git a/main.go b/main.go index 57db3aaa..aced8513 100644 --- a/main.go +++ b/main.go @@ -120,7 +120,10 @@ func main() { "Sets the log output. Valid outputs are stdout and stderr"). Default("stdout").String() awsRegion = kingpin.Flag("aws.region", - "Region for AWS elasticsearch"). + "AWS region to send STS requests to."). + Default("").String() + awsRoleArn = kingpin.Flag("aws.role-arn", + "AWS role ARN of an IAM role to assume."). Default("").String() ) @@ -171,7 +174,7 @@ func main() { } if *awsRegion != "" { - httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, logger) + httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, *awsRoleArn, logger) if err != nil { _ = level.Error(logger).Log("msg", "failed to create AWS transport", "err", err) os.Exit(1) diff --git a/pkg/roundtripper/roundtripper.go b/pkg/roundtripper/roundtripper.go index 824b77ed..10479cb9 100644 --- a/pkg/roundtripper/roundtripper.go +++ b/pkg/roundtripper/roundtripper.go @@ -26,6 +26,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/go-kit/log" "github.com/go-kit/log/level" ) @@ -36,24 +38,24 @@ const ( type AWSSigningTransport struct { t http.RoundTripper - creds aws.Credentials + creds aws.CredentialsProvider region string log log.Logger } -func NewAWSSigningTransport(transport http.RoundTripper, region string, log log.Logger) (*AWSSigningTransport, error) { +func NewAWSSigningTransport(transport http.RoundTripper, region string, roleArn string, log log.Logger) (*AWSSigningTransport, error) { cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region)) if err != nil { - _ = level.Error(log).Log("msg", "fail to load aws default config", "err", err) + _ = level.Error(log).Log("msg", "failed to load aws default config", "err", err) return nil, err } - creds, err := cfg.Credentials.Retrieve(context.Background()) - if err != nil { - _ = level.Error(log).Log("msg", "fail to retrive aws credentials", "err", err) - return nil, err + if roleArn != "" { + cfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), roleArn) } + creds := aws.NewCredentialsCache(cfg.Credentials) + return &AWSSigningTransport{ t: transport, region: region, @@ -66,13 +68,20 @@ func (a *AWSSigningTransport) RoundTrip(req *http.Request) (*http.Response, erro signer := v4.NewSigner() payloadHash, newReader, err := hashPayload(req.Body) if err != nil { - _ = level.Error(a.log).Log("msg", "fail to hash request body", "err", err) + _ = level.Error(a.log).Log("msg", "failed to hash request body", "err", err) return nil, err } req.Body = newReader - err = signer.SignHTTP(context.Background(), a.creds, req, payloadHash, service, a.region, time.Now()) + + creds, err := a.creds.Retrieve(context.Background()) + if err != nil { + _ = level.Error(a.log).Log("msg", "failed to retrieve aws credentials", "err", err) + return nil, err + } + + err = signer.SignHTTP(context.Background(), creds, req, payloadHash, service, a.region, time.Now()) if err != nil { - _ = level.Error(a.log).Log("msg", "fail to sign request body", "err", err) + _ = level.Error(a.log).Log("msg", "failed to sign request body", "err", err) return nil, err } return a.t.RoundTrip(req)