diff --git a/argmin-math/src/lib.rs b/argmin-math/src/lib.rs index 310fc970f..57db4e5fa 100644 --- a/argmin-math/src/lib.rs +++ b/argmin-math/src/lib.rs @@ -251,6 +251,12 @@ pub trait ArgminScaledSub { fn scaled_sub(&self, factor: &U, vec: &T) -> V; } +/// Compute the l2-norm (`U`) of `self` +pub trait ArgminL1Norm { + /// Compute the l1-norm (`U`) of `self` + fn l1_norm(&self) -> U; +} + /// Compute the l2-norm (`U`) of `self` pub trait ArgminNorm { /// Compute the l2-norm (`U`) of `self` @@ -283,3 +289,9 @@ pub trait ArgminMinMax { /// Select piecewise maximum fn max(x: &Self, y: &Self) -> Self; } + +/// Returns a number that represents the sign of `self`. +pub trait ArgminSignum { + /// Returns a number that represents the sign of `self`. + fn signum(self) -> Self; +} diff --git a/argmin-math/src/nalgebra_m/l1norm.rs b/argmin-math/src/nalgebra_m/l1norm.rs new file mode 100644 index 000000000..0784f157b --- /dev/null +++ b/argmin-math/src/nalgebra_m/l1norm.rs @@ -0,0 +1,67 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use crate::ArgminL1Norm; + +use nalgebra::{ + base::{dimension::Dim, storage::Storage}, + LpNorm, Matrix, SimdComplexField, +}; + +impl ArgminL1Norm for Matrix +where + N: SimdComplexField, + R: Dim, + C: Dim, + S: Storage, +{ + #[inline] + fn l1_norm(&self) -> N::SimdRealField { + self.apply_norm(&LpNorm(1)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use nalgebra::Vector2; + use paste::item; + + macro_rules! make_test { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = Vector2::new(4 as $t, 3 as $t); + let res = as ArgminL1Norm<$t>>::l1_norm(&a); + let target = 7 as $t; + assert!(((target - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + macro_rules! make_test_signed { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = Vector2::new(-4 as $t, -3 as $t); + let res = as ArgminL1Norm<$t>>::l1_norm(&a); + let target = 7 as $t; + assert!(((target - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + make_test!(f32); + make_test!(f64); + + make_test_signed!(f32); + make_test_signed!(f64); +} diff --git a/argmin-math/src/nalgebra_m/mod.rs b/argmin-math/src/nalgebra_m/mod.rs index b7c47e643..0d9cea7c6 100644 --- a/argmin-math/src/nalgebra_m/mod.rs +++ b/argmin-math/src/nalgebra_m/mod.rs @@ -11,6 +11,7 @@ mod div; mod dot; mod eye; mod inv; +mod l1norm; mod mul; mod norm; mod scaledadd; @@ -25,6 +26,7 @@ pub use div::*; pub use dot::*; pub use eye::*; pub use inv::*; +pub use l1norm::*; pub use mul::*; pub use norm::*; pub use scaledadd::*; diff --git a/argmin-math/src/ndarray_m/l1norm.rs b/argmin-math/src/ndarray_m/l1norm.rs new file mode 100644 index 000000000..9447a15e6 --- /dev/null +++ b/argmin-math/src/ndarray_m/l1norm.rs @@ -0,0 +1,132 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use crate::ArgminL1Norm; +use ndarray::Array1; +use num_complex::Complex; + +macro_rules! make_l1norm_float { + ($t:ty) => { + impl ArgminL1Norm<$t> for Array1<$t> { + #[inline] + fn l1_norm(&self) -> $t { + self.iter().map(|a| a.abs()).sum() + } + } + }; +} + +macro_rules! make_l1norm_complex_float { + ($t:ty) => { + impl ArgminL1Norm> for Array1> { + #[inline] + fn l1_norm(&self) -> Complex<$t> { + self.iter().map(|a| a.l1_norm()).sum::<$t>().into() + } + } + + impl ArgminL1Norm<$t> for Array1> { + #[inline] + fn l1_norm(&self) -> $t { + self.iter().map(|a| a.l1_norm()).sum() + } + } + }; +} + +macro_rules! make_l1norm_unsigned { + ($t:ty) => { + impl ArgminL1Norm<$t> for Array1<$t> { + #[inline] + fn l1_norm(&self) -> $t { + self.sum() + } + } + }; +} + +macro_rules! make_l1norm_integer { + ($t:ty) => { + impl ArgminL1Norm<$t> for Array1<$t> { + #[inline] + fn l1_norm(&self) -> $t { + self.iter().map(|a| a.abs()).sum() + } + } + }; +} + +make_l1norm_integer!(isize); +make_l1norm_unsigned!(usize); +make_l1norm_integer!(i8); +make_l1norm_integer!(i16); +make_l1norm_integer!(i32); +make_l1norm_integer!(i64); +make_l1norm_unsigned!(u8); +make_l1norm_unsigned!(u16); +make_l1norm_unsigned!(u32); +make_l1norm_unsigned!(u64); +make_l1norm_float!(f32); +make_l1norm_float!(f64); +make_l1norm_complex_float!(f32); +make_l1norm_complex_float!(f64); + +#[cfg(test)] +mod tests { + use super::*; + use ndarray::{array, Array1}; + use paste::item; + + macro_rules! make_test { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = array![4 as $t, 3 as $t]; + let res = as ArgminL1Norm<$t>>::l1_norm(&a); + let target = 7 as $t; + assert!(((target - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + macro_rules! make_test_signed { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = array![-4 as $t, -3 as $t]; + let res = as ArgminL1Norm<$t>>::l1_norm(&a); + let target = 7 as $t; + assert!(((target - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + make_test!(isize); + make_test!(usize); + make_test!(i8); + make_test!(u8); + make_test!(i16); + make_test!(u16); + make_test!(i32); + make_test!(u32); + make_test!(i64); + make_test!(u64); + make_test!(f32); + make_test!(f64); + + make_test_signed!(isize); + make_test_signed!(i8); + make_test_signed!(i16); + make_test_signed!(i32); + make_test_signed!(i64); + make_test_signed!(f32); + make_test_signed!(f64); +} diff --git a/argmin-math/src/ndarray_m/minmax.rs b/argmin-math/src/ndarray_m/minmax.rs new file mode 100644 index 000000000..ba491e743 --- /dev/null +++ b/argmin-math/src/ndarray_m/minmax.rs @@ -0,0 +1,77 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: Tests for Array2 impl + +use crate::ArgminMinMax; +use ndarray::{Array1, Array2}; + +macro_rules! make_minmax { + ($t:ty) => { + impl ArgminMinMax for Array1<$t> { + #[inline] + fn min(x: &Self, y: &Self) -> Array1<$t> { + assert_eq!(x.shape(), y.shape()); + x.iter() + .zip(y) + .map(|(&a, &b)| if a < b { a } else { b }) + .collect() + } + + #[inline] + fn max(x: &Self, y: &Self) -> Array1<$t> { + assert_eq!(x.shape(), y.shape()); + x.iter() + .zip(y) + .map(|(&a, &b)| if a > b { a } else { b }) + .collect() + } + } + + impl ArgminMinMax for Array2<$t> { + #[inline] + fn min(x: &Self, y: &Self) -> Array2<$t> { + assert_eq!(x.shape(), y.shape()); + let m = x.shape()[0]; + let n = x.shape()[1]; + let mut out = x.clone(); + for i in 0..m { + for j in 0..n { + let a = x[(i, j)]; + let b = y[(i, j)]; + out[(i, j)] = if a < b { a } else { b }; + } + } + out + } + + #[inline] + fn max(x: &Self, y: &Self) -> Array2<$t> { + assert_eq!(x.shape(), y.shape()); + let m = x.shape()[0]; + let n = x.shape()[1]; + let mut out = x.clone(); + for i in 0..m { + for j in 0..n { + let a = x[(i, j)]; + let b = y[(i, j)]; + out[(i, j)] = if a > b { a } else { b }; + } + } + out + } + } + }; +} + +make_minmax!(isize); +make_minmax!(i8); +make_minmax!(i16); +make_minmax!(i32); +make_minmax!(i64); +make_minmax!(f32); +make_minmax!(f64); diff --git a/argmin-math/src/ndarray_m/mod.rs b/argmin-math/src/ndarray_m/mod.rs index b7c47e643..df563da8f 100644 --- a/argmin-math/src/ndarray_m/mod.rs +++ b/argmin-math/src/ndarray_m/mod.rs @@ -11,10 +11,13 @@ mod div; mod dot; mod eye; mod inv; +mod l1norm; +mod minmax; mod mul; mod norm; mod scaledadd; mod scaledsub; +mod signum; mod sub; mod transpose; mod zero; @@ -25,10 +28,13 @@ pub use div::*; pub use dot::*; pub use eye::*; pub use inv::*; +pub use l1norm::*; +pub use minmax::*; pub use mul::*; pub use norm::*; pub use scaledadd::*; pub use scaledsub::*; +pub use signum::*; pub use sub::*; pub use transpose::*; pub use zero::*; diff --git a/argmin-math/src/ndarray_m/signum.rs b/argmin-math/src/ndarray_m/signum.rs new file mode 100644 index 000000000..0006967ee --- /dev/null +++ b/argmin-math/src/ndarray_m/signum.rs @@ -0,0 +1,140 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: Tests for Array2 impl + +use crate::ArgminSignum; +use ndarray::{Array1, Array2}; +use num_complex::Complex; + +macro_rules! make_signum { + ($t:ty) => { + impl ArgminSignum for Array1<$t> { + #[inline] + fn signum(mut self) -> Array1<$t> { + for a in &mut self { + *a = a.signum(); + } + self + } + } + + impl ArgminSignum for Array2<$t> { + #[inline] + fn signum(mut self) -> Array2<$t> { + let m = self.shape()[0]; + let n = self.shape()[1]; + for i in 0..m { + for j in 0..n { + self[(i, j)] = self[(i, j)].signum(); + } + } + self + } + } + }; +} + +macro_rules! make_signum_complex { + ($t:ty) => { + impl ArgminSignum for Array1<$t> { + #[inline] + fn signum(mut self) -> Array1<$t> { + for a in &mut self { + a.re = a.re.signum(); + a.im = a.im.signum(); + } + self + } + } + + impl ArgminSignum for Array2<$t> { + #[inline] + fn signum(mut self) -> Array2<$t> { + let m = self.shape()[0]; + let n = self.shape()[1]; + for i in 0..m { + for j in 0..n { + self[(i, j)].re = self[(i, j)].re.signum(); + self[(i, j)].im = self[(i, j)].im.signum(); + } + } + self + } + } + }; +} + +make_signum!(isize); +make_signum!(i8); +make_signum!(i16); +make_signum!(i32); +make_signum!(i64); +make_signum!(f32); +make_signum!(f64); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); + +#[cfg(test)] +mod tests { + use super::*; + use paste::item; + + macro_rules! make_test { + ($t:ty) => { + item! { + #[test] + fn []() { + let x = Array1::from(vec![ + Complex::new(1 as $t, 2 as $t), + Complex::new(4 as $t, -3 as $t), + Complex::new(-8 as $t, 4 as $t), + Complex::new(-8 as $t, -1 as $t), + ]); + let y = Array1::from(vec![ + Complex::new(1 as $t, 1 as $t), + Complex::new(1 as $t, -1 as $t), + Complex::new(-1 as $t, 1 as $t), + Complex::new(-1 as $t, -1 as $t), + ]); + let res = > as ArgminSignum>::signum(x); + for i in 0..4 { + let tmp = y[i] - res[i]; + let norm = ((tmp.re * tmp.re + tmp.im * tmp.im) as f64).sqrt(); + assert!(norm < std::f64::EPSILON); + } + } + } + + item! { + #[test] + fn []() { + let x = Array1::from(vec![1 as $t, -4 as $t, 8 as $t]); + let y = Array1::from(vec![1 as $t, -1 as $t, 1 as $t]); + let res = as ArgminSignum>::signum(x); + for i in 0..3 { + let diff = (y[i] - res[i]).abs() as f64; + assert!(diff < std::f64::EPSILON); + } + } + } + }; + } + + make_test!(isize); + make_test!(i8); + make_test!(i16); + make_test!(i32); + make_test!(i64); + make_test!(f32); + make_test!(f64); +} diff --git a/argmin-math/src/primitives/l1norm.rs b/argmin-math/src/primitives/l1norm.rs new file mode 100644 index 000000000..8c87f9c37 --- /dev/null +++ b/argmin-math/src/primitives/l1norm.rs @@ -0,0 +1,110 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use crate::ArgminL1Norm; +use num_complex::Complex; + +macro_rules! make_l1norm_unsigned { + ($t:ty) => { + impl ArgminL1Norm<$t> for $t { + #[inline] + fn l1_norm(&self) -> $t { + *self + } + } + }; +} + +macro_rules! make_l1norm { + ($t:ty) => { + impl ArgminL1Norm<$t> for $t { + #[inline] + fn l1_norm(&self) -> $t { + self.abs() + } + } + }; +} + +macro_rules! make_l1norm_complex { + ($t:ty) => { + impl ArgminL1Norm<$t> for Complex<$t> { + #[inline] + fn l1_norm(&self) -> $t { + self.l1_norm() + } + } + }; +} + +make_l1norm!(isize); +make_l1norm_unsigned!(usize); +make_l1norm!(i8); +make_l1norm!(i16); +make_l1norm!(i32); +make_l1norm!(i64); +make_l1norm_unsigned!(u8); +make_l1norm_unsigned!(u16); +make_l1norm_unsigned!(u32); +make_l1norm_unsigned!(u64); +make_l1norm!(f32); +make_l1norm!(f64); +make_l1norm_complex!(f32); +make_l1norm_complex!(f64); + +#[cfg(test)] +mod tests { + use super::*; + use paste::item; + + macro_rules! make_test { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = 8 as $t; + let res = <$t as ArgminL1Norm<$t>>::l1_norm(&a); + assert!(((a - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + macro_rules! make_test_signed { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = -8 as $t; + let res = <$t as ArgminL1Norm<$t>>::l1_norm(&a); + assert!(((8 as $t - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + make_test!(isize); + make_test!(usize); + make_test!(i8); + make_test!(u8); + make_test!(i16); + make_test!(u16); + make_test!(i32); + make_test!(u32); + make_test!(i64); + make_test!(u64); + make_test!(f32); + make_test!(f64); + + make_test_signed!(isize); + make_test_signed!(i8); + make_test_signed!(i16); + make_test_signed!(i32); + make_test_signed!(i64); + make_test_signed!(f32); + make_test_signed!(f64); +} diff --git a/argmin-math/src/primitives/mod.rs b/argmin-math/src/primitives/mod.rs index 366da7751..dae4e9222 100644 --- a/argmin-math/src/primitives/mod.rs +++ b/argmin-math/src/primitives/mod.rs @@ -9,6 +9,7 @@ mod add; mod conj; mod div; mod dot; +mod l1norm; mod mul; mod norm; mod scaledadd; @@ -22,6 +23,7 @@ pub use add::*; pub use conj::*; pub use div::*; pub use dot::*; +pub use l1norm::*; pub use mul::*; pub use norm::*; pub use scaledadd::*; diff --git a/argmin-math/src/vec/l1norm.rs b/argmin-math/src/vec/l1norm.rs new file mode 100644 index 000000000..45c241042 --- /dev/null +++ b/argmin-math/src/vec/l1norm.rs @@ -0,0 +1,123 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use crate::ArgminL1Norm; +use num_complex::Complex; + +macro_rules! make_l1norm_float { + ($t:ty) => { + impl ArgminL1Norm<$t> for Vec<$t> { + #[inline] + fn l1_norm(&self) -> $t { + self.iter().map(|a| a.abs()).sum() + } + } + }; +} + +macro_rules! make_l1norm_complex_float { + ($t:ty) => { + impl ArgminL1Norm> for Vec> { + #[inline] + fn l1_norm(&self) -> Complex<$t> { + self.iter().map(|a| a.l1_norm()).sum::<$t>().into() + } + } + }; +} + +macro_rules! make_l1norm_unsigned { + ($t:ty) => { + impl ArgminL1Norm<$t> for Vec<$t> { + #[inline] + fn l1_norm(&self) -> $t { + self.iter().sum() + } + } + }; +} + +macro_rules! make_l1norm_integer { + ($t:ty) => { + impl ArgminL1Norm<$t> for Vec<$t> { + #[inline] + fn l1_norm(&self) -> $t { + self.iter().map(|a| a.abs()).sum() + } + } + }; +} + +make_l1norm_integer!(isize); +make_l1norm_unsigned!(usize); +make_l1norm_integer!(i8); +make_l1norm_integer!(i16); +make_l1norm_integer!(i32); +make_l1norm_integer!(i64); +make_l1norm_unsigned!(u8); +make_l1norm_unsigned!(u16); +make_l1norm_unsigned!(u32); +make_l1norm_unsigned!(u64); +make_l1norm_float!(f32); +make_l1norm_float!(f64); +make_l1norm_complex_float!(f32); +make_l1norm_complex_float!(f64); + +#[cfg(test)] +mod tests { + use super::*; + use paste::item; + + macro_rules! make_test { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = vec![4 as $t, 3 as $t]; + let res = as ArgminL1Norm<$t>>::l1_norm(&a); + let target = 7 as $t; + assert!(((target - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + macro_rules! make_test_signed { + ($t:ty) => { + item! { + #[test] + fn []() { + let a = vec![-4 as $t, -3 as $t]; + let res = as ArgminL1Norm<$t>>::l1_norm(&a); + let target = 7 as $t; + assert!(((target - res) as f64).abs() < std::f64::EPSILON); + } + } + }; + } + + make_test!(isize); + make_test!(usize); + make_test!(i8); + make_test!(u8); + make_test!(i16); + make_test!(u16); + make_test!(i32); + make_test!(u32); + make_test!(i64); + make_test!(u64); + make_test!(f32); + make_test!(f64); + + make_test_signed!(isize); + make_test_signed!(i8); + make_test_signed!(i16); + make_test_signed!(i32); + make_test_signed!(i64); + make_test_signed!(f32); + make_test_signed!(f64); +} diff --git a/argmin-math/src/vec/mod.rs b/argmin-math/src/vec/mod.rs index 45fab6c54..4fc3d9acb 100644 --- a/argmin-math/src/vec/mod.rs +++ b/argmin-math/src/vec/mod.rs @@ -10,12 +10,14 @@ mod conj; mod div; mod dot; mod eye; +mod l1norm; mod minmax; mod mul; mod norm; mod random; mod scaledadd; mod scaledsub; +mod signum; mod sub; mod transpose; mod zero; @@ -25,12 +27,14 @@ pub use conj::*; pub use div::*; pub use dot::*; pub use eye::*; +pub use l1norm::*; pub use minmax::*; pub use mul::*; pub use norm::*; pub use random::*; pub use scaledadd::*; pub use scaledsub::*; +pub use signum::*; pub use sub::*; pub use transpose::*; pub use zero::*; diff --git a/argmin-math/src/vec/signum.rs b/argmin-math/src/vec/signum.rs new file mode 100644 index 000000000..2b1fcee64 --- /dev/null +++ b/argmin-math/src/vec/signum.rs @@ -0,0 +1,51 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use crate::ArgminSignum; +use num_complex::Complex; + +macro_rules! make_signum { + ($t:ty) => { + impl ArgminSignum for Vec<$t> { + fn signum(mut self) -> Self { + for x in &mut self { + *x = x.signum(); + } + self + } + } + }; +} + +macro_rules! make_signum_complex { + ($t:ty) => { + impl ArgminSignum for Vec<$t> { + fn signum(mut self) -> Self { + for x in &mut self { + x.re = x.re.signum(); + x.im = x.im.signum(); + } + self + } + } + }; +} + +make_signum!(isize); +make_signum!(i8); +make_signum!(i16); +make_signum!(i32); +make_signum!(i64); +make_signum!(f32); +make_signum!(f64); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); +make_signum_complex!(Complex); diff --git a/argmin/Cargo.toml b/argmin/Cargo.toml index 15c19b988..a8574633b 100644 --- a/argmin/Cargo.toml +++ b/argmin/Cargo.toml @@ -141,6 +141,10 @@ required-features = ["slog-logger"] name = "observer" required-features = ["slog-logger", "gnuplot"] +[[example]] +name = "owl_qn" +required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"] + [[example]] name = "particleswarm" required-features = [] diff --git a/argmin/examples/owl_qn.rs b/argmin/examples/owl_qn.rs new file mode 100644 index 000000000..5fcf09083 --- /dev/null +++ b/argmin/examples/owl_qn.rs @@ -0,0 +1,73 @@ +// Copyright 2018-2022 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +use argmin::core::observers::{ObserverMode, SlogLogger}; +use argmin::core::{CostFunction, Error, Executor, Gradient}; +use argmin::solver::linesearch::MoreThuenteLineSearch; +use argmin::solver::quasinewton::LBFGS; +use argmin_testfunctions::rosenbrock; +use finitediff::FiniteDiff; +use ndarray::{array, Array1}; + +struct Rosenbrock { + a: f64, + b: f64, +} + +impl CostFunction for Rosenbrock { + type Param = Array1; + type Output = f64; + + fn cost(&self, p: &Self::Param) -> Result { + Ok(rosenbrock(&p.to_vec(), self.a, self.b)) + } +} +impl Gradient for Rosenbrock { + type Param = Array1; + type Gradient = Array1; + + fn gradient(&self, p: &Self::Param) -> Result { + Ok((*p).forward_diff(&|x| rosenbrock(&x.to_vec(), self.a, self.b))) + } +} + +fn run() -> Result<(), Error> { + // Define cost function + let cost = Rosenbrock { a: 1.0, b: 100.0 }; + + // Define initial parameter vector + let init_param: Array1 = array![-1.2, 1.0]; + // let init_param: Array1 = array![-1.2, 1.0, -10.0, 2.0, 3.0, 2.0, 4.0, 10.0]; + + // set up a line search + let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9)?; + + // Set up solver + let solver = LBFGS::new(linesearch, 7) + .with_l1_regularization(1.0)? + .with_tolerance_cost(1e-6)?; + + // Run solver + let res = Executor::new(cost, solver) + .configure(|state| state.param(init_param).max_iters(100)) + .add_observer(SlogLogger::term(), ObserverMode::Always) + .run()?; + + // Wait a second (lets the logger flush everything before printing again) + std::thread::sleep(std::time::Duration::from_secs(1)); + + // Print result + println!("{}", res); + Ok(()) +} + +fn main() { + if let Err(ref e) = run() { + println!("{}", e); + std::process::exit(1); + } +} diff --git a/argmin/src/core/test_utils.rs b/argmin/src/core/test_utils.rs index 70ea0419a..6797e181e 100644 --- a/argmin/src/core/test_utils.rs +++ b/argmin/src/core/test_utils.rs @@ -207,6 +207,102 @@ impl Anneal for TestProblem { } } +/// A struct representing the following sparse problem. +/// +/// Example 1: x = [1, 1, 0, 0], y = 1 +/// Example 2: x = [0, 0, 1, 1], y = -1 +/// Example 3: x = [1, 0, 0, 0], y = 1 +/// Example 4: x = [0, 0, 1, 0], y = -1 +/// +/// cost = Σ (w^T x - y)^2 +/// +/// Implements [`CostFunction`] and [`Gradient`]. +#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct TestSparseProblem {} + +impl TestSparseProblem { + /// Create an instance of `TestSparseProblem`. + /// + /// # Example + /// + /// ``` + /// use argmin::core::test_utils::TestSparseProblem; + /// + /// let problem = TestSparseProblem::new(); + /// # assert_eq!(problem, TestSparseProblem {}); + /// ``` + #[allow(dead_code)] + pub fn new() -> Self { + TestSparseProblem {} + } +} + +impl CostFunction for TestSparseProblem { + type Param = Vec; + type Output = f64; + + /// Returns a sum of squared errors. + /// + /// # Example + /// + /// ``` + /// use argmin::core::test_utils::TestSparseProblem; + /// use argmin::core::CostFunction; + /// # use argmin::core::Error; + /// + /// # fn main() -> Result<(), Error> { + /// let problem = TestSparseProblem::new(); + /// + /// let param = vec![1.0, 2.0, 3.0, 4.0]; + /// + /// let res = problem.cost(¶m)?; + /// # assert_eq!(res, 84f64); + /// # Ok(()) + /// # } + /// ``` + fn cost(&self, param: &Self::Param) -> Result { + let err1 = (param[0] + param[1] - 1.0).powi(2); + let err2 = (param[2] + param[3] + 1.0).powi(2); + let err3 = (param[0] - 1.0).powi(2); + let err4 = (param[2] + 1.0).powi(2); + Ok(err1 + err2 + err3 + err4) + } +} + +impl Gradient for TestSparseProblem { + type Param = Vec; + type Gradient = Vec; + + /// Returns a gradient of the cost function. + /// + /// # Example + /// + /// ``` + /// use argmin::core::test_utils::TestSparseProblem; + /// use argmin::core::Gradient; + /// # use argmin::core::Error; + /// + /// # fn main() -> Result<(), Error> { + /// let problem = TestSparseProblem::new(); + /// + /// let param = vec![1.0, 2.0, 3.0, 4.0]; + /// + /// let res = problem.gradient(¶m)?; + /// # assert_eq!(res, vec![4.0, 4.0, 24.0, 16.0]); + /// # Ok(()) + /// # } + /// ``` + fn gradient(&self, param: &Self::Param) -> Result { + let mut g = vec![0.0; 4]; + g[0] = 4.0 * param[0] + 2.0 * param[1] - 4.0; + g[1] = 2.0 * param[0] + 2.0 * param[1] - 2.0; + g[2] = 4.0 * param[2] + 2.0 * param[3] + 4.0; + g[3] = 2.0 * param[2] + 2.0 * param[3] + 2.0; + Ok(g) + } +} + /// A (non-working) solver useful for testing /// /// Implements the [`Solver`] trait. diff --git a/argmin/src/solver/quasinewton/lbfgs.rs b/argmin/src/solver/quasinewton/lbfgs.rs index 18be3cbc3..e8f62e202 100644 --- a/argmin/src/solver/quasinewton/lbfgs.rs +++ b/argmin/src/solver/quasinewton/lbfgs.rs @@ -9,10 +9,27 @@ use crate::core::{ ArgminFloat, CostFunction, DeserializeOwnedAlias, Error, Executor, Gradient, IterState, LineSearch, OptimizationResult, Problem, SerializeAlias, Solver, State, TerminationReason, KV, }; -use argmin_math::{ArgminAdd, ArgminDot, ArgminMul, ArgminNorm, ArgminSub}; +use argmin_math::{ + ArgminAdd, ArgminDot, ArgminL1Norm, ArgminMinMax, ArgminMul, ArgminNorm, ArgminSignum, + ArgminSub, ArgminZeroLike, +}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; use std::collections::VecDeque; +use std::marker::PhantomData; + +/// Calculates pseudo-gradient of OWL-QN method. +fn calculate_pseudo_gradient(l1_coeff: F, param: &P, gradient: &G) -> G +where + P: ArgminAdd + ArgminSub + ArgminMul + ArgminSignum, + G: ArgminAdd + ArgminAdd + ArgminMinMax + ArgminZeroLike, + F: ArgminFloat, +{ + let coeff_p = param.add(&F::min_positive_value()).signum().mul(&l1_coeff); + let coeff_n = param.sub(&F::min_positive_value()).signum().mul(&l1_coeff); + let zeros = gradient.zero_like(); + G::max(&gradient.add(&coeff_n), &zeros).add(&G::min(&gradient.add(&coeff_p), &zeros)) +} /// # Limited-memory BFGS (L-BFGS) method /// @@ -36,6 +53,13 @@ use std::collections::VecDeque; /// other. If the change is below this tolerance (default: `EPSILON`), the algorithm stops. This /// parameter can be set via [`with_tolerance_cost`](`LBFGS::with_tolerance_cost`). /// +/// ## Orthant-Wise Limited-memory Quasi-Newton (OWL-QN) method +/// +/// OWL-QN is a method that adapts L-BFGS to L1-regularization. The original L-BFGS requires a +/// loss function to be differentiable and does not support L1-regularization. Therefore, +/// this library switches to OWL-QN when L1-regularization is specified. L1-regularization can be +/// performed via [`with_l1_regularization`](`LBFGS::with_l1_regularization`). +/// /// TODO: Implement compact representation of BFGS updating (Nocedal/Wright p.230) /// /// ## Requirements on the optimization problem @@ -46,6 +70,9 @@ use std::collections::VecDeque; /// /// Jorge Nocedal and Stephen J. Wright (2006). Numerical Optimization. /// Springer. ISBN 0-387-30303-0. +/// +/// Galen Andrew and Jianfeng Gao (2007). Scalable Training of L1-Regularized Log-Linear Models, +/// International Conference on Machine Learning. #[derive(Clone)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] pub struct LBFGS { @@ -61,6 +88,10 @@ pub struct LBFGS { tol_grad: F, /// Tolerance for the stopping criterion based on the change of the cost stopping criterion tol_cost: F, + /// Coefficient of L1-regularization + l1_coeff: Option, + /// Unregularized gradient used for calculation of `y`. + l1_prev_unreg_grad: Option, } impl LBFGS @@ -84,6 +115,8 @@ where y: VecDeque::with_capacity(m), tol_grad: F::epsilon().sqrt(), tol_cost: F::epsilon(), + l1_coeff: None, + l1_prev_unreg_grad: None, } } @@ -138,28 +171,170 @@ where self.tol_cost = tol_cost; Ok(self) } + + /// Activates L1-regularization with coefficient `l1_coeff`. + /// + /// Parameter `l1_coeff` must be `> 0.0`. + /// + /// # Example + /// + /// ``` + /// # use argmin::solver::quasinewton::LBFGS; + /// # use argmin::core::Error; + /// # fn main() -> Result<(), Error> { + /// # let linesearch = (); + /// let lbfgs: LBFGS<_, Vec, Vec, f64> = LBFGS::new(linesearch, 3).with_l1_regularization(1.0)?; + /// # Ok(()) + /// # } + pub fn with_l1_regularization(mut self, l1_coeff: F) -> Result { + if l1_coeff <= float!(0.0) { + return Err(argmin_error!( + InvalidParameter, + "`L-BFGS`: coefficient of L1-regularization must be > 0." + )); + } + self.l1_coeff = Some(l1_coeff); + Ok(self) + } +} + +/// Wrapper problem for supporting constrained line search. +struct LineSearchProblem { + problem: O, + xi: Option

, + l1_coeff: Option, + phantom: PhantomData, +} + +impl LineSearchProblem +where + P: ArgminSub, + F: ArgminFloat, +{ + fn new(problem: O) -> Self { + Self { + problem, + xi: None, + l1_coeff: None, + phantom: PhantomData, + } + } + + fn with_l1_constraint(&mut self, l1_coeff: F, param: &P, pseudo_gradient: &G) + where + P: ArgminZeroLike + + ArgminMinMax + + ArgminSignum + + ArgminAdd + + ArgminAdd + + ArgminMul + + ArgminSub + + ArgminMul, + { + let zeros = param.zero_like(); + let sig_param = P::max(¶m.sub(&F::min_positive_value()).signum(), &zeros).add(&P::min( + ¶m.add(&F::min_positive_value()).signum(), + &zeros, + )); + self.xi = Some( + sig_param.add( + &sig_param + .mul(&sig_param) + .sub(&float!(1.0)) + .mul(pseudo_gradient), + ), + ); + self.l1_coeff = Some(l1_coeff); + } +} + +impl CostFunction for LineSearchProblem +where + O: CostFunction, + P: ArgminMul + ArgminMinMax + ArgminSignum + ArgminZeroLike + ArgminL1Norm, + F: ArgminFloat, +{ + type Param = P; + type Output = F; + + fn cost(&self, param: &Self::Param) -> Result { + if let Some(xi) = self.xi.as_ref() { + let zeros = param.zero_like(); + let param = P::max(¶m.mul(xi).signum(), &zeros).mul(param); + let cost = self.problem.cost(¶m)?; + Ok(cost + self.l1_coeff.unwrap() * param.l1_norm()) + } else { + self.problem.cost(param) + } + } +} + +impl Gradient for LineSearchProblem +where + O: Gradient, + P: ArgminAdd + + ArgminMul + + ArgminMul + + ArgminSub + + ArgminMinMax + + ArgminSignum + + ArgminZeroLike, + G: ArgminAdd + ArgminZeroLike + ArgminMinMax + ArgminAdd, + F: ArgminFloat, +{ + type Param = P; + type Gradient = G; + + fn gradient(&self, param: &Self::Param) -> Result { + if let Some(xi) = self.xi.as_ref() { + let zeros = param.zero_like(); + let param = P::max(¶m.mul(xi).signum(), &zeros).mul(param); + let gradient = self.problem.gradient(¶m)?; + Ok(calculate_pseudo_gradient( + self.l1_coeff.unwrap(), + ¶m, + &gradient, + )) + } else { + self.problem.gradient(param) + } + } } impl Solver> for LBFGS where O: CostFunction + Gradient, P: Clone + + std::fmt::Debug + SerializeAlias + DeserializeOwnedAlias + ArgminSub + + ArgminSub + ArgminAdd + + ArgminAdd + ArgminDot - + ArgminMul, + + ArgminMul + + ArgminMul + + ArgminMul + + ArgminL1Norm + + ArgminSignum + + ArgminZeroLike + + ArgminMinMax, G: Clone + + std::fmt::Debug + SerializeAlias + DeserializeOwnedAlias + ArgminNorm + ArgminSub + + ArgminAdd + + ArgminAdd + ArgminDot + ArgminDot + ArgminMul - + ArgminMul, - L: Clone + LineSearch + Solver>, + + ArgminMul + + ArgminZeroLike + + ArgminMinMax, + L: Clone + LineSearch + Solver, IterState>, F: ArgminFloat, { const NAME: &'static str = "L-BFGS"; @@ -179,7 +354,11 @@ where let cost = state.get_cost(); let cost = if cost.is_infinite() { - problem.cost(¶m)? + if let Some(l1_coeff) = self.l1_coeff { + problem.cost(¶m)? + l1_coeff * param.l1_norm() + } else { + problem.cost(¶m)? + } } else { cost }; @@ -202,10 +381,18 @@ where "`L-BFGS`: Parameter vector in state not set." ))?; let cur_cost = state.get_cost(); - let prev_grad = state.take_gradient().ok_or_else(argmin_error_closure!( + + // If L1 regularization is enabled, the state contains pseudo gradient. + let mut prev_grad = state.take_gradient().ok_or_else(argmin_error_closure!( PotentialBug, "`L-BFGS`: Gradient in state not set." ))?; + if let Some(l1_coeff) = self.l1_coeff { + if self.l1_prev_unreg_grad.is_none() { + self.l1_prev_unreg_grad = Some(prev_grad.clone()); + prev_grad = calculate_pseudo_gradient(l1_coeff, ¶m, &prev_grad) + } + } let gamma: F = if let (Some(sk), Some(yk)) = (self.s.back(), self.y.back()) { sk.dot(yk) / yk.dot(yk) @@ -234,14 +421,25 @@ where r = r.add(&sk.mul(&(alpha[i] - beta))); } - self.linesearch.search_direction(r.mul(&float!(-1.0))); + let mut line_problem = LineSearchProblem::new(problem.take_problem().unwrap()); + let d = if let Some(l1_coeff) = self.l1_coeff { + line_problem.with_l1_constraint(l1_coeff, ¶m, &prev_grad); + let zeros = r.zero_like(); + P::max(&r.mul(&prev_grad).signum(), &zeros) + .mul(&r) + .mul(&float!(-1.0)) + } else { + r.mul(&float!(-1.0)) + }; + + self.linesearch.search_direction(d); // Run solver let OptimizationResult { - problem: line_problem, + problem: mut line_problem, state: mut linesearch_state, .. - } = Executor::new(problem.take_problem().unwrap(), self.linesearch.clone()) + } = Executor::new(line_problem, self.linesearch.clone()) .configure(|config| { config .param(param.clone()) @@ -251,11 +449,18 @@ where .ctrlc(false) .run()?; - let xk1 = linesearch_state.take_param().unwrap(); + let mut xk1 = linesearch_state.take_param().unwrap(); let next_cost = linesearch_state.get_cost(); // take back problem and take care of function evaluation counts - problem.consume_problem(line_problem); + let mut internal_line_problem = line_problem.take_problem().unwrap(); + let xi = internal_line_problem.xi.take(); + problem.problem = Some(internal_line_problem.problem); + problem.consume_func_counts(line_problem); + if let Some(xi) = xi { + let zeros = xk1.zero_like(); + xk1 = P::max(&xk1.mul(&xi).signum(), &zeros).mul(&xk1); + } if state.get_iter() >= self.m as u64 { self.s.pop_front(); @@ -265,7 +470,17 @@ where let grad = problem.gradient(&xk1)?; self.s.push_back(xk1.sub(¶m)); - self.y.push_back(grad.sub(&prev_grad)); + let grad = if let Some(l1_coeff) = self.l1_coeff { + // Stores unregularized gradient and returns L1 gradient. + let pseudo_grad = calculate_pseudo_gradient(l1_coeff, ¶m, &grad); + self.y + .push_back(grad.sub(self.l1_prev_unreg_grad.as_ref().unwrap())); + self.l1_prev_unreg_grad = Some(grad); + pseudo_grad + } else { + self.y.push_back(grad.sub(&prev_grad)); + grad + }; Ok(( state.param(xk1).cost(next_cost).gradient(grad), @@ -287,7 +502,10 @@ where #[cfg(test)] mod tests { use super::*; - use crate::core::{test_utils::TestProblem, ArgminError, IterState, State}; + use crate::core::{ + test_utils::{TestProblem, TestSparseProblem}, + ArgminError, IterState, State, + }; use crate::solver::linesearch::MoreThuenteLineSearch; use crate::test_trait_impl; @@ -309,6 +527,8 @@ mod tests { m, s, y, + l1_coeff, + l1_prev_unreg_grad, } = lbfgs; assert_eq!(linesearch, MyFakeLineSearch {}); @@ -317,6 +537,8 @@ mod tests { assert_eq!(m, 3); assert!(s.capacity() >= 3); assert!(y.capacity() >= 3); + assert!(l1_coeff.is_none()); + assert!(l1_prev_unreg_grad.is_none()); } #[test] @@ -459,4 +681,51 @@ mod tests { assert_eq!(s.to_ne_bytes(), g.to_ne_bytes()); } } + + #[test] + fn test_l1_regularization() { + { + let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap(); + + let param: Vec = vec![0.0; 4]; + + let lbfgs: LBFGS<_, Vec, Vec, f64> = LBFGS::new(linesearch, 3); + + let cost = TestSparseProblem::new(); + let res = Executor::new(cost, lbfgs) + .configure(|state| state.param(param).max_iters(2)) + .run() + .unwrap(); + + let result_param = res.state.param.unwrap(); + + assert!((result_param[0] - 0.5).abs() > 1e-6); + assert!((result_param[1]).abs() > 1e-6); + assert!((result_param[2] + 0.5).abs() > 1e-6); + assert!((result_param[3]).abs() > 1e-6); + } + { + let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9).unwrap(); + + let param: Vec = vec![0.0; 4]; + + let lbfgs: LBFGS<_, Vec, Vec, f64> = LBFGS::new(linesearch, 3) + .with_l1_regularization(2.0) + .unwrap(); + + let cost = TestSparseProblem::new(); + let res = Executor::new(cost, lbfgs) + .configure(|state| state.param(param).max_iters(2)) + .run() + .unwrap(); + + let result_param = res.state.param.unwrap(); + dbg!(&result_param); + + assert!((result_param[0] - 0.5).abs() < 1e-6); + assert!((result_param[1]).abs() < 1e-6); + assert!((result_param[2] + 0.5).abs() < 1e-6); + assert!((result_param[3]).abs() < 1e-6); + } + } }