Skip to content

Commit

Permalink
add relu module
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 24, 2021
1 parent cb896f8 commit 036df7b
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 85 deletions.
2 changes: 1 addition & 1 deletion rusty_grad/src/modules/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod add;
mod div;
mod mul;
mod relu;
mod sub;
// mod relu;
69 changes: 69 additions & 0 deletions rusty_grad/src/modules/relu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use crate::variable::Module;
use crate::variable::VariableRef;

pub struct Relu {}

impl Module for Relu {
fn forward(&self, x: f32, _y: f32) -> f32 {
if x > 0.0 {
x
} else {
0.0
}
}

fn backward<'a>(
&self,
grad: &'a f32,
left_ref: &'a VariableRef,
_right_ref: &'a VariableRef,
) -> [f32; 2] {
let left_var = left_ref.borrow();

if left_var.data > 0.0 {
[*grad, 0.0]
} else {
[0.0, 0.0]
}
}
}

impl VariableRef {
fn relu(&mut self) -> VariableRef {
let module = Relu {};
module.subscribe(self, self, Box::new(Relu {}))
}
}

#[cfg(test)]
mod tests {

use crate::variable::Variable;

#[test]
fn check_method() {
let mut x = Variable::new(2.0);

assert_eq!(x.relu().borrow().data, 2.0);

let mut y = Variable::new(-10.0);

assert_eq!(y.relu().borrow().data, 0.0);
}

#[test]
fn check_backward_positive() {
let mut x = Variable::new(2.0);
let mut z = x.relu();
z.backward();
assert_eq!(x.borrow().grad, 1.0);
}

#[test]
fn check_backward_negative() {
let mut x = Variable::new(-2.0);
let mut z = x.relu();
z.backward();
assert_eq!(x.borrow().grad, 0.0);
}
}
27 changes: 7 additions & 20 deletions rusty_grad/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,24 +162,11 @@ mod tests {
assert_eq!(true, x.borrow().is_leaf());
}

// #[test]
// fn new_node_is_not_leaf() {
// let ref x = Variable::new(2.0);
// let ref y = Variable::new(2.0);
//
// assert_eq!(false, (x + y).borrow().is_leaf());
// }

// #[test]
// fn div_check_backward() {
// let ref x = Variable::new(2.0);
// let ref y = Variable::new(3.0);
//
// let mut z = x / y;
//
// z.backward();
//
// assert_eq!(x.borrow().grad, 1.0 / 3.0);
// assert_eq!(y.borrow().grad, -2.0 / 9.0);
/* } */
#[test]
fn new_node_is_not_leaf() {
let ref x = Variable::new(2.0);
let ref y = Variable::new(2.0);

assert_eq!(false, (x + y).borrow().is_leaf());
}
}
128 changes: 64 additions & 64 deletions rusty_grad/tests/autograd.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,64 @@
/* use rusty_grad::variable::Variable; */
// use rusty_grad::variable::VariableRef;
/* */
// #[test]
// fn test_double_add() {
// let ref x = Variable::new(4.0);
// let mut z = x + x;
//
// z.backward();
//
// assert_eq!(x.borrow().grad, 2.0);
// }
//
// // #[test]
// // fn test_simple_autograd() {
// // let ref x = Variable::new(4.0);
// // let ref y = Variable::new(3.0);
// //
// // let mut z = (x + x) + (x + y);
// //
// // z.backward();
// //
// // assert_eq!(x.borrow().grad, 3.0);
// // assert_eq!(y.borrow().grad, 1.0);
// // }
// [> <]
// #[test]
// fn test_simple_two_stage_autograd() {
// let ref x = Variable::new(3.0);
// let ref y = Variable::new(5.0);
//
// let ref h = x + y;
// let mut z = h * x;
//
// z.backward();
//
// assert_eq!(x.borrow().grad, 11.0);
// assert_eq!(y.borrow().grad, 3.0);
// }
//
// // #[test]
// // fn test_complexautograd_1() {
// // let ref x = Variable::new(8.0);
// // let ref y = Variable::new(-3.0);
// //
// // let mut z = (x * y) * (x * y) + (x - y);
// //
// // z.backward();
// //
// // assert_eq!(x.borrow().grad, 145.0);
// // assert_eq!(y.borrow().grad, -385.0);
// // }
// //
// // #[test]
// // fn test_complexautograd_2() {
// // let ref x = Variable::new(-8.0);
// // let ref y = Variable::new(13.0);
// //
// // let mut z = (x + y) * (x + y);
// // z.backward();
// //
// // assert_eq!(x.borrow().grad, 10.0);
// // assert_eq!(y.borrow().grad, 10.0);
/* [> } <] */
use rusty_grad::variable::Variable;
use rusty_grad::variable::VariableRef;

#[test]
fn test_double_add() {
let ref x = Variable::new(4.0);
let mut z = x + x;

z.backward();

assert_eq!(x.borrow().grad, 2.0);
}

#[test]
fn test_simple_autograd() {
let ref x = Variable::new(4.0);
let ref y = Variable::new(3.0);

let mut z = (x + x) + (x + y);

z.backward();

assert_eq!(x.borrow().grad, 3.0);
assert_eq!(y.borrow().grad, 1.0);
}

#[test]
fn test_simple_two_stage_autograd() {
let ref x = Variable::new(3.0);
let ref y = Variable::new(5.0);

let ref h = x + y;
let mut z = h * x;

z.backward();

assert_eq!(x.borrow().grad, 11.0);
assert_eq!(y.borrow().grad, 3.0);
}

#[test]
fn test_complexautograd_1() {
let ref x = Variable::new(8.0);
let ref y = Variable::new(-3.0);

let mut z = (x * y) * (x * y) + (x - y);

z.backward();

assert_eq!(x.borrow().grad, 145.0);
assert_eq!(y.borrow().grad, -385.0);
}

#[test]
fn test_complexautograd_2() {
let ref x = Variable::new(-8.0);
let ref y = Variable::new(13.0);

let mut z = (x + y) * (x + y);
z.backward();

assert_eq!(x.borrow().grad, 10.0);
assert_eq!(y.borrow().grad, 10.0);
}

0 comments on commit 036df7b

Please sign in to comment.