Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

separate unsigned integer decoding to support full range of uint64 #782

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 28 additions & 55 deletions decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,40 @@ import (
)

func parseInteger(b []byte) (int64, error) {
cleaned, base, err := cleanInteger(b)
if err != nil {
return 0, err
}
return strconv.ParseInt(string(cleaned), base, 64)
}

func parseUinteger(b []byte) (uint64, error) {
cleaned, base, err := cleanInteger(b)
if err != nil {
return 0, err
}
return strconv.ParseUint(string(cleaned), base, 64)
}

func cleanInteger(b []byte) (cleaned []byte, base int, err error) {
if len(b) > 2 && b[0] == '0' {
switch b[1] {
case 'x':
return parseIntHex(b)
base = 16
case 'b':
return parseIntBin(b)
base = 2
case 'o':
return parseIntOct(b)
base = 8
default:
panic(fmt.Errorf("invalid base '%c', should have been checked by scanIntOrFloat", b[1]))
}
cleaned, err = checkAndRemoveUnderscoresIntegers(b[2:])
return
}

return parseIntDec(b)
base = 10
cleaned, err = cleanIntDec(b)
return
}

func parseLocalDate(b []byte) (LocalDate, error) {
Expand Down Expand Up @@ -328,56 +348,14 @@ func parseFloat(b []byte) (float64, error) {
return f, nil
}

func parseIntHex(b []byte) (int64, error) {
cleaned, err := checkAndRemoveUnderscoresIntegers(b[2:])
if err != nil {
return 0, err
}

i, err := strconv.ParseInt(string(cleaned), 16, 64)
if err != nil {
return 0, newDecodeError(b, "couldn't parse hexadecimal number: %w", err)
}

return i, nil
}

func parseIntOct(b []byte) (int64, error) {
cleaned, err := checkAndRemoveUnderscoresIntegers(b[2:])
if err != nil {
return 0, err
}

i, err := strconv.ParseInt(string(cleaned), 8, 64)
if err != nil {
return 0, newDecodeError(b, "couldn't parse octal number: %w", err)
}

return i, nil
}

func parseIntBin(b []byte) (int64, error) {
cleaned, err := checkAndRemoveUnderscoresIntegers(b[2:])
if err != nil {
return 0, err
}

i, err := strconv.ParseInt(string(cleaned), 2, 64)
if err != nil {
return 0, newDecodeError(b, "couldn't parse binary number: %w", err)
}

return i, nil
}

func isSign(b byte) bool {
return b == '+' || b == '-'
}

func parseIntDec(b []byte) (int64, error) {
func cleanIntDec(b []byte) ([]byte, error) {
cleaned, err := checkAndRemoveUnderscoresIntegers(b)
if err != nil {
return 0, err
return nil, err
}

startIdx := 0
Expand All @@ -387,15 +365,10 @@ func parseIntDec(b []byte) (int64, error) {
}

if len(cleaned) > startIdx+1 && cleaned[startIdx] == '0' {
return 0, newDecodeError(b, "leading zero not allowed on decimal number")
return nil, newDecodeError(b, "leading zero not allowed on decimal number")
}

i, err := strconv.ParseInt(string(cleaned), 10, 64)
if err != nil {
return 0, newDecodeError(b, "couldn't parse decimal number: %w", err)
}

return i, nil
return cleaned, nil
}

func checkAndRemoveUnderscoresIntegers(b []byte) ([]byte, error) {
Expand Down
133 changes: 63 additions & 70 deletions unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -866,93 +866,86 @@ func (d *decoder) unmarshalFloat(value *ast.Node, v reflect.Value) error {
return nil
}

const (
maxInt = int64(^uint(0) >> 1)
minInt = -maxInt - 1
)

// Maximum value of uint for decoding. Currently the decoder parses the integer
// into an int64. As a result, on architectures where uint is 64 bits, the
// effective maximum uint we can decode is the maximum of int64. On
// architectures where uint is 32 bits, the maximum value we can decode is
// lower: the maximum of uint32. I didn't find a way to figure out this value at
// compile time, so it is computed during initialization.
var maxUint int64 = math.MaxInt64

func init() {
m := uint64(^uint(0))
if m < uint64(maxUint) {
maxUint = int64(m)
}
}

func (d *decoder) unmarshalInteger(value *ast.Node, v reflect.Value) error {
i, err := parseInteger(value.Data)
if err != nil {
return err
}

var r reflect.Value

switch v.Kind() {
case reflect.Int64:
v.SetInt(i)
return nil
case reflect.Int32:
if i < math.MinInt32 || i > math.MaxInt32 {
return fmt.Errorf("toml: number %d does not fit in an int32", i)
k := v.Kind()
switch k {
case reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, reflect.Int:
i, err := parseInteger(value.Data)
if err != nil {
return err
}
switch k {
case reflect.Int64:
v.SetInt(i)
return nil
case reflect.Int32:
if i < math.MinInt32 || i > math.MaxInt32 {
return fmt.Errorf("toml: number %d does not fit in an int32", i)
}

r = reflect.ValueOf(int32(i))
case reflect.Int16:
if i < math.MinInt16 || i > math.MaxInt16 {
return fmt.Errorf("toml: number %d does not fit in an int16", i)
}
r = reflect.ValueOf(int32(i))
case reflect.Int16:
if i < math.MinInt16 || i > math.MaxInt16 {
return fmt.Errorf("toml: number %d does not fit in an int16", i)
}

r = reflect.ValueOf(int16(i))
case reflect.Int8:
if i < math.MinInt8 || i > math.MaxInt8 {
return fmt.Errorf("toml: number %d does not fit in an int8", i)
}
r = reflect.ValueOf(int16(i))
case reflect.Int8:
if i < math.MinInt8 || i > math.MaxInt8 {
return fmt.Errorf("toml: number %d does not fit in an int8", i)
}

r = reflect.ValueOf(int8(i))
case reflect.Int:
if i < minInt || i > maxInt {
return fmt.Errorf("toml: number %d does not fit in an int", i)
}
r = reflect.ValueOf(int8(i))
case reflect.Int:
if i < math.MinInt || i > math.MaxInt {
return fmt.Errorf("toml: number %d does not fit in an int", i)
}

r = reflect.ValueOf(int(i))
case reflect.Uint64:
if i < 0 {
return fmt.Errorf("toml: negative number %d does not fit in an uint64", i)
r = reflect.ValueOf(int(i))
}

r = reflect.ValueOf(uint64(i))
case reflect.Uint32:
if i < 0 || i > math.MaxUint32 {
return fmt.Errorf("toml: negative number %d does not fit in an uint32", i)
case reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8, reflect.Uint:
u, err := parseUinteger(value.Data)
if err != nil {
return err
}
switch k {
case reflect.Uint64:
v.SetUint(u)
return nil
case reflect.Uint32:
if u > math.MaxUint32 {
return fmt.Errorf("toml: number %d does not fit in an uint32", u)
}

r = reflect.ValueOf(uint32(i))
case reflect.Uint16:
if i < 0 || i > math.MaxUint16 {
return fmt.Errorf("toml: negative number %d does not fit in an uint16", i)
}
r = reflect.ValueOf(uint32(u))
case reflect.Uint16:
if u > math.MaxUint16 {
return fmt.Errorf("toml: number %d does not fit in an uint16", u)
}

r = reflect.ValueOf(uint16(i))
case reflect.Uint8:
if i < 0 || i > math.MaxUint8 {
return fmt.Errorf("toml: negative number %d does not fit in an uint8", i)
}
r = reflect.ValueOf(uint16(u))
case reflect.Uint8:
if u > math.MaxUint8 {
return fmt.Errorf("toml: number %d does not fit in an uint8", u)
}

r = reflect.ValueOf(uint8(u))
case reflect.Uint:
if u > math.MaxUint {
return fmt.Errorf("toml: number %d does not fit in an uint", u)
}

r = reflect.ValueOf(uint8(i))
case reflect.Uint:
if i < 0 || i > maxUint {
return fmt.Errorf("toml: negative number %d does not fit in an uint", i)
r = reflect.ValueOf(uint(u))
}

r = reflect.ValueOf(uint(i))
case reflect.Interface:
i, err := parseInteger(value.Data)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This interface case is still ambiguous, and thus limited to int64 currently.

if err != nil {
return err
}
r = reflect.ValueOf(i)
default:
return d.typeMismatchError("integer", v.Type())
Expand Down
11 changes: 10 additions & 1 deletion unmarshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
"testing"
"time"

"github.com/pelletier/go-toml/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/pelletier/go-toml/v2"
)

func ExampleDecoder_DisallowUnknownFields() {
Expand Down Expand Up @@ -2423,6 +2424,14 @@ Host = 'main.domain.com'
require.Equal(t, expected, string(b))
}

func TestIssue781(t *testing.T) {
var v struct {
Uint64 uint64
}
assert.NoError(t, toml.Unmarshal([]byte(`Uint64 = 18446744073709551615`), &v))
assert.Equal(t, uint64(math.MaxUint64), v.Uint64)
}

func TestUnmarshalDecodeErrors(t *testing.T) {
examples := []struct {
desc string
Expand Down