Skip to content

Commit

Permalink
Merge pull request #9 from mdawar/encoding
Browse files Browse the repository at this point in the history
feat: add support for text-based encoding
  • Loading branch information
quagmt authored Oct 16, 2024
2 parents 9f1afce + 5eb136f commit 9d5eef5
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 17 deletions.
5 changes: 0 additions & 5 deletions bint.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ func parseBint(s []byte) (bool, bint, uint8, error) {
return false, bint{}, 0, ErrMaxStrLen
}

// unQuote if the string is quoted, usually when unmarshalling from JSON
if len(s) > 2 && s[0] == '"' && s[len(s)-1] == '"' {
s = s[1 : len(s)-1]
}

// if s has less than 41 characters, it can fit into u128
// 41 chars = maxLen(u128) + dot + sign = 39 + 1 + 1
if len(s) <= 41 {
Expand Down
39 changes: 37 additions & 2 deletions codec.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
package udecimal

import (
"database/sql"
"database/sql/driver"
"encoding"
"encoding/binary"
"encoding/json"
"fmt"
"math/big"
"math/bits"
"unsafe"
)

var (
_ fmt.Stringer = (*Decimal)(nil)
_ sql.Scanner = (*Decimal)(nil)
_ driver.Valuer = (*Decimal)(nil)
_ encoding.TextMarshaler = (*Decimal)(nil)
_ encoding.TextUnmarshaler = (*Decimal)(nil)
_ json.Marshaler = (*Decimal)(nil)
_ json.Unmarshaler = (*Decimal)(nil)
)

// String returns the string representation of the decimal.
// Trailing zeros will be removed.
func (d Decimal) String() string {
Expand Down Expand Up @@ -214,6 +227,7 @@ func unssafeStringToBytes(s string) []byte {
return unsafe.Slice(unsafe.StringData(s), len(s))
}

// MarshalJSON implements the [json.Marshaler] interface.
func (d Decimal) MarshalJSON() ([]byte, error) {
if !d.coef.overflow() {
return d.bytesU128(true, true), nil
Expand All @@ -222,13 +236,34 @@ func (d Decimal) MarshalJSON() ([]byte, error) {
return []byte(`"` + d.stringBigInt(true) + `"`), nil
}

// UnmarshalJSON implements the [json.Unmarshaler] interface.
func (d *Decimal) UnmarshalJSON(data []byte) error {
// Remove quotes if they exist.
if len(data) >= 2 && data[0] == '"' && data[len(data)-1] == '"' {
data = data[1 : len(data)-1]
}

return d.UnmarshalText(data)
}

// MarshalText implements the [encoding.TextMarshaler] interface.
func (d Decimal) MarshalText() ([]byte, error) {
if !d.coef.overflow() {
// Return without quotes.
return d.bytesU128(true, false), nil
}

return []byte(d.stringBigInt(true)), nil
}

// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
func (d *Decimal) UnmarshalText(data []byte) error {
var err error
*d, err = parseBytes(data)
return err
}

// MarshalBinary implements encoding.BinaryMarshaler interface with custom binary format.
// MarshalBinary implements [encoding.BinaryMarshaler] interface with custom binary format.
//
// Binary format: [overflow + neg] [prec] [total bytes] [coef]
//
Expand Down Expand Up @@ -386,7 +421,7 @@ func (d *Decimal) Scan(src any) error {
return err
}

// Value implements driver.Valuer interface.
// Value implements [driver.Valuer] interface.
func (d Decimal) Value() (driver.Value, error) {
return d.String(), nil
}
Expand Down
74 changes: 66 additions & 8 deletions codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,63 @@ func TestStringFixed(t *testing.T) {
}
}

func TestMarshalText(t *testing.T) {
testcases := []struct {
in string
}{
{"123456789.123456789"},
{"0"},
{"1"},
{"-1"},
{"-123456789.123456789"},
{"0.000000001"},
{"-0.000000001"},
{"123.123"},
{"-123.123"},
{"12345678901234567890123456789.1234567890123456789"},
{"-12345678901234567890123456789.1234567890123456789"},
}

for _, tc := range testcases {
t.Run(tc.in, func(t *testing.T) {
a := MustParse(tc.in)

b, err := a.MarshalText()
require.NoError(t, err)

var c Decimal
require.NoError(t, c.UnmarshalText(b))

require.Equal(t, a, c)
})
}
}

func TestUnmarshalText(t *testing.T) {
testcases := []struct {
in string
wantErr error
}{
{"", ErrEmptyString},
{" ", ErrInvalidFormat},
{"abc", ErrInvalidFormat},
{"1234567890123.1234567890123", nil},
{"1234567890123.12345678901234567899", ErrPrecOutOfRange},
}

for _, tc := range testcases {
t.Run(tc.in, func(t *testing.T) {
var d Decimal
err := d.UnmarshalText([]byte(tc.in))
require.ErrorIs(t, err, tc.wantErr)

if tc.wantErr == nil {
require.Equal(t, MustParse(tc.in), d)
}
})
}
}

type A struct {
P Decimal `json:"a"`
}
Expand Down Expand Up @@ -95,13 +152,16 @@ type Test struct {
Test Decimal `json:"price"`
}

func TestUnmarshalNumber(t *testing.T) {
func TestUnmarshalJSON(t *testing.T) {
testcases := []struct {
in string
wantErr error
}{
{`""`, ErrEmptyString},
{`" "`, ErrInvalidFormat},
{`"abc"`, ErrInvalidFormat},
{"1234567890123.1234567890123", nil},
{"1234567890123.12345678901234567899", fmt.Errorf("precision out of range. Only support maximum 19 digits after the decimal point")},
{"1234567890123.12345678901234567899", ErrPrecOutOfRange},
{`"1234567890123.1234567890123"`, nil},
}

Expand All @@ -111,13 +171,11 @@ func TestUnmarshalNumber(t *testing.T) {

var test Test
err := json.Unmarshal([]byte(s), &test)
if tc.wantErr != nil {
require.Equal(t, tc.wantErr, err)
return
}
require.ErrorIs(t, err, tc.wantErr)

require.NoError(t, err)
require.Equal(t, strings.Trim(tc.in, `"`), test.Test.String())
if tc.wantErr == nil {
require.Equal(t, strings.Trim(tc.in, `"`), test.Test.String())
}
})
}
}
Expand Down
2 changes: 0 additions & 2 deletions decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,14 +275,12 @@ func Parse(s string) (Decimal, error) {
}

func parseBytes(b []byte) (Decimal, error) {

neg, bint, prec, err := parseBint(b)
if err != nil {
return Decimal{}, err
}

return newDecimal(neg, bint, prec), nil

}

// MustParse similars to Parse, but pacnis instead of returning error.
Expand Down
21 changes: 21 additions & 0 deletions doc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,19 @@ func ExampleDecimal_MarshalJSON() {
// "1234567890123456789.1234567890123456789"
}

func ExampleDecimal_MarshalText() {
a, _ := MustParse("1.23").MarshalText()
b, _ := MustParse("-1.2345").MarshalText()
c, _ := MustParse("1234567890123456789.1234567890123456789").MarshalText()
fmt.Println(string(a))
fmt.Println(string(b))
fmt.Println(string(c))
// Output:
// 1.23
// -1.2345
// 1234567890123456789.1234567890123456789
}

func ExampleDecimal_Neg() {
fmt.Println(MustParse("1.23").Neg())
fmt.Println(MustParse("-1.23").Neg())
Expand Down Expand Up @@ -399,6 +412,14 @@ func ExampleDecimal_UnmarshalJSON() {
// 1.23
}

func ExampleDecimal_UnmarshalText() {
var a Decimal
_ = a.UnmarshalText([]byte("1.23"))
fmt.Println(a)
// Output:
// 1.23
}

func ExampleDecimal_Value() {
fmt.Println(MustParse("1.2345").Value())
// Output:
Expand Down

0 comments on commit 9d5eef5

Please sign in to comment.