Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: per-function termination hints #3040

Merged
merged 3 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 129 additions & 2 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,133 @@ example : x + foo 2 = 12 + x := by
simp_arith
```

* The syntax of the `termination_by` and `decreasing_by` termination hints is overhauled:

* They are now placed directly after the function they apply to, instead of
after the whole `mutual` block.
* Therefore, the function name no longer has to be mentioned in the hint.
* If the function has a `where` clause, the `termination_by` and
`decreasing_by` for that function come before the `where`. The
functions in the `where` clause can have their own termination hints, each
following the corresponding definition.
* The `termination_by` clause can only bind “extra parameters”, that are not
already bound by the function header, but are bound in a lambda (`:= fun x
y z =>`) or in patterns (`| x, n + 1 => …`). These extra parameters used to
be understood as a suffix of the function parameters; now it is a prefix.

Migration guide: In simple cases just remove the function name, and any
variables already bound at the header.
```diff
def foo : Nat → Nat → Nat := …
-termination_by foo a b => a - b
+termination_by a b => a - b
```
or
```diff
def foo : Nat → Nat → Nat := …
-termination_by _ a b => a - b
+termination_by a b => a - b
```

If the parameters are bound in the function header (before the `:`), remove them as well:
```diff
def foo (a b : Nat) : Nat := …
-termination_by foo a b => a - b
+termination_by a - b
```

Else, if there are multiple extra parameters, make sure to refer to the right
ones; the bound variables are interpreted from left to right, no longer from
right to left:
```diff
def foo : Nat → Nat → Nat → Nat
| a, b, c => …
-termination_by foo b c => b
+termination_by a b => b
```

In the case of a `mutual` block, place the termination arguments (without the
function name) next to the function definition:
```diff
-mutual
-def foo : Nat → Nat → Nat := …
-def bar : Nat → Nat := …
-end
-termination_by
- foo a b => a - b
- bar a => a
+mutual
+def foo : Nat → Nat → Nat := …
+termination_by a b => a - b
+def bar : Nat → Nat := …
+termination_by a => a
+end
```

Similarly, if you have (mutual) recursion through `where` or `let rec`, the
termination hints are now placed directly after the function they apply to:
```diff
-def foo (a b : Nat) : Nat := …
- where bar (x : Nat) : Nat := …
-termination_by
- foo a b => a - b
- bar x => x
+def foo (a b : Nat) : Nat := …
+termination_by a - b
+ where
+ bar (x : Nat) : Nat := …
+ termination_by x

-def foo (a b : Nat) : Nat :=
- let rec bar (x : Nat) : Nat := …
- …
-termination_by
- foo a b => a - b
- bar x => x
+def foo (a b : Nat) : Nat :=
+ let rec bar (x : Nat) : Nat := …
+ termination_by x
+ …
+termination_by a - b
```

In cases where a single `decreasing_by` clause applied to multiple mutually
recursive functions before, the tactic now has to be duplicated.

* The semantics of `decreasing_by` changed; the tactic is applied to all
termination proof goals together, not individually.

This helps when writing termination proofs interactively, as one can focus
each subgoal individually, for example using `·`. Previously, the given
tactic script had to work for _all_ goals, and one had to resort to tactic
combinators like `first`:

```diff
def foo (n : Nat) := … foo e1 … foo e2 …
-decreasing_by
-simp_wf
-first | apply something_about_e1; …
- | apply something_about_e2; …
+decreasing_by
+all_goals simp_wf
+· apply something_about_e1; …
+· apply something_about_e2; …
```

To obtain the old behaviour of applying a tactic to each goal individually,
use `all_goals`:
```diff
def foo (n : Nat) := …
-decreasing_by some_tactic
+decreasing_by all_goals some_tactic
```

In the case of mutual recursion each `decreasing_by` now applies to just its
function. If some functions in a recursive group do not have their own
`decreasing_by`, the default `decreasing_tactic` is used. If the same tactic
ought to be applied to multiple functions, the `decreasing_by` clause has to
be repeated at each of these functions.


v4.5.0
---------
Expand All @@ -88,15 +215,15 @@ v4.5.0
Migration guide: Use `termination_by` instead, e.g.:
```diff
-termination_by' measure (fun ⟨i, _⟩ => as.size - i)
+termination_by go i _ => as.size - i
+termination_by i _ => as.size - i
```

If the well-founded relation you want to use is not the one that the
`WellFoundedRelation` type class would infer for your termination argument,
you can use `WellFounded.wrap` from the std libarary to explicitly give one:
```diff
-termination_by' ⟨r, hwf⟩
+termination_by _ x => hwf.wrap x
+termination_by x => hwf.wrap x
```

* Support snippet edits in LSP `TextEdit`s. See `Lean.Lsp.SnippetString` for more details.
Expand Down
2 changes: 1 addition & 1 deletion doc/examples/palindromes.lean
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ theorem List.palindrome_ind (motive : List α → Prop)
have ih := palindrome_ind motive h₁ h₂ h₃ (a₂::as').dropLast
have : [a₁] ++ (a₂::as').dropLast ++ [(a₂::as').last (by simp)] = a₁::a₂::as' := by simp
this ▸ h₃ _ _ _ ih
termination_by _ as => as.length
termination_by as.length

/-!
We use our new induction principle to prove that if `as.reverse = as`, then `Palindrome as` holds.
Expand Down
26 changes: 13 additions & 13 deletions src/Init/Data/Array/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def mapM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α
map (i+1) (r.push (← f as[i]))
else
pure r
termination_by as.size - i
map 0 (mkEmpty as.size)
termination_by map => as.size - i

@[inline]
def mapIdxM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (f : Fin as.size → α → m β) : m (Array β) :=
Expand Down Expand Up @@ -348,12 +348,12 @@ def anyM {α : Type u} {m : Type → Type w} [Monad m] (p : α → m Bool) (as :
loop (j+1)
else
pure false
termination_by stop - j
loop start
if h : stop ≤ as.size then
any stop h
else
any as.size (Nat.le_refl _)
termination_by loop i j => stop - j

@[inline]
def allM {α : Type u} {m : Type → Type w} [Monad m] (p : α → m Bool) (as : Array α) (start := 0) (stop := as.size) : m Bool :=
Expand Down Expand Up @@ -523,7 +523,7 @@ def isEqvAux (a b : Array α) (hsz : a.size = b.size) (p : α → α → Bool) (
p a[i] b[i] && isEqvAux a b hsz p (i+1)
else
true
termination_by _ => a.size - i
termination_by a.size - i

@[inline] def isEqv (a b : Array α) (p : α → α → Bool) : Bool :=
if h : a.size = b.size then
Expand Down Expand Up @@ -627,7 +627,7 @@ def indexOfAux [BEq α] (a : Array α) (v : α) (i : Nat) : Option (Fin a.size)
if a.get idx == v then some idx
else indexOfAux a v (i+1)
else none
termination_by _ => a.size - i
termination_by a.size - i

def indexOf? [BEq α] (a : Array α) (v : α) : Option (Fin a.size) :=
indexOfAux a v 0
Expand Down Expand Up @@ -659,7 +659,7 @@ where
loop as (i+1) ⟨j-1, this⟩
else
as
termination_by _ => j - i
termination_by j - i

def popWhile (p : α → Bool) (as : Array α) : Array α :=
if h : as.size > 0 then
Expand All @@ -669,7 +669,7 @@ def popWhile (p : α → Bool) (as : Array α) : Array α :=
as
else
as
termination_by popWhile as => as.size
termination_by as.size

def takeWhile (p : α → Bool) (as : Array α) : Array α :=
let rec go (i : Nat) (r : Array α) : Array α :=
Expand All @@ -681,8 +681,8 @@ def takeWhile (p : α → Bool) (as : Array α) : Array α :=
r
else
r
termination_by as.size - i
go 0 #[]
termination_by go i r => as.size - i

def eraseIdxAux (i : Nat) (a : Array α) : Array α :=
if h : i < a.size then
Expand All @@ -692,7 +692,7 @@ def eraseIdxAux (i : Nat) (a : Array α) : Array α :=
eraseIdxAux (i+1) a'
else
a.pop
termination_by _ => a.size - i
termination_by a.size - i

def feraseIdx (a : Array α) (i : Fin a.size) : Array α :=
eraseIdxAux (i.val + 1) a
Expand All @@ -707,7 +707,7 @@ def eraseIdxSzAux (a : Array α) (i : Nat) (r : Array α) (heq : r.size = a.size
eraseIdxSzAux a (i+1) (r.swap idx idx1) ((size_swap r idx idx1).trans heq)
else
⟨r.pop, (size_pop r).trans (heq ▸ rfl)⟩
termination_by _ => r.size - i
termination_by r.size - i

def eraseIdx' (a : Array α) (i : Fin a.size) : { r : Array α // r.size = a.size - 1 } :=
eraseIdxSzAux a (i.val + 1) a rfl
Expand All @@ -726,10 +726,10 @@ def erase [BEq α] (as : Array α) (a : α) : Array α :=
loop as ⟨j', by rw [size_swap]; exact j'.2⟩
else
as
termination_by j.1
let j := as.size
let as := as.push a
loop as ⟨j, size_push .. ▸ j.lt_succ_self⟩
termination_by loop j => j.1

/-- Insert element `a` at position `i`. Panics if `i` is not `i ≤ as.size`. -/
def insertAt! (as : Array α) (i : Nat) (a : α) : Array α :=
Expand Down Expand Up @@ -779,7 +779,7 @@ def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size ≤ bs.size) (i : N
false
else
true
termination_by _ => as.size - i
termination_by as.size - i

/-- Return true iff `as` is a prefix of `bs`.
That is, `bs = as ++ t` for some `t : List α`.-/
Expand All @@ -800,7 +800,7 @@ private def allDiffAux [BEq α] (as : Array α) (i : Nat) : Bool :=
allDiffAuxAux as as[i] i h && allDiffAux as (i+1)
else
true
termination_by _ => as.size - i
termination_by as.size - i

def allDiff [BEq α] (as : Array α) : Bool :=
allDiffAux as 0
Expand All @@ -815,7 +815,7 @@ def allDiff [BEq α] (as : Array α) : Bool :=
cs
else
cs
termination_by _ => as.size - i
termination_by as.size - i

@[inline] def zipWith (as : Array α) (bs : Array β) (f : α → β → γ) : Array γ :=
zipWithAux f as bs 0 #[]
Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/Array/BasicAux.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
have hlt : i < as.size := Nat.lt_of_le_of_ne hle h
let b ← f as[i]
go (i+1) ⟨acc.val.push b, by simp [acc.property]⟩ hlt
termination_by go i _ _ => as.size - i
termination_by as.size - i

@[inline] private unsafe def mapMonoMImp [Monad m] (as : Array α) (f : α → m α) : m (Array α) :=
go 0 as
Expand Down
4 changes: 2 additions & 2 deletions src/Init/Data/Array/DecidableEq.lean
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size)
· have heq : i = a.size := Nat.le_antisymm hi (Nat.ge_of_not_lt h)
subst heq
exact absurd (Nat.lt_of_lt_of_le high low) (Nat.lt_irrefl j)
termination_by _ => a.size - i
termination_by a.size - i

theorem eq_of_isEqv [DecidableEq α] (a b : Array α) : Array.isEqv a b (fun x y => x = y) → a = b := by
simp [Array.isEqv]
Expand All @@ -36,7 +36,7 @@ theorem isEqvAux_self [DecidableEq α] (a : Array α) (i : Nat) : Array.isEqvAux
split
case inl h => simp [h, isEqvAux_self a (i+1)]
case inr h => simp [h]
termination_by _ => a.size - i
termination_by a.size - i

theorem isEqv_self [DecidableEq α] (a : Array α) : Array.isEqv a a (fun x y => x = y) = true := by
simp [isEqv, isEqvAux_self]
Expand Down
2 changes: 1 addition & 1 deletion src/Init/Data/Array/QSort.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def qpartition (as : Array α) (lt : α → α → Bool) (lo hi : Nat) : Nat ×
else
let as := as.swap! i hi
(i, as)
termination_by hi - j
loop as lo lo
termination_by _ => hi - j

@[inline] partial def qsort (as : Array α) (lt : α → α → Bool) (low := 0) (high := as.size - 1) : Array α :=
let rec @[specialize] sort (as : Array α) (low high : Nat) :=
Expand Down
8 changes: 4 additions & 4 deletions src/Init/Data/Nat/Power2.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ where
go (power * 2) (Nat.mul_pos h (by decide))
else
power
termination_by go p h => n - p
decreasing_by simp_wf; apply nextPowerOfTwo_dec <;> assumption
termination_by n - power
decreasing_by simp_wf; apply nextPowerOfTwo_dec <;> assumption

def isPowerOfTwo (n : Nat) := ∃ k, n = 2 ^ k

Expand All @@ -48,7 +48,7 @@ where
split
. exact isPowerOfTwo_go (power*2) (Nat.mul_pos h₁ (by decide)) (Nat.mul2_isPowerOfTwo_of_isPowerOfTwo h₂)
. assumption
termination_by isPowerOfTwo_go p _ _ => n - p
decreasing_by simp_wf; apply nextPowerOfTwo_dec <;> assumption
termination_by n - power
decreasing_by simp_wf; apply nextPowerOfTwo_dec <;> assumption

end Nat
Loading
Loading