From f5733f022441bd01e661f703c02c7574b20d4f3e Mon Sep 17 00:00:00 2001 From: Daejun Park Date: Wed, 6 Dec 2023 23:40:47 -0800 Subject: [PATCH] fix: normalize extracts for overflow checks (#225) --- src/halmos/sevm.py | 116 +++++++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/src/halmos/sevm.py b/src/halmos/sevm.py index cab9e49f..7dfc75be 100644 --- a/src/halmos/sevm.py +++ b/src/halmos/sevm.py @@ -258,6 +258,58 @@ def wstore_bytes( mem[loc + i] = arr[i] +def normalize(expr: Any) -> Any: + # Concat(Extract(255, 8, op(x, y)), op(Extract(7, 0, x), Extract(7, 0, y))) => op(x, y) + def normalize_extract(arg0, arg1): + if ( + arg0.decl().name() == "extract" + and arg0.num_args() == 1 + and arg0.params() == [255, 8] + ): + target = arg0.arg(0) # op(x, y) + + # this form triggers the partial inward-propagation of extracts in simplify() + # that is, `Extract(7, 0, op(x, y))` => `op(Extract(7, 0, x), Extract(7, 0, y))`, followed by further simplification + target_equivalent = Concat(Extract(255, 8, target), Extract(7, 0, target)) + + given = Concat(arg0, arg1) + + # since target_equivalent and given may not be structurally equal, we compare their fully simplified forms + if eq(simplify(given), simplify(target_equivalent)): + # here we have: given == target_equivalent == target + return target + + return None + + if expr.decl().name() == "concat" and expr.num_args() >= 2: + new_args = [] + + i = 0 + n = expr.num_args() + + # apply normalize_extract for each pair of adjacent arguments + while i < n - 1: + arg0 = expr.arg(i) + arg1 = expr.arg(i + 1) + + arg0_arg1 = normalize_extract(arg0, arg1) + + if arg0_arg1 is None: # not simplified + new_args.append(arg0) + i += 1 + else: # simplified into a single term + new_args.append(arg0_arg1) + i += 2 + + # handle the last element + if i == n - 1: + new_args.append(expr.arg(i)) + + return concat(new_args) + + return expr + + @dataclass(frozen=True) class EventLog: """ @@ -995,59 +1047,7 @@ def resolve_libs(self, creation_hexcode, deployed_hexcode, lib_references) -> st class Storage: - @classmethod - def normalize(cls, expr: Any) -> Any: - # Concat(Extract(255, 8, bvadd(x, y)), bvadd(Extract(7, 0, x), Extract(7, 0, y))) => x + y - def normalize_extract(arg0, arg1): - if ( - arg0.decl().name() == "extract" - and arg0.num_args() == 1 - and arg0.params() == [255, 8] - ): - target = arg0.arg(0) # bvadd(x, y) - - # this form triggers the partial inward-propagation of extracts in simplify() - # that is, `Extract(7, 0, bvadd(x, y))` => `bvadd(Extract(7, 0, x), Extract(7, 0, y))`, followed by further simplification - target_equivalent = Concat( - Extract(255, 8, target), Extract(7, 0, target) - ) - - given = Concat(arg0, arg1) - - # since target_equivalent and given may not be structurally equal, we compare their fully simplified forms - if eq(simplify(given), simplify(target_equivalent)): - # here we have: given == target_equivalent == target - return target - - return None - - if expr.decl().name() == "concat" and expr.num_args() >= 2: - new_args = [] - - i = 0 - n = expr.num_args() - - # apply normalize_extract for each pair of adjacent arguments - while i < n - 1: - arg0 = expr.arg(i) - arg1 = expr.arg(i + 1) - - arg0_arg1 = normalize_extract(arg0, arg1) - - if arg0_arg1 is None: # not simplified - new_args.append(arg0) - i += 1 - else: # simplified into a single term - new_args.append(arg0_arg1) - i += 2 - - # handle the last element - if i == n - 1: - new_args.append(expr.arg(i)) - - return concat(new_args) - - return expr + pass class SolidityStorage(Storage): @@ -1133,7 +1133,7 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: @classmethod def decode(cls, loc: Any) -> Any: - loc = cls.normalize(loc) + loc = normalize(loc) # m[k] : hash(k.m) if loc.decl().name() == "sha3_512": args = loc.arg(0) @@ -1146,7 +1146,7 @@ def decode(cls, loc: Any) -> Any: return cls.decode(base) + (con(0),) # m[k] : hash(k.m) where |k| != 256-bit elif loc.decl().name().startswith("sha3_"): - sha3_input = cls.normalize(loc.arg(0)) + sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat" and sha3_input.num_args() == 2: offset = simplify(sha3_input.arg(0)) base = simplify(sha3_input.arg(1)) @@ -1235,14 +1235,14 @@ def store(cls, ex: Exec, addr: Any, loc: Any, val: Any) -> None: @classmethod def decode(cls, loc: Any) -> Any: - loc = cls.normalize(loc) + loc = normalize(loc) if loc.decl().name() == "sha3_512": # hash(hi,lo), recursively args = loc.arg(0) hi = cls.decode(simplify(Extract(511, 256, args))) lo = cls.decode(simplify(Extract(255, 0, args))) return cls.simple_hash(Concat(hi, lo)) elif loc.decl().name().startswith("sha3_"): - sha3_input = cls.normalize(loc.arg(0)) + sha3_input = normalize(loc.arg(0)) if sha3_input.decl().name() == "concat": decoded_sha3_input_args = [ cls.decode(sha3_input.arg(i)) for i in range(sha3_input.num_args()) @@ -1386,6 +1386,8 @@ def bitsize(w: Word) -> int: return 256 - w.arg(0).size() return 256 + w1 = normalize(w1) + if w1.decl().name() == "bvmul" and w1.num_args() == 2: x = w1.arg(0) y = w1.arg(1)