Skip to content

Commit

Permalink
fix: refactor router into mux (#56)
Browse files Browse the repository at this point in the history
Small refactor based off of work on future lambdas:

- rename `router.Handler` to `handler.Mux`
- embed Mux in Handler
- move common logging functionality into the Mux

This change allows us to reduce some boilerplate when knocking out more
functions.
  • Loading branch information
jta authored Nov 3, 2023
1 parent 56a7c53 commit 5ed31ce
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 64 deletions.
2 changes: 1 addition & 1 deletion cmd/forwarder/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/go-logr/logr"
"github.com/sethvargo/go-envconfig"

"github.com/observeinc/aws-sam-testing/handlers/forwarder"
"github.com/observeinc/aws-sam-testing/handler/forwarder"
"github.com/observeinc/aws-sam-testing/logging"
)

Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"
"testing"

"github.com/observeinc/aws-sam-testing/handlers/forwarder"
"github.com/observeinc/aws-sam-testing/handler/forwarder"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down
19 changes: 9 additions & 10 deletions handlers/forwarder/handler.go → handler/forwarder/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/go-logr/logr"

"github.com/observeinc/aws-sam-testing/handler"
)

var errNoLambdaContext = fmt.Errorf("no lambda context found")
Expand All @@ -25,10 +27,11 @@ type S3Client interface {
}

type Handler struct {
handler.Mux

DestinationURI *url.URL
LogPrefix string
S3Client S3Client
Logger logr.Logger
}

// GetCopyObjectInput constructs the input struct for CopyObject.
Expand Down Expand Up @@ -80,14 +83,7 @@ func (h *Handler) Handle(ctx context.Context, request events.SQSEvent) (response
return response, errNoLambdaContext
}

logger := h.Logger.WithValues("requestId", lctx.AwsRequestID)

logger.V(3).Info("handling request")
defer func() {
if err != nil {
logger.Error(err, "failed to process request", "payload", request)
}
}()
logger := logr.FromContextOrDiscard(ctx)

var messages bytes.Buffer
defer func() {
Expand Down Expand Up @@ -131,12 +127,15 @@ func New(cfg *Config) (*Handler, error) {
DestinationURI: u,
LogPrefix: cfg.LogPrefix,
S3Client: cfg.S3Client,
Logger: logr.Discard(),
}

if cfg.Logger != nil {
h.Logger = *cfg.Logger
}

if err := h.Register(h.Handle); err != nil {
return nil, fmt.Errorf("failed to register handler: %w", err)
}

return h, nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/observeinc/aws-sam-testing/handlers/forwarder"
"github.com/observeinc/aws-sam-testing/handler/forwarder"
)

type MockS3Client struct {
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"fmt"
"testing"

"github.com/observeinc/aws-sam-testing/handlers/forwarder"
"github.com/observeinc/aws-sam-testing/handler/forwarder"

"github.com/google/go-cmp/cmp"
)
Expand Down
File renamed without changes.
68 changes: 49 additions & 19 deletions handlers/router/router.go → handler/mux.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package router
package handler

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"sync"

"github.com/aws/aws-lambda-go/lambdacontext"
"github.com/go-logr/logr"
)

var (
Expand All @@ -19,16 +23,32 @@ var (
ErrHandlerRequireError = errors.New("last return value must be an error")
)

type Router struct {
// Mux for multiple lambda handler entrypoints.
//
// This is a common helper to bridge between the convenience of declaring
// strongly typed lambda handlers, and the flexibility of routing payloads
// via the baseline Invoke method.
type Mux struct {
Logger logr.Logger

handlers map[reflect.Type]reflect.Value
sync.Mutex
}

// Register a lambda handler.
func (r *Router) Register(fs ...any) error {
r.Lock()
defer r.Unlock()
for _, f := range fs {
var _ interface {
Invoke(context.Context, []byte) ([]byte, error)
} = &Mux{}

// Register a set of lambda handlers.
func (m *Mux) Register(fns ...any) error {
m.Lock()
defer m.Unlock()

if m.handlers == nil {
m.handlers = make(map[reflect.Type]reflect.Value)
}

for _, f := range fns {
handler := reflect.ValueOf(f)
handlerType := reflect.TypeOf(f)
if k := handlerType.Kind(); k != reflect.Func {
Expand All @@ -52,20 +72,36 @@ func (r *Router) Register(fs ...any) error {
}

eventType := handlerType.In(handlerType.NumIn() - 1)
if _, ok := r.handlers[eventType]; ok {
if _, ok := m.handlers[eventType]; ok {
return fmt.Errorf("event type %s: %w", eventType, ErrHandlerAlreadyRegistered)
}

r.handlers[eventType] = handler
m.handlers[eventType] = handler
}
return nil
}

func (r *Router) Handle(ctx context.Context, v json.RawMessage) (json.RawMessage, error) {
for eventType, handler := range r.handlers {
func (m *Mux) Invoke(ctx context.Context, payload []byte) (response []byte, err error) {
logger := m.Logger
if lctx, ok := lambdacontext.FromContext(ctx); ok {
logger = m.Logger.WithValues("requestId", lctx.AwsRequestID)
ctx = logr.NewContext(ctx, logger)
}

logger.V(3).Info("handling request")
defer func() {
if err != nil {
logger.Error(err, "failed to process request", "payload", string(payload))
}
}()

for eventType, handler := range m.handlers {
event := reflect.New(eventType)

if err := json.Unmarshal(v, event.Interface()); err != nil {
dec := json.NewDecoder(bytes.NewReader(payload))
dec.DisallowUnknownFields()

if err := dec.Decode(event.Interface()); err != nil {
// assume event was destined for a different handler
continue
}
Expand All @@ -80,13 +116,7 @@ func (r *Router) Handle(ctx context.Context, v json.RawMessage) (json.RawMessage
if err != nil {
return nil, fmt.Errorf("failed to marshal response: %w", err)
}
return json.RawMessage(data), nil
return data, nil
}
return nil, ErrNoHandler
}

func New() *Router {
return &Router{
handlers: make(map[reflect.Type]reflect.Value),
}
}
62 changes: 31 additions & 31 deletions handlers/router/router_test.go → handler/mux_test.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
package router_test
package handler_test

import (
"context"
"encoding/json"
"fmt"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/observeinc/aws-sam-testing/handlers/router"
"github.com/observeinc/aws-sam-testing/handler"
)

func TestRouter(t *testing.T) {
func TestHandler(t *testing.T) {
t.Parallel()
testcases := []struct {
Handlers []any
Checks map[string]string
HandlerFuncs []any
Checks map[string]string
}{
{
Handlers: []any{
HandlerFuncs: []any{
func(_ context.Context, _ string) (string, error) { return "string", nil },
func(_ context.Context, _ int) (string, error) { return "int", nil },
func(_ context.Context, _ struct{ V string }) (string, error) { return "custom", nil },
func(_ context.Context, _ struct{ V string }) (string, error) { return "v", nil },
func(_ context.Context, _ struct{ W string }) (string, error) { return "w", nil },
},
Checks: map[string]string{
`1`: `"int"`,
`"1"`: `"string"`,
`{"v": "test"}`: `"custom"`,
`{"v": "test"}`: `"v"`,
`{"w": "test"}`: `"w"`,
},
},
}
Expand All @@ -36,14 +37,14 @@ func TestRouter(t *testing.T) {
tc := tc
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
r := router.New()

if err := r.Register(tc.Handlers...); err != nil {
var h handler.Mux
if err := h.Register(tc.HandlerFuncs...); err != nil {
t.Fatal(err)
}

for input, output := range tc.Checks {
result, err := r.Handle(context.Background(), json.RawMessage(input))
result, err := h.Invoke(context.Background(), []byte(input))
if err != nil {
t.Fatalf("failed to validate %s: %s", input, err)
}
Expand All @@ -56,55 +57,55 @@ func TestRouter(t *testing.T) {
}
}

func TestRouterErrors(t *testing.T) {
func TestHandlerErrors(t *testing.T) {
t.Parallel()

testcases := []struct {
Handlers []any
ExpectErr error
HandlerFuncs []any
ExpectErr error
}{
{
// no handlers, no problem
},
{
Handlers: []any{
HandlerFuncs: []any{
"1",
},
ExpectErr: router.ErrHandlerType,
ExpectErr: handler.ErrHandlerType,
},
{
Handlers: []any{
HandlerFuncs: []any{
func() {},
},
ExpectErr: router.ErrHandlerArgsCount,
ExpectErr: handler.ErrHandlerArgsCount,
},
{
Handlers: []any{
HandlerFuncs: []any{
func(int, int) {},
},
ExpectErr: router.ErrHandlerRequireContext,
ExpectErr: handler.ErrHandlerRequireContext,
},
{
Handlers: []any{
HandlerFuncs: []any{
func(context.Context, int) {},
},
ExpectErr: router.ErrHandlerReturnCount,
ExpectErr: handler.ErrHandlerReturnCount,
},
{
Handlers: []any{
HandlerFuncs: []any{
func(context.Context, int) (int, int) { return 1, 1 },
},
ExpectErr: router.ErrHandlerRequireError,
ExpectErr: handler.ErrHandlerRequireError,
},
{
Handlers: []any{
HandlerFuncs: []any{
func(context.Context, int) (int, error) { return 1, nil },
func(context.Context, int) (int, error) { return 1, nil },
},
ExpectErr: router.ErrHandlerAlreadyRegistered,
ExpectErr: handler.ErrHandlerAlreadyRegistered,
},
{
Handlers: []any{
HandlerFuncs: []any{
func(context.Context, string) (int, error) { return 1, nil },
func(context.Context, int) (int, error) { return 1, nil },
func(context.Context, float64) (int, error) { return 1, nil },
Expand All @@ -116,9 +117,8 @@ func TestRouterErrors(t *testing.T) {
tc := tc
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
r := router.New()

err := r.Register(tc.Handlers...)
var h handler.Mux
err := h.Register(tc.HandlerFuncs...)
if diff := cmp.Diff(err, tc.ExpectErr, cmpopts.EquateErrors()); diff != "" {
t.Error("unexpected error", diff)
}
Expand Down

0 comments on commit 5ed31ce

Please sign in to comment.