diff --git a/Cargo.toml b/Cargo.toml index 9257b5dd..b18d332b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,6 +41,11 @@ default-features = false version = "0.3" default-features = false +[dependencies.sprs] +git = "https://github.com/nlhepler/sprs.git" +branch = "master" +optional = true + [dependencies.openblas-src] version = "0.6" default-features = false diff --git a/src/lapack/mod.rs b/src/lapack/mod.rs index c0f2690a..c07dbfa7 100644 --- a/src/lapack/mod.rs +++ b/src/lapack/mod.rs @@ -7,7 +7,6 @@ pub mod qr; pub mod solve; pub mod solveh; pub mod svd; -pub mod svddc; pub mod triangular; pub use self::cholesky::*; @@ -17,7 +16,6 @@ pub use self::qr::*; pub use self::solve::*; pub use self::solveh::*; pub use self::svd::*; -pub use self::svddc::*; pub use self::triangular::*; use super::error::*; @@ -26,7 +24,7 @@ use super::types::*; pub type Pivot = Vec; /// Trait for primitive types which implements LAPACK subroutines -pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + SVDDC_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {} +pub trait Lapack: OperatorNorm_ + QR_ + SVD_ + Solve_ + Solveh_ + Cholesky_ + Eigh_ + Triangular_ {} impl Lapack for f32 {} impl Lapack for f64 {} diff --git a/src/lapack/svd.rs b/src/lapack/svd.rs index ca06502a..9be9f491 100644 --- a/src/lapack/svd.rs +++ b/src/lapack/svd.rs @@ -9,13 +9,7 @@ use crate::types::*; use super::into_result; -#[repr(u8)] -enum FlagSVD { - All = b'A', - // OverWrite = b'O', - // Separately = b'S', - No = b'N', -} +use crate::svd::FlagSVD; /// Result of SVD pub struct SVDOutput { @@ -27,35 +21,45 @@ pub struct SVDOutput { pub vt: Option>, } -/// Wraps `*gesvd` +/// Wraps `*gesvd` and `*gesdd` pub trait SVD_: Scalar { - unsafe fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, a: &mut [Self]) -> Result>; + unsafe fn svd(l: MatrixLayout, jobu: FlagSVD, jobvt: FlagSVD, a: &mut [Self]) -> Result>; + unsafe fn svd_dc(l: MatrixLayout, jobz: FlagSVD, a: &mut [Self]) -> Result>; } macro_rules! impl_svd { - ($scalar:ty, $gesvd:path) => { + ($scalar:ty, $gesvd:path, $gesdd:path) => { impl SVD_ for $scalar { - unsafe fn svd(l: MatrixLayout, calc_u: bool, calc_vt: bool, mut a: &mut [Self]) -> Result> { + unsafe fn svd( + l: MatrixLayout, + jobu: FlagSVD, + jobvt: FlagSVD, + mut a: &mut [Self], + ) -> Result> { let (m, n) = l.size(); let k = ::std::cmp::min(n, m); let lda = l.lda(); - let (ju, ldu, mut u) = if calc_u { - (FlagSVD::All, m, vec![Self::zero(); (m * m) as usize]) - } else { - (FlagSVD::No, 1, Vec::new()) + let ucol = match jobu { + FlagSVD::All => m, + FlagSVD::Some => k, + FlagSVD::None => 0, }; - let (jvt, ldvt, mut vt) = if calc_vt { - (FlagSVD::All, n, vec![Self::zero(); (n * n) as usize]) - } else { - (FlagSVD::No, n, Vec::new()) + let vtrow = match jobvt { + FlagSVD::All => n, + FlagSVD::Some => k, + FlagSVD::None => 0, }; + let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; + let ldu = l.resized(m, ucol).lda(); + let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; + let ldvt = l.resized(vtrow, n).lda(); let mut s = vec![Self::Real::zero(); k as usize]; let mut superb = vec![Self::Real::zero(); (k - 1) as usize]; dbg!(ldvt); let info = $gesvd( l.lapacke_layout(), - ju as u8, - jvt as u8, + jobu as u8, + jobvt as u8, m, n, &mut a, @@ -71,8 +75,45 @@ macro_rules! impl_svd { info, SVDOutput { s: s, - u: if calc_u { Some(u) } else { None }, - vt: if calc_vt { Some(vt) } else { None }, + u: if jobu == FlagSVD::None { None } else { Some(u) }, + vt: if jobvt == FlagSVD::None { None } else { Some(vt) }, + }, + ) + } + + unsafe fn svd_dc(l: MatrixLayout, jobz: FlagSVD, mut a: &mut [Self]) -> Result> { + let (m, n) = l.size(); + let k = m.min(n); + let lda = l.lda(); + let (ucol, vtrow) = match jobz { + FlagSVD::All => (m, n), + FlagSVD::Some => (k, k), + FlagSVD::None => (0, 0), + }; + let mut s = vec![Self::Real::zero(); k.max(1) as usize]; + let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; + let ldu = l.resized(m, ucol).lda(); + let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; + let ldvt = l.resized(vtrow, n).lda(); + let info = $gesdd( + l.lapacke_layout(), + jobz as u8, + m, + n, + &mut a, + lda, + &mut s, + &mut u, + ldu, + &mut vt, + ldvt, + ); + into_result( + info, + SVDOutput { + s: s, + u: if jobz == FlagSVD::None { None } else { Some(u) }, + vt: if jobz == FlagSVD::None { None } else { Some(vt) }, }, ) } @@ -80,7 +121,7 @@ macro_rules! impl_svd { }; } // impl_svd! -impl_svd!(f64, lapacke::dgesvd); -impl_svd!(f32, lapacke::sgesvd); -impl_svd!(c64, lapacke::zgesvd); -impl_svd!(c32, lapacke::cgesvd); +impl_svd!(f64, lapacke::dgesvd, lapacke::dgesdd); +impl_svd!(f32, lapacke::sgesvd, lapacke::sgesdd); +impl_svd!(c64, lapacke::zgesvd, lapacke::zgesdd); +impl_svd!(c32, lapacke::cgesvd, lapacke::cgesdd); diff --git a/src/lapack/svddc.rs b/src/lapack/svddc.rs deleted file mode 100644 index 9c59b7aa..00000000 --- a/src/lapack/svddc.rs +++ /dev/null @@ -1,69 +0,0 @@ -use lapacke; -use num_traits::Zero; - -use crate::error::*; -use crate::layout::MatrixLayout; -use crate::types::*; -use crate::svddc::UVTFlag; - -use super::{SVDOutput, into_result}; - -pub trait SVDDC_: Scalar { - unsafe fn svddc(l: MatrixLayout, jobz: UVTFlag, a: &mut [Self]) -> Result>; -} - -macro_rules! impl_svdd { - ($scalar:ty, $gesdd:path) => { - impl SVDDC_ for $scalar { - unsafe fn svddc( - l: MatrixLayout, - jobz: UVTFlag, - mut a: &mut [Self], - ) -> Result> { - let (m, n) = l.size(); - let k = m.min(n); - let lda = l.lda(); - let (ucol, vtrow) = match jobz { - UVTFlag::Full => (m, n), - UVTFlag::Some => (k, k), - UVTFlag::None => (1, 1), - }; - let mut s = vec![Self::Real::zero(); k.max(1) as usize]; - let mut u = vec![Self::zero(); (m * ucol).max(1) as usize]; - let ldu = l.resized(m, ucol).lda(); - let mut vt = vec![Self::zero(); (vtrow * n).max(1) as usize]; - let ldvt = l.resized(vtrow, n).lda(); - let info = $gesdd( - l.lapacke_layout(), - jobz as u8, - m, - n, - &mut a, - lda, - &mut s, - &mut u, - ldu, - &mut vt, - ldvt, - ); - into_result( - info, - SVDOutput { - s: s, - u: if jobz == UVTFlag::None { None } else { Some(u) }, - vt: if jobz == UVTFlag::None { - None - } else { - Some(vt) - }, - }, - ) - } - } - }; -} - -impl_svdd!(f32, lapacke::sgesdd); -impl_svdd!(f64, lapacke::dgesdd); -impl_svdd!(c32, lapacke::cgesdd); -impl_svdd!(c64, lapacke::zgesdd); diff --git a/src/lib.rs b/src/lib.rs index 198449cc..3ced3120 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,7 +57,7 @@ pub mod qr; pub mod solve; pub mod solveh; pub mod svd; -pub mod svddc; +pub mod svd_rand; pub mod trace; pub mod triangular; pub mod types; @@ -77,7 +77,7 @@ pub use qr::*; pub use solve::*; pub use solveh::*; pub use svd::*; -pub use svddc::*; +pub use svd_rand::*; pub use trace::*; pub use triangular::*; pub use types::*; diff --git a/src/svd.rs b/src/svd.rs index 3acce0c2..696a7c8e 100644 --- a/src/svd.rs +++ b/src/svd.rs @@ -9,12 +9,36 @@ use super::error::*; use super::layout::*; use super::types::*; +#[derive(Clone, Copy, Eq, PartialEq)] +#[repr(u8)] +pub enum FlagSVD { + All = b'A', + // Overwrite = b'O', + Some = b'S', + None = b'N', +} + +impl Into for bool { + fn into(self) -> FlagSVD { + if self { + FlagSVD::All + } else { + FlagSVD::None + } + } +} + /// singular-value decomposition of matrix reference pub trait SVD { type U; type VT; type Sigma; - fn svd(&self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)>; + fn svd, Y: Into>( + &self, + calc_u: X, + calc_vt: Y, + ) -> Result<(Option, Self::Sigma, Option)>; + fn svd_dc>(&self, mode: X) -> Result<(Option, Self::Sigma, Option)>; } /// singular-value decomposition @@ -22,7 +46,12 @@ pub trait SVDInto { type U; type VT; type Sigma; - fn svd_into(self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)>; + fn svd_into, Y: Into>( + self, + calc_u: X, + calc_vt: Y, + ) -> Result<(Option, Self::Sigma, Option)>; + fn svd_dc_into>(self, mode: X) -> Result<(Option, Self::Sigma, Option)>; } /// singular-value decomposition for mutable reference of matrix @@ -30,35 +59,63 @@ pub trait SVDInplace { type U; type VT; type Sigma; - fn svd_inplace(&mut self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)>; + fn svd_inplace, Y: Into>( + &mut self, + calc_u: X, + calc_vt: Y, + ) -> Result<(Option, Self::Sigma, Option)>; + fn svd_dc_inplace>(&mut self, mode: X) + -> Result<(Option, Self::Sigma, Option)>; } -impl SVDInto for ArrayBase +impl SVD for ArrayBase where A: Scalar + Lapack, - S: DataMut, + S: Data, { type U = Array2; type VT = Array2; type Sigma = Array1; - fn svd_into(mut self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)> { - self.svd_inplace(calc_u, calc_vt) + fn svd(&self, calc_u: X, calc_vt: Y) -> Result<(Option, Self::Sigma, Option)> + where + X: Into, + Y: Into, + { + let a = self.to_owned(); + a.svd_into(calc_u, calc_vt) + } + + fn svd_dc(&self, mode: X) -> Result<(Option, Self::Sigma, Option)> + where + X: Into, + { + self.to_owned().svd_dc_into(mode) } } -impl SVD for ArrayBase +impl SVDInto for ArrayBase where A: Scalar + Lapack, - S: Data, + S: DataMut, { type U = Array2; type VT = Array2; type Sigma = Array1; - fn svd(&self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)> { - let a = self.to_owned(); - a.svd_into(calc_u, calc_vt) + fn svd_into(mut self, calc_u: X, calc_vt: Y) -> Result<(Option, Self::Sigma, Option)> + where + X: Into, + Y: Into, + { + self.svd_inplace(calc_u, calc_vt) + } + + fn svd_dc_into(mut self, mode: X) -> Result<(Option, Self::Sigma, Option)> + where + X: Into, + { + self.svd_dc_inplace(mode) } } @@ -71,9 +128,13 @@ where type VT = Array2; type Sigma = Array1; - fn svd_inplace(&mut self, calc_u: bool, calc_vt: bool) -> Result<(Option, Self::Sigma, Option)> { + fn svd_inplace(&mut self, calc_u: X, calc_vt: Y) -> Result<(Option, Self::Sigma, Option)> + where + X: Into, + Y: Into, + { let l = self.layout()?; - let svd_res = unsafe { A::svd(l, calc_u, calc_vt, self.as_allocated_mut()?)? }; + let svd_res = unsafe { A::svd(l, calc_u.into(), calc_vt.into(), self.as_allocated_mut()?)? }; let (n, m) = l.size(); let u = svd_res .u @@ -84,4 +145,28 @@ where let s = ArrayBase::from_vec(svd_res.s); Ok((u, s, vt)) } + + fn svd_dc_inplace(&mut self, mode: X) -> Result<(Option, Self::Sigma, Option)> + where + X: Into, + { + let mode = mode.into(); + let l = self.layout()?; + let svd_res = unsafe { A::svd_dc(l, mode, self.as_allocated_mut()?)? }; + let (m, n) = l.size(); + let k = m.min(n); + let (ldu, tdu, ldvt, tdvt) = match mode { + FlagSVD::All => (m, m, n, n), + FlagSVD::Some => (m, k, k, n), + FlagSVD::None => (1, 1, 1, 1), + }; + let u = svd_res + .u + .map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches")); + let vt = svd_res + .vt + .map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches")); + let s = ArrayBase::from_vec(svd_res.s); + Ok((u, s, vt)) + } } diff --git a/src/svd_rand.rs b/src/svd_rand.rs new file mode 100644 index 00000000..a842554f --- /dev/null +++ b/src/svd_rand.rs @@ -0,0 +1,150 @@ +//! Singular-value decomposition (SVD) +//! +//! [arXiv article on randomized linear algebra algorithms (including SVD)](https://arxiv.org/pdf/0909.4061.pdf) + +use ndarray::linalg::Dot; +use ndarray::*; +use rand::{ + distributions::{ + uniform::{SampleUniform, Uniform}, + Distribution, + }, + SeedableRng, +}; + +use super::error::*; +use super::qr::QR; +use super::svd::{FlagSVD, SVDInto}; +use super::types::*; + +#[cfg(feature = "sprs")] +use ::{sprs::{CsMatBase, SpIndex}, std::ops::Deref}; + +/// trait to capture shape of matrix +pub trait ArrayLike { + type A; + fn dim(&self) -> D::Pattern; +} + +impl ArrayLike for ArrayBase +where + D: ndarray::Dimension, + S: DataMut, +{ + type A = A; + fn dim(&self) -> D::Pattern { + ArrayBase::dim(self) + } +} + +#[cfg(feature = "sprs")] +impl ArrayLike for CsMatBase +where + I: SpIndex, + IptrStorage: Deref, + IndStorage: Deref, + DataStorage: Deref, +{ + type A = N; + fn dim(&self) -> ::Pattern { + (self.rows(), self.cols()) + } +} + +/// randomized truncated singular-value decomposition +pub trait SVDRand { + type U; + type VT; + type Sigma; + fn svd_rand( + &self, + k: usize, + n_iter: Option, + l: Option, + seed: Option, + ) -> Result<(Option, Self::Sigma, Option)>; +} + +impl SVDRand for T +where + A: Scalar + Lapack + SampleUniform, + T: ArrayLike + Dot, Output = Array2>, + for<'a> ArrayView2<'a, A>: Dot>, + Array2: QR> + Dot> + Dot, Output = Array2>, +{ + type U = Array2; + type VT = Array2; + type Sigma = Array1; + + fn svd_rand( + &self, + k: usize, + n_iter: Option, + l: Option, + seed: Option, + ) -> Result<(Option, Self::Sigma, Option)> { + let n_iter = n_iter.unwrap_or(7); + let l = l.unwrap_or(k + 2); + let (m, n) = self.dim(); + + if m < 2 || n < 2 { + panic!("m or n are <2!") + } + if m.min(n) < k { + panic!("min(m, n) is = Uniform::new(-A::one(), A::one()); + + // TODO(nlhepler): Additional cases to handle + // - fall through to straight svd when l/k is within ~25% of m or n. + + if m >= n { + let omega: Array2 = Array2::from_shape_fn((n, l), move |_| unif.sample(&mut rng)); + let mut q = self.dot(&omega).qr()?.0; + + for _ in 0..n_iter { + q = q.t().dot(self).reversed_axes().qr()?.0; + q = self.dot(&q).qr()?.0; + } + + let (u, s, vt) = { + let b = q.t().dot(self); + // info!("performing svd"); + let svd = b.svd_dc_into(FlagSVD::Some)?; + // info!("svd finished"); + ( + svd.0.unwrap().slice(s![.., ..k]).to_owned(), + svd.1.slice(s![..k]).to_owned(), + svd.2.unwrap().slice(s![..k, ..]).to_owned(), + ) + }; + + let u = q.dot(&u); + Ok((Some(u), s, Some(vt))) + } else { + // n > m + let omega = Array2::from_shape_fn((l, m), move |_| unif.sample(&mut rng)); + let mut q = omega.dot(self).reversed_axes().qr()?.0; + + for _ in 0..n_iter { + q = self.dot(&q).qr()?.0; + q = q.t().dot(self).reversed_axes().qr()?.0; + } + + let (u, s, vt) = { + let b = self.dot(&q); + let svd = b.svd_dc_into(FlagSVD::Some)?; + ( + svd.0.unwrap().slice(s![.., ..k]).to_owned(), + svd.1.slice(s![..k]).to_owned(), + svd.2.unwrap().slice(s![..k, ..]).to_owned(), + ) + }; + + let vt = vt.dot(&q.t()); + Ok((Some(u), s, Some(vt))) + } + } +} diff --git a/src/svddc.rs b/src/svddc.rs deleted file mode 100644 index 10a49668..00000000 --- a/src/svddc.rs +++ /dev/null @@ -1,110 +0,0 @@ -//! Singular-value decomposition (SVD) by divide-and-conquer (?gesdd) - -use ndarray::*; - -use super::convert::*; -use super::error::*; -use super::layout::*; -use super::types::*; - -#[derive(Clone, Copy, Eq, PartialEq)] -#[repr(u8)] -pub enum UVTFlag { - Full = b'A', - Some = b'S', - None = b'N', -} - -/// Singular-value decomposition of matrix (copying) by divide-and-conquer -pub trait SVDDC { - type U; - type VT; - type Sigma; - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)>; -} - -/// Singular-value decomposition of matrix by divide-and-conquer -pub trait SVDDCInto { - type U; - type VT; - type Sigma; - fn svddc_into( - self, - uvt_flag: UVTFlag, - ) -> Result<(Option, Self::Sigma, Option)>; -} - -/// Singular-value decomposition of matrix reference by divide-and-conquer -pub trait SVDDCInplace { - type U; - type VT; - type Sigma; - fn svddc_inplace( - &mut self, - uvt_flag: UVTFlag, - ) -> Result<(Option, Self::Sigma, Option)>; -} - -impl SVDDC for ArrayBase -where - A: Scalar + Lapack, - S: DataMut, -{ - type U = Array2; - type VT = Array2; - type Sigma = Array1; - - fn svddc(&self, uvt_flag: UVTFlag) -> Result<(Option, Self::Sigma, Option)> { - self.to_owned().svddc_into(uvt_flag) - } -} - -impl SVDDCInto for ArrayBase -where - A: Scalar + Lapack, - S: DataMut, -{ - type U = Array2; - type VT = Array2; - type Sigma = Array1; - - fn svddc_into( - mut self, - uvt_flag: UVTFlag, - ) -> Result<(Option, Self::Sigma, Option)> { - self.svddc_inplace(uvt_flag) - } -} - -impl SVDDCInplace for ArrayBase -where - A: Scalar + Lapack, - S: DataMut, -{ - type U = Array2; - type VT = Array2; - type Sigma = Array1; - - fn svddc_inplace( - &mut self, - uvt_flag: UVTFlag, - ) -> Result<(Option, Self::Sigma, Option)> { - let l = self.layout()?; - let svd_res = unsafe { A::svddc(l, uvt_flag, self.as_allocated_mut()?)? }; - let (m, n) = l.size(); - let k = m.min(n); - let (ldu, tdu, ldvt, tdvt) = match uvt_flag { - UVTFlag::Full => (m, m, n, n), - UVTFlag::Some => (m, k, k, n), - UVTFlag::None => (1, 1, 1, 1), - }; - let u = svd_res - .u - .map(|u| into_matrix(l.resized(ldu, tdu), u).expect("Size of U mismatches")); - let vt = svd_res - .vt - .map(|vt| into_matrix(l.resized(ldvt, tdvt), vt).expect("Size of VT mismatches")); - let s = ArrayBase::from_vec(svd_res.s); - Ok((u, s, vt)) - } -} diff --git a/tests/svd_dc.rs b/tests/svd_dc.rs new file mode 100644 index 00000000..4ae6c779 --- /dev/null +++ b/tests/svd_dc.rs @@ -0,0 +1,74 @@ +use ndarray::*; +use ndarray_linalg::*; + +fn test(a: &Array2, flag: FlagSVD) { + let (m, n) = a.dim(); + let k = m.min(n); + let answer = a.clone(); + println!("a = \n{:?}", a); + let (u, s, vt): (_, Array1<_>, _) = a.svd_dc(flag).unwrap(); + let mut sm = match flag { + FlagSVD::All => Array::zeros((m, n)), + FlagSVD::Some => Array::zeros((k, k)), + FlagSVD::None => { + assert!(u.is_none()); + assert!(vt.is_none()); + return; + }, + }; + let u: Array2<_> = u.unwrap(); + let vt: Array2<_> = vt.unwrap(); + println!("u = \n{:?}", &u); + println!("s = \n{:?}", &s); + println!("v = \n{:?}", &vt); + for i in 0..k { + sm[(i, i)] = s[i]; + } + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); +} + +macro_rules! test_svd_dc_impl { + ($m:expr, $n:expr) => { + paste::item! { + #[test] + fn []() { + let a = random(($m, $n)); + test(&a, FlagSVD::All); + } + + #[test] + fn []() { + let a = random(($m, $n)); + test(&a, FlagSVD::Some); + } + + #[test] + fn []() { + let a = random(($m, $n)); + test(&a, FlagSVD::None); + } + + #[test] + fn []() { + let a = random(($m, $n).f()); + test(&a, FlagSVD::All); + } + + #[test] + fn []() { + let a = random(($m, $n).f()); + test(&a, FlagSVD::Some); + } + + #[test] + fn []() { + let a = random(($m, $n).f()); + test(&a, FlagSVD::None); + } + } + }; +} + +test_svd_dc_impl!(3, 3); +test_svd_dc_impl!(4, 3); +test_svd_dc_impl!(3, 4); diff --git a/tests/svd_rand.rs b/tests/svd_rand.rs new file mode 100644 index 00000000..8d9517e8 --- /dev/null +++ b/tests/svd_rand.rs @@ -0,0 +1,43 @@ +use ndarray::*; +use ndarray_linalg::*; + +#[cfg(feature = "sprs")] +use sprs::{CsMatBase, CsMatI}; + +fn test, Sigma = Array1, VT = Array2>>(a: &A, k: usize, answer: &Array2) { + let s0 = answer.svd_dc(FlagSVD::Some).unwrap().1; + let (u, s, vt) = a.svd_rand(k, None, None, None).unwrap(); + let u = u.unwrap(); + let vt = vt.unwrap(); + let mut sm = Array::zeros((k, k)); + for i in 0..k { + sm[(i, i)] = s[i]; + } + assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 0.2); + assert_close_max!(&s, &s0.slice(s![..k]), 1e-7) +} + +macro_rules! test_svd_rand_impl { + ($m:expr, $n:expr, $k:expr) => { + paste::item! { + #[test] + fn []() { + // TODO(nlhepler): implement lower-rank matrices (w/ noise) for testing + let a = random(($m, $n)); + test(&a, $k, &a); + } + + #[cfg(feature = "sprs")] + #[test] + fn []() { + let a: Array2 = random(($m, $n)); + let a: CsMatI = CsMatBase::csr_from_dense(a.view(), 0.0); + test(&a, $k, &a.to_dense()); + } + } + }; +} + +test_svd_rand_impl!(20, 20, 17); +test_svd_rand_impl!(25, 20, 17); +test_svd_rand_impl!(20, 25, 17); diff --git a/tests/svddc.rs b/tests/svddc.rs deleted file mode 100644 index 1dbbf32b..00000000 --- a/tests/svddc.rs +++ /dev/null @@ -1,74 +0,0 @@ -use ndarray::*; -use ndarray_linalg::*; - -fn test(a: &Array2, flag: UVTFlag) { - let (n, m) = a.dim(); - let k = n.min(m); - let answer = a.clone(); - println!("a = \n{:?}", a); - let (u, s, vt): (_, Array1<_>, _) = a.svddc(flag).unwrap(); - let mut sm = match flag { - UVTFlag::Full => Array::zeros((n, m)), - UVTFlag::Some => Array::zeros((k, k)), - UVTFlag::None => { - assert!(u.is_none()); - assert!(vt.is_none()); - return; - }, - }; - let u: Array2<_> = u.unwrap(); - let vt: Array2<_> = vt.unwrap(); - println!("u = \n{:?}", &u); - println!("s = \n{:?}", &s); - println!("v = \n{:?}", &vt); - for i in 0..k { - sm[(i, i)] = s[i]; - } - assert_close_l2!(&u.dot(&sm).dot(&vt), &answer, 1e-7); -} - -macro_rules! test_svd_impl { - ($n:expr, $m:expr) => { - paste::item! { - #[test] - fn []() { - let a = random(($n, $m)); - test(&a, UVTFlag::Full); - } - - #[test] - fn []() { - let a = random(($n, $m)); - test(&a, UVTFlag::Some); - } - - #[test] - fn []() { - let a = random(($n, $m)); - test(&a, UVTFlag::None); - } - - #[test] - fn []() { - let a = random(($n, $m).f()); - test(&a, UVTFlag::Full); - } - - #[test] - fn []() { - let a = random(($n, $m).f()); - test(&a, UVTFlag::Some); - } - - #[test] - fn []() { - let a = random(($n, $m).f()); - test(&a, UVTFlag::None); - } - } - }; -} - -test_svd_impl!(3, 3); -test_svd_impl!(4, 3); -test_svd_impl!(3, 4);