diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index b8e7df44..3fa926b0 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -1,5 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; +use crate::prec; use crate::statistics::*; use crate::{Result, StatsError}; use rand::Rng; @@ -132,20 +133,60 @@ impl ContinuousCDF for Gamma { fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 - } - else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { + } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { 0.0 - } - else if self.rate.is_infinite() { + } else if self.rate.is_infinite() { 1.0 - } - else if x.is_infinite() { + } else if x.is_infinite() { 0.0 - } - else { + } else { gamma::gamma_ur(self.shape, x * self.rate) } } + + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("default inverse_cdf implementation should be provided probability on [0,1]") + } + if p == 0.0 { + return self.min(); + }; + if p == 1.0 { + return self.max(); + }; + + // Bisection search for MAX_ITERS.0 iterations + let mut high = 2.0; + let mut low = 1.0; + while self.cdf(low) > p { + low /= 2.0; + } + while self.cdf(high) < p { + high *= 2.0; + } + let mut x_0 = (high + low) / 2.0; + + for _ in 0..8 { + if self.cdf(x_0) >= p { + high = x_0; + } else { + low = x_0; + } + if prec::convergence(&mut x_0, (high + low) / 2.0) { + break; + } + } + + // Newton Raphson, for at least one step + for _ in 0..4 { + let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0); + if prec::convergence(&mut x_0, x_next) { + break; + } + } + + x_0 + } } impl Min for Gamma { @@ -456,7 +497,11 @@ mod tests { for &(arg, res) in test.iter() { test_case_special(arg, res, 10e-6, f); } - let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, f64::INFINITY), 0.0)]; + let test = [ + ((10.0, 10.0), 0.9), + ((10.0, 1.0), 9.0), + ((10.0, f64::INFINITY), 0.0), + ]; for &(arg, res) in test.iter() { test_case(arg, res, f); } @@ -562,6 +607,32 @@ mod tests { test_case((1.0, 0.1), 0.0, |x| x.cdf(0.0)); } + #[test] + fn test_cdf_inverse_identity() { + let f = |p: f64| move |g: Gamma| g.cdf(g.inverse_cdf(p)); + let params = [ + (1.0, 0.1), + (1.0, 1.0), + (10.0, 10.0), + (10.0, 1.0), + (100.0, 200.0), + ]; + + for param in params { + for n in -5..0 { + let p = 10.0f64.powi(n); + test_case(param, p, f(p)); + } + } + + // test case from issue #200 + { + let x = 20.5567; + let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x)); + test_case((3.0, 0.5), x, f(x)) + } + } + #[test] fn test_sf() { let f = |arg: f64| move |x: Gamma| x.sf(arg); diff --git a/src/prec.rs b/src/prec.rs index 59ad3714..042c8b22 100644 --- a/src/prec.rs +++ b/src/prec.rs @@ -25,3 +25,12 @@ pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { (a - b).abs() < acc } + +/// Compares if two floats are close via `approx::relative_eq!` +/// and `crate::consts::ACC` relative precision. +/// Updates first argument to value of second argument +pub fn convergence(x: &mut f64, x_new: f64) -> bool { + let res = approx::relative_eq!(*x, x_new, max_relative = crate::consts::ACC); + *x = x_new; + res +}