diff --git a/decode.go b/decode.go index a600981f..fe990e54 100644 --- a/decode.go +++ b/decode.go @@ -528,7 +528,7 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) { if ok { return mapNode, nil } - return nil, errUnexpectedNodeType(anchor.Value.Type(), ast.MappingType, node.GetToken()) + return nil, errors.ErrUnexpectedNodeType(anchor.Value.Type(), ast.MappingType, node.GetToken()) } if alias, ok := node.(*ast.AliasNode); ok { aliasName := alias.Value.GetToken().Value @@ -540,11 +540,11 @@ func (d *Decoder) getMapNode(node ast.Node) (ast.MapNode, error) { if ok { return mapNode, nil } - return nil, errUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken()) + return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken()) } mapNode, ok := node.(ast.MapNode) if !ok { - return nil, errUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken()) + return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.MappingType, node.GetToken()) } return mapNode, nil } @@ -559,7 +559,7 @@ func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) { return arrayNode, nil } - return nil, errUnexpectedNodeType(anchor.Value.Type(), ast.SequenceType, node.GetToken()) + return nil, errors.ErrUnexpectedNodeType(anchor.Value.Type(), ast.SequenceType, node.GetToken()) } if alias, ok := node.(*ast.AliasNode); ok { aliasName := alias.Value.GetToken().Value @@ -571,11 +571,11 @@ func (d *Decoder) getArrayNode(node ast.Node) (ast.ArrayNode, error) { if ok { return arrayNode, nil } - return nil, errUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken()) + return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken()) } arrayNode, ok := node.(ast.ArrayNode) if !ok { - return nil, errUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken()) + return nil, errors.ErrUnexpectedNodeType(node.Type(), ast.SequenceType, node.GetToken()) } return arrayNode, nil } @@ -598,7 +598,7 @@ func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node) // else, fall through to the error below } } - return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken()) + return reflect.Zero(typ), errors.ErrTypeMismatch(typ, v.Type(), src.GetToken()) } return v.Convert(typ), nil } @@ -614,43 +614,11 @@ func (d *Decoder) convertValue(v reflect.Value, typ reflect.Type, src ast.Node) return reflect.ValueOf(fmt.Sprint(v.Bool())), nil } if !v.Type().ConvertibleTo(typ) { - return reflect.Zero(typ), errTypeMismatch(typ, v.Type(), src.GetToken()) + return reflect.Zero(typ), errors.ErrTypeMismatch(typ, v.Type(), src.GetToken()) } return v.Convert(typ), nil } -func errTypeMismatch(dstType, srcType reflect.Type, token *token.Token) *errors.TypeError { - return &errors.TypeError{DstType: dstType, SrcType: srcType, Token: token} -} - -type unknownFieldError struct { - err error -} - -func (e *unknownFieldError) Error() string { - return e.err.Error() -} - -func errUnknownField(msg string, tk *token.Token) *unknownFieldError { - return &unknownFieldError{err: errors.ErrSyntax(msg, tk)} -} - -func errUnexpectedNodeType(actual, expected ast.NodeType, tk *token.Token) error { - return errors.ErrSyntax(fmt.Sprintf("%s was used where %s is expected", actual.YAMLName(), expected.YAMLName()), tk) -} - -type duplicateKeyError struct { - err error -} - -func (e *duplicateKeyError) Error() string { - return e.err.Error() -} - -func errDuplicateKey(msg string, tk *token.Token) *duplicateKeyError { - return &duplicateKeyError{err: errors.ErrSyntax(msg, tk)} -} - func (d *Decoder) deleteStructKeys(structType reflect.Type, unknownFields map[string]ast.Node) error { if structType.Kind() == reflect.Ptr { structType = structType.Elem() @@ -988,10 +956,10 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No return nil } } else { // couldn't be parsed as float - return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } default: - return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken()) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: @@ -1022,11 +990,11 @@ func (d *Decoder) decodeValue(ctx context.Context, dst reflect.Value, src ast.No return nil } } else { // couldn't be parsed as float - return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } default: - return errTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) + return errors.ErrTypeMismatch(valueType, reflect.TypeOf(v), src.GetToken()) } return errors.ErrOverflow(valueType, fmt.Sprint(v), src.GetToken()) } @@ -1215,7 +1183,7 @@ func (d *Decoder) castToTime(src ast.Node) (time.Time, error) { } s, ok := v.(string) if !ok { - return time.Time{}, errTypeMismatch(reflect.TypeOf(time.Time{}), reflect.TypeOf(v), src.GetToken()) + return time.Time{}, errors.ErrTypeMismatch(reflect.TypeOf(time.Time{}), reflect.TypeOf(v), src.GetToken()) } for _, format := range allowedTimestampFormats { t, err := time.Parse(format, s) @@ -1250,7 +1218,7 @@ func (d *Decoder) castToDuration(src ast.Node) (time.Duration, error) { } s, ok := v.(string) if !ok { - return 0, errTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v), src.GetToken()) + return 0, errors.ErrTypeMismatch(reflect.TypeOf(time.Duration(0)), reflect.TypeOf(v), src.GetToken()) } t, err := time.ParseDuration(s) if err != nil { @@ -1421,7 +1389,7 @@ func (d *Decoder) decodeStruct(ctx context.Context, dst reflect.Value, src ast.N // Unknown fields are expected (they could be fields from the parent struct). if len(unknownFields) != 0 && d.disallowUnknownField && src.GetToken() != nil { for key, node := range unknownFields { - return errUnknownField(fmt.Sprintf(`unknown field "%s"`, key), node.GetToken()) + return errors.ErrUnknownField(fmt.Sprintf(`unknown field "%s"`, key), node.GetToken()) } } @@ -1572,7 +1540,7 @@ func (d *Decoder) validateDuplicateKey(keyMap map[string]struct{}, key interface } if !d.allowDuplicateMapKey { if _, exists := keyMap[k]; exists { - return errDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken()) + return errors.ErrDuplicateKey(fmt.Sprintf(`duplicate key "%s"`, k), keyNode.GetToken()) } } keyMap[k] = struct{}{} diff --git a/error.go b/error.go index 36e4f06e..3d09712d 100644 --- a/error.go +++ b/error.go @@ -1,10 +1,10 @@ package yaml import ( - "errors" "fmt" "github.com/goccy/go-yaml/ast" + "github.com/goccy/go-yaml/internal/errors" ) var ( @@ -17,6 +17,15 @@ var ( ErrDecodeRequiredPointerType = errors.New("required pointer type value") ) +type ( + SyntaxError = errors.SyntaxError + TypeError = errors.TypeError + OverflowError = errors.OverflowError + DuplicateKeyError = errors.DuplicateKeyError + UnknownFieldError = errors.UnknownFieldError + UnexpectedNodeTypeError = errors.UnexpectedNodeTypeError +) + func ErrUnsupportedHeadPositionType(node ast.Node) error { return fmt.Errorf("unsupported comment head position for %s", node.Type()) } diff --git a/internal/errors/error.go b/internal/errors/error.go index 5d3ca4b3..d7f11247 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -1,167 +1,175 @@ package errors import ( - "bytes" "errors" "fmt" "reflect" + "github.com/goccy/go-yaml/ast" "github.com/goccy/go-yaml/printer" "github.com/goccy/go-yaml/token" ) +var ( + As = errors.As + Is = errors.Is + New = errors.New +) + const ( - defaultColorize = false + defaultFormatColor = false defaultIncludeSource = true ) -// ErrSyntax create syntax error instance with message and token -func ErrSyntax(msg string, tk *token.Token) *syntaxError { - return &syntaxError{ - msg: msg, - token: tk, - } +type PrettyFormatError interface { + FormatError(bool, bool) string } -// ErrOverflow creates an overflow error instance with message and a token. -func ErrOverflow(dstType reflect.Type, num string, tk *token.Token) *overflowError { - return &overflowError{dstType: dstType, srcNum: num, token: tk} +type SyntaxError struct { + Message string + Token *token.Token } -type Printer interface { - // Print appends args to the message output. - Print(args ...any) +type TypeError struct { + DstType reflect.Type + SrcType reflect.Type + StructFieldName *string + Token *token.Token } -type FormatErrorPrinter struct { - Printer - Colored bool - InclSource bool +type OverflowError struct { + DstType reflect.Type + SrcNum string + Token *token.Token } -var ( - As = errors.As - Is = errors.Is - New = errors.New -) +type DuplicateKeyError struct { + Message string + Token *token.Token +} -type overflowError struct { - dstType reflect.Type - srcNum string - token *token.Token +type UnknownFieldError struct { + Message string + Token *token.Token } -func (e *overflowError) Error() string { - return fmt.Sprintf("cannot unmarshal %s into Go value of type %s ( overflow )", e.srcNum, e.dstType) +type UnexpectedNodeTypeError struct { + Actual ast.NodeType + Expected ast.NodeType + Token *token.Token } -func (e *overflowError) PrettyPrint(p Printer, colored, inclSource bool) error { - return e.FormatError(&FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) +// ErrSyntax create syntax error instance with message and token +func ErrSyntax(msg string, tk *token.Token) *SyntaxError { + return &SyntaxError{ + Message: msg, + Token: tk, + } } -func (e *overflowError) FormatError(p Printer) error { - var pp printer.Printer +// ErrOverflow creates an overflow error instance with message and a token. +func ErrOverflow(dstType reflect.Type, num string, tk *token.Token) *OverflowError { + return &OverflowError{ + DstType: dstType, + SrcNum: num, + Token: tk, + } +} - var colored, inclSource bool - if fep, ok := p.(*FormatErrorPrinter); ok { - colored = fep.Colored - inclSource = fep.InclSource +// ErrTypeMismatch cerates an type mismatch error instance with token. +func ErrTypeMismatch(dstType, srcType reflect.Type, token *token.Token) *TypeError { + return &TypeError{ + DstType: dstType, + SrcType: srcType, + Token: token, } +} - pos := fmt.Sprintf("[%d:%d] ", e.token.Position.Line, e.token.Position.Column) - msg := pp.PrintErrorMessage(fmt.Sprintf("%s%s", pos, e.Error()), colored) - if inclSource { - msg += "\n" + pp.PrintErrorToken(e.token, colored) +// ErrDuplicateKey creates an duplicate key error instance with token. +func ErrDuplicateKey(msg string, tk *token.Token) *DuplicateKeyError { + return &DuplicateKeyError{ + Message: msg, + Token: tk, } - p.Print(msg) +} - return nil +// ErrUnknownField creates an unknown field error instance with token. +func ErrUnknownField(msg string, tk *token.Token) *UnknownFieldError { + return &UnknownFieldError{ + Message: msg, + Token: tk, + } } -type syntaxError struct { - msg string - token *token.Token +func ErrUnexpectedNodeType(actual, expected ast.NodeType, tk *token.Token) *UnexpectedNodeTypeError { + return &UnexpectedNodeTypeError{ + Actual: actual, + Expected: expected, + Token: tk, + } } -func (e *syntaxError) PrettyPrint(p Printer, colored, inclSource bool) error { - return e.FormatError(&FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) +func (e *SyntaxError) Error() string { + return e.FormatError(defaultFormatColor, defaultIncludeSource) } -func (e *syntaxError) FormatError(p Printer) error { - var pp printer.Printer +func (e *SyntaxError) FormatError(colored, inclSource bool) string { + return formatError(e.Message, e.Token, colored, inclSource) +} - var colored, inclSource bool - if fep, ok := p.(*FormatErrorPrinter); ok { - colored = fep.Colored - inclSource = fep.InclSource - } +func (e *OverflowError) Error() string { + return e.FormatError(defaultFormatColor, defaultIncludeSource) +} - pos := fmt.Sprintf("[%d:%d] ", e.token.Position.Line, e.token.Position.Column) - msg := pp.PrintErrorMessage(fmt.Sprintf("%s%s", pos, e.msg), colored) - if inclSource { - msg += "\n" + pp.PrintErrorToken(e.token, colored) - } - p.Print(msg) - return nil +func (e *OverflowError) FormatError(colored, inclSource bool) string { + return formatError(fmt.Sprintf("cannot unmarshal %s into Go value of type %s ( overflow )", e.SrcNum, e.DstType), e.Token, colored, inclSource) } -type PrettyPrinter interface { - PrettyPrint(Printer, bool, bool) error +func (e *TypeError) msg() string { + if e.StructFieldName != nil { + return fmt.Sprintf("cannot unmarshal %s into Go struct field %s of type %s", e.SrcType, *e.StructFieldName, e.DstType) + } + return fmt.Sprintf("cannot unmarshal %s into Go value of type %s", e.SrcType, e.DstType) } -type Sink struct{ *bytes.Buffer } +func (e *TypeError) Error() string { + return e.FormatError(defaultFormatColor, defaultIncludeSource) +} -func (es *Sink) Print(args ...interface{}) { - fmt.Fprint(es.Buffer, args...) +func (e *TypeError) FormatError(colored, inclSource bool) string { + return formatError(e.msg(), e.Token, colored, inclSource) } -func (es *Sink) Printf(f string, args ...interface{}) { - fmt.Fprintf(es.Buffer, f, args...) +func (e *DuplicateKeyError) Error() string { + return e.FormatError(defaultFormatColor, defaultIncludeSource) } -func (es *Sink) Detail() bool { - return false +func (e *DuplicateKeyError) FormatError(colored, inclSource bool) string { + return formatError(e.Message, e.Token, colored, inclSource) } -func (e *syntaxError) Error() string { - var buf bytes.Buffer - e.PrettyPrint(&Sink{&buf}, defaultColorize, defaultIncludeSource) - return buf.String() +func (e *UnknownFieldError) Error() string { + return e.FormatError(defaultFormatColor, defaultIncludeSource) } -type TypeError struct { - DstType reflect.Type - SrcType reflect.Type - StructFieldName *string - Token *token.Token +func (e *UnknownFieldError) FormatError(colored, inclSource bool) string { + return formatError(e.Message, e.Token, colored, inclSource) } -func (e *TypeError) Error() string { - if e.StructFieldName != nil { - return fmt.Sprintf("cannot unmarshal %s into Go struct field %s of type %s", e.SrcType, *e.StructFieldName, e.DstType) - } - return fmt.Sprintf("cannot unmarshal %s into Go value of type %s", e.SrcType, e.DstType) +func (e *UnexpectedNodeTypeError) Error() string { + return e.FormatError(defaultFormatColor, defaultIncludeSource) } -func (e *TypeError) PrettyPrint(p Printer, colored, inclSource bool) error { - return e.FormatError(&FormatErrorPrinter{Printer: p, Colored: colored, InclSource: inclSource}) +func (e *UnexpectedNodeTypeError) FormatError(colored, inclSource bool) string { + return formatError(fmt.Sprintf("%s was used where %s is expected", e.Actual.YAMLName(), e.Expected.YAMLName()), e.Token, colored, inclSource) } -func (e *TypeError) FormatError(p Printer) error { +func formatError(errMsg string, token *token.Token, colored, inclSource bool) string { var pp printer.Printer - - var colored, inclSource bool - if fep, ok := p.(*FormatErrorPrinter); ok { - colored = fep.Colored - inclSource = fep.InclSource - } - - pos := fmt.Sprintf("[%d:%d] ", e.Token.Position.Line, e.Token.Position.Column) - msg := pp.PrintErrorMessage(fmt.Sprintf("%s%s", pos, e.Error()), colored) + pos := fmt.Sprintf("[%d:%d] ", token.Position.Line, token.Position.Column) + msg := pp.PrintErrorMessage(fmt.Sprintf("%s%s", pos, errMsg), colored) if inclSource { - msg += "\n" + pp.PrintErrorToken(e.Token, colored) + msg += "\n" + pp.PrintErrorToken(token, colored) } - p.Print(msg) - - return nil + return msg } diff --git a/yaml.go b/yaml.go index a90560e1..ec2103b2 100644 --- a/yaml.go +++ b/yaml.go @@ -212,11 +212,9 @@ func NodeToValue(node ast.Node, v interface{}, opts ...DecodeOption) error { // If the third argument `inclSource` is true, the error message will // contain snippets of the YAML source that was used. func FormatError(e error, colored, inclSource bool) string { - var pp errors.PrettyPrinter - if errors.As(e, &pp) { - var buf bytes.Buffer - pp.PrettyPrint(&errors.Sink{Buffer: &buf}, colored, inclSource) - return buf.String() + var pe errors.PrettyFormatError + if errors.As(e, &pe) { + return pe.FormatError(colored, inclSource) } return e.Error()