diff --git a/protocol/protocol.go b/protocol/protocol.go index ebf91a798..3d0a7b8dd 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -213,6 +213,37 @@ func Register(req, res Message) { } } +// OverrideTypeMessage is an interface implemented by messages that want to override the standard +// request/response types for a given API. +type OverrideTypeMessage interface { + TypeKey() OverrideTypeKey +} + +type OverrideTypeKey int16 + +const ( + RawProduceOverride OverrideTypeKey = 0 +) + +var overrideApiTypes [numApis]map[OverrideTypeKey]apiType + +func RegisterOverride(req, res Message, key OverrideTypeKey) { + k1 := req.ApiKey() + k2 := res.ApiKey() + + if k1 != k2 { + panic(fmt.Sprintf("[%T/%T]: request and response API keys mismatch: %d != %d", req, res, k1, k2)) + } + + if overrideApiTypes[k1] == nil { + overrideApiTypes[k1] = make(map[OverrideTypeKey]apiType) + } + overrideApiTypes[k1][key] = apiType{ + requests: typesOf(req), + responses: typesOf(res), + } +} + func typesOf(v interface{}) []messageType { return makeTypes(reflect.TypeOf(v).Elem()) } diff --git a/protocol/prototest/reflect.go b/protocol/prototest/reflect.go index 5c3d0a1d7..a266d688f 100644 --- a/protocol/prototest/reflect.go +++ b/protocol/prototest/reflect.go @@ -1,6 +1,7 @@ package prototest import ( + "bytes" "errors" "io" "reflect" @@ -49,6 +50,13 @@ func loadValue(v reflect.Value) (reset func()) { } resetFunc() resets = append(resets, resetFunc) + case io.Reader: + buf, _ := io.ReadAll(x) + resetFunc := func() { + f.Set(reflect.ValueOf(bytes.NewBuffer(buf))) + } + resetFunc() + resets = append(resets, resetFunc) } }) diff --git a/protocol/prototest/request.go b/protocol/prototest/request.go index 15c6e79c8..c0197f25d 100644 --- a/protocol/prototest/request.go +++ b/protocol/prototest/request.go @@ -46,6 +46,39 @@ func TestRequest(t *testing.T, version int16, msg protocol.Message) { }) } +// TestRequestWithOverride validates requests that have an overridden type. For requests with type overrides, we +// double-serialize the request to ensure the resulting encoding of the overridden and original type are identical. +func TestRequestWithOverride(t *testing.T, version int16, msg protocol.Message) { + reset := load(msg) + + t.Run(fmt.Sprintf("v%d", version), func(t *testing.T) { + b1 := &bytes.Buffer{} + + if err := protocol.WriteRequest(b1, version, 1234, "me", msg); err != nil { + t.Fatal(err) + } + + reset() + t.Logf("\n%s\n", hex.Dump(b1.Bytes())) + + _, _, _, req, err := protocol.ReadRequest(b1) + if err != nil { + t.Fatal(err) + } + + b2 := &bytes.Buffer{} + if err := protocol.WriteRequest(b2, version, 1234, "me", req); err != nil { + t.Fatal(err) + } + + if !deepEqual(b1, b2) { + t.Errorf("request message mismatch:") + t.Logf("expected: %+v", hex.Dump(b1.Bytes())) + t.Logf("found: %+v", hex.Dump(b2.Bytes())) + } + }) +} + func BenchmarkRequest(b *testing.B, version int16, msg protocol.Message) { reset := load(msg) diff --git a/protocol/rawproduce/rawproduce.go b/protocol/rawproduce/rawproduce.go new file mode 100644 index 000000000..bad83138d --- /dev/null +++ b/protocol/rawproduce/rawproduce.go @@ -0,0 +1,91 @@ +package rawproduce + +import ( + "fmt" + + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/produce" +) + +func init() { + // Register a type override so that raw produce requests will be encoded with the correct type. + req := &Request{} + protocol.RegisterOverride(req, &produce.Response{}, req.TypeKey()) +} + +type Request struct { + TransactionalID string `kafka:"min=v3,max=v8,nullable"` + Acks int16 `kafka:"min=v0,max=v8"` + Timeout int32 `kafka:"min=v0,max=v8"` + Topics []RequestTopic `kafka:"min=v0,max=v8"` +} + +func (r *Request) ApiKey() protocol.ApiKey { return protocol.Produce } + +func (r *Request) TypeKey() protocol.OverrideTypeKey { return protocol.RawProduceOverride } + +func (r *Request) Broker(cluster protocol.Cluster) (protocol.Broker, error) { + broker := protocol.Broker{ID: -1} + + for i := range r.Topics { + t := &r.Topics[i] + + topic, ok := cluster.Topics[t.Topic] + if !ok { + return broker, NewError(protocol.NewErrNoTopic(t.Topic)) + } + + for j := range t.Partitions { + p := &t.Partitions[j] + + partition, ok := topic.Partitions[p.Partition] + if !ok { + return broker, NewError(protocol.NewErrNoPartition(t.Topic, p.Partition)) + } + + if b, ok := cluster.Brokers[partition.Leader]; !ok { + return broker, NewError(protocol.NewErrNoLeader(t.Topic, p.Partition)) + } else if broker.ID < 0 { + broker = b + } else if b.ID != broker.ID { + return broker, NewError(fmt.Errorf("mismatching leaders (%d!=%d)", b.ID, broker.ID)) + } + } + } + + return broker, nil +} + +func (r *Request) HasResponse() bool { + return r.Acks != 0 +} + +type RequestTopic struct { + Topic string `kafka:"min=v0,max=v8"` + Partitions []RequestPartition `kafka:"min=v0,max=v8"` +} + +type RequestPartition struct { + Partition int32 `kafka:"min=v0,max=v8"` + RecordSet protocol.RawRecordSet `kafka:"min=v0,max=v8"` +} + +var ( + _ protocol.BrokerMessage = (*Request)(nil) +) + +type Error struct { + Err error +} + +func NewError(err error) *Error { + return &Error{Err: err} +} + +func (e *Error) Error() string { + return fmt.Sprintf("fetch request error: %v", e.Err) +} + +func (e *Error) Unwrap() error { + return e.Err +} diff --git a/protocol/rawproduce/rawproduce_test.go b/protocol/rawproduce/rawproduce_test.go new file mode 100644 index 000000000..2d987e711 --- /dev/null +++ b/protocol/rawproduce/rawproduce_test.go @@ -0,0 +1,201 @@ +package rawproduce_test + +import ( + "bytes" + "testing" + "time" + + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/prototest" + "github.com/segmentio/kafka-go/protocol/rawproduce" +) + +const ( + v0 = 0 + v3 = 3 + v5 = 5 +) + +func TestRawProduceRequest(t *testing.T) { + t0 := time.Now().Truncate(time.Millisecond) + t1 := t0.Add(1 * time.Millisecond) + t2 := t0.Add(2 * time.Millisecond) + + prototest.TestRequestWithOverride(t, v0, &rawproduce.Request{ + Acks: 1, + Timeout: 500, + Topics: []rawproduce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []rawproduce.RequestPartition{ + { + Partition: 0, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: nil}, + ), 1, 0), + }, + { + Partition: 1, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), 1, 0), + }, + }, + }, + + { + Topic: "topic-2", + Partitions: []rawproduce.RequestPartition{ + { + Partition: 0, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), 1, protocol.Gzip), + }, + }, + }, + }, + }) + + prototest.TestRequestWithOverride(t, v3, &rawproduce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []rawproduce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []rawproduce.RequestPartition{ + { + Partition: 0, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: nil}, + ), 1, 0), + }, + { + Partition: 1, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), 1, 0), + }, + }, + }, + }, + }) + + headers := []protocol.Header{ + {Key: "key-1", Value: []byte("value-1")}, + {Key: "key-2", Value: []byte("value-2")}, + {Key: "key-3", Value: []byte("value-3")}, + } + + prototest.TestRequestWithOverride(t, v5, &rawproduce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []rawproduce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []rawproduce.RequestPartition{ + { + Partition: 1, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), 2, 0), + }, + }, + }, + + { + Topic: "topic-2", + Partitions: []rawproduce.RequestPartition{ + { + Partition: 1, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), 2, protocol.Snappy), + }, + }, + }, + }, + }) +} + +func NewRawRecordSet(reader protocol.RecordReader, version int8, attr protocol.Attributes) protocol.RawRecordSet { + rs := protocol.RecordSet{Version: version, Attributes: attr, Records: reader} + buf := &bytes.Buffer{} + rs.WriteTo(buf) + + return protocol.RawRecordSet{ + Reader: buf, + } +} + +func BenchmarkProduceRequest(b *testing.B) { + t0 := time.Now().Truncate(time.Millisecond) + t1 := t0.Add(1 * time.Millisecond) + t2 := t0.Add(2 * time.Millisecond) + + prototest.BenchmarkRequest(b, v3, &rawproduce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []rawproduce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []rawproduce.RequestPartition{ + { + Partition: 0, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: nil}, + ), 1, 0), + }, + { + Partition: 1, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0")}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), 1, 0), + }, + }, + }, + }, + }) + + headers := []protocol.Header{ + {Key: "key-1", Value: []byte("value-1")}, + {Key: "key-2", Value: []byte("value-2")}, + {Key: "key-3", Value: []byte("value-3")}, + } + + prototest.BenchmarkRequest(b, v5, &rawproduce.Request{ + TransactionalID: "1234", + Acks: 1, + Timeout: 500, + Topics: []rawproduce.RequestTopic{ + { + Topic: "topic-1", + Partitions: []rawproduce.RequestPartition{ + { + Partition: 1, + RecordSet: NewRawRecordSet(protocol.NewRecordReader( + protocol.Record{Offset: 0, Time: t0, Key: nil, Value: prototest.String("msg-0"), Headers: headers}, + protocol.Record{Offset: 1, Time: t1, Key: nil, Value: prototest.String("msg-1")}, + protocol.Record{Offset: 2, Time: t2, Key: prototest.Bytes([]byte{1}), Value: prototest.String("msg-2")}, + ), 2, 0), + }, + }, + }, + }, + }) +} diff --git a/protocol/record.go b/protocol/record.go index 84594868b..e11af4dcc 100644 --- a/protocol/record.go +++ b/protocol/record.go @@ -292,6 +292,46 @@ func (rs *RecordSet) WriteTo(w io.Writer) (int64, error) { return n, nil } +// RawRecordSet represents a record set for a RawProduce request. The record set is +// represented as a raw sequence of pre-encoded record set bytes. +type RawRecordSet struct { + // Reader exposes the raw sequence of record set bytes. + Reader io.Reader +} + +// ReadFrom reads the representation of a record set from r into rrs. It re-uses the +// existing RecordSet.ReadFrom implementation to first read/decode data into a RecordSet, +// then writes/encodes the RecordSet to a buffer referenced by the RawRecordSet. +// +// Note: re-using the RecordSet.ReadFrom implementation makes this suboptimal from a +// performance standpoint as it require an extra copy of the record bytes. Holding off +// on optimizing, as this code path is only invoked in tests. +func (rrs *RawRecordSet) ReadFrom(r io.Reader) (int64, error) { + rs := &RecordSet{} + n, err := rs.ReadFrom(r) + if err != nil { + return 0, err + } + + buf := &bytes.Buffer{} + rs.WriteTo(buf) + *rrs = RawRecordSet{ + Reader: buf, + } + + return n, nil +} + +// WriteTo writes the RawRecordSet to an io.Writer. Since this is a raw record set representation, all that is +// done here is copying bytes from the underlying reader to the specified writer. +func (rrs *RawRecordSet) WriteTo(w io.Writer) (int64, error) { + if rrs.Reader == nil { + return 0, ErrNoRecord + } + + return io.Copy(w, rrs.Reader) +} + func makeTime(t int64) time.Time { return time.Unix(t/1000, (t%1000)*int64(time.Millisecond)) } diff --git a/protocol/request.go b/protocol/request.go index 8b99e0537..135b938bb 100644 --- a/protocol/request.go +++ b/protocol/request.go @@ -81,6 +81,12 @@ func WriteRequest(w io.Writer, apiVersion int16, correlationID int32, clientID s return fmt.Errorf("unsupported api: %s", apiNames[apiKey]) } + if typedMessage, ok := msg.(OverrideTypeMessage); ok { + typeKey := typedMessage.TypeKey() + overrideType := overrideApiTypes[apiKey][typeKey] + t = &overrideType + } + minVersion := t.minVersion() maxVersion := t.maxVersion() diff --git a/protocol/response.go b/protocol/response.go index 619480313..a43bd0237 100644 --- a/protocol/response.go +++ b/protocol/response.go @@ -95,6 +95,12 @@ func WriteResponse(w io.Writer, apiVersion int16, correlationID int32, msg Messa return fmt.Errorf("unsupported api: %s", apiNames[apiKey]) } + if typedMessage, ok := msg.(OverrideTypeMessage); ok { + typeKey := typedMessage.TypeKey() + overrideType := overrideApiTypes[apiKey][typeKey] + t = &overrideType + } + minVersion := t.minVersion() maxVersion := t.maxVersion() diff --git a/rawproduce.go b/rawproduce.go new file mode 100644 index 000000000..5928cb2f8 --- /dev/null +++ b/rawproduce.go @@ -0,0 +1,103 @@ +package kafka + +import ( + "context" + "errors" + "fmt" + "net" + + "github.com/segmentio/kafka-go/protocol" + produceAPI "github.com/segmentio/kafka-go/protocol/produce" + "github.com/segmentio/kafka-go/protocol/rawproduce" +) + +// RawProduceRequest represents a request sent to a kafka broker to produce records +// to a topic partition. The request contains a pre-encoded/raw record set. +type RawProduceRequest struct { + // Address of the kafka broker to send the request to. + Addr net.Addr + + // The topic to produce the records to. + Topic string + + // The partition to produce the records to. + Partition int + + // The level of required acknowledgements to ask the kafka broker for. + RequiredAcks RequiredAcks + + // The message format version used when encoding the records. + // + // By default, the client automatically determine which version should be + // used based on the version of the Produce API supported by the server. + MessageVersion int + + // An optional transaction id when producing to the kafka broker is part of + // a transaction. + TransactionalID string + + // The sequence of records to produce to the topic partition. + RawRecords protocol.RawRecordSet +} + +// RawProduce sends a raw produce request to a kafka broker and returns the response. +// +// If the request contained no records, an error wrapping protocol.ErrNoRecord +// is returned. +// +// When the request is configured with RequiredAcks=none, both the response and +// the error will be nil on success. +func (c *Client) RawProduce(ctx context.Context, req *RawProduceRequest) (*ProduceResponse, error) { + m, err := c.roundTrip(ctx, req.Addr, &rawproduce.Request{ + TransactionalID: req.TransactionalID, + Acks: int16(req.RequiredAcks), + Timeout: c.timeoutMs(ctx, defaultProduceTimeout), + Topics: []rawproduce.RequestTopic{{ + Topic: req.Topic, + Partitions: []rawproduce.RequestPartition{{ + Partition: int32(req.Partition), + RecordSet: req.RawRecords, + }}, + }}, + }) + + switch { + case err == nil: + case errors.Is(err, protocol.ErrNoRecord): + return new(ProduceResponse), nil + default: + return nil, fmt.Errorf("kafka.(*Client).RawProduce: %w", err) + } + + if req.RequiredAcks == RequireNone { + return nil, nil + } + + res := m.(*produceAPI.Response) + if len(res.Topics) == 0 { + return nil, fmt.Errorf("kafka.(*Client).RawProduce: %w", protocol.ErrNoTopic) + } + topic := &res.Topics[0] + if len(topic.Partitions) == 0 { + return nil, fmt.Errorf("kafka.(*Client).RawProduce: %w", protocol.ErrNoPartition) + } + partition := &topic.Partitions[0] + + ret := &ProduceResponse{ + Throttle: makeDuration(res.ThrottleTimeMs), + Error: makeError(partition.ErrorCode, partition.ErrorMessage), + BaseOffset: partition.BaseOffset, + LogAppendTime: makeTime(partition.LogAppendTime), + LogStartOffset: partition.LogStartOffset, + } + + if len(partition.RecordErrors) != 0 { + ret.RecordErrors = make(map[int]error, len(partition.RecordErrors)) + + for _, recErr := range partition.RecordErrors { + ret.RecordErrors[int(recErr.BatchIndex)] = errors.New(recErr.BatchIndexErrorMessage) + } + } + + return ret, nil +} diff --git a/rawproduce_test.go b/rawproduce_test.go new file mode 100644 index 000000000..2c7fed782 --- /dev/null +++ b/rawproduce_test.go @@ -0,0 +1,123 @@ +package kafka + +import ( + "bytes" + "context" + "testing" + "time" + + "github.com/segmentio/kafka-go/protocol" + ktesting "github.com/segmentio/kafka-go/testing" +) + +func TestClientRawProduce(t *testing.T) { + // The RawProduce request records are encoded in the format introduced in Kafka 0.11.0. + if !ktesting.KafkaIsAtLeast("0.11.0") { + t.Skip("Skipping because the RawProduce request is not supported by Kafka versions below 0.11.0") + } + + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + now := time.Now() + + res, err := client.RawProduce(context.Background(), &RawProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + RawRecords: NewRawRecordSet(NewRecordReader( + Record{Time: now, Value: NewBytes([]byte(`hello-1`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-2`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-3`))}, + ), 0), + }) + + if err != nil { + t.Fatal(err) + } + + if res.Error != nil { + t.Error(res.Error) + } + + for index, err := range res.RecordErrors { + t.Errorf("record at index %d produced an error: %v", index, err) + } +} + +func TestClientRawProduceCompressed(t *testing.T) { + // The RawProduce request records are encoded in the format introduced in Kafka 0.11.0. + if !ktesting.KafkaIsAtLeast("0.11.0") { + t.Skip("Skipping because the RawProduce request is not supported by Kafka versions below 0.11.0") + } + + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + now := time.Now() + + res, err := client.RawProduce(context.Background(), &RawProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + RawRecords: NewRawRecordSet(NewRecordReader( + Record{Time: now, Value: NewBytes([]byte(`hello-1`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-2`))}, + Record{Time: now, Value: NewBytes([]byte(`hello-3`))}, + ), protocol.Gzip), + }) + + if err != nil { + t.Fatal(err) + } + + if res.Error != nil { + t.Error(res.Error) + } + + for index, err := range res.RecordErrors { + t.Errorf("record at index %d produced an error: %v", index, err) + } +} + +func TestClientRawProduceNilRecords(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + _, err := client.RawProduce(context.Background(), &RawProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + RawRecords: protocol.RawRecordSet{Reader: nil}, + }) + + if err != nil { + t.Fatal(err) + } +} + +func TestClientRawProduceEmptyRecords(t *testing.T) { + client, topic, shutdown := newLocalClientAndTopic() + defer shutdown() + + _, err := client.Produce(context.Background(), &ProduceRequest{ + Topic: topic, + Partition: 0, + RequiredAcks: -1, + Records: NewRecordReader(), + }) + + if err != nil { + t.Fatal(err) + } +} + +func NewRawRecordSet(reader protocol.RecordReader, attr protocol.Attributes) protocol.RawRecordSet { + rs := protocol.RecordSet{Version: 2, Attributes: attr, Records: reader} + buf := &bytes.Buffer{} + rs.WriteTo(buf) + + return protocol.RawRecordSet{ + Reader: buf, + } +}