diff --git a/src/distributions/dirichlet.rs b/src/distributions/dirichlet.rs new file mode 100644 index 00000000000..281bdfb1821 --- /dev/null +++ b/src/distributions/dirichlet.rs @@ -0,0 +1,138 @@ +// Copyright 2013 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// https://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The dirichlet distribution. + +use Rng; +use distributions::Distribution; +use distributions::gamma::Gamma; + +/// The dirichelet distribution `Dirichlet(alpha)`. +/// +/// The Dirichlet distribution is a family of continuous multivariate probability distributions parameterized by +/// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution +/// It is a multivariate generalization of the beta distribution. +/// +/// # Example +/// +/// ``` +/// use rand::prelude::*; +/// use rand::distributions::Dirichlet; +/// +/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]); +/// let samples = dirichlet.sample(&mut rand::thread_rng()); +/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); +/// ``` + +#[derive(Clone, Debug)] +pub struct Dirichlet { + /// Concentration parameters (alpha) + alpha: Vec, +} + +impl Dirichlet { + /// Construct a new `Dirichlet` with the given alpha parameter `alpha`. + /// + /// # Panics + /// - if `alpha.len() < 2` + /// + #[inline] + pub fn new>>(alpha: V) -> Dirichlet { + let a = alpha.into(); + assert!(a.len() > 1); + for i in 0..a.len() { + assert!(a[i] > 0.0); + } + + Dirichlet { alpha: a } + } + + /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`. + /// + /// # Panics + /// - if `alpha <= 0.0` + /// - if `size < 2` + /// + #[inline] + pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet { + assert!(alpha > 0.0); + assert!(size > 1); + Dirichlet { + alpha: vec![alpha; size], + } + } +} + +impl Distribution> for Dirichlet { + fn sample(&self, rng: &mut R) -> Vec { + let n = self.alpha.len(); + let mut samples = vec![0.0f64; n]; + let mut sum = 0.0f64; + + for i in 0..n { + let g = Gamma::new(self.alpha[i], 1.0); + samples[i] = g.sample(rng); + sum += samples[i]; + } + let invacc = 1.0 / sum; + for i in 0..n { + samples[i] *= invacc; + } + samples + } +} + +#[cfg(test)] +mod test { + use super::Dirichlet; + use distributions::Distribution; + + #[test] + fn test_dirichlet() { + let d = Dirichlet::new(vec![1.0, 2.0, 3.0]); + let mut rng = ::test::rng(221); + let samples = d.sample(&mut rng); + let _: Vec = samples + .into_iter() + .map(|x| { + assert!(x > 0.0); + x + }) + .collect(); + } + + #[test] + fn test_dirichlet_with_param() { + let alpha = 0.5f64; + let size = 2; + let d = Dirichlet::new_with_param(alpha, size); + let mut rng = ::test::rng(221); + let samples = d.sample(&mut rng); + let _: Vec = samples + .into_iter() + .map(|x| { + assert!(x > 0.0); + x + }) + .collect(); + } + + #[test] + #[should_panic] + fn test_dirichlet_invalid_length() { + Dirichlet::new_with_param(0.5f64, 1); + } + + #[test] + #[should_panic] + fn test_dirichlet_invalid_alpha() { + Dirichlet::new_with_param(0.0f64, 2); + } +} diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index f6dccc1bbf7..6904d7b62b2 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -81,7 +81,6 @@ //! - Related to real-valued quantities that grow linearly //! (e.g. errors, offsets): //! - [`Normal`] distribution, and [`StandardNormal`] as a primitive -//! - [`Cauchy`] distribution //! - Related to Bernoulli trials (yes/no events, with a given probability): //! - [`Binomial`] distribution //! - [`Bernoulli`] distribution, similar to [`Rng::gen_bool`]. @@ -96,7 +95,8 @@ //! - [`ChiSquared`] distribution //! - [`StudentT`] distribution //! - [`FisherF`] distribution -//! +//! - Related to continuous multivariate probability distributions +//! - [`Dirichlet`] distribution //! //! # Examples //! @@ -150,6 +150,7 @@ //! [`Binomial`]: struct.Binomial.html //! [`Cauchy`]: struct.Cauchy.html //! [`ChiSquared`]: struct.ChiSquared.html +//! [`Dirichlet`]: struct.Dirichlet.html //! [`Exp`]: struct.Exp.html //! [`Exp1`]: struct.Exp1.html //! [`FisherF`]: struct.FisherF.html @@ -184,6 +185,8 @@ pub use self::uniform::Uniform as Range; #[doc(inline)] pub use self::bernoulli::Bernoulli; #[cfg(feature = "std")] #[doc(inline)] pub use self::cauchy::Cauchy; +#[cfg(feature = "std")] +#[doc(inline)] pub use self::dirichlet::Dirichlet; pub mod uniform; #[cfg(feature="std")] @@ -199,6 +202,8 @@ pub mod uniform; #[doc(hidden)] pub mod bernoulli; #[cfg(feature = "std")] #[doc(hidden)] pub mod cauchy; +#[cfg(feature = "std")] +#[doc(hidden)] pub mod dirichlet; mod float; mod integer;