From cdd602acb2a6b01178f9747f83fd204d3580f5af Mon Sep 17 00:00:00 2001 From: Corey Lowman Date: Thu, 14 Sep 2023 12:18:55 -0400 Subject: [PATCH] Fixing documentation --- dfdx-nn-derives/src/lib.rs | 18 ++-- dfdx-nn/src/lib.rs | 152 +++++++++++++++++++++++++++++ dfdx/src/lib.rs | 72 -------------- dfdx/src/tensor/gradients.rs | 3 - dfdx/src/tensor/mod.rs | 2 +- dfdx/src/tensor/safetensors.rs | 2 +- dfdx/src/tensor/tensor_impls.rs | 12 +-- dfdx/src/tensor/tensorlike.rs | 2 +- dfdx/src/tensor_ops/adam/mod.rs | 4 +- dfdx/src/tensor_ops/optim.rs | 4 +- dfdx/src/tensor_ops/rmsprop/mod.rs | 2 +- dfdx/src/tensor_ops/sgd/mod.rs | 12 +-- 12 files changed, 177 insertions(+), 108 deletions(-) diff --git a/dfdx-nn-derives/src/lib.rs b/dfdx-nn-derives/src/lib.rs index 78bec0d4e..11a3318bd 100644 --- a/dfdx-nn-derives/src/lib.rs +++ b/dfdx-nn-derives/src/lib.rs @@ -8,15 +8,15 @@ macro_rules! has_attr { }; } -/// Allows you to implement [dfdx_nn::Module], while automatically implementing the following: -/// 1. [dfdx_nn::BuildOnDevice] -/// 2. [dfdx_nn::ResetParams] -/// 3. [dfdx_nn::UpdateParams] -/// 4. [dfdx_nn::ZeroGrads] -/// 5. [dfdx_nn::SaveSafeTensors] -/// 6. [dfdx_nn::LoadSafeTensors] +/// Allows you to implement [dfdx_nn_core::Module], while automatically implementing the following: +/// 1. [dfdx_nn_core::BuildOnDevice] +/// 2. [dfdx_nn_core::ResetParams] +/// 3. [dfdx_nn_core::UpdateParams] +/// 4. [dfdx_nn_core::ZeroGrads] +/// 5. [dfdx_nn_core::SaveSafeTensors] +/// 6. [dfdx_nn_core::LoadSafeTensors] /// -/// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx_nn::BuildOnDevice]. +/// If your struct contains sub module configs, then you must add the `#[module]` attribute to those items. Any field that is marked with `#[module]` will be expected to implement [dfdx_nn_core::BuildOnDevice]. /// /// You can control the name of the built struct with the `#[built()]` attribute on the struct. /// @@ -309,7 +309,7 @@ pub fn custom_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream /// Implements all of the dfdx_nn traits automatically on your type. Assumes all fields on your type /// are modules (i.e. they also implement all the dfdx_nn traits). /// -/// [dfdx_nn::Module] is implemented as calling each of the fields on the type in definition order. +/// [dfdx_nn_core::Module] is implemented as calling each of the fields on the type in definition order. /// /// # Example usage /// Here we define a simple feedforward network with 3 layers. diff --git a/dfdx-nn/src/lib.rs b/dfdx-nn/src/lib.rs index d9d22074f..de2fd7d7f 100644 --- a/dfdx-nn/src/lib.rs +++ b/dfdx-nn/src/lib.rs @@ -1,3 +1,155 @@ +//! # Architecture Configuration vs Models +//! +//! `dfdx-nn` differentiates between the *architecture* of a model, and the constructed model (that has parameters on the device). +//! This is mainly to make specifying architecture not dependent on the dtype and device +//! that a model is stored on. +//! +//! For example, a linear model has a couple pieces: +//! 1. The architecture configuration type: [LinearConfig] +//! 2. The actual built type that contains the parameters: [Linear] +//! +//! There's a third piece for convenience: [LinearConstConfig], which let's you specify dimensions at compile time. +//! +//! For specifying architecture, you just need the dimensions for the linear, but not the device/dtype: +//! ```rust +//! use dfdx_nn::LinearConfig; +//! let _: LinearConfig = LinearConfig::new(3, 5); +//! let _: LinearConfig, usize> = LinearConfig::new(Const, 5); +//! let _: LinearConfig> = LinearConfig::new(3, Const); +//! let _: LinearConfig, Const<5>> = LinearConfig::new(Const, Const); +//! let _: LinearConfig, Const<5>> = Default::default(); +//! let _: LinearConstConfig<3, 5> = Default::default(); +//! ``` +//! **Note** that we don't have any idea on what device or what dtype this will be. +//! +//! When we build this configuration into a [Linear] object, it will be placed on a device and have a certain dtype. +//! +//! # Building a model from an architecture +//! +//! We will use [BuildModuleExt::build_module()], an extension trait on devices, to actually construct a model. +//! +//! ```rust +//! # use dfdx_nn::*; +//! # use dfdx::prelude::*; +//! let dev: Cpu = Default::default(); +//! let arch = LinearConfig::new(Const::<3>, 5); +//! let model: Linear, usize, f32, Cpu> = dev.build_module::(arch); +//! ``` +//! +//! Notice here we have to give both the architecture configuration and a dtype. Since we are calling this method +//! on a specific device, we also end up giving the model the device it will be located on. +//! +//! # Using a model +//! +//! There are many things you can do with models. The main action is calling [Module::forward()] and [Module::forward_mut()] +//! during inference and training. +//! +//! ```rust +//! # use dfdx_nn::*; +//! # use dfdx::prelude::*; +//! # let dev: Cpu = Default::default(); +//! let arch = LinearConfig::new(Const::<3>, 5); +//! let model = dev.build_module::(arch); +//! let x: Tensor<(Const<3>,), f32, _> = dev.sample_normal(); +//! let y = model.forward(x); +//! assert_eq!(y.shape(), (5, )); +//! ``` +//! +//! # Composing layers into Sequential models +//! +//! There are multiple ways of doing this. +//! +//! The recommended way is to derive Sequential because: +//! 1. You can reference fields/submodules with named items instead of indexing into tuples. +//! 2. Error messages of deeply nested models are more readable. +//! +//! Under the hood, the code generated for Sequential vs tuples are identical. +//! +//! ## Deriving [Sequential] +//! +//! See [Sequential] for more detailed information. +//! +//! ```rust +//! # use dfdx::prelude::*; +//! # use dfdx_nn::*; +//! #[derive(Debug, Clone, Sequential)] +//! #[built(Mlp)] +//! struct MlpConfig { +//! // Linear with compile time input size & runtime known output size +//! linear1: LinearConfig, usize>, +//! act1: ReLU, +//! // Linear with runtime input & output size +//! linear2: LinearConfig, +//! act2: Tanh, +//! // Linear with runtime input & compile time output size. +//! linear3: LinearConfig>, +//! } +//! +//! // fill in the dimensions for the architecture +//! let arch = MlpConfig { +//! linear1: LinearConfig::new(Const, 256), +//! act1: Default::default(), +//! linear2: LinearConfig::new(256, 128), +//! act2: Default::default(), +//! linear3: LinearConfig::new(128, Const), +//! }; +//! let mut model = dev.build_module::(arch); +//! let x: Tensor<(usize, Const<3>), f32, _> = dev.sample_uniform_like(&(100, Const)); +//! let y = model.forward_mut(x); +//! assert_eq!(y.shape(), (100, Const::<10>)); +//! ``` +//! +//! ## Tuples +//! The simplest is to create a tuple of layer configs, which represents sequential models. +//! +//! Here's an example of how this works: +//! +//! ```rust +//! # use dfdx::prelude::*; +//! # use dfdx_nn::*; +//! # let dev: Cpu = Default::default(); +//! type Arch = (LinearConstConfig<3, 5>, ReLU, LinearConstConfig<5, 10>); +//! let mut model = dev.build_module::(Arch::default()); +//! let x: Tensor<(usize, Const<3>), f32, _> = dev.sample_uniform_like(&(100, Const)); +//! let y = model.forward_mut(x); +//! assert_eq!(y.shape(), (100, Const::<10>)); +//! ``` +//! +//! # Optimizers and Gradients +//! +//! *See [optim] for more information* +//! +//! dfdx-nn supports a number of the standard optimizers: +//! +//! | Optimizer | dfdx | pytorch | +//! | --- | --- | --- | +//! | SGD | [optim::Sgd] | `torch.optim.SGD` | +//! | Adam | [optim::Adam] | torch.optim.Adam` | +//! | AdamW | [optim::Adam] with [optim::WeightDecay::Decoupled] | `torch.optim.AdamW` | +//! | RMSprop | [optim::RMSprop] | `torch.optim.RMSprop` | +//! +//! You can use optimizers to optimize neural networks (or even tensors!). Here's +//! a simple example of how to do this: +//! ```rust +//! # use dfdx::{prelude::*, optim::*}; +//! # let dev: Cpu = Default::default(); +//! type Arch = (LinearConstConfig<3, 5>, ReLU, LinearConstConfig<5, 10>); +//! let arch = Arch::default(); +//! let mut model = dev.build_module::(arch); +//! // 1. allocate gradients for the model +//! let mut grads = model.alloc_grads(); +//! // 2. create our optimizer +//! let mut opt = Sgd::new(&model, Default::default()); +//! // 3. trace gradients through forward pass +//! let x: Tensor, f32, _> = dev.sample_normal(); +//! let y = model.forward_mut(x.traced(grads)); +//! // 4. compute loss & run backpropagation +//! let loss = y.square().mean(); +//! grads = loss.backward(); +//! // 5. apply gradients +//! opt.update(&mut model, &grads); +//! ``` + #![cfg_attr(feature = "nightly", feature(generic_const_exprs))] mod layers; diff --git a/dfdx/src/lib.rs b/dfdx/src/lib.rs index cadd8f278..f3fb85f08 100644 --- a/dfdx/src/lib.rs +++ b/dfdx/src/lib.rs @@ -100,78 +100,6 @@ //! | Concat | [tensor_ops::TryConcat] | `np.concatenate` | `torch.concat` | //! //! and **much much more!** -//! -//! # Neural networks -//! -//! *See [nn] for more information.* -//! -//! Neural networks are composed of building blocks that you can chain together. In -//! dfdx, sequential neural networks are represents by **tuples**! For example, -//! the following two networks are identical: -//! -//! | dfdx | pytorch | -//! | --- | --- | -//! | `(Linear<3, 5>, ReLU, Linear<5, 10>)` | `nn.Sequential(nn.Linear(3, 5), nn.ReLU(), nn.Linear(5, 10))` | -//! | `((Conv2D<3, 2, 1>, Tanh), Conv2D<3, 2, 1>)` | `nn.Sequential(nn.Sequential(nn.Conv2d(3, 2, 1), nn.Tanh()), nn.Conv2d(3, 2, 1))` -//! -//! To build a neural network, you of course need a device: -//! -//! ```rust -//! # use dfdx::prelude::*; -//! let dev: Cpu = Default::default(); -//! type Model = (Linear<3, 5>, ReLU, Linear<5, 10>); -//! let model = dev.build_module::(); -//! ``` -//! -//! Note two things: -//! 1. We are using [nn::DeviceBuildExt] to instantiate the model -//! 2. We **need** to pass a dtype (in this case f32) to create the model. -//! -//! You can then pass tensors into the model with [nn::Module::forward()]: -//! -//! ```rust -//! # use dfdx::prelude::*; -//! # let dev: Cpu = Default::default(); -//! # type Model = (Linear<3, 5>, ReLU, Linear<5, 10>); -//! # let model = dev.build_module::(); -//! // tensor with runtime batch dimension of 10 -//! let x: Tensor<(usize, Const<3>), f32, _> = dev.sample_normal_like(&(10, Const)); -//! let y = model.forward(x); -//! ``` -//! -//! # Optimizers and Gradients -//! -//! *See [optim] for more information* -//! -//! dfdx supports a number of the standard optimizers: -//! -//! | Optimizer | dfdx | pytorch | -//! | --- | --- | --- | -//! | SGD | [optim::Sgd] | `torch.optim.SGD` | -//! | Adam | [optim::Adam] | torch.optim.Adam` | -//! | AdamW | [optim::Adam] with [optim::WeightDecay::Decoupled] | `torch.optim.AdamW` | -//! | RMSprop | [optim::RMSprop] | `torch.optim.RMSprop` | -//! -//! You can use optimizers to optimize neural networks (or even tensors!). Here's -//! a simple example of how to do this with [nn::ZeroGrads]: -//! ```rust -//! # use dfdx::{prelude::*, optim::*}; -//! # let dev: Cpu = Default::default(); -//! type Model = (Linear<3, 5>, ReLU, Linear<5, 10>); -//! let mut model = dev.build_module::(); -//! // 1. allocate gradients for the model -//! let mut grads = model.alloc_grads(); -//! // 2. create our optimizer -//! let mut opt = Sgd::new(&model, Default::default()); -//! // 3. trace gradients through forward pass -//! let x: Tensor, f32, _> = dev.sample_normal(); -//! let y = model.forward_mut(x.traced(grads)); -//! // 4. compute loss & run backpropagation -//! let loss = y.square().mean(); -//! grads = loss.backward(); -//! // 5. apply gradients -//! opt.update(&mut model, &grads); -//! ``` #![cfg_attr(all(feature = "no-std", not(feature = "std")), no_std)] #![allow(incomplete_features)] diff --git a/dfdx/src/tensor/gradients.rs b/dfdx/src/tensor/gradients.rs index 6b8644c14..99dc7a163 100644 --- a/dfdx/src/tensor/gradients.rs +++ b/dfdx/src/tensor/gradients.rs @@ -31,9 +31,6 @@ impl> Gradients { /// This is why this method is called `leaky`, because /// it will keep gradients from previous passes if it is /// used consecutively. - /// - /// **You should use [crate::nn::ZeroGrads::alloc_grads]**, - /// which will ensure non-leaf gradients are freed after backwards. pub fn leaky() -> Self { Self { gradient_by_id: Default::default(), diff --git a/dfdx/src/tensor/mod.rs b/dfdx/src/tensor/mod.rs index 4774dfb24..a6e79cbff 100644 --- a/dfdx/src/tensor/mod.rs +++ b/dfdx/src/tensor/mod.rs @@ -108,7 +108,7 @@ //! # use dfdx::prelude::*; //! # let dev: Cpu = Default::default(); //! let t: Tensor,f32, _> = dev.zeros(); -//! let mut grads = t.alloc_grads(); +//! let mut grads = Gradients::leaky(); //! let t_clone: Tensor, f32, _, OwnedTape> = t.trace(grads); //! ``` //! diff --git a/dfdx/src/tensor/safetensors.rs b/dfdx/src/tensor/safetensors.rs index 6dc3efe42..c0566c406 100644 --- a/dfdx/src/tensor/safetensors.rs +++ b/dfdx/src/tensor/safetensors.rs @@ -4,7 +4,7 @@ use safetensors::tensor::{SafeTensorError, SafeTensors}; use std::vec::Vec; impl, T> Tensor { - /// Loads data from the [SafeTensors] Storage with the given `key` + /// Loads data from the [SafeTensors] `Storage` with the given `key` pub fn load_safetensor( &mut self, tensors: &SafeTensors, diff --git a/dfdx/src/tensor/tensor_impls.rs b/dfdx/src/tensor/tensor_impls.rs index 305d6ca34..5d7a96674 100644 --- a/dfdx/src/tensor/tensor_impls.rs +++ b/dfdx/src/tensor/tensor_impls.rs @@ -63,26 +63,18 @@ pub trait Trace>: Clone { type Traced; /// Start tracking gradients, clones self. The gradients will never free /// temporary gradients - See [Gradients::leaky()] for more info. - /// - /// Prefer to use [Tensor::trace()] with gradients allocated - /// with [crate::nn::ZeroGrads::alloc_grads()]. fn leaky_trace(&self) -> Self::Traced { self.clone().leaky_traced() } /// Start tracking gradients. The gradients will never free /// temporary gradients - See [Gradients::leaky()] for more info. - /// - /// Prefer to use [Tensor::traced()] with gradients allocated - /// with [crate::nn::ZeroGrads::alloc_grads()]. fn leaky_traced(self) -> Self::Traced; - /// Accumulates gradients into `gradients`, clones self. Use [crate::nn::ZeroGrads::alloc_grads()] - /// to create gradients. + /// Accumulates gradients into `gradients`, clones self. fn trace(&self, gradients: Gradients) -> Self::Traced { self.clone().traced(gradients) } - /// Accumulates gradients into `gradients`. Use [crate::nn::ZeroGrads::alloc_grads()] - /// to create gradients. + /// Accumulates gradients into `gradients`. fn traced(self, gradients: Gradients) -> Self::Traced; } diff --git a/dfdx/src/tensor/tensorlike.rs b/dfdx/src/tensor/tensorlike.rs index 85c538260..18e4f074d 100644 --- a/dfdx/src/tensor/tensorlike.rs +++ b/dfdx/src/tensor/tensorlike.rs @@ -7,7 +7,7 @@ use crate::{ use super::{storage_traits::AllocGrad, GhostTensor, Tensor, UniqueId}; /// Contains everything that comprises a tensor, except possibly for the actual data. This really -/// exists to unify handling of [Tensor] and [GhostTensor]. +/// exists to unify handling of [Tensor] and tensors without data. /// /// *If it looks like a tensor and barks like a tensor, then pet it like a tensor.* #[allow(clippy::len_without_is_empty)] diff --git a/dfdx/src/tensor_ops/adam/mod.rs b/dfdx/src/tensor_ops/adam/mod.rs index 1e95777af..34ab40b90 100644 --- a/dfdx/src/tensor_ops/adam/mod.rs +++ b/dfdx/src/tensor_ops/adam/mod.rs @@ -10,11 +10,11 @@ use crate::{ use super::WeightDecay; -/// Configuration of hyperparameters for [crate::optim::Adam]. +/// Configuration of hyperparameters for Adam. /// /// Changing all default parameters: /// ```rust -/// # use dfdx::{prelude::*, optim::*}; +/// # use dfdx::prelude::*; /// AdamConfig { /// lr: 1e-2, /// betas: [0.1, 0.2], diff --git a/dfdx/src/tensor_ops/optim.rs b/dfdx/src/tensor_ops/optim.rs index 469c22c14..793e67bf0 100644 --- a/dfdx/src/tensor_ops/optim.rs +++ b/dfdx/src/tensor_ops/optim.rs @@ -28,13 +28,13 @@ pub(super) fn weight_decay_to_cuda(wd: Option) -> (WeightDecayType, } } -/// Momentum used for [crate::optim::Sgd] and others +/// Momentum used for Sgd and others #[derive(Debug, Clone, Copy)] pub enum Momentum { /// Momentum that is applied to the velocity of a parameter directly. Classic(f64), - /// Momentum that is applied to both velocity and gradients. See [crate::optim::Sgd] nesterov paper for more. + /// Momentum that is applied to both velocity and gradients. See nesterov paper for more. Nesterov(f64), } diff --git a/dfdx/src/tensor_ops/rmsprop/mod.rs b/dfdx/src/tensor_ops/rmsprop/mod.rs index 1899cd5e7..55afb3089 100644 --- a/dfdx/src/tensor_ops/rmsprop/mod.rs +++ b/dfdx/src/tensor_ops/rmsprop/mod.rs @@ -10,7 +10,7 @@ use crate::{ use super::WeightDecay; -/// Configuration of hyperparameters for [crate::optim::RMSprop]. +/// Configuration of hyperparameters for RMSprop. #[derive(Debug, Clone, Copy)] pub struct RMSpropConfig { /// Learning rate. Defaults to `1e-2`. diff --git a/dfdx/src/tensor_ops/sgd/mod.rs b/dfdx/src/tensor_ops/sgd/mod.rs index 112f248a7..6b060feb7 100644 --- a/dfdx/src/tensor_ops/sgd/mod.rs +++ b/dfdx/src/tensor_ops/sgd/mod.rs @@ -10,11 +10,11 @@ use crate::{ use super::optim::{Momentum, WeightDecay}; -/// Configuration of hyperparameters for [crate::optim::Sgd]. +/// Configuration of hyperparameters for Sgd. /// /// Using different learning rate: /// ```rust -/// # use dfdx::{prelude::*, optim::*}; +/// # use dfdx::prelude::*; /// SgdConfig { /// lr: 1e-1, /// momentum: None, @@ -24,7 +24,7 @@ use super::optim::{Momentum, WeightDecay}; /// /// Using classic momentum: /// ```rust -/// # use dfdx::{prelude::*, optim::*}; +/// # use dfdx::prelude::*; /// SgdConfig { /// lr: 1e-2, /// momentum: Some(Momentum::Classic(0.5)), @@ -34,7 +34,7 @@ use super::optim::{Momentum, WeightDecay}; /// /// Using nesterov momentum: /// ```rust -/// # use dfdx::{prelude::*, optim::*}; +/// # use dfdx::prelude::*; /// SgdConfig { /// lr: 1e-3, /// momentum: Some(Momentum::Nesterov(0.25)), @@ -44,7 +44,7 @@ use super::optim::{Momentum, WeightDecay}; /// /// Using L2 weight decay: /// ```rust -/// # use dfdx::{prelude::*, optim::*}; +/// # use dfdx::prelude::*; /// SgdConfig { /// lr: 1e-3, /// momentum: None, @@ -54,7 +54,7 @@ use super::optim::{Momentum, WeightDecay}; /// /// Using decoupled weight decay: /// ```rust -/// # use dfdx::{prelude::*, optim::*}; +/// # use dfdx::prelude::*; /// SgdConfig { /// lr: 1e-3, /// momentum: None,