From 6f7f44407874343eff5f0601e3707e04ef7cf127 Mon Sep 17 00:00:00 2001 From: madvorak Date: Fri, 9 Feb 2024 15:25:22 +0100 Subject: [PATCH 1/7] feat: Matrix.fromRows_mulVec and Matrix.vecMul_fromColumns --- Mathlib/Data/Matrix/ColumnRowPartitioned.lean | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean index 41b8de1cc7d70..f56da7004d17b 100644 --- a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean +++ b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean @@ -94,11 +94,11 @@ lemma toColumns₂_apply (A : Matrix m (n₁ ⊕ n₂) R) (i : m) (j : n₂) : (toColumns₂ A) i j = A i (Sum.inr j) := rfl @[simp] -lemma toColumns₁_fromColumns (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) : +lemma toColumns₁_fromColumns (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) : toColumns₁ (fromColumns A₁ A₂) = A₁ := rfl @[simp] -lemma toColumns₂_fromColumns (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) : +lemma toColumns₂_fromColumns (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) : toColumns₂ (fromColumns A₁ A₂) = A₂ := rfl @[simp] @@ -144,6 +144,16 @@ section Semiring variable [Semiring R] +@[simp] +lemma fromRows_mulVec (A₁ : Matrix m₁ n R) (A₂ : Matrix m₂ n R) (v : n → R) : + fromRows A₁ A₂ *ᵥ v = Sum.elim (A₁ *ᵥ v) (A₂ *ᵥ v) := by + ext (_ | _) <;> rfl + +@[simp] +lemma vecMul_fromColumns (B₁ : Matrix m n₁ R) (B₂ : Matrix m n₂ R) (v : m → R) : + v ᵥ* Matrix.fromColumns B₁ B₂ = Sum.elim (v ᵥ* B₁) (v ᵥ* B₂) := by + ext (_ | _) <;> rfl + @[simp] lemma fromRows_mul (A₁ : Matrix m₁ n R) (A₂ : Matrix m₂ n R) (B : Matrix n m R) : (fromRows A₁ A₂) * B = fromRows (A₁ * B) (A₂ * B) := by From e01db9126983b9a8e3d23d88cf79dc6211ab75c5 Mon Sep 17 00:00:00 2001 From: madvorak Date: Fri, 9 Feb 2024 16:35:10 +0100 Subject: [PATCH 2/7] sorry --- Mathlib/Data/Matrix/ColumnRowPartitioned.lean | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean index f56da7004d17b..346b0db95896c 100644 --- a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean +++ b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean @@ -154,6 +154,12 @@ lemma vecMul_fromColumns (B₁ : Matrix m n₁ R) (B₂ : Matrix m n₂ R) (v : v ᵥ* Matrix.fromColumns B₁ B₂ = Sum.elim (v ᵥ* B₁) (v ᵥ* B₂) := by ext (_ | _) <;> rfl +@[simp] +lemma fromColumns_mulVec_sum_elim (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) + (v₁ : n₁ → R) (v₂ : n₂ → R) : + (Matrix.fromColumns A₁ A₂) *ᵥ (Sum.elim v₁ v₂) = A₁ *ᵥ v₁ + A₂ *ᵥ v₂ := by + sorry + @[simp] lemma fromRows_mul (A₁ : Matrix m₁ n R) (A₂ : Matrix m₂ n R) (B : Matrix n m R) : (fromRows A₁ A₂) * B = fromRows (A₁ * B) (A₂ * B) := by From 6b56bd1e47532758e90f7091ee8eccd9af0444e1 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Fri, 9 Feb 2024 15:51:56 +0000 Subject: [PATCH 3/7] prove the lemma --- Mathlib/Data/Matrix/ColumnRowPartitioned.lean | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean index 346b0db95896c..6d594643e5585 100644 --- a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean +++ b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean @@ -157,17 +157,18 @@ lemma vecMul_fromColumns (B₁ : Matrix m n₁ R) (B₂ : Matrix m n₂ R) (v : @[simp] lemma fromColumns_mulVec_sum_elim (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) (v₁ : n₁ → R) (v₂ : n₂ → R) : - (Matrix.fromColumns A₁ A₂) *ᵥ (Sum.elim v₁ v₂) = A₁ *ᵥ v₁ + A₂ *ᵥ v₂ := by - sorry + fromColumns A₁ A₂ *ᵥ Sum.elim v₁ v₂ = A₁ *ᵥ v₁ + A₂ *ᵥ v₂ := by + ext x + simp [Matrix.mulVec, fromColumns] @[simp] lemma fromRows_mul (A₁ : Matrix m₁ n R) (A₂ : Matrix m₂ n R) (B : Matrix n m R) : - (fromRows A₁ A₂) * B = fromRows (A₁ * B) (A₂ * B) := by + fromRows A₁ A₂ * B = fromRows (A₁ * B) (A₂ * B) := by ext (_ | _) _ <;> simp [mul_apply] @[simp] lemma mul_fromColumns (A : Matrix m n R) (B₁ : Matrix n n₁ R) (B₂ : Matrix n n₂ R) : - A * (fromColumns B₁ B₂) = fromColumns (A * B₁) (A * B₂) := by + A * fromColumns B₁ B₂ = fromColumns (A * B₁) (A * B₂) := by ext _ (_ | _) <;> simp [mul_apply] @[simp] From e452db059eb259d4263667e01ded9e98660272b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martin=20Dvo=C5=99=C3=A1k?= Date: Fri, 9 Feb 2024 16:54:44 +0100 Subject: [PATCH 4/7] Update Mathlib/Data/Matrix/ColumnRowPartitioned.lean Co-authored-by: Eric Wieser --- Mathlib/Data/Matrix/ColumnRowPartitioned.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean index 6d594643e5585..322385d55c300 100644 --- a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean +++ b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean @@ -151,7 +151,7 @@ lemma fromRows_mulVec (A₁ : Matrix m₁ n R) (A₂ : Matrix m₂ n R) (v : n @[simp] lemma vecMul_fromColumns (B₁ : Matrix m n₁ R) (B₂ : Matrix m n₂ R) (v : m → R) : - v ᵥ* Matrix.fromColumns B₁ B₂ = Sum.elim (v ᵥ* B₁) (v ᵥ* B₂) := by + v ᵥ* fromColumns B₁ B₂ = Sum.elim (v ᵥ* B₁) (v ᵥ* B₂) := by ext (_ | _) <;> rfl @[simp] From 9fc589462a0b4052034841a0c4f9dfc73e45edad Mon Sep 17 00:00:00 2001 From: madvorak Date: Fri, 9 Feb 2024 17:22:09 +0100 Subject: [PATCH 5/7] sum_elim_vecMul_fromRows --- Mathlib/Data/Matrix/ColumnRowPartitioned.lean | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean index 322385d55c300..45e10f59f5df5 100644 --- a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean +++ b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean @@ -154,11 +154,24 @@ lemma vecMul_fromColumns (B₁ : Matrix m n₁ R) (B₂ : Matrix m n₂ R) (v : v ᵥ* fromColumns B₁ B₂ = Sum.elim (v ᵥ* B₁) (v ᵥ* B₂) := by ext (_ | _) <;> rfl +@[simp] +lemma sum_elim_vecMul_fromRows (B₁ : Matrix m₁ n R) (B₂ : Matrix m₂ n R) + (v₁ : m₁ → R) (v₂ : m₂ → R) : + Sum.elim v₁ v₂ ᵥ* fromRows B₁ B₂ = v₁ ᵥ* B₁ + v₂ ᵥ* B₂ := by + ext j + rw [vecMul, fromRows] + convert_to + Sum.elim v₁ v₂ ⬝ᵥ Sum.elim (fun i ↦ B₁ i j) (fun i ↦ B₂ i j) = + (v₁ ⬝ᵥ fun i ↦ B₁ i j) + (v₂ ⬝ᵥ fun i ↦ B₂ i j) using 2 + · ext i + cases i <;> rfl + rw [Matrix.sum_elim_dotProduct_sum_elim] + @[simp] lemma fromColumns_mulVec_sum_elim (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) (v₁ : n₁ → R) (v₂ : n₂ → R) : fromColumns A₁ A₂ *ᵥ Sum.elim v₁ v₂ = A₁ *ᵥ v₁ + A₂ *ᵥ v₂ := by - ext x + ext simp [Matrix.mulVec, fromColumns] @[simp] From 3b618bf56cdf0cf7811a33f8c5b898bcacd32a0f Mon Sep 17 00:00:00 2001 From: madvorak Date: Fri, 9 Feb 2024 17:26:52 +0100 Subject: [PATCH 6/7] golf --- Mathlib/Data/Matrix/ColumnRowPartitioned.lean | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean index 45e10f59f5df5..7c2fa2a244456 100644 --- a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean +++ b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean @@ -158,14 +158,10 @@ lemma vecMul_fromColumns (B₁ : Matrix m n₁ R) (B₂ : Matrix m n₂ R) (v : lemma sum_elim_vecMul_fromRows (B₁ : Matrix m₁ n R) (B₂ : Matrix m₂ n R) (v₁ : m₁ → R) (v₂ : m₂ → R) : Sum.elim v₁ v₂ ᵥ* fromRows B₁ B₂ = v₁ ᵥ* B₁ + v₂ ᵥ* B₂ := by - ext j + ext rw [vecMul, fromRows] - convert_to - Sum.elim v₁ v₂ ⬝ᵥ Sum.elim (fun i ↦ B₁ i j) (fun i ↦ B₂ i j) = - (v₁ ⬝ᵥ fun i ↦ B₁ i j) + (v₂ ⬝ᵥ fun i ↦ B₂ i j) using 2 - · ext i - cases i <;> rfl - rw [Matrix.sum_elim_dotProduct_sum_elim] + convert sum_elim_dotProduct_sum_elim .. + aesop @[simp] lemma fromColumns_mulVec_sum_elim (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R) From 3ac06b52565259a8d9228c9efe04634b62042418 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Fri, 9 Feb 2024 19:17:50 +0000 Subject: [PATCH 7/7] simplify the proof --- Mathlib/Data/Matrix/ColumnRowPartitioned.lean | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean index 7c2fa2a244456..6536c880b207b 100644 --- a/Mathlib/Data/Matrix/ColumnRowPartitioned.lean +++ b/Mathlib/Data/Matrix/ColumnRowPartitioned.lean @@ -159,9 +159,7 @@ lemma sum_elim_vecMul_fromRows (B₁ : Matrix m₁ n R) (B₂ : Matrix m₂ n R) (v₁ : m₁ → R) (v₂ : m₂ → R) : Sum.elim v₁ v₂ ᵥ* fromRows B₁ B₂ = v₁ ᵥ* B₁ + v₂ ᵥ* B₂ := by ext - rw [vecMul, fromRows] - convert sum_elim_dotProduct_sum_elim .. - aesop + simp [Matrix.vecMul, fromRows, dotProduct] @[simp] lemma fromColumns_mulVec_sum_elim (A₁ : Matrix m n₁ R) (A₂ : Matrix m n₂ R)