diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index be730e9cbb46..dde3b7d3b215 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -1690,9 +1690,9 @@ private def isDefEqOnFailure (t s : Expr) : MetaM Bool := do tryUnificationHints t s <||> tryUnificationHints s t private def isDefEqProj : Expr → Expr → MetaM Bool - | Expr.proj m i t, Expr.proj n j s => pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s - | Expr.proj structName 0 s, v => isDefEqSingleton structName s v - | v, Expr.proj structName 0 s => isDefEqSingleton structName s v + | .proj m i t, .proj n j s => pure (i == j && m == n) <&&> Meta.isExprDefEqAux t s + | .proj structName 0 s, v => isDefEqSingleton structName s v + | v, .proj structName 0 s => isDefEqSingleton structName s v | _, _ => pure false where /-- If `structName` is a structure with a single field and `(?m ...).1 =?= v`, then solve constraint as `?m ... =?= ⟨v⟩` -/ @@ -1779,25 +1779,30 @@ private def isExprDefEqExpensive (t : Expr) (s : Expr) : MetaM Bool := do whenUndefDo (isDefEqEta t s) do whenUndefDo (isDefEqEta s t) do if (← isDefEqProj t s) then return true - whenUndefDo (isDefEqNative t s) do - whenUndefDo (isDefEqNat t s) do - whenUndefDo (isDefEqOffset t s) do - whenUndefDo (isDefEqDelta t s) do - -- We try structure eta *after* lazy delta reduction; - -- otherwise we would end up applying it at every step of a reduction chain - -- as soon as one of the sides is a constructor application, - -- which is very costly because it requires us to unify the fields. - if (← (isDefEqEtaStruct t s <||> isDefEqEtaStruct s t)) then - return true - if t.isConst && s.isConst then - if t.constName! == s.constName! then isListLevelDefEqAux t.constLevels! s.constLevels! else return false - else if (← pure t.isApp <&&> pure s.isApp <&&> isDefEqApp t s) then - return true + let t' ← whnfCore t + let s' ← whnfCore s + if t != t' || s != s' then + Meta.isExprDefEqAux t' s' else - whenUndefDo (isDefEqProjInst t s) do - whenUndefDo (isDefEqStringLit t s) do - if (← isDefEqUnitLike t s) then return true else - isDefEqOnFailure t s + whenUndefDo (isDefEqNative t s) do + whenUndefDo (isDefEqNat t s) do + whenUndefDo (isDefEqOffset t s) do + whenUndefDo (isDefEqDelta t s) do + -- We try structure eta *after* lazy delta reduction; + -- otherwise we would end up applying it at every step of a reduction chain + -- as soon as one of the sides is a constructor application, + -- which is very costly because it requires us to unify the fields. + if (← (isDefEqEtaStruct t s <||> isDefEqEtaStruct s t)) then + return true + if t.isConst && s.isConst then + if t.constName! == s.constName! then isListLevelDefEqAux t.constLevels! s.constLevels! else return false + else if (← pure t.isApp <&&> pure s.isApp <&&> isDefEqApp t s) then + return true + else + whenUndefDo (isDefEqProjInst t s) do + whenUndefDo (isDefEqStringLit t s) do + if (← isDefEqUnitLike t s) then return true else + isDefEqOnFailure t s inductive DefEqCacheKind where | transient -- problem has mvars or is using nonstandard configuration, we should use transient cache @@ -1863,14 +1868,41 @@ partial def isExprDefEqAuxImpl (t : Expr) (s : Expr) : MetaM Bool := withIncRecD whenUndefDo (isDefEqProofIrrel t s) do /- We also reduce projections here to prevent expensive defeq checks when unifying TC operations. - When unifying e.g. `@Neg.neg α (@Field.toNeg α inst1) =?= @Neg.neg α (@Field.toNeg α inst2)`, + When unifying e.g. `(@Field.toNeg α inst1).1 =?= (@Field.toNeg α inst2).1`, we only want to unify negation (and not all other field operations as well). Unifying the field instances slowed down unification: https://github.com/leanprover/lean4/issues/1986 - We used to *not* reduce projections here, to support unifying `(?a).1 =?= (x, y).1`. - NOTE: this still seems to work because we don't eagerly unfold projection definitions to primitive projections. + + Note that ew use `proj := .yesWithDeltaI` to ensure `whnfI` is used to reduce the projection structure. + We added this refinement to address a performance issue in code such as + ``` + let val : Test := bar c1 key + have : val.1 = (bar c1 key).1 := rfl + ``` + where `bar` is a complex function that takes a long time to be reduced. + + Note that the current solution times out at unification problems such as + `(f x).1 =?= (g x).1` where `f`, `g` are defined as + ``` + structure Foo where + x : Nat + y : Nat + + def f (x : Nat) : Foo := + { x, y := ack 10 10 } + + def g (x : Nat) : Foo := + { x, y := ack 10 11 } + ``` + and `ack` is ackermann. We claim this is an abuse of the unifier. + That being said, we could in principle address this issue by implementing + lazy-delta reduction at `isDefEqProj`. + + The current solution should be sufficient. In the past, we have used + `whnfCore t (config := { proj := .yes })` which more conservative than `.yesWithDeltaI`, + and it only created performance issues when handling TC unification problems. -/ - let t' ← whnfCore t - let s' ← whnfCore s + let t' ← whnfCore t (config := { proj := .yesWithDeltaI }) + let s' ← whnfCore s (config := { proj := .yesWithDeltaI }) if t != t' || s != s' then isExprDefEqAuxImpl t' s' else diff --git a/src/Lean/Meta/WHNF.lean b/src/Lean/Meta/WHNF.lean index 1a5287901c29..a5b08970b734 100644 --- a/src/Lean/Meta/WHNF.lean +++ b/src/Lean/Meta/WHNF.lean @@ -342,6 +342,16 @@ inductive ProjReductionKind where Recall that `whnfCore` does not perform `delta` reduction (i.e., it will not unfold constant declarations), but `whnf` does. -/ | yesWithDelta + /-- + Projections `s.i` are reduced at `whnfCore`, and `whnfI` is used at `s` during the process. + Recall that `whnfI` is like `whnf` but uses transparency `instances`. + This option is stronger than `yes`, but weaker than `yesWithDelta`. + We use this option to ensure we reduce projections to prevent expensive defeq checks when unifying TC operations. + When unifying e.g. `(@Field.toNeg α inst1).1 =?= (@Field.toNeg α inst2).1`, + we only want to unify negation (and not all other field operations as well). + Unifying the field instances slowed down unification: https://github.com/leanprover/lean4/issues/1986 + -/ + | yesWithDeltaI deriving DecidableEq, Inhabited, Repr /-- @@ -566,12 +576,6 @@ private def whnfDelayedAssigned? (f' : Expr) (e : Expr) : MetaM (Option Expr) := /-- Apply beta-reduction, zeta-reduction (i.e., unfold let local-decls), iota-reduction, expand let-expressions, expand assigned meta-variables. - -The parameter `deltaAtProj` controls how to reduce projections `s.i`. If `deltaAtProj == true`, -then delta reduction is used to reduce `s` (i.e., `whnf` is used), otherwise `whnfCore`. - -If `simpleReduceOnly`, then `iota` and projection reduction are not performed. -Note that the value of `deltaAtProj` is irrelevant if `simpleReduceOnly = true`. -/ partial def whnfCore (e : Expr) (config : WhnfCoreConfig := {}): MetaM Expr := go e @@ -613,11 +617,15 @@ where return e | _ => return e | .proj _ i c => - if config.proj == .no then return e - let c ← if config.proj == .yesWithDelta then whnf c else go c - match (← projectCore? c i) with - | some e => go e - | none => return e + let k (c : Expr) := do + match (← projectCore? c i) with + | some e => go e + | none => return e + match config.proj with + | .no => return e + | .yes => k (← go c) + | .yesWithDelta => k (← whnf c) + | .yesWithDeltaI => k (← whnfI c) | _ => unreachable! /-- diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 1de1dc73dd90..9adb025818c2 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -376,16 +376,8 @@ expr type_checker::whnf_fvar(expr const & e, bool cheap_rec, bool cheap_proj) { return e; } -/* If `cheap == true`, then we don't perform delta-reduction when reducing major premise. */ -optional type_checker::reduce_proj(expr const & e, bool cheap_rec, bool cheap_proj) { - if (!proj_idx(e).is_small()) - return none_expr(); - unsigned idx = proj_idx(e).get_small_value(); - expr c; - if (cheap_proj) - c = whnf_core(proj_expr(e), cheap_rec, cheap_proj); - else - c = whnf(proj_expr(e)); +/* Auxiliary method for `reduce_proj` */ +optional type_checker::reduce_proj_core(expr c, unsigned idx) { if (is_string_lit(c)) c = string_lit_to_constructor(c); buffer args; @@ -402,6 +394,19 @@ optional type_checker::reduce_proj(expr const & e, bool cheap_rec, bool ch return none_expr(); } +/* If `cheap == true`, then we don't perform delta-reduction when reducing major premise. */ +optional type_checker::reduce_proj(expr const & e, bool cheap_rec, bool cheap_proj) { + if (!proj_idx(e).is_small()) + return none_expr(); + unsigned idx = proj_idx(e).get_small_value(); + expr c; + if (cheap_proj) + c = whnf_core(proj_expr(e), cheap_rec, cheap_proj); + else + c = whnf(proj_expr(e)); + return reduce_proj_core(c, idx); +} + static bool is_let_fvar(local_ctx const & lctx, expr const & e) { lean_assert(is_fvar(e)); if (optional decl = lctx.find_local_decl(e)) { @@ -983,6 +988,33 @@ lbool type_checker::lazy_delta_reduction(expr & t_n, expr & s_n) { } } +/* +Auxiliary method for checking `t_n.idx =?= s_n.idx`. +It lazily unfolds `t_n` and `s_n`. +Recall that the simpler approach used at `Meta.ExprDefEq` cannot be used in the +kernel since it does not have access to reducibility annotations. +The approach used here is more complicated, but it is also more powerful. +*/ +bool type_checker::lazy_delta_proj_reduction(expr & t_n, expr & s_n, nat const & idx) { + while (true) { + switch (lazy_delta_reduction_step(t_n, s_n)) { + case reduction_status::Continue: break; + case reduction_status::DefEqual: return true; + case reduction_status::DefUnknown: + case reduction_status::DefDiff: + if (idx.is_small()) { + unsigned i = idx.get_small_value(); + if (auto t = reduce_proj_core(t_n, i)) { + if (auto s = reduce_proj_core(s_n, i)) { + return is_def_eq_core(*t, *s); + }} + } + return is_def_eq_core(t_n, s_n); + } + } +} + + static expr * g_string_mk = nullptr; lbool type_checker::try_string_lit_expansion_core(expr const & t, expr const & s) { @@ -1054,8 +1086,12 @@ bool type_checker::is_def_eq_core(expr const & t, expr const & s) { if (is_fvar(t_n) && is_fvar(s_n) && fvar_name(t_n) == fvar_name(s_n)) return true; - if (is_proj(t_n) && is_proj(s_n) && proj_idx(t_n) == proj_idx(s_n) && is_def_eq(proj_expr(t_n), proj_expr(s_n))) - return true; + if (is_proj(t_n) && is_proj(s_n) && proj_idx(t_n) == proj_idx(s_n)) { + expr t_c = proj_expr(t_n); + expr s_c = proj_expr(s_n); + if (lazy_delta_proj_reduction(t_c, s_c, proj_idx(t_n))) + return true; + } // Invoke `whnf_core` again, but now using `whnf` to reduce projections. expr t_n_n = whnf_core(t_n); diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index cb15eaca74c8..e38a60772bf8 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -63,6 +63,7 @@ class type_checker { enum class reduction_status { Continue, DefUnknown, DefEqual, DefDiff }; optional reduce_recursor(expr const & e, bool cheap_rec, bool cheap_proj); + optional reduce_proj_core(expr c, unsigned idx); optional reduce_proj(expr const & e, bool cheap_rec, bool cheap_proj); expr whnf_fvar(expr const & e, bool cheap_rec, bool cheap_proj); optional is_delta(expr const & e) const; @@ -91,6 +92,7 @@ class type_checker { void cache_failure(expr const & t, expr const & s); reduction_status lazy_delta_reduction_step(expr & t_n, expr & s_n); lbool lazy_delta_reduction(expr & t_n, expr & s_n); + bool lazy_delta_proj_reduction(expr & t_n, expr & s_n, nat const & idx); bool is_def_eq_core(expr const & t, expr const & s); /** \brief Like \c check, but ignores undefined universes */ expr check_ignore_undefined_universes(expr const & e); diff --git a/tests/lean/run/1986.lean b/tests/lean/run/1986.lean index 68cf03faabab..04e5f6520c44 100644 --- a/tests/lean/run/1986.lean +++ b/tests/lean/run/1986.lean @@ -194,7 +194,7 @@ instance Pi.completeDistribLattice' {ι : Type _} {π : ι → Type _} [∀ i, CompleteDistribLattice (π i)] : CompleteDistribLattice (∀ i, π i) := CompleteDistribLattice.mk (Pi.coframe.infᵢ_sup_le_sup_infₛ) --- takes around 2 seconds wall clock time on my PC (but very quick in Lean 3) +-- User: takes around 2 seconds wall clock time on my PC (but very quick in Lean 3) set_option maxHeartbeats 400 -- make sure it stays fast set_option synthInstance.maxHeartbeats 400 instance Pi.completeDistribLattice'' {ι : Type _} {π : ι → Type _} diff --git a/tests/lean/run/bv_math_lit_perf.lean b/tests/lean/run/bv_math_lit_perf.lean index 8a43e7c2844f..098f2773d744 100644 --- a/tests/lean/run/bv_math_lit_perf.lean +++ b/tests/lean/run/bv_math_lit_perf.lean @@ -16,7 +16,7 @@ def f (x : BitVec 32) : Nat := | 920#32 => 12 | _ => 1000 -set_option maxHeartbeats 2800 +set_option maxHeartbeats 3000 example : f 500#32 = x := by simp [f] sorry diff --git a/tests/lean/run/isDefEqProjIssue.lean b/tests/lean/run/isDefEqProjIssue.lean new file mode 100644 index 000000000000..15ba43b86dc3 --- /dev/null +++ b/tests/lean/run/isDefEqProjIssue.lean @@ -0,0 +1,85 @@ +import Lean +open Lean + +-- We need a structure as this is related to isDefEq problems of the form `e1.proj =?= e2.proj`. +structure Test where + x : Nat + +-- We need a data structure with functions that are not meant for reduction purposes. +abbrev Cache := HashMap Nat Test + +def Cache.insert (cache : Cache) (key : Nat) (val : Test) : Cache := + HashMap.insert cache key val + +def Cache.find? (cache : Cache) (key : Nat) : Option Test := + HashMap.find? cache key + +-- This function just contains a call to a function that we definitely do not want to reduce. +-- To illustrate that the problem is actually noticeable there are multiple implementations provided. +-- Each of these implementations does additional modifications on the cache before looking things up, +-- as one might expect in irl functions. +-- Each version has a lot of additional complexity from the type checkers POV. +def barImpl1 (cache : Cache) (key : Nat) : Test := + match cache.find? key with + | some val => val + | none => ⟨0⟩ + +def barImpl2 (cache : Cache) (key : Nat) : Test := + match (cache.insert key ⟨0⟩).find? key with + | some val => val + | none => ⟨0⟩ + +def barImpl3 (cache : Cache) (key : Nat) : Test := + match ((cache.insert key ⟨0⟩).insert 0 ⟨0⟩).find? key with + | some val => val + | none => ⟨0⟩ + +def barImpl4 (cache : Cache) (key : Nat) : Test := + match (((cache.insert key ⟨0⟩).insert 0 ⟨0⟩).insert key ⟨key⟩).find? key with + | some val => val + | none => ⟨0⟩ + +def bar := barImpl4 + +set_option maxHeartbeats 400 in +def test (c1 : Cache) (key : Nat) : Nat := + go c1 key +where + go (c1 : Cache) (key : Nat) : Nat := + let val : Test := bar c1 key + have : val.x = (bar c1 key).x := rfl + val.x + +def ack : Nat → Nat → Nat + | 0, y => y+1 + | x+1, 0 => ack x 1 + | x+1, y+1 => ack x (ack (x+1) y) + +class Foo where + x : Nat + y : Nat + +instance f (x : Nat) : Foo := + { x, y := ack 10 10 } + +instance g (x : Nat) : Foo := + { x, y := ack 10 11 } + +open Lean Meta +set_option maxHeartbeats 400 in +run_meta do + withLocalDeclD `x (mkConst ``Nat) fun x => do + let lhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``f) x + let rhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``g) x + assert! (← isDefEq lhs rhs) + +run_meta do + withLocalDeclD `x (mkConst ``Nat) fun x => do + let lhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``f) x + let rhs := Expr.proj ``Foo 0 <| mkApp (mkConst ``g) x + match Kernel.isDefEq (← getEnv) {} lhs rhs with + | .ok b => assert! b + | .error _ => throwError "failed" + +example : (f x).1 = (g x).1 := + rfl