Skip to content

Commit

Permalink
add refines SAW command to build refinesS terms
Browse files Browse the repository at this point in the history
  • Loading branch information
m-yac committed Apr 25, 2023
1 parent 5df5615 commit 33694b2
Show file tree
Hide file tree
Showing 8 changed files with 274 additions and 39 deletions.
2 changes: 1 addition & 1 deletion cryptol-saw-core/src/Verifier/SAW/TypedTerm.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ data TypedTermType
deriving Show


-- | Convert the 'ttTerm' field of a 'TypedTerm' to a SAW core term
-- | Convert the 'ttType' field of a 'TypedTerm' to a SAW core term
ttTypeAsTerm :: SharedContext -> Env -> TypedTerm -> IO Term
ttTypeAsTerm sc env (TypedTerm (TypedTermSchema schema) _) =
importSchema sc env schema
Expand Down
71 changes: 51 additions & 20 deletions examples/mr_solver/mr_solver_unit_tests.saw
Original file line number Diff line number Diff line change
Expand Up @@ -28,42 +28,57 @@ const1 <- parse_core const1_core;
// const0 <= const0
run_test "const0 |= const0" (mr_solver_query const0 const0) true;
// (using mrsolver tactic)
prove_extcore mrsolver (refines [] const0 const0);
// (testing that "refines [] const0 const0" is actually "const0 <= const0")
let const0_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", const0_core, ") x) ", "((", const0_core, ") x)"];
prove_extcore mrsolver (parse_core const0_refines);
run_test "refines [] const0 const0" (is_convertible (parse_core const0_refines)
(refines [] const0 const0)) true;

// The function test_fun0 = const0
// The function test_fun0 <= const0
test_fun0 <- parse_core_mod "test_funs" "test_fun0";
run_test "const0 |= test_fun0" (mr_solver_query const0 test_fun0) true;
// (using mrsolver tactic)
prove_extcore mrsolver (refines [] const0 test_fun0);
// (testing that "refines [] const0 test_fun0" is actually "const0 <= test_fun0")
let const0_test_fun0_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", const0_core, ") x) ", "(test_fun0 x)"];
prove_extcore mrsolver (parse_core_mod "test_funs" const0_test_fun0_refines);
run_test "refines [] const0 test_fun0" (is_convertible (parse_core_mod "test_funs" const0_test_fun0_refines)
(refines [] const0 test_fun0)) true;

// not const0 <= const1
run_test "const0 |= const1" (mr_solver_query const0 const1) false;
// (using mrsolver tactic - fails as expected)
// let const0_const1_refines =
// str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
// "((", const0_core, ") x) ", "((", const1_core, ") x)"];
// prove_extcore mrsolver (parse_core const0_const1_refines);
// prove_extcore mrsolver (refines [] const0 const1);
// (testing that "refines [] const0 const1" is actually "const0 <= const1")
let const0_const1_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", const0_core, ") x) ", "((", const1_core, ") x)"];
run_test "refines [] const0 const1" (is_convertible (parse_core const0_const1_refines)
(refines [] const0 const1)) true;

// The function test_fun1 = const1
test_fun1 <- parse_core_mod "test_funs" "test_fun1";
run_test "const1 |= test_fun1" (mr_solver_query const1 test_fun1) true;
run_test "const0 |= test_fun1" (mr_solver_query const0 test_fun1) false;
// (using mrsolver tactic)
prove_extcore mrsolver (refines [] const1 test_fun1);
// (testing that "refines [] const1 test_fun1" is actually "const1 <= test_fun1")
let const1_test_fun1_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", const1_core, ") x) ", "(test_fun1 x)"];
prove_extcore mrsolver (parse_core_mod "test_funs" const1_test_fun1_refines);
run_test "refines [] const1 test_fun1" (is_convertible (parse_core_mod "test_funs" const1_test_fun1_refines)
(refines [] const1 test_fun1)) true;
// (using mrsolver tactic - fails as expected)
// let const0_test_fun1_refines =
// str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
// "((", const0_core, ") x) ", "(test_fun1 x)"];
// prove_extcore mrsolver (parse_core_mod "test_funs" const0_test_fun1_refines);
// prove_extcore mrsolver (refines [] const0 test_fun1);
// (testing that "refines [] const0 test_fun1" is actually "const0 <= test_fun1")
let const0_test_fun1_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", const0_core, ") x) ", "(test_fun1 x)"];
run_test "refines [] const0 test_fun1" (is_convertible (parse_core_mod "test_funs" const0_test_fun1_refines)
(refines [] const0 test_fun1)) true;

// ifxEq0 x = If x == 0 then x else 0; should be equal to 0
let ifxEq0_core = "\\ (x:Vec 64 Bool) -> \
Expand All @@ -76,18 +91,25 @@ ifxEq0 <- parse_core ifxEq0_core;
// ifxEq0 <= const0
run_test "ifxEq0 |= const0" (mr_solver_query ifxEq0 const0) true;
// (using mrsolver tactic)
prove_extcore mrsolver (refines [] ifxEq0 const0);
// (testing that "refines [] ifxEq0 const0" is actually "ifxEq0 <= const0")
let ifxEq0_const0_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", ifxEq0_core, ") x) ", "((", const0_core, ") x)"];
prove_extcore mrsolver (parse_core ifxEq0_const0_refines);
run_test "refines [] ifxEq0 const0" (is_convertible (parse_core ifxEq0_const0_refines)
(refines [] ifxEq0 const0)) true;


// not ifxEq0 <= const1
run_test "ifxEq0 |= const1" (mr_solver_query ifxEq0 const1) false;
// (using mrsolver tactic - fails as expected)
// let ifxEq0_const1_refines =
// str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
// "((", ifxEq0_core, ") x) ", "((", const1_core, ") x)"];
// prove_extcore mrsolver (parse_core ifxEq0_const1_refines);
// prove_extcore mrsolver (refines [] ifxEq0 const1);
// (testing that "refines [] ifxEq0 const1" is actually "ifxEq0 <= const1")
let ifxEq0_const1_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", ifxEq0_core, ") x) ", "((", const1_core, ") x)"];
run_test "refines [] ifxEq0 const1" (is_convertible (parse_core ifxEq0_const1_refines)
(refines [] ifxEq0 const1)) true;

// noErrors1 x = existsS x. retS x
let noErrors1_core =
Expand All @@ -97,18 +119,24 @@ noErrors1 <- parse_core noErrors1_core;
// const0 <= noErrors
run_test "noErrors1 |= noErrors1" (mr_solver_query noErrors1 noErrors1) true;
// (using mrsolver tactic)
prove_extcore mrsolver (refines [] noErrors1 noErrors1);
// (testing that "refines [] noErrors1 noErrors1" is actually "noErrors1 <= noErrors1")
let noErrors1_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", noErrors1_core, ") x) ", "((", noErrors1_core, ") x)"];
prove_extcore mrsolver (parse_core noErrors1_refines);
run_test "refines [] noErrors1 noErrors1" (is_convertible (parse_core noErrors1_refines)
(refines [] noErrors1 noErrors1)) true;

// const1 <= noErrors
run_test "const1 |= noErrors1" (mr_solver_query const1 noErrors1) true;
// (using mrsolver tactic)
prove_extcore mrsolver (refines [] const1 noErrors1);
// (testing that "refines [] const1 noErrors1" is actually "const1 <= noErrors1")
let const1_noErrors1_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", const1_core, ") x) ", "((", noErrors1_core, ") x)"];
prove_extcore mrsolver (parse_core const1_noErrors1_refines);
run_test "refines [] const1 noErrors1" (is_convertible (parse_core const1_noErrors1_refines)
(refines [] const1 noErrors1)) true;

// noErrorsRec1 _ = orS (existsM x. returnM x) (noErrorsRec1 x)
// Intuitively, this specifies functions that either return a value or loop
Expand Down Expand Up @@ -137,7 +165,10 @@ loop1 <- parse_core loop1_core;
// loop1 <= noErrorsRec1
run_test "loop1 |= noErrorsRec1" (mr_solver_query loop1 noErrorsRec1) true;
// (using mrsolver tactic)
prove_extcore mrsolver (refines [] loop1 noErrorsRec1);
// (testing that "refines [] loop1 noErrorsRec1" is actually "loop1 <= noErrorsRec1")
let loop1_noErrorsRec1_refines =
str_concats ["(x:Vec 64 Bool) -> refinesS_eq VoidEv emptyFunStack (Vec 64 Bool) ",
"((", loop1_core, ") x) ", "((", noErrorsRec1_core, ") x)"];
prove_extcore mrsolver (parse_core loop1_noErrorsRec1_refines);
run_test "refines [] loop1 noErrorsRec1" (is_convertible (parse_core loop1_noErrorsRec1_refines)
(refines [] loop1 noErrorsRec1)) true;
51 changes: 47 additions & 4 deletions src/SAWScript/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2209,12 +2209,26 @@ mrSolverTactic sc = execTactic $ Tactic $ \goal -> lift $ do
case sequentState (goalSequent goal) of
Unfocused -> fail "mrsolver: focus required"
HypFocus _ _ -> fail "mrsolver: cannot apply mrsolver in a hypothesis"
ConclFocus (asPiList . unProp -> (args, asApplyAll ->
(asGlobalDef -> Just "Prelude.refinesS",
[ev1, ev2, stack1, stack2,
asApplyAll -> (asGlobalDef -> Just "Prelude.eqPreRel", _),
asApplyAll -> (asGlobalDef -> Just "Prelude.eqPostRel", _),
rtp1, rtp2,
asApplyAll -> (asGlobalDef -> Just "Prelude.eqRR", _),
t1, t2]))) _ ->
on_refinesS dlvl goal args ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2
ConclFocus (asPiList . unProp -> (args, asApplyAll ->
(asGlobalDef -> Just "Prelude.refinesS_eq",
[ev, stack, rtp, t1, t2]))) _ ->
do tp <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev, stack, rtp]
let tt1 = TypedTerm (TypedTermOther tp) t1
let tt2 = TypedTerm (TypedTermOther tp) t2
on_refinesS dlvl goal args ev ev stack stack rtp rtp t1 t2
_ -> error "[MRSolver] cannot apply mrsolver tactic to a refinesS goal with non-trivial RPre/RPost/RR"
where
on_refinesS dlvl goal args ev1 ev2 stack1 stack2 rtp1 rtp2 t1 t2 =
do tp1 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev1, stack1, rtp1]
tp2 <- liftIO $ scGlobalApply sc "Prelude.SpecM" [ev2, stack2, rtp2]
let tt1 = TypedTerm (TypedTermOther tp1) t1
let tt2 = TypedTerm (TypedTermOther tp2) t2
(diff, res) <- mrSolver Prover.askMRSolver (Just "mrsolver") sc args tt1 tt2
case res of
Left err | dlvl == 0 ->
Expand All @@ -2231,7 +2245,6 @@ mrSolverTactic sc = execTactic $ Tactic $ \goal -> lift $ do
printOutLnTop Info (printf "[MRSolver] Success in %s" (show diff)) >>
let stats = solverStats "MRSOLVER ADMITTED" (sequentSharedSize (goalSequent goal)) in
return ((), stats, [], leafEvidence MrSolverEvidence)
_ -> error "mrsolver tactic not applied to a refinesS_eq goal"

-- | Run Mr Solver to prove that the first term refines the second, adding
-- any relevant 'Prover.FunAssump's to the 'Prover.MREnv' if the first argument
Expand Down Expand Up @@ -2318,6 +2331,36 @@ mrSolverSetDebug dlvl =
modify (\rw -> rw { rwMRSolverEnv =
Prover.mrEnvSetDebugLevel dlvl (rwMRSolverEnv rw) })

-- | Given a list of names and types representing variables over which to
-- quantify as as well as two terms containing those variables, which may be
-- terms or functions in the SpecM monad, construct the SAWCore term which is
-- the refinement (@Prelude.refinesS@) of the given terms, with the given
-- variables generalized with a Pi type.
refinesTerm :: [(Text, C.Schema)] -> TypedTerm -> TypedTerm -> TopLevel TypedTerm
refinesTerm args tt1 tt2 =
do dlvl <- Prover.mreDebugLevel <$> rwMRSolverEnv <$> get
sc <- getSharedContext
env <- rwMRSolverEnv <$> get
args' <- io $ mapM (mapM (argType sc)) args
m1 <- ttTerm <$> ensureMonadicTerm sc tt1
m2 <- ttTerm <$> ensureMonadicTerm sc tt2
res <- io $ Prover.refinementTerm sc env Nothing args' m1 m2
case res of
Left err | dlvl == 0 ->
io (putStrLn $ Prover.showMRFailure err) >>
printOutLnTop Info (printf "[MRSolver] Failed to build refinement term") >>
io (Exit.exitWith $ Exit.ExitFailure 1)
Left err ->
-- we ignore the MRFailure context here since it will have already
-- been printed by the debug trace
io (putStrLn $ Prover.showMRFailureNoCtx err) >>
printOutLnTop Info (printf "[MRSolver] Failed to build refinement term") >>
io (Exit.exitWith $ Exit.ExitFailure 1)
Right t ->
io (mkTypedTerm sc t)
where argType sc (C.Forall [] [] a) = Cryptol.importType sc Cryptol.emptyEnv a
argType _ _ = fail "refinesTerm: given a non-monomorphic type"

setMonadification :: SharedContext -> String -> String -> Bool -> TopLevel ()
setMonadification sc cry_str saw_str poly_p =
do rw <- get
Expand Down
14 changes: 14 additions & 0 deletions src/SAWScript/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3861,6 +3861,15 @@ primitives = Map.fromList
[ "Use MRSolver to prove a current goal of the form:"
, "(a1:A1) -> ... -> (an:A1) -> refinesS_eq ..." ]

, prim "refines" "[(String, Type)] -> Term -> Term -> Term"
(funVal3 refinesTerm)
Experimental
[ "Given a list of names and types representing variables over which"
, " to quantify as as well as two terms containing those variables,"
, " which may be terms or functions in the SpecM monad, construct the"
, " SAWCore term which is the refinement (`Prelude.refinesS`) of the"
, " given terms, with the given variables generalized with a Pi type." ]

---------------------------------------------------------------------

, prim "monadify_term" "Term -> TopLevel Term"
Expand Down Expand Up @@ -4315,6 +4324,11 @@ primitives = Map.fromList
funVal2 f _ _ = VLambda $ \a -> return $ VLambda $ \b ->
fmap toValue (f (fromValue a) (fromValue b))

funVal3 :: forall a b c t. (FromValue a, FromValue b, FromValue c, IsValue t) => (a -> b -> c -> TopLevel t)
-> Options -> BuiltinContext -> Value
funVal3 f _ _ = VLambda $ \a -> return $ VLambda $ \b -> return $ VLambda $ \c ->
fmap toValue (f (fromValue a) (fromValue b) (fromValue c))

scVal :: forall t. IsValue t =>
(SharedContext -> t) -> Options -> BuiltinContext -> Value
scVal f _ bic = toValue (f (biSharedContext bic))
Expand Down
2 changes: 1 addition & 1 deletion src/SAWScript/Prover/MRSolver.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Portability : non-portable (language extensions)
-}

module SAWScript.Prover.MRSolver
(askMRSolver, assumeMRSolver, MRSolverResult,
(askMRSolver, assumeMRSolver, MRSolverResult, refinementTerm,
MRFailure(..), showMRFailure, showMRFailureNoCtx,
FunAssump(..), FunAssumpRHS(..),
MREnv(..), emptyMREnv, mrEnvAddFunAssump, mrEnvSetDebugLevel,
Expand Down
77 changes: 64 additions & 13 deletions src/SAWScript/Prover/MRSolver/Monad.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ monadic combinators for operating on terms.

module SAWScript.Prover.MRSolver.Monad where

import Data.Maybe (fromJust)
import Data.List (find, findIndex, foldl')
import Data.Foldable (foldrM)
import qualified Data.Text as T
import System.IO (hPutStrLn, stderr)
import Control.Monad.Reader
import Control.Monad.State
import Control.Monad.Except
import Control.Monad.Trans.Maybe
import GHC.Generics

import Data.Map (Map)
import qualified Data.Map as Map
Expand Down Expand Up @@ -72,6 +73,7 @@ data MRFailure
| CannotLookupFunDef FunName
| RecursiveUnfold FunName
| MalformedLetRecTypes Term
| MalformedDataTypeAssump Term
| MalformedDefs Term
| MalformedComp Term
| NotCompFunType Term
Expand Down Expand Up @@ -151,6 +153,9 @@ instance PrettyInCtx MRFailure where
ppWithPrefix "Recursive unfolding of function inside its own body:" nm
prettyInCtx (MalformedLetRecTypes t) =
ppWithPrefix "Not a ground LetRecTypes list:" t
prettyInCtx (MalformedDataTypeAssump t) =
ppWithPrefix ("assertS/assumeS expects a Bool, Either, or TCNum equality"
++ " with a constructor on one side, got:") t
prettyInCtx (MalformedDefs t) =
ppWithPrefix "Cannot handle multiFixS recursive definitions term:" t
prettyInCtx (MalformedComp t) =
Expand Down Expand Up @@ -285,18 +290,6 @@ instance PrettyInCtx CoIndHyp where
return "|=",
prettyTermApp (funNameTerm f2) args2]

-- | An assumption that something is equal to one of the constructors of a
-- datatype, e.g. equal to @Left@ of some 'Term' or @Right@ of some 'Term'
data DataTypeAssump
= IsLeft Term | IsRight Term | IsNum Term | IsInf
deriving (Generic, Show, TermLike)

instance PrettyInCtx DataTypeAssump where
prettyInCtx (IsLeft x) = prettyInCtx x >>= ppWithPrefix "Left _ _"
prettyInCtx (IsRight x) = prettyInCtx x >>= ppWithPrefix "Right _ _"
prettyInCtx (IsNum x) = prettyInCtx x >>= ppWithPrefix "TCNum"
prettyInCtx IsInf = return "TCInf"

-- | A map from 'Term's to 'DataTypeAssump's over that term
type DataTypeAssumps = HashMap Term DataTypeAssump

Expand Down Expand Up @@ -825,6 +818,64 @@ mrCallsFun f = memoFixTermFun $ \recurse t -> case t of
(unwrapTermF -> tf) ->
foldM (\b t' -> if b then return b else recurse t') False tf

-- | Given a 'DataTypeAssump' and a 'Term' to which it applies, return the
-- equality representing the proposition that the 'DataTypeAssump' holds.
-- For example, @mrDataTypeAssumpTerm x (IsLeft y)@ for @x : Either a b@
-- would return @Eq (Either a b) x (Left a b y)@.
mrDataTypeAssumpTerm :: Term -> DataTypeAssump -> MRM Term
mrDataTypeAssumpTerm x dt =
do tp <- mrTypeOf x
y <- case dt of
IsLeft y
| Just (primName -> "Prelude.Either", [a, b]) <- asDataType tp ->
liftSC2 scCtorApp "Prelude.Left" [a, b, y]
| otherwise -> error $ "IsLeft expected Either, got: " ++ show tp
IsRight y
| Just (primName -> "Prelude.Either", [a, b]) <- asDataType tp ->
liftSC2 scCtorApp "Prelude.Right" [a, b, y]
| otherwise -> error $ "IsRight expected Either, got: " ++ show tp
IsNum y -> liftSC2 scCtorApp "Prelude.TCNum" [y]
IsInf -> liftSC2 scCtorApp "Prelude.TCInf" []
liftSC2 scGlobalApply "Prelude.Eq" [tp, x, y]

-- | Return the 'Term' which is the refinement (@Prelude.refinesS@) of the
-- given 'Term's, after quantifying over all current 'mrUVars' with Pi types
-- and adding calls to @assertS@ on the right hand side for any current
-- 'mrAssumps' and/or 'mrDataTypeAssump's
mrRefinementGoal :: Term -> Term -> MRM Term
mrRefinementGoal t1 t2 =
do (SpecMParams ev1 stack1, tp1) <- fromJust . asSpecM <$> mrTypeOf t1
(SpecMParams ev2 stack2, tp2) <- fromJust . asSpecM <$> mrTypeOf t2
assumps <- mrAssumptions
assumpsAssert <- liftSC2 scGlobalApply "Prelude.assertBoolS"
[ev2, stack2, assumps]
t2' <- case asBool assumps of
Just True -> return t2
_ -> bindConst ev2 stack2 tp2 assumpsAssert t2
dtAssumps <- HashMap.toList <$> mrDataTypeAssumps
dtAssumpAsserts <- forM dtAssumps $ \(nm, assump) ->
do assump_tm <- mrDataTypeAssumpTerm nm assump
liftSC2 scGlobalApply "Prelude.assertS"
[ev2, stack2, assump_tm]
t2'' <- foldrM (bindConst ev2 stack2 tp2) t2' dtAssumpAsserts
coIndHyps <- mrCoIndHyps
(rpre, rpost, rr) <-
if Map.null coIndHyps
then (,,) <$> liftSC2 scGlobalApply "Prelude.eqPreRel" [ev2, stack2]
<*> liftSC2 scGlobalApply "Prelude.eqPostRel" [ev2, stack2]
<*> liftSC2 scGlobalApply "Prelude.eqRR" [tp2]
else error "FIXME: Handle CoIndHyps in mrRefinementGoal"
ref_tm <- liftSC2 scGlobalApply "Prelude.refinesS"
[ev1, ev2, stack1, stack2, rpre, rpost,
tp1, tp2, rr, t1, t2'']
uvars <- mrUVarsOuterToInner
liftSC2 scPiList uvars ref_tm
where bindConst ev stack tp x y =
do unit <- liftSC0 scUnitType
const_y <- liftSC3 incVars 0 1 y >>= liftSC3 scLambda "_" unit
liftSC2 scGlobalApply "Prelude.bindS"
[ev, stack, unit, tp, x, const_y]


----------------------------------------------------------------------
-- * Monadic Operations on Mr. Solver State
Expand Down
Loading

0 comments on commit 33694b2

Please sign in to comment.