diff --git a/message.go b/message.go index c320006ba..0539e6038 100644 --- a/message.go +++ b/message.go @@ -49,6 +49,17 @@ func (msg *Message) size() int32 { return 4 + 1 + 1 + sizeofBytes(msg.Key) + sizeofBytes(msg.Value) + timestampSize } +func (msg *Message) headerSize() int { + return varArrayLen(len(msg.Headers), func(i int) int { + h := &msg.Headers[i] + return varStringLen(h.Key) + varBytesLen(h.Value) + }) +} + +func (msg *Message) totalSize() int32 { + return int32(msg.headerSize()) + msg.size() +} + type message struct { CRC int32 MagicByte int8 diff --git a/writer.go b/writer.go index f5d6fc2c5..1264ae718 100644 --- a/writer.go +++ b/writer.go @@ -621,7 +621,7 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error { batchBytes := w.batchBytes() for i := range msgs { - n := int64(msgs[i].size()) + n := int64(msgs[i].totalSize()) if n > batchBytes { // This error is left for backward compatibility with historical // behavior, but it can yield O(N^2) behaviors. The expectations @@ -1216,7 +1216,7 @@ func newWriteBatch(now time.Time, timeout time.Duration) *writeBatch { } func (b *writeBatch) add(msg Message, maxSize int, maxBytes int64) bool { - bytes := int64(msg.size()) + bytes := int64(msg.totalSize()) if b.size > 0 && (b.bytes+bytes) > maxBytes { return false diff --git a/writer_test.go b/writer_test.go index 70a44ca8d..b1f6246c8 100644 --- a/writer_test.go +++ b/writer_test.go @@ -7,6 +7,7 @@ import ( "io" "math" "strconv" + "strings" "sync" "testing" "time" @@ -134,6 +135,10 @@ func TestWriter(t *testing.T) { scenario: "writing messages with a small batch byte size", function: testWriterSmallBatchBytes, }, + { + scenario: "writing messages with headers", + function: testWriterBatchBytesHeaders, + }, { scenario: "setting a non default balancer on the writer", function: testWriterSetsRightBalancer, @@ -449,7 +454,7 @@ func testWriterBatchBytes(t *testing.T) { w := newTestWriter(WriterConfig{ Topic: topic, - BatchBytes: 48, + BatchBytes: 50, BatchTimeout: math.MaxInt32 * time.Second, Balancer: &RoundRobin{}, }) @@ -458,10 +463,10 @@ func testWriterBatchBytes(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := w.WriteMessages(ctx, []Message{ - {Value: []byte("M0")}, // 24 Bytes - {Value: []byte("M1")}, // 24 Bytes - {Value: []byte("M2")}, // 24 Bytes - {Value: []byte("M3")}, // 24 Bytes + {Value: []byte("M0")}, // 25 Bytes + {Value: []byte("M1")}, // 25 Bytes + {Value: []byte("M2")}, // 25 Bytes + {Value: []byte("M3")}, // 25 Bytes }...); err != nil { t.Error(err) return @@ -592,6 +597,67 @@ func testWriterSmallBatchBytes(t *testing.T) { } } +func testWriterBatchBytesHeaders(t *testing.T) { + topic := makeTopic() + createTopic(t, topic, 1) + defer deleteTopic(t, topic) + + offset, err := readOffset(topic, 0) + if err != nil { + t.Fatal(err) + } + + w := newTestWriter(WriterConfig{ + Topic: topic, + BatchBytes: 100, + BatchTimeout: 50 * time.Millisecond, + Balancer: &RoundRobin{}, + }) + defer w.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := w.WriteMessages(ctx, []Message{ + { + Value: []byte("Hello World 1"), + Headers: []Header{ + {Key: "User-Agent", Value: []byte("abc/xyz")}, + }, + }, + { + Value: []byte("Hello World 2"), + Headers: []Header{ + {Key: "User-Agent", Value: []byte("abc/xyz")}, + }, + }, + }...); err != nil { + t.Error(err) + return + } + ws := w.Stats() + if ws.Writes != 2 { + t.Error("didn't batch messages; Writes: ", ws.Writes) + return + } + msgs, err := readPartition(topic, 0, offset) + if err != nil { + t.Error("error reading partition", err) + return + } + + if len(msgs) != 2 { + t.Error("bad messages in partition", msgs) + return + } + + for _, m := range msgs { + if strings.HasPrefix(string(m.Value), "Hello World") { + continue + } + t.Error("bad messages in partition", msgs) + } +} + func testWriterMultipleTopics(t *testing.T) { topic1 := makeTopic() createTopic(t, topic1, 1)