diff --git a/services/batch.go b/services/batch.go index 52af0b1..0eaf4a2 100644 --- a/services/batch.go +++ b/services/batch.go @@ -22,6 +22,8 @@ type BatchingService struct { batcher *batch.Executor[*models.AnchorRequestMessage, *uuid.UUID] metricService models.MetricService logger models.Logger + flushTicker *time.Ticker + flushInterval time.Duration } func NewBatchingService(ctx context.Context, batchPublisher models.QueuePublisher, batchStore models.KeyValueRepository, metricService models.MetricService, logger models.Logger) *BatchingService { @@ -37,11 +39,18 @@ func NewBatchingService(ctx context.Context, batchPublisher models.QueuePublishe anchorBatchLinger = parsedAnchorBatchLinger } } + anchorBatchFlushInterval := time.Hour // Default to 1 hour + if configAnchorBatchFlushInterval, found := os.LookupEnv("ANCHOR_BATCH_FLUSH_INTERVAL"); found { + if parsedAnchorBatchFlushInterval, err := time.ParseDuration(configAnchorBatchFlushInterval); err == nil { + anchorBatchFlushInterval = parsedAnchorBatchFlushInterval + } + } batchingService := BatchingService{ batchPublisher: batchPublisher, batchStore: batchStore, metricService: metricService, logger: logger, + flushInterval: anchorBatchFlushInterval, } beOpts := batch.Opts{MaxSize: anchorBatchSize, MaxLinger: anchorBatchLinger} batchingService.batcher = batch.New[*models.AnchorRequestMessage, *uuid.UUID]( @@ -51,10 +60,11 @@ func NewBatchingService(ctx context.Context, batchPublisher models.QueuePublishe return batchingService.batch(ctx, anchorReqs) }, ) + batchingService.startFlushTicker(ctx) return &batchingService } -func (b BatchingService) Batch(ctx context.Context, msgBody string) error { +func (b *BatchingService) Batch(ctx context.Context, msgBody string) error { anchorReq := new(models.AnchorRequestMessage) if err := json.Unmarshal([]byte(msgBody), anchorReq); err != nil { return err @@ -70,7 +80,7 @@ func (b BatchingService) Batch(ctx context.Context, msgBody string) error { } } -func (b BatchingService) batch(ctx context.Context, anchorReqs []*models.AnchorRequestMessage) ([]results.Result[*uuid.UUID], error) { +func (b *BatchingService) batch(ctx context.Context, anchorReqs []*models.AnchorRequestMessage) ([]results.Result[*uuid.UUID], error) { batchSize := len(anchorReqs) anchorReqBatch := models.AnchorBatchMessage{ Id: uuid.New(), @@ -108,9 +118,43 @@ func (b BatchingService) batch(ctx context.Context, anchorReqs []*models.AnchorR return batchResults, nil } -func (b BatchingService) Flush() { - // Flush the current batch however far along it's gotten in size or expiration. The caller needs to ensure that no - // more messages are sent to this service for processing once this function is called. Receiving more messages will - // cause workers to wait till the end of the batch expiration if there aren't enough messages to fill the batch. +func (b *BatchingService) startFlushTicker(ctx context.Context) { + // Calculate the duration until the next tick + now := time.Now().UTC() + nextTick := now.Truncate(b.flushInterval).Add(b.flushInterval) + tillNextTick := nextTick.Sub(now) + + // Wait for the initial duration before starting the ticker + time.AfterFunc(tillNextTick, func() { + b.Flush() + b.flushTicker = time.NewTicker(b.flushInterval) + go b.flushLoop(ctx) + }) +} + +func (b *BatchingService) flushLoop(ctx context.Context) { + for { + select { + case <-b.flushTicker.C: + b.Flush() + case <-ctx.Done(): + b.flushTicker.Stop() + return + } + } +} + +// Flush forces the batching service to flush any pending requests, however far along it's gotten in size or expiration. +// +// We're using this in two ways: +// 1. The top of the hour UTC: We want to flush any pending requests at the top of the hour so that C1 nodes can send +// in their Merkle Tree roots for anchoring right before the top of the hour. This ensures that the gap between the +// anchor request being sent and it being anchored is predictable and small. Without this logic, the gap could be +// upto 1 hour, i.e. till the Scheduler builds a new batch and sends it to the CAS. +// 2. At process shutdown: We want to flush any pending requests when the process is shutting down so that we don't +// lose any in-flight requests. The caller needs to ensure that no more messages are sent to this service for +// processing once this function is called. Receiving more messages will cause queue workers to wait till the end of +// the batch expiration if there aren't enough messages to fill the batch. +func (b *BatchingService) Flush() { b.batcher.Flush() } diff --git a/services/batch_test.go b/services/batch_test.go index ed49d3e..fdef41f 100644 --- a/services/batch_test.go +++ b/services/batch_test.go @@ -5,10 +5,12 @@ import ( "encoding/json" "sync" "testing" + "time" + + "github.com/google/uuid" "github.com/ceramicnetwork/go-cas/common/loggers" "github.com/ceramicnetwork/go-cas/models" - "github.com/google/uuid" ) func TestBatch(t *testing.T) { @@ -128,3 +130,82 @@ func TestBatch(t *testing.T) { }) } } + +func TestHourlyBatch(t *testing.T) { + t.Setenv("ANCHOR_BATCH_SIZE", "100") // High value so we know the batch is not flushed due to size + t.Setenv("ANCHOR_BATCH_LINGER", "1h") // High value so we know the batch is not flushed due to linger + t.Setenv("ANCHOR_BATCH_FLUSH_INTERVAL", "1s") + + requests := []*models.AnchorRequestMessage{ + {Id: uuid.New()}, + {Id: uuid.New()}, + {Id: uuid.New()}, + {Id: uuid.New()}, + {Id: uuid.New()}, + {Id: uuid.New()}, + } + numRequests := len(requests) + + encodedRequests := make([]string, len(requests)) + for i, request := range requests { + if requestMessage, err := json.Marshal(request); err != nil { + t.Fatalf("Failed to encode request %v", request) + } else { + encodedRequests[i] = string(requestMessage) + } + } + + logger := loggers.NewTestLogger() + testCtx := context.Background() + t.Run("flush batch at tick", func(t *testing.T) { + metricService := &MockMetricService{} + s3BatchStore := &MockS3BatchStore{} + publisher := &MockPublisher{messages: make(chan any, numRequests)} + + ctx, cancel := context.WithCancel(testCtx) + batchingServices := NewBatchingService(testCtx, publisher, s3BatchStore, metricService, logger) + + var wg sync.WaitGroup + for i := 1; i <= len(encodedRequests); i++ { + wg.Add(1) + go func() { + defer wg.Done() + + if err := batchingServices.Batch(ctx, encodedRequests[i-1]); err != nil { + t.Errorf("Unexpected error received %v", err) + } + }() + + // The flush interval is 1s so sleep for 2 seconds after every 2 requests to ensure the batch is flushed and + // contains 2 requests. + if i%2 == 0 { + <-time.After(2 * time.Second) + } + } + wg.Wait() + + // Each batch should have 2 requests in it, so the number of batches should be half the number of requests. + receivedMessages := waitForMesssages(publisher.messages, numRequests/2) + cancel() + + receivedBatches := make([]models.AnchorBatchMessage, len(receivedMessages)) + for i, message := range receivedMessages { + if batch, ok := message.(models.AnchorBatchMessage); !ok { + t.Fatalf("Received invalid anchor batch message: %v", message) + } else { + receivedBatches[i] = batch + } + } + + // Make sure each batch has 2 requests in it + for i := 0; i < len(receivedBatches); i++ { + if s3BatchStore.getBatchSize(receivedBatches[i].Id.String()) != 2 { + t.Errorf("Expected %v requests in batch %v. Contained %v requests", 2, i+1, len(receivedBatches[i].Ids)) + } + } + + Assert(t, numRequests, metricService.counts[models.MetricName_BatchIngressRequest], "Incorrect batch ingress request count") + Assert(t, numRequests/2, metricService.counts[models.MetricName_BatchCreated], "Incorrect created batch count") + Assert(t, numRequests/2, metricService.counts[models.MetricName_BatchStored], "Incorrect stored batch count") + }) +}