Skip to content

Commit

Permalink
feat(subscriber): add subscription diff logic (#80)
Browse files Browse the repository at this point in the history
First pass at fleshing out the `HandleSubscribe` endpoint. This is
unlikely to have all the safeguards we need, but for now:

- we only manage subscription filters which have our filter name as a
  prefix.
- if there are already 2 filters, we do nothing

The diffing function is broken out so that we can test it without
involving AWS API stubs.
  • Loading branch information
jta authored Nov 14, 2023
1 parent 569840f commit c68faae
Show file tree
Hide file tree
Showing 6 changed files with 249 additions and 8 deletions.
9 changes: 9 additions & 0 deletions apps/subscriber/events/subscribe.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"subscribe": {
"logGroups": [
{
"logGroupName": "does_not_exist"
}
]
}
}
4 changes: 0 additions & 4 deletions apps/subscriber/events/sync.json

This file was deleted.

3 changes: 3 additions & 0 deletions apps/subscriber/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ Resources:
- Effect: Allow
Action:
- logs:DescribeLogGroups
- logs:DescribeSubscriptionFilters
- logs:DeleteSubscriptionFilter
- logs:PutSubscriptionFilter
Resource: "*"
LogGroup:
Type: 'AWS::Logs::LogGroup'
Expand Down
120 changes: 117 additions & 3 deletions handler/subscriber/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,26 @@ import (
"context"
"errors"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
"github.com/go-logr/logr"

"github.com/observeinc/aws-sam-testing/handler"
)

var ErrNotImplemented = errors.New("not implemented")
var (
MaxSubscriptionFilterCount = 2
ErrNotImplemented = errors.New("not implemented")
)

type CloudWatchLogsClient interface {
DescribeLogGroups(context.Context, *cloudwatchlogs.DescribeLogGroupsInput, ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.DescribeLogGroupsOutput, error)
DescribeSubscriptionFilters(context.Context, *cloudwatchlogs.DescribeSubscriptionFiltersInput, ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.DescribeSubscriptionFiltersOutput, error)
PutSubscriptionFilter(context.Context, *cloudwatchlogs.PutSubscriptionFilterInput, ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.PutSubscriptionFilterOutput, error)
DeleteSubscriptionFilter(context.Context, *cloudwatchlogs.DeleteSubscriptionFilterInput, ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.DeleteSubscriptionFilterOutput, error)
}

type Queue interface {
Expand All @@ -25,6 +35,8 @@ type Handler struct {

Queue Queue
Client CloudWatchLogsClient

subscriptionFilter types.SubscriptionFilter
}

func (h *Handler) HandleDiscoveryRequest(ctx context.Context, discoveryReq *DiscoveryRequest) (*Response, error) {
Expand All @@ -46,11 +58,107 @@ func (h *Handler) HandleDiscoveryRequest(ctx context.Context, discoveryReq *Disc
return &Response{DiscoveryResponse: &discoveryResp}, nil
}

func (h *Handler) HandleSubscriptionRequest(_ context.Context, _ *SubscriptionRequest) (*Response, error) {
// to be implemented
func (h *Handler) HandleSubscriptionRequest(ctx context.Context, subReq *SubscriptionRequest) (*Response, error) {
for _, logGroup := range subReq.LogGroups {
if err := h.SubscribeLogGroup(ctx, logGroup); err != nil {
return nil, fmt.Errorf("failed to subscribe log group: %w", err)
}
}

return nil, nil
}

func (h *Handler) SubscribeLogGroup(ctx context.Context, logGroup *LogGroup) error {
logger := logr.FromContextOrDiscard(ctx).WithValues("logGroup", logGroup.LogGroupName)

logger.V(6).Info("describing subscription filters")
output, err := h.Client.DescribeSubscriptionFilters(ctx, &cloudwatchlogs.DescribeSubscriptionFiltersInput{
LogGroupName: &logGroup.LogGroupName,
})
if err != nil {
var exc *types.ResourceNotFoundException
if errors.As(err, &exc) {
logger.Info("skipping log group")
return nil
}
return fmt.Errorf("failed to retrieve subscription filters: %w", err)
}

for _, action := range h.SubscriptionFilterDiff(output.SubscriptionFilters) {
switch v := action.(type) {
case *cloudwatchlogs.DeleteSubscriptionFilterInput:
v.LogGroupName = &logGroup.LogGroupName
logger.V(3).Info("deleting subscription filter", "filterName", *v.FilterName)
if _, err := h.Client.DeleteSubscriptionFilter(ctx, v); err != nil {
return fmt.Errorf("failed to delete subscription filter: %w", err)
}
case *cloudwatchlogs.PutSubscriptionFilterInput:
v.LogGroupName = &logGroup.LogGroupName
logger.V(3).Info("updating subscription filter")
if _, err := h.Client.PutSubscriptionFilter(ctx, v); err != nil {
return fmt.Errorf("failed to put subscription filter: %w", err)
}
}
}

return nil
}

func subscriptionFilterEquals(a, b types.SubscriptionFilter) bool {
switch {
case aws.ToString(a.FilterName) != aws.ToString(b.FilterName):
case aws.ToString(a.FilterPattern) != aws.ToString(b.FilterPattern):
case aws.ToString(a.DestinationArn) != aws.ToString(b.DestinationArn):
case aws.ToString(a.RoleArn) != aws.ToString(b.RoleArn):
// do not match log group, since one of the arguments will be the config
// intended for all log groups.
default:
return true
}
return false
}

// SubscriptionFilterDiff returns a list of actions to execute against
// cloudwatch API in order to converge to our intended configuration state.
func (h *Handler) SubscriptionFilterDiff(subscriptionFilters []types.SubscriptionFilter) (actions []any) {
var (
deleted, updated int
deleteOnly = aws.ToString(h.subscriptionFilter.DestinationArn) == ""
)

for _, f := range subscriptionFilters {
if !strings.HasPrefix(aws.ToString(f.FilterName), aws.ToString(h.subscriptionFilter.FilterName)) {
// subscription filter not managed by this handler
continue
}
if deleteOnly || aws.ToString(h.subscriptionFilter.FilterName) != aws.ToString(f.FilterName) {
deleted++
actions = append(actions, &cloudwatchlogs.DeleteSubscriptionFilterInput{
FilterName: f.FilterName,
})
} else if !subscriptionFilterEquals(h.subscriptionFilter, f) {
updated++
actions = append(actions, &cloudwatchlogs.PutSubscriptionFilterInput{
FilterName: h.subscriptionFilter.FilterName,
FilterPattern: h.subscriptionFilter.FilterPattern,
DestinationArn: h.subscriptionFilter.DestinationArn,
RoleArn: h.subscriptionFilter.LogGroupName,
})
}
}

if !deleteOnly && updated == 0 && len(subscriptionFilters)-deleted < MaxSubscriptionFilterCount {
actions = append(actions, &cloudwatchlogs.PutSubscriptionFilterInput{
FilterName: h.subscriptionFilter.FilterName,
FilterPattern: h.subscriptionFilter.FilterPattern,
DestinationArn: h.subscriptionFilter.DestinationArn,
RoleArn: h.subscriptionFilter.LogGroupName,
})
}

return actions
}

func (h *Handler) HandleRequest(ctx context.Context, req *Request) (*Response, error) {
if err := req.Validate(); err != nil {
return nil, fmt.Errorf("failed to validate request: %w", err)
Expand All @@ -74,6 +182,12 @@ func New(cfg *Config) (*Handler, error) {
h := &Handler{
Client: cfg.CloudWatchLogsClient,
Queue: cfg.Queue,
subscriptionFilter: types.SubscriptionFilter{
FilterName: aws.String(cfg.FilterName),
FilterPattern: aws.String(cfg.FilterPattern),
DestinationArn: aws.String(cfg.DestinationARN),
RoleArn: aws.String(cfg.RoleARN),
},
}

if cfg.Logger != nil {
Expand Down
119 changes: 119 additions & 0 deletions handler/subscriber/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"

"github.com/observeinc/aws-sam-testing/handler/handlertest"
"github.com/observeinc/aws-sam-testing/handler/subscriber"
Expand Down Expand Up @@ -142,3 +144,120 @@ func TestHandleDiscovery(t *testing.T) {
})
}
}

func TestSubscriptionFilterDiff(t *testing.T) {
t.Parallel()

testcases := []struct {
Configure types.SubscriptionFilter
Existing []types.SubscriptionFilter
ExpectedActions []any
}{
{
/*
In the absence of a destination ARN, we delete all subscription
filters that contain our filter name as a prefix.
*/
Configure: types.SubscriptionFilter{
FilterName: aws.String("observe"),
},
Existing: []types.SubscriptionFilter{
{
FilterName: aws.String("foo"),
},
{
FilterName: aws.String("observe-logs-subscription"),
},
},
ExpectedActions: []any{
&cloudwatchlogs.DeleteSubscriptionFilterInput{
FilterName: aws.String("observe-logs-subscription"),
},
},
},
{
Configure: types.SubscriptionFilter{
FilterName: aws.String("observe"),
DestinationArn: aws.String("arn:aws:lambda:us-west-2:123456789012:function:example"),
},
Existing: []types.SubscriptionFilter{
{
FilterName: aws.String("foo"),
},
{
FilterName: aws.String("observe-logs-subscription"),
},
},
ExpectedActions: []any{
&cloudwatchlogs.DeleteSubscriptionFilterInput{
FilterName: aws.String("observe-logs-subscription"),
},
&cloudwatchlogs.PutSubscriptionFilterInput{
FilterName: aws.String("observe"),
FilterPattern: aws.String(""),
DestinationArn: aws.String("arn:aws:lambda:us-west-2:123456789012:function:example"),
},
},
},
{
/*
Do nothing if we exceed the two subscription filter limit
*/
Configure: types.SubscriptionFilter{
FilterName: aws.String("observe"),
DestinationArn: aws.String("arn:aws:lambda:us-west-2:123456789012:function:example"),
},
Existing: []types.SubscriptionFilter{},
ExpectedActions: []any{
&cloudwatchlogs.PutSubscriptionFilterInput{
FilterName: aws.String("observe"),
FilterPattern: aws.String(""),
DestinationArn: aws.String("arn:aws:lambda:us-west-2:123456789012:function:example"),
},
},
},
{
Configure: types.SubscriptionFilter{
FilterName: aws.String("observe"),
DestinationArn: aws.String("arn:aws:lambda:us-west-2:123456789012:function:example"),
},
Existing: []types.SubscriptionFilter{
{
FilterName: aws.String("foo"),
},
{
FilterName: aws.String("bar"),
},
},
// no expected actions
},
}

for i, tt := range testcases {
tt := tt
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
s, err := subscriber.New(
&subscriber.Config{
CloudWatchLogsClient: &handlertest.CloudWatchLogsClient{},
Queue: &MockQueue{},
FilterName: aws.ToString(tt.Configure.FilterName),
DestinationARN: aws.ToString(tt.Configure.DestinationArn),
RoleARN: aws.ToString(tt.Configure.RoleArn),
})
if err != nil {
t.Fatal(err)
}

output := s.SubscriptionFilterDiff(tt.Existing)

opts := cmpopts.IgnoreUnexported(
cloudwatchlogs.PutSubscriptionFilterInput{},
cloudwatchlogs.DeleteSubscriptionFilterInput{},
)
if diff := cmp.Diff(output, tt.ExpectedActions, opts); diff != "" {
t.Error(diff)
}
})
}
}
2 changes: 1 addition & 1 deletion integration/scripts/check_subscriber
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ check_result() {
fi
}

echo '{"subscribe": {}}' > ${TMPFILE}
echo '{"subscribe": {"logGroups": [{"logGroupName": "does_not_exist"}]}}' > ${TMPFILE}
RESULT=$(aws lambda invoke \
--function-name ${FUNCTION_NAME} \
--payload fileb://${TMPFILE} ${TMPFILE} \
Expand Down

0 comments on commit c68faae

Please sign in to comment.