Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: isDefEq performance issue #3807

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 58 additions & 26 deletions src/Lean/Meta/ExprDefEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -1690,9 +1690,9 @@ private def isDefEqOnFailure (t s : Expr) : MetaM Bool := do
tryUnificationHints t s <||> tryUnificationHints s t

private def isDefEqProj : Expr → Expr → MetaM Bool
| Expr.proj m i t, Expr.proj n j s => pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s
| Expr.proj structName 0 s, v => isDefEqSingleton structName s v
| v, Expr.proj structName 0 s => isDefEqSingleton structName s v
| .proj m i t, .proj n j s => pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s
| .proj structName 0 s, v => isDefEqSingleton structName s v
| v, .proj structName 0 s => isDefEqSingleton structName s v
| _, _ => pure false
where
/-- If `structName` is a structure with a single field and `(?m ...).1 =?= v`, then solve constraint as `?m ... =?= ⟨v⟩` -/
Expand Down Expand Up @@ -1779,25 +1779,30 @@ private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do
whenUndefDo (isDefEqEta t s) do
whenUndefDo (isDefEqEta s t) do
if (← isDefEqProj t s) then return true
whenUndefDo (isDefEqNative t s) do
whenUndefDo (isDefEqNat t s) do
whenUndefDo (isDefEqOffset t s) do
whenUndefDo (isDefEqDelta t s) do
-- We try structure eta *after* lazy delta reduction;
-- otherwise we would end up applying it at every step of a reduction chain
-- as soon as one of the sides is a constructor application,
-- which is very costly because it requires us to unify the fields.
if (← (isDefEqEtaStruct t s <||> isDefEqEtaStruct s t)) then
return true
if t.isConst && s.isConst then
if t.constName! == s.constName! then isListLevelDefEqAux t.constLevels! s.constLevels! else return false
else if (← pure t.isApp <&&> pure s.isApp <&&> isDefEqApp t s) then
return true
let t' ← whnfCore t
let s' ← whnfCore s
if t != t' || s != s' then
Meta.isExprDefEqAux t' s'
else
whenUndefDo (isDefEqProjInst t s) do
whenUndefDo (isDefEqStringLit t s) do
if (← isDefEqUnitLike t s) then return true else
isDefEqOnFailure t s
whenUndefDo (isDefEqNative t s) do
whenUndefDo (isDefEqNat t s) do
whenUndefDo (isDefEqOffset t s) do
whenUndefDo (isDefEqDelta t s) do
-- We try structure eta *after* lazy delta reduction;
-- otherwise we would end up applying it at every step of a reduction chain
-- as soon as one of the sides is a constructor application,
-- which is very costly because it requires us to unify the fields.
if (← (isDefEqEtaStruct t s <||> isDefEqEtaStruct s t)) then
return true
if t.isConst && s.isConst then
if t.constName! == s.constName! then isListLevelDefEqAux t.constLevels! s.constLevels! else return false
else if (← pure t.isApp <&&> pure s.isApp <&&> isDefEqApp t s) then
return true
else
whenUndefDo (isDefEqProjInst t s) do
whenUndefDo (isDefEqStringLit t s) do
if (← isDefEqUnitLike t s) then return true else
isDefEqOnFailure t s

inductive DefEqCacheKind where
| transient -- problem has mvars or is using nonstandard configuration, we should use transient cache
Expand Down Expand Up @@ -1863,14 +1868,41 @@ partial def isExprDefEqAuxImpl (t : Expr) (s : Expr) : MetaM Bool := withIncRecD
whenUndefDo (isDefEqProofIrrel t s) do
/-
We also reduce projections here to prevent expensive defeq checks when unifying TC operations.
When unifying e.g. `@Neg.neg α (@Field.toNeg α inst1) =?= @Neg.neg α (@Field.toNeg α inst2)`,
When unifying e.g. `(@Field.toNeg α inst1).1 =?= (@Field.toNeg α inst2).1`,
we only want to unify negation (and not all other field operations as well).
Unifying the field instances slowed down unification: https://github.com/leanprover/lean4/issues/1986
We used to *not* reduce projections here, to support unifying `(?a).1 =?= (x, y).1`.
NOTE: this still seems to work because we don't eagerly unfold projection definitions to primitive projections.

Note that ew use `proj := .yesWithDeltaI` to ensure `whnfI` is used to reduce the projection structure.
We added this refinement to address a performance issue in code such as
```
let val : Test := bar c1 key
have : val.1 = (bar c1 key).1 := rfl
```
where `bar` is a complex function that takes a long time to be reduced.

Note that the current solution times out at unification problems such as
`(f x).1 =?= (g x).1` where `f`, `g` are defined as
```
structure Foo where
x : Nat
y : Nat

def f (x : Nat) : Foo :=
{ x, y := ack 10 10 }

def g (x : Nat) : Foo :=
{ x, y := ack 10 11 }
```
and `ack` is ackermann. We claim this is an abuse of the unifier.
That being said, we could in principle address this issue by implementing
lazy-delta reduction at `isDefEqProj`.

The current solution should be sufficient. In the past, we have used
`whnfCore t (config := { proj := .yes })` which more conservative than `.yesWithDeltaI`,
and it only created performance issues when handling TC unification problems.
-/
let t' ← whnfCore t
let s' ← whnfCore s
let t' ← whnfCore t (config := { proj := .yesWithDeltaI })
let s' ← whnfCore s (config := { proj := .yesWithDeltaI })
if t != t' || s != s' then
isExprDefEqAuxImpl t' s'
else
Expand Down
30 changes: 19 additions & 11 deletions src/Lean/Meta/WHNF.lean
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,16 @@ inductive ProjReductionKind where
Recall that `whnfCore` does not perform `delta` reduction (i.e., it will not unfold constant declarations), but `whnf` does.
-/
| yesWithDelta
/--
Projections `s.i` are reduced at `whnfCore`, and `whnfI` is used at `s` during the process.
Recall that `whnfI` is like `whnf` but uses transparency `instances`.
This option is stronger than `yes`, but weaker than `yesWithDelta`.
We use this option to ensure we reduce projections to prevent expensive defeq checks when unifying TC operations.
When unifying e.g. `(@Field.toNeg α inst1).1 =?= (@Field.toNeg α inst2).1`,
we only want to unify negation (and not all other field operations as well).
Unifying the field instances slowed down unification: https://github.com/leanprover/lean4/issues/1986
-/
| yesWithDeltaI
deriving DecidableEq, Inhabited, Repr

/--
Expand Down Expand Up @@ -566,12 +576,6 @@ private def whnfDelayedAssigned? (f' : Expr) (e : Expr) : MetaM (Option Expr) :=
/--
Apply beta-reduction, zeta-reduction (i.e., unfold let local-decls), iota-reduction,
expand let-expressions, expand assigned meta-variables.

The parameter `deltaAtProj` controls how to reduce projections `s.i`. If `deltaAtProj == true`,
then delta reduction is used to reduce `s` (i.e., `whnf` is used), otherwise `whnfCore`.

If `simpleReduceOnly`, then `iota` and projection reduction are not performed.
Note that the value of `deltaAtProj` is irrelevant if `simpleReduceOnly = true`.
-/
partial def whnfCore (e : Expr) (config : WhnfCoreConfig := {}): MetaM Expr :=
go e
Expand Down Expand Up @@ -613,11 +617,15 @@ where
return e
| _ => return e
| .proj _ i c =>
if config.proj == .no then return e
let c ← if config.proj == .yesWithDelta then whnf c else go c
match (← projectCore? c i) with
| some e => go e
| none => return e
let k (c : Expr) := do
match (← projectCore? c i) with
| some e => go e
| none => return e
match config.proj with
| .no => return e
| .yes => k (← go c)
| .yesWithDelta => k (← whnf c)
| .yesWithDeltaI => k (← whnfI c)
| _ => unreachable!

/--
Expand Down
60 changes: 48 additions & 12 deletions src/kernel/type_checker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,16 +376,8 @@ expr type_checker::whnf_fvar(expr const & e, bool cheap_rec, bool cheap_proj) {
return e;
}

/* If `cheap == true`, then we don't perform delta-reduction when reducing major premise. */
optional<expr> type_checker::reduce_proj(expr const & e, bool cheap_rec, bool cheap_proj) {
if (!proj_idx(e).is_small())
return none_expr();
unsigned idx = proj_idx(e).get_small_value();
expr c;
if (cheap_proj)
c = whnf_core(proj_expr(e), cheap_rec, cheap_proj);
else
c = whnf(proj_expr(e));
/* Auxiliary method for `reduce_proj` */
optional<expr> type_checker::reduce_proj_core(expr c, unsigned idx) {
if (is_string_lit(c))
c = string_lit_to_constructor(c);
buffer<expr> args;
Expand All @@ -402,6 +394,19 @@ optional<expr> type_checker::reduce_proj(expr const & e, bool cheap_rec, bool ch
return none_expr();
}

/* If `cheap == true`, then we don't perform delta-reduction when reducing major premise. */
optional<expr> type_checker::reduce_proj(expr const & e, bool cheap_rec, bool cheap_proj) {
if (!proj_idx(e).is_small())
return none_expr();
unsigned idx = proj_idx(e).get_small_value();
expr c;
if (cheap_proj)
c = whnf_core(proj_expr(e), cheap_rec, cheap_proj);
else
c = whnf(proj_expr(e));
return reduce_proj_core(c, idx);
}

static bool is_let_fvar(local_ctx const & lctx, expr const & e) {
lean_assert(is_fvar(e));
if (optional<local_decl> decl = lctx.find_local_decl(e)) {
Expand Down Expand Up @@ -983,6 +988,33 @@ lbool type_checker::lazy_delta_reduction(expr & t_n, expr & s_n) {
}
}

/*
Auxiliary method for checking `t_n.idx =?= s_n.idx`.
It lazily unfolds `t_n` and `s_n`.
Recall that the simpler approach used at `Meta.ExprDefEq` cannot be used in the
kernel since it does not have access to reducibility annotations.
The approach used here is more complicated, but it is also more powerful.
*/
bool type_checker::lazy_delta_proj_reduction(expr & t_n, expr & s_n, nat const & idx) {
while (true) {
switch (lazy_delta_reduction_step(t_n, s_n)) {
case reduction_status::Continue: break;
case reduction_status::DefEqual: return true;
case reduction_status::DefUnknown:
case reduction_status::DefDiff:
if (idx.is_small()) {
unsigned i = idx.get_small_value();
if (auto t = reduce_proj_core(t_n, i)) {
if (auto s = reduce_proj_core(s_n, i)) {
return is_def_eq_core(*t, *s);
}}
}
return is_def_eq_core(t_n, s_n);
}
}
}


static expr * g_string_mk = nullptr;

lbool type_checker::try_string_lit_expansion_core(expr const & t, expr const & s) {
Expand Down Expand Up @@ -1054,8 +1086,12 @@ bool type_checker::is_def_eq_core(expr const & t, expr const & s) {
if (is_fvar(t_n) && is_fvar(s_n) && fvar_name(t_n) == fvar_name(s_n))
return true;

if (is_proj(t_n) && is_proj(s_n) && proj_idx(t_n) == proj_idx(s_n) && is_def_eq(proj_expr(t_n), proj_expr(s_n)))
return true;
if (is_proj(t_n) && is_proj(s_n) && proj_idx(t_n) == proj_idx(s_n)) {
expr t_c = proj_expr(t_n);
expr s_c = proj_expr(s_n);
if (lazy_delta_proj_reduction(t_c, s_c, proj_idx(t_n)))
return true;
}

// Invoke `whnf_core` again, but now using `whnf` to reduce projections.
expr t_n_n = whnf_core(t_n);
Expand Down
2 changes: 2 additions & 0 deletions src/kernel/type_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class type_checker {

enum class reduction_status { Continue, DefUnknown, DefEqual, DefDiff };
optional<expr> reduce_recursor(expr const & e, bool cheap_rec, bool cheap_proj);
optional<expr> reduce_proj_core(expr c, unsigned idx);
optional<expr> reduce_proj(expr const & e, bool cheap_rec, bool cheap_proj);
expr whnf_fvar(expr const & e, bool cheap_rec, bool cheap_proj);
optional<constant_info> is_delta(expr const & e) const;
Expand Down Expand Up @@ -91,6 +92,7 @@ class type_checker {
void cache_failure(expr const & t, expr const & s);
reduction_status lazy_delta_reduction_step(expr & t_n, expr & s_n);
lbool lazy_delta_reduction(expr & t_n, expr & s_n);
bool lazy_delta_proj_reduction(expr & t_n, expr & s_n, nat const & idx);
bool is_def_eq_core(expr const & t, expr const & s);
/** \brief Like \c check, but ignores undefined universes */
expr check_ignore_undefined_universes(expr const & e);
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/1986.lean
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ instance Pi.completeDistribLattice' {ι : Type _} {π : ι → Type _}
[∀ i, CompleteDistribLattice (π i)] : CompleteDistribLattice (∀ i, π i) :=
CompleteDistribLattice.mk (Pi.coframe.infᵢ_sup_le_sup_infₛ)

-- takes around 2 seconds wall clock time on my PC (but very quick in Lean 3)
-- User: takes around 2 seconds wall clock time on my PC (but very quick in Lean 3)
set_option maxHeartbeats 400 -- make sure it stays fast
set_option synthInstance.maxHeartbeats 400
instance Pi.completeDistribLattice'' {ι : Type _} {π : ι → Type _}
Expand Down
2 changes: 1 addition & 1 deletion tests/lean/run/bv_math_lit_perf.lean
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def f (x : BitVec 32) : Nat :=
| 920#32 => 12
| _ => 1000

set_option maxHeartbeats 2800
set_option maxHeartbeats 3000
example : f 500#32 = x := by
simp [f]
sorry
85 changes: 85 additions & 0 deletions tests/lean/run/isDefEqProjIssue.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import Lean
open Lean

-- We need a structure as this is related to isDefEq problems of the form `e1.proj =?= e2.proj`.
structure Test where
x : Nat

-- We need a data structure with functions that are not meant for reduction purposes.
abbrev Cache := HashMap Nat Test

def Cache.insert (cache : Cache) (key : Nat) (val : Test) : Cache :=
HashMap.insert cache key val

def Cache.find? (cache : Cache) (key : Nat) : Option Test :=
HashMap.find? cache key

-- This function just contains a call to a function that we definitely do not want to reduce.
-- To illustrate that the problem is actually noticeable there are multiple implementations provided.
-- Each of these implementations does additional modifications on the cache before looking things up,
-- as one might expect in irl functions.
-- Each version has a lot of additional complexity from the type checkers POV.
def barImpl1 (cache : Cache) (key : Nat) : Test :=
match cache.find? key with
| some val => val
| none => ⟨0⟩

def barImpl2 (cache : Cache) (key : Nat) : Test :=
match (cache.insert key ⟨0⟩).find? key with
| some val => val
| none => ⟨0⟩

def barImpl3 (cache : Cache) (key : Nat) : Test :=
match ((cache.insert key ⟨0⟩).insert 0 ⟨0⟩).find? key with
| some val => val
| none => ⟨0⟩

def barImpl4 (cache : Cache) (key : Nat) : Test :=
match (((cache.insert key ⟨0⟩).insert 0 ⟨0⟩).insert key ⟨key⟩).find? key with
| some val => val
| none => ⟨0⟩

def bar := barImpl4

set_option maxHeartbeats 400 in
def test (c1 : Cache) (key : Nat) : Nat :=
go c1 key
where
go (c1 : Cache) (key : Nat) : Nat :=
let val : Test := bar c1 key
have : val.x = (bar c1 key).x := rfl
val.x

def ack : Nat → Nat → Nat
| 0, y => y+1
| x+1, 0 => ack x 1
| x+1, y+1 => ack x (ack (x+1) y)

class Foo where
x : Nat
y : Nat

instance f (x : Nat) : Foo :=
{ x, y := ack 10 10 }

instance g (x : Nat) : Foo :=
{ x, y := ack 10 11 }

open Lean Meta
set_option maxHeartbeats 400 in
run_meta do
withLocalDeclD `x (mkConst ``Nat) fun x => do
let lhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``f) x
let rhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``g) x
assert! (← isDefEq lhs rhs)

run_meta do
withLocalDeclD `x (mkConst ``Nat) fun x => do
let lhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``f) x
let rhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``g) x
match Kernel.isDefEq (← getEnv) {} lhs rhs with
| .ok b => assert! b
| .error _ => throwError "failed"

example : (f x).1 = (g x).1 :=
rfl
Loading