diff --git a/src/errors.rs b/src/errors.rs index 2265674..68b1ef4 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -10,16 +10,6 @@ pub struct ErrorMessage { pub message: String, } -#[derive(Debug)] -pub struct FromHexError; - -impl Reject for FromHexError {} - -#[derive(Debug)] -pub struct FromDecStrError; - -impl Reject for FromDecStrError {} - #[derive(Debug)] pub struct NoURLForChainIdError; @@ -50,6 +40,11 @@ pub struct StateNotFound(); impl Reject for StateNotFound {} +#[derive(Debug)] +pub struct OverrideError; + +impl Reject for OverrideError {} + #[derive(Debug)] pub struct EvmError(pub Report); @@ -65,12 +60,6 @@ pub async fn handle_rejection(err: Rejection) -> Result } else if let Some(_e) = err.find::() { code = StatusCode::NOT_FOUND; message = "STATE_NOT_FOUND".to_string(); - } else if let Some(FromHexError) = err.find() { - code = StatusCode::BAD_REQUEST; - message = "FROM_HEX_ERROR".to_string(); - } else if let Some(FromDecStrError) = err.find() { - code = StatusCode::BAD_REQUEST; - message = "FROM_DEC_STR_ERROR".to_string(); } else if let Some(NoURLForChainIdError) = err.find() { code = StatusCode::BAD_REQUEST; message = "CHAIN_ID_NOT_SUPPORTED".to_string(); @@ -86,6 +75,9 @@ pub async fn handle_rejection(err: Rejection) -> Result } else if let Some(_e) = err.find::() { code = StatusCode::BAD_REQUEST; message = "INVALID_BLOCK_NUMBERS".to_string(); + } else if let Some(_e) = err.find::() { + code = StatusCode::INTERNAL_SERVER_ERROR; + message = "OVERRIDE_ERROR".to_string(); } else if let Some(_e) = err.find::() { if _e.0.to_string().contains("CallGasCostMoreThanGasLimit") { code = StatusCode::BAD_REQUEST; diff --git a/src/evm.rs b/src/evm.rs index b460870..3720977 100644 --- a/src/evm.rs +++ b/src/evm.rs @@ -1,4 +1,6 @@ -use ethers::abi::{Address, Uint}; +use std::collections::HashMap; + +use ethers::abi::{Address, Hash, Uint}; use ethers::core::types::Log; use ethers::types::Bytes; use foundry_config::Chain; @@ -7,10 +9,13 @@ use foundry_evm::executor::{opts::EvmOpts, Backend, ExecutorBuilder}; use foundry_evm::trace::identifier::{EtherscanIdentifier, SignaturesIdentifier}; use foundry_evm::trace::node::CallTraceNode; use foundry_evm::trace::{CallTraceArena, CallTraceDecoder, CallTraceDecoderBuilder}; +use foundry_evm::utils::{h160_to_b160, u256_to_ru256}; +use revm::db::DatabaseRef; use revm::interpreter::InstructionResult; -use revm::primitives::Env; +use revm::primitives::{Account, Bytecode, Env, StorageSlot}; +use revm::DatabaseCommit; -use crate::errors::EvmError; +use crate::errors::{EvmError, OverrideError}; use crate::simulation::CallTrace; #[derive(Debug, Clone)] @@ -36,6 +41,12 @@ impl From for CallTrace { } } +#[derive(Debug, Clone, PartialEq)] +pub struct StorageOverride { + pub slots: HashMap, + pub diff: bool, +} + pub struct Evm { executor: Executor, decoder: CallTraceDecoder, @@ -155,6 +166,56 @@ impl Evm { }) } + pub fn override_account( + &mut self, + address: Address, + balance: Option, + nonce: Option, + code: Option, + storage: Option, + ) -> Result<(), OverrideError> { + let address = h160_to_b160(address); + let mut account = Account { + info: self + .executor + .backend() + .basic(address) + .map_err(|_| OverrideError)? + .unwrap_or_default(), + ..Account::new_not_existing() + }; + + if let Some(balance) = balance { + account.info.balance = u256_to_ru256(balance); + } + if let Some(nonce) = nonce { + account.info.nonce = nonce; + } + if let Some(code) = code { + account.info.code = Some(Bytecode::new_raw(code.to_vec().into())); + } + if let Some(storage) = storage { + // If we do a "full storage override", make sure to set this flag so + // that existing storage slots are cleared, and unknown ones aren't + // fetched from the forked node. + account.storage_cleared = !storage.diff; + account + .storage + .extend(storage.slots.into_iter().map(|(key, value)| { + ( + u256_to_ru256(Uint::from_big_endian(key.as_bytes())), + StorageSlot::new(u256_to_ru256(value)), + ) + })); + } + + self.executor + .backend_mut() + .commit([(address, account)].into_iter().collect()); + + Ok(()) + } + pub async fn call_raw_committing( &mut self, from: Address, diff --git a/src/simulation.rs b/src/simulation.rs index 989b214..8b54406 100644 --- a/src/simulation.rs +++ b/src/simulation.rs @@ -1,8 +1,9 @@ +use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; use dashmap::mapref::one::RefMut; -use ethers::abi::{Address, Uint}; +use ethers::abi::{Address, Hash, Uint}; use ethers::core::types::Log; use ethers::types::Bytes; use foundry_evm::CallKind; @@ -10,14 +11,14 @@ use revm::interpreter::InstructionResult; use serde::{Deserialize, Serialize}; use tokio::sync::Mutex; use uuid::Uuid; -use warp::reject::custom; use warp::reply::Json; use warp::Rejection; use crate::errors::{ - FromDecStrError, FromHexError, IncorrectChainIdError, InvalidBlockNumbersError, - MultipleChainIdsError, NoURLForChainIdError, StateNotFound, + IncorrectChainIdError, InvalidBlockNumbersError, MultipleChainIdsError, NoURLForChainIdError, + StateNotFound, }; +use crate::evm::StorageOverride; use crate::SharedSimulationState; use super::config::Config; @@ -32,11 +33,13 @@ pub struct SimulationRequest { pub data: Option, #[serde(rename = "gasLimit")] pub gas_limit: u64, - pub value: Option, + pub value: Option, #[serde(rename = "blockNumber")] pub block_number: Option, #[serde(rename = "formatTrace")] pub format_trace: Option, + #[serde(rename = "stateOverrides")] + pub state_overrides: Option>, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -79,6 +82,44 @@ pub struct StatefulSimulationEndResponse { pub success: bool, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct StateOverride { + pub balance: Option, + pub nonce: Option, + pub code: Option, + #[serde(flatten)] + pub state: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum State { + Full { + state: HashMap, + }, + Diff { + #[serde(rename = "stateDiff")] + state_diff: HashMap, + }, +} + +impl From for StorageOverride { + fn from(value: State) -> Self { + let (slots, diff) = match value { + State::Full { state } => (state, false), + State::Diff { state_diff } => (state_diff, true), + }; + + StorageOverride { + slots: slots + .into_iter() + .map(|(key, value)| (key, value.into())) + .collect(), + diff, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct CallTrace { #[serde(rename = "callType")] @@ -88,6 +129,32 @@ pub struct CallTrace { pub value: Uint, } +#[derive(Debug, Default, Clone, Copy, Serialize, PartialEq)] +#[serde(transparent)] +pub struct PermissiveUint(pub Uint); + +impl From for Uint { + fn from(value: PermissiveUint) -> Self { + value.0 + } +} + +impl<'de> Deserialize<'de> for PermissiveUint { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + // Accept value in hex or decimal formats + let value = String::deserialize(deserializer)?; + let parsed = if value.starts_with("0x") { + Uint::from_str(&value).map_err(serde::de::Error::custom)? + } else { + Uint::from_dec_str(&value).map_err(serde::de::Error::custom)? + }; + Ok(Self(parsed)) + } +} + fn chain_id_to_fork_url(chain_id: u64) -> Result { match chain_id { // ethereum @@ -123,22 +190,21 @@ async fn run( transaction: SimulationRequest, commit: bool, ) -> Result { - // Accept value in hex or decimal formats - let value = if let Some(value) = transaction.value { - if value.starts_with("0x") { - Some(Uint::from_str(value.as_str()).map_err(|_err| custom(FromHexError))?) - } else { - Some(Uint::from_dec_str(value.as_str()).map_err(|_err| custom(FromDecStrError))?) - } - } else { - None - }; + for (address, state_override) in transaction.state_overrides.into_iter().flatten() { + evm.override_account( + address, + state_override.balance.map(Uint::from), + state_override.nonce, + state_override.code, + state_override.state.map(StorageOverride::from), + )?; + } let result = if commit { evm.call_raw_committing( transaction.from, transaction.to, - value, + transaction.value.map(Uint::from), transaction.data, transaction.gas_limit, transaction.format_trace.unwrap_or_default(), @@ -148,7 +214,7 @@ async fn run( evm.call_raw( transaction.from, transaction.to, - value, + transaction.value.map(Uint::from), transaction.data, transaction.format_trace.unwrap_or_default(), )