Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: native hex string literals #4271

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
22 changes: 22 additions & 0 deletions tests/functional/codegen/types/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,28 @@ def test2(l: bytes{m} = {vyper_literal}) -> bool:
assert c.test2(vyper_literal) is True


@pytest.mark.parametrize("m,val", [(2, "ab"), (3, "ab"), (4, "abcd")])
def test_native_hex_literals(get_contract, m, val):
vyper_literal = bytes.fromhex(val)
code = f"""
@external
def test() -> bool:
l: Bytes[{m}] = x"{val}"
return l == {vyper_literal}

@external
def test2(l: Bytes[{m}] = x"{val}") -> bool:
return l == {vyper_literal}
"""
print(code)

c = get_contract(code)

assert c.test() is True
assert c.test2() is True
assert c.test2(vyper_literal) is True


def test_zero_padding_with_private(get_contract):
code = """
counter: uint256
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/grammar/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,6 @@ def has_no_docstrings(c):
max_examples=500, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much]
)
def test_grammar_bruteforce(code):
_, _, _, reformatted_code = pre_parse(code + "\n")
_, _, _, _, reformatted_code = pre_parse(code + "\n")
tree = parse_to_ast(reformatted_code)
assert isinstance(tree, Module)
9 changes: 9 additions & 0 deletions tests/functional/syntax/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def test() -> Bytes[1]:
(
"""
@external
def test() -> Bytes[2]:
a: Bytes[2] = x"abc"
return a
""",
SyntaxException,
),
(
"""
@external
def foo():
a: Bytes = b"abc"
""",
Expand Down
12 changes: 10 additions & 2 deletions tests/unit/ast/test_annotate_and_optimize_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@ def foo() -> int128:


def get_contract_info(source_code):
_, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code)
(
_,
loop_var_annotations,
native_hex_literal_locations,
class_types,
reformatted_code,
) = pre_parse(source_code)
py_ast = python_ast.parse(reformatted_code)

annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types)
annotate_python_ast(
py_ast, reformatted_code, loop_var_annotations, native_hex_literal_locations, class_types
)

return py_ast, reformatted_code

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/ast/test_pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version):
@pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples)
def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version):
mock_version("0.3.10")
settings, _, _, _ = pre_parse(code)
settings, _, _, _, _ = pre_parse(code)

assert settings == pre_parse_settings

Expand Down
2 changes: 1 addition & 1 deletion vyper/ast/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ COMMENT: /#[^\n\r]*/
_NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+


STRING: /b?("(?!"").*?(?<!\\)(\\\\)*?"|'(?!'').*?(?<!\\)(\\\\)*?')/i
STRING: /x?b?("(?!"").*?(?<!\\)(\\\\)*?"|'(?!'').*?(?<!\\)(\\\\)*?')/i
DOCSTRING: /(""".*?(?<!\\)(\\\\)*?"""|'''.*?(?<!\\)(\\\\)*?''')/is

DEC_NUMBER: /0|[1-9]\d*/i
Expand Down
29 changes: 27 additions & 2 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@ def parse_to_ast_with_settings(
"""
if "\x00" in vyper_source:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
settings, class_types, for_loop_annotations, python_source = pre_parse(vyper_source)
(
settings,
class_types,
for_loop_annotations,
native_hex_literal_locations,
python_source,
) = pre_parse(vyper_source)
try:
py_ast = python_ast.parse(python_source)
except SyntaxError as e:
Expand All @@ -75,6 +81,7 @@ def parse_to_ast_with_settings(
vyper_source,
class_types,
for_loop_annotations,
native_hex_literal_locations,
source_id=source_id,
module_path=module_path,
resolved_path=resolved_path,
Expand Down Expand Up @@ -120,6 +127,7 @@ def annotate_python_ast(
vyper_source: str,
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
native_hex_literal_locations: list,
source_id: int = 0,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
Expand Down Expand Up @@ -150,6 +158,7 @@ def annotate_python_ast(
vyper_source,
modification_offsets,
for_loop_annotations,
native_hex_literal_locations,
tokens,
source_id,
module_path=module_path,
Expand All @@ -170,6 +179,7 @@ def __init__(
source_code: str,
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
native_hex_literal_locations: list,
tokens: asttokens.ASTTokens,
source_id: int,
module_path: Optional[str] = None,
Expand All @@ -182,6 +192,7 @@ def __init__(
self._source_code = source_code
self._modification_offsets = modification_offsets
self._for_loop_annotations = for_loop_annotations
self._native_hex_literal_locations = native_hex_literal_locations

self.counter: int = 0

Expand Down Expand Up @@ -401,7 +412,21 @@ def visit_Constant(self, node):
if node.value is None or isinstance(node.value, bool):
node.ast_type = "NameConstant"
elif isinstance(node.value, str):
node.ast_type = "Str"
if (node.lineno, node.col_offset) in self._native_hex_literal_locations:
if len(node.value) % 2 != 0:
raise SyntaxException(
"Native hex string must have an even number of characters",
self._source_code,
node.lineno,
node.col_offset,
)

byte_val = bytes.fromhex(node.value)

node.ast_type = "Bytes"
node.value = byte_val
else:
node.ast_type = "Str"
elif isinstance(node.value, bytes):
node.ast_type = "Bytes"
elif isinstance(node.value, Ellipsis.__class__):
Expand Down
67 changes: 55 additions & 12 deletions vyper/ast/pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import io
import re
from collections import defaultdict
from tokenize import COMMENT, NAME, OP, TokenError, TokenInfo, tokenize, untokenize
from tokenize import COMMENT, NAME, OP, STRING, TokenError, TokenInfo, tokenize, untokenize

from packaging.specifiers import InvalidSpecifier, SpecifierSet

Expand Down Expand Up @@ -48,7 +48,7 @@ def validate_version_pragma(version_str: str, full_source_code: str, start: Pars
)


class ForParserState(enum.Enum):
class CustomParserState(enum.Enum):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just ParserState

NOT_RUNNING = enum.auto()
START_SOON = enum.auto()
RUNNING = enum.auto()
Expand All @@ -63,23 +63,23 @@ def __init__(self, code):
self.annotations = {}
self._current_annotation = None

self._state = ForParserState.NOT_RUNNING
self._state = CustomParserState.NOT_RUNNING
self._current_for_loop = None

def consume(self, token):
# state machine: we can start slurping tokens soon
if token.type == NAME and token.string == "for":
# note: self._state should be NOT_RUNNING here, but we don't sanity
# check here as that should be an error the parser will handle.
self._state = ForParserState.START_SOON
self._state = CustomParserState.START_SOON
self._current_for_loop = token.start

if self._state == ForParserState.NOT_RUNNING:
if self._state == CustomParserState.NOT_RUNNING:
return False

# state machine: start slurping tokens
if token.type == OP and token.string == ":":
self._state = ForParserState.RUNNING
self._state = CustomParserState.RUNNING

# sanity check -- this should never really happen, but if it does,
# try to raise an exception which pinpoints the source.
Expand All @@ -93,19 +93,55 @@ def consume(self, token):

# state machine: end slurping tokens
if token.type == NAME and token.string == "in":
self._state = ForParserState.NOT_RUNNING
self._state = CustomParserState.NOT_RUNNING
self.annotations[self._current_for_loop] = self._current_annotation or []
self._current_annotation = None
return False

if self._state != ForParserState.RUNNING:
if self._state != CustomParserState.RUNNING:
return False

# slurp the token
self._current_annotation.append(token)
return True


class NativeHexParser:
def __init__(self):
self.locations = []
self._current_x = None
self._state = CustomParserState.NOT_RUNNING

def consume(self, token, result):
# prepare to check if the next token is a STRING
if token.type == NAME and token.string == "x":
self._state = CustomParserState.RUNNING
self._current_x = token
return True

if self._state == CustomParserState.NOT_RUNNING:
return False

if self._state == CustomParserState.RUNNING:
current_x = self._current_x
self._current_x = None
self._state = CustomParserState.NOT_RUNNING

toks = [current_x]

# drop the leading x token if the next token is a STRING to avoid a python
# parser error
if token.type == STRING:
self.locations.append(current_x.start)
toks = [TokenInfo(STRING, token.string, current_x.start, token.end, token.line)]
result.extend(toks)
return True

result.extend(toks)

return False


# compound statements that are replaced with `class`
# TODO remove enum in favor of flag
VYPER_CLASS_TYPES = {
Expand All @@ -122,7 +158,7 @@ def consume(self, token):
CUSTOM_EXPRESSION_TYPES = {"extcall": "ExtCall", "staticcall": "StaticCall"}


def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]:
def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, list, str]:
"""
Re-formats a vyper source string into a python source string and performs
some validation. More specifically,
Expand Down Expand Up @@ -153,10 +189,11 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]:
str
Reformatted python source string.
"""
result = []
result: list[TokenInfo] = []
modification_offsets: ModificationOffsets = {}
settings = Settings()
for_parser = ForParser(code)
native_hex_parser = NativeHexParser()

_col_adjustments: dict[int, int] = defaultdict(lambda: 0)

Expand Down Expand Up @@ -264,7 +301,7 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]:
if (typ, string) == (OP, ";"):
raise SyntaxException("Semi-colon statements not allowed", code, start[0], start[1])

if not for_parser.consume(token):
if not for_parser.consume(token) and not native_hex_parser.consume(token, result):
result.extend(toks)

except TokenError as e:
Expand All @@ -274,4 +311,10 @@ def pre_parse(code: str) -> tuple[Settings, ModificationOffsets, dict, str]:
for k, v in for_parser.annotations.items():
for_loop_annotations[k] = v.copy()

return settings, modification_offsets, for_loop_annotations, untokenize(result).decode("utf-8")
return (
settings,
modification_offsets,
for_loop_annotations,
native_hex_parser.locations,
untokenize(result).decode("utf-8"),
)
Loading