Skip to content

Commit

Permalink
feat: Expr internalization, add facts, congruence theorem cache
Browse files Browse the repository at this point in the history
  • Loading branch information
leodemoura committed Dec 18, 2024
1 parent c1dade0 commit 0504f50
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 29 deletions.
2 changes: 2 additions & 0 deletions src/Lean/Meta/Tactic/Grind.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ namespace Lean

builtin_initialize registerTraceClass `grind
builtin_initialize registerTraceClass `grind.eq
builtin_initialize registerTraceClass `grind.issues
builtin_initialize registerTraceClass `grind.add

end Lean
121 changes: 112 additions & 9 deletions src/Lean/Meta/Tactic/Grind/Core.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def isInterpreted (e : Expr) : MetaM Bool := do
Creates an `ENode` for `e` if one does not already exist.
This method assumes `e` has been hashconsed.
-/
def mkENode (e : Expr) (generation : Nat := 0) : GoalM Unit := do
if (← getENode? e).isSome then return ()
def mkENode (e : Expr) (generation : Nat) : GoalM Unit := do
if (← alreadyInternalized e) then return ()
let ctor := (← isConstructorAppCore? e).isSome
let interpreted ← isInterpreted e
mkENodeCore e interpreted ctor generation
Expand All @@ -40,11 +40,6 @@ def getNext (e : Expr) : GoalM Expr := do
let some n ← getENode? e | return e
return n.next

@[inline] def isSameExpr (a b : Expr) : Bool :=
-- It is safe to use pointer equality because we hashcons all expressions
-- inserted into the E-graph
unsafe ptrEq a b

private def pushNewEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit :=
modify fun s => { s with newEqs := s.newEqs.push { lhs, rhs, proof, isHEq } }

Expand All @@ -54,6 +49,44 @@ private def pushNewEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit :=
@[inline] private def pushNewHEq (lhs rhs proof : Expr) : GoalM Unit :=
pushNewEqCore lhs rhs proof (isHEq := true)

/--
Adds `e` to congruence table.
-/
def addCongrTable (_e : Expr) : GoalM Unit := do
-- TODO
return ()

partial def internalize (e : Expr) (generation : Nat) : GoalM Unit := do
if (← alreadyInternalized e) then return ()
match e with
| .bvar .. => unreachable!
| .sort .. => return ()
| .fvar .. | .letE .. | .lam .. | .forallE .. =>
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
| .lit .. | .const .. =>
mkENode e generation
| .mvar ..
| .mdata ..
| .proj .. =>
trace[grind.issues] "unexpected term during internalization{indentExpr e}"
mkENodeCore e (ctor := false) (interpreted := false) (generation := generation)
| .app .. => e.withApp fun f args => do
let congrThm ← mkHCongrWithArity f args.size
let info ← getFunInfo f
let shouldInternalize (i : Nat) : GoalM Bool := do
if h : i < info.paramInfo.size then
let pinfo := info.paramInfo[i]
if pinfo.binderInfo.isInstImplicit || pinfo.isProp then
return false
return true
for h : i in [: args.size] do
let arg := args[i]
if (← shouldInternalize i) then
unless (← isTypeFormerType arg) do
internalize arg generation
mkENode e generation
addCongrTable e

/--
The fields `target?` and `proof?` in `e`'s `ENode` are encoding a transitivity proof
from `e` to the root of the equivalence class
Expand All @@ -77,7 +110,11 @@ where
private def markAsInconsistent : GoalM Unit :=
modify fun s => { s with inconsistent := true }

def isInconsistent : GoalM Bool :=
return (← get).inconsistent

private partial def addEqStep (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
trace[grind.eq] "{lhs} {if isHEq then "=" else ""} {rhs}"
let some lhsNode ← getENode? lhs | return () -- `lhs` has not been internalized yet
let some rhsNode ← getENode? rhs | return () -- `rhs` has not been internalized yet
if isSameExpr lhsNode.root rhsNode.root then return () -- `lhs` and `rhs` are already in the same equivalence class.
Expand Down Expand Up @@ -136,13 +173,17 @@ where
loop n.next
loop lhs

/-- Ensures collection of equations to be processed is empty. -/
def resetNewEqs : GoalM Unit :=
modify fun s => { s with newEqs := #[] }

partial def addEqCore (lhs rhs proof : Expr) (isHEq : Bool) : GoalM Unit := do
addEqStep lhs rhs proof isHEq
processTodo
where
processTodo : GoalM Unit := do
if (← get).inconsistent then
modify fun s => { s with newEqs := #[] }
if (← isInconsistent) then
resetNewEqs
return ()
let some { lhs, rhs, proof, isHEq } := (← get).newEqs.back? | return ()
addEqStep lhs rhs proof isHEq
Expand All @@ -154,4 +195,66 @@ def addEq (lhs rhs proof : Expr) : GoalM Unit := do
def addHEq (lhs rhs proof : Expr) : GoalM Unit := do
addEqCore lhs rhs proof true

/--
Adds a new `fact` justified by the given proof and using the given generation.
-/
def add (fact : Expr) (proof : Expr) (generation := 0) : GoalM Unit := do
trace[grind.add] "{proof} : {fact}"
if (← isInconsistent) then return ()
resetNewEqs
let_expr Not p := fact
| go fact false
go p true
where
go (p : Expr) (isNeg : Bool) : GoalM Unit := do
trace[grind.add] "isNeg: {isNeg}, {p}"
match_expr p with
| Eq _ lhs rhs => goEq p lhs rhs isNeg false
| HEq _ _ lhs rhs => goEq p lhs rhs isNeg true
| _ =>
internalize p generation
if isNeg then
addEq p (← getFalseExpr) (← mkEqFalse proof)
else
addEq p (← getFalseExpr) (← mkEqTrue proof)

goEq (p : Expr) (lhs rhs : Expr) (isNeg : Bool) (isHEq : Bool) : GoalM Unit := do
if isNeg then
internalize p generation
addEq p (← getFalseExpr) (← mkEqFalse proof)
else
internalize lhs generation
internalize rhs generation
addEqCore lhs rhs proof isHEq

/--
Adds a new hypothesis.
-/
def addHyp (fvarId : FVarId) (generation := 0) : GoalM Unit := do
add (← fvarId.getType) (mkFVar fvarId) generation

/--
Returns expressions in the given expression equivalence class.
-/
partial def getEqc (e : Expr) : GoalM (List Expr) :=
go e e []
where
go (first : Expr) (e : Expr) (acc : List Expr) : GoalM (List Expr) := do
let next ← getNext e
let acc := e :: acc
if isSameExpr e next then
return acc
else
go first next acc

/--
Returns all equivalence classes in the current goal.
-/
partial def getEqcs : GoalM (List (List Expr)) := do
let mut r := []
for (_, node) in (← get).enodes do
if isSameExpr node.root node.self then
r := (← getEqc node.self) :: r
return r

end Lean.Meta.Grind
4 changes: 2 additions & 2 deletions src/Lean/Meta/Tactic/Grind/Preprocessor.lean
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def introNext (goal : Goal) : PreM IntroResult := do
let r ← simp goal p
let p' := r.expr
let p' ← eraseIrrelevantMData p'
let p' ← foldProjs p'
let p' ← canon p'
let p' ← shareCommon p'
let fvarId ← mkFreshFVarId
Expand Down Expand Up @@ -135,8 +136,7 @@ partial def loop (goal : Goal) : PreM Unit := do
else if let some goal ← applyInjection? goal fvarId then
loop goal
else
let clause ← goal.mvarId.withContext do mkInputClause fvarId
loop { goal with clauses := goal.clauses.push clause }
loop (← GoalM.run' goal <| addHyp fvarId)
| .newDepHyp goal =>
loop goal
| .newLocal fvarId goal =>
Expand Down
Loading

0 comments on commit 0504f50

Please sign in to comment.