diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index efcfed67ba..5445f3baf3 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -36,6 +36,7 @@ rand = { path = "..", version = "=0.9.0-alpha.1", default-features = false } num-traits = { version = "0.2", default-features = false, features = ["libm"] } serde = { version = "1.0.103", features = ["derive"], optional = true } serde_with = { version = ">= 3.0, <= 3.11", optional = true } +spec_math = "0.1.6" [dev-dependencies] rand_pcg = { version = "=0.9.0-alpha.1", path = "../rand_pcg" } diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 03fad85c91..995bd878d9 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -112,7 +112,7 @@ pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric}; pub use self::gumbel::{Error as GumbelError, Gumbel}; pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric}; pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian}; -pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; +pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal, TruncatedNormal}; pub use self::normal_inverse_gaussian::{ Error as NormalInverseGaussianError, NormalInverseGaussian, }; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index 330c1ec2d6..4da8b1c58e 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -12,9 +12,12 @@ use crate::utils::ziggurat; use crate::{ziggurat_tables, Distribution, Open01}; use core::fmt; -use num_traits::Float; +use num_traits::{cast, Float, FloatConst}; +use rand::distr::uniform::Uniform; use rand::Rng; +use spec_math::cephes64::{ndtr, ndtri}; + /// The standard Normal distribution `N(0, 1)`. /// /// This is equivalent to `Normal::new(0.0, 1.0)`, but faster. @@ -160,6 +163,8 @@ pub enum Error { MeanTooSmall, /// The standard deviation or other dispersion parameter is not finite. BadVariance, + /// The left bound is greater than or equal to the right bound. + InvalidInterval, } impl fmt::Display for Error { @@ -167,6 +172,7 @@ impl fmt::Display for Error { f.write_str(match self { Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution", Error::BadVariance => "variation parameter is non-finite in (log)normal distribution", + Error::InvalidInterval => "the left bound must be less than right bound", }) } } @@ -363,6 +369,144 @@ where } } +/// The [Truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{TruncatedNormal, Distribution}; +/// +/// let truncnorm = TruncatedNormal::new(0., 1., -1.0, 2.0).unwrap(); +/// let val = truncnorm.sample(&mut rand::thread_rng()); +/// println!("{}", val); +/// ``` +/// +/// # Notes +/// +/// This implementation is ported from [`scipy.stats.truncnorm`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html), +/// which is based on [Cephes Mathematical Library](https://www.netlib.org/cephes/). +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct TruncatedNormal +where + F: Float + FloatConst, +{ + mean: F, + std_dev: F, + a: F, + b: F, + uniform: Uniform, +} + +impl TruncatedNormal +where + F: Float + FloatConst, +{ + /// Construct, from mean and standard deviation + /// + /// Parameters: + /// + /// - mean (`μ`, unrestricted) + /// - standard deviation (`σ`, must be finite) + /// - left bound (`a`) + /// - right bound (`b`) + #[inline] + pub fn new(mean: F, std_dev: F, a: F, b: F) -> Result, Error> { + if !std_dev.is_finite() { + return Err(Error::BadVariance); + } + if a >= b { + return Err(Error::InvalidInterval); + } + Ok(TruncatedNormal { + mean, + std_dev, + a, + b, + uniform: Uniform::new(0., 1.).unwrap(), + }) + } + + /// Returns the mean (`μ`) of the distribution. + pub fn mean(&self) -> F { + self.mean + } + + /// Returns the standard deviation (`σ`) of the distribution. + pub fn std_dev(&self) -> F { + self.std_dev + } + + /// Cumulative Distribution Function + pub fn cdf(&self, x: F) -> F { + cast(ndtr(cast(x).unwrap())).unwrap() + } + + /// Inverse Cumulative Distribution Function + pub fn icdf(&self, x: F) -> F { + cast(ndtri(cast(x).unwrap())).unwrap() + } + + /// Percent Point Function + /// based on `scipy.stats.truncnorm` with modifications + pub fn ppf(&self, q: F) -> F { + // logsumexp trick for log(p + q) with only log(p) and log(q) + let log_sum_exp = |log_p: F, log_q: F| -> F { + let max = log_p.max(log_q); + ((log_p - max).exp() + (log_q - max).exp()).ln() + max + }; + + // Log diff for log(p - q) and insuring that the difference is not negative + let log_diff_exp = |log_p: F, log_q: F| -> F { + let max = log_p.max(log_q); + ((log_p - max).exp() - (log_q - max).exp()).abs().ln() + max + }; + + // Log of Gaussian probability mass within an interval + let log_gauss_mass = |a: F, b: F| -> F { + if b <= F::zero() { + log_diff_exp(self.cdf(b).ln(), self.cdf(a).ln()) + } else if a > F::zero() { + // Calculations in right tail are inaccurate, so we'll exploit the + // symmetry and work only in the left tail + log_diff_exp(self.cdf(-b).ln(), self.cdf(-a).ln()) + } else { + // Catastrophic cancellation occurs as exp(log_mass) approaches 1. + // Correct for this with an alternative formulation. + // We're not concerned with underflow here: if only one term + // underflows, it was insignificant; if both terms underflow, + // the result can't accurately be represented in logspace anyway + // because log1p(x) ~ x for small x. + (-self.cdf(a) - self.cdf(-b)).ln_1p() + } + }; + + if self.a < F::zero() { + let log_phi_x = log_sum_exp( + self.cdf(self.a).ln(), + q.ln() + log_gauss_mass(self.a, self.b), + ); + self.icdf(log_phi_x.exp()) + } else { + let log_phi_x = log_sum_exp( + self.cdf(-self.b).ln(), + (-q).ln_1p() + log_gauss_mass(self.a, self.b), + ); + -self.icdf(log_phi_x.exp()) + } + } +} + +impl Distribution for TruncatedNormal +where + F: Float + FloatConst, +{ + #[inline] + fn sample(&self, rng: &mut R) -> F { + self.mean + self.std_dev * self.ppf(cast(self.uniform.sample(rng)).unwrap()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -429,4 +573,50 @@ mod tests { fn log_normal_distributions_can_be_compared() { assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0)); } + + #[test] + fn test_truncated_normal_pos() { + let truncnorm = TruncatedNormal::new(0., 1., 1., 2.).unwrap(); + let mut rng = crate::test::rng(212); + let mut integral = 0.; + for _ in 0..1000 { + integral += truncnorm.sample(&mut rng); + } + // According to the result from: + // https://www.wolframalpha.com/input?i=integral+e%5E%28-%28x%5E2%29%2F2%29%2F%28sqrt%282+%CF%80%29%29+from+1+to+2 + // + // integral e^(-(x^2)/2)/(sqrt(2 π)) from 1 to 2 ≈ 0.135905 + // + // The error of the integral result by 1000 samples from TruncatedNormal is below 3% + assert_almost_eq!(integral, 1359.05, 1359.05 * 0.03); + } + + #[test] + fn test_truncated_normal_neg() { + let truncnorm = TruncatedNormal::new(0., 1., -2., -1.).unwrap(); + let mut rng = crate::test::rng(212); + let mut integral = 0.; + for _ in 0..1000 { + integral += truncnorm.sample(&mut rng); + } + // Mirror case of the `test_truncated_normal_pos` + assert_almost_eq!(integral, -1359.05, 1359.05 * 0.03); + } + + #[test] + fn test_truncated_normal_across() { + let truncnorm = TruncatedNormal::new(0., 1., -1., 1.).unwrap(); + let mut rng = crate::test::rng(212); + let mut integral = 0.; + for _ in 0..1000 { + integral += truncnorm.sample(&mut rng); + } + // Symmetry case, the sum of result is almost equal to zero. + assert_almost_eq!(integral, 0., 1359.05 * 0.03); + } + + #[test] + fn test_truncated_normal_invalid_bounds() { + assert!(TruncatedNormal::new(0., 1., 2., 1.).is_err()); + } }