Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: native hex string literals #4271

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion docs/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,12 @@ A byte array with a max size.
The syntax being ``Bytes[maxLen]``, where ``maxLen`` is an integer which denotes the maximum number of bytes.
On the ABI level the Fixed-size bytes array is annotated as ``bytes``.

Bytes literals may be given as bytes strings.
Bytes literals may be given as bytes strings or as hex strings.

.. code-block:: vyper

bytes_string: Bytes[100] = b"\x01"
bytes_string: Bytes[100] = x"01"

.. index:: !string

Expand Down
22 changes: 22 additions & 0 deletions tests/functional/codegen/types/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,28 @@ def test2(l: bytes{m} = {vyper_literal}) -> bool:
assert c.test2(vyper_literal) is True


@pytest.mark.parametrize("m,val", [(2, "ab"), (3, "ab"), (4, "abcd")])
def test_native_hex_literals(get_contract, m, val):
vyper_literal = bytes.fromhex(val)
code = f"""
@external
def test() -> bool:
l: Bytes[{m}] = x"{val}"
return l == {vyper_literal}

@external
def test2(l: Bytes[{m}] = x"{val}") -> bool:
return l == {vyper_literal}
"""
print(code)

c = get_contract(code)

assert c.test() is True
assert c.test2() is True
assert c.test2(vyper_literal) is True


def test_zero_padding_with_private(get_contract):
code = """
counter: uint256
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/grammar/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,6 @@ def has_no_docstrings(c):
max_examples=500, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much]
)
def test_grammar_bruteforce(code):
_, _, _, reformatted_code = pre_parse(code + "\n")
tree = parse_to_ast(reformatted_code)
pre_parse_result = pre_parse(code + "\n")
tree = parse_to_ast(pre_parse_result.reformatted_code)
assert isinstance(tree, Module)
9 changes: 9 additions & 0 deletions tests/functional/syntax/test_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ def test() -> Bytes[1]:
(
"""
@external
def test() -> Bytes[2]:
a: Bytes[2] = x"abc"
return a
""",
SyntaxException,
),
(
"""
@external
def foo():
a: Bytes = b"abc"
""",
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/ast/test_annotate_and_optimize_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@


def get_contract_info(source_code):
_, loop_var_annotations, class_types, reformatted_code = pre_parse(source_code)
py_ast = python_ast.parse(reformatted_code)
pre_parse_result = pre_parse(source_code)
py_ast = python_ast.parse(pre_parse_result.reformatted_code)

annotate_python_ast(py_ast, reformatted_code, loop_var_annotations, class_types)
annotate_python_ast(py_ast, pre_parse_result)
Fixed Show fixed Hide fixed

return py_ast, reformatted_code
return py_ast, pre_parse_result.reformatted_code


def test_it_annotates_ast_with_source_code():
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ast/test_pre_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def test_prerelease_invalid_version_pragma(file_version, mock_version):
@pytest.mark.parametrize("code, pre_parse_settings, compiler_data_settings", pragma_examples)
def test_parse_pragmas(code, pre_parse_settings, compiler_data_settings, mock_version):
mock_version("0.3.10")
settings, _, _, _ = pre_parse(code)
pre_parse_result = pre_parse(code)

assert settings == pre_parse_settings
assert pre_parse_result.settings == pre_parse_settings

compiler_data = CompilerData(code)

Expand Down
2 changes: 1 addition & 1 deletion vyper/ast/grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ COMMENT: /#[^\n\r]*/
_NEWLINE: ( /\r?\n[\t ]*/ | COMMENT )+


STRING: /b?("(?!"").*?(?<!\\)(\\\\)*?"|'(?!'').*?(?<!\\)(\\\\)*?')/i
STRING: /x?b?("(?!"").*?(?<!\\)(\\\\)*?"|'(?!'').*?(?<!\\)(\\\\)*?')/i
DOCSTRING: /(""".*?(?<!\\)(\\\\)*?"""|'''.*?(?<!\\)(\\\\)*?''')/is

DEC_NUMBER: /0|[1-9]\d*/i
Expand Down
50 changes: 29 additions & 21 deletions vyper/ast/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import asttokens

from vyper.ast import nodes as vy_ast
from vyper.ast.pre_parser import pre_parse
from vyper.ast.pre_parser import PreParseResult, pre_parse
from vyper.compiler.settings import Settings
from vyper.exceptions import CompilerPanic, ParserException, SyntaxException
from vyper.typing import ModificationOffsets
Expand Down Expand Up @@ -55,9 +55,9 @@ def parse_to_ast_with_settings(
"""
if "\x00" in vyper_source:
raise ParserException("No null bytes (\\x00) allowed in the source code.")
settings, class_types, for_loop_annotations, python_source = pre_parse(vyper_source)
pre_parse_result = pre_parse(vyper_source)
try:
py_ast = python_ast.parse(python_source)
py_ast = python_ast.parse(pre_parse_result.reformatted_code)
except SyntaxError as e:
# TODO: Ensure 1-to-1 match of source_code:reformatted_code SyntaxErrors
raise SyntaxException(str(e), vyper_source, e.lineno, e.offset) from None
Expand All @@ -73,21 +73,20 @@ def parse_to_ast_with_settings(
annotate_python_ast(
py_ast,
vyper_source,
class_types,
for_loop_annotations,
pre_parse_result,
source_id=source_id,
module_path=module_path,
resolved_path=resolved_path,
)

# postcondition: consumed all the for loop annotations
assert len(for_loop_annotations) == 0
assert len(pre_parse_result.for_loop_annotations) == 0

# Convert to Vyper AST.
module = vy_ast.get_node(py_ast)
assert isinstance(module, vy_ast.Module) # mypy hint

return settings, module
return pre_parse_result.settings, module


def ast_to_dict(ast_struct: Union[vy_ast.VyperNode, List]) -> Union[Dict, List]:
Expand Down Expand Up @@ -118,8 +117,7 @@ def dict_to_ast(ast_struct: Union[Dict, List]) -> Union[vy_ast.VyperNode, List]:
def annotate_python_ast(
parsed_ast: python_ast.AST,
vyper_source: str,
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
pre_parse_result: PreParseResult,
source_id: int = 0,
module_path: Optional[str] = None,
resolved_path: Optional[str] = None,
Expand All @@ -133,11 +131,8 @@ def annotate_python_ast(
The AST to be annotated and optimized.
vyper_source: str
The original vyper source code
loop_var_annotations: dict
A mapping of line numbers of `For` nodes to the tokens of the type
annotation of the iterator extracted during pre-parsing.
modification_offsets : dict
A mapping of class names to their original class types.
pre_parse_result: PreParseResult
Outputs from pre-parsing.

Returns
-------
Expand All @@ -148,8 +143,7 @@ def annotate_python_ast(
tokens.mark_tokens(parsed_ast)
visitor = AnnotatingVisitor(
vyper_source,
modification_offsets,
for_loop_annotations,
pre_parse_result,
tokens,
source_id,
module_path=module_path,
Expand All @@ -168,8 +162,7 @@ class AnnotatingVisitor(python_ast.NodeTransformer):
def __init__(
self,
source_code: str,
modification_offsets: ModificationOffsets,
for_loop_annotations: dict,
pre_parse_result: PreParseResult,
tokens: asttokens.ASTTokens,
source_id: int,
module_path: Optional[str] = None,
Expand All @@ -180,8 +173,9 @@ def __init__(
self._module_path = module_path
self._resolved_path = resolved_path
self._source_code = source_code
self._modification_offsets = modification_offsets
self._for_loop_annotations = for_loop_annotations
self._modification_offsets = pre_parse_result.modification_offsets
self._for_loop_annotations = pre_parse_result.for_loop_annotations
self._native_hex_literal_locations = pre_parse_result.native_hex_literal_locations

self.counter: int = 0

Expand Down Expand Up @@ -401,7 +395,21 @@ def visit_Constant(self, node):
if node.value is None or isinstance(node.value, bool):
node.ast_type = "NameConstant"
elif isinstance(node.value, str):
node.ast_type = "Str"
if (node.lineno, node.col_offset) in self._native_hex_literal_locations:
if len(node.value) % 2 != 0:
raise SyntaxException(
"Native hex string must have an even number of characters",
self._source_code,
node.lineno,
node.col_offset,
)

byte_val = bytes.fromhex(node.value)

node.ast_type = "Bytes"
node.value = byte_val
else:
node.ast_type = "Str"
elif isinstance(node.value, bytes):
node.ast_type = "Bytes"
elif isinstance(node.value, Ellipsis.__class__):
Expand Down
Loading
Loading