Skip to content

Commit

Permalink
fix: normalize extracts for overflow checks (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
daejunpark authored Dec 7, 2023
1 parent b9bc79e commit f5733f0
Showing 1 changed file with 59 additions and 57 deletions.
116 changes: 59 additions & 57 deletions src/halmos/sevm.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,58 @@ def wstore_bytes(
mem[loc + i] = arr[i]


def normalize(expr: Any) -> Any:
# Concat(Extract(255, 8, op(x, y)), op(Extract(7, 0, x), Extract(7, 0, y))) => op(x, y)
def normalize_extract(arg0, arg1):
if (
arg0.decl().name() == "extract"
and arg0.num_args() == 1
and arg0.params() == [255, 8]
):
target = arg0.arg(0) # op(x, y)

# this form triggers the partial inward-propagation of extracts in simplify()
# that is, `Extract(7, 0, op(x, y))` => `op(Extract(7, 0, x), Extract(7, 0, y))`, followed by further simplification
target_equivalent = Concat(Extract(255, 8, target), Extract(7, 0, target))

given = Concat(arg0, arg1)

# since target_equivalent and given may not be structurally equal, we compare their fully simplified forms
if eq(simplify(given), simplify(target_equivalent)):
# here we have: given == target_equivalent == target
return target

return None

if expr.decl().name() == "concat" and expr.num_args() >= 2:
new_args = []

i = 0
n = expr.num_args()

# apply normalize_extract for each pair of adjacent arguments
while i < n - 1:
arg0 = expr.arg(i)
arg1 = expr.arg(i + 1)

arg0_arg1 = normalize_extract(arg0, arg1)

if arg0_arg1 is None: # not simplified
new_args.append(arg0)
i += 1
else: # simplified into a single term
new_args.append(arg0_arg1)
i += 2

# handle the last element
if i == n - 1:
new_args.append(expr.arg(i))

return concat(new_args)

return expr


@dataclass(frozen=True)
class EventLog:
"""
Expand Down Expand Up @@ -995,59 +1047,7 @@ def resolve_libs(self, creation_hexcode, deployed_hexcode, lib_references) -> st


class Storage:
@classmethod
def normalize(cls, expr: Any) -> Any:
# Concat(Extract(255, 8, bvadd(x, y)), bvadd(Extract(7, 0, x), Extract(7, 0, y))) => x + y
def normalize_extract(arg0, arg1):
if (
arg0.decl().name() == "extract"
and arg0.num_args() == 1
and arg0.params() == [255, 8]
):
target = arg0.arg(0) # bvadd(x, y)

# this form triggers the partial inward-propagation of extracts in simplify()
# that is, `Extract(7, 0, bvadd(x, y))` => `bvadd(Extract(7, 0, x), Extract(7, 0, y))`, followed by further simplification
target_equivalent = Concat(
Extract(255, 8, target), Extract(7, 0, target)
)

given = Concat(arg0, arg1)

# since target_equivalent and given may not be structurally equal, we compare their fully simplified forms
if eq(simplify(given), simplify(target_equivalent)):
# here we have: given == target_equivalent == target
return target

return None

if expr.decl().name() == "concat" and expr.num_args() >= 2:
new_args = []

i = 0
n = expr.num_args()

# apply normalize_extract for each pair of adjacent arguments
while i < n - 1:
arg0 = expr.arg(i)
arg1 = expr.arg(i + 1)

arg0_arg1 = normalize_extract(arg0, arg1)

if arg0_arg1 is None: # not simplified
new_args.append(arg0)
i += 1
else: # simplified into a single term
new_args.append(arg0_arg1)
i += 2

# handle the last element
if i == n - 1:
new_args.append(expr.arg(i))

return concat(new_args)

return expr
pass


class SolidityStorage(Storage):
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:

@classmethod
def decode(cls, loc: Any) -> Any:
loc = cls.normalize(loc)
loc = normalize(loc)
# m[k] : hash(k.m)
if loc.decl().name() == "sha3_512":
args = loc.arg(0)
Expand All @@ -1146,7 +1146,7 @@ def decode(cls, loc: Any) -> Any:
return cls.decode(base) + (con(0),)
# m[k] : hash(k.m) where |k| != 256-bit
elif loc.decl().name().startswith("sha3_"):
sha3_input = cls.normalize(loc.arg(0))
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2:
offset = simplify(sha3_input.arg(0))
base = simplify(sha3_input.arg(1))
Expand Down Expand Up @@ -1235,14 +1235,14 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None:

@classmethod
def decode(cls, loc: Any) -> Any:
loc = cls.normalize(loc)
loc = normalize(loc)
if loc.decl().name() == "sha3_512": # hash(hi,lo), recursively
args = loc.arg(0)
hi = cls.decode(simplify(Extract(511, 256, args)))
lo = cls.decode(simplify(Extract(255, 0, args)))
return cls.simple_hash(Concat(hi, lo))
elif loc.decl().name().startswith("sha3_"):
sha3_input = cls.normalize(loc.arg(0))
sha3_input = normalize(loc.arg(0))
if sha3_input.decl().name() == "concat":
decoded_sha3_input_args = [
cls.decode(sha3_input.arg(i)) for i in range(sha3_input.num_args())
Expand Down Expand Up @@ -1386,6 +1386,8 @@ def bitsize(w: Word) -> int:
return 256 - w.arg(0).size()
return 256

w1 = normalize(w1)

if w1.decl().name() == "bvmul" and w1.num_args() == 2:
x = w1.arg(0)
y = w1.arg(1)
Expand Down

0 comments on commit f5733f0

Please sign in to comment.