Skip to content

Commit

Permalink
Implement State Overrides for Simulations
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Rodrigues Lordello committed Sep 20, 2023
1 parent 64fe96a commit dc55bd4
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 36 deletions.
24 changes: 8 additions & 16 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,6 @@ pub struct ErrorMessage {
pub message: String,
}

#[derive(Debug)]
pub struct FromHexError;

impl Reject for FromHexError {}

#[derive(Debug)]
pub struct FromDecStrError;

impl Reject for FromDecStrError {}

#[derive(Debug)]
pub struct NoURLForChainIdError;

Expand Down Expand Up @@ -50,6 +40,11 @@ pub struct StateNotFound();

impl Reject for StateNotFound {}

#[derive(Debug)]
pub struct OverrideError;

impl Reject for OverrideError {}

#[derive(Debug)]
pub struct EvmError(pub Report);

Expand All @@ -65,12 +60,6 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible>
} else if let Some(_e) = err.find::<StateNotFound>() {
code = StatusCode::NOT_FOUND;
message = "STATE_NOT_FOUND".to_string();
} else if let Some(FromHexError) = err.find() {
code = StatusCode::BAD_REQUEST;
message = "FROM_HEX_ERROR".to_string();
} else if let Some(FromDecStrError) = err.find() {
code = StatusCode::BAD_REQUEST;
message = "FROM_DEC_STR_ERROR".to_string();
} else if let Some(NoURLForChainIdError) = err.find() {
code = StatusCode::BAD_REQUEST;
message = "CHAIN_ID_NOT_SUPPORTED".to_string();
Expand All @@ -86,6 +75,9 @@ pub async fn handle_rejection(err: Rejection) -> Result<impl Reply, Infallible>
} else if let Some(_e) = err.find::<InvalidBlockNumbersError>() {
code = StatusCode::BAD_REQUEST;
message = "INVALID_BLOCK_NUMBERS".to_string();
} else if let Some(_e) = err.find::<OverrideError>() {
code = StatusCode::INTERNAL_SERVER_ERROR;
message = "OVERRIDE_ERROR".to_string();
} else if let Some(_e) = err.find::<EvmError>() {
if _e.0.to_string().contains("CallGasCostMoreThanGasLimit") {
code = StatusCode::BAD_REQUEST;
Expand Down
67 changes: 64 additions & 3 deletions src/evm.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use ethers::abi::{Address, Uint};
use std::collections::HashMap;

use ethers::abi::{Address, Hash, Uint};
use ethers::core::types::Log;
use ethers::types::Bytes;
use foundry_config::Chain;
Expand All @@ -7,10 +9,13 @@ use foundry_evm::executor::{opts::EvmOpts, Backend, ExecutorBuilder};
use foundry_evm::trace::identifier::{EtherscanIdentifier, SignaturesIdentifier};
use foundry_evm::trace::node::CallTraceNode;
use foundry_evm::trace::{CallTraceArena, CallTraceDecoder, CallTraceDecoderBuilder};
use foundry_evm::utils::{h160_to_b160, u256_to_ru256};
use revm::db::DatabaseRef;
use revm::interpreter::InstructionResult;
use revm::primitives::Env;
use revm::primitives::{Account, Bytecode, Env, StorageSlot};
use revm::DatabaseCommit;

use crate::errors::EvmError;
use crate::errors::{EvmError, OverrideError};
use crate::simulation::CallTrace;

#[derive(Debug, Clone)]
Expand All @@ -36,6 +41,12 @@ impl From<CallTraceNode> for CallTrace {
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct StorageOverride {
pub slots: HashMap<Hash, Uint>,
pub diff: bool,
}

pub struct Evm {
executor: Executor,
decoder: CallTraceDecoder,
Expand Down Expand Up @@ -155,6 +166,56 @@ impl Evm {
})
}

pub fn override_account(
&mut self,
address: Address,
balance: Option<Uint>,
nonce: Option<u64>,
code: Option<Bytes>,
storage: Option<StorageOverride>,
) -> Result<(), OverrideError> {
let address = h160_to_b160(address);
let mut account = Account {
info: self
.executor
.backend()
.basic(address)
.map_err(|_| OverrideError)?
.unwrap_or_default(),
..Account::new_not_existing()
};

if let Some(balance) = balance {
account.info.balance = u256_to_ru256(balance);
}
if let Some(nonce) = nonce {
account.info.nonce = nonce;
}
if let Some(code) = code {
account.info.code = Some(Bytecode::new_raw(code.to_vec().into()));
}
if let Some(storage) = storage {
// If we do a "full storage override", make sure to set this flag so
// that existing storage slots are cleared, and unknown ones aren't
// fetched from the forked node.
account.storage_cleared = !storage.diff;
account
.storage
.extend(storage.slots.into_iter().map(|(key, value)| {
(
u256_to_ru256(Uint::from_big_endian(key.as_bytes())),
StorageSlot::new(u256_to_ru256(value)),
)
}));
}

self.executor
.backend_mut()
.commit([(address, account)].into_iter().collect());

Ok(())
}

pub async fn call_raw_committing(
&mut self,
from: Address,
Expand Down
100 changes: 83 additions & 17 deletions src/simulation.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::Arc;

use dashmap::mapref::one::RefMut;
use ethers::abi::{Address, Uint};
use ethers::abi::{Address, Hash, Uint};
use ethers::core::types::Log;
use ethers::types::Bytes;
use foundry_evm::CallKind;
use revm::interpreter::InstructionResult;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use uuid::Uuid;
use warp::reject::custom;
use warp::reply::Json;
use warp::Rejection;

use crate::errors::{
FromDecStrError, FromHexError, IncorrectChainIdError, InvalidBlockNumbersError,
MultipleChainIdsError, NoURLForChainIdError, StateNotFound,
IncorrectChainIdError, InvalidBlockNumbersError, MultipleChainIdsError, NoURLForChainIdError,
StateNotFound,
};
use crate::evm::StorageOverride;
use crate::SharedSimulationState;

use super::config::Config;
Expand All @@ -32,11 +33,13 @@ pub struct SimulationRequest {
pub data: Option<Bytes>,
#[serde(rename = "gasLimit")]
pub gas_limit: u64,
pub value: Option<String>,
pub value: Option<PermissiveUint>,
#[serde(rename = "blockNumber")]
pub block_number: Option<u64>,
#[serde(rename = "formatTrace")]
pub format_trace: Option<bool>,
#[serde(rename = "stateOverrides")]
pub state_overrides: Option<HashMap<Address, StateOverride>>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
Expand Down Expand Up @@ -79,6 +82,44 @@ pub struct StatefulSimulationEndResponse {
pub success: bool,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct StateOverride {
pub balance: Option<PermissiveUint>,
pub nonce: Option<u64>,
pub code: Option<Bytes>,
#[serde(flatten)]
pub state: Option<State>,
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum State {
Full {
state: HashMap<Hash, PermissiveUint>,
},
Diff {
#[serde(rename = "stateDiff")]
state_diff: HashMap<Hash, PermissiveUint>,
},
}

impl From<State> for StorageOverride {
fn from(value: State) -> Self {
let (slots, diff) = match value {
State::Full { state } => (state, false),
State::Diff { state_diff } => (state_diff, true),
};

StorageOverride {
slots: slots
.into_iter()
.map(|(key, value)| (key, value.into()))
.collect(),
diff,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct CallTrace {
#[serde(rename = "callType")]
Expand All @@ -88,6 +129,32 @@ pub struct CallTrace {
pub value: Uint,
}

#[derive(Debug, Default, Clone, Copy, Serialize, PartialEq)]
#[serde(transparent)]
pub struct PermissiveUint(pub Uint);

impl From<PermissiveUint> for Uint {
fn from(value: PermissiveUint) -> Self {
value.0
}
}

impl<'de> Deserialize<'de> for PermissiveUint {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
// Accept value in hex or decimal formats
let value = String::deserialize(deserializer)?;
let parsed = if value.starts_with("0x") {
Uint::from_str(&value).map_err(serde::de::Error::custom)?
} else {
Uint::from_dec_str(&value).map_err(serde::de::Error::custom)?
};
Ok(Self(parsed))
}
}

fn chain_id_to_fork_url(chain_id: u64) -> Result<String, Rejection> {
match chain_id {
// ethereum
Expand Down Expand Up @@ -123,22 +190,21 @@ async fn run(
transaction: SimulationRequest,
commit: bool,
) -> Result<SimulationResponse, Rejection> {
// Accept value in hex or decimal formats
let value = if let Some(value) = transaction.value {
if value.starts_with("0x") {
Some(Uint::from_str(value.as_str()).map_err(|_err| custom(FromHexError))?)
} else {
Some(Uint::from_dec_str(value.as_str()).map_err(|_err| custom(FromDecStrError))?)
}
} else {
None
};
for (address, state_override) in transaction.state_overrides.into_iter().flatten() {
evm.override_account(
address,
state_override.balance.map(Uint::from),
state_override.nonce,
state_override.code,
state_override.state.map(StorageOverride::from),
)?;
}

let result = if commit {
evm.call_raw_committing(
transaction.from,
transaction.to,
value,
transaction.value.map(Uint::from),
transaction.data,
transaction.gas_limit,
transaction.format_trace.unwrap_or_default(),
Expand All @@ -148,7 +214,7 @@ async fn run(
evm.call_raw(
transaction.from,
transaction.to,
value,
transaction.value.map(Uint::from),
transaction.data,
transaction.format_trace.unwrap_or_default(),
)
Expand Down

0 comments on commit dc55bd4

Please sign in to comment.