Skip to content

Commit

Permalink
Add SigV4 support (#78)
Browse files Browse the repository at this point in the history
Changes:
- Connector:
    - `enable_sigv4_auth` environment variable has been added. If this environment variable is true, basic auth headers will be ignored and the connector will attempt to use local credentials.
- SAM template:
    - `EnableSigV4Auth` parameter has been added, which sets authentication to `AWS_IAM` for the routes (`/write` and `/read`) and passes in the `enable_sigv4_auth` environment variable to the connector Lambda function with the value `true`.
    - Events have been removed from the Lambda function and replaced with routes to configure SigV4 authentication.
    - The `APIGateway` resource has been changed to an `AWS::ApiGatewayV2::Api` with protocol set to `HTTP`. This was done because using API Gateway resources (such as `AWS::ApiGatewayV2::Route` resources) on an `AWS::Serverless::HttpApi` resource has undefined behaviour.
- Documentation:
    - `Launch (SigV4)` links have been added to the DEVELOPER_README, which set the `EnableSigV4Auth` parameter to `true`.

- [x] Integration tests passed (`go test -v ./integration/`).
- [x] Unit tests passed (`go test -tags=unit -cover -v ./timestream ./`).
- [x] TLS tests passed (`go test -v ./integration/tls`).
- [x] Correctness tests passed (`go test -v ./correctness`).
- [x] Stack deployment tested with SigV4 enabled using the OpenTelemetry collector.
- [x] Stack deployment tested with SigV4 disabled using Prometheus for reading and writing.
  • Loading branch information
trevorbonas authored Nov 20, 2024
1 parent 7eb0cd5 commit 0c51013
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 49 deletions.
1 change: 1 addition & 0 deletions configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
maxRetriesConfig = &configuration{flag: "max-retries", envFlag: "max_retries", defaultValue: strconv.Itoa(awsClient.DefaultRetryerMaxNumRetries)}
defaultDatabaseConfig = &configuration{flag: "default-database", envFlag: "default_database", defaultValue: ""}
defaultTableConfig = &configuration{flag: "default-table", envFlag: "default_table", defaultValue: ""}
enableSigV4AuthConfig = &configuration{flag: "enable-sigv4-auth", envFlag: "enable_sigv4_auth", defaultValue: "true"}
listenAddrConfig = &configuration{flag: "web.listen-address", envFlag: "", defaultValue: ":9201"}
telemetryPathConfig = &configuration{flag: "web.telemetry-path", envFlag: "", defaultValue: "/metrics"}
failOnLabelConfig = &configuration{flag: "fail-on-long-label", envFlag: "fail_on_long_label", defaultValue: "false"}
Expand Down
34 changes: 28 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/go-kit/log"
"github.com/gogo/protobuf/proto"
"github.com/golang/snappy"
Expand Down Expand Up @@ -83,6 +84,7 @@ type connectionConfig struct {
defaultDatabase string
defaultTable string
enableLogging bool
enableSigV4Auth bool
failOnLongMetricLabelName bool
failOnInvalidSample bool
listenAddr string
Expand Down Expand Up @@ -145,9 +147,20 @@ func lambdaHandler(req events.APIGatewayProxyRequest) (events.APIGatewayProxyRes

logger := cfg.createLogger()

awsCredentials, ok := parseBasicAuth(req.Headers[basicAuthHeader])
if !ok {
return createErrorResponse(errors.NewParseBasicAuthHeaderError().(*errors.ParseBasicAuthHeaderError).Message())
var awsCredentials *credentials.Credentials
var ok bool

// If SigV4 authentication has been enabled, such as when write requests originate
// from the OpenTelemetry collector, credentials will be taken from the local environment.
// Otherwise, basic auth is used for AWS credentials
if cfg.enableSigV4Auth {
sess := session.Must(session.NewSession())
awsCredentials = sess.Config.Credentials
} else {
awsCredentials, ok = parseBasicAuth(req.Headers[basicAuthHeader])
if !ok {
return createErrorResponse(errors.NewParseBasicAuthHeaderError().(*errors.ParseBasicAuthHeaderError).Message())
}
}

awsConfigs := cfg.buildAWSConfig()
Expand Down Expand Up @@ -280,7 +293,7 @@ func (cfg *connectionConfig) createLogger() (logger log.Logger) {
}

// parseBoolFromStrings parses the boolean configuration options from the strings in connectionConfig.
func (cfg *connectionConfig) parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample string) error {
func (cfg *connectionConfig) parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample, enableSigV4Auth string) error {
var err error

cfg.enableLogging, err = strconv.ParseBool(enableLogging)
Expand All @@ -304,6 +317,13 @@ func (cfg *connectionConfig) parseBoolFromStrings(enableLogging, failOnLongMetri
return timestreamError
}

cfg.enableSigV4Auth, err = strconv.ParseBool(enableSigV4Auth)
if err != nil {
timestreamError := errors.NewParseSampleOptionError(failOnInvalidSample)
fmt.Println(timestreamError.Error())
return timestreamError
}

return nil
}

Expand All @@ -328,7 +348,7 @@ func parseEnvironmentVariables() (*connectionConfig, error) {
cfg.defaultTable = getOrDefault(defaultTableConfig)

var err error
err = cfg.parseBoolFromStrings(getOrDefault(enableLogConfig), getOrDefault(failOnLabelConfig), getOrDefault(failOnInvalidSampleConfig))
err = cfg.parseBoolFromStrings(getOrDefault(enableLogConfig), getOrDefault(failOnLabelConfig), getOrDefault(failOnInvalidSampleConfig), getOrDefault(enableSigV4AuthConfig))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -357,6 +377,7 @@ func parseFlags() *connectionConfig {
}

var enableLogging string
var enableSigV4Auth string
var failOnLongMetricLabelName string
var failOnInvalidSample string

Expand All @@ -373,6 +394,7 @@ func parseFlags() *connectionConfig {
Default(failOnInvalidSampleConfig.defaultValue).StringVar(&failOnInvalidSample)
a.Flag(certificateConfig.flag, "TLS server certificate file.").Default(certificateConfig.defaultValue).StringVar(&cfg.certificate)
a.Flag(keyConfig.flag, "TLS server private key file.").Default(keyConfig.defaultValue).StringVar(&cfg.key)
a.Flag(enableSigV4AuthConfig.flag, "Whether to enable SigV4 authentication with the API Gateway. Default to 'false'.").Default(enableSigV4AuthConfig.defaultValue).StringVar(&enableSigV4Auth)

flag.AddFlags(a, &cfg.promlogConfig)

Expand All @@ -381,7 +403,7 @@ func parseFlags() *connectionConfig {
os.Exit(1)
}

if err := cfg.parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample); err != nil {
if err := cfg.parseBoolFromStrings(enableLogging, failOnLongMetricLabelName, failOnInvalidSample, enableSigV4Auth); err != nil {
os.Exit(1)
}

Expand Down
4 changes: 4 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ func setUp() ([]string, *connectionConfig) {
defaultDatabase: "foo",
defaultTable: "bar",
enableLogging: true,
enableSigV4Auth: true,
listenAddr: ":9201",
maxRetries: 3,
telemetryPath: "/metrics",
Expand Down Expand Up @@ -365,6 +366,7 @@ func TestLambdaHandlerPrepareRequest(t *testing.T) {
lambdaOptions: []lambdaEnvOptions{
{key: defaultTableConfig.envFlag, value: tableValue},
{key: defaultDatabaseConfig.envFlag, value: databaseValue},
{key: enableSigV4AuthConfig.envFlag, value: "false"},
},
inputRequest: events.APIGatewayProxyRequest{
IsBase64Encoded: true,
Expand All @@ -379,6 +381,7 @@ func TestLambdaHandlerPrepareRequest(t *testing.T) {
lambdaOptions: []lambdaEnvOptions{
{key: defaultTableConfig.envFlag, value: tableValue},
{key: defaultDatabaseConfig.envFlag, value: databaseValue},
{key: enableSigV4AuthConfig.envFlag, value: "false"},
},
inputRequest: events.APIGatewayProxyRequest{
IsBase64Encoded: true,
Expand Down Expand Up @@ -658,6 +661,7 @@ func TestParseEnvironmentVariables(t *testing.T) {
clientConfig: &clientConfig{region: "us-east-1"},
promlogConfig: defaultLogConfig,
enableLogging: true,
enableSigV4Auth: true,
failOnInvalidSample: false,
failOnLongMetricLabelName: false,
maxRetries: 3,
Expand Down
Loading

0 comments on commit 0c51013

Please sign in to comment.