diff --git a/tests/functional/builtins/folding/test_abs.py b/tests/functional/builtins/folding/test_abs.py index 68131678fa..c954380def 100644 --- a/tests/functional/builtins/folding/test_abs.py +++ b/tests/functional/builtins/folding/test_abs.py @@ -2,8 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType @@ -19,9 +18,9 @@ def foo(a: int256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"abs({a})") + vyper_ast = parse_and_fold(f"abs({a})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a) == new_node.value == abs(a) diff --git a/tests/functional/builtins/folding/test_addmod_mulmod.py b/tests/functional/builtins/folding/test_addmod_mulmod.py index 1d789f1655..e6a9fc193f 100644 --- a/tests/functional/builtins/folding/test_addmod_mulmod.py +++ b/tests/functional/builtins/folding/test_addmod_mulmod.py @@ -2,8 +2,7 @@ from hypothesis import assume, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256 - 1) @@ -22,8 +21,8 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})") + vyper_ast = parse_and_fold(f"{fn_name}({a}, {b}, {c})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b, c) == new_node.value diff --git a/tests/functional/builtins/folding/test_bitwise.py b/tests/functional/builtins/folding/test_bitwise.py index 53a6d333a0..c1ff7674bb 100644 --- a/tests/functional/builtins/folding/test_bitwise.py +++ b/tests/functional/builtins/folding/test_bitwise.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import InvalidType, OverflowException from vyper.semantics.analysis.utils import validate_expected_type from vyper.semantics.types.shortcuts import INT256_T, UINT256_T @@ -29,7 +29,7 @@ def foo(a: uint256, b: uint256) -> uint256: contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") + vyper_ast = parse_and_fold(f"{a} {op} {b}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -48,10 +48,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() # force bounds check, no-op because validate_numeric_bounds # already does this, but leave in for hygiene (in case @@ -78,10 +77,9 @@ def foo(a: int256, b: uint256) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}") - old_node = vyper_ast.body[0].value - try: + vyper_ast = parse_and_fold(f"{a} {op} {b}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() validate_expected_type(new_node, INT256_T) # force bounds check # compile time behavior does not match runtime behavior. @@ -105,7 +103,7 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"~{value}") + vyper_ast = parse_and_fold(f"~{value}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/functional/builtins/folding/test_epsilon.py b/tests/functional/builtins/folding/test_epsilon.py index 4f5e9434ec..7bc2afe757 100644 --- a/tests/functional/builtins/folding/test_epsilon.py +++ b/tests/functional/builtins/folding/test_epsilon.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("typ_name", ["decimal"]) @@ -13,8 +12,8 @@ def foo() -> {typ_name}: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})") + vyper_ast = parse_and_fold(f"epsilon({typ_name})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo() == new_node.value diff --git a/tests/functional/builtins/folding/test_floor_ceil.py b/tests/functional/builtins/folding/test_floor_ceil.py index 04921e504e..9e63c7b099 100644 --- a/tests/functional/builtins/folding/test_floor_ceil.py +++ b/tests/functional/builtins/folding/test_floor_ceil.py @@ -4,8 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_decimals = st.decimals( min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10 @@ -28,8 +27,8 @@ def foo(a: decimal) -> int256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_fold_as_wei_value.py b/tests/functional/builtins/folding/test_fold_as_wei_value.py index 4287615bab..01af646a16 100644 --- a/tests/functional/builtins/folding/test_fold_as_wei_value.py +++ b/tests/functional/builtins/folding/test_fold_as_wei_value.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.builtins import functions as vy_fn from vyper.utils import SizeLimits @@ -30,9 +30,9 @@ def foo(a: decimal) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value:.10f}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -49,8 +49,8 @@ def foo(a: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')") + vyper_ast = parse_and_fold(f"as_wei_value({value}, '{denom}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.AsWeiValue()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_keccak_sha.py b/tests/functional/builtins/folding/test_keccak_sha.py index 8da420538f..3b5f99891f 100644 --- a/tests/functional/builtins/folding/test_keccak_sha.py +++ b/tests/functional/builtins/folding/test_keccak_sha.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&()*+,-./:;<=>?@[]^_`{|}~' # NOQA: E501 @@ -20,9 +19,9 @@ def foo(a: String[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')") + vyper_ast = parse_and_fold(f"{fn_name}('''{value}''')") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -39,9 +38,9 @@ def foo(a: Bytes[100]) -> bytes32: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value @@ -60,8 +59,8 @@ def foo(a: Bytes[100]) -> bytes32: value = f"0x{value.hex()}" - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})") + vyper_ast = parse_and_fold(f"{fn_name}({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert f"0x{contract.foo(value).hex()}" == new_node.value diff --git a/tests/functional/builtins/folding/test_len.py b/tests/functional/builtins/folding/test_len.py index 967f906555..6d59751748 100644 --- a/tests/functional/builtins/folding/test_len.py +++ b/tests/functional/builtins/folding/test_len.py @@ -1,7 +1,6 @@ import pytest -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold @pytest.mark.parametrize("length", [0, 1, 32, 33, 64, 65, 1024]) @@ -15,9 +14,9 @@ def foo(a: String[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len('{value}')") + vyper_ast = parse_and_fold(f"len('{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value @@ -33,9 +32,9 @@ def foo(a: Bytes[1024]) -> uint256: value = "a" * length - vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')") + vyper_ast = parse_and_fold(f"len(b'{value}')") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value.encode()) == new_node.value @@ -51,8 +50,8 @@ def foo(a: Bytes[1024]) -> uint256: value = f"0x{'00' * length}" - vyper_ast = vy_ast.parse_to_ast(f"len({value})") + vyper_ast = parse_and_fold(f"len({value})") old_node = vyper_ast.body[0].value - new_node = vy_fn.Len()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(value) == new_node.value diff --git a/tests/functional/builtins/folding/test_min_max.py b/tests/functional/builtins/folding/test_min_max.py index 36a611fa1b..752b64eb04 100644 --- a/tests/functional/builtins/folding/test_min_max.py +++ b/tests/functional/builtins/folding/test_min_max.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold from vyper.utils import SizeLimits st_decimals = st.decimals( @@ -29,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -48,9 +47,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value @@ -67,8 +66,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})") + vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})") old_node = vyper_ast.body[0].value - new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(left, right) == new_node.value diff --git a/tests/functional/builtins/folding/test_powmod.py b/tests/functional/builtins/folding/test_powmod.py index a3c2567f58..ad1197e8e3 100644 --- a/tests/functional/builtins/folding/test_powmod.py +++ b/tests/functional/builtins/folding/test_powmod.py @@ -2,8 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast -from vyper.builtins import functions as vy_fn +from tests.utils import parse_and_fold st_uint256 = st.integers(min_value=0, max_value=2**256) @@ -19,8 +18,8 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})") + vyper_ast = parse_and_fold(f"pow_mod256({a}, {b})") old_node = vyper_ast.body[0].value - new_node = vy_fn.PowMod256()._try_fold(old_node) + new_node = old_node.get_folded_value() assert contract.foo(a, b) == new_node.value diff --git a/tests/functional/grammar/test_grammar.py b/tests/functional/grammar/test_grammar.py index 652102c376..351793b28e 100644 --- a/tests/functional/grammar/test_grammar.py +++ b/tests/functional/grammar/test_grammar.py @@ -4,7 +4,7 @@ import hypothesis import hypothesis.strategies as st import pytest -from hypothesis import assume, given +from hypothesis import HealthCheck, assume, given from hypothesis.extra.lark import LarkStrategy from vyper.ast import Module, parse_to_ast @@ -103,7 +103,7 @@ def has_no_docstrings(c): @pytest.mark.fuzzing @given(code=from_grammar().filter(lambda c: utf8_encodable(c))) -@hypothesis.settings(max_examples=500) +@hypothesis.settings(max_examples=500, suppress_health_check=[HealthCheck.too_slow]) def test_grammar_bruteforce(code): if utf8_encodable(code): _, _, _, reformatted_code = pre_parse(code + "\n") diff --git a/tests/functional/syntax/test_bool.py b/tests/functional/syntax/test_bool.py index 48ed37321a..5388a92b95 100644 --- a/tests/functional/syntax/test_bool.py +++ b/tests/functional/syntax/test_bool.py @@ -37,7 +37,7 @@ def foo(): def foo() -> bool: return (1 == 2) <= (1 == 1) """, - TypeMismatch, + InvalidOperation, ), """ @external diff --git a/tests/unit/ast/nodes/test_fold_binop_decimal.py b/tests/unit/ast/nodes/test_fold_binop_decimal.py index e426a11de9..a75d114f88 100644 --- a/tests/unit/ast/nodes/test_fold_binop_decimal.py +++ b/tests/unit/ast/nodes/test_fold_binop_decimal.py @@ -4,7 +4,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException st_decimals = st.decimals( @@ -28,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -45,11 +45,8 @@ def foo(a: decimal, b: decimal) -> decimal: def test_binop_pow(): # raises because Vyper does not support decimal exponentiation - vyper_ast = vy_ast.parse_to_ast("3.1337 ** 4.2") - old_node = vyper_ast.body[0].value - with pytest.raises(TypeMismatch): - old_node.get_folded_value() + _ = parse_and_fold("3.1337 ** 4.2") @pytest.mark.fuzzing @@ -72,8 +69,8 @@ def foo({input_value}) -> decimal: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = -(2**127) <= expected < 2**127 diff --git a/tests/unit/ast/nodes/test_fold_binop_int.py b/tests/unit/ast/nodes/test_fold_binop_int.py index 904b36c167..d9340927fe 100644 --- a/tests/unit/ast/nodes/test_fold_binop_int.py +++ b/tests/unit/ast/nodes/test_fold_binop_int.py @@ -2,7 +2,7 @@ from hypothesis import example, given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import ZeroDivisionException st_int32 = st.integers(min_value=-(2**32), max_value=2**32) @@ -24,9 +24,9 @@ def foo(a: int128, b: int128) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = True except ZeroDivisionException: @@ -54,9 +54,9 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") - old_node = vyper_ast.body[0].value try: + vyper_ast = parse_and_fold(f"{left} {op} {right}") + old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() is_valid = new_node.value >= 0 except ZeroDivisionException: @@ -83,7 +83,7 @@ def foo(a: uint256, b: uint256) -> uint256: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} ** {right}") + vyper_ast = parse_and_fold(f"{left} ** {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -112,9 +112,8 @@ def foo({input_value}) -> int128: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) - try: + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value is_valid = True diff --git a/tests/unit/ast/nodes/test_fold_boolop.py b/tests/unit/ast/nodes/test_fold_boolop.py index 3c42da0d26..082e6f35c3 100644 --- a/tests/unit/ast/nodes/test_fold_boolop.py +++ b/tests/unit/ast/nodes/test_fold_boolop.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold variables = "abcdefghij" @@ -24,7 +24,7 @@ def foo({input_value}) -> bool: literal_op = f" {comparator} ".join(str(i) for i in values) - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -52,7 +52,7 @@ def foo({input_value}) -> bool: literal_op = " ".join(f"{a} {b}" for a, b in zip(values, comparators)) literal_op = literal_op.rsplit(maxsplit=1)[0] - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/unit/ast/nodes/test_fold_compare.py b/tests/unit/ast/nodes/test_fold_compare.py index 2b7c0f09d7..aab8ac0b2d 100644 --- a/tests/unit/ast/nodes/test_fold_compare.py +++ b/tests/unit/ast/nodes/test_fold_compare.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold from vyper.exceptions import UnfoldableNode @@ -19,7 +19,7 @@ def foo(a: int128, b: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -39,7 +39,7 @@ def foo(a: uint128, b: uint128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}") + vyper_ast = parse_and_fold(f"{left} {op} {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -63,7 +63,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} in {right}") + vyper_ast = parse_and_fold(f"{left} in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -92,7 +92,7 @@ def bar(a: int128) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{left} not in {right}") + vyper_ast = parse_and_fold(f"{left} not in {right}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -106,7 +106,7 @@ def bar(a: int128) -> bool: @pytest.mark.parametrize("op", ["==", "!=", "<", "<=", ">=", ">"]) def test_compare_type_mismatch(op): - vyper_ast = vy_ast.parse_to_ast(f"1 {op} 1.0") + vyper_ast = parse_and_fold(f"1 {op} 1.0") old_node = vyper_ast.body[0].value with pytest.raises(UnfoldableNode): old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_subscript.py b/tests/unit/ast/nodes/test_fold_subscript.py index 1884abf73b..3ed26d07b7 100644 --- a/tests/unit/ast/nodes/test_fold_subscript.py +++ b/tests/unit/ast/nodes/test_fold_subscript.py @@ -2,7 +2,7 @@ from hypothesis import given, settings from hypothesis import strategies as st -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.fuzzing @@ -19,7 +19,7 @@ def foo(array: int128[10], idx: uint256) -> int128: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"{array}[{idx}]") + vyper_ast = parse_and_fold(f"{array}[{idx}]") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() diff --git a/tests/unit/ast/nodes/test_fold_unaryop.py b/tests/unit/ast/nodes/test_fold_unaryop.py index ff48adfe71..af72f5f8b0 100644 --- a/tests/unit/ast/nodes/test_fold_unaryop.py +++ b/tests/unit/ast/nodes/test_fold_unaryop.py @@ -1,6 +1,6 @@ import pytest -from vyper import ast as vy_ast +from tests.utils import parse_and_fold @pytest.mark.parametrize("bool_cond", [True, False]) @@ -12,7 +12,7 @@ def foo(a: bool) -> bool: """ contract = get_contract(source) - vyper_ast = vy_ast.parse_to_ast(f"not {bool_cond}") + vyper_ast = parse_and_fold(f"not {bool_cond}") old_node = vyper_ast.body[0].value new_node = old_node.get_folded_value() @@ -30,7 +30,7 @@ def foo(a: bool) -> bool: contract = get_contract(source) literal_op = f"{'not ' * count}{bool_cond}" - vyper_ast = vy_ast.parse_to_ast(literal_op) + vyper_ast = parse_and_fold(literal_op) new_node = vyper_ast.body[0].value.get_folded_value() expected = new_node.value diff --git a/tests/utils.py b/tests/utils.py index 0c89c39ff3..b8a6b493d8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,9 @@ import contextlib import os +from vyper import ast as vy_ast +from vyper.semantics.analysis.pre_typecheck import pre_typecheck + @contextlib.contextmanager def working_directory(directory): @@ -10,3 +13,9 @@ def working_directory(directory): yield finally: os.chdir(tmp) + + +def parse_and_fold(source_code): + ast = vy_ast.parse_to_ast(source_code) + pre_typecheck(ast) + return ast diff --git a/vyper/ast/nodes.py b/vyper/ast/nodes.py index 7a8c7443b7..90365c63d5 100644 --- a/vyper/ast/nodes.py +++ b/vyper/ast/nodes.py @@ -400,21 +400,11 @@ def get_folded_value(self) -> "VyperNode": """ Attempt to get the folded value, bubbling up UnfoldableNode if the node is not foldable. - - - The returned value is cached on `_metadata["folded_value"]`. - - For constant/literal nodes, the node should be directly returned - without caching to the metadata. """ - if self.is_literal_value: - return self - - if "folded_value" not in self._metadata: - res = self._try_fold() # possibly throws UnfoldableNode - self._set_folded_value(res) - - return self._metadata["folded_value"] + try: + return self._metadata["folded_value"] + except KeyError: + raise UnfoldableNode("not foldable", self) def _set_folded_value(self, node: "VyperNode") -> None: # sanity check this is only called once @@ -422,7 +412,9 @@ def _set_folded_value(self, node: "VyperNode") -> None: # set the "original node" so that exceptions can point to the original # node and not the folded node - node = copy.copy(node) + cls = node.__class__ + # make a fresh copy so that the node metadata is fresh. + node = cls(**{i: getattr(node, i) for i in node.get_fields() if hasattr(node, i)}) node._original_node = self self._metadata["folded_value"] = node @@ -430,17 +422,6 @@ def _set_folded_value(self, node: "VyperNode") -> None: def get_original_node(self) -> "VyperNode": return self._original_node or self - def _try_fold(self) -> "VyperNode": - """ - Attempt to constant-fold the content of a node, returning the result of - constant-folding if possible. - - If a node cannot be folded, it should raise `UnfoldableNode`. This - base implementation acts as a catch-all to raise on any inherited - classes that do not implement the method. - """ - raise UnfoldableNode(f"{type(self)} cannot be folded") - def validate(self) -> None: """ Validate the content of a node. @@ -919,10 +900,6 @@ class List(ExprNode): def is_literal_value(self): return all(e.is_literal_value for e in self.elements) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class Tuple(ExprNode): __slots__ = ("elements",) @@ -936,10 +913,6 @@ def validate(self): if not self.elements: raise InvalidLiteral("Cannot have an empty tuple", self) - def _try_fold(self) -> ExprNode: - elements = [e.get_folded_value() for e in self.elements] - return type(self).from_node(self, elements=elements) - class NameConstant(Constant): __slots__ = () @@ -960,10 +933,6 @@ class Dict(ExprNode): def is_literal_value(self): return all(v.is_literal_value for v in self.values) - def _try_fold(self) -> ExprNode: - values = [v.get_folded_value() for v in self.values] - return type(self).from_node(self, values=values) - class Name(ExprNode): __slots__ = ("id",) @@ -972,27 +941,6 @@ class Name(ExprNode): class UnaryOp(ExprNode): __slots__ = ("op", "operand") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the unary operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - operand = self.operand.get_folded_value() - - if isinstance(self.op, Not) and not isinstance(operand, NameConstant): - raise UnfoldableNode("not a boolean!", self.operand) - if isinstance(self.op, USub) and not isinstance(operand, Num): - raise UnfoldableNode("not a number!", self.operand) - if isinstance(self.op, Invert) and not isinstance(operand, Int): - raise UnfoldableNode("not an int!", self.operand) - - value = self.op._op(operand.value) - return type(operand).from_node(self, value=value) - class Operator(VyperNode): pass @@ -1021,30 +969,6 @@ def _op(self, value): class BinOp(ExprNode): __slots__ = ("left", "op", "right") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the arithmetic operation. - - Returns - ------- - Int | Decimal - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if type(left) is not type(right): - raise UnfoldableNode("invalid operation", self) - if not isinstance(left, Num): - raise UnfoldableNode("not a number!", self.left) - - # this validation is performed to prevent the compiler from hanging - # on very large shifts and improve the error message for negative - # values. - if isinstance(self.op, (LShift, RShift)) and not (0 <= right.value <= 256): - raise InvalidLiteral("Shift bits must be between 0 and 256", self.right) - - value = self.op._op(left.value, right.value) - return type(left).from_node(self, value=value) - class Add(Operator): __slots__ = () @@ -1170,24 +1094,6 @@ class RShift(Operator): class BoolOp(ExprNode): __slots__ = ("op", "values") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the boolean operation. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - values = [v.get_folded_value() for v in self.values] - - if any(not isinstance(v, NameConstant) for v in values): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - values = [v.value for v in values] - value = self.op._op(values) - return NameConstant.from_node(self, value=value) - class And(Operator): __slots__ = () @@ -1225,40 +1131,6 @@ def __init__(self, *args, **kwargs): kwargs["right"] = kwargs.pop("comparators")[0] super().__init__(*args, **kwargs) - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the comparison. - - Returns - ------- - NameConstant - Node representing the result of the evaluation. - """ - left, right = [i.get_folded_value() for i in (self.left, self.right)] - if not isinstance(left, Constant): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - - # CMC 2022-08-04 we could probably remove these evaluation rules as they - # are taken care of in the IR optimizer now. - if isinstance(self.op, (In, NotIn)): - if not isinstance(right, List): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if next((i for i in right.elements if not isinstance(i, Constant)), None): - raise UnfoldableNode("Node contains invalid field(s) for evaluation") - if len(set([type(i) for i in right.elements])) > 1: - raise UnfoldableNode("List contains multiple literal types") - value = self.op._op(left.value, [i.value for i in right.elements]) - return NameConstant.from_node(self, value=value) - - if not isinstance(left, type(right)): - raise UnfoldableNode("Cannot compare different literal types") - - if not isinstance(self.op, (Eq, NotEq)) and not isinstance(left, (Int, Decimal)): - raise TypeMismatch(f"Invalid literal types for {self.op.description} comparison", self) - - value = self.op._op(left.value, right.value) - return NameConstant.from_node(self, value=value) - class Eq(Operator): __slots__ = () @@ -1315,21 +1187,6 @@ def _op(self, left, right): class Call(ExprNode): __slots__ = ("func", "args", "keywords") - # try checking if this is a builtin, which is foldable - def _try_fold(self): - if not isinstance(self.func, Name): - raise UnfoldableNode("not a builtin", self) - - # cursed import cycle! - from vyper.builtins.functions import DISPATCH_TABLE - - func_name = self.func.id - if func_name not in DISPATCH_TABLE: - raise UnfoldableNode("not a builtin", self) - - builtin_t = DISPATCH_TABLE[func_name] - return builtin_t._try_fold(self) - class keyword(VyperNode): __slots__ = ("arg", "value") @@ -1342,37 +1199,6 @@ class Attribute(ExprNode): class Subscript(ExprNode): __slots__ = ("slice", "value") - def _try_fold(self) -> ExprNode: - """ - Attempt to evaluate the subscript. - - This method reduces an indexed reference to a literal array into the value - within the array, e.g. `["foo", "bar"][1]` becomes `"bar"` - - Returns - ------- - ExprNode - Node representing the result of the evaluation. - """ - slice_ = self.slice.value.get_folded_value() - value = self.value.get_folded_value() - - if not isinstance(value, List): - raise UnfoldableNode("Subscript object is not a literal list") - - elements = value.elements - if len(set([type(i) for i in elements])) > 1: - raise UnfoldableNode("List contains multiple node types") - - if not isinstance(slice_, Int): - raise UnfoldableNode("invalid index type", slice_) - - idx = slice_.value - if idx < 0 or idx >= len(elements): - raise UnfoldableNode("invalid index value") - - return elements[idx] - class Index(VyperNode): __slots__ = ("value",) diff --git a/vyper/ast/nodes.pyi b/vyper/ast/nodes.pyi index 8bc4a4eb57..4a5bc0d001 100644 --- a/vyper/ast/nodes.pyi +++ b/vyper/ast/nodes.pyi @@ -31,7 +31,6 @@ class VyperNode: @classmethod def get_fields(cls: Any) -> set: ... def get_folded_value(self) -> VyperNode: ... - def _try_fold(self) -> VyperNode: ... def _set_folded_value(self, node: VyperNode) -> None: ... @classmethod def from_node(cls, node: VyperNode, **kwargs: Any) -> Any: ... diff --git a/vyper/builtins/functions.py b/vyper/builtins/functions.py index 39d97c4abe..4f8101dfbe 100644 --- a/vyper/builtins/functions.py +++ b/vyper/builtins/functions.py @@ -90,6 +90,7 @@ ceil32, fourbytes_to_int, keccak256, + method_id, method_id_int, vyper_warn, ) @@ -723,12 +724,12 @@ def _try_fold(self, node): raise InvalidLiteral("Invalid function signature - no spaces allowed.", node.args[0]) return_type = self.infer_kwarg_types(node)["output_type"].typedef - value = method_id_int(value.value) + value = method_id(value.value) if return_type.compare_type(BYTES4_T): - return vy_ast.Hex.from_node(node, value=hex(value)) + return vy_ast.Hex.from_node(node, value="0x" + value.hex()) else: - return vy_ast.Bytes.from_node(node, value=value.to_bytes(4, "big")) + return vy_ast.Bytes.from_node(node, value=value) def fetch_call_return(self, node): validate_call_args(node, 1, ["output_type"]) diff --git a/vyper/exceptions.py b/vyper/exceptions.py index 51f3fea14c..04667aaa59 100644 --- a/vyper/exceptions.py +++ b/vyper/exceptions.py @@ -373,7 +373,7 @@ def tag_exceptions(node, fallback_exception_type=CompilerPanic, note=None): raise e from None except Exception as e: tb = e.__traceback__ - fallback_message = "unhandled exception" + fallback_message = f"unhandled exception {e}" if note: fallback_message += f", {note}" raise fallback_exception_type(fallback_message, node).with_traceback(tb) diff --git a/vyper/semantics/analysis/local.py b/vyper/semantics/analysis/local.py index 169c71269d..cc8ddaf98d 100644 --- a/vyper/semantics/analysis/local.py +++ b/vyper/semantics/analysis/local.py @@ -510,8 +510,7 @@ def visit(self, node, typ): # validate and annotate folded value if node.has_folded_value: folded_node = node.get_folded_value() - validate_expected_type(folded_node, typ) - folded_node._metadata["type"] = typ + self.visit(folded_node, typ) def visit_Attribute(self, node: vy_ast.Attribute, typ: VyperType) -> None: _validate_msg_data_attribute(node) diff --git a/vyper/semantics/analysis/module.py b/vyper/semantics/analysis/module.py index 8e435f870f..4a7e33e848 100644 --- a/vyper/semantics/analysis/module.py +++ b/vyper/semantics/analysis/module.py @@ -26,11 +26,7 @@ from vyper.semantics.analysis.import_graph import ImportGraph from vyper.semantics.analysis.local import ExprVisitor, validate_functions from vyper.semantics.analysis.pre_typecheck import pre_typecheck -from vyper.semantics.analysis.utils import ( - check_modifiability, - get_exact_type_from_node, - validate_expected_type, -) +from vyper.semantics.analysis.utils import check_modifiability, get_exact_type_from_node from vyper.semantics.data_locations import DataLocation from vyper.semantics.namespace import Namespace, get_namespace, override_global_namespace from vyper.semantics.types import EventT, FlagT, InterfaceT, StructT @@ -315,12 +311,11 @@ def _validate_self_namespace(): if node.is_constant: assert node.value is not None # checked in VariableDecl.validate() - ExprVisitor().visit(node.value, type_) + ExprVisitor().visit(node.value, type_) # performs validate_expected_type if not check_modifiability(node.value, Modifiability.CONSTANT): raise StateAccessViolation("Value must be a literal", node.value) - validate_expected_type(node.value, type_) _validate_self_namespace() return _finalize() diff --git a/vyper/semantics/analysis/pre_typecheck.py b/vyper/semantics/analysis/pre_typecheck.py index a1302ce9c9..1c2a5392c3 100644 --- a/vyper/semantics/analysis/pre_typecheck.py +++ b/vyper/semantics/analysis/pre_typecheck.py @@ -1,94 +1,210 @@ from vyper import ast as vy_ast -from vyper.exceptions import UnfoldableNode - - -# try to fold a node, swallowing exceptions. this function is very similar to -# `VyperNode.get_folded_value()` but additionally checks in the constants -# table if the node is a `Name` node. -# -# CMC 2023-12-30 a potential refactor would be to move this function into -# `Name._try_fold` (which would require modifying the signature of _try_fold to -# take an optional constants table as parameter). this would remove the -# need to use this function in conjunction with `get_descendants` since -# `VyperNode._try_fold()` already recurses. it would also remove the need -# for `VyperNode._set_folded_value()`. -def _fold_with_constants(node: vy_ast.VyperNode, constants: dict[str, vy_ast.VyperNode]): - if node.has_folded_value: - return - - if isinstance(node, vy_ast.Name): - # check if it's in constants table - var_name = node.id - - if var_name not in constants: - return - - res = constants[var_name] - node._set_folded_value(res) - return - - try: - # call get_folded_value for its side effects - node.get_folded_value() - except UnfoldableNode: - pass - - -def _get_constants(node: vy_ast.Module) -> dict: - constants: dict[str, vy_ast.VyperNode] = {} - const_var_decls = node.get_children(vy_ast.VariableDecl, {"is_constant": True}) - - while True: - n_processed = 0 - - for c in const_var_decls.copy(): - assert c.value is not None # guaranteed by VariableDecl.validate() - - for n in c.get_descendants(reverse=True): - _fold_with_constants(n, constants) - +from vyper.exceptions import InvalidLiteral, UnfoldableNode +from vyper.semantics.analysis.base import VarInfo +from vyper.semantics.analysis.common import VyperNodeVisitorBase +from vyper.semantics.namespace import get_namespace + + +def pre_typecheck(module_ast: vy_ast.Module): + ConstantFolder(module_ast).run() + + +class ConstantFolder(VyperNodeVisitorBase): + def __init__(self, module_ast): + self._constants = {} + self._module_ast = module_ast + + def run(self): + self._get_constants() + self.visit(self._module_ast) + + def _get_constants(self): + module = self._module_ast + const_var_decls = module.get_children(vy_ast.VariableDecl, {"is_constant": True}) + + while True: + n_processed = 0 + + for c in const_var_decls.copy(): + # visit the entire constant node in case its type annotation + # has unfolded constants in it. + self.visit(c) + + assert c.value is not None # guaranteed by VariableDecl.validate() + try: + val = c.value.get_folded_value() + except UnfoldableNode: + # not foldable, maybe it depends on other constants + # so try again later + continue + + # note that if a constant is redefined, its value will be + # overwritten, but it is okay because the error is handled + # downstream + name = c.target.id + self._constants[name] = val + + n_processed += 1 + const_var_decls.remove(c) + + if n_processed == 0: + # this condition means that there are some constant vardecls + # whose values are not foldable. this can happen for struct + # and interface constants for instance. these are valid constant + # declarations, but we just can't fold them at this stage. + break + + def visit(self, node): + if node.has_folded_value: + return node.get_folded_value() + + for c in node.get_children(): try: - val = c.value.get_folded_value() + self.visit(c) except UnfoldableNode: - # not foldable, maybe it depends on other constants - # so try again later - continue - - # note that if a constant is redefined, its value will be - # overwritten, but it is okay because the error is handled - # downstream - name = c.target.id - constants[name] = val - - n_processed += 1 - const_var_decls.remove(c) - - if n_processed == 0: - # this condition means that there are some constant vardecls - # whose values are not foldable. this can happen for struct - # and interface constants for instance. these are valid constant - # declarations, but we just can't fold them at this stage. - break - - return constants - - -# perform constant folding on a module AST -def pre_typecheck(node: vy_ast.Module) -> None: - """ - Perform pre-typechecking steps on a Module AST node. - At this point, this is limited to performing constant folding. - """ - constants = _get_constants(node) - - # note: use reverse to get descendants in leaf-first order - for n in node.get_descendants(reverse=True): - # try folding every single node. note this should be done before - # type checking because the typechecker requires literals or - # foldable nodes in type signatures and some other places (e.g. - # certain builtin kwargs). - # - # note we could limit to only folding nodes which are required - # during type checking, but it's easier to just fold everything - # and be done with it! - _fold_with_constants(n, constants) + # ignore bubbled up exceptions + pass + + try: + for class_ in node.__class__.mro(): + ast_type = class_.__name__ + + visitor_fn = getattr(self, f"visit_{ast_type}", None) + if visitor_fn: + folded_value = visitor_fn(node) + node._set_folded_value(folded_value) + return folded_value + except UnfoldableNode: + # ignore bubbled up exceptions + pass + + return node + + def visit_Constant(self, node) -> vy_ast.ExprNode: + return node + + def visit_Name(self, node) -> vy_ast.ExprNode: + try: + return self._constants[node.id] + except KeyError: + raise UnfoldableNode("unknown name", node) + + def visit_UnaryOp(self, node): + operand = node.operand.get_folded_value() + + if isinstance(node.op, vy_ast.Not) and not isinstance(operand, vy_ast.NameConstant): + raise UnfoldableNode("not a boolean!", node.operand) + if isinstance(node.op, vy_ast.USub) and not isinstance(operand, vy_ast.Num): + raise UnfoldableNode("not a number!", node.operand) + if isinstance(node.op, vy_ast.Invert) and not isinstance(operand, vy_ast.Int): + raise UnfoldableNode("not an int!", node.operand) + + value = node.op._op(operand.value) + return type(operand).from_node(node, value=value) + + def visit_BinOp(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if type(left) is not type(right): + raise UnfoldableNode("invalid operation", node) + if not isinstance(left, vy_ast.Num): + raise UnfoldableNode("not a number!", node.left) + + # this validation is performed to prevent the compiler from hanging + # on very large shifts and improve the error message for negative + # values. + if isinstance(node.op, (vy_ast.LShift, vy_ast.RShift)) and not (0 <= right.value <= 256): + raise InvalidLiteral("Shift bits must be between 0 and 256", node.right) + + value = node.op._op(left.value, right.value) + return type(left).from_node(node, value=value) + + def visit_BoolOp(self, node): + values = [v.get_folded_value() for v in node.values] + + if any(not isinstance(v, vy_ast.NameConstant) for v in values): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + values = [v.value for v in values] + value = node.op._op(values) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_Compare(self, node): + left, right = [i.get_folded_value() for i in (node.left, node.right)] + if not isinstance(left, vy_ast.Constant): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + + # CMC 2022-08-04 we could probably remove these evaluation rules as they + # are taken care of in the IR optimizer now. + if isinstance(node.op, (vy_ast.In, vy_ast.NotIn)): + if not isinstance(right, vy_ast.List): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if next((i for i in right.elements if not isinstance(i, vy_ast.Constant)), None): + raise UnfoldableNode("Node contains invalid field(s) for evaluation") + if len(set([type(i) for i in right.elements])) > 1: + raise UnfoldableNode("List contains multiple literal types") + value = node.op._op(left.value, [i.value for i in right.elements]) + return vy_ast.NameConstant.from_node(node, value=value) + + if not isinstance(left, type(right)): + raise UnfoldableNode("Cannot compare different literal types") + + # this is maybe just handled in the type checker. + if not isinstance(node.op, (vy_ast.Eq, vy_ast.NotEq)) and not isinstance(left, vy_ast.Num): + raise UnfoldableNode( + f"Invalid literal types for {node.op.description} comparison", node + ) + + value = node.op._op(left.value, right.value) + return vy_ast.NameConstant.from_node(node, value=value) + + def visit_List(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Tuple(self, node) -> vy_ast.ExprNode: + elements = [e.get_folded_value() for e in node.elements] + return type(node).from_node(node, elements=elements) + + def visit_Dict(self, node) -> vy_ast.ExprNode: + values = [v.get_folded_value() for v in node.values] + return type(node).from_node(node, values=values) + + def visit_Call(self, node) -> vy_ast.ExprNode: + if not isinstance(node.func, vy_ast.Name): + raise UnfoldableNode("not a builtin", node) + + namespace = get_namespace() + + func_name = node.func.id + if func_name not in namespace: + raise UnfoldableNode("unknown", node) + + varinfo = namespace[func_name] + if not isinstance(varinfo, VarInfo): + raise UnfoldableNode("unfoldable", node) + + typ = varinfo.typ + # TODO: rename to vyper_type.try_fold_call_expr + if not hasattr(typ, "_try_fold"): + raise UnfoldableNode("unfoldable", node) + return typ._try_fold(node) # type: ignore + + def visit_Subscript(self, node) -> vy_ast.ExprNode: + slice_ = node.slice.value.get_folded_value() + value = node.value.get_folded_value() + + if not isinstance(value, vy_ast.List): + raise UnfoldableNode("Subscript object is not a literal list") + + elements = value.elements + if len(set([type(i) for i in elements])) > 1: + raise UnfoldableNode("List contains multiple node types") + + if not isinstance(slice_, vy_ast.Int): + raise UnfoldableNode("invalid index type", slice_) + + idx = slice_.value + if idx < 0 or idx >= len(elements): + raise UnfoldableNode("invalid index value") + + return elements[idx] diff --git a/vyper/semantics/analysis/utils.py b/vyper/semantics/analysis/utils.py index ba1b02b8d6..359b51b71e 100644 --- a/vyper/semantics/analysis/utils.py +++ b/vyper/semantics/analysis/utils.py @@ -650,6 +650,8 @@ def check_modifiability(node: vy_ast.VyperNode, modifiability: Modifiability) -> return all(check_modifiability(v, modifiability) for v in args[0].values) call_type = get_exact_type_from_node(node.func) + + # builtins call_type_modifiability = getattr(call_type, "_modifiability", Modifiability.MODIFIABLE) return call_type_modifiability >= modifiability diff --git a/vyper/semantics/types/base.py b/vyper/semantics/types/base.py index 429ba807e1..14949f693f 100644 --- a/vyper/semantics/types/base.py +++ b/vyper/semantics/types/base.py @@ -19,7 +19,7 @@ # type of type `type_` class _GenericTypeAcceptor: def __repr__(self): - return repr(self.type_) + return f"GenericTypeAcceptor({self.type_})" def __init__(self, type_): self.type_ = type_ diff --git a/vyper/semantics/types/module.py b/vyper/semantics/types/module.py index b0d7800011..8f1a5cc0dc 100644 --- a/vyper/semantics/types/module.py +++ b/vyper/semantics/types/module.py @@ -4,7 +4,12 @@ from vyper import ast as vy_ast from vyper.abi_types import ABI_Address, ABIType from vyper.ast.validation import validate_call_args -from vyper.exceptions import InterfaceViolation, NamespaceCollision, StructureException +from vyper.exceptions import ( + InterfaceViolation, + NamespaceCollision, + StructureException, + UnfoldableNode, +) from vyper.semantics.analysis.base import VarInfo from vyper.semantics.analysis.utils import validate_expected_type, validate_unique_method_ids from vyper.semantics.namespace import get_namespace @@ -53,6 +58,15 @@ def abi_type(self) -> ABIType: def __repr__(self): return f"interface {self._id}" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + arg = node.args[0].get_folded_value() + if not isinstance(arg, vy_ast.Hex): + raise UnfoldableNode("not an address", arg) + + return node + # when using the type itself (not an instance) in the call position def _ctor_call_return(self, node: vy_ast.Call) -> "InterfaceT": self._ctor_arg_types(node) diff --git a/vyper/semantics/types/user.py b/vyper/semantics/types/user.py index a4e782349d..8ef9aa8d4a 100644 --- a/vyper/semantics/types/user.py +++ b/vyper/semantics/types/user.py @@ -10,6 +10,7 @@ InvalidAttribute, NamespaceCollision, StructureException, + UnfoldableNode, UnknownAttribute, VariableDeclarationException, ) @@ -357,6 +358,16 @@ def from_StructDef(cls, base_node: vy_ast.StructDef) -> "StructT": def __repr__(self): return f"{self._id} declaration object" + def _try_fold(self, node): + if len(node.args) != 1: + raise UnfoldableNode("wrong number of args", node.args) + args = [arg.get_folded_value() for arg in node.args] + if not isinstance(args[0], vy_ast.Dict): + raise UnfoldableNode("not a dict") + + # it can't be reduced, but this lets upstream code know it's constant + return node + @property def size_in_bytes(self): return sum(i.size_in_bytes for i in self.member_types.values())