Skip to content

Commit

Permalink
Convert MonadState VM to explicit arguments (#1074)
Browse files Browse the repository at this point in the history
  • Loading branch information
arcz authored Jun 14, 2023
1 parent 916d2fc commit d022554
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 98 deletions.
24 changes: 9 additions & 15 deletions lib/Echidna/Campaign.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ import Optics.Core hiding ((|>))
import Control.Concurrent (writeChan)
import Control.DeepSeq (force)
import Control.Monad (replicateM, when, void, forM_)
import Control.Monad.Catch (MonadCatch(..), MonadThrow(..))
import Control.Monad.Catch (MonadThrow(..))
import Control.Monad.Random.Strict (MonadRandom, RandT, evalRandT)
import Control.Monad.Reader (MonadReader, asks, liftIO, ask)
import Control.Monad.State.Strict
(MonadState(..), StateT(..), evalStateT, gets, MonadIO, modify')
(MonadState(..), StateT(..), gets, MonadIO, modify')
import Control.Monad.Trans (lift)
import Control.Monad.Trans.Random.Strict (liftCatch)
import Data.Binary.Get (runGetOrFail)
import Data.ByteString.Lazy qualified as LBS
import Data.IORef (readIORef, writeIORef, atomicModifyIORef')
Expand Down Expand Up @@ -51,8 +50,6 @@ import Echidna.Utility (getTimestamp)

instance MonadThrow m => MonadThrow (RandT g m) where
throwM = lift . throwM
instance MonadCatch m => MonadCatch (RandT g m) where
catch = liftCatch catch

-- | Given a 'Campaign', check if the test results should be reported as a
-- success or a failure.
Expand All @@ -64,7 +61,7 @@ isSuccessful tests =
-- state. Can be used to minimize corpus as the final campaign state will
-- contain minized corpus without sequences that didn't increase the coverage.
replayCorpus
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
:: (MonadIO m, MonadThrow m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
=> VM -- ^ VM to start replaying from
-> [[Tx]] -- ^ corpus to replay
-> m ()
Expand All @@ -77,7 +74,7 @@ replayCorpus vm txSeqs =
-- optional dictionary to generate calls with. Return the 'Campaign' state once
-- we can't solve or shrink anything.
runWorker
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m)
:: (MonadIO m, MonadThrow m, MonadRandom m, MonadReader Env m)
=> StateT WorkerState m ()
-- ^ Callback to run after each state update (for instrumentation)
-> VM -- ^ Initial VM state
Expand Down Expand Up @@ -189,7 +186,7 @@ randseq deployedContracts world = do
-- | Runs a transaction sequence and checks if any test got falsified or can be
-- minimized. Stores any useful data in the campaign state if coverage increased.
callseq
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
:: (MonadIO m, MonadThrow m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
=> VM
-> [Tx]
-> m VM
Expand All @@ -200,10 +197,7 @@ callseq vm txSeq = do
let
conf = env.cfg.campaignConf
coverageEnabled = isJust conf.knownCoverage
execFunc =
if coverageEnabled
then execTxOptC
else \vm' tx -> runStateT (execTx tx) vm'
execFunc = if coverageEnabled then execTxOptC else execTx

-- Run each call sequentially. This gives us the result of each call
-- and the new state
Expand Down Expand Up @@ -327,7 +321,7 @@ updateGasInfo ((t, _):ts) tseq gi = updateGasInfo ts (t:tseq) gi
-- of transactions, constantly checking if we've solved any tests or can shrink
-- known solves.
evalSeq
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
:: (MonadIO m, MonadThrow m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
=> VM -- ^ Initial VM
-> (VM -> Tx -> m (result, VM))
-> [Tx]
Expand Down Expand Up @@ -370,7 +364,7 @@ runUpdate f = do
-- (3): The test is unshrunk, and we can shrink it
-- Then update accordingly, keeping track of how many times we've tried to solve or shrink.
updateTest
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
:: (MonadIO m, MonadThrow m, MonadRandom m, MonadReader Env m, MonadState WorkerState m)
=> VM
-> (VM, [Tx])
-> EchidnaTest
Expand All @@ -379,7 +373,7 @@ updateTest vmForShrink (vm, xs) test = do
dappInfo <- asks (.dapp)
case test.state of
Open -> do
(testValue, vm') <- evalStateT (checkETest test) vm
(testValue, vm') <- checkETest test vm
let
events = extractEvents False dappInfo vm'
results = getResultFromVM vm'
Expand Down
6 changes: 3 additions & 3 deletions lib/Echidna/Deploy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Echidna.Deploy where
import Control.Monad (foldM)
import Control.Monad.Catch (MonadThrow(..), throwM)
import Control.Monad.Reader (MonadReader, asks)
import Control.Monad.State.Strict (execStateT, MonadIO)
import Control.Monad.State.Strict (MonadIO)
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.ByteString.Base16 qualified as BS16 (decode)
Expand Down Expand Up @@ -50,8 +50,8 @@ deployBytecodes'
deployBytecodes' cs src initialVM = foldM deployOne initialVM cs
where
deployOne vm (dst, bytecode) = do
vm' <- flip execStateT vm $
execTx $ createTx (bytecode <> zeros) src dst unlimitedGasPerBlock (0, 0)
(_, vm') <-
execTx vm $ createTx (bytecode <> zeros) src dst unlimitedGasPerBlock (0, 0)
case vm'.result of
Just (VMSuccess _) -> pure vm'
_ -> do
Expand Down
9 changes: 5 additions & 4 deletions lib/Echidna/Exec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,11 @@ logMsg msg = do

-- | Execute a transaction "as normal".
execTx
:: (MonadIO m, MonadState VM m, MonadReader Env m, MonadThrow m)
=> Tx
-> m (VMResult, Gas)
execTx = execTxWith equality' vmExcept $ fromEVM exec
:: (MonadIO m, MonadReader Env m, MonadThrow m)
=> VM
-> Tx
-> m ((VMResult, Gas), VM)
execTx vm tx = runStateT (execTxWith equality' vmExcept (fromEVM exec) tx) vm

-- | A type alias for the context we carry while executing instructions
type CoverageContext = (Bool, Maybe (BS.ByteString, Int))
Expand Down
27 changes: 13 additions & 14 deletions lib/Echidna/Shrink.hs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module Echidna.Shrink (shrinkTest) where

import Control.Monad ((<=<))
import Control.Monad.Catch (MonadThrow, MonadCatch)
import Control.Monad.Catch (MonadThrow)
import Control.Monad.Random.Strict (MonadRandom, getRandomR, uniform)
import Control.Monad.Reader.Class (MonadReader (ask), asks)
import Control.Monad.State.Strict (MonadState(get, put), evalStateT, MonadIO)
import Data.Foldable (traverse_)
import Control.Monad.State.Strict (MonadIO)
import Data.Set qualified as Set
import Data.List qualified as List

Expand All @@ -22,7 +21,7 @@ import Echidna.Types.Campaign (CampaignConf(..))
import Echidna.Test (getResultFromVM, checkETest)

shrinkTest
:: (MonadIO m, MonadCatch m, MonadRandom m, MonadReader Env m)
:: (MonadIO m, MonadThrow m, MonadRandom m, MonadReader Env m)
=> VM
-> EchidnaTest
-> m (Maybe EchidnaTest)
Expand All @@ -33,7 +32,7 @@ shrinkTest vm test = do
pure $ Just test { state = Solved }
Large i ->
if length test.reproducer > 1 || any canShrinkTx test.reproducer then do
maybeShrunk <- evalStateT (shrinkSeq (checkETest test) test.value test.reproducer) vm
maybeShrunk <- shrinkSeq vm (checkETest test) test.value test.reproducer
pure $ case maybeShrunk of
Just (txs, val, vm') -> do
Just test { state = Large (i + 1)
Expand All @@ -53,25 +52,25 @@ shrinkTest vm test = do
-- | Given a call sequence that solves some Echidna test, try to randomly
-- generate a smaller one that still solves that test.
shrinkSeq
:: (MonadIO m, MonadRandom m, MonadReader Env m, MonadThrow m, MonadState VM m)
=> m (TestValue, VM)
:: (MonadIO m, MonadRandom m, MonadReader Env m, MonadThrow m)
=> VM
-> (VM -> m (TestValue, VM))
-> TestValue
-> [Tx]
-> m (Maybe ([Tx], TestValue, VM))
shrinkSeq f v txs = do
shrinkSeq vm f v txs = do
txs' <- uniform =<< sequence [shorten, shrunk]
(value, vm') <- check txs'
(value, vm') <- check txs' vm
-- if the test passed it means we didn't shrink successfully
pure $ case (value,v) of
(BoolValue False, _) -> Just (txs', value, vm')
(IntValue x, IntValue y) | x >= y -> Just (txs', value, vm')
_ -> Nothing
where
check xs' = do
vm <- get
res <- traverse_ execTx xs' >> f
put vm
pure res
check [] vm' = f vm'
check (x:xs') vm' = do
(_, vm'') <- execTx vm' x
check xs' vm''
shrunk = mapM (shrinkSender <=< shrinkTx) txs
shorten = (\i -> take i txs ++ drop (i + 1) txs) <$> getRandomR (0, length txs)

Expand Down
33 changes: 16 additions & 17 deletions lib/Echidna/Solidity.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import Control.Monad (when, unless, forM_)
import Control.Monad.Catch (MonadThrow(..))
import Control.Monad.Extra (whenM)
import Control.Monad.Reader (ReaderT(runReaderT))
import Control.Monad.State.Strict (execStateT)
import Data.Foldable (toList)
import Data.List (find, partition, isSuffixOf, (\\))
import Data.List.NonEmpty (NonEmpty((:|)))
Expand Down Expand Up @@ -266,27 +265,27 @@ loadSpecified env name cs = do
vm2 <- deployBytecodes solConf.deployBytecodes solConf.deployer vm1

-- main contract deployment
let deployment = execTx $ createTxWithValue
mainContract.creationCode
solConf.deployer
solConf.contractAddr
unlimitedGasPerBlock
(fromIntegral solConf.balanceContract)
(0, 0)
vm3 <- execStateT deployment vm2
let deployment = execTx vm2 $ createTxWithValue
mainContract.creationCode
solConf.deployer
solConf.contractAddr
unlimitedGasPerBlock
(fromIntegral solConf.balanceContract)
(0, 0)
(_, vm3) <- deployment
when (isNothing $ currentContract vm3) $
throwM $ DeploymentFailed solConf.contractAddr $ T.unlines $ extractEvents True env.dapp vm3

-- Run
let transaction = execTx $ uncurry basicTx
setUpFunction
solConf.deployer
solConf.contractAddr
unlimitedGasPerBlock
(0, 0)
let transaction = execTx vm3 $ uncurry basicTx
setUpFunction
solConf.deployer
solConf.contractAddr
unlimitedGasPerBlock
(0, 0)
vm4 <- if isDapptestMode solConf.testMode && setUpFunction `elem` abi
then execStateT transaction vm3
else return vm3
then snd <$> transaction
else pure vm3

case vm4.result of
Just (VMFailure _) -> throwM SetUpCallFailed
Expand Down
82 changes: 40 additions & 42 deletions lib/Echidna/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ module Echidna.Test where
import Prelude hiding (Word)

import Control.Monad.Catch (MonadThrow)
import Control.Monad.IO.Class (MonadIO)
import Control.Monad.Reader.Class (MonadReader, asks)
import Control.Monad.State.Strict (MonadState(get, put), gets, MonadIO)
import Data.ByteString qualified as BS
import Data.ByteString.Lazy qualified as LBS
import Data.Text (Text)
Expand Down Expand Up @@ -136,47 +136,45 @@ updateOpenTest _ _ _ = error "Invalid type of test"

-- | Given a 'SolTest', evaluate it and see if it currently passes.
checkETest
:: (MonadIO m, MonadReader Env m, MonadState VM m, MonadThrow m)
:: (MonadIO m, MonadReader Env m, MonadThrow m)
=> EchidnaTest
-> VM
-> m (TestValue, VM)
checkETest test = case test.testType of
Exploration -> (BoolValue True,) <$> get -- These values are never used
PropertyTest n a -> checkProperty n a
OptimizationTest n a -> checkOptimization n a
AssertionTest dt n a -> if dt then checkDapptestAssertion n a
else checkStatefulAssertion n a
CallTest _ f -> checkCall f
checkETest test vm = case test.testType of
Exploration -> pure (BoolValue True, vm) -- These values are never used
PropertyTest n a -> checkProperty vm n a
OptimizationTest n a -> checkOptimization vm n a
AssertionTest dt n a -> if dt then checkDapptestAssertion vm n a
else checkStatefulAssertion vm n a
CallTest _ f -> checkCall vm f

-- | Given a property test, evaluate it and see if it currently passes.
checkProperty
:: (MonadIO m, MonadReader Env m, MonadState VM m, MonadThrow m)
=> Text
:: (MonadIO m, MonadReader Env m, MonadThrow m)
=> VM
-> Text
-> Addr
-> m (TestValue, VM)
checkProperty f a = do
vm <- get
checkProperty vm f a = do
case vm.result of
Just (VMSuccess _) -> do
TestConf{classifier, testSender} <- asks (.cfg.testConf)
(_, vm') <- runTx f testSender a
b <- gets $ classifier f
put vm -- restore EVM state
pure (BoolValue b, vm')
vm' <- runTx vm f testSender a
pure (BoolValue (classifier f vm'), vm')
_ -> pure (BoolValue True, vm) -- These values are never used

runTx
:: (MonadIO m, MonadReader Env m, MonadState VM m, MonadThrow m)
=> Text
:: (MonadIO m, MonadReader Env m, MonadThrow m)
=> VM
-> Text
-> (Addr -> Addr)
-> Addr
-> m (VM, VM)
runTx f s a = do
vm <- get -- save EVM state
-> m VM
runTx vm f s a = do
-- Our test is a regular user-defined test, we exec it and check the result
g <- asks (.cfg.txConf.propGas)
_ <- execTx $ basicTx f [] (s a) a g (0, 0)
vm' <- get
return (vm, vm')
(_, vm') <- execTx vm $ basicTx f [] (s a) a g (0, 0)
pure vm'

--- | Extract a test value from an execution.
getIntFromResult :: Maybe VMResult -> TestValue
Expand All @@ -189,24 +187,24 @@ getIntFromResult _ = IntValue minBound

-- | Given a property test, evaluate it and see if it currently passes.
checkOptimization
:: (MonadIO m, MonadReader Env m, MonadState VM m, MonadThrow m)
=> Text
:: (MonadIO m, MonadReader Env m, MonadThrow m)
=> VM
-> Text
-> Addr
-> m (TestValue, VM)
checkOptimization f a = do
checkOptimization vm f a = do
TestConf _ s <- asks (.cfg.testConf)
(vm, vm') <- runTx f s a
put vm -- restore EVM state
vm' <- runTx vm f s a
pure (getIntFromResult vm'.result, vm')

checkStatefulAssertion
:: (MonadReader Env m, MonadState VM m, MonadThrow m)
=> SolSignature
:: (MonadReader Env m, MonadThrow m)
=> VM
-> SolSignature
-> Addr
-> m (TestValue, VM)
checkStatefulAssertion sig addr = do
checkStatefulAssertion vm sig addr = do
dappInfo <- asks (.dapp)
vm <- get
let
-- Whether the last transaction called the function `sig`.
isCorrectFn =
Expand All @@ -230,12 +228,12 @@ assumeMagicReturnCode :: BS.ByteString
assumeMagicReturnCode = "FOUNDRY::ASSUME\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"

checkDapptestAssertion
:: (MonadReader Env m, MonadState VM m, MonadThrow m)
=> SolSignature
:: (MonadReader Env m, MonadThrow m)
=> VM
-> SolSignature
-> Addr
-> m (TestValue, VM)
checkDapptestAssertion sig addr = do
vm <- get
checkDapptestAssertion vm sig addr = do
let
-- Whether the last transaction has any value
hasValue = vm.state.callvalue /= Lit 0
Expand All @@ -254,12 +252,12 @@ checkDapptestAssertion sig addr = do
pure (BoolValue (not isFailure), vm)

checkCall
:: (MonadReader Env m, MonadState VM m, MonadThrow m)
=> (DappInfo -> VM -> TestValue)
:: (MonadReader Env m, MonadThrow m)
=> VM
-> (DappInfo -> VM -> TestValue)
-> m (TestValue, VM)
checkCall f = do
checkCall vm f = do
dappInfo <- asks (.dapp)
vm <- get
pure (f dappInfo vm, vm)

checkAssertionTest :: DappInfo -> VM -> TestValue
Expand Down
Loading

0 comments on commit d022554

Please sign in to comment.