diff --git a/codec.go b/codec.go index 66ad77d..c209c61 100644 --- a/codec.go +++ b/codec.go @@ -86,13 +86,13 @@ func decoderOfType(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDecod // Handle eface case when it isnt a union if typ.Kind() == reflect.Interface && schema.Type() != Union { if _, ok := typ.(*reflect2.UnsafeIFaceType); !ok { - return &efaceDecoder{schema: schema} + return newEfaceDecoder(cfg, schema) } } switch schema.Type() { case String, Bytes, Int, Long, Float, Double, Boolean: - return createDecoderOfNative(schema, typ) + return createDecoderOfNative(schema.(*PrimitiveSchema), typ) case Record: return createDecoderOfRecord(cfg, schema, typ) diff --git a/codec_dynamic.go b/codec_dynamic.go index 4b6271d..229079f 100644 --- a/codec_dynamic.go +++ b/codec_dynamic.go @@ -9,31 +9,43 @@ import ( type efaceDecoder struct { schema Schema + typ reflect2.Type + dec ValDecoder +} + +func newEfaceDecoder(cfg *frozenConfig, schema Schema) *efaceDecoder { + typ, _ := genericReceiver(schema) + dec := decoderOfType(cfg, schema, typ) + + return &efaceDecoder{ + schema: schema, + typ: typ, + dec: dec, + } } func (d *efaceDecoder) Decode(ptr unsafe.Pointer, r *Reader) { pObj := (*any)(ptr) - obj := *pObj - if obj == nil { - *pObj = genericDecode(d.schema, r) + if *pObj == nil { + *pObj = genericDecode(d.typ, d.dec, r) return } - typ := reflect2.TypeOf(obj) + typ := reflect2.TypeOf(*pObj) if typ.Kind() != reflect.Ptr { - *pObj = genericDecode(d.schema, r) + *pObj = genericDecode(d.typ, d.dec, r) return } ptrType := typ.(*reflect2.UnsafePtrType) ptrElemType := ptrType.Elem() - if reflect2.IsNil(obj) { + if reflect2.IsNil(*pObj) { obj := ptrElemType.New() r.ReadVal(d.schema, obj) *pObj = obj return } - r.ReadVal(d.schema, obj) + r.ReadVal(d.schema, *pObj) } type interfaceEncoder struct { diff --git a/codec_generic.go b/codec_generic.go index 2091659..36bf9ee 100644 --- a/codec_generic.go +++ b/codec_generic.go @@ -1,124 +1,121 @@ package avro import ( - "fmt" + "errors" "math/big" "time" - "unsafe" "github.com/modern-go/reflect2" ) -func genericDecode(schema Schema, r *Reader) any { - rPtr, rTyp, err := genericReceiver(schema) - if err != nil { - r.ReportError("Read", err.Error()) - return nil - } - decoderOfType(r.cfg, schema, rTyp).Decode(rPtr, r) +func genericDecode(typ reflect2.Type, dec ValDecoder, r *Reader) any { + ptr := typ.UnsafeNew() + dec.Decode(ptr, r) if r.Error != nil { return nil } - obj := rTyp.UnsafeIndirect(rPtr) + + obj := typ.UnsafeIndirect(ptr) if reflect2.IsNil(obj) { return nil } // Generic reader returns a different result from the // codec in the case of a big.Rat. Handle this. - if rTyp.Type1() == ratType { + if typ.Type1() == ratType { dec := obj.(big.Rat) return &dec } - return obj } -func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) { +func genericReceiver(schema Schema) (reflect2.Type, error) { + if schema.Type() == Ref { + schema = schema.(*RefSchema).Schema() + } + var ls LogicalSchema lts, ok := schema.(LogicalTypeSchema) if ok { ls = lts.Logical() } - name := string(schema.Type()) + schemaName := string(schema.Type()) if ls != nil { - name += "." + string(ls.Type()) + schemaName += "." + string(ls.Type()) } switch schema.Type() { case Boolean: var v bool - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Int: if ls != nil { switch ls.Type() { case Date: var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case TimeMillis: var v time.Duration - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil } } var v int - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Long: if ls != nil { switch ls.Type() { case TimeMicros: var v time.Duration - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case TimestampMillis: var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case TimestampMicros: var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case LocalTimestampMillis: var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case LocalTimestampMicros: var v time.Time - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil } } var v int64 - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Float: var v float32 - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Double: var v float64 - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case String: var v string - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Bytes: if ls != nil && ls.Type() == Decimal { var v *big.Rat - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil } var v []byte - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Record: var v map[string]any - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil - case Ref: - return genericReceiver(schema.(*RefSchema).Schema()) + return reflect2.TypeOf(v), nil case Enum: var v string - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Array: v := make([]any, 0) - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Map: var v map[string]any - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Union: var v map[string]any - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Fixed: fixed := schema.(*FixedSchema) ls := fixed.Logical() @@ -126,15 +123,16 @@ func genericReceiver(schema Schema) (unsafe.Pointer, reflect2.Type, error) { switch ls.Type() { case Duration: var v LogicalDuration - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil case Decimal: var v big.Rat - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil } } v := byteSliceToArray(make([]byte, fixed.Size()), fixed.Size()) - return unsafe.Pointer(&v), reflect2.TypeOf(v), nil + return reflect2.TypeOf(v), nil default: - return nil, nil, fmt.Errorf("dynamic receiver not found for schema: %v", name) + // This should not be possible. + return nil, errors.New("dynamic receiver not found for schema " + schemaName) } } diff --git a/codec_generic_internal_test.go b/codec_generic_internal_test.go index 772bc3f..98a5855 100644 --- a/codec_generic_internal_test.go +++ b/codec_generic_internal_test.go @@ -213,10 +213,16 @@ func TestGenericDecode(t *testing.T) { for _, test := range tests { test := test t.Run(test.name, func(t *testing.T) { + defer ConfigTeardown() + schema := MustParse(test.schema) r := NewReader(bytes.NewReader(test.data), 10) - got := genericDecode(schema, r) + typ, err := genericReceiver(schema) + require.NoError(t, err) + dec := decoderOfType(DefaultConfig.(*frozenConfig), schema, typ) + + got := genericDecode(typ, dec, r) test.wantErr(t, r.Error) assert.Equal(t, test.want, got) @@ -224,11 +230,10 @@ func TestGenericDecode(t *testing.T) { } } -func TestGenericDecode_UnsupportedType(t *testing.T) { +func TestGenericReceiver_UnsupportedType(t *testing.T) { schema := NewPrimitiveSchema(Type("test"), nil) - r := NewReader(bytes.NewReader([]byte{0x01}), 10) - _ = genericDecode(schema, r) + _, err := genericReceiver(schema) - assert.Error(t, r.Error) + assert.Error(t, err) } diff --git a/codec_native.go b/codec_native.go index b99fb21..3129e29 100644 --- a/codec_native.go +++ b/codec_native.go @@ -11,9 +11,8 @@ import ( ) //nolint:maintidx // Splitting this would not make it simpler. -func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { - converter := resolveConverter(schema.(*PrimitiveSchema).actual) - +func createDecoderOfNative(schema *PrimitiveSchema, typ reflect2.Type) ValDecoder { + isConv := schema.actual != "" switch typ.Kind() { case reflect.Bool: if schema.Type() != Boolean { @@ -61,7 +60,10 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Long { break } - return &longCodec[uint32]{convert: converter.toLong} + if isConv { + return &longConvCodec[uint32]{convert: createLongConverter(schema.actual)} + } + return &longCodec[uint32]{} case reflect.Int64: st := schema.Type() @@ -72,13 +74,14 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { case st == Long && lt == TimeMicros: // time.Duration return &timeMicrosCodec{ - convert: converter.toLong, + convert: createLongConverter(schema.actual), } case st == Long: - return &longCodec[int64]{ - convert: converter.toLong, + if isConv { + return &longConvCodec[int64]{convert: createLongConverter(schema.actual)} } + return &longCodec[int64]{} default: break @@ -88,34 +91,31 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { if schema.Type() != Float { break } - return &float32Codec{ - convert: converter.toFloat, + if isConv { + return &float32ConvCodec{convert: createFloatConverter(schema.actual)} } + return &float32Codec{} case reflect.Float64: if schema.Type() != Double { break } - return &float64Codec{ - convert: converter.toDouble, + if isConv { + return &float64ConvCodec{convert: createDoubleConverter(schema.actual)} } + return &float64Codec{} case reflect.String: if schema.Type() != String { break } - return &stringCodec{ - convert: converter.toString, - } + return &stringCodec{} case reflect.Slice: if typ.(reflect2.SliceType).Elem().Kind() != reflect.Uint8 || schema.Type() != Bytes { break } - return &bytesCodec{ - sliceType: typ.(*reflect2.UnsafeSliceType), - convert: converter.toBytes, - } + return &bytesCodec{sliceType: typ.(*reflect2.UnsafeSliceType)} case reflect.Struct: st := schema.Type() @@ -127,28 +127,25 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { return &dateCodec{} case isTime && st == Long && lt == TimestampMillis: return ×tampMillisCodec{ - convert: converter.toLong, + convert: createLongConverter(schema.actual), } case isTime && st == Long && lt == TimestampMicros: return ×tampMicrosCodec{ - convert: converter.toLong, + convert: createLongConverter(schema.actual), } case isTime && st == Long && lt == LocalTimestampMillis: return ×tampMillisCodec{ local: true, - convert: converter.toLong, + convert: createLongConverter(schema.actual), } case isTime && st == Long && lt == LocalTimestampMicros: return ×tampMicrosCodec{ local: true, - convert: converter.toLong, + convert: createLongConverter(schema.actual), } case typ.Type1().ConvertibleTo(ratType) && st == Bytes && lt == Decimal: dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalCodec{ - prec: dec.Precision(), scale: dec.Scale(), - convert: converter.toBytes, - } + return &bytesDecimalCodec{prec: dec.Precision(), scale: dec.Scale()} default: break @@ -166,10 +163,7 @@ func createDecoderOfNative(schema Schema, typ reflect2.Type) ValDecoder { } dec := ls.(*DecimalLogicalSchema) - return &bytesDecimalPtrCodec{ - prec: dec.Precision(), scale: dec.Scale(), - convert: converter.toBytes, - } + return &bytesDecimalPtrCodec{prec: dec.Precision(), scale: dec.Scale()} } return &errorDecoder{err: fmt.Errorf("avro: %s is unsupported for Avro %s", typ.String(), schema.Type())} @@ -364,79 +358,70 @@ type largeInt interface { ~int32 | ~uint32 | int64 } -type longCodec[T largeInt] struct { - convert func(*Reader) int64 -} +type longCodec[T largeInt] struct{} func (c *longCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { - var v T - if c.convert != nil { - v = T(c.convert(r)) - } else { - v = T(r.ReadLong()) - } - *((*T)(ptr)) = v + *((*T)(ptr)) = T(r.ReadLong()) } func (*longCodec[T]) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteLong(int64(*((*T)(ptr)))) } -type float32Codec struct { - convert func(*Reader) float32 +type longConvCodec[T largeInt] struct { + convert func(*Reader) int64 } -func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { - var v float32 - if c.convert != nil { - v = c.convert(r) - } else { - v = r.ReadFloat() - } +func (c *longConvCodec[T]) Decode(ptr unsafe.Pointer, r *Reader) { + *((*T)(ptr)) = T(c.convert(r)) +} + +type float32Codec struct{} - *((*float32)(ptr)) = v +func (c *float32Codec) Decode(ptr unsafe.Pointer, r *Reader) { + *((*float32)(ptr)) = r.ReadFloat() } func (*float32Codec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteFloat(*((*float32)(ptr))) } +type float32ConvCodec struct { + convert func(*Reader) float32 +} + +func (c *float32ConvCodec) Decode(ptr unsafe.Pointer, r *Reader) { + *((*float32)(ptr)) = c.convert(r) +} + type float32DoubleCodec struct{} func (*float32DoubleCodec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteDouble(float64(*((*float32)(ptr)))) } -type float64Codec struct { - convert func(*Reader) float64 -} +type float64Codec struct{} func (c *float64Codec) Decode(ptr unsafe.Pointer, r *Reader) { - var v float64 - if c.convert != nil { - v = c.convert(r) - } else { - v = r.ReadDouble() - } - *((*float64)(ptr)) = v + *((*float64)(ptr)) = r.ReadDouble() } func (*float64Codec) Encode(ptr unsafe.Pointer, w *Writer) { w.WriteDouble(*((*float64)(ptr))) } -type stringCodec struct { - convert func(*Reader) string +type float64ConvCodec struct { + convert func(*Reader) float64 +} + +func (c *float64ConvCodec) Decode(ptr unsafe.Pointer, r *Reader) { + *((*float64)(ptr)) = c.convert(r) } +type stringCodec struct{} + func (c *stringCodec) Decode(ptr unsafe.Pointer, r *Reader) { - var v string - if c.convert != nil { - v = c.convert(r) - } else { - v = r.ReadString() - } - *((*string)(ptr)) = v + *((*string)(ptr)) = r.ReadString() } func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { @@ -445,16 +430,10 @@ func (*stringCodec) Encode(ptr unsafe.Pointer, w *Writer) { type bytesCodec struct { sliceType *reflect2.UnsafeSliceType - convert func(*Reader) []byte } func (c *bytesCodec) Decode(ptr unsafe.Pointer, r *Reader) { - var b []byte - if c.convert != nil { - b = c.convert(r) - } else { - b = r.ReadBytes() - } + b := r.ReadBytes() c.sliceType.UnsafeSet(ptr, reflect2.PtrOf(b)) } @@ -584,18 +563,12 @@ func (c *timeMicrosCodec) Encode(ptr unsafe.Pointer, w *Writer) { var one = big.NewInt(1) type bytesDecimalCodec struct { - prec int - scale int - convert func(*Reader) []byte + prec int + scale int } func (c *bytesDecimalCodec) Decode(ptr unsafe.Pointer, r *Reader) { - var b []byte - if c.convert != nil { - b = c.convert(r) - } else { - b = r.ReadBytes() - } + b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } @@ -636,19 +609,12 @@ func (c *bytesDecimalCodec) Encode(ptr unsafe.Pointer, w *Writer) { } type bytesDecimalPtrCodec struct { - prec int - scale int - convert func(*Reader) []byte + prec int + scale int } func (c *bytesDecimalPtrCodec) Decode(ptr unsafe.Pointer, r *Reader) { - var b []byte - if c.convert != nil { - b = c.convert(r) - } else { - b = r.ReadBytes() - } - + b := r.ReadBytes() if i := (&big.Int{}).SetBytes(b); len(b) > 0 && b[0]&0x80 > 0 { i.Sub(i, new(big.Int).Lsh(one, uint(len(b))*8)) } diff --git a/codec_record.go b/codec_record.go index 751e867..03bfed6 100644 --- a/codec_record.go +++ b/codec_record.go @@ -256,16 +256,15 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields := make([]recordMapDecoderField, len(rec.Fields())) for i, field := range rec.Fields() { - if field.action == FieldIgnore { + switch field.action { + case FieldIgnore: fields[i] = recordMapDecoderField{ name: field.Name(), decoder: createSkipDecoder(field.Type()), skip: true, } continue - } - - if field.action == FieldSetDefault { + case FieldSetDefault: if field.hasDef { fields[i] = recordMapDecoderField{ name: field.Name(), @@ -277,7 +276,7 @@ func decoderOfRecord(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValDec fields[i] = recordMapDecoderField{ name: field.Name(), - decoder: decoderOfType(cfg, field.Type(), mapType.Elem()), + decoder: newEfaceDecoder(cfg, field.Type()), } } @@ -302,17 +301,17 @@ type recordMapDecoder struct { func (d *recordMapDecoder) Decode(ptr unsafe.Pointer, r *Reader) { if d.mapType.UnsafeIsNil(ptr) { - d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) + d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(len(d.fields))) } for _, field := range d.fields { - elem := d.elemType.UnsafeNew() - field.decoder.Decode(elem, r) + elemPtr := d.elemType.UnsafeNew() + field.decoder.Decode(elemPtr, r) if field.skip { continue } - d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elem) + d.mapType.UnsafeSetIndex(ptr, reflect2.PtrOf(field), elemPtr) } if r.Error != nil && !errors.Is(r.Error, io.EOF) { diff --git a/codec_union.go b/codec_union.go index 7f864be..084f3c7 100644 --- a/codec_union.go +++ b/codec_union.go @@ -62,11 +62,20 @@ func decoderOfMapUnion(cfg *frozenConfig, schema Schema, typ reflect2.Type) ValD union := schema.(*UnionSchema) mapType := typ.(*reflect2.UnsafeMapType) + typeDecs := make([]ValDecoder, len(union.Types())) + for i, s := range union.Types() { + if s.Type() == Null { + continue + } + typeDecs[i] = newEfaceDecoder(cfg, s) + } + return &mapUnionDecoder{ cfg: cfg, schema: union, mapType: mapType, elemType: mapType.Elem(), + typeDecs: typeDecs, } } @@ -75,10 +84,11 @@ type mapUnionDecoder struct { schema *UnionSchema mapType *reflect2.UnsafeMapType elemType reflect2.Type + typeDecs []ValDecoder } func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { - _, resSchema := getUnionSchema(d.schema, r) + idx, resSchema := getUnionSchema(d.schema, r) if resSchema == nil { return } @@ -89,14 +99,14 @@ func (d *mapUnionDecoder) Decode(ptr unsafe.Pointer, r *Reader) { } if d.mapType.UnsafeIsNil(ptr) { - d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(0)) + d.mapType.UnsafeSet(ptr, d.mapType.UnsafeMakeMap(1)) } key := schemaTypeName(resSchema) keyPtr := reflect2.PtrOf(key) elemPtr := d.elemType.UnsafeNew() - decoderOfType(d.cfg, resSchema, d.elemType).Decode(elemPtr, r) + d.typeDecs[idx].Decode(elemPtr, r) d.mapType.UnsafeSetIndex(ptr, keyPtr, elemPtr) } @@ -294,7 +304,12 @@ func (d *unionResolvedDecoder) Decode(ptr unsafe.Pointer, r *Reader) { // We cannot resolve this, set it to the map type name := schemaTypeName(schema) obj := map[string]any{} - obj[name] = genericDecode(schema, r) + vTyp, err := genericReceiver(schema) + if err != nil { + r.ReportError("Union", err.Error()) + return + } + obj[name] = genericDecode(vTyp, decoderOfType(d.cfg, schema, vTyp), r) *pObj = obj return diff --git a/converter.go b/converter.go index 0a100a3..cc1f17c 100644 --- a/converter.go +++ b/converter.go @@ -1,98 +1,34 @@ package avro -import ( - "fmt" - "unsafe" - - "github.com/modern-go/reflect2" -) - -type converter struct { - toLong func(*Reader) int64 - toFloat func(*Reader) float32 - toDouble func(*Reader) float64 - toString func(*Reader) string - toBytes func(*Reader) []byte -} - -// resolveConverter returns a set of converter functions based on the actual type. -// Depending on the actual type value, some converter functions may be nil; -// thus, the downstream caller must first check the converter function value. -func resolveConverter(typ Type) converter { - cv := converter{} - cv.toLong, _ = createLongConverter(typ) - cv.toFloat, _ = createFloatConverter(typ) - cv.toDouble, _ = createDoubleConverter(typ) - cv.toString, _ = createStringConverter(typ) - cv.toBytes, _ = createBytesConverter(typ) - return cv -} - -func createLongConverter(typ Type) (func(*Reader) int64, error) { +func createLongConverter(typ Type) func(*Reader) int64 { switch typ { case Int: - return func(r *Reader) int64 { return int64(r.ReadInt()) }, nil - case Long: - return func(r *Reader) int64 { return r.ReadLong() }, nil + return func(r *Reader) int64 { return int64(r.ReadInt()) } default: - return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + return nil } } -func createFloatConverter(typ Type) (func(*Reader) float32, error) { +func createFloatConverter(typ Type) func(*Reader) float32 { switch typ { case Int: - return func(r *Reader) float32 { return float32(r.ReadInt()) }, nil + return func(r *Reader) float32 { return float32(r.ReadInt()) } case Long: - return func(r *Reader) float32 { return float32(r.ReadLong()) }, nil - case Float: - return func(r *Reader) float32 { return r.ReadFloat() }, nil + return func(r *Reader) float32 { return float32(r.ReadLong()) } default: - return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + return nil } } -func createDoubleConverter(typ Type) (func(*Reader) float64, error) { +func createDoubleConverter(typ Type) func(*Reader) float64 { switch typ { case Int: - return func(r *Reader) float64 { return float64(r.ReadInt()) }, nil + return func(r *Reader) float64 { return float64(r.ReadInt()) } case Long: - return func(r *Reader) float64 { return float64(r.ReadLong()) }, nil + return func(r *Reader) float64 { return float64(r.ReadLong()) } case Float: - return func(r *Reader) float64 { return float64(r.ReadFloat()) }, nil - case Double: - return func(r *Reader) float64 { return r.ReadDouble() }, nil - default: - return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) - } -} - -func createStringConverter(typ Type) (func(*Reader) string, error) { - switch typ { - case Bytes: - return func(r *Reader) string { - b := r.ReadBytes() - if len(b) == 0 { - return "" - } - return *(*string)(unsafe.Pointer(&b)) - }, nil - case String: - return func(r *Reader) string { return r.ReadString() }, nil - default: - return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) - } -} - -func createBytesConverter(typ Type) (func(*Reader) []byte, error) { - switch typ { - case String: - return func(r *Reader) []byte { - return reflect2.UnsafeCastString(r.ReadString()) - }, nil - case Bytes: - return func(r *Reader) []byte { return r.ReadBytes() }, nil + return func(r *Reader) float64 { return float64(r.ReadFloat()) } default: - return nil, fmt.Errorf("cannot promote from %q to %q", typ, Long) + return nil } } diff --git a/converter_test.go b/converter_test.go index 743164f..e39421d 100644 --- a/converter_test.go +++ b/converter_test.go @@ -6,106 +6,83 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func TestConverter(t *testing.T) { +func TestLongConverter(t *testing.T) { tests := []struct { - data []byte - want any - typ, wantTyp Type - wantErr require.ErrorAssertionFunc + data []byte + typ Type + want int64 }{ { - data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, - want: int64(987654321), - typ: Int, - wantTyp: Long, - wantErr: require.NoError, - }, - { - data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, - want: int64(9223372036854775807), - typ: Long, - wantTyp: Long, - wantErr: require.NoError, - }, - { - data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, - want: float32(987654321), - typ: Int, - wantTyp: Float, - wantErr: require.NoError, - }, - { - data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, - want: float32(9223372036854775807), - typ: Long, - wantTyp: Float, - wantErr: require.NoError, - }, - { - data: []byte{0x62, 0x20, 0x71, 0x49}, - want: float32(987654.124), - typ: Float, - wantTyp: Float, - wantErr: require.NoError, - }, - { - data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, - want: float64(987654321), - typ: Int, - wantTyp: Double, - wantErr: require.NoError, - }, - { - data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, - want: float64(9223372036854775807), - typ: Long, - wantTyp: Double, - wantErr: require.NoError, - }, - { - data: []byte{0x62, 0x20, 0x71, 0x49}, - want: float64(float32(987654.124)), - typ: Float, - wantTyp: Double, - wantErr: require.NoError, + data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, + typ: Int, + want: 987654321, }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + r := NewReader(bytes.NewReader(test.data), 10) + + got := createLongConverter(test.typ)(r) + + assert.Equal(t, test.want, got) + }) + } +} + +func TestFloatConverter(t *testing.T) { + tests := []struct { + data []byte + typ Type + want float32 + }{ { - data: []byte{0xB6, 0xF3, 0x7D, 0x54, 0x34, 0x6F, 0x9D, 0xC1}, - want: float64(-123456789.123), - typ: Double, - wantTyp: Double, - wantErr: require.NoError, + data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, + typ: Int, + want: 987654321, }, { - data: []byte{0x08, 0xEC, 0xAB, 0x44, 0x00}, - want: string([]byte{0xEC, 0xAB, 0x44, 0x00}), - typ: Bytes, - wantTyp: String, - wantErr: require.NoError, + data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, + typ: Long, + want: 9223372036854775807, }, + } + + for i, test := range tests { + test := test + t.Run(strconv.Itoa(i), func(t *testing.T) { + r := NewReader(bytes.NewReader(test.data), 10) + + got := createFloatConverter(test.typ)(r) + + assert.Equal(t, test.want, got) + }) + } +} + +func TestDoubleConverter(t *testing.T) { + tests := []struct { + data []byte + typ Type + want float64 + }{ { - data: []byte{0x28, 0x6F, 0x70, 0x70, 0x61, 0x6E, 0x20, 0x67, 0x61, 0x6E, 0x67, 0x6E, 0x61, 0x6D, 0x20, 0x73, 0x74, 0x79, 0x6C, 0x65, 0x21}, - want: "oppan gangnam style!", - typ: String, - wantTyp: String, - wantErr: require.NoError, + data: []byte{0xE2, 0xA2, 0xF3, 0xAD, 0x07}, + typ: Int, + want: 987654321, }, { - data: []byte{0x36, 0xD1, 0x87, 0xD0, 0xB5, 0x2D, 0xD1, 0x82, 0xD0, 0xBE, 0x20, 0xD0, 0xBF, 0xD0, 0xBE, 0x20, 0xD1, 0x80, 0xD1, 0x83, 0xD1, 0x81, 0xD1, 0x81, 0xD0, 0xBA, 0xD0, 0xB8}, - want: []byte("че-то по русски"), - typ: String, - wantTyp: Bytes, - wantErr: require.NoError, + data: []byte{0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01}, + typ: Long, + want: 9223372036854775807, }, { - data: []byte{0x0C, 0xAC, 0xDC, 0x01, 0x00, 0x10, 0x0F}, - want: []byte{0xAC, 0xDC, 0x01, 0x00, 0x10, 0x0F}, - typ: Bytes, - wantTyp: Bytes, - wantErr: require.NoError, + data: []byte{0x62, 0x20, 0x71, 0x49}, + typ: Float, + want: float64(float32(987654.124)), }, } @@ -113,24 +90,9 @@ func TestConverter(t *testing.T) { test := test t.Run(strconv.Itoa(i), func(t *testing.T) { r := NewReader(bytes.NewReader(test.data), 10) - conv := resolveConverter(test.typ) - var got any - switch test.wantTyp { - case Long: - got = conv.toLong(r) - case Float: - got = conv.toFloat(r) - case Double: - got = conv.toDouble(r) - case String: - got = conv.toString(r) - case Bytes: - got = conv.toBytes(r) - default: - } + got := createDoubleConverter(test.typ)(r) - test.wantErr(t, r.Error) assert.Equal(t, test.want, got) }) }