Skip to content

Commit

Permalink
Merge branch 'main' into pub-adddim
Browse files Browse the repository at this point in the history
  • Loading branch information
coreylowman authored Sep 14, 2023
2 parents 6c30a09 + 2e5116c commit 7304377
Show file tree
Hide file tree
Showing 22 changed files with 1,524 additions and 169 deletions.
9 changes: 2 additions & 7 deletions examples/02-ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,8 @@ fn main() {
dbg!(f.array());

// and of course you can chain all of these together
let _ = dev
.sample_normal::<Rank2<5, 10>>()
.clamp(-1.0, 1.0)
.exp()
.abs()
.powf(0.5)
/ 2.0;
let _: Tensor<(Const<5>, Const<10>), f32, _> =
dev.sample_normal().clamp(-1.0, 1.0).exp().abs().powf(0.5) / 2.0;

// binary and unary operations can also be performed on dynamically sized tensors
let mut a: Tensor<(Const<3>, usize), f32, _> = dev.sample_uniform_like(&(Const, 5));
Expand Down
7 changes: 2 additions & 5 deletions src/losses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ pub fn huber_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
pub fn smooth_l1_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
pred: Tensor<S, E, D, T>,
targ: Tensor<S, E, D>,
delta: impl Into<f64>,
delta: impl Copy + Into<f64>,
) -> Tensor<Rank0, E, D, T> {
let delta: f64 = delta.into();
huber_loss(pred, targ, delta) / E::from_f64(delta).unwrap()
huber_loss(pred, targ, delta) / delta
}

/// [Cross entropy loss](https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_loss_function_and_logistic_regression).
Expand All @@ -83,7 +82,6 @@ pub fn cross_entropy_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<
target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
let inv_last_axis_numel = 1.0 / <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
let inv_last_axis_numel = E::from_f64(inv_last_axis_numel).unwrap();
let probs = logits.log_softmax::<S::LastAxis>();
(probs * target_probs).mean().negate() / inv_last_axis_numel
}
Expand All @@ -103,7 +101,6 @@ pub fn kl_div_with_logits_loss<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>>(
target_probs: Tensor<S, E, D>,
) -> Tensor<Rank0, E, D, T> {
let inv_last_axis_numel = 1.0 / <S as HasAxes<S::LastAxis>>::size(logits.shape()) as f64;
let inv_last_axis_numel = E::from_f64(inv_last_axis_numel).unwrap();
let probs = logits.log_softmax::<S::LastAxis>();
((probs - target_probs.clone().ln()) * target_probs)
.mean()
Expand Down
9 changes: 2 additions & 7 deletions src/nn/batchnorm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ where
var.try_axpy(1.0 - momentum, &var_chan, momentum * n / (n - 1.0))?;

// statistics for normalizing - on tape
let std = var_chan
.try_add(E::from_f64(epsilon).unwrap())?
.try_sqrt()?;
let std = var_chan.try_add(epsilon)?.try_sqrt()?;

// record broadcast of scale & bias - on tape
let scale = scale
Expand Down Expand Up @@ -81,10 +79,7 @@ where
let shape = *x.shape();

// statistics for normalizing
let std = var
.clone()
.try_add(E::from_f64(epsilon).unwrap())?
.try_sqrt()?;
let std = var.clone().try_add(epsilon)?.try_sqrt()?;

let scale = scale.clone().try_div(std)?.try_broadcast_like(&shape)?;

Expand Down
281 changes: 281 additions & 0 deletions src/nn/conv1d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
use num_traits::Float;
use rand_distr::uniform::SampleUniform;

use crate::{shapes::*, tensor::*, tensor_ops::*};

use super::*;

pub mod builder {
#[derive(Debug)]
pub struct Conv1D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize = 1,
const PADDING: usize = 0,
const DILATION: usize = 1,
const GROUPS: usize = 1,
>;
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> BuildOnDevice<D, E> for builder::Conv1D<I, O, K, S, P, L, G>
where
E: Dtype,
D: Device<E>,
Const<{ I / G }>: Sized,
Conv1D<I, O, K, S, P, L, G, E, D>: BuildModule<D, E>,
{
type Built = Conv1D<I, O, K, S, P, L, G, E, D>;
fn try_build_on_device(device: &D) -> Result<Self::Built, <D>::Err> {
Self::Built::try_build(device)
}
}

/// **Requires Nightly** Performs *unbiased* 1d convolutions on 2d and 3d images.
///
/// **Pytorch Equivalent**: `torch.nn.Conv1d(..., bias=False)`
///
/// Generics:
/// - `IN_CHAN`: The number of input channels in an image.
/// - `OUT_CHAN`: The number of channels in the output of the layer.
/// - `KERNEL_SIZE`: The size of the kernel applied to both width and height of the images.
/// - `STRIDE`: How far to move the kernel each step. Defaults to `1`
/// - `PADDING`: How much zero padding to add around the images. Defaults to `0`.
/// - `DILATION`: Controls the spacing between kernel points. Defaults to `1`.
/// - `GROUPS`: Controls the connections between inputs and outputs.
/// `IN_CHAN` and `OUT_CHAN` must both be divisible by `GROUPS`. For example,
///
/// See [conv animations](https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md) for helpful
/// visualization of all of these parameters.
#[derive(Debug, Clone)]
pub struct Conv1D<
const IN_CHAN: usize,
const OUT_CHAN: usize,
const KERNEL_SIZE: usize,
const STRIDE: usize,
const PADDING: usize,
const DILATION: usize,
const GROUPS: usize,
E: Dtype,
D: Storage<E>,
> where
Const<{ IN_CHAN / GROUPS }>: Sized,
{
pub weight: Tensor<Rank3<OUT_CHAN, { IN_CHAN / GROUPS }, KERNEL_SIZE>, E, D>,
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
> TensorCollection<E, D> for Conv1D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype + Float + SampleUniform,
D: Device<E>,
{
type To<E2: Dtype, D2: Device<E2>> = Conv1D<I, O, K, S, P, L, G, E2, D2>;

fn iter_tensors<V: ModuleVisitor<Self, E, D>>(
visitor: &mut V,
) -> Result<Option<Self::To<V::E2, V::D2>>, V::Err> {
visitor.visit_fields(
Self::tensor(
"weight",
|s| &s.weight,
|s| &mut s.weight,
TensorOptions::reset_with(|t| {
let scale = E::from_f64(G as f64 / (I * K) as f64).unwrap();
let b = scale.sqrt();
t.try_fill_with_distr(rand_distr::Uniform::new(-b, b))
}),
),
|weight| Conv1D { weight },
)
}
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E,
D,
Img,
> Module<Img> for Conv1D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
E: Dtype,
D: Device<E>,
(Img, Tensor<Rank3<O, { I / G }, K>, E, D>): TryConv1D<Const<S>, Const<P>, Const<L>, Const<G>>,
{
type Output = <(Img, Tensor<Rank3<O, { I / G }, K>, E, D>) as TryConv1D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Convolved;
type Error = <(Img, Tensor<Rank3<O, { I / G }, K>, E, D>) as TryConv1D<
Const<S>,
Const<P>,
Const<L>,
Const<G>,
>>::Error;

fn try_forward(&self, x: Img) -> Result<Self::Output, Self::Error> {
(x, self.weight.clone()).try_conv1d(Const, Const, Const, Const)
}
}

impl<
const I: usize,
const O: usize,
const K: usize,
const S: usize,
const P: usize,
const L: usize,
const G: usize,
E: Dtype,
D: Storage<E>,
> NonMutableModule for Conv1D<I, O, K, S, P, L, G, E, D>
where
Const<{ I / G }>: Sized,
{
}

#[cfg(test)]
mod tests {
use crate::{
optim::*,
tensor::{AsArray, SampleTensor, ZerosTensor},
tests::*,
};

use super::{builder::Conv1D, *};

#[rustfmt::skip]
#[test]
fn test_forward_3d_sizes() {
let dev: TestDevice = Default::default();
let x = dev.zeros::<Rank2<3, 10>>();
let _: Tensor<Rank2<2, 8>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<4, 8>, _, _, _> = dev.build_module::<Conv1D<3, 4, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<4, 9>, _, _, _> = dev.build_module::<Conv1D<3, 4, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<4, 7>, _, _, _> = dev.build_module::<Conv1D<3, 4, 4>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 4>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 3>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 10>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 1>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 12>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank2<2, 6>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone());
}

#[test]
fn test_grouped_forward_sizes() {
let dev: TestDevice = Default::default();

let x = dev.ones::<Rank2<16, 10>>();

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1>, TestDtype>();
let _: Tensor<Rank3<32, 16, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("1");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 2>, TestDtype>();
let _: Tensor<Rank3<32, 8, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("2");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 4>, TestDtype>();
let _: Tensor<Rank3<32, 4, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("3");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 8>, TestDtype>();
let _: Tensor<Rank3<32, 2, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x.clone());
println!("4");

let m = dev.build_module::<Conv1D<16, 32, 3, 1, 0, 1, 16>, TestDtype>();
let _: Tensor<Rank3<32, 1, 3>, _, _> = m.weight;
let _: Tensor<Rank2<32, 8>, _, _> = m.forward(x);
println!("5");
}

#[rustfmt::skip]
#[test]
fn test_forward_4d_sizes() {
let dev: TestDevice = Default::default();
let x = dev.zeros::<Rank3<5, 3, 10>>();
let _: Tensor<Rank3<5, 2, 8>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 4, 8>, _, _, _> = dev.build_module::<Conv1D<3, 4, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 4, 9>, _, _, _> = dev.build_module::<Conv1D<3, 4, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 4, 7>, _, _, _> = dev.build_module::<Conv1D<3, 4, 4>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 4>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 3>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 3>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 10>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 1>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 12>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 1, 2>, TestDtype>().forward(x.clone());
let _: Tensor<Rank3<5, 2, 6>, _, _, _> = dev.build_module::<Conv1D<3, 2, 3, 2, 2>, TestDtype>().forward(x.clone());
}

#[test]
fn test_2_conv_sizes() {
let dev = Cpu::default();
type A = Conv1D<1, 2, 3>;
type B = Conv1D<2, 4, 3>;
let _: Tensor<Rank2<4, 6>, _, _> = dev
.build_module::<(A, B), TestDtype>()
.forward(dev.zeros::<Rank2<1, 10>>());
}

#[test]
fn test_3_conv_sizes() {
type A = Conv1D<1, 2, 3>;
type B = Conv1D<2, 4, 3>;
type C = Conv1D<4, 1, 1, 1, 1>;

let dev = Cpu::default();
let _: Tensor<Rank2<1, 8>, _, _> = dev
.build_module::<(A, B, C), TestDtype>()
.forward_mut(dev.zeros::<Rank2<1, 10>>());
}

#[test]
fn test_conv_with_optimizer() {
let dev: TestDevice = Default::default();

let mut m = dev.build_module::<Conv1D<2, 4, 3>, TestDtype>();

let weight_init = m.weight.clone();

let mut opt = Sgd::new(&m, Default::default());
let out = m.forward(dev.sample_normal::<Rank3<8, 2, 28>>().leaky_trace());
let g = out.square().mean().backward();

assert_ne!(g.get(&m.weight).array(), [[[TestDtype::zero(); 3]; 2]; 4]);

opt.update(&mut m, &g).expect("unused params");

assert_ne!(weight_init.array(), m.weight.array());
}
}
File renamed without changes.
8 changes: 5 additions & 3 deletions src/nn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ mod batchnorm1d;
mod batchnorm2d;
mod bias2d;
#[cfg(feature = "nightly")]
mod conv;
mod conv1d;
#[cfg(feature = "nightly")]
mod conv2d;
#[cfg(feature = "nightly")]
mod convtrans;
mod dropout;
Expand Down Expand Up @@ -243,7 +245,7 @@ pub mod modules {
pub use super::batchnorm2d::BatchNorm2D;
pub use super::bias2d::Bias2D;
#[cfg(feature = "nightly")]
pub use super::conv::Conv2D;
pub use super::conv2d::Conv2D;
#[cfg(feature = "nightly")]
pub use super::convtrans::ConvTrans2D;
pub use super::dropout::{Dropout, DropoutOneIn};
Expand Down Expand Up @@ -279,7 +281,7 @@ pub mod builders {
pub use super::batchnorm2d::builder::BatchNorm2D;
pub use super::bias2d::builder::Bias2D;
#[cfg(feature = "nightly")]
pub use super::conv::builder::Conv2D;
pub use super::conv2d::builder::Conv2D;
#[cfg(feature = "nightly")]
pub use super::convtrans::builder::ConvTrans2D;
pub use super::dropout::{Dropout, DropoutOneIn};
Expand Down
2 changes: 1 addition & 1 deletion src/nn/transformer/mha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ where
let q = q.try_permute::<_, Axes4<0, 2, 1, 3>>()?;

// Get weights
let scalar: E = E::from_f64(1.0 / ((K / H) as f64).sqrt()).unwrap();
let scalar = 1.0 / ((K / H) as f64).sqrt();
let weights = q.try_matmul(k)?.try_mul(scalar)?;
let weights = weights.try_softmax::<Axis<3>>()?;

Expand Down
Loading

0 comments on commit 7304377

Please sign in to comment.