Skip to content

Commit

Permalink
feat(subscriber): add ability to consume items from queue (#101)
Browse files Browse the repository at this point in the history
This adds a handler to process SQS messages. We were previously dumping
only subscription requests into the queue. This commit additional makes
it so that any request recognized by the handler can also be consumed
via the queue.
  • Loading branch information
jta authored Nov 21, 2023
1 parent b262e55 commit d57d354
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 2 deletions.
2 changes: 1 addition & 1 deletion handler/subscriber/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (h *Handler) HandleDiscoveryRequest(ctx context.Context, discoveryReq *Disc
return resp, fmt.Errorf("failed to handle subscription request: %w", err)
}
resp.Discovery.Subscription.Add(s.Subscription)
} else if err := h.Queue.Put(ctx, subscriptionRequest); err != nil {
} else if err := h.Queue.Put(ctx, &Request{SubscriptionRequest: subscriptionRequest}); err != nil {
return resp, fmt.Errorf("failed to write to queue: %w", err)
}
}
Expand Down
26 changes: 25 additions & 1 deletion handler/subscriber/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@ package subscriber

import (
"context"
"encoding/json"
"errors"
"fmt"
"runtime"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
"github.com/go-logr/logr"

"github.com/observeinc/aws-sam-testing/handler"
)
Expand Down Expand Up @@ -57,6 +60,27 @@ func (h *Handler) HandleRequest(ctx context.Context, req *Request) (*Response, e
}
}

func (h *Handler) HandleSQS(ctx context.Context, request events.SQSEvent) (response events.SQSEventResponse, err error) {
logger := logr.FromContextOrDiscard(ctx)
for _, record := range request.Records {
var req Request
var err error

if err = json.Unmarshal([]byte(record.Body), &req); err == nil {
_, err = h.HandleRequest(ctx, &req)
}

if err != nil {
// SQS record will be under 256KB, should be ok to log
logger.Error(err, "failed to process request", "body", record.Body)
response.BatchItemFailures = append(response.BatchItemFailures, events.SQSBatchItemFailure{
ItemIdentifier: record.MessageId,
})
}
}
return response, nil
}

func New(cfg *Config) (*Handler, error) {
if err := cfg.Validate(); err != nil {
return nil, err
Expand All @@ -83,7 +107,7 @@ func New(cfg *Config) (*Handler, error) {
h.Logger = *cfg.Logger
}

if err := h.Mux.Register(h.HandleRequest); err != nil {
if err := h.Mux.Register(h.HandleRequest, h.HandleSQS); err != nil {
return nil, fmt.Errorf("failed to register handler: %w", err)
}

Expand Down
74 changes: 74 additions & 0 deletions handler/subscriber/handler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package subscriber_test

import (
"context"
"fmt"
"testing"

"github.com/aws/aws-lambda-go/events"
"github.com/google/go-cmp/cmp"

"github.com/observeinc/aws-sam-testing/handler/handlertest"
"github.com/observeinc/aws-sam-testing/handler/subscriber"
)

func TestHandleSQS(t *testing.T) {
t.Parallel()

testcases := []struct {
Event events.SQSEvent
Expect events.SQSEventResponse
}{
{
Event: events.SQSEvent{
Records: []events.SQSMessage{
{
MessageId: "invalid request",
Body: "",
},
},
},
Expect: events.SQSEventResponse{
BatchItemFailures: []events.SQSBatchItemFailure{
{ItemIdentifier: "invalid request"},
},
},
},
{
Event: events.SQSEvent{
Records: []events.SQSMessage{
{
MessageId: "ok",
Body: "{\"subscribe\": {\"logGroups\":[{\"logGroupName\":\"/aws/hello\"}]}}",
},
},
},
Expect: events.SQSEventResponse{},
},
}

for i, tt := range testcases {
tt := tt
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()

s, err := subscriber.New(&subscriber.Config{
CloudWatchLogsClient: &handlertest.CloudWatchLogsClient{},
FilterName: "test",
LogGroupNamePrefixes: []string{"*"},
})
if err != nil {
t.Fatal(err)
}

resp, err := s.HandleSQS(context.Background(), tt.Event)
if err != nil {
t.Fatal(err)
}

if diff := cmp.Diff(resp, tt.Expect); diff != "" {
t.Error(diff)
}
})
}
}

0 comments on commit d57d354

Please sign in to comment.