From 87ddfd193816b0d2554d0e995a210b66dfd1e14a Mon Sep 17 00:00:00 2001 From: Masaaki Goshima Date: Tue, 29 Oct 2024 20:00:48 +0900 Subject: [PATCH] returns InvalidToken if scanner encounters error (#486) --- lexer/lexer_test.go | 27 +++++ parser/parser.go | 3 + parser/parser_test.go | 11 ++ scanner/context.go | 15 +-- scanner/error.go | 19 ++++ scanner/scanner.go | 237 +++++++++++++++++++++++------------------- token/token.go | 37 ++++++- 7 files changed, 228 insertions(+), 121 deletions(-) create mode 100644 scanner/error.go diff --git a/lexer/lexer_test.go b/lexer/lexer_test.go index 3a6d9126..b6cc87a7 100644 --- a/lexer/lexer_test.go +++ b/lexer/lexer_test.go @@ -2393,3 +2393,30 @@ b: 1`, }) } } + +func TestInvalid(t *testing.T) { + tests := []struct { + name string + src string + }{ + { + name: "literal opt", + src: ` +a: |invalid + foo`, + }, + { + name: "literal opt", + src: ` +a: |invalid`, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := lexer.Tokenize(test.src) + if got.InvalidToken() == nil { + t.Fatal("expected contains invalid token") + } + }) + } +} diff --git a/parser/parser.go b/parser/parser.go index 82e066f6..ea8bb8c4 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -719,6 +719,9 @@ func ParseBytes(bytes []byte, mode Mode) (*ast.File, error) { // Parse parse from token instances, and returns ast.File func Parse(tokens token.Tokens, mode Mode) (*ast.File, error) { + if tk := tokens.InvalidToken(); tk != nil { + return nil, errors.ErrSyntax("found invalid token", tk) + } var p parser f, err := p.parse(tokens, mode) if err != nil { diff --git a/parser/parser_test.go b/parser/parser_test.go index 2ec0e7e6..09facb71 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -712,6 +712,17 @@ a: ^ `, }, + { + ` +a: |invalidopt + foo +`, + ` +[2:4] found invalid token +> 2 | a:|invalidopt + ^ + 3 | foo`, + }, } for _, test := range tests { t.Run(test.source, func(t *testing.T) { diff --git a/scanner/context.go b/scanner/context.go index b5ffabbc..1522e5ba 100644 --- a/scanner/context.go +++ b/scanner/context.go @@ -21,7 +21,6 @@ type Context struct { isRawFolded bool isLiteral bool isFolded bool - isSingleLine bool literalOpt string } @@ -35,9 +34,8 @@ var ( func createContext() *Context { return &Context{ - idx: 0, - tokens: token.Tokens{}, - isSingleLine: true, + idx: 0, + tokens: token.Tokens{}, } } @@ -58,7 +56,6 @@ func (c *Context) reset(src []rune) { c.tokens = c.tokens[:0] c.resetBuffer() c.isRawFolded = false - c.isSingleLine = true c.isLiteral = false c.isFolded = false c.literalOpt = "" @@ -71,10 +68,6 @@ func (c *Context) resetBuffer() { c.notSpaceOrgCharPos = 0 } -func (c *Context) isSaveIndentMode() bool { - return c.isLiteral || c.isFolded || c.isRawFolded -} - func (c *Context) breakLiteral() { c.isLiteral = false c.isRawFolded = false @@ -186,10 +179,6 @@ func (c *Context) progress(num int) { c.idx += num } -func (c *Context) nextPos() int { - return c.idx + 1 -} - func (c *Context) existsBuffer() bool { return len(c.bufferedSrc()) != 0 } diff --git a/scanner/error.go b/scanner/error.go new file mode 100644 index 00000000..4298be42 --- /dev/null +++ b/scanner/error.go @@ -0,0 +1,19 @@ +package scanner + +import "github.com/goccy/go-yaml/token" + +type InvalidTokenError struct { + Message string + Token *token.Token +} + +func (e *InvalidTokenError) Error() string { + return e.Message +} + +func ErrInvalidToken(msg string, tk *token.Token) *InvalidTokenError { + return &InvalidTokenError{ + Message: msg, + Token: tk, + } +} diff --git a/scanner/scanner.go b/scanner/scanner.go index a559555a..aa9bcc7d 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -2,6 +2,7 @@ package scanner import ( "errors" + "fmt" "io" "strings" @@ -70,7 +71,7 @@ func (s *Scanner) bufferedToken(ctx *Context) *token.Token { line := s.line column := s.column - len(ctx.buf) level := s.indentLevel - if ctx.isSaveIndentMode() { + if ctx.isDocument() { line -= s.newLineCount(ctx.buf) column = strings.Index(string(ctx.obuf), string(ctx.buf)) + 1 // Since we are in a literal, folded or raw folded @@ -92,7 +93,7 @@ func (s *Scanner) bufferedToken(ctx *Context) *token.Token { func (s *Scanner) progressColumn(ctx *Context, num int) { s.column += num s.offset += num - ctx.progress(num) + s.progress(ctx, num) } func (s *Scanner) progressLine(ctx *Context) { @@ -103,7 +104,12 @@ func (s *Scanner) progressLine(ctx *Context) { s.indentNum = 0 s.isFirstCharAtLine = true s.isAnchor = false - ctx.progress(1) + s.progress(ctx, 1) +} + +func (s *Scanner) progress(ctx *Context, num int) { + ctx.progress(num) + s.sourcePos += num } func (s *Scanner) isNewLineChar(c rune) bool { @@ -203,7 +209,7 @@ func (s *Scanner) breakLiteral(ctx *Context) { ctx.breakLiteral() } -func (s *Scanner) scanSingleQuote(ctx *Context) (tk *token.Token, pos int) { +func (s *Scanner) scanSingleQuote(ctx *Context) *token.Token { ctx.addOriginBuf('\'') srcpos := s.pos() startIndex := ctx.idx + 1 @@ -212,6 +218,8 @@ func (s *Scanner) scanSingleQuote(ctx *Context) (tk *token.Token, pos int) { value := []rune{} isFirstLineChar := false isNewLine := false + + var tk *token.Token for idx := startIndex; idx < size; idx++ { if !isNewLine { s.progressColumn(ctx, 1) @@ -219,7 +227,6 @@ func (s *Scanner) scanSingleQuote(ctx *Context) (tk *token.Token, pos int) { isNewLine = false } c := src[idx] - pos = idx + 1 ctx.addOriginBuf(c) if s.isNewLineChar(c) { value = append(value, ' ') @@ -239,14 +246,14 @@ func (s *Scanner) scanSingleQuote(ctx *Context) (tk *token.Token, pos int) { value = append(value, c) ctx.addOriginBuf(c) idx++ + s.progressColumn(ctx, 1) continue } s.progressColumn(ctx, 1) tk = token.SingleQuote(string(value), string(ctx.obuf), srcpos) - pos = idx - startIndex + 1 - return + return tk } - return + return tk } func hexToInt(b rune) int { @@ -267,7 +274,7 @@ func hexRunesToInt(b []rune) int { return sum } -func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { +func (s *Scanner) scanDoubleQuote(ctx *Context) *token.Token { ctx.addOriginBuf('"') srcpos := s.pos() startIndex := ctx.idx + 1 @@ -276,6 +283,8 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { value := []rune{} isFirstLineChar := false isNewLine := false + + var tk *token.Token for idx := startIndex; idx < size; idx++ { if !isNewLine { s.progressColumn(ctx, 1) @@ -283,7 +292,6 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { isNewLine = false } c := src[idx] - pos = idx + 1 ctx.addOriginBuf(c) if s.isNewLineChar(c) { value = append(value, ' ') @@ -347,32 +355,35 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { ctx.addOriginBuf(nextChar) value = append(value, nextChar) case 'x': - progress = 3 - if idx+progress >= size { - // TODO: need to return error - //err = errors.New("invalid escape character \\x") - return + if idx+3 >= size { + progress = 1 + ctx.addOriginBuf(nextChar) + value = append(value, nextChar) + } else { + progress = 3 + codeNum := hexRunesToInt(src[idx+2 : idx+progress+1]) + value = append(value, rune(codeNum)) } - codeNum := hexRunesToInt(src[idx+2 : idx+progress+1]) - value = append(value, rune(codeNum)) case 'u': - progress = 5 - if idx+progress >= size { - // TODO: need to return error - //err = errors.New("invalid escape character \\u") - return + if idx+5 >= size { + progress = 1 + ctx.addOriginBuf(nextChar) + value = append(value, nextChar) + } else { + progress = 5 + codeNum := hexRunesToInt(src[idx+2 : idx+progress+1]) + value = append(value, rune(codeNum)) } - codeNum := hexRunesToInt(src[idx+2 : idx+progress+1]) - value = append(value, rune(codeNum)) case 'U': - progress = 9 - if idx+progress >= size { - // TODO: need to return error - //err = errors.New("invalid escape character \\U") - return + if idx+9 >= size { + progress = 1 + ctx.addOriginBuf(nextChar) + value = append(value, nextChar) + } else { + progress = 9 + codeNum := hexRunesToInt(src[idx+2 : idx+progress+1]) + value = append(value, rune(codeNum)) } - codeNum := hexRunesToInt(src[idx+2 : idx+progress+1]) - value = append(value, rune(codeNum)) case '\\': progress = 1 ctx.addOriginBuf(nextChar) @@ -390,13 +401,12 @@ func (s *Scanner) scanDoubleQuote(ctx *Context) (tk *token.Token, pos int) { } s.progressColumn(ctx, 1) tk = token.DoubleQuote(string(value), string(ctx.obuf), srcpos) - pos = idx - startIndex + 1 - return + return tk } - return + return tk } -func (s *Scanner) scanQuote(ctx *Context, ch rune) (tk *token.Token, pos int) { +func (s *Scanner) scanQuote(ctx *Context, ch rune) *token.Token { if ch == '\'' { return s.scanSingleQuote(ctx) } @@ -427,26 +437,34 @@ func (s *Scanner) isMergeKey(ctx *Context) bool { return false } -func (s *Scanner) scanTag(ctx *Context) (tk *token.Token, pos int) { +func (s *Scanner) scanTag(ctx *Context) *token.Token { ctx.addOriginBuf('!') - ctx.progress(1) // skip '!' character + s.progress(ctx, 1) // skip '!' character + + var ( + tk *token.Token + progress int + ) for idx, c := range ctx.src[ctx.idx:] { - pos = idx + 1 + progress = idx + 1 ctx.addOriginBuf(c) switch c { case ' ', '\n', '\r': value := ctx.source(ctx.idx-1, ctx.idx+idx) tk = token.Tag(value, string(ctx.obuf), s.pos()) - pos = len([]rune(value)) - return + progress = len([]rune(value)) + goto END } } - return +END: + s.progressColumn(ctx, progress) + return tk } -func (s *Scanner) scanComment(ctx *Context) (tk *token.Token, pos int) { +func (s *Scanner) scanComment(ctx *Context) *token.Token { ctx.addOriginBuf('#') - ctx.progress(1) // skip '#' character + s.progress(ctx, 1) // skip '#' character + for idx, c := range ctx.src[ctx.idx:] { ctx.addOriginBuf(c) switch c { @@ -455,25 +473,32 @@ func (s *Scanner) scanComment(ctx *Context) (tk *token.Token, pos int) { continue } value := ctx.source(ctx.idx, ctx.idx+idx) - tk = token.Comment(value, string(ctx.obuf), s.pos()) - pos = len([]rune(value)) + 1 - return + progress := len([]rune(value)) + tk := token.Comment(value, string(ctx.obuf), s.pos()) + s.progressColumn(ctx, progress) + s.progressLine(ctx) + return tk } } // document ends with comment. value := string(ctx.src[ctx.idx:]) - tk = token.Comment(value, string(ctx.obuf), s.pos()) - pos = len([]rune(value)) + 1 - return + tk := token.Comment(value, string(ctx.obuf), s.pos()) + progress := len([]rune(value)) + s.progressColumn(ctx, progress) + s.progressLine(ctx) + return tk } -func trimCommentFromLiteralOpt(text string) (string, error) { +func (s *Scanner) trimCommentFromLiteralOpt(text string, header rune) (string, error) { idx := strings.Index(text, "#") if idx < 0 { return text, nil } if idx == 0 { - return "", errors.New("invalid literal header") + return "", ErrInvalidToken( + fmt.Sprintf("invalid literal header %s", text), + token.Invalid(string(header)+text, s.pos()), + ) } return text[:idx-1], nil } @@ -535,7 +560,7 @@ func (s *Scanner) scanNewLine(ctx *Context, c rune) { // > -- https://yaml.org/spec/1.2/spec.html if c == '\r' && ctx.nextChar() == '\n' { ctx.addOriginBuf('\r') - ctx.progress(1) + s.progress(ctx, 1) c = '\n' } @@ -546,7 +571,6 @@ func (s *Scanner) scanNewLine(ctx *Context, c rune) { } ctx.addBuf(' ') ctx.addOriginBuf(c) - ctx.isSingleLine = false s.progressLine(ctx) } @@ -675,7 +699,7 @@ func (s *Scanner) scanMergeKey(ctx *Context) bool { s.lastDelimColumn = s.column ctx.addToken(token.MergeKey(string(ctx.obuf)+"<<", s.pos())) - s.progressColumn(ctx, 1) + s.progressColumn(ctx, 2) return true } @@ -718,30 +742,28 @@ func (s *Scanner) scanLiteralHeader(ctx *Context) (bool, error) { return false, nil } - progress, err := s.scanLiteralHeaderOption(ctx) - if err != nil { + if err := s.scanLiteralHeaderOption(ctx); err != nil { return false, err } - s.progressColumn(ctx, progress) s.progressLine(ctx) return true, nil } -func (s *Scanner) scanLiteralHeaderOption(ctx *Context) (pos int, err error) { +func (s *Scanner) scanLiteralHeaderOption(ctx *Context) error { header := ctx.currentChar() ctx.addOriginBuf(header) - ctx.progress(1) // skip '|' or '>' character + s.progress(ctx, 1) // skip '|' or '>' character for idx, c := range ctx.src[ctx.idx:] { - pos = idx + progress := idx ctx.addOriginBuf(c) switch c { case '\n', '\r': value := ctx.source(ctx.idx, ctx.idx+idx) opt := strings.TrimRight(value, " ") orgOptLen := len(opt) - opt, err = trimCommentFromLiteralOpt(opt) + opt, err := s.trimCommentFromLiteralOpt(opt, header) if err != nil { - return + return err } switch opt { case "", "+", "-", @@ -781,12 +803,19 @@ func (s *Scanner) scanLiteralHeaderOption(ctx *Context) (pos int, err error) { s.indentState = IndentStateKeep ctx.resetBuffer() ctx.literalOpt = opt - return + s.progressColumn(ctx, progress) + return nil + default: + tk := token.Invalid(string(header)+opt, s.pos()) + s.progressColumn(ctx, progress) + return ErrInvalidToken(fmt.Sprintf("invalid literal header: %q", opt), tk) } } } - err = errors.New("invalid literal header") - return + text := string(ctx.src[ctx.idx:]) + tk := token.Invalid(string(header)+text, s.pos()) + s.progressColumn(ctx, len(text)) + return ErrInvalidToken(fmt.Sprintf("invalid literal header: %q", text), tk) } func (s *Scanner) scanMapKey(ctx *Context) bool { @@ -842,9 +871,8 @@ func (s *Scanner) scanAlias(ctx *Context) bool { return true } -func (s *Scanner) scan(ctx *Context) (pos int) { +func (s *Scanner) scan(ctx *Context) error { for ctx.next() { - pos = ctx.nextPos() c := ctx.currentChar() // First, change the IndentState. @@ -867,32 +895,29 @@ func (s *Scanner) scan(ctx *Context) (pos int) { switch c { case '{': if s.scanFlowMapStart(ctx) { - return + return nil } case '}': if s.scanFlowMapEnd(ctx) { - return + return nil } case '.': if s.scanDocumentEnd(ctx) { - pos += 2 - return + return nil } case '<': if s.scanMergeKey(ctx) { - pos++ - return + return nil } case '-': if s.scanDocumentStart(ctx) { - pos += 2 - return + return nil } if s.scanRawFoldedChar(ctx) { continue } if s.scanSequence(ctx) { - return + return nil } if ctx.existsBuffer() { // '-' is literal @@ -903,83 +928,73 @@ func (s *Scanner) scan(ctx *Context) (pos int) { } case '[': if s.scanFlowArrayStart(ctx) { - return + return nil } case ']': if s.scanFlowArrayEnd(ctx) { - return + return nil } case ',': if s.scanFlowEntry(ctx, c) { - return + return nil } case ':': if s.scanMapDelim(ctx) { - return + return nil } case '|', '>': scanned, err := s.scanLiteralHeader(ctx) if err != nil { - // TODO: returns syntax error object - return + return err } if scanned { continue } case '!': if !ctx.existsBuffer() { - token, progress := s.scanTag(ctx) + token := s.scanTag(ctx) ctx.addToken(token) - s.progressColumn(ctx, progress) - if c := ctx.previousChar(); s.isNewLineChar(c) { - s.progressLine(ctx) - } - pos += progress - return + return nil } case '%': if s.scanDirective(ctx) { - return + return nil } case '?': if s.scanMapKey(ctx) { - return + return nil } case '&': if s.scanAnchor(ctx) { - return + return nil } case '*': if s.scanAlias(ctx) { - return + return nil } case '#': if !ctx.existsBuffer() || ctx.previousChar() == ' ' { s.addBufferedTokenIfExists(ctx) - token, progress := s.scanComment(ctx) + token := s.scanComment(ctx) ctx.addToken(token) - s.progressColumn(ctx, progress) - s.progressLine(ctx) - pos += progress - return + return nil } case '\'', '"': if !ctx.existsBuffer() { - token, progress := s.scanQuote(ctx, c) + token := s.scanQuote(ctx, c) ctx.addToken(token) - pos += progress // If the non-whitespace character immediately following the quote is ':', the quote should be treated as a map key. // Therefore, do not return and continue processing as a normal map key. if ctx.currentCharWithSkipWhitespace() == ':' { continue } - return + return nil } case '\r', '\n': s.scanNewLine(ctx, c) continue case ' ': - if ctx.isSaveIndentMode() || (!s.isAnchor && !s.isFirstCharAtLine) { + if ctx.isDocument() || (!s.isAnchor && !s.isFirstCharAtLine) { ctx.addBuf(c) ctx.addOriginBuf(c) s.progressColumn(ctx, 1) @@ -991,16 +1006,16 @@ func (s *Scanner) scan(ctx *Context) (pos int) { continue } s.addBufferedTokenIfExists(ctx) - pos-- // to rescan white space at next scanning for adding white space to next buffer. s.isAnchor = false - return + // rescan white space at next scanning for adding white space to next buffer. + return nil } ctx.addBuf(c) ctx.addOriginBuf(c) s.progressColumn(ctx, 1) } s.addBufferedTokenIfExists(ctx) - return + return nil } // Init prepares the scanner s to tokenize the text src by setting the scanner at the beginning of src. @@ -1026,9 +1041,17 @@ func (s *Scanner) Scan() (token.Tokens, error) { } ctx := newContext(s.source[s.sourcePos:]) defer ctx.release() - progress := s.scan(ctx) - s.sourcePos += progress + var tokens token.Tokens + err := s.scan(ctx) tokens = append(tokens, ctx.tokens...) + + if err != nil { + var invalidTokenErr *InvalidTokenError + if errors.As(err, &invalidTokenErr) { + tokens = append(tokens, invalidTokenErr.Token) + } + return tokens, err + } return tokens, nil } diff --git a/token/token.go b/token/token.go index bc8b531c..c2d9a4bc 100644 --- a/token/token.go +++ b/token/token.go @@ -117,6 +117,8 @@ const ( StringType // BoolType type for Bool token BoolType + // InvalidType type for invalid token + InvalidType ) // String type identifier to text @@ -186,6 +188,8 @@ func (t Type) String() string { return "Infinity" case NanType: return "Nan" + case InvalidType: + return "Invalid" } return "" } @@ -202,6 +206,8 @@ const ( CharacterTypeMiscellaneous // CharacterTypeEscaped type of escaped character CharacterTypeEscaped + // CharacterTypeInvalid type for a invalid token. + CharacterTypeInvalid ) // String character type identifier to text @@ -759,9 +765,26 @@ func (t *Token) Clone() *Token { return &copied } +// Dump outputs token information to stdout for debugging. +func (t *Token) Dump() { + fmt.Printf( + "[TYPE]:%q [CHARTYPE]:%q [INDICATOR]:%q [VALUE]:%q [ORG]:%q [POS(line:column:level)]: %d:%d:%d\n", + t.Type, t.CharacterType, t.Indicator, t.Value, t.Origin, t.Position.Line, t.Position.Column, t.Position.IndentLevel, + ) +} + // Tokens type of token collection type Tokens []*Token +func (t Tokens) InvalidToken() *Token { + for _, tt := range t { + if tt.Type == InvalidType { + return tt + } + } + return nil +} + func (t *Tokens) add(tk *Token) { tokens := *t if len(tokens) == 0 { @@ -785,7 +808,8 @@ func (t *Tokens) Add(tks ...*Token) { // Dump dump all token structures for debugging func (t Tokens) Dump() { for _, tk := range t { - fmt.Printf("- %+v\n", tk) + fmt.Print("- ") + tk.Dump() } } @@ -1057,6 +1081,17 @@ func DocumentEnd(org string, pos *Position) *Token { } } +func Invalid(org string, pos *Position) *Token { + return &Token{ + Type: InvalidType, + CharacterType: CharacterTypeInvalid, + Indicator: NotIndicator, + Value: org, + Origin: org, + Position: pos, + } +} + // DetectLineBreakCharacter detect line break character in only one inside scalar content scope. func DetectLineBreakCharacter(src string) string { nc := strings.Count(src, "\n")