From 7acbee8ae454cb9be195b952e3a269bad2d1267c Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Mon, 18 Dec 2023 14:46:42 +0100 Subject: [PATCH] refactor: move unpackArg etc. to WF.PackDomain/WF.PackMutual (#3077) extracted from #3040 to keep the diff smaller --- src/Lean/Elab/PreDefinition/WF/GuessLex.lean | 34 +-------- .../Elab/PreDefinition/WF/PackDomain.lean | 13 ++++ .../Elab/PreDefinition/WF/PackMutual.lean | 73 +++++++++++++++---- 3 files changed, 72 insertions(+), 48 deletions(-) diff --git a/src/Lean/Elab/PreDefinition/WF/GuessLex.lean b/src/Lean/Elab/PreDefinition/WF/GuessLex.lean index 14a1c8773fd8..447e3e593025 100644 --- a/src/Lean/Elab/PreDefinition/WF/GuessLex.lean +++ b/src/Lean/Elab/PreDefinition/WF/GuessLex.lean @@ -14,6 +14,7 @@ import Lean.Elab.RecAppSyntax import Lean.Elab.PreDefinition.Basic import Lean.Elab.PreDefinition.Structural.Basic import Lean.Elab.PreDefinition.WF.TerminationHint +import Lean.Elab.PreDefinition.WF.PackMutual import Lean.Data.Array @@ -263,39 +264,6 @@ def filterSubsumed (rcs : Array RecCallWithContext ) : Array RecCallWithContext return (false, true) return (true, true) -/-- Given the packed argument of a (possibly) mutual and (possibly) nary call, -return the function index that is called and the arguments individually. - -We expect precisely the expressions produced by `packMutual`, with manifest -`PSum.inr`, `PSum.inl` and `PSigma.mk` constructors, and thus take them apart -rather than using projectinos. -/ -def unpackArg {m} [Monad m] [MonadError m] (arities : Array Nat) (e : Expr) : - m (Nat × Array Expr) := do - -- count PSum injections to find out which function is doing the call - let mut funidx := 0 - let mut e := e - while funidx + 1 < arities.size do - if e.isAppOfArity ``PSum.inr 3 then - e := e.getArg! 2 - funidx := funidx + 1 - else if e.isAppOfArity ``PSum.inl 3 then - e := e.getArg! 2 - break - else - throwError "Unexpected expression while unpacking mutual argument" - - -- now unpack PSigmas - let arity := arities[funidx]! - let mut args := #[] - while args.size + 1 < arity do - if e.isAppOfArity ``PSigma.mk 4 then - args := args.push (e.getArg! 2) - e := e.getArg! 3 - else - throwError "Unexpected expression while unpacking n-ary argument" - args := args.push e - return (funidx, args) - /-- Traverse a unary PreDefinition, and returns a `WithRecCall` closure for each recursive call site. -/ diff --git a/src/Lean/Elab/PreDefinition/WF/PackDomain.lean b/src/Lean/Elab/PreDefinition/WF/PackDomain.lean index f2083ec7fb3e..fdd83d35b6eb 100644 --- a/src/Lean/Elab/PreDefinition/WF/PackDomain.lean +++ b/src/Lean/Elab/PreDefinition/WF/PackDomain.lean @@ -40,6 +40,19 @@ where else return args[i]! +/-- Unpacks a unary packed argument created with `mkUnaryArg`. -/ +def unpackUnaryArg {m} [Monad m] [MonadError m] (arity : Nat) (e : Expr) : m (Array Expr) := do + let mut e := e + let mut args := #[] + while args.size + 1 < arity do + if e.isAppOfArity ``PSigma.mk 4 then + args := args.push (e.getArg! 2) + e := e.getArg! 3 + else + throwError "Unexpected expression while unpacking n-ary argument" + args := args.push e + return args + private partial def mkPSigmaCasesOn (y : Expr) (codomain : Expr) (xs : Array Expr) (value : Expr) : MetaM Expr := do let mvar ← mkFreshExprSyntheticOpaqueMVar codomain let rec go (mvarId : MVarId) (y : FVarId) (ys : Array Expr) : MetaM Unit := do diff --git a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean index 71bd812916b0..0167f9358156 100644 --- a/src/Lean/Elab/PreDefinition/WF/PackMutual.lean +++ b/src/Lean/Elab/PreDefinition/WF/PackMutual.lean @@ -5,6 +5,7 @@ Authors: Leonardo de Moura -/ import Lean.Meta.Tactic.Cases import Lean.Elab.PreDefinition.Basic +import Lean.Elab.PreDefinition.WF.PackDomain namespace Lean.Elab.WF open Meta @@ -110,8 +111,60 @@ def withAppN (n : Nat) (e : Expr) (k : Array Expr → MetaM Expr) : MetaM Expr : mkLambdaFVars xs e' /-- - Auxiliary function for replacing nested `preDefs` recursive calls in `e` with the new function `newFn`. - See: `packMutual` +If `arg` is the argument to the `fidx`th of the `numFuncs` in the recursive group, +then `mkMutualArg` packs that argument in `PSum.inl` and `PSum.inr` constructors +to create the mutual-packed argument of type `domain`. +-/ +partial def mkMutualArg (numFuncs : Nat) (domain : Expr) (fidx : Nat) (arg : Expr) : MetaM Expr := do + let rec go (i : Nat) (type : Expr) : MetaM Expr := do + if i == numFuncs - 1 then + return arg + else + (← whnfD type).withApp fun f args => do + assert! args.size == 2 + if i == fidx then + return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg + else + let r ← go (i+1) args[1]! + return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r + go 0 domain + +/-- +Unpacks a mutually packed argument, returning the argument and function index. +Inverse of `mkMutualArg`. Cf. `unpackUnaryArg` and `unpackArg`, which does both +-/ +def unpackMutualArg {m} [Monad m] [MonadError m] (numFuncs : Nat) (e : Expr) : m (Nat × Expr) := do + let mut funidx := 0 + let mut e := e + while funidx + 1 < numFuncs do + if e.isAppOfArity ``PSum.inr 3 then + e := e.getArg! 2 + funidx := funidx + 1 + else if e.isAppOfArity ``PSum.inl 3 then + e := e.getArg! 2 + break + else + throwError "Unexpected expression while unpacking mutual argument" + return (funidx, e) + +/-- +Given the packed argument of a (possibly) mutual and (possibly) nary call, +return the function index that is called and the arguments individually. + +We expect precisely the expressions produced by `packMutual`, with manifest +`PSum.inr`, `PSum.inl` and `PSigma.mk` constructors, and thus take them apart +rather than using projectinos. +-/ +def unpackArg {m} [Monad m] [MonadError m] (arities : Array Nat) (e : Expr) : + m (Nat × Array Expr) := do + let (funidx, e) ← unpackMutualArg arities.size e + let args ← unpackUnaryArg arities[funidx]! e + return (funidx, args) + + +/-- +Auxiliary function for replacing nested `preDefs` recursive calls in `e` with the new function `newFn`. +See: `packMutual` -/ private partial def post (fixedPrefix : Nat) (preDefs : Array PreDefinition) (domain : Expr) (newFn : Name) (e : Expr) : MetaM TransformStep := do let f := e.getAppFn @@ -122,19 +175,9 @@ private partial def post (fixedPrefix : Nat) (preDefs : Array PreDefinition) (do if let some fidx := preDefs.findIdx? (·.declName == declName) then let e' ← withAppN (fixedPrefix + 1) e fun args => do let fixedArgs := args[:fixedPrefix] - let arg := args[fixedPrefix]! - let rec mkNewArg (i : Nat) (type : Expr) : MetaM Expr := do - if i == preDefs.size - 1 then - return arg - else - (← whnfD type).withApp fun f args => do - assert! args.size == 2 - if i == fidx then - return mkApp3 (mkConst ``PSum.inl f.constLevels!) args[0]! args[1]! arg - else - let r ← mkNewArg (i+1) args[1]! - return mkApp3 (mkConst ``PSum.inr f.constLevels!) args[0]! args[1]! r - return mkApp (mkAppN (mkConst newFn us) fixedArgs) (← mkNewArg 0 domain) + let arg := args[fixedPrefix]! + let packedArg ← mkMutualArg preDefs.size domain fidx arg + return mkApp (mkAppN (mkConst newFn us) fixedArgs) packedArg return TransformStep.done e' return TransformStep.done e