Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(forge) _expectRevertCheatcode #6841

Merged
merged 16 commits into from
Jan 22, 2024
60 changes: 60 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions crates/cheatcodes/assets/cheatcodes.schema.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions crates/cheatcodes/spec/src/cheatcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ pub enum Status {
///
/// Use of removed cheatcodes will result in a hard error.
Removed,
/// The cheatcode is only used internally for foundry testing and may be changed or removed at
/// any time.
///
/// Use of internal cheatcodes is discouraged and will result in a warning.
Internal,
}

/// Cheatcode groups.
Expand Down
12 changes: 12 additions & 0 deletions crates/cheatcodes/spec/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,18 @@ interface Vm {
#[cheatcode(group = Testing, safety = Unsafe)]
function expectRevert(bytes calldata revertData) external;

/// Expects an error on next cheatcode call with any revert data.
#[cheatcode(group = Testing, safety = Unsafe, status = Internal)]
function _expectCheatcodeRevert() external;

/// Expects an error on next cheatcode call that starts with the revert data.
#[cheatcode(group = Testing, safety = Unsafe, status = Internal)]
function _expectCheatcodeRevert(bytes4 revertData) external;

/// Expects an error on next cheatcode call that exactly matches the revert data.
#[cheatcode(group = Testing, safety = Unsafe, status = Internal)]
function _expectCheatcodeRevert(bytes calldata revertData) external;

/// Only allows memory writes to offsets [0x00, 0x60) ∪ [min, max) in the current subcontext. If any other
/// memory is written to, the test will fail. Can be called multiple times to add more ranges to the set.
#[cheatcode(group = Testing, safety = Unsafe)]
Expand Down
111 changes: 67 additions & 44 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ use crate::{
},
script::Broadcast,
test::expect::{
self, ExpectedCallData, ExpectedCallTracker, ExpectedCallType, ExpectedEmit, ExpectedRevert,
self, ExpectedCallData, ExpectedCallTracker, ExpectedCallType, ExpectedEmit,
ExpectedRevert, ExpectedRevertKind,
},
CheatsConfig, CheatsCtxt, Error, Result, Vm,
};
Expand All @@ -22,7 +23,7 @@ use ethers_signers::LocalWallet;
use foundry_common::{evm::Breakpoints, provider::alloy::RpcUrl, types::ToEthers};
use foundry_evm_core::{
backend::{DatabaseError, DatabaseExt, RevertDiagnostic},
constants::{CHEATCODE_ADDRESS, DEFAULT_CREATE2_DEPLOYER, HARDHAT_CONSOLE_ADDRESS, MAGIC_SKIP},
constants::{CHEATCODE_ADDRESS, DEFAULT_CREATE2_DEPLOYER, HARDHAT_CONSOLE_ADDRESS},
};
use itertools::Itertools;
use revm::{
Expand Down Expand Up @@ -130,9 +131,6 @@ pub struct Cheatcodes {
/// Remembered private keys
pub script_wallets: Vec<LocalWallet>,

/// Whether the skip cheatcode was activated
pub skip: bool,

/// Prank information
pub prank: Option<Prank>,

Expand Down Expand Up @@ -919,61 +917,84 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
status: InstructionResult,
retdata: Bytes,
) -> (InstructionResult, Gas, Bytes) {
if call.contract == CHEATCODE_ADDRESS || call.contract == HARDHAT_CONSOLE_ADDRESS {
return (status, remaining_gas, retdata);
}

if data.journaled_state.depth() == 0 && self.skip {
return (
InstructionResult::Revert,
remaining_gas,
super::Error::from(MAGIC_SKIP).abi_encode().into(),
);
}

// Clean up pranks
if let Some(prank) = &self.prank {
if data.journaled_state.depth() == prank.depth {
data.env.tx.caller = prank.prank_origin;
let cheatcode_call =
call.contract == CHEATCODE_ADDRESS || call.contract == HARDHAT_CONSOLE_ADDRESS;

// Clean up pranks/broadcasts if it's not a cheatcode call end. We shouldn't do
// it for cheatcode calls because they are not appplied for cheatcodes in the `call` hook.
// This should be placed before the revert handling, because we might exit early there
if !cheatcode_call {
klkvr marked this conversation as resolved.
Show resolved Hide resolved
// Clean up pranks
if let Some(prank) = &self.prank {
if data.journaled_state.depth() == prank.depth {
data.env.tx.caller = prank.prank_origin;

// Clean single-call prank once we have returned to the original depth
if prank.single_call {
let _ = self.prank.take();
// Clean single-call prank once we have returned to the original depth
if prank.single_call {
let _ = self.prank.take();
}
}
}
}

// Clean up broadcast
if let Some(broadcast) = &self.broadcast {
if data.journaled_state.depth() == broadcast.depth {
data.env.tx.caller = broadcast.original_origin;
// Clean up broadcast
if let Some(broadcast) = &self.broadcast {
if data.journaled_state.depth() == broadcast.depth {
data.env.tx.caller = broadcast.original_origin;

// Clean single-call broadcast once we have returned to the original depth
if broadcast.single_call {
let _ = self.broadcast.take();
// Clean single-call broadcast once we have returned to the original depth
if broadcast.single_call {
let _ = self.broadcast.take();
}
}
}
}

// Handle expected reverts
if let Some(expected_revert) = &self.expected_revert {
if data.journaled_state.depth() <= expected_revert.depth {
let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
return match expect::handle_expect_revert(
false,
expected_revert.reason.as_deref(),
status,
retdata,
) {
Err(error) => {
trace!(expected=?expected_revert, ?error, ?status, "Expected revert mismatch");
(InstructionResult::Revert, remaining_gas, error.abi_encode().into())
let needs_processing: bool = match expected_revert.kind {
ExpectedRevertKind::Default => !cheatcode_call,
// `pending_processing` == true means that we're in the `call_end` hook for
// `vm.expectCheatcodeRevert` and shouldn't expect revert here
ExpectedRevertKind::Cheatcode { pending_processing } => {
cheatcode_call && !pending_processing
klkvr marked this conversation as resolved.
Show resolved Hide resolved
}
Ok((_, retdata)) => (InstructionResult::Return, remaining_gas, retdata),
};

if needs_processing {
let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
return match expect::handle_expect_revert(
false,
expected_revert.reason.as_deref(),
status,
retdata,
) {
Err(error) => {
trace!(expected=?expected_revert, ?error, ?status, "Expected revert mismatch");
(InstructionResult::Revert, remaining_gas, error.abi_encode().into())
}
Ok((_, retdata)) => (InstructionResult::Return, remaining_gas, retdata),
};
}

// Flip `pending_processing` flag for cheatcode revert expectations, marking that
// we've exited the `expectCheatcodeRevert` call scope
if let ExpectedRevertKind::Cheatcode { pending_processing } =
&mut self.expected_revert.as_mut().unwrap().kind
{
if *pending_processing {
*pending_processing = false;
}
}
}
}

// Exit early for calls to cheatcodes as other logic is not relevant for cheatcode
// invocations
if cheatcode_call {
return (status, remaining_gas, retdata);
}

// If `startStateDiffRecording` has been called, update the `reverted` status of the
// previous call depth's recorded accesses, if any
if let Some(recorded_account_diffs_stack) = &mut self.recorded_account_diffs_stack {
Expand Down Expand Up @@ -1290,7 +1311,9 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {

// Handle expected reverts
if let Some(expected_revert) = &self.expected_revert {
if data.journaled_state.depth() <= expected_revert.depth {
if data.journaled_state.depth() <= expected_revert.depth &&
matches!(expected_revert.kind, ExpectedRevertKind::Default)
{
let expected_revert = std::mem::take(&mut self.expected_revert).unwrap();
return match expect::handle_expect_revert(
true,
Expand Down
1 change: 0 additions & 1 deletion crates/cheatcodes/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ impl Cheatcode for skipCall {
// Skip should not work if called deeper than at test level.
// Since we're not returning the magic skip bytes, this will cause a test failure.
ensure!(ccx.data.journaled_state.depth() <= 1, "`skip` can only be used at test level");
ccx.state.skip = true;
Err(MAGIC_SKIP.into())
} else {
Ok(Default::default())
Expand Down
60 changes: 54 additions & 6 deletions crates/cheatcodes/src/test/expect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,27 @@ pub enum ExpectedCallType {
Count,
}

#[derive(Clone, Debug, Default)]
/// The type of expected revert.
#[derive(Clone, Debug)]
pub enum ExpectedRevertKind {
/// Expects revert from the next non-cheatcode call.
Default,
/// Expects revert from the next cheatcode call.
///
/// The `pending_processing` flag is used to track whether we have exited
/// `expectCheatcodeRevert` context or not.
/// We have to track it to avoid expecting `expectCheatcodeRevert` call to revert itself.
Cheatcode { pending_processing: bool },
}

#[derive(Clone, Debug)]
pub struct ExpectedRevert {
/// The expected data returned by the revert, None being any
pub reason: Option<Vec<u8>>,
/// The depth at which the revert is expected
pub depth: u64,
/// The type of expected revert.
pub kind: ExpectedRevertKind,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -222,21 +237,41 @@ impl Cheatcode for expectEmit_3Call {
impl Cheatcode for expectRevert_0Call {
fn apply_full<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self {} = self;
expect_revert(ccx.state, None, ccx.data.journaled_state.depth())
expect_revert(ccx.state, None, ccx.data.journaled_state.depth(), false)
}
}

impl Cheatcode for expectRevert_1Call {
fn apply_full<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData.as_ref()), ccx.data.journaled_state.depth())
expect_revert(ccx.state, Some(revertData.as_ref()), ccx.data.journaled_state.depth(), false)
}
}

impl Cheatcode for expectRevert_2Call {
fn apply_full<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData), ccx.data.journaled_state.depth())
expect_revert(ccx.state, Some(revertData), ccx.data.journaled_state.depth(), false)
}
}

impl Cheatcode for _expectCheatcodeRevert_0Call {
fn apply_full<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
expect_revert(ccx.state, None, ccx.data.journaled_state.depth(), true)
}
}

impl Cheatcode for _expectCheatcodeRevert_1Call {
fn apply_full<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData.as_ref()), ccx.data.journaled_state.depth(), true)
}
}

impl Cheatcode for _expectCheatcodeRevert_2Call {
fn apply_full<DB: DatabaseExt>(&self, ccx: &mut CheatsCtxt<DB>) -> Result {
let Self { revertData } = self;
expect_revert(ccx.state, Some(revertData), ccx.data.journaled_state.depth(), true)
}
}

Expand Down Expand Up @@ -430,12 +465,25 @@ pub(crate) fn handle_expect_emit(
}
}

fn expect_revert(state: &mut Cheatcodes, reason: Option<&[u8]>, depth: u64) -> Result {
fn expect_revert(
state: &mut Cheatcodes,
reason: Option<&[u8]>,
depth: u64,
cheatcode: bool,
) -> Result {
ensure!(
state.expected_revert.is_none(),
"you must call another function prior to expecting a second revert"
);
state.expected_revert = Some(ExpectedRevert { reason: reason.map(<[_]>::to_vec), depth });
state.expected_revert = Some(ExpectedRevert {
reason: reason.map(<[_]>::to_vec),
depth,
kind: if cheatcode {
ExpectedRevertKind::Cheatcode { pending_processing: true }
} else {
ExpectedRevertKind::Default
},
});
Ok(Default::default())
}

Expand Down
2 changes: 1 addition & 1 deletion testdata/cheats/Etch.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ contract EtchTest is DSTest {
function testEtchNotAvailableOnPrecompiles() public {
address target = address(1);
bytes memory code = hex"1010";
vm.expectRevert(bytes("cannot call `etch` on precompile 0x0000000000000000000000000000000000000001"));
vm._expectCheatcodeRevert(bytes("cannot call `etch` on precompile 0x0000000000000000000000000000000000000001"));
vm.etch(target, code);
}
}
Loading