diff --git a/gen.go b/gen.go index 48d832a..ca57890 100644 --- a/gen.go +++ b/gen.go @@ -5,16 +5,20 @@ import ( "io" "math/big" "reflect" + "strconv" "strings" "text/template" - "github.com/ipfs/go-cid" + cid "github.com/ipfs/go-cid" ) const MaxLength = 8192 const ByteArrayMaxLen = 2 << 20 +const MaxLenTag = "maxlen" +const NoUsrMaxLen = -1 + var ( cidType = reflect.TypeOf(cid.Cid{}) bigIntType = reflect.TypeOf(big.Int{}) @@ -32,6 +36,12 @@ func doTemplate(w io.Writer, info interface{}, templ string) error { "ReadHeader": func(rdr string) string { return fmt.Sprintf(`%s.ReadHeader()`, rdr) }, + "MaxLen": func(val int, def string) string { + if val <= 0 { + return def + } + return fmt.Sprintf("%d", val) + }, }).Parse(templ)) return t.Execute(w, info) @@ -81,6 +91,8 @@ type Field struct { Pkg string IterLabel string + + MaxLen int } func typeName(pkg string, t reflect.Type) string { @@ -172,9 +184,23 @@ func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) { } mapk := f.Name + usrMaxLen := NoUsrMaxLen tagval := f.Tag.Get("cborgen") - if tagval != "" { - mapk = tagval + tags, err := tagparse(tagval) + if err != nil { + return nil, fmt.Errorf("invalid tag format: %w", err) + } + + if tags["name"] != "" { + mapk = tags["name"] + } + if msize := tags["maxlen"]; msize != "" { + val, err := strconv.Atoi(msize) + if err != nil { + return nil, fmt.Errorf("maxsize tag value was not valid: %w", err) + } + + usrMaxLen = val } out.Fields = append(out.Fields, Field{ @@ -183,12 +209,37 @@ func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) { Pointer: pointer, Type: ft, Pkg: pkg, + MaxLen: usrMaxLen, }) } return &out, nil } +func tagparse(v string) (map[string]string, error) { + out := make(map[string]string) + for _, elem := range strings.Split(v, ",") { + elem = strings.TrimSpace(elem) + if elem == "" { + continue + } + + if strings.Contains(elem, "=") { + parts := strings.Split(elem, "=") + if len(parts) != 2 { + return nil, fmt.Errorf("struct tags with params must be of form X=Y") + } + + out[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } else { + out["name"] = elem + } + + } + + return out, nil +} + func (gti GenTypeInfo) TupleHeader() []byte { return CborEncodeMajorType(MajArray, uint64(len(gti.Fields))) } @@ -223,7 +274,7 @@ func emitCborMarshalStringField(w io.Writer, f Field) error { } return doTemplate(w, f, ` - if len({{ .Name }}) > cbg.MaxLength { + if len({{ .Name }}) > {{ MaxLen .MaxLen "cbg.MaxLength" }} { return xerrors.Errorf("Value in field {{ .Name | js }} was too long") } @@ -397,7 +448,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { // Note: this re-slices the slice to deal with arrays. if e.Kind() == reflect.Uint8 { return doTemplate(w, f, ` - if len({{ .Name }}) > cbg.ByteArrayMaxLen { + if len({{ .Name }}) > {{ MaxLen .MaxLen "cbg.ByteArrayMaxLen" }} { return xerrors.Errorf("Byte array in field {{ .Name }} was too long") } @@ -414,7 +465,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { } err := doTemplate(w, f, ` - if len({{ .Name }}) > cbg.MaxLength { + if len({{ .Name }}) > {{ MaxLen .MaxLen "cbg.MaxLength" }} { return xerrors.Errorf("Slice value in field {{ .Name }} was too long") } @@ -869,7 +920,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { if e.Kind() == reflect.Uint8 { return doTemplate(w, f, ` - if extra > cbg.ByteArrayMaxLen { + if extra > {{ MaxLen .MaxLen "cbg.ByteArrayMaxLen" }} { return fmt.Errorf("{{ .Name }}: byte array too large (%d)", extra) } if maj != cbg.MajByteString { @@ -893,7 +944,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { } if err := doTemplate(w, f, ` - if extra > cbg.MaxLength { + if extra > {{ MaxLen .MaxLen "cbg.MaxLength" }} { return fmt.Errorf("{{ .Name }}: array too large (%d)", extra) } `); err != nil { diff --git a/testgen/main.go b/testgen/main.go index 90f2224..ac5ddf4 100644 --- a/testgen/main.go +++ b/testgen/main.go @@ -13,6 +13,7 @@ func main() { types.DeferredContainer{}, types.FixedArrays{}, types.ThingWithSomeTime{}, + types.BigField{}, ); err != nil { panic(err) } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index d53913e..fd2a7b3 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -1095,3 +1095,79 @@ func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) (err error) { } return nil } + +var lengthBufBigField = []byte{129} + +func (t *BigField) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + + if _, err := cw.Write(lengthBufBigField); err != nil { + return err + } + + // t.LargeBytes ([]uint8) (slice) + if len(t.LargeBytes) > 10000000 { + return xerrors.Errorf("Byte array in field t.LargeBytes was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajByteString, uint64(len(t.LargeBytes))); err != nil { + return err + } + + if _, err := cw.Write(t.LargeBytes[:]); err != nil { + return err + } + return nil +} + +func (t *BigField) UnmarshalCBOR(r io.Reader) (err error) { + *t = BigField{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 1 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.LargeBytes ([]uint8) (slice) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > 10000000 { + return fmt.Errorf("t.LargeBytes: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra > 0 { + t.LargeBytes = make([]uint8, extra) + } + + if _, err := io.ReadFull(cr, t.LargeBytes[:]); err != nil { + return err + } + return nil +} diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index 3695142..326701a 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -345,3 +345,36 @@ func TestErrUnexpectedEOF(t *testing.T) { t.Error(err) } } + +func TestLargeField(t *testing.T) { + // 10 MB of data is the specified max so 4 MiB should work + bs := make([]byte, 2<<21) + bs[2<<20] = 0xaa // flags to check that serialization works + bs[2<<20+2<<19] = 0xbb + bs[2<<21-1] = 0xcc + typ := BigField{ + LargeBytes: bs, + } + buf := new(bytes.Buffer) + if err := typ.MarshalCBOR(buf); err != nil { + t.Error(err) + } + enc := buf.Bytes() + typ.LargeBytes = make([]byte, 0) // reset + if err := typ.UnmarshalCBOR(bytes.NewReader(enc)); err != nil { + t.Error(err) + } + + // 16 MiB > 10, fails + bs = make([]byte, 2<<23) + badType := BigField{ + LargeBytes: bs, + } + buf = new(bytes.Buffer) + err := badType.MarshalCBOR(buf) + if err == nil { + t.Fatal("buffer bigger than specified in struct tag should fail") + } +} + +//TODO same for strings diff --git a/testing/types.go b/testing/types.go index 7941ab5..c5fe80c 100644 --- a/testing/types.go +++ b/testing/types.go @@ -104,3 +104,7 @@ type RenamedFields struct { Foo int64 `cborgen:"foo"` Bar string `cborgen:"beep"` } + +type BigField struct { + LargeBytes []byte `cborgen:"maxlen=10000000"` +}