Skip to content

Commit

Permalink
Decode: unstable/Unmarshal interface (#940)
Browse files Browse the repository at this point in the history
Co-authored-by: Pavlos Karakalidis <pkarakal@pkarakal.com>
Co-authored-by: Thomas Pelletier <thomas@pelletier.codes>
  • Loading branch information
3 people authored Mar 19, 2024
1 parent 7dad877 commit 8ed6d13
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 0 deletions.
33 changes: 33 additions & 0 deletions unmarshaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ type Decoder struct {

// global settings
strict bool

// toggles unmarshaler interface
unmarshalerInterface bool
}

// NewDecoder creates a new Decoder that will read from r.
Expand All @@ -54,6 +57,24 @@ func (d *Decoder) DisallowUnknownFields() *Decoder {
return d
}

// EnableUnmarshalerInterface allows to enable unmarshaler interface.
//
// With this feature enabled, types implementing the unstable/Unmarshaler
// interface can be decoded from any structure of the document. It allows types
// that don't have a straightfoward TOML representation to provide their own
// decoding logic.
//
// Currently, types can only decode from a single value. Tables and array tables
// are not supported.
//
// *Unstable:* This method does not follow the compatibility guarantees of
// semver. It can be changed or removed without a new major version being
// issued.
func (d *Decoder) EnableUnmarshalerInterface() *Decoder {
d.unmarshalerInterface = true
return d
}

// Decode the whole content of r into v.
//
// By default, values in the document that don't exist in the target Go value
Expand Down Expand Up @@ -108,6 +129,7 @@ func (d *Decoder) Decode(v interface{}) error {
strict: strict{
Enabled: d.strict,
},
unmarshalerInterface: d.unmarshalerInterface,
}

return dec.FromParser(v)
Expand Down Expand Up @@ -143,6 +165,9 @@ type decoder struct {
// Strict mode
strict strict

// Flag that enables/disables unmarshaler interface.
unmarshalerInterface bool

// Current context for the error.
errorContext *errorContext
}
Expand Down Expand Up @@ -648,6 +673,14 @@ func (d *decoder) handleValue(value *unstable.Node, v reflect.Value) error {
v = initAndDereferencePointer(v)
}

if d.unmarshalerInterface {
if v.CanAddr() && v.Addr().CanInterface() {
if outi, ok := v.Addr().Interface().(unstable.Unmarshaler); ok {
return outi.UnmarshalTOML(value)
}
}
}

ok, err := d.tryTextUnmarshaler(value, v)
if ok || err != nil {
return err
Expand Down
93 changes: 93 additions & 0 deletions unmarshaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"time"

"github.com/pelletier/go-toml/v2"
"github.com/pelletier/go-toml/v2/unstable"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -3772,3 +3773,95 @@ func TestUnmarshal_Nil(t *testing.T) {
})
}
}

type CustomUnmarshalerKey struct {
A int64
}

func (k *CustomUnmarshalerKey) UnmarshalTOML(value *unstable.Node) error {
item, err := strconv.ParseInt(string(value.Data), 10, 64)
if err != nil {
return fmt.Errorf("error converting to int64, %v", err)
}
k.A = item
return nil

}

func TestUnmarshal_CustomUnmarshaler(t *testing.T) {
type MyConfig struct {
Unmarshalers []CustomUnmarshalerKey `toml:"unmarshalers"`
Foo *string `toml:"foo,omitempty"`
}

examples := []struct {
desc string
disableUnmarshalerInterface bool
input string
expected MyConfig
err bool
}{
{
desc: "empty",
input: ``,
expected: MyConfig{Unmarshalers: []CustomUnmarshalerKey{}, Foo: nil},
},
{
desc: "simple",
input: `unmarshalers = [1,2,3]`,
expected: MyConfig{
Unmarshalers: []CustomUnmarshalerKey{
{A: 1},
{A: 2},
{A: 3},
},
Foo: nil,
},
},
{
desc: "unmarshal string and custom unmarshaler",
input: `unmarshalers = [1,2,3]
foo = "bar"`,
expected: MyConfig{
Unmarshalers: []CustomUnmarshalerKey{
{A: 1},
{A: 2},
{A: 3},
},
Foo: func(v string) *string {
return &v
}("bar"),
},
},
{
desc: "simple example, but unmarshaler interface disabled",
disableUnmarshalerInterface: true,
input: `unmarshalers = [1,2,3]`,
err: true,
},
}

for _, ex := range examples {
e := ex
t.Run(e.desc, func(t *testing.T) {
foo := MyConfig{}

decoder := toml.NewDecoder(bytes.NewReader([]byte(e.input)))
if !ex.disableUnmarshalerInterface {
decoder.EnableUnmarshalerInterface()
}
err := decoder.Decode(&foo)

if e.err {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, len(foo.Unmarshalers), len(e.expected.Unmarshalers))
for i := 0; i < len(foo.Unmarshalers); i++ {
require.Equal(t, foo.Unmarshalers[i], e.expected.Unmarshalers[i])
}
require.Equal(t, foo.Foo, e.expected.Foo)
}
})
}
}
7 changes: 7 additions & 0 deletions unstable/unmarshaler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package unstable

// The Unmarshaler interface may be implemented by types to customize their
// behavior when being unmarshaled from a TOML document.
type Unmarshaler interface {
UnmarshalTOML(value *Node) error
}

0 comments on commit 8ed6d13

Please sign in to comment.