diff --git a/router/customdestinationmanager/customdestinationmanager.go b/router/customdestinationmanager/customdestinationmanager.go index ca34091b15..d5e5e54065 100644 --- a/router/customdestinationmanager/customdestinationmanager.go +++ b/router/customdestinationmanager/customdestinationmanager.go @@ -76,7 +76,7 @@ func Init() { } func loadConfig() { - ObjectStreamDestinations = []string{"KINESIS", "KAFKA", "AZURE_EVENT_HUB", "FIREHOSE", "EVENTBRIDGE", "GOOGLEPUBSUB", "CONFLUENT_CLOUD", "PERSONALIZE", "GOOGLESHEETS", "BQSTREAM", "LAMBDA", "GOOGLE_CLOUD_FUNCTION"} + ObjectStreamDestinations = []string{"KINESIS", "KAFKA", "AZURE_EVENT_HUB", "FIREHOSE", "EVENTBRIDGE", "GOOGLEPUBSUB", "CONFLUENT_CLOUD", "PERSONALIZE", "GOOGLESHEETS", "BQSTREAM", "LAMBDA", "GOOGLE_CLOUD_FUNCTION", "WUNDERKIND"} KVStoreDestinations = []string{"REDIS"} Destinations = append(ObjectStreamDestinations, KVStoreDestinations...) disableEgress = config.GetBoolVar(false, "disableEgress") diff --git a/services/streammanager/streammanager.go b/services/streammanager/streammanager.go index f46933db92..c616d4b594 100644 --- a/services/streammanager/streammanager.go +++ b/services/streammanager/streammanager.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" backendconfig "github.com/rudderlabs/rudder-server/backend-config" "github.com/rudderlabs/rudder-server/services/streammanager/bqstream" "github.com/rudderlabs/rudder-server/services/streammanager/common" @@ -16,6 +18,7 @@ import ( "github.com/rudderlabs/rudder-server/services/streammanager/kinesis" "github.com/rudderlabs/rudder-server/services/streammanager/lambda" "github.com/rudderlabs/rudder-server/services/streammanager/personalize" + "github.com/rudderlabs/rudder-server/services/streammanager/wunderkind" ) // NewProducer delegates the call to the appropriate based on parameter destination for creating producer @@ -48,6 +51,8 @@ func NewProducer(destination *backendconfig.DestinationT, opts common.Opts) (com return lambda.NewProducer(destination, opts) case "GOOGLE_CLOUD_FUNCTION": return googlecloudfunction.NewProducer(destination, opts) + case "WUNDERKIND": + return wunderkind.NewProducer(config.Default, logger.NewLogger().Child("streammanager")) default: return nil, fmt.Errorf("no provider configured for StreamManager") // 404, "No provider configured for StreamManager", "" } diff --git a/services/streammanager/wunderkind/wunderkindmanager.go b/services/streammanager/wunderkind/wunderkindmanager.go new file mode 100644 index 0000000000..8bcb1b7f60 --- /dev/null +++ b/services/streammanager/wunderkind/wunderkindmanager.go @@ -0,0 +1,142 @@ +package wunderkind + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + + "github.com/aws/aws-sdk-go/service/lambda" + jsoniter "github.com/json-iterator/go" + + "github.com/rudderlabs/rudder-go-kit/awsutil" + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + obskit "github.com/rudderlabs/rudder-observability-kit/go/labels" + "github.com/rudderlabs/rudder-server/services/streammanager/common" +) + +const ( + InvocationType = "RequestResponse" + WunderkindRegion = "WUNDERKIND_REGION" + WunderkindIamRoleArn = "WUNDERKIND_IAM_ROLE_ARN" + WunderkindExternalId = "WUNDERKIND_EXTERNAL_ID" + WunderkindLambda = "WUNDERKIND_LAMBDA" +) + +var jsonFast = jsoniter.ConfigCompatibleWithStandardLibrary + +type inputData struct { + Payload string `json:"payload"` +} + +type Producer struct { + conf *config.Config + client lambdaClient + logger logger.Logger +} + +type lambdaClient interface { + Invoke(input *lambda.InvokeInput) (*lambda.InvokeOutput, error) +} + +// NewProducer creates a producer based on destination config +func NewProducer(conf *config.Config, log logger.Logger) (*Producer, error) { + if err := validate(conf); err != nil { + return nil, fmt.Errorf("invalid environment config: %w", err) + } + sessionConfig := &awsutil.SessionConfig{ + Region: conf.GetString(WunderkindRegion, ""), + IAMRoleARN: conf.GetString(WunderkindIamRoleArn, ""), + ExternalID: conf.GetString(WunderkindExternalId, ""), + RoleBasedAuth: true, + } + awsSession, err := awsutil.CreateSession(sessionConfig) + if err != nil { + return nil, fmt.Errorf("creating session: %w", err) + } + + return &Producer{ + conf: conf, + client: lambda.New(awsSession), + logger: log.Child("wunderkind"), + }, nil +} + +// Produce creates a producer and send data to Lambda. +func (p *Producer) Produce(jsonData json.RawMessage, _ interface{}) (int, string, string) { + client := p.client + var input inputData + err := jsonFast.Unmarshal(jsonData, &input) + if err != nil { + returnMessage := "[Wunderkind] error while unmarshalling jsonData :: " + err.Error() + return http.StatusBadRequest, "Failure", returnMessage + } + if input.Payload == "" { + return http.StatusBadRequest, "Failure", "[Wunderkind] error :: Invalid payload" + } + + var invokeInput lambda.InvokeInput + wunderKindLambda := p.conf.GetString(WunderkindLambda, "") + invokeInput.SetFunctionName(wunderKindLambda) + invokeInput.SetPayload([]byte(input.Payload)) + invokeInput.SetInvocationType(InvocationType) + invokeInput.SetLogType("Tail") + + if err = invokeInput.Validate(); err != nil { + return http.StatusBadRequest, "Failure", "[Wunderkind] error :: Invalid invokeInput :: " + err.Error() + } + + response, err := client.Invoke(&invokeInput) + if err != nil { + statusCode, respStatus, responseMessage := common.ParseAWSError(err) + p.logger.Warnn("Invocation", + logger.NewStringField("statusCode", fmt.Sprint(statusCode)), + logger.NewStringField("respStatus", respStatus), + logger.NewStringField("responseMessage", responseMessage), + obskit.Error(err), + ) + return statusCode, respStatus, responseMessage + } + + // handle a case where lambda invocation is successful, but there is an issue with the payload. + if response.FunctionError != nil { + statusCode := http.StatusBadRequest + respStatus := "Failure" + responseMessage := string(response.Payload) + p.logger.Warnn("Function execution", + logger.NewStringField("statusCode", fmt.Sprint(statusCode)), + logger.NewStringField("respStatus", respStatus), + logger.NewStringField("responseMessage", responseMessage), + logger.NewStringField("functionError", *response.FunctionError), + ) + return statusCode, respStatus, responseMessage + } + + return http.StatusOK, "Success", "Event delivered to Wunderkind :: " + wunderKindLambda +} + +func validate(conf *config.Config) error { + if conf.GetString(WunderkindRegion, "") == "" { + return errors.New("region cannot be empty") + } + + if conf.GetString(WunderkindIamRoleArn, "") == "" { + return errors.New("iam role arn cannot be empty") + } + + if conf.GetString(WunderkindExternalId, "") == "" { + return errors.New("external id cannot be empty") + } + + if conf.GetString(WunderkindLambda, "") == "" { + return errors.New("lambda function cannot be empty") + } + + return nil +} + +func (*Producer) Close() error { + // no-op + return nil +} diff --git a/services/streammanager/wunderkind/wunderkindmanager_test.go b/services/streammanager/wunderkind/wunderkindmanager_test.go new file mode 100644 index 0000000000..cda2c066a7 --- /dev/null +++ b/services/streammanager/wunderkind/wunderkindmanager_test.go @@ -0,0 +1,142 @@ +package wunderkind + +import ( + "encoding/json" + "errors" + "net/http" + "testing" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/lambda" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/config" + "github.com/rudderlabs/rudder-go-kit/logger" + "github.com/rudderlabs/rudder-go-kit/logger/mock_logger" + mock_lambda "github.com/rudderlabs/rudder-server/mocks/services/streammanager/lambda" +) + +var ( + sampleMessage = "sample payload" + sampleFunction = "sampleLambdaFunction" + sampleExternalID = "sampleExternalID" + sampleIAMRoleARN = "sampleRoleArn" + invocationType = "RequestResponse" +) + +func TestNewProducer(t *testing.T) { + t.Run("valid", func(t *testing.T) { + conf := config.New() + conf.Set("WUNDERKIND_REGION", "us-east-1") + conf.Set("WUNDERKIND_IAM_ROLE_ARN", sampleIAMRoleARN) + conf.Set("WUNDERKIND_EXTERNAL_ID", sampleExternalID) + conf.Set("WUNDERKIND_LAMBDA", sampleFunction) + producer, err := NewProducer(conf, logger.NOP) + require.Nil(t, err) + require.NotNil(t, producer) + require.NotNil(t, producer.client) + }) + + t.Run("empty external id", func(t *testing.T) { + conf := config.New() + conf.Set("WUNDERKIND_REGION", "us-east-1") + conf.Set("WUNDERKIND_IAM_ROLE_ARN", sampleIAMRoleARN) + conf.Set("WUNDERKIND_EXTERNAL_ID", "") + conf.Set("WUNDERKIND_LAMBDA", sampleFunction) + producer, err := NewProducer(conf, logger.NOP) + require.Nil(t, producer) + require.Equal(t, "invalid environment config: external id cannot be empty", err.Error()) + }) +} + +func TestProduceWithInvalidData(t *testing.T) { + ctrl := gomock.NewController(t) + mockClient := mock_lambda.NewMockLambdaClient(ctrl) + producer := &Producer{client: mockClient} + + t.Run("Invalid input", func(t *testing.T) { + sampleEventJson := []byte("invalid json") + statusCode, statusMsg, respMsg := producer.Produce(sampleEventJson, map[string]string{}) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, "Failure", statusMsg) + require.Contains(t, respMsg, "[Wunderkind] error while unmarshalling jsonData ") + }) + + t.Run("Empty payload", func(t *testing.T) { + sampleEventJson, _ := json.Marshal(map[string]interface{}{ + "payload": "", + }) + statusCode, statusMsg, respMsg := producer.Produce(sampleEventJson, map[string]string{}) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, "Failure", statusMsg) + require.Contains(t, respMsg, "[Wunderkind] error :: Invalid payload") + }) +} + +func TestProduceWithServiceResponse(t *testing.T) { + conf := config.New() + conf.Set("WUNDERKIND_REGION", "us-east-1") + conf.Set("WUNDERKIND_IAM_ROLE_ARN", sampleIAMRoleARN) + conf.Set("WUNDERKIND_EXTERNAL_ID", sampleExternalID) + conf.Set("WUNDERKIND_LAMBDA", sampleFunction) + + ctrl := gomock.NewController(t) + mockClient := mock_lambda.NewMockLambdaClient(ctrl) + mockLogger := mock_logger.NewMockLogger(ctrl) + producer := &Producer{conf: conf, client: mockClient, logger: mockLogger} + + sampleEventJson, _ := json.Marshal(map[string]interface{}{ + "payload": sampleMessage, + }) + + destConfig := map[string]string{} + + var sampleInput lambda.InvokeInput + sampleInput.SetFunctionName(sampleFunction) + sampleInput.SetPayload([]byte(sampleMessage)) + sampleInput.SetInvocationType(invocationType) + sampleInput.SetLogType("Tail") + + t.Run("success", func(t *testing.T) { + mockClient.EXPECT().Invoke(&sampleInput).Return(&lambda.InvokeOutput{}, nil) + statusCode, statusMsg, respMsg := producer.Produce(sampleEventJson, destConfig) + require.Equal(t, http.StatusOK, statusCode) + require.Equal(t, "Success", statusMsg) + require.NotEmpty(t, respMsg) + }) + + t.Run("general error", func(t *testing.T) { + errorCode := "errorCode" + mockClient.EXPECT().Invoke(&sampleInput).Return(nil, errors.New(errorCode)) + mockLogger.EXPECT().Warnn(gomock.Any(), gomock.Any()).Times(1) + statusCode, statusMsg, respMsg := producer.Produce(sampleEventJson, destConfig) + require.Equal(t, http.StatusInternalServerError, statusCode) + require.Equal(t, "Failure", statusMsg) + require.NotEmpty(t, respMsg) + }) + + t.Run("when lambda invocation is successful, but there is an issue with the payload", func(t *testing.T) { + mockClient.EXPECT().Invoke(&sampleInput).Return(&lambda.InvokeOutput{ + StatusCode: aws.Int64(http.StatusOK), + FunctionError: aws.String("Unhandled"), + ExecutedVersion: aws.String("$LATEST"), + }, nil) + mockLogger.EXPECT().Warnn(gomock.Any(), gomock.Any()).Times(1) + statusCode, statusMsg, _ := producer.Produce(sampleEventJson, destConfig) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, "Failure", statusMsg) + }) + + t.Run("aws error", func(t *testing.T) { + errorCode := "errorCode" + mockClient.EXPECT().Invoke(&sampleInput).Return( + nil, awserr.NewRequestFailure(awserr.New(errorCode, errorCode, errors.New(errorCode)), http.StatusBadRequest, "request-id")) + mockLogger.EXPECT().Warnn(gomock.Any(), gomock.Any()).Times(1) + statusCode, statusMsg, respMsg := producer.Produce(sampleEventJson, destConfig) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, errorCode, statusMsg) + require.NotEmpty(t, respMsg) + }) +}