Skip to content

Commit

Permalink
Track constructors through the monomorphization procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
JOSHCLUNE committed May 18, 2024
1 parent 29d65a8 commit 3a973bb
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Auto/Solver/SMT.lean
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ def createSolver (name : SolverName) : MetaM SolverProc := do
if auto.smt.dumpHints.get (← getOptions) then
if auto.smt.dumpHints.limitedRws.get (← getOptions) then
createAux "cvc5"
#[s!"--tlimit={tlim * 1000}", "--produce-models",
#[s!"--tlimit={tlim * 1000}", "--produce-models", "--enum-inst",
"--dump-hints", "--proof-granularity=dsl-rewrite", "--hints-only-rw-insts"]
else
createAux "cvc5"
#[s!"--tlimit={tlim * 1000}", "--produce-models",
#[s!"--tlimit={tlim * 1000}", "--produce-models", "--enum-inst",
"--dump-hints", "--proof-granularity=dsl-rewrite"]
else
createAux "cvc5" #[s!"--tlimit={tlim * 1000}", "--produce-models"]
Expand Down
31 changes: 22 additions & 9 deletions Auto/Tactic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ abbrev solverLemmas := List Expr × List Expr × List Expr × List (List Expr)

open Embedding.Lam in
def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndInfo)
: LamReif.ReifM (Option solverLemmas) := do
(idCiMap : HashMap FVarId Monomorphization.ConstInst) : LamReif.ReifM (Option solverLemmas) := do
let lamVarTy := (← LamReif.getVarVal).map Prod.snd
let lamEVarTy ← LamReif.getLamEVarTy
let exportLamTerms ← exportFacts.mapM (fun re => do
Expand All @@ -384,8 +384,20 @@ def querySMTForHints (exportFacts : Array REntry) (exportInds : Array MutualIndI
match varAtom with
| .term termNum =>
let vExp := varVal[termNum]!.1
symbolMap := symbolMap.insert varName vExp
| _ => logWarning s!"varName: {varName} maps to an atom other than term"
match vExp with
| .fvar fVarId =>
-- If `vExp` is an fvar, check whether its fVarId appears in `idCiMap`
-- If it does, have `symbolMap` map `varName` to the original Lean Expr indicated by `idCiMap`
match idCiMap.find? fVarId with
| some ci => symbolMap := symbolMap.insert varName (← ci.toExpr)
| none => symbolMap := symbolMap.insert varName vExp
| _ => symbolMap := symbolMap.insert varName vExp
| .sort _ => logWarning s!"varName: {varName} maps to a sort"
| .etom _ => logWarning s!"varName: {varName} maps to an etom"
| .bvOfNat n => logWarning s!"varName: {varName} maps to bvOfNat {n}"
| .bvToNat n => logWarning s!"varName: {varName} maps to bvToNat {n}"
| .compCtor lamTerm => logWarning s!"varName: {varName} maps to compCtor {lamTerm}"
| .compProj lamTerm => logWarning s!"varName: {varName} maps to compProj {lamTerm}"
if ← auto.getHints.getFailOnParseErrorM then
let preprocessFacts ← preprocessFacts.mapM (fun lemTerm => Parser.SMTTerm.parseTerm lemTerm symbolMap)
let theoryLemmas ← theoryLemmas.mapM (fun lemTerm => Parser.SMTTerm.parseTerm lemTerm symbolMap)
Expand Down Expand Up @@ -534,7 +546,7 @@ def runAutoGetHints (lemmas : Array Lemma) (inhFacts : Array Lemma) : MetaM solv
let decide_simp_lem ← Lemma.ofConst ``Auto.Bool.decide_simp (.leaf "hw Auto.Bool.decide_simp")
let lemmas ← lemmas.mapM (fun lem => Lemma.rewriteUPolyRigid lem decide_simp_lem)
let afterReify (uvalids : Array UMonoFact) (uinhs : Array UMonoFact) (minds : Array (Array SimpleIndVal))
: LamReif.ReifM solverLemmas := (do
(idCiMap : HashMap FVarId Monomorphization.ConstInst) : LamReif.ReifM solverLemmas := (do
let exportFacts ← LamReif.reifFacts uvalids
let mut exportFacts := exportFacts.map (Embedding.Lam.REntry.valid [])
let _ ← LamReif.reifInhabitations uinhs
Expand All @@ -545,14 +557,15 @@ def runAutoGetHints (lemmas : Array Lemma) (inhFacts : Array Lemma) : MetaM solv
-- runAutoGetHints only supports SMT right now
-- **SMT**
if auto.smt.get (← getOptions) then
if let .some lemmas ← querySMTForHints exportFacts exportInds then
if let .some lemmas ← querySMTForHints exportFacts exportInds idCiMap then
return lemmas
throwError "autoGetHints only implemented for cvc5 (enable option auto.smt)"
)
let (lemmas, _) ← Monomorphization.monomorphize lemmas inhFacts (@id (Reif.ReifM solverLemmas) do
let s ← get
let u ← computeMaxLevel s.facts
(afterReify s.facts s.inhTys s.inds).run' {u := u})
let (lemmas, _) ← Monomorphization.monomorphizePreserveMap lemmas inhFacts $ fun (idCiMap, s) => do
let callAfterReify : Reif.ReifM solverLemmas := do
let u ← computeMaxLevel s.facts
(afterReify s.facts s.inhTys s.inds idCiMap).run' {u := u}
callAfterReify.run s
trace[auto.tactic] "Auto found preprocessing and theory lemmas: {lemmas}"
return lemmas

Expand Down
54 changes: 54 additions & 0 deletions Auto/Translation/Monomorphization.lean
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,8 @@ namespace FVarRep
exprMap : HashMap Expr FVarId := {}
ciMap : HashMap Expr ConstInsts
ciIdMap : HashMap ConstInst FVarId := {}
-- Inverse of `ciIdMap`
idCiMap : HashMap FVarId ConstInst := {}
-- Canonicalization map for types
tyCanMap : HashMap Expr Expr := {}

Expand Down Expand Up @@ -620,6 +622,7 @@ namespace FVarRep
processType city
let fvarId ← MetaState.withLetDecl userName city cie .default
setCiIdMap ((← getCiIdMap).insert ci fvarId)
setIdCiMap ((← getIdCiMap).insert fvarId ci)
setFfvars ((← getFfvars).push fvarId)
return fvarId

Expand Down Expand Up @@ -776,4 +779,55 @@ def monomorphize (lemmas : Array Lemma) (inhFacts : Array Lemma) (k : Reif.State
return (s.ffvars, Reif.State.mk s.ffvars uvalids polyVal s.tyCanMap inhs inductiveVals none))
MetaState.runWithIntroducedFVars metaStateMAction k

/-- Like `monomoprhize` but `k` is also passed in `idCiMap` so that `querySMTForHints` can look up the original Lean expressions corresponding
to fvars generated by `FVarRep.replacePolyWithFVar` -/
def monomorphizePreserveMap (lemmas : Array Lemma) (inhFacts : Array Lemma) (k : HashMap FVarId ConstInst × Reif.State → MetaM α) : MetaM α := do
let monoMAction : MonoM (Array (Array SimpleIndVal)) := (do
let startTime ← IO.monoMsNow
initializeMonoM lemmas
saturate
postprocessSaturate
trace[auto.mono] "Monomorphization took {(← IO.monoMsNow) - startTime}ms"
collectMonoMutInds)
let (inductiveVals, monoSt) ← monoMAction.run {}
-- Lemma instances
let lis := monoSt.lisArr.concatMap id
let fvarRepMFactAction : FVarRep.FVarRepM (Array UMonoFact) :=
lis.mapM (fun li => do return ⟨li.proof, ← FVarRep.replacePolyWithFVar li.type, li.deriv⟩)
let fvarRepMInductAction (ivals : Array (Array SimpleIndVal)) : FVarRep.FVarRepM (Array (Array SimpleIndVal)) :=
ivals.mapM (fun svals => svals.mapM (fun ⟨name, type, ctors, projs⟩ => do
FVarRep.processType type
let ctors ← ctors.mapM (fun (val, ty) => do
FVarRep.processType ty
let val' ← FVarRep.replacePolyWithFVar val
return (val', ty))
let projs ← projs.mapM (fun arr => arr.mapM (fun e => do
FVarRep.replacePolyWithFVar e))
return ⟨name, type, ctors, projs⟩))
let metaStateMAction : MetaState.MetaStateM (Array FVarId × HashMap FVarId ConstInst × Reif.State) := (do
let (uvalids, s) ← fvarRepMFactAction.run { ciMap := monoSt.ciMap }
for ⟨proof, ty, _⟩ in uvalids do
trace[auto.mono.printResult] "Monomorphized :: {proof} : {ty}"
let exlis := s.exprMap.toList.map (fun (e, id) => (id, e))
let cilis ← s.ciIdMap.toList.mapM (fun (ci, id) => do return (id, ← MetaState.runMetaM ci.toExpr))
let polyVal := HashMap.ofList (exlis ++ cilis)
let tyCans := s.tyCanMap.toArray.map Prod.snd
-- Inhabited types
let startTime ← IO.monoMsNow
let mut tyCanInhs := #[]
for e in tyCans do
if let .some inh ← MetaState.runMetaM <| Meta.withNewMCtxDepth <| Meta.trySynthInhabited e then
tyCanInhs := tyCanInhs.push ⟨inh, e, .leaf "tyCanInh"
let inhMatches ← MetaState.runMetaM (Inhabitation.inhFactMatchAtomTys inhFacts tyCans)
let inhs := tyCanInhs ++ inhMatches
trace[auto.mono] "Monomorphizing inhabitation facts took {(← IO.monoMsNow) - startTime}ms"
-- Inductive types
let startTime ← IO.monoMsNow
trace[auto.mono] "Monomorphizing inductive types took {(← IO.monoMsNow) - startTime}ms"
let (inductiveVals, s) ← (fvarRepMInductAction inductiveVals).run s
-- For the sake of `querySMTForHints`, we also return `s.idCiMap` which will allow `querySMTForHints` to map
-- generated fVarIds back to the expressions that generated them
return (s.ffvars, s.idCiMap, Reif.State.mk s.ffvars uvalids polyVal s.tyCanMap inhs inductiveVals none))
MetaState.runWithIntroducedFVars metaStateMAction k

end Auto.Monomorphization

0 comments on commit 3a973bb

Please sign in to comment.