Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support OWL-QN method (L-BFGS with L1-regularization) #244

Merged
merged 19 commits into from
Aug 17, 2022
Merged
12 changes: 12 additions & 0 deletions argmin-math/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ pub trait ArgminScaledSub<T, U, V> {
fn scaled_sub(&self, factor: &U, vec: &T) -> V;
}

/// Compute the l2-norm (`U`) of `self`
pub trait ArgminL1Norm<U> {
/// Compute the l1-norm (`U`) of `self`
fn l1_norm(&self) -> U;
}

/// Compute the l2-norm (`U`) of `self`
pub trait ArgminNorm<U> {
/// Compute the l2-norm (`U`) of `self`
Expand Down Expand Up @@ -283,3 +289,9 @@ pub trait ArgminMinMax {
/// Select piecewise maximum
fn max(x: &Self, y: &Self) -> Self;
}

/// Returns a number that represents the sign of `self`.
pub trait ArgminSignum {
/// Returns a number that represents the sign of `self`.
fn signum(self) -> Self;
}
67 changes: 67 additions & 0 deletions argmin-math/src/nalgebra_m/l1norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2018-2022 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use crate::ArgminL1Norm;

use nalgebra::{
base::{dimension::Dim, storage::Storage},
LpNorm, Matrix, SimdComplexField,
};

impl<N, R, C, S> ArgminL1Norm<N::SimdRealField> for Matrix<N, R, C, S>
where
N: SimdComplexField,
R: Dim,
C: Dim,
S: Storage<N, R, C>,
{
#[inline]
fn l1_norm(&self) -> N::SimdRealField {
self.apply_norm(&LpNorm(1))
}
}

#[cfg(test)]
mod tests {
use super::*;
use nalgebra::Vector2;
use paste::item;

macro_rules! make_test {
($t:ty) => {
item! {
#[test]
fn [<test_l1norm_ $t>]() {
let a = Vector2::new(4 as $t, 3 as $t);
let res = <Vector2<$t> as ArgminL1Norm<$t>>::l1_norm(&a);
let target = 7 as $t;
assert!(((target - res) as f64).abs() < std::f64::EPSILON);
}
}
};
}

macro_rules! make_test_signed {
($t:ty) => {
item! {
#[test]
fn [<test_l1norm_signed_ $t>]() {
let a = Vector2::new(-4 as $t, -3 as $t);
let res = <Vector2<$t> as ArgminL1Norm<$t>>::l1_norm(&a);
let target = 7 as $t;
assert!(((target - res) as f64).abs() < std::f64::EPSILON);
}
}
};
}

make_test!(f32);
make_test!(f64);

make_test_signed!(f32);
make_test_signed!(f64);
}
2 changes: 2 additions & 0 deletions argmin-math/src/nalgebra_m/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod div;
mod dot;
mod eye;
mod inv;
mod l1norm;
mod mul;
mod norm;
mod scaledadd;
Expand All @@ -25,6 +26,7 @@ pub use div::*;
pub use dot::*;
pub use eye::*;
pub use inv::*;
pub use l1norm::*;
pub use mul::*;
pub use norm::*;
pub use scaledadd::*;
Expand Down
132 changes: 132 additions & 0 deletions argmin-math/src/ndarray_m/l1norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2018-2022 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use crate::ArgminL1Norm;
use ndarray::Array1;
use num_complex::Complex;

macro_rules! make_l1norm_float {
($t:ty) => {
impl ArgminL1Norm<$t> for Array1<$t> {
#[inline]
fn l1_norm(&self) -> $t {
self.iter().map(|a| a.abs()).sum()
}
}
};
}

macro_rules! make_l1norm_complex_float {
($t:ty) => {
impl ArgminL1Norm<Complex<$t>> for Array1<Complex<$t>> {
#[inline]
fn l1_norm(&self) -> Complex<$t> {
self.iter().map(|a| a.l1_norm()).sum::<$t>().into()
}
}

impl ArgminL1Norm<$t> for Array1<Complex<$t>> {
#[inline]
fn l1_norm(&self) -> $t {
self.iter().map(|a| a.l1_norm()).sum()
}
}
};
}

macro_rules! make_l1norm_unsigned {
($t:ty) => {
impl ArgminL1Norm<$t> for Array1<$t> {
#[inline]
fn l1_norm(&self) -> $t {
self.sum()
}
}
};
}

macro_rules! make_l1norm_integer {
($t:ty) => {
impl ArgminL1Norm<$t> for Array1<$t> {
#[inline]
fn l1_norm(&self) -> $t {
self.iter().map(|a| a.abs()).sum()
}
}
};
}

make_l1norm_integer!(isize);
make_l1norm_unsigned!(usize);
make_l1norm_integer!(i8);
make_l1norm_integer!(i16);
make_l1norm_integer!(i32);
make_l1norm_integer!(i64);
make_l1norm_unsigned!(u8);
make_l1norm_unsigned!(u16);
make_l1norm_unsigned!(u32);
make_l1norm_unsigned!(u64);
make_l1norm_float!(f32);
make_l1norm_float!(f64);
make_l1norm_complex_float!(f32);
make_l1norm_complex_float!(f64);

#[cfg(test)]
mod tests {
use super::*;
use ndarray::{array, Array1};
use paste::item;

macro_rules! make_test {
($t:ty) => {
item! {
#[test]
fn [<test_norm_ $t>]() {
let a = array![4 as $t, 3 as $t];
let res = <Array1<$t> as ArgminL1Norm<$t>>::l1_norm(&a);
let target = 7 as $t;
assert!(((target - res) as f64).abs() < std::f64::EPSILON);
}
}
};
}

macro_rules! make_test_signed {
($t:ty) => {
item! {
#[test]
fn [<test_norm_signed_ $t>]() {
let a = array![-4 as $t, -3 as $t];
let res = <Array1<$t> as ArgminL1Norm<$t>>::l1_norm(&a);
let target = 7 as $t;
assert!(((target - res) as f64).abs() < std::f64::EPSILON);
}
}
};
}

make_test!(isize);
make_test!(usize);
make_test!(i8);
make_test!(u8);
make_test!(i16);
make_test!(u16);
make_test!(i32);
make_test!(u32);
make_test!(i64);
make_test!(u64);
make_test!(f32);
make_test!(f64);

make_test_signed!(isize);
make_test_signed!(i8);
make_test_signed!(i16);
make_test_signed!(i32);
make_test_signed!(i64);
make_test_signed!(f32);
make_test_signed!(f64);
}
77 changes: 77 additions & 0 deletions argmin-math/src/ndarray_m/minmax.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright 2018-2022 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: Tests for Array2 impl

use crate::ArgminMinMax;
use ndarray::{Array1, Array2};

macro_rules! make_minmax {
($t:ty) => {
impl ArgminMinMax for Array1<$t> {
#[inline]
fn min(x: &Self, y: &Self) -> Array1<$t> {
assert_eq!(x.shape(), y.shape());
x.iter()
.zip(y)
.map(|(&a, &b)| if a < b { a } else { b })
.collect()
}

#[inline]
fn max(x: &Self, y: &Self) -> Array1<$t> {
assert_eq!(x.shape(), y.shape());
x.iter()
.zip(y)
.map(|(&a, &b)| if a > b { a } else { b })
.collect()
}
}

impl ArgminMinMax for Array2<$t> {
#[inline]
fn min(x: &Self, y: &Self) -> Array2<$t> {
assert_eq!(x.shape(), y.shape());
let m = x.shape()[0];
let n = x.shape()[1];
let mut out = x.clone();
for i in 0..m {
for j in 0..n {
let a = x[(i, j)];
let b = y[(i, j)];
out[(i, j)] = if a < b { a } else { b };
}
}
out
}

#[inline]
fn max(x: &Self, y: &Self) -> Array2<$t> {
assert_eq!(x.shape(), y.shape());
let m = x.shape()[0];
let n = x.shape()[1];
let mut out = x.clone();
for i in 0..m {
for j in 0..n {
let a = x[(i, j)];
let b = y[(i, j)];
out[(i, j)] = if a > b { a } else { b };
}
}
out
}
}
};
}

make_minmax!(isize);
make_minmax!(i8);
make_minmax!(i16);
make_minmax!(i32);
make_minmax!(i64);
make_minmax!(f32);
make_minmax!(f64);
6 changes: 6 additions & 0 deletions argmin-math/src/ndarray_m/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ mod div;
mod dot;
mod eye;
mod inv;
mod l1norm;
mod minmax;
mod mul;
mod norm;
mod scaledadd;
mod scaledsub;
mod signum;
mod sub;
mod transpose;
mod zero;
Expand All @@ -25,10 +28,13 @@ pub use div::*;
pub use dot::*;
pub use eye::*;
pub use inv::*;
pub use l1norm::*;
pub use minmax::*;
pub use mul::*;
pub use norm::*;
pub use scaledadd::*;
pub use scaledsub::*;
pub use signum::*;
pub use sub::*;
pub use transpose::*;
pub use zero::*;
Loading