Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimise generic decodes #349

Merged
merged 3 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,13 @@ func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecod
// Handle eface case when it isnt a union
if typ.Kind() == reflect.Interface && schema.Type() != Union {
if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok {
return &efaceDecoder{schema: schema}
return newEfaceDecoder(cfg, schema)
}
}

switch schema.Type() {
case String, Bytes, Int, Long, Float, Double, Boolean:
return createDecoderOfNative(schema, typ)
return createDecoderOfNative(schema.(*PrimitiveSchema), typ)

case Record:
return createDecoderOfRecord(cfg, schema, typ)
Expand Down
26 changes: 19 additions & 7 deletions codec_dynamic.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,43 @@ import (

type efaceDecoder struct {
schema Schema
typ reflect2.Type
dec ValDecoder
}

func newEfaceDecoder(cfg *frozenConfig, schema Schema) *efaceDecoder {
typ, _ := genericReceiver(schema)
dec := decoderOfType(cfg, schema, typ)

return &efaceDecoder{
schema: schema,
typ: typ,
dec: dec,
}
}

func (d *efaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) {
pObj := (*any)(ptr)
obj := *pObj
if obj == nil {
*pObj = genericDecode(d.schema, r)
if *pObj == nil {
*pObj = genericDecode(d.typ, d.dec, r)
return
}

typ := reflect2.TypeOf(obj)
typ := reflect2.TypeOf(*pObj)
if typ.Kind() != reflect.Ptr {
*pObj = genericDecode(d.schema, r)
*pObj = genericDecode(d.typ, d.dec, r)
return
}

ptrType := typ.(*reflect2.UnsafePtrType)
ptrElemType := ptrType.Elem()
if reflect2.IsNil(obj) {
if reflect2.IsNil(*pObj) {
obj := ptrElemType.New()
r.ReadVal(d.schema, obj)
*pObj = obj
return
}
r.ReadVal(d.schema, obj)
r.ReadVal(d.schema, *pObj)
}

type interfaceEncoder struct {
Expand Down
80 changes: 39 additions & 41 deletions codec_generic.go
Original file line number Diff line number Diff line change
@@ -1,140 +1,138 @@
package avro

import (
"fmt"
"errors"
"math/big"
"time"
"unsafe"

"github.com/modern-go/reflect2"
)

func genericDecode(schema Schema, r *Reader) any {
rPtr, rTyp, err := genericReceiver(schema)
if err != nil {
r.ReportError("Read", err.Error())
return nil
}
decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r)
func genericDecode(typ reflect2.Type, dec ValDecoder, r *Reader) any {
ptr := typ.UnsafeNew()
dec.Decode(ptr, r)
if r.Error != nil {
return nil
}
obj := rTyp.UnsafeIndirect(rPtr)

obj := typ.UnsafeIndirect(ptr)
if reflect2.IsNil(obj) {
return nil
}

// Generic reader returns a different result from the
// codec in the case of a big.Rat. Handle this.
if rTyp.Type1() == ratType {
if typ.Type1() == ratType {
dec := obj.(big.Rat)
return &dec
}

return obj
}

func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) {
func genericReceiver(schema Schema) (reflect2.Type, error) {
if schema.Type() == Ref {
schema = schema.(*RefSchema).Schema()
}

var ls LogicalSchema
lts, ok := schema.(LogicalTypeSchema)
if ok {
ls = lts.Logical()
}

name := string(schema.Type())
schemaName := string(schema.Type())
if ls != nil {
name += "." + string(ls.Type())
schemaName += "." + string(ls.Type())
}

switch schema.Type() {
case Boolean:
var v bool
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Int:
if ls != nil {
switch ls.Type() {
case Date:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil

case TimeMillis:
var v time.Duration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
}
}
var v int
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Long:
if ls != nil {
switch ls.Type() {
case TimeMicros:
var v time.Duration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case TimestampMillis:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case TimestampMicros:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case LocalTimestampMillis:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case LocalTimestampMicros:
var v time.Time
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
}
}
var v int64
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Float:
var v float32
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Double:
var v float64
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case String:
var v string
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Bytes:
if ls != nil && ls.Type() == Decimal {
var v *big.Rat
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
}
var v []byte
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Record:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
case Ref:
return genericReceiver(schema.(*RefSchema).Schema())
return reflect2.TypeOf(v), nil
case Enum:
var v string
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Array:
v := make([]any, 0)
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Map:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Union:
var v map[string]any
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Fixed:
fixed := schema.(*FixedSchema)
ls := fixed.Logical()
if ls != nil {
switch ls.Type() {
case Duration:
var v LogicalDuration
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
case Decimal:
var v big.Rat
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
}
}
v := byteSliceToArray(make([]byte, fixed.Size()), fixed.Size())
return unsafe.Pointer(&v), reflect2.TypeOf(v), nil
return reflect2.TypeOf(v), nil
default:
return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name)
// This should not be possible.
return nil, errors.New("dynamic receiver not found for schema " + schemaName)
}
}
15 changes: 10 additions & 5 deletions codec_generic_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,27 @@ func TestGenericDecode(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
defer ConfigTeardown()

schema := MustParse(test.schema)
r := NewReader(bytes.NewReader(test.data), 10)

got := genericDecode(schema, r)
typ, err := genericReceiver(schema)
require.NoError(t, err)
dec := decoderOfType(DefaultConfig.(*frozenConfig), schema, typ)

got := genericDecode(typ, dec, r)

test.wantErr(t, r.Error)
assert.Equal(t, test.want, got)
})
}
}

func TestGenericDecode_UnsupportedType(t *testing.T) {
func TestGenericReceiver_UnsupportedType(t *testing.T) {
schema := NewPrimitiveSchema(Type("test"), nil)
r := NewReader(bytes.NewReader([]byte{0x01}), 10)

_ = genericDecode(schema, r)
_, err := genericReceiver(schema)

assert.Error(t, r.Error)
assert.Error(t, err)
}
Loading
Loading