Skip to content

Commit

Permalink
refactor: constant folding (#3719)
Browse files Browse the repository at this point in the history
refactor constant folding into a visitor class, clean up a couple passes

this moves responsibility for knowing how to fold a node off the
individual AST node implementations and into the ConstantFolder visitor.

by adding a dependency to get_namespace() it also makes constant folding
more generic; soon we can rely on more things being in the global
namespace at constant folding time.
  • Loading branch information
charles-cooper authored Jan 10, 2024
1 parent a1fd228 commit 06fa46a
Show file tree
Hide file tree
Showing 30 changed files with 337 additions and 379 deletions.
7 changes: 3 additions & 4 deletions tests/functional/builtins/folding/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from hypothesis import example, given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold
from vyper.exceptions import InvalidType


Expand All @@ -19,9 +18,9 @@ def foo(a: int256) -> int256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"abs({a})")
vyper_ast = parse_and_fold(f"abs({a})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["abs"]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(a) == new_node.value == abs(a)

Expand Down
7 changes: 3 additions & 4 deletions tests/functional/builtins/folding/test_addmod_mulmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from hypothesis import assume, given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold

st_uint256 = st.integers(min_value=0, max_value=2**256 - 1)

Expand All @@ -22,8 +21,8 @@ def foo(a: uint256, b: uint256, c: uint256) -> uint256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({a}, {b}, {c})")
vyper_ast = parse_and_fold(f"{fn_name}({a}, {b}, {c})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(a, b, c) == new_node.value
16 changes: 7 additions & 9 deletions tests/functional/builtins/folding/test_bitwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from hypothesis import given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from tests.utils import parse_and_fold
from vyper.exceptions import InvalidType, OverflowException
from vyper.semantics.analysis.utils import validate_expected_type
from vyper.semantics.types.shortcuts import INT256_T, UINT256_T
Expand All @@ -29,7 +29,7 @@ def foo(a: uint256, b: uint256) -> uint256:

contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
vyper_ast = parse_and_fold(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()

Expand All @@ -48,10 +48,9 @@ def foo(a: uint256, b: uint256) -> uint256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value

try:
vyper_ast = parse_and_fold(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()
# force bounds check, no-op because validate_numeric_bounds
# already does this, but leave in for hygiene (in case
Expand All @@ -78,10 +77,9 @@ def foo(a: int256, b: uint256) -> int256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value

try:
vyper_ast = parse_and_fold(f"{a} {op} {b}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()
validate_expected_type(new_node, INT256_T) # force bounds check
# compile time behavior does not match runtime behavior.
Expand All @@ -105,7 +103,7 @@ def foo(a: uint256) -> uint256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"~{value}")
vyper_ast = parse_and_fold(f"~{value}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()

Expand Down
7 changes: 3 additions & 4 deletions tests/functional/builtins/folding/test_epsilon.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold


@pytest.mark.parametrize("typ_name", ["decimal"])
Expand All @@ -13,8 +12,8 @@ def foo() -> {typ_name}:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"epsilon({typ_name})")
vyper_ast = parse_and_fold(f"epsilon({typ_name})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE["epsilon"]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo() == new_node.value
7 changes: 3 additions & 4 deletions tests/functional/builtins/folding/test_floor_ceil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from hypothesis import example, given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold

st_decimals = st.decimals(
min_value=-(2**32), max_value=2**32, allow_nan=False, allow_infinity=False, places=10
Expand All @@ -28,8 +27,8 @@ def foo(a: decimal) -> int256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})")
vyper_ast = parse_and_fold(f"{fn_name}({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(value) == new_node.value
10 changes: 5 additions & 5 deletions tests/functional/builtins/folding/test_fold_as_wei_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from hypothesis import given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from tests.utils import parse_and_fold
from vyper.builtins import functions as vy_fn
from vyper.utils import SizeLimits

Expand Down Expand Up @@ -30,9 +30,9 @@ def foo(a: decimal) -> uint256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value:.10f}, '{denom}')")
vyper_ast = parse_and_fold(f"as_wei_value({value:.10f}, '{denom}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.AsWeiValue()._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(value) == new_node.value

Expand All @@ -49,8 +49,8 @@ def foo(a: uint256) -> uint256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"as_wei_value({value}, '{denom}')")
vyper_ast = parse_and_fold(f"as_wei_value({value}, '{denom}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.AsWeiValue()._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(value) == new_node.value
15 changes: 7 additions & 8 deletions tests/functional/builtins/folding/test_keccak_sha.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from hypothesis import given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold

alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&()*+,-./:;<=>?@[]^_`{|}~' # NOQA: E501

Expand All @@ -20,9 +19,9 @@ def foo(a: String[100]) -> bytes32:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}('''{value}''')")
vyper_ast = parse_and_fold(f"{fn_name}('''{value}''')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert f"0x{contract.foo(value).hex()}" == new_node.value

Expand All @@ -39,9 +38,9 @@ def foo(a: Bytes[100]) -> bytes32:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})")
vyper_ast = parse_and_fold(f"{fn_name}({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert f"0x{contract.foo(value).hex()}" == new_node.value

Expand All @@ -60,8 +59,8 @@ def foo(a: Bytes[100]) -> bytes32:

value = f"0x{value.hex()}"

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({value})")
vyper_ast = parse_and_fold(f"{fn_name}({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert f"0x{contract.foo(value).hex()}" == new_node.value
15 changes: 7 additions & 8 deletions tests/functional/builtins/folding/test_len.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold


@pytest.mark.parametrize("length", [0, 1, 32, 33, 64, 65, 1024])
Expand All @@ -15,9 +14,9 @@ def foo(a: String[1024]) -> uint256:

value = "a" * length

vyper_ast = vy_ast.parse_to_ast(f"len('{value}')")
vyper_ast = parse_and_fold(f"len('{value}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.Len()._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(value) == new_node.value

Expand All @@ -33,9 +32,9 @@ def foo(a: Bytes[1024]) -> uint256:

value = "a" * length

vyper_ast = vy_ast.parse_to_ast(f"len(b'{value}')")
vyper_ast = parse_and_fold(f"len(b'{value}')")
old_node = vyper_ast.body[0].value
new_node = vy_fn.Len()._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(value.encode()) == new_node.value

Expand All @@ -51,8 +50,8 @@ def foo(a: Bytes[1024]) -> uint256:

value = f"0x{'00' * length}"

vyper_ast = vy_ast.parse_to_ast(f"len({value})")
vyper_ast = parse_and_fold(f"len({value})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.Len()._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(value) == new_node.value
15 changes: 7 additions & 8 deletions tests/functional/builtins/folding/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from hypothesis import given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold
from vyper.utils import SizeLimits

st_decimals = st.decimals(
Expand All @@ -29,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})")
vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(left, right) == new_node.value

Expand All @@ -48,9 +47,9 @@ def foo(a: int128, b: int128) -> int128:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})")
vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(left, right) == new_node.value

Expand All @@ -67,8 +66,8 @@ def foo(a: uint256, b: uint256) -> uint256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{fn_name}({left}, {right})")
vyper_ast = parse_and_fold(f"{fn_name}({left}, {right})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.DISPATCH_TABLE[fn_name]._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(left, right) == new_node.value
7 changes: 3 additions & 4 deletions tests/functional/builtins/folding/test_powmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from hypothesis import given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from vyper.builtins import functions as vy_fn
from tests.utils import parse_and_fold

st_uint256 = st.integers(min_value=0, max_value=2**256)

Expand All @@ -19,8 +18,8 @@ def foo(a: uint256, b: uint256) -> uint256:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"pow_mod256({a}, {b})")
vyper_ast = parse_and_fold(f"pow_mod256({a}, {b})")
old_node = vyper_ast.body[0].value
new_node = vy_fn.PowMod256()._try_fold(old_node)
new_node = old_node.get_folded_value()

assert contract.foo(a, b) == new_node.value
4 changes: 2 additions & 2 deletions tests/functional/grammar/test_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import hypothesis
import hypothesis.strategies as st
import pytest
from hypothesis import assume, given
from hypothesis import HealthCheck, assume, given
from hypothesis.extra.lark import LarkStrategy

from vyper.ast import Module, parse_to_ast
Expand Down Expand Up @@ -103,7 +103,7 @@ def has_no_docstrings(c):

@pytest.mark.fuzzing
@given(code=from_grammar().filter(lambda c: utf8_encodable(c)))
@hypothesis.settings(max_examples=500)
@hypothesis.settings(max_examples=500, suppress_health_check=[HealthCheck.too_slow])
def test_grammar_bruteforce(code):
if utf8_encodable(code):
_, _, _, reformatted_code = pre_parse(code + "\n")
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/syntax/test_bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def foo():
def foo() -> bool:
return (1 == 2) <= (1 == 1)
""",
TypeMismatch,
InvalidOperation,
),
"""
@external
Expand Down
13 changes: 5 additions & 8 deletions tests/unit/ast/nodes/test_fold_binop_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from hypothesis import example, given, settings
from hypothesis import strategies as st

from vyper import ast as vy_ast
from tests.utils import parse_and_fold
from vyper.exceptions import OverflowException, TypeMismatch, ZeroDivisionException

st_decimals = st.decimals(
Expand All @@ -28,9 +28,9 @@ def foo(a: decimal, b: decimal) -> decimal:
"""
contract = get_contract(source)

vyper_ast = vy_ast.parse_to_ast(f"{left} {op} {right}")
old_node = vyper_ast.body[0].value
try:
vyper_ast = parse_and_fold(f"{left} {op} {right}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()
is_valid = True
except ZeroDivisionException:
Expand All @@ -45,11 +45,8 @@ def foo(a: decimal, b: decimal) -> decimal:

def test_binop_pow():
# raises because Vyper does not support decimal exponentiation
vyper_ast = vy_ast.parse_to_ast("3.1337 ** 4.2")
old_node = vyper_ast.body[0].value

with pytest.raises(TypeMismatch):
old_node.get_folded_value()
_ = parse_and_fold("3.1337 ** 4.2")


@pytest.mark.fuzzing
Expand All @@ -72,8 +69,8 @@ def foo({input_value}) -> decimal:

literal_op = " ".join(f"{a} {b}" for a, b in zip(values, ops))
literal_op = literal_op.rsplit(maxsplit=1)[0]
vyper_ast = vy_ast.parse_to_ast(literal_op)
try:
vyper_ast = parse_and_fold(literal_op)
new_node = vyper_ast.body[0].value.get_folded_value()
expected = new_node.value
is_valid = -(2**127) <= expected < 2**127
Expand Down
Loading

0 comments on commit 06fa46a

Please sign in to comment.