-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for prank(sender, origin) and startPrank(sender, origin) cheatcodes #336
Changes from 19 commits
f1e803b
022e2df
c1d5f9e
15ffcdb
73f4478
d62c1eb
1c3373b
9939ca5
3ec4901
324f77d
b6c1e48
c051be4
8bb0d6f
707b9e1
656c359
de64d95
09b6438
1689032
09b8066
df2c7a3
1ea690c
c1c64e6
17eba86
24f1493
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,15 @@ | ||
[submodule "tests/lib/forge-std"] | ||
path = tests/lib/forge-std | ||
url = https://github.com/foundry-rs/forge-std | ||
shallow = true | ||
[submodule "tests/lib/halmos-cheatcodes"] | ||
path = tests/lib/halmos-cheatcodes | ||
url = https://github.com/a16z/halmos-cheatcodes | ||
shallow = true | ||
[submodule "tests/lib/openzeppelin-contracts"] | ||
path = tests/lib/openzeppelin-contracts | ||
url = https://github.com/OpenZeppelin/openzeppelin-contracts | ||
shallow = true | ||
[submodule "tests/lib/solmate"] | ||
path = tests/lib/solmate | ||
url = https://github.com/transmissions11/solmate | ||
shallow = true | ||
[submodule "tests/lib/solady"] | ||
path = tests/lib/solady | ||
url = https://github.com/Vectorized/solady | ||
shallow = true | ||
[submodule "tests/lib/multicaller"] | ||
path = tests/lib/multicaller | ||
url = https://github.com/Vectorized/multicaller | ||
shallow = true |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
multicaller/=../../tests/lib/multicaller/src/ | ||
multicaller/=src/multicaller/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
// SPDX-License-Identifier: MIT | ||
pragma solidity ^0.8.4; | ||
|
||
/// from Vectorized/multicaller@v1.3.2 | ||
|
||
/** | ||
* @title MulticallerWithSender | ||
* @author vectorized.eth | ||
* @notice Contract that allows for efficient aggregation of multiple calls | ||
* in a single transaction, while "forwarding" the `msg.sender`. | ||
*/ | ||
contract MulticallerWithSender { | ||
// ============================================================= | ||
// ERRORS | ||
// ============================================================= | ||
|
||
/** | ||
* @dev The lengths of the input arrays are not the same. | ||
*/ | ||
error ArrayLengthsMismatch(); | ||
|
||
/** | ||
* @dev This function does not support reentrancy. | ||
*/ | ||
error Reentrancy(); | ||
|
||
// ============================================================= | ||
// CONSTRUCTOR | ||
// ============================================================= | ||
|
||
constructor() payable { | ||
assembly { | ||
// Throughout this code, we will abuse returndatasize | ||
// in place of zero anywhere before a call to save a bit of gas. | ||
// We will use storage slot zero to store the caller at | ||
// bits [0..159] and reentrancy guard flag at bit 160. | ||
sstore(returndatasize(), shl(160, 1)) | ||
} | ||
} | ||
|
||
// ============================================================= | ||
// AGGREGATION OPERATIONS | ||
// ============================================================= | ||
|
||
/** | ||
* @dev Returns the address that called `aggregateWithSender` on this contract. | ||
* The value is always the zero address outside a transaction. | ||
*/ | ||
receive() external payable { | ||
assembly { | ||
mstore(returndatasize(), and(sub(shl(160, 1), 1), sload(returndatasize()))) | ||
return(returndatasize(), 0x20) | ||
} | ||
} | ||
|
||
/** | ||
* @dev Aggregates multiple calls in a single transaction. | ||
* This method will set `sender` to the `msg.sender` temporarily | ||
* for the span of its execution. | ||
* This method does not support reentrancy. | ||
* @param targets An array of addresses to call. | ||
* @param data An array of calldata to forward to the targets. | ||
* @param values How much ETH to forward to each target. | ||
* @return An array of the returndata from each call. | ||
*/ | ||
function aggregateWithSender( | ||
address[] calldata targets, | ||
bytes[] calldata data, | ||
uint256[] calldata values | ||
) external payable returns (bytes[] memory) { | ||
assembly { | ||
if iszero(and(eq(targets.length, data.length), eq(data.length, values.length))) { | ||
// Store the function selector of `ArrayLengthsMismatch()`. | ||
mstore(returndatasize(), 0x3b800a46) | ||
// Revert with (offset, size). | ||
revert(0x1c, 0x04) | ||
} | ||
|
||
if iszero(and(sload(returndatasize()), shl(160, 1))) { | ||
// Store the function selector of `Reentrancy()`. | ||
mstore(returndatasize(), 0xab143c06) | ||
// Revert with (offset, size). | ||
revert(0x1c, 0x04) | ||
} | ||
|
||
mstore(returndatasize(), 0x20) // Store the memory offset of the `results`. | ||
mstore(0x20, data.length) // Store `data.length` into `results`. | ||
// Early return if no data. | ||
if iszero(data.length) { return(returndatasize(), 0x40) } | ||
|
||
// Set the sender slot temporarily for the span of this transaction. | ||
sstore(returndatasize(), caller()) | ||
|
||
let results := 0x40 | ||
// Left shift by 5 is equivalent to multiplying by 0x20. | ||
data.length := shl(5, data.length) | ||
// Copy the offsets from calldata into memory. | ||
calldatacopy(results, data.offset, data.length) | ||
// Offset into `results`. | ||
let resultsOffset := data.length | ||
// Pointer to the end of `results`. | ||
// Recycle `data.length` to avoid stack too deep. | ||
data.length := add(results, data.length) | ||
|
||
for {} 1 {} { | ||
// The offset of the current bytes in the calldata. | ||
let o := add(data.offset, mload(results)) | ||
let memPtr := add(resultsOffset, 0x40) | ||
// Copy the current bytes from calldata to the memory. | ||
calldatacopy( | ||
memPtr, | ||
add(o, 0x20), // The offset of the current bytes' bytes. | ||
calldataload(o) // The length of the current bytes. | ||
) | ||
if iszero( | ||
call( | ||
gas(), // Remaining gas. | ||
calldataload(targets.offset), // Address to call. | ||
calldataload(values.offset), // ETH to send. | ||
memPtr, // Start of input calldata in memory. | ||
calldataload(o), // Size of input calldata. | ||
0x00, // We will use returndatacopy instead. | ||
0x00 // We will use returndatacopy instead. | ||
) | ||
) { | ||
// Bubble up the revert if the call reverts. | ||
returndatacopy(0x00, 0x00, returndatasize()) | ||
revert(0x00, returndatasize()) | ||
} | ||
// Advance the `targets.offset`. | ||
targets.offset := add(targets.offset, 0x20) | ||
// Advance the `values.offset`. | ||
values.offset := add(values.offset, 0x20) | ||
// Append the current `resultsOffset` into `results`. | ||
mstore(results, resultsOffset) | ||
results := add(results, 0x20) | ||
// Append the returndatasize, and the returndata. | ||
mstore(memPtr, returndatasize()) | ||
returndatacopy(add(memPtr, 0x20), 0x00, returndatasize()) | ||
// Advance the `resultsOffset` by `returndatasize() + 0x20`, | ||
// rounded up to the next multiple of 0x20. | ||
resultsOffset := and(add(add(resultsOffset, returndatasize()), 0x3f), not(0x1f)) | ||
if iszero(lt(results, data.length)) { break } | ||
} | ||
// Restore the `sender` slot. | ||
sstore(0, shl(160, 1)) | ||
// Direct return. | ||
return(0x00, add(resultsOffset, 0x40)) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,58 +75,93 @@ def stringified_bytes_to_bytes(hexstring: str) -> ByteVec: | |
return ByteVec(ret_bytes) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PrankResult: | ||
sender: Address | None = None | ||
origin: Address | None = None | ||
|
||
def __bool__(self) -> bool: | ||
""" | ||
True iff either sender or origin is set. | ||
""" | ||
return self.sender is not None or self.origin is not None | ||
|
||
def __str__(self) -> str: | ||
return f"{hexify(self.sender)}, {hexify(self.origin)}" | ||
|
||
|
||
NO_PRANK = PrankResult() | ||
|
||
|
||
@dataclass | ||
class Prank: | ||
addr: Any # prank address | ||
keep: bool # start / stop prank | ||
""" | ||
A mutable object to store current prank context, one per execution context. | ||
|
||
Because it's mutable, it must be copied across contexts. | ||
|
||
def __init__(self, addr: Any = None, keep: bool = False) -> None: | ||
if addr is not None: | ||
assert_address(addr) | ||
self.addr = addr | ||
self.keep = keep | ||
Can test for the existence of an active prank with `if prank: ...` | ||
|
||
A prank is active if either sender or origin is set. | ||
Technically supports pranking origin but not sender, which is not | ||
possible with the current cheatcodes: | ||
- prank(address) sets sender | ||
- prank(address, address) sets both sender and origin | ||
""" | ||
|
||
active: PrankResult = NO_PRANK # active prank context | ||
keep: bool = False # start / stop prank | ||
|
||
def __bool__(self) -> bool: | ||
""" | ||
True iff either sender or origin is set. | ||
""" | ||
return bool(self.active) | ||
|
||
def __str__(self) -> str: | ||
if self.addr: | ||
if self.keep: | ||
return f"startPrank({str(self.addr)})" | ||
else: | ||
return f"prank({str(self.addr)})" | ||
else: | ||
return "None" | ||
if not self: | ||
return "no active prank" | ||
|
||
fn_name = "startPrank" if self.keep else "prank" | ||
return f"{fn_name}({str(self.active)})" | ||
|
||
def lookup(self, to: Address) -> PrankResult: | ||
""" | ||
If `to` is an eligible prank destination, return the active prank context. | ||
|
||
If `keep` is False, this resets the prank context. | ||
""" | ||
|
||
def lookup(self, this: Any, to: Any) -> Any: | ||
assert_address(this) | ||
assert_address(to) | ||
caller = this | ||
if ( | ||
self.addr is not None | ||
self | ||
and not eq(to, hevm_cheat_code.address) | ||
and not eq(to, halmos_cheat_code.address) | ||
): | ||
caller = self.addr | ||
result = self.active | ||
if not self.keep: | ||
self.addr = None | ||
return caller | ||
self.stopPrank() | ||
return result | ||
|
||
def prank(self, addr: Any) -> bool: | ||
assert_address(addr) | ||
if self.addr is not None: | ||
return NO_PRANK | ||
|
||
def prank(self, sender: Address, origin: Address | None = None) -> bool: | ||
assert_address(sender) | ||
if self.active: | ||
return False | ||
self.addr = addr | ||
|
||
self.active = PrankResult(sender=sender, origin=origin) | ||
self.keep = False | ||
return True | ||
|
||
def startPrank(self, addr: Any) -> bool: | ||
assert_address(addr) | ||
if self.addr is not None: | ||
return False | ||
self.addr = addr | ||
self.keep = True | ||
return True | ||
def startPrank(self, sender: Address, origin: Address | None = None) -> bool: | ||
result = self.prank(sender, origin) | ||
self.keep = result if result else self.keep | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is accurate but a bit convoluted. For better code readability, I'd suggest having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah you're right, in 17eba86 |
||
return result | ||
|
||
def stopPrank(self) -> bool: | ||
# stopPrank is allowed to call even when no active prank exists | ||
self.addr = None | ||
# stopPrank calls are allowed even when no active prank exists | ||
self.active = NO_PRANK | ||
self.keep = False | ||
return True | ||
|
||
|
@@ -282,9 +317,15 @@ class hevm_cheat_code: | |
# bytes4(keccak256("prank(address)")) | ||
prank_sig: int = 0xCA669FA7 | ||
|
||
# bytes4(keccak256("prank(address,address)")) | ||
prank_addr_addr_sig: int = 0x47E50CCE | ||
|
||
# bytes4(keccak256("startPrank(address)")) | ||
start_prank_sig: int = 0x06447D56 | ||
|
||
# bytes4(keccak256("startPrank(address,address)")) | ||
start_prank_addr_addr_sig: int = 0x45B56078 | ||
|
||
# bytes4(keccak256("stopPrank()")) | ||
stop_prank_sig: int = 0x90C5013B | ||
|
||
|
@@ -381,8 +422,17 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: | |
|
||
# vm.prank(address) | ||
elif funsig == hevm_cheat_code.prank_sig: | ||
address = uint160(arg.get_word(4)) | ||
result = ex.prank.prank(address) | ||
sender = uint160(arg.get_word(4)) | ||
result = ex.prank.prank(sender) | ||
if not result: | ||
raise HalmosException("You have an active prank already.") | ||
return ret | ||
|
||
# vm.prank(address sender, address origin) | ||
elif funsig == hevm_cheat_code.prank_addr_addr_sig: | ||
sender = uint160(arg.get_word(4)) | ||
origin = uint160(arg.get_word(36)) | ||
result = ex.prank.prank(sender, origin) | ||
if not result: | ||
raise HalmosException("You have an active prank already.") | ||
return ret | ||
|
@@ -395,6 +445,15 @@ def handle(sevm, ex, arg: ByteVec, stack, step_id) -> Optional[ByteVec]: | |
raise HalmosException("You have an active prank already.") | ||
return ret | ||
|
||
# vm.startPrank(address sender, address origin) | ||
elif funsig == hevm_cheat_code.start_prank_addr_addr_sig: | ||
sender = uint160(arg.get_word(4)) | ||
origin = uint160(arg.get_word(36)) | ||
result = ex.prank.startPrank(sender, origin) | ||
if not result: | ||
raise HalmosException("You have an active prank already.") | ||
return ret | ||
|
||
# vm.stopPrank() | ||
elif funsig == hevm_cheat_code.stop_prank_sig: | ||
ex.prank.stopPrank() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytest -n 1
is what fixed CI, it seems that the issue was running multiple forge commands in parallel would result in race conditions trying to clone submodules or install solc