diff --git a/gen.go b/gen.go index 3bf7eb1..323c198 100644 --- a/gen.go +++ b/gen.go @@ -36,7 +36,12 @@ func doTemplate(w io.Writer, info interface{}, templ string) error { "ReadHeader": func(rdr string) string { return fmt.Sprintf(`%s.ReadHeader()`, rdr) }, - //todo do it here + "MaxLen": func(val int, def string) string { + if val <= 0 { + return def + } + return fmt.Sprintf("%d", val) + }, }).Parse(templ)) return t.Execute(w, info) @@ -181,14 +186,21 @@ func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) { mapk := f.Name usrMaxLen := NoUsrMaxLen tagval := f.Tag.Get("cborgen") - if len(tagval) > len(MaxLenTag) && tagval[0:len(MaxLenTag)] == MaxLenTag { - var err error - usrMaxLen, err = strconv.Atoi(tagval[len(MaxLenTag):]) + tags, err := tagparse(tagval) + if err != nil { + return nil, fmt.Errorf("invalid tag format: %w", err) + } + + if tags["name"] != "" { + mapk = tags["name"] + } + if msize := tags["maxlen"]; msize != "" { + val, err := strconv.Atoi(msize) if err != nil { - return nil, fmt.Errorf("failed to parse specified length in max len tag %w", err) + return nil, fmt.Errorf("maxsize tag value was not valid: %w", err) } - } else if tagval != "" { - mapk = tagval + + usrMaxLen = val } out.Fields = append(out.Fields, Field{ @@ -204,6 +216,30 @@ func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) { return &out, nil } +func tagparse(v string) (map[string]string, error) { + out := make(map[string]string) + for _, elem := range strings.Split(v, ",") { + elem = strings.TrimSpace(elem) + if elem == "" { + continue + } + + if strings.Contains(elem, "=") { + parts := strings.Split(elem, "=") + if len(parts) != 2 { + return nil, fmt.Errorf("struct tags with params must be of form X=Y") + } + + out[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1]) + } else { + out["name"] = elem + } + + } + + return out, nil +} + func (gti GenTypeInfo) TupleHeader() []byte { return CborEncodeMajorType(MajArray, uint64(len(gti.Fields))) } @@ -238,7 +274,7 @@ func emitCborMarshalStringField(w io.Writer, f Field) error { } return doTemplate(w, f, ` - if len({{ .Name }}) > {{ .MaxLen }} { + if len({{ .Name }}) > {{ MaxLen .MaxLen "cbg.MaxLength" }} { return xerrors.Errorf("Value in field {{ .Name | js }} was too long") } @@ -412,7 +448,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { // Note: this re-slices the slice to deal with arrays. if e.Kind() == reflect.Uint8 { return doTemplate(w, f, ` - if len({{ .Name }}) > {{ .MaxLen }} { + if len({{ .Name }}) > {{ MaxLen .MaxLen "cbg.ByteArrayMaxLen" }} { return xerrors.Errorf("Byte array in field {{ .Name }} was too long") } @@ -429,7 +465,7 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error { } err := doTemplate(w, f, ` - if len({{ .Name }}) > {{ .MaxLen }} { + if len({{ .Name }}) > {{ MaxLen .MaxLen "cbg.MaxLength" }} { return xerrors.Errorf("Slice value in field {{ .Name }} was too long") } @@ -884,7 +920,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { if e.Kind() == reflect.Uint8 { return doTemplate(w, f, ` - if extra > {{ .MaxLen }} { + if extra > {{ MaxLen .MaxLen "cbg.ByteArrayMaxLen" }} { return fmt.Errorf("{{ .Name }}: byte array too large (%d)", extra) } if maj != cbg.MajByteString { @@ -908,7 +944,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error { } if err := doTemplate(w, f, ` - if extra > {{ .MaxLen }} { + if extra > {{ MaxLen .MaxLen "cbg.MaxLength" }} { return fmt.Errorf("{{ .Name }}: array too large (%d)", extra) } `); err != nil { @@ -1208,7 +1244,7 @@ func (t *{{ .Name}}) UnmarshalCBOR(r io.Reader) (err error) { return fmt.Errorf("cbor input should be of type map") } - if extra > {{ .MaxLen }} { + if extra > cbg.MaxLength { return fmt.Errorf("{{ .Name }}: map struct too large (%d)", extra) } diff --git a/testing/cbor_gen.go b/testing/cbor_gen.go index d53913e..fd2a7b3 100644 --- a/testing/cbor_gen.go +++ b/testing/cbor_gen.go @@ -1095,3 +1095,79 @@ func (t *ThingWithSomeTime) UnmarshalCBOR(r io.Reader) (err error) { } return nil } + +var lengthBufBigField = []byte{129} + +func (t *BigField) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + + if _, err := cw.Write(lengthBufBigField); err != nil { + return err + } + + // t.LargeBytes ([]uint8) (slice) + if len(t.LargeBytes) > 10000000 { + return xerrors.Errorf("Byte array in field t.LargeBytes was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajByteString, uint64(len(t.LargeBytes))); err != nil { + return err + } + + if _, err := cw.Write(t.LargeBytes[:]); err != nil { + return err + } + return nil +} + +func (t *BigField) UnmarshalCBOR(r io.Reader) (err error) { + *t = BigField{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 1 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.LargeBytes ([]uint8) (slice) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > 10000000 { + return fmt.Errorf("t.LargeBytes: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra > 0 { + t.LargeBytes = make([]uint8, extra) + } + + if _, err := io.ReadFull(cr, t.LargeBytes[:]); err != nil { + return err + } + return nil +} diff --git a/testing/roundtrip_test.go b/testing/roundtrip_test.go index 63b93cd..326701a 100644 --- a/testing/roundtrip_test.go +++ b/testing/roundtrip_test.go @@ -361,17 +361,17 @@ func TestLargeField(t *testing.T) { } enc := buf.Bytes() typ.LargeBytes = make([]byte, 0) // reset - if err := typ.LargeBytes.UnmarshalCBOR(bytes.NewReader(enc)); err != nil { + if err := typ.UnmarshalCBOR(bytes.NewReader(enc)); err != nil { t.Error(err) } // 16 MiB > 10, fails - bs := make([]byte, 2<<23) - badType = BigField{ + bs = make([]byte, 2<<23) + badType := BigField{ LargeBytes: bs, } - buf := new(bytes.Buffer) - err := badTyp.MarshalCBOR(buf) + buf = new(bytes.Buffer) + err := badType.MarshalCBOR(buf) if err == nil { t.Fatal("buffer bigger than specified in struct tag should fail") }