diff --git a/README.md b/README.md index f9527ce3..ca3356fe 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ elasticsearch_exporter --help ``` | Argument | Introduced in Version | Description | Default | -| -------- | --------------------- | ----------- | ----------- | +| ----------------------- | --------------------- | ----------- | ----------- | | es.uri | 1.0.2 | Address (host and port) of the Elasticsearch node we should connect to. This could be a local node (`localhost:9200`, for instance), or the address of a remote Elasticsearch server. When basic auth is needed, specify as: `://:@:`. E.G., `http://admin:pass@localhost:9200`. Special characters in the user credentials need to be URL-encoded. | | | es.all | 1.0.2 | If true, query stats for all nodes in the cluster, rather than just the node we connect to. | false | | es.cluster_settings | 1.1.0rc1 | If true, query stats for cluster settings. | false | @@ -70,6 +70,7 @@ elasticsearch_exporter --help | web.listen-address | 1.0.2 | Address to listen on for web interface and telemetry. | :9114 | | web.telemetry-path | 1.0.2 | Path under which to expose metrics. | /metrics | | aws.region | 1.5.0 | Region for AWS elasticsearch | | +| aws.role-arn | 1.6.0 | Role ARN of an IAM role to assume. | | | version | 1.0.2 | Show version info on stdout and exit. | | Commandline parameters start with a single `-` for versions less than `1.1.0rc1`. diff --git a/go.mod b/go.mod index 5c003b24..4590801f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.19 require ( github.com/aws/aws-sdk-go-v2 v1.17.3 github.com/aws/aws-sdk-go-v2/config v1.18.7 + github.com/aws/aws-sdk-go-v2/credentials v1.13.7 + github.com/aws/aws-sdk-go-v2/service/sts v1.17.7 github.com/blang/semver/v4 v4.0.0 github.com/go-kit/log v0.2.1 github.com/imdario/mergo v0.3.13 @@ -17,7 +19,6 @@ require ( require ( github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751 // indirect github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.13.7 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.21 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21 // indirect @@ -25,7 +26,6 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.21 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.11.28 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.13.11 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.17.7 // indirect github.com/aws/smithy-go v1.13.5 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect diff --git a/main.go b/main.go index 857504d9..5680950f 100644 --- a/main.go +++ b/main.go @@ -125,6 +125,9 @@ func main() { awsRegion = kingpin.Flag("aws.region", "Region for AWS elasticsearch"). Default("").String() + awsRoleArn = kingpin.Flag("aws.role-arn", + "Role ARN of an IAM role to assume."). + Default("").String() ) kingpin.Version(version.Print(name)) @@ -174,7 +177,7 @@ func main() { } if *awsRegion != "" { - httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, logger) + httpClient.Transport, err = roundtripper.NewAWSSigningTransport(httpTransport, *awsRegion, *awsRoleArn, logger) if err != nil { _ = level.Error(logger).Log("msg", "failed to create AWS transport", "err", err) os.Exit(1) diff --git a/pkg/roundtripper/roundtripper.go b/pkg/roundtripper/roundtripper.go index 824b77ed..7c55005a 100644 --- a/pkg/roundtripper/roundtripper.go +++ b/pkg/roundtripper/roundtripper.go @@ -26,6 +26,8 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/go-kit/log" "github.com/go-kit/log/level" ) @@ -36,21 +38,28 @@ const ( type AWSSigningTransport struct { t http.RoundTripper - creds aws.Credentials + creds aws.CredentialsProvider region string log log.Logger } -func NewAWSSigningTransport(transport http.RoundTripper, region string, log log.Logger) (*AWSSigningTransport, error) { +func NewAWSSigningTransport(transport http.RoundTripper, region string, roleArn string, log log.Logger) (*AWSSigningTransport, error) { cfg, err := config.LoadDefaultConfig(context.Background(), config.WithRegion(region)) if err != nil { - _ = level.Error(log).Log("msg", "fail to load aws default config", "err", err) + _ = level.Error(log).Log("msg", "failed to load aws default config", "err", err) return nil, err } - creds, err := cfg.Credentials.Retrieve(context.Background()) + if roleArn != "" { + cfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), roleArn) + } + + creds := aws.NewCredentialsCache(cfg.Credentials) + // Run a single fetch credentials operation to ensure that the credentials + // are valid before returning the transport. + _, err = cfg.Credentials.Retrieve(context.Background()) if err != nil { - _ = level.Error(log).Log("msg", "fail to retrive aws credentials", "err", err) + _ = level.Error(log).Log("msg", "failed to retrive aws credentials", "err", err) return nil, err } @@ -66,13 +75,20 @@ func (a *AWSSigningTransport) RoundTrip(req *http.Request) (*http.Response, erro signer := v4.NewSigner() payloadHash, newReader, err := hashPayload(req.Body) if err != nil { - _ = level.Error(a.log).Log("msg", "fail to hash request body", "err", err) + _ = level.Error(a.log).Log("msg", "failed to hash request body", "err", err) return nil, err } req.Body = newReader - err = signer.SignHTTP(context.Background(), a.creds, req, payloadHash, service, a.region, time.Now()) + + creds, err := a.creds.Retrieve(context.Background()) + if err != nil { + _ = level.Error(a.log).Log("msg", "failed to retrieve aws credentials", "err", err) + return nil, err + } + + err = signer.SignHTTP(context.Background(), creds, req, payloadHash, service, a.region, time.Now()) if err != nil { - _ = level.Error(a.log).Log("msg", "fail to sign request body", "err", err) + _ = level.Error(a.log).Log("msg", "failed to sign request body", "err", err) return nil, err } return a.t.RoundTrip(req)