Skip to content

Commit

Permalink
Update scalar_tac to use the aesop tactic (#282)
Browse files Browse the repository at this point in the history
* Start using aesop in scalar_tac

* Update the Lean dependencies

* Start adding some utilities for Aesop

* Update intTac to use Aesop.saturate

* Make progress on updating scalar_tac to use aesop

* Start updating Lean to v4.10.0-rc1

* Update the tests

* Make progress on scalar_tac

* Update the dependencies

* Update the dependencies

* Update the proofs in the Lean standard library

* Update the dependencies in the Lean tests

* Use scalar_tac patterns in the proof of the hashmap

* Make minor modifications

* Add options and persistent extensions to add more rule sets for scalar_tac
  • Loading branch information
sonmarcho authored Jul 19, 2024
1 parent 595ca8a commit 219d478
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 339 deletions.
1 change: 1 addition & 0 deletions backends/lean/Base/Arith.lean
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
import Base.Arith.Int
import Base.Arith.Scalar
import Base.Arith.Lemmas
40 changes: 40 additions & 0 deletions backends/lean/Base/Arith/Init.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import Base.Extensions
import Aesop
open Lean

/-!
# Scalar tac rules sets
This module defines several Aesop rule sets and options which are used by the
`scalar_tac` tactic. Aesop rule sets only become visible once the file in which
they're declared is imported, so we must put this declaration into its own file.
-/

namespace Arith

declare_aesop_rule_sets [Aeneas.ScalarTac, Aeneas.ScalarTacNonLin]

#check Lean.Option.register
register_option scalarTac.nonLin : Bool := {
defValue := false
group := ""
descr := "Activate the use of a set of lemmas to reason about non-linear arithmetic by `scalar_tac`"
}

-- The sets of rules that `scalar_tac` should use
open Extensions in
initialize scalarTacRuleSets : ListDeclarationExtension Name ← do
mkListDeclarationExtension `scalarTacRuleSetsList

def scalarTacRuleSets.get : MetaM (List Name) := do
pure (scalarTacRuleSets.getState (← getEnv))

-- Note that the changes are not persistent
def scalarTacRuleSets.set (names : List Name) : MetaM Unit := do
let _ := scalarTacRuleSets.setState (← getEnv) names

-- Note that the changes are not persistent
def scalarTacRuleSets.add (name : Name) : MetaM Unit := do
let _ := scalarTacRuleSets.modifyState (← getEnv) (fun ls => name :: ls)

end Arith
221 changes: 30 additions & 191 deletions backends/lean/Base/Arith/Int.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,28 @@ import Init.Data.List.Basic
import Mathlib.Tactic.Ring.RingNF
import Base.Utils
import Base.Arith.Base
import Base.Arith.Init

namespace Arith

open Utils
open Lean Lean.Elab Lean.Meta
open Lean Lean.Elab Lean.Meta Lean.Elab.Tactic

/- We can introduce a term in the context.
For instance, if we find `x : U32` in the context we can introduce `0 ≤ x ∧ x ≤ U32.max`
/- Defining a custom attribute for Aesop - we use Aesop tactic in the arithmetic tactics -/

Remark: I tried a version of the shape `HasScalarProp {a : Type} (x : a)`
but the lookup didn't work.
-/
class HasIntProp (a : Sort u) where
prop_ty : a → Prop
prop : ∀ x:a, prop_ty x
attribute [aesop (rule_sets := [Aeneas.ScalarTac]) unfold norm] Function.comp

/- Terms that induces predicates: if we can find the term `x`, we can introduce `concl` in the context. -/
class HasIntPred {a: Sort u} (x: a) where
concl : Prop
prop : concl
/-- The `int_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/
macro "int_tac" pat:term : attr =>
`(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTac):ident]) (pattern := $pat))

/- Proposition with implications: if we find P we can introduce Q in the context -/
class PropHasImp (x : Sort u) where
concl : Prop
prop : x → concl
/-- The `scalar_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/
macro "scalar_tac" pat:term : attr =>
`(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTac):ident]) (pattern := $pat))

instance (p : Int → Prop) : HasIntProp (Subtype p) where
prop_ty := λ x => p x
prop := λ x => x.property
/-- The `nonlin_scalar_tac` attribute used to tag forward theorems for the `int_tac` and `scalar_tac` tactics. -/
macro "nonlin_scalar_tac" pat:term : attr =>
`(attr|aesop safe forward (rule_sets := [$(Lean.mkIdent `Aeneas.ScalarTacNonLin):ident]) (pattern := $pat))

/- Check if a proposition is a linear integer proposition.
We notably use this to check the goals: this is useful to filter goals that
Expand Down Expand Up @@ -70,186 +63,32 @@ def goalIsLinearInt : Tactic.TacticM Bool := do
| .some _ => pure true
| _ => pure false

/- Explore a term by decomposing the applications (we explore the applied
functions and their arguments, but ignore lambdas, forall, etc. -
should we go inside?).
Remark: we pretend projections are applications, and explore the projected
terms. -/
partial def foldTermApps (k : α → Expr → MetaM α) (s : α) (e : Expr) : MetaM α := do
-- Explore the current expression
let e := e.consumeMData
let s ← k s e
-- Recurse
match e with
| .proj _ _ e' =>
foldTermApps k s e'
| .app .. =>
e.withApp fun f args => do
let s ← k s f
args.foldlM (foldTermApps k) s
| _ => pure s

/- Provided a function `k` which lookups type class instances on an expression,
collect all the instances lookuped by applying `k` on the sub-expressions of `e`. -/
def collectInstances
(k : Expr → MetaM (Option Expr)) (s : HashSet Expr) (e : Expr) : MetaM (HashSet Expr) := do
let k s e := do
match ← k e with
| none => pure s
| some i => pure (s.insert i)
foldTermApps k s e

/- Similar to `collectInstances`, but explores all the local declarations in the
main context. -/
def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do
Tactic.withMainContext do
-- Get the local context
let ctx ← Lean.MonadLCtx.getLCtx
-- Initialize the hashset
let hs := HashSet.empty
-- Explore the declarations
let decls ← ctx.getDecls
let hs ← decls.foldlM (fun hs d => do
-- Collect instances over all subexpressions in the context.
-- Note that if the local declaration is
-- Note that we explore the *type* of propositions: if we have
-- for instance `h : A ∧ B` in the context, the expression itself is simply
-- `h`; the information we are interested in is its type.
-- However, if the decl is not a proposition, we explore it directly.
-- For instance: `x : U32`
-- TODO: case disjunction on whether the local decl is a Prop or not. If prop,
-- we need to explore its type.
let d := d.toExpr
if d.isProp then
collectInstances k hs d
else
let ty ← Lean.Meta.inferType d
collectInstances k hs ty
) hs
-- Also explore the goal
collectInstances k hs (← Tactic.getMainTarget)

-- Helper
def lookupProp (fName : String) (className : Name) (e : Expr)
(instantiateClassFn : Expr → MetaM (Array Expr))
(instantiateProjectionFn : Expr → MetaM (Array Expr)) : MetaM (Option Expr) := do
trace[Arith] fName
-- TODO: do we need Lean.observing?
-- This actually eliminates the error messages
trace[Arith] m!"{fName}: {e}"
Lean.observing? do
trace[Arith] m!"{fName}: observing: {e}"
let hasProp ← mkAppM className (← instantiateClassFn e)
let hasPropInst ← trySynthInstance hasProp
match hasPropInst with
| LOption.some i =>
trace[Arith] "Found {fName} instance"
let i_prop ← mkProjection i (Name.mkSimple "prop")
some (← mkAppM' i_prop (← instantiateProjectionFn e))
| _ => none

-- Return an instance of `HasIntProp` for `e` if it has some
def lookupHasIntProp (e : Expr) : MetaM (Option Expr) :=
lookupProp "lookupHasIntProp" ``HasIntProp e (fun e => do pure #[← Lean.Meta.inferType e]) (fun e => pure #[e])

-- Collect the instances of `HasIntProp` for the subexpressions in the context
def collectHasIntPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
collectInstancesFromMainCtx lookupHasIntProp

-- Return an instance of `HasIntPred` for `e` if it has some
def lookupHasIntPred (e : Expr) : MetaM (Option Expr) :=
lookupProp "lookupHasIntPred" ``HasIntPred e (fun term => pure #[term]) (fun _ => pure #[])

-- Collect the instances of `HasIntPred` for the subexpressions in the context
def collectHasIntPredInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
collectInstancesFromMainCtx lookupHasIntPred

-- Return an instance of `PropHasImp` for `e` if it has some
def lookupPropHasImp (e : Expr) : MetaM (Option Expr) := do
trace[Arith] m!"lookupPropHasImp: {e}"
-- TODO: do we need Lean.observing?
-- This actually eliminates the error messages
Lean.observing? do
trace[Arith] "lookupPropHasImp: observing: {e}"
let ty ← Lean.Meta.inferType e
trace[Arith] "lookupPropHasImp: ty: {ty}"
let cl ← mkAppM ``PropHasImp #[ty]
let inst ← trySynthInstance cl
match inst with
| LOption.some i =>
trace[Arith] "Found PropHasImp instance"
let i_prop ← mkProjection i (Name.mkSimple "prop")
some (← mkAppM' i_prop #[e])
| _ => none

-- Collect the instances of `PropHasImp` for the subexpressions in the context
def collectPropHasImpInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
collectInstancesFromMainCtx lookupPropHasImp

elab "display_prop_has_imp_instances" : tactic => do
trace[Arith] "Displaying the PropHasImp instances"
let hs ← collectPropHasImpInstancesFromMainCtx
hs.forM fun e => do
trace[Arith] "+ PropHasImp instance: {e}"

example (x y : Int) (_ : x ≠ y) (_ : ¬ x = y) : True := by
display_prop_has_imp_instances
simp

example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by
omega

-- Lookup instances in a context and introduce them with additional declarations.
def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) : Tactic.TacticM (Array Expr) := do
let hs ← collectInstancesFromMainCtx lookup
hs.toArray.mapM fun e => do
let type ← inferType e
let name ← mkFreshAnonPropUserName
-- Add a declaration
let nval ← Utils.addDeclTac name e type (asLet := false)
-- Simplify to unfold the declaration to unfold (i.e., the projector)
Utils.simpAt true {} [] [declToUnfold] [] [] (Location.targets #[mkIdent name] false)
-- Return the new value
pure nval

def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do
trace[Arith] "Introducing the HasIntProp instances"
introInstances ``HasIntProp.prop_ty lookupHasIntProp

-- Lookup the instances of `HasIntProp for all the sub-expressions in the context,
-- and introduce the corresponding assumptions
elab "intro_has_int_prop_instances" : tactic => do
let _ ← introHasIntPropInstances

def introHasIntPredInstances : Tactic.TacticM (Array Expr) := do
trace[Arith] "Introducing the HasIntPred instances"
introInstances ``HasIntPred.concl lookupHasIntPred

elab "intro_has_int_pred_instances" : tactic => do
let _ ← introHasIntPredInstances

def introPropHasImpInstances : Tactic.TacticM (Array Expr) := do
trace[Arith] "Introducing the PropHasImp instances"
introInstances ``PropHasImp.concl lookupPropHasImp

-- Lookup the instances of `PropHasImp for all the sub-expressions in the context,
-- and introduce the corresponding assumptions
elab "intro_prop_has_imp_instances" : tactic => do
let _ ← introPropHasImpInstances

def intTacSimpRocs : List Name := [``Int.reduceNegSucc, ``Int.reduceNeg]

/-- Apply the scalar_tac forward rules -/
def intTacSaturateForward : Tactic.TacticM Unit := do
let options : Aesop.Options := {}
-- Use a forward max depth of 0 to prevent recursively applying forward rules on the assumptions
-- introduced by the forward rules themselves.
let options ← options.toOptions' (some 0)
-- We always use the rule set `Aeneas.ScalarTac`, but also need to add other rule sets locally
-- activated by the user. The `Aeneas.ScalarTacNonLin` rule set has a special treatment as
-- it is activated through an option.
let ruleSets :=
let ruleSets := `Aeneas.ScalarTac :: (← scalarTacRuleSets.get)
if scalarTac.nonLin.get (← getOptions) then `Aeneas.ScalarTacNonLin :: ruleSets
else ruleSets
evalAesopSaturate options ruleSets.toArray

/- Boosting a bit the `omega` tac.
-/
def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
Tactic.withMainContext do
-- Introduce the instances of `HasIntProp`
let _ ← introHasIntPropInstances
-- Introduce the instances of `HasIntPred`
let _ ← introHasIntPredInstances
-- Introduce the instances of `PropHasImp`
let _ ← introPropHasImpInstances
-- Apply the forward rules
intTacSaturateForward
-- Extra preprocessing
extraPreprocess
-- Reduce all the terms in the goal - note that the extra preprocessing step
Expand Down
31 changes: 31 additions & 0 deletions backends/lean/Base/Arith/Lemmas.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import Base.Arith.Int
import Base.Arith.Scalar

@[nonlin_scalar_tac n % m]
theorem Int.emod_of_pos_disj (n m : Int) : m ≤ 0 ∨ (0 ≤ n % m ∧ n % m < m) := by
if h: 0 < m then
right; constructor
. apply Int.emod_nonneg; omega
. apply Int.emod_lt_of_pos; omega
else left; omega

theorem Int.pos_mul_pos_is_pos (n m : Int) (hm : 0 ≤ m) (hn : 0 ≤ n): 0 ≤ m * n := by
have h : (0 : Int) = 0 * 0 := by simp
rw [h]
apply mul_le_mul <;> norm_cast

@[nonlin_scalar_tac m * n]
theorem Int.pos_mul_pos_is_pos_disj (n m : Int) : m < 0 ∨ n < 00 ≤ m * n := by
cases h: (m < 0 : Bool) <;> simp_all
cases h: (n < 0 : Bool) <;> simp_all
right; right; apply pos_mul_pos_is_pos <;> tauto

-- Some tests
section

-- Activate the rule set for non linear arithmetic
set_option scalarTac.nonLin true

example (x y : Int) (h : 0 ≤ x ∧ 0 ≤ y) : 0 ≤ x * y := by scalar_tac

end
10 changes: 6 additions & 4 deletions backends/lean/Base/Arith/Scalar.lean
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ def scalarTac (splitGoalConjs : Bool) : Tactic.TacticM Unit := do
elab "scalar_tac" : tactic =>
scalarTac false

instance (ty : ScalarTy) : HasIntProp (Scalar ty) where
-- prop_ty is inferred
prop := λ x => And.intro x.hmin x.hmax
@[scalar_tac x]
theorem Scalar.bounds {ty : ScalarTy} (x : Scalar ty) :
Scalar.min ty ≤ x.val ∧ x.val ≤ Scalar.max ty :=
And.intro x.hmin x.hmax

example (x _y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
intro_has_int_prop_instances
scalar_tac_preprocess
simp [*]

example (x _y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
Expand All @@ -65,6 +66,7 @@ example : U32.ofInt 1 ≤ U32.max := by

example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) :
U32.ofIntCore x (by constructor <;> scalar_tac) ≤ U32.max := by
scalar_tac_preprocess
scalar_tac

-- Not equal
Expand Down
Loading

0 comments on commit 219d478

Please sign in to comment.