diff --git a/rewrite/rewrite/python/_parser_visitor.py b/rewrite/rewrite/python/_parser_visitor.py index 248dd20..df07910 100644 --- a/rewrite/rewrite/python/_parser_visitor.py +++ b/rewrite/rewrite/python/_parser_visitor.py @@ -3,8 +3,7 @@ from functools import lru_cache from io import BytesIO from pathlib import Path -from random import random -from tokenize import tokenize +from tokenize import tokenize, TokenInfo from typing import Optional, TypeVar, cast, Callable, List, Tuple, Dict, Type, Sequence from rewrite import random_id, Markers @@ -974,7 +973,7 @@ def _map_comprehension_condition(self, i): ) def visit_Module(self, node: ast.Module) -> py.CompilationUnit: - return py.CompilationUnit( + cu = py.CompilationUnit( random_id(), Space.EMPTY, Markers.EMPTY, @@ -988,6 +987,8 @@ def visit_Module(self, node: ast.Module) -> py.CompilationUnit: self.__pad_right(j.Empty(random_id(), Space.EMPTY, Markers.EMPTY), Space.EMPTY)], self.__whitespace() ) + # assert self._cursor == len(self._source) + return cu def visit_Name(self, node): return j.Identifier( @@ -1350,8 +1351,7 @@ def _map_assignment_operator(self, op): raise ValueError(f"Unsupported operator: {op}") return self.__pad_left(self.__source_before(op_str), op) - def __map_fstring(self, node, prefix, tok, tokens): - consume_end_delim = False + def __map_fstring(self, node: ast.JoinedStr, prefix: Space, tok: TokenInfo, tokens): if tok.type != token.FSTRING_START: if len(node.values) == 1 and isinstance(node.values[0], ast.Constant): # format specifiers are stored as f-strings in the AST; e.g. `f'{1:n}'` @@ -1368,6 +1368,7 @@ def __map_fstring(self, node, prefix, tok, tokens): ), next(tokens)) else: delimiter = '' + consume_end_delim = False else: delimiter = tok.string self._cursor += len(delimiter) @@ -1377,10 +1378,11 @@ def __map_fstring(self, node, prefix, tok, tokens): # tokenizer tokens: FSTRING_START, FSTRING_MIDDLE, OP, ..., OP, FSTRING_MIDDLE, FSTRING_END parts = [] for value in node.values: - if tok.type == token.OP: + if tok.type == token.OP and tok.string == '{': self._cursor += len(tok.string) - if isinstance(value.value, ast.JoinedStr): - nested, tok = self.__map_fstring(value.value, Space.EMPTY, next(tokens), tokens) + tok = next(tokens) + if isinstance(cast(ast.FormattedValue, value).value, ast.JoinedStr): + nested, tok = self.__map_fstring(cast(ast.JoinedStr, cast(ast.FormattedValue, value).value), Space.EMPTY, tok, tokens) expr = self.__pad_right( nested, Space.EMPTY @@ -1390,7 +1392,7 @@ def __map_fstring(self, node, prefix, tok, tokens): tok = next(tokens) else: expr = self.__pad_right( - self.__convert(value.value), + self.__convert(cast(ast.FormattedValue, value).value), self.__whitespace() ) prev_tok = tok