diff --git a/cmd/forwarder/main.go b/cmd/forwarder/main.go index 7a905c86..429dfb61 100644 --- a/cmd/forwarder/main.go +++ b/cmd/forwarder/main.go @@ -10,7 +10,7 @@ import ( "github.com/go-logr/logr" "github.com/sethvargo/go-envconfig" - "github.com/observeinc/aws-sam-testing/handlers/forwarder" + "github.com/observeinc/aws-sam-testing/handler/forwarder" "github.com/observeinc/aws-sam-testing/logging" ) diff --git a/handlers/forwarder/config.go b/handler/forwarder/config.go similarity index 100% rename from handlers/forwarder/config.go rename to handler/forwarder/config.go diff --git a/handlers/forwarder/config_test.go b/handler/forwarder/config_test.go similarity index 93% rename from handlers/forwarder/config_test.go rename to handler/forwarder/config_test.go index d3177dee..cf53697d 100644 --- a/handlers/forwarder/config_test.go +++ b/handler/forwarder/config_test.go @@ -4,7 +4,7 @@ import ( "fmt" "testing" - "github.com/observeinc/aws-sam-testing/handlers/forwarder" + "github.com/observeinc/aws-sam-testing/handler/forwarder" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" diff --git a/handlers/forwarder/handler.go b/handler/forwarder/handler.go similarity index 92% rename from handlers/forwarder/handler.go rename to handler/forwarder/handler.go index 25327b69..1205d53e 100644 --- a/handlers/forwarder/handler.go +++ b/handler/forwarder/handler.go @@ -15,6 +15,8 @@ import ( "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/go-logr/logr" + + "github.com/observeinc/aws-sam-testing/handler" ) var errNoLambdaContext = fmt.Errorf("no lambda context found") @@ -25,10 +27,11 @@ type S3Client interface { } type Handler struct { + handler.Mux + DestinationURI *url.URL LogPrefix string S3Client S3Client - Logger logr.Logger } // GetCopyObjectInput constructs the input struct for CopyObject. @@ -80,14 +83,7 @@ func (h *Handler) Handle(ctx context.Context, request events.SQSEvent) (response return response, errNoLambdaContext } - logger := h.Logger.WithValues("requestId", lctx.AwsRequestID) - - logger.V(3).Info("handling request") - defer func() { - if err != nil { - logger.Error(err, "failed to process request", "payload", request) - } - }() + logger := logr.FromContextOrDiscard(ctx) var messages bytes.Buffer defer func() { @@ -131,12 +127,15 @@ func New(cfg *Config) (*Handler, error) { DestinationURI: u, LogPrefix: cfg.LogPrefix, S3Client: cfg.S3Client, - Logger: logr.Discard(), } if cfg.Logger != nil { h.Logger = *cfg.Logger } + if err := h.Register(h.Handle); err != nil { + return nil, fmt.Errorf("failed to register handler: %w", err) + } + return h, nil } diff --git a/handlers/forwarder/handler_test.go b/handler/forwarder/handler_test.go similarity index 99% rename from handlers/forwarder/handler_test.go rename to handler/forwarder/handler_test.go index d2aeedc0..69ea0ee1 100644 --- a/handlers/forwarder/handler_test.go +++ b/handler/forwarder/handler_test.go @@ -18,7 +18,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/observeinc/aws-sam-testing/handlers/forwarder" + "github.com/observeinc/aws-sam-testing/handler/forwarder" ) type MockS3Client struct { diff --git a/handlers/forwarder/message.go b/handler/forwarder/message.go similarity index 100% rename from handlers/forwarder/message.go rename to handler/forwarder/message.go diff --git a/handlers/forwarder/message_test.go b/handler/forwarder/message_test.go similarity index 99% rename from handlers/forwarder/message_test.go rename to handler/forwarder/message_test.go index f7aadb5f..2168e080 100644 --- a/handlers/forwarder/message_test.go +++ b/handler/forwarder/message_test.go @@ -5,7 +5,7 @@ import ( "fmt" "testing" - "github.com/observeinc/aws-sam-testing/handlers/forwarder" + "github.com/observeinc/aws-sam-testing/handler/forwarder" "github.com/google/go-cmp/cmp" ) diff --git a/handlers/forwarder/testdata/event.json b/handler/forwarder/testdata/event.json similarity index 100% rename from handlers/forwarder/testdata/event.json rename to handler/forwarder/testdata/event.json diff --git a/handlers/router/router.go b/handler/mux.go similarity index 57% rename from handlers/router/router.go rename to handler/mux.go index df0705ca..d6ebbe26 100644 --- a/handlers/router/router.go +++ b/handler/mux.go @@ -1,12 +1,16 @@ -package router +package handler import ( + "bytes" "context" "encoding/json" "errors" "fmt" "reflect" "sync" + + "github.com/aws/aws-lambda-go/lambdacontext" + "github.com/go-logr/logr" ) var ( @@ -19,16 +23,32 @@ var ( ErrHandlerRequireError = errors.New("last return value must be an error") ) -type Router struct { +// Mux for multiple lambda handler entrypoints. +// +// This is a common helper to bridge between the convenience of declaring +// strongly typed lambda handlers, and the flexibility of routing payloads +// via the baseline Invoke method. +type Mux struct { + Logger logr.Logger + handlers map[reflect.Type]reflect.Value sync.Mutex } -// Register a lambda handler. -func (r *Router) Register(fs ...any) error { - r.Lock() - defer r.Unlock() - for _, f := range fs { +var _ interface { + Invoke(context.Context, []byte) ([]byte, error) +} = &Mux{} + +// Register a set of lambda handlers. +func (m *Mux) Register(fns ...any) error { + m.Lock() + defer m.Unlock() + + if m.handlers == nil { + m.handlers = make(map[reflect.Type]reflect.Value) + } + + for _, f := range fns { handler := reflect.ValueOf(f) handlerType := reflect.TypeOf(f) if k := handlerType.Kind(); k != reflect.Func { @@ -52,20 +72,36 @@ func (r *Router) Register(fs ...any) error { } eventType := handlerType.In(handlerType.NumIn() - 1) - if _, ok := r.handlers[eventType]; ok { + if _, ok := m.handlers[eventType]; ok { return fmt.Errorf("event type %s: %w", eventType, ErrHandlerAlreadyRegistered) } - r.handlers[eventType] = handler + m.handlers[eventType] = handler } return nil } -func (r *Router) Handle(ctx context.Context, v json.RawMessage) (json.RawMessage, error) { - for eventType, handler := range r.handlers { +func (m *Mux) Invoke(ctx context.Context, payload []byte) (response []byte, err error) { + logger := m.Logger + if lctx, ok := lambdacontext.FromContext(ctx); ok { + logger = m.Logger.WithValues("requestId", lctx.AwsRequestID) + ctx = logr.NewContext(ctx, logger) + } + + logger.V(3).Info("handling request") + defer func() { + if err != nil { + logger.Error(err, "failed to process request", "payload", string(payload)) + } + }() + + for eventType, handler := range m.handlers { event := reflect.New(eventType) - if err := json.Unmarshal(v, event.Interface()); err != nil { + dec := json.NewDecoder(bytes.NewReader(payload)) + dec.DisallowUnknownFields() + + if err := dec.Decode(event.Interface()); err != nil { // assume event was destined for a different handler continue } @@ -80,13 +116,7 @@ func (r *Router) Handle(ctx context.Context, v json.RawMessage) (json.RawMessage if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) } - return json.RawMessage(data), nil + return data, nil } return nil, ErrNoHandler } - -func New() *Router { - return &Router{ - handlers: make(map[reflect.Type]reflect.Value), - } -} diff --git a/handlers/router/router_test.go b/handler/mux_test.go similarity index 62% rename from handlers/router/router_test.go rename to handler/mux_test.go index 9b4e2daa..74ab3e0f 100644 --- a/handlers/router/router_test.go +++ b/handler/mux_test.go @@ -1,33 +1,34 @@ -package router_test +package handler_test import ( "context" - "encoding/json" "fmt" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/observeinc/aws-sam-testing/handlers/router" + "github.com/observeinc/aws-sam-testing/handler" ) -func TestRouter(t *testing.T) { +func TestHandler(t *testing.T) { t.Parallel() testcases := []struct { - Handlers []any - Checks map[string]string + HandlerFuncs []any + Checks map[string]string }{ { - Handlers: []any{ + HandlerFuncs: []any{ func(_ context.Context, _ string) (string, error) { return "string", nil }, func(_ context.Context, _ int) (string, error) { return "int", nil }, - func(_ context.Context, _ struct{ V string }) (string, error) { return "custom", nil }, + func(_ context.Context, _ struct{ V string }) (string, error) { return "v", nil }, + func(_ context.Context, _ struct{ W string }) (string, error) { return "w", nil }, }, Checks: map[string]string{ `1`: `"int"`, `"1"`: `"string"`, - `{"v": "test"}`: `"custom"`, + `{"v": "test"}`: `"v"`, + `{"w": "test"}`: `"w"`, }, }, } @@ -36,14 +37,14 @@ func TestRouter(t *testing.T) { tc := tc t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Parallel() - r := router.New() - if err := r.Register(tc.Handlers...); err != nil { + var h handler.Mux + if err := h.Register(tc.HandlerFuncs...); err != nil { t.Fatal(err) } for input, output := range tc.Checks { - result, err := r.Handle(context.Background(), json.RawMessage(input)) + result, err := h.Invoke(context.Background(), []byte(input)) if err != nil { t.Fatalf("failed to validate %s: %s", input, err) } @@ -56,55 +57,55 @@ func TestRouter(t *testing.T) { } } -func TestRouterErrors(t *testing.T) { +func TestHandlerErrors(t *testing.T) { t.Parallel() testcases := []struct { - Handlers []any - ExpectErr error + HandlerFuncs []any + ExpectErr error }{ { // no handlers, no problem }, { - Handlers: []any{ + HandlerFuncs: []any{ "1", }, - ExpectErr: router.ErrHandlerType, + ExpectErr: handler.ErrHandlerType, }, { - Handlers: []any{ + HandlerFuncs: []any{ func() {}, }, - ExpectErr: router.ErrHandlerArgsCount, + ExpectErr: handler.ErrHandlerArgsCount, }, { - Handlers: []any{ + HandlerFuncs: []any{ func(int, int) {}, }, - ExpectErr: router.ErrHandlerRequireContext, + ExpectErr: handler.ErrHandlerRequireContext, }, { - Handlers: []any{ + HandlerFuncs: []any{ func(context.Context, int) {}, }, - ExpectErr: router.ErrHandlerReturnCount, + ExpectErr: handler.ErrHandlerReturnCount, }, { - Handlers: []any{ + HandlerFuncs: []any{ func(context.Context, int) (int, int) { return 1, 1 }, }, - ExpectErr: router.ErrHandlerRequireError, + ExpectErr: handler.ErrHandlerRequireError, }, { - Handlers: []any{ + HandlerFuncs: []any{ func(context.Context, int) (int, error) { return 1, nil }, func(context.Context, int) (int, error) { return 1, nil }, }, - ExpectErr: router.ErrHandlerAlreadyRegistered, + ExpectErr: handler.ErrHandlerAlreadyRegistered, }, { - Handlers: []any{ + HandlerFuncs: []any{ func(context.Context, string) (int, error) { return 1, nil }, func(context.Context, int) (int, error) { return 1, nil }, func(context.Context, float64) (int, error) { return 1, nil }, @@ -116,9 +117,8 @@ func TestRouterErrors(t *testing.T) { tc := tc t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { t.Parallel() - r := router.New() - - err := r.Register(tc.Handlers...) + var h handler.Mux + err := h.Register(tc.HandlerFuncs...) if diff := cmp.Diff(err, tc.ExpectErr, cmpopts.EquateErrors()); diff != "" { t.Error("unexpected error", diff) }