Skip to content

Commit

Permalink
pubsub/awssnssqs: Use PublishBatch for sending messages in SNS (#3288)
Browse files Browse the repository at this point in the history
  • Loading branch information
vangent authored Jul 31, 2023
1 parent 2693ff1 commit e6e3a0e
Show file tree
Hide file tree
Showing 92 changed files with 5,755 additions and 6,889 deletions.
237 changes: 162 additions & 75 deletions pubsub/awssnssqs/awssnssqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@
// - Topic: (V1) *sns.SNS for OpenSNSTopic, *sqs.SQS for OpenSQSTopic; (V2) *snsv2.Client for OpenSNSTopicV2, *sqsv2.Client for OpenSQSTopicV2
// - Subscription: (V1) *sqs.SQS; (V2) *sqsv2.Client
// - Message: (V1) *sqs.Message; (V2) sqstypesv2.Message
// - Message.BeforeSend: (V1) *sns.PublishInput for OpenSNSTopic, *sqs.SendMessageBatchRequestEntry or *sqs.SendMessageInput(deprecated) for OpenSQSTopic; (V2) *snsv2.PublishInput for OpenSNSTopicV2, *sqstypesv2.SendMessageBatchRequestEntry for OpenSQSTopicV2
// - Message.AfterSend: (V1) *sns.PublishOutput for OpenSNSTopic, *sqs.SendMessageBatchResultEntry for OpenSQSTopic; (V2) *snsv2.PublishOutput for OpenSNSTopicV2, sqstypesv2.SendMessageBatchResultEntry for OpenSQSTopicV2
// - Message.BeforeSend: (V1) *sns.PublishBatchRequestEntry or *sns.PublishInput (deprecated) for OpenSNSTopic, *sqs.SendMessageBatchRequestEntry or *sqs.SendMessageInput (deprecated) for OpenSQSTopic; (V2) *snsv2.PublishBatchRequestEntry or *snsv2.PublishInput (deprecated) for OpenSNSTopicV2, *sqstypesv2.SendMessageBatchRequestEntry for OpenSQSTopicV2
// - Message.AfterSend: (V1) sns.PublishBatchResultEntry or *sns.PublishOutput (deprecated) for OpenSNSTopic, *sqs.SendMessageBatchResultEntry for OpenSQSTopic; (V2) snstypesv2.PublishBatchResultEntry or *snsv2.PublishOutput (deprecated) for OpenSNSTopicV2, sqstypesv2.SendMessageBatchResultEntry for OpenSQSTopicV2
// - Error: (V1) awserr.Error, (V2) any error type returned by the service, notably smithy.APIError
package awssnssqs // import "gocloud.dev/pubsub/awssnssqs"

Expand Down Expand Up @@ -106,7 +106,7 @@ const (
)

var sendBatcherOptsSNS = &batcher.Options{
MaxBatchSize: 1, // SNS SendBatch only supports one message at a time
MaxBatchSize: 10, // SNS SendBatch supports 10 message at a time
MaxHandlers: 100, // max concurrency for sends
}

Expand Down Expand Up @@ -456,112 +456,199 @@ func maybeEncodeBody(body []byte, opt BodyBase64Encoding) (string, bool) {

// SendBatch implements driver.Topic.SendBatch.
func (t *snsTopic) SendBatch(ctx context.Context, dms []*driver.Message) error {
if len(dms) != 1 {
panic("snsTopic.SendBatch should only get one message at a time")
}
dm := dms[0]

if t.useV2 {
attrs := map[string]snstypesv2.MessageAttributeValue{}
req := &snsv2.PublishBatchInput{
TopicArn: &t.arn,
}
for _, dm := range dms {
attrs := map[string]snstypesv2.MessageAttributeValue{}
for k, v := range encodeMetadata(dm.Metadata) {
attrs[k] = snstypesv2.MessageAttributeValue{
DataType: stringDataType,
StringValue: aws.String(v),
}
}
body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding)
if didEncode {
attrs[base64EncodedKey] = snstypesv2.MessageAttributeValue{
DataType: stringDataType,
StringValue: aws.String("true"),
}
}
if len(attrs) == 0 {
attrs = nil
}
entry := &snstypesv2.PublishBatchRequestEntry{
Id: aws.String(strconv.Itoa(len(req.PublishBatchRequestEntries))),
MessageAttributes: attrs,
Message: aws.String(body),
}
if dm.BeforeSend != nil {
// A previous revision used the non-batch API PublishInput, which takes
// a *snsv2.PublishInput. For backwards compatibility for As, continue
// to support that type. If it is requested, create a PublishInput
// with the fields from PublishBatchRequestEntry that were set, and
// then copy all of the matching fields back after calling dm.BeforeSend.
var pi *snsv2.PublishInput
asFunc := func(i interface{}) bool {
if p, ok := i.(**snsv2.PublishInput); ok {
pi = &snsv2.PublishInput{
// Id does not exist on PublishInput.
MessageAttributes: entry.MessageAttributes,
Message: entry.Message,
}
*p = pi
return true
}
if p, ok := i.(**snstypesv2.PublishBatchRequestEntry); ok {
*p = entry
return true
}
return false
}
if err := dm.BeforeSend(asFunc); err != nil {
return err
}
if pi != nil {
// Copy all of the fields that may have been modified back to the entry.
entry.MessageAttributes = pi.MessageAttributes
entry.Message = pi.Message
entry.MessageDeduplicationId = pi.MessageDeduplicationId
entry.MessageGroupId = pi.MessageGroupId
entry.MessageStructure = pi.MessageStructure
entry.Subject = pi.Subject
}
}
req.PublishBatchRequestEntries = append(req.PublishBatchRequestEntries, *entry)
}
resp, err := t.clientV2.PublishBatch(ctx, req)
if err != nil {
return err
}
if numFailed := len(resp.Failed); numFailed > 0 {
first := resp.Failed[0]
return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sns.PublishBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil)
}
if len(resp.Successful) == len(dms) {
for n, dm := range dms {
if dm.AfterSend != nil {
asFunc := func(i interface{}) bool {
if p, ok := i.(*snstypesv2.PublishBatchResultEntry); ok {
*p = resp.Successful[n]
return true
}
if p, ok := i.(**snsv2.PublishOutput); ok {
// For backwards compability.
*p = &snsv2.PublishOutput{
MessageId: resp.Successful[n].MessageId,
SequenceNumber: resp.Successful[n].SequenceNumber,
}
return true
}
return false
}
if err := dm.AfterSend(asFunc); err != nil {
return err
}
}
}
}
return nil
}
req := &sns.PublishBatchInput{
TopicArn: &t.arn,
}
for _, dm := range dms {
attrs := map[string]*sns.MessageAttributeValue{}
for k, v := range encodeMetadata(dm.Metadata) {
attrs[k] = snstypesv2.MessageAttributeValue{
attrs[k] = &sns.MessageAttributeValue{
DataType: stringDataType,
StringValue: aws.String(v),
}
}
body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding)
if didEncode {
attrs[base64EncodedKey] = snstypesv2.MessageAttributeValue{
attrs[base64EncodedKey] = &sns.MessageAttributeValue{
DataType: stringDataType,
StringValue: aws.String("true"),
}
}
if len(attrs) == 0 {
attrs = nil
}
input := &snsv2.PublishInput{
Message: aws.String(body),
entry := &sns.PublishBatchRequestEntry{
Id: aws.String(strconv.Itoa(len(req.PublishBatchRequestEntries))),
MessageAttributes: attrs,
TopicArn: &t.arn,
Message: aws.String(body),
}
if dm.BeforeSend != nil {
// A previous revision used the non-batch API PublishInput, which takes
// a *snsv2.PublishInput. For backwards compatibility for As, continue
// to support that type. If it is requested, create a PublishInput
// with the fields from PublishBatchRequestEntry that were set, and
// then copy all of the matching fields back after calling dm.BeforeSend.
var pi *sns.PublishInput
asFunc := func(i interface{}) bool {
if p, ok := i.(**snsv2.PublishInput); ok {
*p = input
if p, ok := i.(**sns.PublishInput); ok {
pi = &sns.PublishInput{
// Id does not exist on PublishInput.
MessageAttributes: entry.MessageAttributes,
Message: entry.Message,
}
*p = pi
return true
}
return false
}
if err := dm.BeforeSend(asFunc); err != nil {
return err
}
}
po, err := t.clientV2.Publish(ctx, input)
if err != nil {
return err
}
if dm.AfterSend != nil {
asFunc := func(i interface{}) bool {
if p, ok := i.(**snsv2.PublishOutput); ok {
*p = po
if p, ok := i.(**sns.PublishBatchRequestEntry); ok {
*p = entry
return true
}
return false
}
if err := dm.AfterSend(asFunc); err != nil {
if err := dm.BeforeSend(asFunc); err != nil {
return err
}
}
return nil
}
attrs := map[string]*sns.MessageAttributeValue{}
for k, v := range encodeMetadata(dm.Metadata) {
attrs[k] = &sns.MessageAttributeValue{
DataType: stringDataType,
StringValue: aws.String(v),
}
}
body, didEncode := maybeEncodeBody(dm.Body, t.opts.BodyBase64Encoding)
if didEncode {
attrs[base64EncodedKey] = &sns.MessageAttributeValue{
DataType: stringDataType,
StringValue: aws.String("true"),
}
}
if len(attrs) == 0 {
attrs = nil
}
input := &sns.PublishInput{
Message: aws.String(body),
MessageAttributes: attrs,
TopicArn: &t.arn,
}
if dm.BeforeSend != nil {
asFunc := func(i interface{}) bool {
if p, ok := i.(**sns.PublishInput); ok {
*p = input
return true
if pi != nil {
// Copy all of the fields that may have been modified back to the entry.
entry.MessageAttributes = pi.MessageAttributes
entry.Message = pi.Message
entry.MessageDeduplicationId = pi.MessageDeduplicationId
entry.MessageGroupId = pi.MessageGroupId
entry.MessageStructure = pi.MessageStructure
entry.Subject = pi.Subject
}
return false
}
if err := dm.BeforeSend(asFunc); err != nil {
return err
}
req.PublishBatchRequestEntries = append(req.PublishBatchRequestEntries, entry)
}
po, err := t.client.PublishWithContext(ctx, input)
resp, err := t.client.PublishBatchWithContext(ctx, req)
if err != nil {
return err
}
if dm.AfterSend != nil {
asFunc := func(i interface{}) bool {
if p, ok := i.(**sns.PublishOutput); ok {
*p = po
return true
if numFailed := len(resp.Failed); numFailed > 0 {
first := resp.Failed[0]
return awserr.New(aws.StringValue(first.Code), fmt.Sprintf("sns.PublishBatch failed for %d message(s): %s", numFailed, aws.StringValue(first.Message)), nil)
}
if len(resp.Successful) == len(dms) {
for n, dm := range dms {
if dm.AfterSend != nil {
asFunc := func(i interface{}) bool {
if p, ok := i.(*sns.PublishBatchResultEntry); ok {
*p = *resp.Successful[n]
return true
}
if p, ok := i.(**sns.PublishOutput); ok {
// For backwards compability.
*p = &sns.PublishOutput{
MessageId: resp.Successful[n].MessageId,
SequenceNumber: resp.Successful[n].SequenceNumber,
}
return true
}
return false
}
if err := dm.AfterSend(asFunc); err != nil {
return err
}
}
return false
}
if err := dm.AfterSend(asFunc); err != nil {
return err
}
}
return nil
Expand Down
17 changes: 17 additions & 0 deletions pubsub/awssnssqs/awssnssqs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"testing"

snsv2 "github.com/aws/aws-sdk-go-v2/service/sns"
snstypesv2 "github.com/aws/aws-sdk-go-v2/service/sns/types"
sqsv2 "github.com/aws/aws-sdk-go-v2/service/sqs"
sqstypesv2 "github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -532,11 +533,19 @@ func (t awsAsTest) BeforeSend(as func(interface{}) bool) error {
if !as(&pub) {
return fmt.Errorf("cast failed for %T", &pub)
}
var entry *snstypesv2.PublishBatchRequestEntry
if !as(&entry) {
return fmt.Errorf("cast failed for %T", &entry)
}
} else {
var pub *sns.PublishInput
if !as(&pub) {
return fmt.Errorf("cast failed for %T", &pub)
}
var entry *sns.PublishBatchRequestEntry
if !as(&entry) {
return fmt.Errorf("cast failed for %T", &entry)
}
}
case topicKindSQS:
if t.useV2 {
Expand Down Expand Up @@ -568,11 +577,19 @@ func (t awsAsTest) AfterSend(as func(interface{}) bool) error {
if !as(&pub) {
return fmt.Errorf("cast failed for %T", &pub)
}
var entry snstypesv2.PublishBatchResultEntry
if !as(&entry) {
return fmt.Errorf("cast failed for %T", &entry)
}
} else {
var pub *sns.PublishOutput
if !as(&pub) {
return fmt.Errorf("cast failed for %T", &pub)
}
var entry sns.PublishBatchResultEntry
if !as(&entry) {
return fmt.Errorf("cast failed for %T", &entry)
}
}
case topicKindSQS:
if t.useV2 {
Expand Down
Loading

0 comments on commit e6e3a0e

Please sign in to comment.