Skip to content

Commit

Permalink
fix(forge): Optionally use create2 factory in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RPate97 committed Dec 23, 2023
1 parent c312c0d commit ac45ba5
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 21 deletions.
4 changes: 4 additions & 0 deletions crates/cheatcodes/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ use std::path::{Path, PathBuf};
pub struct CheatsConfig {
/// Whether the FFI cheatcode is enabled.
pub ffi: bool,
/// Use the create 2 factory in all cases including tests and non-broadcasting scripts.
pub always_use_create_2_factory: bool,
/// RPC storage caching settings determines what chains and endpoints to cache
pub rpc_storage_caching: StorageCachingConfig,
/// All known endpoints and their aliases
Expand Down Expand Up @@ -44,6 +46,7 @@ impl CheatsConfig {

Self {
ffi: evm_opts.ffi,
always_use_create_2_factory: evm_opts.always_use_create_2_factory,
rpc_storage_caching: config.rpc_storage_caching.clone(),
rpc_endpoints,
paths: config.project_paths(),
Expand Down Expand Up @@ -157,6 +160,7 @@ impl Default for CheatsConfig {
fn default() -> Self {
Self {
ffi: false,
always_use_create_2_factory: false,
rpc_storage_caching: Default::default(),
rpc_endpoints: Default::default(),
paths: ProjectPathsConfig::builder().build_with_root("./"),
Expand Down
61 changes: 40 additions & 21 deletions crates/cheatcodes/src/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1183,17 +1183,8 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
data.env.tx.caller = broadcast.new_origin;

if data.journaled_state.depth() == broadcast.depth {
let (bytecode, to, nonce) = match process_create(
broadcast.new_origin,
call.init_code.clone(),
data,
call,
) {
Ok(val) => val,
Err(err) => {
return (InstructionResult::Revert, None, gas, Error::encode(err))
}
};
let (bytecode, to, nonce) =
process_create(broadcast.new_origin, call.init_code.clone(), data, call);

let is_fixed_gas_limit = check_if_fixed_gas_limit(data, call.gas_limit);

Expand Down Expand Up @@ -1222,6 +1213,14 @@ impl<DB: DatabaseExt> Inspector<DB> for Cheatcodes {
}
}

// Apply the Create2 deployer
if self.broadcast.is_some() || self.config.always_use_create_2_factory {
match apply_create2_deployer(data, call, &self.prank, &self.broadcast) {
Ok(_val) => {}
Err(err) => return (InstructionResult::Revert, None, gas, Error::encode(err)),
};
}

// allow cheatcodes from the address of the new contract
// Compute the address *after* any possible broadcast updates, so it's based on the updated
// call inputs
Expand Down Expand Up @@ -1395,18 +1394,22 @@ fn mstore_revert_string(interpreter: &mut Interpreter<'_>, bytes: &[u8]) {
interpreter.return_len = interpreter.shared_memory.len() - starting_offset
}

fn process_create<DB: DatabaseExt>(
broadcast_sender: Address,
bytecode: Bytes,
fn apply_create2_deployer<DB: DatabaseExt>(
data: &mut EVMData<'_, DB>,
call: &mut CreateInputs,
) -> Result<(Bytes, Option<Address>, u64), DB::Error> {
match call.scheme {
CreateScheme::Create => {
call.caller = broadcast_sender;
Ok((bytecode, None, data.journaled_state.account(broadcast_sender).info.nonce))
prank: &Option<Prank>,
broadcast: &Option<Broadcast>,
) -> Result<(), DB::Error> {
if let CreateScheme::Create2 { salt: _ } = call.scheme {
let mut base_depth = 1;
if let Some(prank) = &prank {
base_depth = prank.depth;
} else if let Some(broadcast) = &broadcast {
base_depth = broadcast.depth;
}
CreateScheme::Create2 { salt } => {
// If the create scheme is Create2 and the depth equals the broadcast/prank/default
// depth, then use the default create2 factory as the deployer
if data.journaled_state.depth() == base_depth {
// Sanity checks for our CREATE2 deployer
let info =
&data.journaled_state.load_account(DEFAULT_CREATE2_DEPLOYER, data.db)?.0.info;
Expand All @@ -1419,7 +1422,23 @@ fn process_create<DB: DatabaseExt>(
}

call.caller = DEFAULT_CREATE2_DEPLOYER;
}
}
Ok(())
}

fn process_create<DB: DatabaseExt>(
broadcast_sender: Address,
bytecode: Bytes,
data: &mut EVMData<'_, DB>,
call: &mut CreateInputs,
) -> (Bytes, Option<Address>, u64) {
match call.scheme {
CreateScheme::Create => {
call.caller = broadcast_sender;
(bytecode, None, data.journaled_state.account(broadcast_sender).info.nonce)
}
CreateScheme::Create2 { salt } => {
// We have to increment the nonce of the user address, since this create2 will be done
// by the create2_deployer
let account = data.journaled_state.state().get_mut(&broadcast_sender).unwrap();
Expand All @@ -1429,7 +1448,7 @@ fn process_create<DB: DatabaseExt>(

// Proxy deployer requires the data to be `salt ++ init_code`
let calldata = [&salt.to_be_bytes::<32>()[..], &bytecode[..]].concat();
Ok((calldata.into(), Some(DEFAULT_CREATE2_DEPLOYER), prev))
(calldata.into(), Some(DEFAULT_CREATE2_DEPLOYER), prev)
}
}
}
Expand Down
12 changes: 12 additions & 0 deletions crates/common/src/evm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ pub struct EvmArgs {
#[serde(skip)]
pub ffi: bool,

/// Use the create 2 factory in all cases including tests and non-broadcasting scripts.
#[clap(long)]
#[serde(skip)]
pub always_use_create_2_factory: bool,

/// Verbosity of the EVM.
///
/// Pass multiple times to increase the verbosity (e.g. -v, -vv, -vvv).
Expand Down Expand Up @@ -161,6 +166,13 @@ impl Provider for EvmArgs {
dict.insert("ffi".to_string(), self.ffi.into());
}

if self.always_use_create_2_factory {
dict.insert(
"always_use_create_2_factory".to_string(),
self.always_use_create_2_factory.into(),
);
}

if self.no_storage_caching {
dict.insert("no_storage_caching".to_string(), self.no_storage_caching.into());
}
Expand Down
1 change: 1 addition & 0 deletions crates/config/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ no_match_contract = "Bar"
match_path = "*/Foo*"
no_match_path = "*/Bar*"
ffi = false
always_use_create_2_factory = false
# These are the default callers, generated using `address(uint160(uint256(keccak256("foundry default caller"))))`
sender = '0x1804c8AB1F12E6bbf3894d4083f33e07309d1f38'
tx_origin = '0x1804c8AB1F12E6bbf3894d4083f33e07309d1f38'
Expand Down
6 changes: 6 additions & 0 deletions crates/config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ pub struct Config {
pub invariant: InvariantConfig,
/// Whether to allow ffi cheatcodes in test
pub ffi: bool,
/// Use the create 2 factory in all cases including tests and non-broadcasting scripts.
pub always_use_create_2_factory: bool,
/// The address which will be executing all tests
pub sender: Address,
/// The tx.origin value during EVM execution
Expand Down Expand Up @@ -1786,6 +1788,7 @@ impl Default for Config {
path_pattern_inverse: None,
fuzz: Default::default(),
invariant: Default::default(),
always_use_create_2_factory: false,
ffi: false,
sender: Config::DEFAULT_SENDER,
tx_origin: Config::DEFAULT_SENDER,
Expand Down Expand Up @@ -3301,6 +3304,7 @@ mod tests {
revert_strings = "strip"
allow_paths = ["allow", "paths"]
build_info_path = "build-info"
always_use_create_2_factory = true
[rpc_endpoints]
optimism = "https://example.com/"
Expand Down Expand Up @@ -3352,6 +3356,7 @@ mod tests {
),
]),
build_info_path: Some("build-info".into()),
always_use_create_2_factory: true,
..Config::default()
}
);
Expand Down Expand Up @@ -3403,6 +3408,7 @@ mod tests {
evm_version = 'london'
extra_output = []
extra_output_files = []
always_use_create_2_factory = false
ffi = false
force = false
gas_limit = 9223372036854775807
Expand Down
3 changes: 3 additions & 0 deletions crates/evm/core/src/opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ pub struct EvmOpts {
/// Enables the FFI cheatcode.
pub ffi: bool,

/// Use the create 2 factory in all cases including tests and non-broadcasting scripts.
pub always_use_create_2_factory: bool,

/// Verbosity mode of EVM output as number of occurrences.
pub verbosity: u8,

Expand Down
1 change: 1 addition & 0 deletions crates/forge/tests/cli/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ forgetest!(can_extract_config_values, |prj, cmd| {
},
invariant: InvariantConfig { runs: 256, ..Default::default() },
ffi: true,
always_use_create_2_factory: false,
sender: "00a329c0648769A73afAc7F9381D08FB43dBEA72".parse().unwrap(),
tx_origin: "00a329c0648769A73afAc7F9F81E08FB43dBEA72".parse().unwrap(),
initial_balance: U256::from(0xffffffffffffffffffffffffu128),
Expand Down
7 changes: 7 additions & 0 deletions crates/forge/tests/it/repros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,10 @@ test_repro!(6554; |config| {
cheats_config.fs_permissions.add(PathPermission::read_write(path));
config.runner.cheats_config = std::sync::Arc::new(cheats_config);
});

// https://github.com/foundry-rs/foundry/issues/5529
test_repro!(5529; |config| {
let mut cheats_config = config.runner.cheats_config.as_ref().clone();
cheats_config.always_use_create_2_factory = true;
config.runner.cheats_config = std::sync::Arc::new(cheats_config);
});
37 changes: 37 additions & 0 deletions testdata/repros/Issue5529.t.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.13;

import "ds-test/test.sol";
import "../cheats/Vm.sol";

contract Counter {
uint256 public number;

function setNumber(uint256 newNumber) public {
number = newNumber;
}

function increment() public {
number++;
}
}

contract CounterTest is DSTest {
Vm constant vm = Vm(HEVM_ADDRESS);

Counter public counter;
address public constant default_create2_factory = 0x4e59b44847b379578588920cA78FbF26c0B4956C;

function testCreate2FactoryUsedInTests() public {
address a = vm.computeCreate2Address(0, keccak256(type(Counter).creationCode), address(default_create2_factory));
address b = address(new Counter{salt: 0}());
require(a == b, "create2 address mismatch");
}

function testCreate2FactoryUsedWhenPranking() public {
vm.startPrank(address(1234));
address a = vm.computeCreate2Address(0, keccak256(type(Counter).creationCode), address(default_create2_factory));
address b = address(new Counter{salt: 0}());
require(a == b, "create2 address mismatch");
}
}

0 comments on commit ac45ba5

Please sign in to comment.