Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benches/distributions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ distr_float!(distr_log_normal, f64, LogNormal::new(-1.23, 4.56).unwrap());
distr_float!(distr_gamma_large_shape, f64, Gamma::new(10., 1.0).unwrap());
distr_float!(distr_gamma_small_shape, f64, Gamma::new(0.1, 1.0).unwrap());
distr_float!(distr_cauchy, f64, Cauchy::new(4.2, 6.9).unwrap());
distr_float!(distr_triangular, f64, Triangular::new(0., 1., 0.9).unwrap());
distr_int!(distr_binomial, u64, Binomial::new(20, 0.7).unwrap());
distr_int!(distr_binomial_small, u64, Binomial::new(1000000, 1e-30).unwrap());
distr_int!(distr_poisson, u64, Poisson::new(4.0).unwrap());
Expand Down
8 changes: 4 additions & 4 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl Dirichlet {
///
/// Requires `size >= 2`.
#[inline]
pub fn new_with_param(alpha: f64, size: usize) -> Result<Dirichlet, Error> {
pub fn new_with_size(alpha: f64, size: usize) -> Result<Dirichlet, Error> {
if !(alpha > 0.0) {
return Err(Error::AlphaTooSmall);
}
Expand Down Expand Up @@ -124,7 +124,7 @@ mod test {
fn test_dirichlet_with_param() {
let alpha = 0.5f64;
let size = 2;
let d = Dirichlet::new_with_param(alpha, size).unwrap();
let d = Dirichlet::new_with_size(alpha, size).unwrap();
let mut rng = crate::test::rng(221);
let samples = d.sample(&mut rng);
let _: Vec<f64> = samples
Expand All @@ -139,12 +139,12 @@ mod test {
#[test]
#[should_panic]
fn test_dirichlet_invalid_length() {
Dirichlet::new_with_param(0.5f64, 1).unwrap();
Dirichlet::new_with_size(0.5f64, 1).unwrap();
}

#[test]
#[should_panic]
fn test_dirichlet_invalid_alpha() {
Dirichlet::new_with_param(0.0f64, 2).unwrap();
Dirichlet::new_with_size(0.0f64, 2).unwrap();
}
}
25 changes: 24 additions & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ pub use self::gamma::{Gamma, Error as GammaError, ChiSquared, ChiSquaredError,
pub use self::normal::{Normal, Error as NormalError, LogNormal, StandardNormal};
pub use self::exponential::{Exp, Error as ExpError, Exp1};
pub use self::pareto::{Pareto, Error as ParetoError};
pub use self::pert::{Pert, PertError};
pub use self::poisson::{Poisson, Error as PoissonError};
pub use self::binomial::{Binomial, Error as BinomialError};
pub use self::cauchy::{Cauchy, Error as CauchyError};
pub use self::dirichlet::{Dirichlet, Error as DirichletError};
pub use self::triangular::{Triangular, Error as TriangularError};
pub use self::triangular::{Triangular, TriangularError};
pub use self::weibull::{Weibull, Error as WeibullError};

mod unit_sphere;
Expand All @@ -81,6 +82,7 @@ mod gamma;
mod normal;
mod exponential;
mod pareto;
mod pert;
mod poisson;
mod binomial;
mod cauchy;
Expand All @@ -92,8 +94,29 @@ mod ziggurat_tables;

#[cfg(test)]
mod test {
// Notes on testing
//
// Testing random number distributions correctly is hard. The following
// testing is desired:
//
// - Construction: test initialisation with a few valid parameter sets.
// - Erroneous usage: test that incorrect usage generates an error.
// - Vector: test that usage with fixed inputs (including RNG) generates a
// fixed output sequence on all platforms.
// - Correctness at fixed points (optional): using a specific mock RNG,
// check that specific values are sampled (e.g. end-points and median of
// distribution).
// - Correctness of PDF (extra): generate a histogram of samples within a
// certain range, and check this approximates the PDF. These tests are
// expected to be expensive, and should be behind a feature-gate.
//
// TODO: Vector and correctness tests are largely absent so far.
// NOTE: Some distributions have tests checking only that samples can be
// generated. This is redundant with vector and correctness tests.

use rand::{RngCore, SeedableRng, rngs::StdRng};

/// Construct a deterministic RNG with the given seed
pub fn rng(seed: u64) -> impl RngCore {
StdRng::seed_from_u64(seed)
}
Expand Down
125 changes: 125 additions & 0 deletions rand_distr/src/pert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The PERT distribution.

use rand::Rng;
use crate::{Distribution, Beta};

/// The PERT distribution.
///
/// Similar to the [Triangular] distribution, the PERT distribution is
/// parameterised by a range and a mode within that range. Unlike the
/// [Triangular] distribution, the probability density function of the PERT
/// distribution is smooth, with a configurable weighting around the mode.
///
/// # Example
///
/// ```rust
/// use rand_distr::{Pert, Distribution};
///
/// let d = Pert::new(0., 5., 2.5).unwrap();
/// let v = d.sample(&mut rand::thread_rng());
/// println!("{} is from a PERT distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Pert {
min: f64,
range: f64,
beta: Beta,
}

/// Error type returned from [`Pert`] constructors.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PertError {
/// `max < min` or `min` or `max` is NaN.
RangeTooSmall,
/// `mode < min` or `mode > max` or `mode` is NaN.
ModeRange,
/// `shape < 0` or `shape` is NaN
ShapeTooSmall,
}

impl Pert {
/// Set up the PERT distribution with defined `min`, `max` and `mode`.
///
/// This is equivalent to calling `Pert::new_shape` with `shape == 4.0`.
#[inline]
pub fn new(min: f64, max: f64, mode: f64) -> Result<Pert, PertError> {
Pert::new_with_shape(min, max, mode, 4.)
}

/// Set up the PERT distribution with defined `min`, `max`, `mode` and
/// `shape`.
pub fn new_with_shape(min: f64, max: f64, mode: f64, shape: f64) -> Result<Pert, PertError> {
if !(max > min) {
return Err(PertError::RangeTooSmall);
}
if !(mode >= min && max >= mode) {
return Err(PertError::ModeRange);
}
if !(shape >= 0.) {
return Err(PertError::ShapeTooSmall);
}

let range = max - min;
let mu = (min + max + shape * mode) / (shape + 2.);
let v = if mu == mode {
shape * 0.5 + 1.
} else {
(mu - min) * (2. * mode - min - max)
/ ((mode - mu) * (max - min))
};
let w = v * (max - mu) / (mu - min);
let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
Ok(Pert{ min, range, beta })
}
}

impl Distribution<f64> for Pert {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
self.beta.sample(rng) * self.range + self.min
}
}

#[cfg(test)]
mod test {
use std::f64;
use super::*;

#[test]
fn test_pert() {
for &(min, max, mode) in &[
(-1., 1., 0.),
(1., 2., 1.),
(5., 25., 25.),
] {
let _distr = Pert::new(min, max, mode).unwrap();
// TODO: test correctness
}

for &(min, max, mode) in &[
(-1., 1., 2.),
(-1., 1., -2.),
(2., 1., 1.),
] {
assert!(Pert::new(min, max, mode).is_err());
}
}

#[test]
fn test_pert_vector() {
let rng = crate::test::rng(860);
let distr = Pert::new(2., 10., 3.).unwrap(); // mean = 4, var = 12/7
let seq = distr.sample_iter(rng).take(5).collect::<Vec<f64>>();
println!("seq: {:?}", seq);
let expected = vec![3.945192480331639, 4.571769050527243,
7.419819712922435, 4.049743197259167, 5.825644880531534];
assert!(seq == expected);
}
}
92 changes: 56 additions & 36 deletions rand_distr/src/triangular.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ use rand::Rng;
use crate::{Distribution, Standard};

/// The triangular distribution.
///
/// A continuous probability distribution parameterised by a range, and a mode
/// (most likely value) within that range.
///
/// The probability density function is triangular. For a similar distribution
/// with a smooth PDF, see the [Pert] distribution.
///
/// # Example
///
Expand All @@ -28,30 +34,24 @@ pub struct Triangular {
mode: f64,
}

/// Error type returned from `Triangular::new`.
/// Error type returned from [`Triangular::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `max < mode` or `max` is `nan`.
MaxTooSmall,
/// `mode < min` or `mode` is `nan`.
ModeTooSmall,
/// `max == min` or `min` is `nan`.
MaxEqualsMin,
pub enum TriangularError {
/// `max < min` or `min` or `max` is NaN.
RangeTooSmall,
/// `mode < min` or `mode > max` or `mode` is NaN.
ModeRange,
}

impl Triangular {
/// Construct a new `Triangular` with minimum `min`, maximum `max` and mode
/// `mode`.
/// Set up the Triangular distribution with defined `min`, `max` and `mode`.
#[inline]
pub fn new(min: f64, max: f64, mode: f64) -> Result<Triangular, Error> {
if !(max >= mode) {
return Err(Error::MaxTooSmall);
}
if !(mode >= min) {
return Err(Error::ModeTooSmall);
pub fn new(min: f64, max: f64, mode: f64) -> Result<Triangular, TriangularError> {
if !(max >= min) {
return Err(TriangularError::RangeTooSmall);
}
if !(max != min) {
return Err(Error::MaxEqualsMin);
if !(mode >= min && max >= mode) {
return Err(TriangularError::ModeRange);
}
Ok(Triangular { min, max, mode })
}
Expand All @@ -62,37 +62,57 @@ impl Distribution<f64> for Triangular {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let f: f64 = rng.sample(Standard);
let diff_mode_min = self.mode - self.min;
let diff_max_min = self.max - self.min;
if f * diff_max_min < diff_mode_min {
self.min + (f * diff_max_min * diff_mode_min).sqrt()
let range = self.max - self.min;
let f_range = f * range;
if f_range < diff_mode_min {
self.min + (f_range * diff_mode_min).sqrt()
} else {
self.max - ((1. - f) * diff_max_min * (self.max - self.mode)).sqrt()
self.max - ((range - f_range) * (self.max - self.mode)).sqrt()
}
}
}

#[cfg(test)]
mod test {
use crate::Distribution;
use super::Triangular;
use std::f64;
use rand::{Rng, rngs::mock};
use super::*;

#[test]
fn test_new() {
fn test_triangular() {
let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0);
assert_eq!(half_rng.gen::<f64>(), 0.5);
for &(min, max, mode, median) in &[
(-1., 1., 0., 0.),
(1., 2., 1., 2. - 0.5f64.sqrt()),
(5., 25., 25., 5. + 200f64.sqrt()),
(1e-5, 1e5, 1e-3, 1e5 - 4999999949.5f64.sqrt()),
(0., 1., 0.9, 0.45f64.sqrt()),
(-4., -0.5, -2., -4.0 + 3.5f64.sqrt()),
] {
println!("{} {} {} {}", min, max, mode, median);
let distr = Triangular::new(min, max, mode).unwrap();
// Test correct value at median:
assert_eq!(distr.sample(&mut half_rng), median);
}

for &(min, max, mode) in &[
(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.), (1e-5, 1e5, 1e-3),
(0., 1., 0.9), (-4., -0.5, -2.), (-13.039, 8.41, 1.17),
(-1., 1., 2.),
(-1., 1., -2.),
(2., 1., 1.),
] {
println!("{} {} {}", min, max, mode);
let _ = Triangular::new(min, max, mode).unwrap();
assert!(Triangular::new(min, max, mode).is_err());
}
}

#[test]
fn test_sample() {
let norm = Triangular::new(0., 1., 0.5).unwrap();
let mut rng = crate::test::rng(1);
for _ in 0..1000 {
norm.sample(&mut rng);
}
fn test_triangular_vector() {
let rng = crate::test::rng(860);
let distr = Triangular::new(2., 10., 3.).unwrap();
let seq = distr.sample_iter(rng).take(5).collect::<Vec<f64>>();
println!("seq: {:?}", seq);
let expected = vec![4.941640229082449, 2.421447306833011,
4.5964271605527385, 2.789763631136542, 4.8014432067978445];
assert!(seq == expected);
}
}