Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for prank(sender, origin) and startPrank(sender, origin) cheatcodes #336

Merged
merged 24 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f1e803b
WIP: beginning of startPrank(address sender, address origin)
karmacoma-eth Jul 31, 2024
022e2df
add tests/test_prank.py
karmacoma-eth Aug 1, 2024
c1d5f9e
WIP: rework Prank class
karmacoma-eth Aug 1, 2024
15ffcdb
add an origin field to Exec and wire pranks to it
karmacoma-eth Aug 1, 2024
73f4478
finish wiring up new prank cheatcodes
karmacoma-eth Aug 2, 2024
d62c1eb
Merge branch 'main' into start-prank-origin
karmacoma-eth Aug 2, 2024
1c3373b
simplify prank Target: recordCaller() -> reset()
karmacoma-eth Aug 2, 2024
9939ca5
update tests/lib/multicaller to v1.3.2
karmacoma-eth Aug 2, 2024
3ec4901
remove shallow from .gitmodules
karmacoma-eth Aug 2, 2024
324f77d
delete multicaller
karmacoma-eth Aug 2, 2024
b6c1e48
replace multicaller submodule with a snapshot of the file we need
karmacoma-eth Aug 2, 2024
c051be4
update tests/lib/openzeppelin-contracts@v5.0.2
karmacoma-eth Aug 2, 2024
8bb0d6f
test.yml: recursively checkout submodules
karmacoma-eth Aug 2, 2024
707b9e1
test.yml: add --debug to halmos options
karmacoma-eth Aug 2, 2024
656c359
test.yml: get back to a single pytest worker
karmacoma-eth Aug 2, 2024
de64d95
Revert "test.yml: recursively checkout submodules"
karmacoma-eth Aug 2, 2024
09b6438
Merge branch 'main' into start-prank-origin
karmacoma-eth Aug 2, 2024
1689032
add Prank test with nested contexts
karmacoma-eth Aug 3, 2024
09b8066
add more Prank tests with nested contexts
karmacoma-eth Aug 3, 2024
df2c7a3
wire Prank inside CallContext rather than Exec
karmacoma-eth Aug 9, 2024
1ea690c
add a startPrank in constructor test
karmacoma-eth Aug 9, 2024
c1c64e6
add test_prank_in_context
karmacoma-eth Aug 9, 2024
17eba86
less convoluted code in prank/startPrank
karmacoma-eth Aug 13, 2024
24f1493
Merge branch 'main' into start-prank-origin
karmacoma-eth Aug 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ jobs:
run: python -m pip install -e .

- name: Run pytest
run: pytest -n 4 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="-st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }}
run: pytest -n 1 -v -k "not long and not ffi" --ignore=tests/lib --halmos-options="--debug -st ${{ matrix.parallel }} --storage-layout ${{ matrix.storage-layout }} --solver-timeout-assertion 0 ${{ inputs.halmos-options }}" ${{ inputs.pytest-options }}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytest -n 1 is what fixed CI, it seems that the issue was running multiple forge commands in parallel would result in race conditions trying to clone submodules or install solc

9 changes: 0 additions & 9 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
[submodule "tests/lib/forge-std"]
path = tests/lib/forge-std
url = https://github.com/foundry-rs/forge-std
shallow = true
[submodule "tests/lib/halmos-cheatcodes"]
path = tests/lib/halmos-cheatcodes
url = https://github.com/a16z/halmos-cheatcodes
shallow = true
[submodule "tests/lib/openzeppelin-contracts"]
path = tests/lib/openzeppelin-contracts
url = https://github.com/OpenZeppelin/openzeppelin-contracts
shallow = true
[submodule "tests/lib/solmate"]
path = tests/lib/solmate
url = https://github.com/transmissions11/solmate
shallow = true
[submodule "tests/lib/solady"]
path = tests/lib/solady
url = https://github.com/Vectorized/solady
shallow = true
[submodule "tests/lib/multicaller"]
path = tests/lib/multicaller
url = https://github.com/Vectorized/multicaller
shallow = true
2 changes: 1 addition & 1 deletion examples/simple/remappings.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
multicaller/=../../tests/lib/multicaller/src/
multicaller/=src/multicaller/
151 changes: 151 additions & 0 deletions examples/simple/src/multicaller/MulticallerWithSender.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

/// from Vectorized/multicaller@v1.3.2

/**
* @title MulticallerWithSender
* @author vectorized.eth
* @notice Contract that allows for efficient aggregation of multiple calls
* in a single transaction, while "forwarding" the `msg.sender`.
*/
contract MulticallerWithSender {
// =============================================================
// ERRORS
// =============================================================

/**
* @dev The lengths of the input arrays are not the same.
*/
error ArrayLengthsMismatch();

/**
* @dev This function does not support reentrancy.
*/
error Reentrancy();

// =============================================================
// CONSTRUCTOR
// =============================================================

constructor() payable {
assembly {
// Throughout this code, we will abuse returndatasize
// in place of zero anywhere before a call to save a bit of gas.
// We will use storage slot zero to store the caller at
// bits [0..159] and reentrancy guard flag at bit 160.
sstore(returndatasize(), shl(160, 1))
}
}

// =============================================================
// AGGREGATION OPERATIONS
// =============================================================

/**
* @dev Returns the address that called `aggregateWithSender` on this contract.
* The value is always the zero address outside a transaction.
*/
receive() external payable {
assembly {
mstore(returndatasize(), and(sub(shl(160, 1), 1), sload(returndatasize())))
return(returndatasize(), 0x20)
}
}

/**
* @dev Aggregates multiple calls in a single transaction.
* This method will set `sender` to the `msg.sender` temporarily
* for the span of its execution.
* This method does not support reentrancy.
* @param targets An array of addresses to call.
* @param data An array of calldata to forward to the targets.
* @param values How much ETH to forward to each target.
* @return An array of the returndata from each call.
*/
function aggregateWithSender(
address[] calldata targets,
bytes[] calldata data,
uint256[] calldata values
) external payable returns (bytes[] memory) {
assembly {
if iszero(and(eq(targets.length, data.length), eq(data.length, values.length))) {
// Store the function selector of `ArrayLengthsMismatch()`.
mstore(returndatasize(), 0x3b800a46)
// Revert with (offset, size).
revert(0x1c, 0x04)
}

if iszero(and(sload(returndatasize()), shl(160, 1))) {
// Store the function selector of `Reentrancy()`.
mstore(returndatasize(), 0xab143c06)
// Revert with (offset, size).
revert(0x1c, 0x04)
}

mstore(returndatasize(), 0x20) // Store the memory offset of the `results`.
mstore(0x20, data.length) // Store `data.length` into `results`.
// Early return if no data.
if iszero(data.length) { return(returndatasize(), 0x40) }

// Set the sender slot temporarily for the span of this transaction.
sstore(returndatasize(), caller())

let results := 0x40
// Left shift by 5 is equivalent to multiplying by 0x20.
data.length := shl(5, data.length)
// Copy the offsets from calldata into memory.
calldatacopy(results, data.offset, data.length)
// Offset into `results`.
let resultsOffset := data.length
// Pointer to the end of `results`.
// Recycle `data.length` to avoid stack too deep.
data.length := add(results, data.length)

for {} 1 {} {
// The offset of the current bytes in the calldata.
let o := add(data.offset, mload(results))
let memPtr := add(resultsOffset, 0x40)
// Copy the current bytes from calldata to the memory.
calldatacopy(
memPtr,
add(o, 0x20), // The offset of the current bytes' bytes.
calldataload(o) // The length of the current bytes.
)
if iszero(
call(
gas(), // Remaining gas.
calldataload(targets.offset), // Address to call.
calldataload(values.offset), // ETH to send.
memPtr, // Start of input calldata in memory.
calldataload(o), // Size of input calldata.
0x00, // We will use returndatacopy instead.
0x00 // We will use returndatacopy instead.
)
) {
// Bubble up the revert if the call reverts.
returndatacopy(0x00, 0x00, returndatasize())
revert(0x00, returndatasize())
}
// Advance the `targets.offset`.
targets.offset := add(targets.offset, 0x20)
// Advance the `values.offset`.
values.offset := add(values.offset, 0x20)
// Append the current `resultsOffset` into `results`.
mstore(results, resultsOffset)
results := add(results, 0x20)
// Append the returndatasize, and the returndata.
mstore(memPtr, returndatasize())
returndatacopy(add(memPtr, 0x20), 0x00, returndatasize())
// Advance the `resultsOffset` by `returndatasize() + 0x20`,
// rounded up to the next multiple of 0x20.
resultsOffset := and(add(add(resultsOffset, returndatasize()), 0x3f), not(0x1f))
if iszero(lt(results, data.length)) { break }
}
// Restore the `sender` slot.
sstore(0, shl(160, 1))
// Direct return.
return(0x00, add(resultsOffset, 0x40))
}
}
}
1 change: 1 addition & 0 deletions src/halmos/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ def run(
jumpis={},
symbolic=args.symbolic_storage,
prank=Prank(), # prank is reset after setUp()
origin=setup_ex.origin,
#
path=path,
alias=setup_ex.alias.copy(),
Expand Down
131 changes: 95 additions & 36 deletions src/halmos/cheatcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,58 +75,93 @@ def stringified_bytes_to_bytes(hexstring: str) -> ByteVec:
return ByteVec(ret_bytes)


@dataclass(frozen=True)
class PrankResult:
sender: Address | None = None
origin: Address | None = None

def __bool__(self) -> bool:
"""
True iff either sender or origin is set.
"""
return self.sender is not None or self.origin is not None

def __str__(self) -> str:
return f"{hexify(self.sender)}, {hexify(self.origin)}"


NO_PRANK = PrankResult()


@dataclass
class Prank:
addr: Any # prank address
keep: bool # start / stop prank
"""
A mutable object to store current prank context, one per execution context.

Because it's mutable, it must be copied across contexts.

def __init__(self, addr: Any = None, keep: bool = False) -> None:
if addr is not None:
assert_address(addr)
self.addr = addr
self.keep = keep
Can test for the existence of an active prank with `if prank: ...`

A prank is active if either sender or origin is set.
Technically supports pranking origin but not sender, which is not
possible with the current cheatcodes:
- prank(address) sets sender
- prank(address, address) sets both sender and origin
"""

active: PrankResult = NO_PRANK # active prank context
keep: bool = False # start / stop prank

def __bool__(self) -> bool:
"""
True iff either sender or origin is set.
"""
return bool(self.active)

def __str__(self) -> str:
if self.addr:
if self.keep:
return f"startPrank({str(self.addr)})"
else:
return f"prank({str(self.addr)})"
else:
return "None"
if not self:
return "no active prank"

fn_name = "startPrank" if self.keep else "prank"
return f"{fn_name}({str(self.active)})"

def lookup(self, to: Address) -> PrankResult:
"""
If `to` is an eligible prank destination, return the active prank context.

If `keep` is False, this resets the prank context.
"""

def lookup(self, this: Any, to: Any) -> Any:
assert_address(this)
assert_address(to)
caller = this
if (
self.addr is not None
self
and not eq(to, hevm_cheat_code.address)
and not eq(to, halmos_cheat_code.address)
):
caller = self.addr
result = self.active
if not self.keep:
self.addr = None
return caller
self.stopPrank()
return result

def prank(self, addr: Any) -> bool:
assert_address(addr)
if self.addr is not None:
return NO_PRANK

def prank(self, sender: Address, origin: Address | None = None) -> bool:
assert_address(sender)
if self.active:
return False
self.addr = addr

self.active = PrankResult(sender=sender, origin=origin)
self.keep = False
return True

def startPrank(self, addr: Any) -> bool:
assert_address(addr)
if self.addr is not None:
return False
self.addr = addr
self.keep = True
return True
def startPrank(self, sender: Address, origin: Address | None = None) -> bool:
result = self.prank(sender, origin)
self.keep = result if result else self.keep
Copy link
Collaborator

@daejunpark daejunpark Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is accurate but a bit convoluted. For better code readability, I'd suggest having prank() take an additional argument (defaulting to False if not provided) that is assigned to self.keep. Then, have startPrank() call prank(..., keep=True).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah you're right, in 17eba86

return result

def stopPrank(self) -> bool:
# stopPrank is allowed to call even when no active prank exists
self.addr = None
# stopPrank calls are allowed even when no active prank exists
self.active = NO_PRANK
self.keep = False
return True

Expand Down Expand Up @@ -282,9 +317,15 @@ class hevm_cheat_code:
# bytes4(keccak256("prank(address)"))
prank_sig: int = 0xCA669FA7

# bytes4(keccak256("prank(address,address)"))
prank_addr_addr_sig: int = 0x47E50CCE

# bytes4(keccak256("startPrank(address)"))
start_prank_sig: int = 0x06447D56

# bytes4(keccak256("startPrank(address,address)"))
start_prank_addr_addr_sig: int = 0x45B56078

# bytes4(keccak256("stopPrank()"))
stop_prank_sig: int = 0x90C5013B

Expand Down Expand Up @@ -381,8 +422,17 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]:

# vm.prank(address)
elif funsig == hevm_cheat_code.prank_sig:
address = uint160(arg.get_word(4))
result = ex.prank.prank(address)
sender = uint160(arg.get_word(4))
result = ex.prank.prank(sender)
if not result:
raise HalmosException("You have an active prank already.")
return ret

# vm.prank(address sender, address origin)
elif funsig == hevm_cheat_code.prank_addr_addr_sig:
sender = uint160(arg.get_word(4))
origin = uint160(arg.get_word(36))
result = ex.prank.prank(sender, origin)
if not result:
raise HalmosException("You have an active prank already.")
return ret
Expand All @@ -395,6 +445,15 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]:
raise HalmosException("You have an active prank already.")
return ret

# vm.startPrank(address sender, address origin)
elif funsig == hevm_cheat_code.start_prank_addr_addr_sig:
sender = uint160(arg.get_word(4))
origin = uint160(arg.get_word(36))
result = ex.prank.startPrank(sender, origin)
if not result:
raise HalmosException("You have an active prank already.")
return ret

# vm.stopPrank()
elif funsig == hevm_cheat_code.stop_prank_sig:
ex.prank.stopPrank()
Expand Down
Loading
Loading