-
-
Notifications
You must be signed in to change notification settings - Fork 433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for Dirichlet distribution #485
Changes from 9 commits
546caa8
14462ca
c530db7
57fb716
aad2f45
4d1cc6c
ad13516
d7a5ced
d728989
732209a
2f706a6
fde9567
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
|
||
//! The dirichlet distribution. | ||
|
||
use Rng; | ||
use distributions::Distribution; | ||
use distributions::gamma::Gamma; | ||
|
||
/// The dirichelet distribution `Dirichlet(alpha)`. | ||
/// | ||
/// The Dirichlet distribution is a family of continuous multivariate probability distributions parameterized by | ||
/// a vector alpha of positive reals. https://en.wikipedia.org/wiki/Dirichlet_distribution | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The naked link will look weird in the docs. I think you can remove, there is not precedence in Rand for linking to Wikipedia. |
||
/// It is a multivariate generalization of the beta distribution. | ||
/// | ||
/// # Example | ||
/// | ||
/// ``` | ||
/// use rand::prelude::*; | ||
/// use rand::distributions::Dirichlet; | ||
/// | ||
/// let dirichlet = Dirichlet::new(vec![1.0, 2.0, 3.0]); | ||
/// let samples = dirichlet.sample(&mut rand::thread_rng()); | ||
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); | ||
/// ``` | ||
|
||
#[derive(Clone, Debug)] | ||
pub struct Dirichlet { | ||
/// Concentration parameters (alpha) | ||
alpha: Vec<f64>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking about it, we could probably use On the other hand it may not be worth it since it makes the type less ergonomic to use for what is probably not a lot of gain. Another option would be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do our distributions even work when you write There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
} | ||
|
||
impl Dirichlet { | ||
/// Construct a new `Dirichlet` with the given alpha parameter | ||
/// `alpha`. Panics if `alpha.len() < 2`. | ||
#[inline] | ||
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet { | ||
let a = alpha.into(); | ||
assert!(a.len() > 1); | ||
for i in 0..a.len() { | ||
assert!(a[i] > 0.0); | ||
} | ||
|
||
Dirichlet { alpha: a.into() } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you don't need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This bit. Not sure why my last comment went somewhere else. |
||
} | ||
|
||
/// Construct a new `Dirichlet` with the given shape parameter and size | ||
/// `alpha`. Panics if `alpha <= 0.0`. | ||
/// `size` . Panic if `size < 2` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't render well. If you want a list, leave a blank line, the prefix each item with |
||
#[inline] | ||
pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet { | ||
assert!(alpha > 0.0); | ||
assert!(size > 1); | ||
Dirichlet { | ||
alpha: vec![alpha; size], | ||
} | ||
} | ||
} | ||
|
||
impl Distribution<Vec<f64>> for Dirichlet { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure our current distribution trait is well suited for multivariate distributions. It would be nice to sample without allocating, but this requires different method. Something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is why you opened #496. I agree. On the other hand, I'm not too fussed about having to make breaking changes to this distribution later (it's still better for users than not having it, and we're not close to 1.0). |
||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> { | ||
let n = self.alpha.len(); | ||
let mut samples = vec![0.0f64; n]; | ||
let mut sum = 0.0f64; | ||
|
||
for i in 0..n { | ||
let g = Gamma::new(self.alpha[i], 1.0); | ||
samples[i] = g.sample(rng); | ||
sum += samples[i]; | ||
} | ||
let invacc = 1.0 / sum; | ||
for i in 0..n { | ||
samples[i] *= invacc; | ||
} | ||
samples | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use super::Dirichlet; | ||
use distributions::Distribution; | ||
|
||
#[test] | ||
fn test_dirichlet() { | ||
let d = Dirichlet::new(vec![1.0, 2.0, 3.0]); | ||
let mut rng = ::test::rng(221); | ||
let samples = d.sample(&mut rng); | ||
let _: Vec<f64> = samples | ||
.into_iter() | ||
.map(|x| { | ||
assert!(x > 0.0); | ||
x | ||
}) | ||
.collect(); | ||
} | ||
|
||
#[test] | ||
fn test_dirichlet_with_param() { | ||
let alpha = 0.5f64; | ||
let size = 2; | ||
let d = Dirichlet::new_with_param(alpha, size); | ||
let mut rng = ::test::rng(221); | ||
let samples = d.sample(&mut rng); | ||
let _: Vec<f64> = samples | ||
.into_iter() | ||
.map(|x| { | ||
assert!(x > 0.0); | ||
x | ||
}) | ||
.collect(); | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
fn test_dirichlet_invalid_length() { | ||
Dirichlet::new_with_param(0.5f64, 1); | ||
} | ||
|
||
#[test] | ||
#[should_panic] | ||
fn test_dirichlet_invalid_alpha() { | ||
Dirichlet::new_with_param(0.0f64, 2); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,7 +81,6 @@ | |
//! - Related to real-valued quantities that grow linearly | ||
//! (e.g. errors, offsets): | ||
//! - [`Normal`] distribution, and [`StandardNormal`] as a primitive | ||
//! - [`Cauchy`] distribution | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rohitjoshi did you intend to move the reference to Cauchy in this documentation? Because you haven't added it back. |
||
//! - Related to Bernoulli trials (yes/no events, with a given probability): | ||
//! - [`Binomial`] distribution | ||
//! - [`Bernoulli`] distribution, similar to [`Rng::gen_bool`]. | ||
|
@@ -96,7 +95,8 @@ | |
//! - [`ChiSquared`] distribution | ||
//! - [`StudentT`] distribution | ||
//! - [`FisherF`] distribution | ||
//! | ||
//! - Dirichlet distribution | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's been added now, but |
||
//! - [`Dirichlet`] distribution | ||
//! | ||
//! # Examples | ||
//! | ||
|
@@ -150,6 +150,7 @@ | |
//! [`Binomial`]: struct.Binomial.html | ||
//! [`Cauchy`]: struct.Cauchy.html | ||
//! [`ChiSquared`]: struct.ChiSquared.html | ||
//! [`Dirichlet`]: struct.Dirichlet.html | ||
//! [`Exp`]: struct.Exp.html | ||
//! [`Exp1`]: struct.Exp1.html | ||
//! [`FisherF`]: struct.FisherF.html | ||
|
@@ -184,6 +185,8 @@ pub use self::uniform::Uniform as Range; | |
#[doc(inline)] pub use self::bernoulli::Bernoulli; | ||
#[cfg(feature = "std")] | ||
#[doc(inline)] pub use self::cauchy::Cauchy; | ||
#[cfg(feature = "std")] | ||
#[doc(inline)] pub use self::dirichlet::Dirichlet; | ||
|
||
pub mod uniform; | ||
#[cfg(feature="std")] | ||
|
@@ -199,6 +202,8 @@ pub mod uniform; | |
#[doc(hidden)] pub mod bernoulli; | ||
#[cfg(feature = "std")] | ||
#[doc(hidden)] pub mod cauchy; | ||
#[cfg(feature = "std")] | ||
#[doc(hidden)] pub mod dirichlet; | ||
|
||
mod float; | ||
mod integer; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is a bit long. I think we usually wrap comments at 80 characters.