Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Determine the records type based on the magic number not API version #990

Merged
merged 1 commit into from
Nov 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,11 +519,10 @@ func (child *partitionConsumer) parseMessages(msgSet *MessageSet) ([]*ConsumerMe
return messages, nil
}

func (child *partitionConsumer) parseRecords(block *FetchResponseBlock) ([]*ConsumerMessage, error) {
func (child *partitionConsumer) parseRecords(batch *RecordBatch) ([]*ConsumerMessage, error) {
var messages []*ConsumerMessage
var incomplete bool
prelude := true
batch := block.Records.recordBatch

for _, rec := range batch.Records {
offset := batch.FirstOffset + rec.OffsetDelta
Expand Down Expand Up @@ -599,10 +598,10 @@ func (child *partitionConsumer) parseResponse(response *FetchResponse) ([]*Consu
return nil, err
}

if response.Version < 4 {
if block.Records.recordsType == legacyRecords {
return child.parseMessages(block.Records.msgSet)
}
return child.parseRecords(block)
return child.parseRecords(block.Records.recordBatch)
}

// brokerConsumer
Expand Down
43 changes: 43 additions & 0 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,49 @@ func TestConsumerExtraOffsets(t *testing.T) {
}
}

func TestConsumeMessageWithNewerFetchAPIVersion(t *testing.T) {
// Given
fetchResponse1 := &FetchResponse{Version: 4}
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 1)
fetchResponse1.AddMessage("my_topic", 0, nil, testMsg, 2)

cfg := NewConfig()
cfg.Version = V0_11_0_0

broker0 := NewMockBroker(t, 0)
fetchResponse2 := &FetchResponse{}
fetchResponse2.Version = 4
fetchResponse2.AddError("my_topic", 0, ErrNoError)
broker0.SetHandlerByMap(map[string]MockResponse{
"MetadataRequest": NewMockMetadataResponse(t).
SetBroker(broker0.Addr(), broker0.BrokerID()).
SetLeader("my_topic", 0, broker0.BrokerID()),
"OffsetRequest": NewMockOffsetResponse(t).
SetVersion(1).
SetOffset("my_topic", 0, OffsetNewest, 1234).
SetOffset("my_topic", 0, OffsetOldest, 0),
"FetchRequest": NewMockSequence(fetchResponse1, fetchResponse2),
})

master, err := NewConsumer([]string{broker0.Addr()}, cfg)
if err != nil {
t.Fatal(err)
}

// When
consumer, err := master.ConsumePartition("my_topic", 0, 1)
if err != nil {
t.Fatal(err)
}

assertMessageOffset(t, <-consumer.Messages(), 1)
assertMessageOffset(t, <-consumer.Messages(), 2)

safeClose(t, consumer)
safeClose(t, master)
broker0.Close()
}

// It is fine if offsets of fetched messages are not sequential (although
// strictly increasing!).
func TestConsumerNonSequentialOffsets(t *testing.T) {
Expand Down
9 changes: 1 addition & 8 deletions fetch_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,18 +79,11 @@ func (b *FetchResponseBlock) decode(pd packetDecoder, version int16) (err error)
if err != nil {
return err
}
var records Records
if version >= 4 {
records = newDefaultRecords(nil)
} else {
records = newLegacyRecords(nil)
}
if recordsSize > 0 {
if err = records.decode(recordsDecoder); err != nil {
if err = b.Records.decode(recordsDecoder); err != nil {
return err
}
}
b.Records = records

return nil
}
Expand Down
77 changes: 75 additions & 2 deletions fetch_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,28 @@ var (
0x06, 0x05, 0x06, 0x07,
0x02,
0x06, 0x08, 0x09, 0x0A,
0x04, 0x0B, 0x0C,
}
0x04, 0x0B, 0x0C}

oneMessageFetchResponseV4 = []byte{
0x00, 0x00, 0x00, 0x00, // ThrottleTime
0x00, 0x00, 0x00, 0x01, // Number of Topics
0x00, 0x05, 't', 'o', 'p', 'i', 'c', // Topic
0x00, 0x00, 0x00, 0x01, // Number of Partitions
0x00, 0x00, 0x00, 0x05, // Partition
0x00, 0x01, // Error
0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // High Watermark Offset
0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x10, 0x10, // Last Stable Offset
0x00, 0x00, 0x00, 0x00, // Number of Aborted Transactions
0x00, 0x00, 0x00, 0x1C,
// messageSet
0x00, 0x00, 0x00, 0x00, 0x00, 0x55, 0x00, 0x00,
0x00, 0x00, 0x00, 0x10,
// message
0x23, 0x96, 0x4a, 0xf7, // CRC
0x00,
0x00,
0xFF, 0xFF, 0xFF, 0xFF,
0x00, 0x00, 0x00, 0x02, 0x00, 0xEE}
)

func TestEmptyFetchResponse(t *testing.T) {
Expand Down Expand Up @@ -173,3 +193,56 @@ func TestOneRecordFetchResponse(t *testing.T) {
t.Error("Decoding produced incorrect record value.")
}
}

func TestOneMessageFetchResponseV4(t *testing.T) {
response := FetchResponse{}
testVersionDecodable(t, "one message v4", &response, oneMessageFetchResponseV4, 4)

if len(response.Blocks) != 1 {
t.Fatal("Decoding produced incorrect number of topic blocks.")
}

if len(response.Blocks["topic"]) != 1 {
t.Fatal("Decoding produced incorrect number of partition blocks for topic.")
}

block := response.GetBlock("topic", 5)
if block == nil {
t.Fatal("GetBlock didn't return block.")
}
if block.Err != ErrOffsetOutOfRange {
t.Error("Decoding didn't produce correct error code.")
}
if block.HighWaterMarkOffset != 0x10101010 {
t.Error("Decoding didn't produce correct high water mark offset.")
}
partial, err := block.Records.isPartial()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if partial {
t.Error("Decoding detected a partial trailing record where there wasn't one.")
}

n, err := block.Records.numRecords()
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if n != 1 {
t.Fatal("Decoding produced incorrect number of records.")
}
msgBlock := block.Records.msgSet.Messages[0]
if msgBlock.Offset != 0x550000 {
t.Error("Decoding produced incorrect message offset.")
}
msg := msgBlock.Msg
if msg.Codec != CompressionNone {
t.Error("Decoding produced incorrect message compression.")
}
if msg.Key != nil {
t.Error("Decoding produced message key where there was none.")
}
if !bytes.Equal(msg.Value, []byte{0x00, 0xEE}) {
t.Error("Decoding produced incorrect message value.")
}
}
1 change: 1 addition & 0 deletions packet_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type packetDecoder interface {
// Subsets
remaining() int
getSubset(length int) (packetDecoder, error)
peek(offset, length int) (packetDecoder, error) // similar to getSubset, but it doesn't advance the offset

// Stacks, see PushDecoder
push(in pushDecoder) error
Expand Down
5 changes: 0 additions & 5 deletions produce_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,6 @@ func (r *ProduceRequest) decode(pd packetDecoder, version int16) error {
return err
}
var records Records
if version >= 3 {
records = newDefaultRecords(nil)
} else {
records = newLegacyRecords(nil)
}
if err := records.decode(recordsDecoder); err != nil {
return err
}
Expand Down
8 changes: 8 additions & 0 deletions real_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ func (rd *realDecoder) getRawBytes(length int) ([]byte, error) {
return rd.raw[start:rd.off], nil
}

func (rd *realDecoder) peek(offset, length int) (packetDecoder, error) {
if rd.remaining() < offset+length {
return nil, ErrInsufficientData
}
off := rd.off + offset
return &realDecoder{raw: rd.raw[off : off+length]}, nil
}

// stacks

func (rd *realDecoder) push(in pushDecoder) error {
Expand Down
73 changes: 72 additions & 1 deletion records.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ package sarama
import "fmt"

const (
legacyRecords = iota
unknownRecords = iota
legacyRecords
defaultRecords

magicOffset = 16
magicLength = 1
)

// Records implements a union type containing either a RecordBatch or a legacy MessageSet.
Expand All @@ -22,7 +26,30 @@ func newDefaultRecords(batch *RecordBatch) Records {
return Records{recordsType: defaultRecords, recordBatch: batch}
}

// setTypeFromFields sets type of Records depending on which of msgSet or recordBatch is not nil.
// The first return value indicates whether both fields are nil (and the type is not set).
// If both fields are not nil, it returns an error.
func (r *Records) setTypeFromFields() (bool, error) {
if r.msgSet == nil && r.recordBatch == nil {
return true, nil
}
if r.msgSet != nil && r.recordBatch != nil {
return false, fmt.Errorf("both msgSet and recordBatch are set, but record type is unknown")
}
r.recordsType = defaultRecords
if r.msgSet != nil {
r.recordsType = legacyRecords
}
return false, nil
}

func (r *Records) encode(pe packetEncoder) error {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return err
}
}

switch r.recordsType {
case legacyRecords:
if r.msgSet == nil {
Expand All @@ -38,7 +65,31 @@ func (r *Records) encode(pe packetEncoder) error {
return fmt.Errorf("unknown records type: %v", r.recordsType)
}

func (r *Records) setTypeFromMagic(pd packetDecoder) error {
dec, err := pd.peek(magicOffset, magicLength)
if err != nil {
return err
}

magic, err := dec.getInt8()
if err != nil {
return err
}

r.recordsType = defaultRecords
if magic < 2 {
r.recordsType = legacyRecords
}
return nil
}

func (r *Records) decode(pd packetDecoder) error {
if r.recordsType == unknownRecords {
if err := r.setTypeFromMagic(pd); err != nil {
return nil
}
}

switch r.recordsType {
case legacyRecords:
r.msgSet = &MessageSet{}
Expand All @@ -51,6 +102,12 @@ func (r *Records) decode(pd packetDecoder) error {
}

func (r *Records) numRecords() (int, error) {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return 0, err
}
}

switch r.recordsType {
case legacyRecords:
if r.msgSet == nil {
Expand All @@ -67,7 +124,15 @@ func (r *Records) numRecords() (int, error) {
}

func (r *Records) isPartial() (bool, error) {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return false, err
}
}

switch r.recordsType {
case unknownRecords:
return false, nil
case legacyRecords:
if r.msgSet == nil {
return false, nil
Expand All @@ -83,6 +148,12 @@ func (r *Records) isPartial() (bool, error) {
}

func (r *Records) isControl() (bool, error) {
if r.recordsType == unknownRecords {
if empty, err := r.setTypeFromFields(); err != nil || empty {
return false, err
}
}

switch r.recordsType {
case legacyRecords:
return false, nil
Expand Down
10 changes: 8 additions & 2 deletions records_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestLegacyRecords(t *testing.T) {
}

set = &MessageSet{}
r = newLegacyRecords(nil)
r = Records{}

err = decode(exp, set)
if err != nil {
Expand All @@ -42,6 +42,9 @@ func TestLegacyRecords(t *testing.T) {
t.Fatal(err)
}

if r.recordsType != legacyRecords {
t.Fatalf("Wrong records type %v, expected %v", r.recordsType, legacyRecords)
}
if !reflect.DeepEqual(set, r.msgSet) {
t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet)
}
Expand Down Expand Up @@ -96,7 +99,7 @@ func TestDefaultRecords(t *testing.T) {
}

batch = &RecordBatch{}
r = newDefaultRecords(nil)
r = Records{}

err = decode(exp, batch)
if err != nil {
Expand All @@ -107,6 +110,9 @@ func TestDefaultRecords(t *testing.T) {
t.Fatal(err)
}

if r.recordsType != defaultRecords {
t.Fatalf("Wrong records type %v, expected %v", r.recordsType, defaultRecords)
}
if !reflect.DeepEqual(batch, r.recordBatch) {
t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch)
}
Expand Down