Skip to content

Commit

Permalink
Merge branch 'master' into feat/cse
Browse files Browse the repository at this point in the history
  • Loading branch information
HodanPlodky authored Sep 30, 2024
2 parents 9ce9dd8 + e21f3e8 commit a3585ed
Show file tree
Hide file tree
Showing 10 changed files with 167 additions and 50 deletions.
17 changes: 17 additions & 0 deletions tests/unit/ast/nodes/test_fold_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,20 @@ def test_compare_type_mismatch(op):
old_node = vyper_ast.body[0].value
with pytest.raises(UnfoldableNode):
old_node.get_folded_value()


@pytest.mark.parametrize("op", ["==", "!="])
def test_compare_eq_bytes(get_contract, op):
left, right = "0xA1AAB33F", "0xa1aab33f"
source = f"""
@external
def foo(a: bytes4, b: bytes4) -> bool:
return a {op} b
"""
contract = get_contract(source)

vyper_ast = parse_and_fold(f"{left} {op} {right}")
old_node = vyper_ast.body[0].value
new_node = old_node.get_folded_value()

assert contract.foo(left, right) == new_node.value
31 changes: 31 additions & 0 deletions tests/unit/compiler/venom/test_sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,34 @@ def test_cont_phi_const_case():
assert sccp.lattice[IRVariable("%5", version=1)].value == 106
assert sccp.lattice[IRVariable("%5", version=2)].value == 97
assert sccp.lattice[IRVariable("%5")].value == 2


def test_phi_reduction_after_unreachable_block():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
join = IRBasicBlock(IRLabel("join"), fn)
fn.append_basic_block(join)

op = bb.append_instruction("store", 1)
true = IRLiteral(1)
bb.append_instruction("jnz", true, br1.label, join.label)

op1 = br1.append_instruction("store", 2)

br1.append_instruction("jmp", join.label)

join.append_instruction("phi", bb.label, op, br1.label, op1)
join.append_instruction("stop")

ac = IRAnalysesCache(fn)
SCCP(ac, fn).run_pass()

assert join.instructions[0].opcode == "store", join.instructions[0]
assert join.instructions[0].operands == [op1]

assert join.instructions[1].opcode == "stop"
49 changes: 49 additions & 0 deletions tests/unit/compiler/venom/test_simplify_cfg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.basicblock import IRBasicBlock, IRLabel, IRLiteral
from vyper.venom.context import IRContext
from vyper.venom.passes.sccp import SCCP
from vyper.venom.passes.simplify_cfg import SimplifyCFGPass


def test_phi_reduction_after_block_pruning():
ctx = IRContext()
fn = ctx.create_function("_global")

bb = fn.get_basic_block()

br1 = IRBasicBlock(IRLabel("then"), fn)
fn.append_basic_block(br1)
br2 = IRBasicBlock(IRLabel("else"), fn)
fn.append_basic_block(br2)

join = IRBasicBlock(IRLabel("join"), fn)
fn.append_basic_block(join)

true = IRLiteral(1)
bb.append_instruction("jnz", true, br1.label, br2.label)

op1 = br1.append_instruction("store", 1)
op2 = br2.append_instruction("store", 2)

br1.append_instruction("jmp", join.label)
br2.append_instruction("jmp", join.label)

join.append_instruction("phi", br1.label, op1, br2.label, op2)
join.append_instruction("stop")

ac = IRAnalysesCache(fn)
SCCP(ac, fn).run_pass()
SimplifyCFGPass(ac, fn).run_pass()

bbs = list(fn.get_basic_blocks())

assert len(bbs) == 1
final_bb = bbs[0]

inst0, inst1, inst2 = final_bb.instructions

assert inst0.opcode == "store"
assert inst0.operands == [IRLiteral(1)]
assert inst1.opcode == "store"
assert inst1.operands == [inst0.output]
assert inst2.opcode == "stop"
7 changes: 5 additions & 2 deletions vyper/semantics/analysis/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,11 @@ def visit_Compare(self, node):
raise UnfoldableNode(
f"Invalid literal types for {node.op.description} comparison", node
)

value = node.op._op(left.value, right.value)
lvalue, rvalue = left.value, right.value
if isinstance(left, vy_ast.Hex):
# Hex values are str, convert to be case-unsensitive.
lvalue, rvalue = lvalue.lower(), rvalue.lower()
value = node.op._op(lvalue, rvalue)
return vy_ast.NameConstant.from_node(node, value=value)

def visit_List(self, node) -> vy_ast.ExprNode:
Expand Down
18 changes: 10 additions & 8 deletions vyper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ class OrderedSet(Generic[_T]):
"""

def __init__(self, iterable=None):
self._data = dict()
if iterable is not None:
self.update(iterable)
if iterable is None:
self._data = dict()
else:
self._data = dict.fromkeys(iterable)

def __repr__(self):
keys = ", ".join(repr(k) for k in self)
Expand Down Expand Up @@ -57,6 +58,7 @@ def pop(self):
def add(self, item: _T) -> None:
self._data[item] = None

# NOTE to refactor: duplicate of self.update()
def addmany(self, iterable):
for item in iterable:
self._data[item] = None
Expand Down Expand Up @@ -109,11 +111,11 @@ def intersection(cls, *sets):
if len(sets) == 0:
raise ValueError("undefined: intersection of no sets")

ret = sets[0].copy()
for e in sets[0]:
if any(e not in s for s in sets[1:]):
ret.remove(e)
return ret
tmp = sets[0]._data.keys()
for s in sets[1:]:
tmp &= s._data.keys()

return cls(tmp)


class StringEnum(enum.Enum):
Expand Down
15 changes: 0 additions & 15 deletions vyper/venom/analysis/dup_requirements.py

This file was deleted.

2 changes: 0 additions & 2 deletions vyper/venom/basicblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ class IRInstruction:
output: Optional[IROperand]
# set of live variables at this instruction
liveness: OrderedSet[IRVariable]
dup_requirements: OrderedSet[IRVariable]
parent: "IRBasicBlock"
fence_id: int
annotation: Optional[str]
Expand All @@ -228,7 +227,6 @@ def __init__(
self.operands = list(operands) # in case we get an iterator
self.output = output
self.liveness = OrderedSet()
self.dup_requirements = OrderedSet()
self.fence_id = -1
self.annotation = None
self.ast_source = None
Expand Down
38 changes: 34 additions & 4 deletions vyper/venom/passes/sccp/sccp.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ class SCCP(IRPass):
uses: dict[IRVariable, OrderedSet[IRInstruction]]
lattice: Lattice
work_list: list[WorkListItem]
cfg_dirty: bool
cfg_in_exec: dict[IRBasicBlock, OrderedSet[IRBasicBlock]]

cfg_dirty: bool

def __init__(self, analyses_cache: IRAnalysesCache, function: IRFunction):
super().__init__(analyses_cache, function)
self.lattice = {}
Expand All @@ -72,9 +73,9 @@ def run_pass(self):
self._calculate_sccp(self.fn.entry)
self._propagate_constants()

# self._propagate_variables()

self.analyses_cache.invalidate_analysis(CFGAnalysis)
if self.cfg_dirty:
self.analyses_cache.force_analysis(CFGAnalysis)
self._fix_phi_nodes()

def _calculate_sccp(self, entry: IRBasicBlock):
"""
Expand Down Expand Up @@ -304,6 +305,7 @@ def _replace_constants(self, inst: IRInstruction):
target = inst.operands[1]
inst.opcode = "jmp"
inst.operands = [target]

self.cfg_dirty = True

elif inst.opcode in ("assert", "assert_unreachable"):
Expand All @@ -329,6 +331,34 @@ def _replace_constants(self, inst: IRInstruction):
if isinstance(lat, IRLiteral):
inst.operands[i] = lat

def _fix_phi_nodes(self):
# fix basic blocks whose cfg in was changed
# maybe this should really be done in _visit_phi
needs_sort = False

for bb in self.fn.get_basic_blocks():
cfg_in_labels = OrderedSet(in_bb.label for in_bb in bb.cfg_in)

for inst in bb.instructions:
if inst.opcode != "phi":
break
needs_sort |= self._fix_phi_inst(inst, cfg_in_labels)

# move phi instructions to the top of the block
if needs_sort:
bb.instructions.sort(key=lambda inst: inst.opcode != "phi")

def _fix_phi_inst(self, inst: IRInstruction, cfg_in_labels: OrderedSet):
operands = [op for label, op in inst.phi_operands if label in cfg_in_labels]

if len(operands) != 1:
return False

assert inst.output is not None
inst.opcode = "store"
inst.operands = operands
return True


def _meet(x: LatticeItem, y: LatticeItem) -> LatticeItem:
if x == LatticeEnum.TOP:
Expand Down
16 changes: 7 additions & 9 deletions vyper/venom/passes/simplify_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,21 @@ class SimplifyCFGPass(IRPass):
visited: OrderedSet

def _merge_blocks(self, a: IRBasicBlock, b: IRBasicBlock):
a.instructions.pop()
a.instructions.pop() # pop terminating instruction
for inst in b.instructions:
assert inst.opcode != "phi", "Not implemented yet"
if inst.opcode == "phi":
a.instructions.insert(0, inst)
else:
inst.parent = a
a.instructions.append(inst)
assert inst.opcode != "phi", f"Instruction should never be phi {b}"
inst.parent = a
a.instructions.append(inst)

# Update CFG
a.cfg_out = b.cfg_out
if len(b.cfg_out) > 0:
next_bb = b.cfg_out.first()

for next_bb in a.cfg_out:
next_bb.remove_cfg_in(b)
next_bb.add_cfg_in(a)

for inst in next_bb.instructions:
# assume phi instructions are at beginning of bb
if inst.opcode != "phi":
break
inst.operands[inst.operands.index(b.label)] = a.label
Expand Down
24 changes: 14 additions & 10 deletions vyper/venom/venom_to_assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
)
from vyper.utils import MemoryPositions, OrderedSet
from vyper.venom.analysis.analysis import IRAnalysesCache
from vyper.venom.analysis.dup_requirements import DupRequirementsAnalysis
from vyper.venom.analysis.liveness import LivenessAnalysis
from vyper.venom.basicblock import (
IRBasicBlock,
Expand Down Expand Up @@ -153,7 +152,6 @@ def generate_evm(self, no_optimize: bool = False) -> list[str]:

NormalizationPass(ac, fn).run_pass()
self.liveness_analysis = ac.request_analysis(LivenessAnalysis)
ac.request_analysis(DupRequirementsAnalysis)

assert fn.normalized, "Non-normalized CFG!"

Expand Down Expand Up @@ -231,7 +229,12 @@ def _stack_reorder(
return cost

def _emit_input_operands(
self, assembly: list, inst: IRInstruction, ops: list[IROperand], stack: StackModel
self,
assembly: list,
inst: IRInstruction,
ops: list[IROperand],
stack: StackModel,
next_liveness: OrderedSet[IRVariable],
) -> None:
# PRE: we already have all the items on the stack that have
# been scheduled to be killed. now it's just a matter of emitting
Expand All @@ -241,7 +244,7 @@ def _emit_input_operands(
# it with something that is wanted
if ops and stack.height > 0 and stack.peek(0) not in ops:
for op in ops:
if isinstance(op, IRVariable) and op not in inst.dup_requirements:
if isinstance(op, IRVariable) and op not in next_liveness:
self.swap_op(assembly, stack, op)
break

Expand All @@ -264,7 +267,7 @@ def _emit_input_operands(
stack.push(op)
continue

if op in inst.dup_requirements and op not in emitted_ops:
if op in next_liveness and op not in emitted_ops:
self.dup_op(assembly, stack, op)

if op in emitted_ops:
Expand All @@ -288,7 +291,9 @@ def _generate_evm_for_basicblock_r(
all_insts = sorted(basicblock.instructions, key=lambda x: x.opcode != "param")

for i, inst in enumerate(all_insts):
next_liveness = all_insts[i + 1].liveness if i + 1 < len(all_insts) else OrderedSet()
next_liveness = (
all_insts[i + 1].liveness if i + 1 < len(all_insts) else basicblock.out_vars
)

asm.extend(self._generate_evm_for_instruction(inst, stack, next_liveness))

Expand Down Expand Up @@ -327,10 +332,9 @@ def clean_stack_from_cfg_in(
self.pop(asm, stack)

def _generate_evm_for_instruction(
self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet = None
self, inst: IRInstruction, stack: StackModel, next_liveness: OrderedSet
) -> list[str]:
assembly: list[str | int] = []
next_liveness = next_liveness or OrderedSet()
opcode = inst.opcode

#
Expand Down Expand Up @@ -375,7 +379,7 @@ def _generate_evm_for_instruction(
# example, for `%56 = %label1 %13 %label2 %14`, we will
# find an instance of %13 *or* %14 in the stack and replace it with %56.
to_be_replaced = stack.peek(depth)
if to_be_replaced in inst.dup_requirements:
if to_be_replaced in next_liveness:
# %13/%14 is still live(!), so we make a copy of it
self.dup(assembly, stack, depth)
stack.poke(0, ret)
Expand All @@ -390,7 +394,7 @@ def _generate_evm_for_instruction(
return apply_line_numbers(inst, assembly)

# Step 2: Emit instruction's input operands
self._emit_input_operands(assembly, inst, operands, stack)
self._emit_input_operands(assembly, inst, operands, stack, next_liveness)

# Step 3: Reorder stack before join points
if opcode == "jmp":
Expand Down

0 comments on commit a3585ed

Please sign in to comment.