Skip to content

Commit

Permalink
flag to control zeroing of target structs before decode/unmarshal
Browse files Browse the repository at this point in the history
  • Loading branch information
glycerine committed Oct 16, 2016
1 parent 886f034 commit d8d0f82
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 15 deletions.
3 changes: 3 additions & 0 deletions cfg/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ type MsgpConfig struct {
Marshal bool
Tests bool
Unexported bool

SkipStructZeroingOnDecode bool
}

// call DefineFlags before myflags.Parse()
Expand All @@ -21,6 +23,7 @@ func (c *MsgpConfig) DefineFlags(fs *flag.FlagSet) {
fs.BoolVar(&c.Marshal, "marshal", true, "create Marshal and Unmarshal methods")
fs.BoolVar(&c.Tests, "tests", true, "create tests and benchmarks")
fs.BoolVar(&c.Unexported, "unexported", false, "also process unexported types")
fs.BoolVar(&c.SkipStructZeroingOnDecode, "skip-decode-struct-zeroing", false, "possibly dangerous option: don't zero out the target struct before decoding/unmarshalling into it.")
}

// call c.ValidateConfig() after myflags.Parse()
Expand Down
14 changes: 10 additions & 4 deletions gen/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ package gen
import (
"io"
"strconv"

"github.com/tinylib/msgp/cfg"
)

func decode(w io.Writer) *decodeGen {
func decode(w io.Writer, cfg *cfg.MsgpConfig) *decodeGen {
return &decodeGen{
p: printer{w: w},
hasfield: false,
cfg: cfg,
}
}

Expand All @@ -17,6 +20,7 @@ type decodeGen struct {
p printer
hasfield bool
depth int
cfg *cfg.MsgpConfig
}

func (d *decodeGen) Method() Method { return Decode }
Expand Down Expand Up @@ -93,9 +97,11 @@ func (d *decodeGen) structAsTuple(s *Struct) {
}

func (d *decodeGen) structAsMap(s *Struct) {
if d.depth == 1 {
d.p.printf("\n\n// zero the target:\n")
d.p.printf("*%s = %s{}\n", s.vname, s.TypeName())
if !d.cfg.SkipStructZeroingOnDecode {
if d.depth == 1 {
d.p.printf("\n\n// zero the target:\n")
d.p.printf("*%s = %s{}\n", s.vname, s.TypeName())
}
}
d.needsField()
sz := randIdent()
Expand Down
13 changes: 8 additions & 5 deletions gen/spec.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gen

import (
"fmt"
"github.com/tinylib/msgp/cfg"
"io"
)

Expand Down Expand Up @@ -101,17 +102,19 @@ const (

type Printer struct {
gens []generator
cfg *cfg.MsgpConfig
}

func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
func NewPrinter(m Method, out io.Writer, tests io.Writer, cfg *cfg.MsgpConfig) *Printer {
if m.isset(Test) && tests == nil {
panic("cannot print tests with 'nil' tests argument!")
}
gens := make([]generator, 0, 8)
if m.isset(Decode) {
gens = append(gens, decode(out))
gens = append(gens, decode(out, cfg))
}
// must run FieldsEmpty before Encode/Marshal, to set Struct.hasOmitEmptyTags
// must run FieldsEmpty before Encode/Marshal, so as
// to set Struct.hasOmitEmptyTags
if m.isset(FieldsEmpty) {
gens = append(gens, fieldsempty(out))
}
Expand All @@ -122,7 +125,7 @@ func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
gens = append(gens, marshal(out))
}
if m.isset(Unmarshal) {
gens = append(gens, unmarshal(out))
gens = append(gens, unmarshal(out, cfg))
}
if m.isset(Size) {
gens = append(gens, sizes(out))
Expand All @@ -136,7 +139,7 @@ func NewPrinter(m Method, out io.Writer, tests io.Writer) *Printer {
if len(gens) == 0 {
panic("NewPrinter called with invalid method flags")
}
return &Printer{gens: gens}
return &Printer{gens: gens, cfg: cfg}
}

// TransformPass is a pass that transforms individual
Expand Down
16 changes: 11 additions & 5 deletions gen/unmarshal.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ package gen
import (
"io"
"strconv"

"github.com/tinylib/msgp/cfg"
)

func unmarshal(w io.Writer) *unmarshalGen {
func unmarshal(w io.Writer, cfg *cfg.MsgpConfig) *unmarshalGen {
return &unmarshalGen{
p: printer{w: w},
p: printer{w: w},
cfg: cfg,
}
}

Expand All @@ -16,6 +19,7 @@ type unmarshalGen struct {
p printer
hasfield bool
depth int
cfg *cfg.MsgpConfig
}

func (u *unmarshalGen) Method() Method { return Unmarshal }
Expand Down Expand Up @@ -90,9 +94,11 @@ func (u *unmarshalGen) tuple(s *Struct) {
}

func (u *unmarshalGen) mapstruct(s *Struct) {
if u.depth == 1 {
u.p.printf("\n\n// zero the target:\n")
u.p.printf("*%s = %s{}\n", s.vname, s.TypeName())
if !u.cfg.SkipStructZeroingOnDecode {
if u.depth == 1 {
u.p.printf("\n\n// zero the target:\n")
u.p.printf("*%s = %s{}\n", s.vname, s.TypeName())
}
}
u.needsField()
sz := randIdent()
Expand Down
2 changes: 1 addition & 1 deletion printer/print.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func generate(f *parse.FileSet, mode gen.Method) (*bytes.Buffer, *bytes.Buffer,
}
testwr = testbuf
}
return outbuf, testbuf, f.PrintTo(gen.NewPrinter(mode, outbuf, testwr))
return outbuf, testbuf, f.PrintTo(gen.NewPrinter(mode, outbuf, testwr, f.Cfg))
}

func writePkgHeader(b *bytes.Buffer, name string) {
Expand Down

0 comments on commit d8d0f82

Please sign in to comment.