Skip to content

Commit

Permalink
Add rand_distr::TruncatedNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
ongchi committed Nov 4, 2024
1 parent 585b29f commit d0352bc
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 2 deletions.
1 change: 1 addition & 0 deletions rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ rand = { path = "..", version = "=0.9.0-alpha.1", default-features = false }
num-traits = { version = "0.2", default-features = false, features = ["libm"] }
serde = { version = "1.0.103", features = ["derive"], optional = true }
serde_with = { version = ">= 3.0, <= 3.11", optional = true }
spec_math = "0.1.6"

[dev-dependencies]
rand_pcg = { version = "=0.9.0-alpha.1", path = "../rand_pcg" }
Expand Down
2 changes: 1 addition & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric};
pub use self::gumbel::{Error as GumbelError, Gumbel};
pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric};
pub use self::inverse_gaussian::{Error as InverseGaussianError, InverseGaussian};
pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal};
pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal, TruncatedNormal};
pub use self::normal_inverse_gaussian::{
Error as NormalInverseGaussianError, NormalInverseGaussian,
};
Expand Down
192 changes: 191 additions & 1 deletion rand_distr/src/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
use crate::utils::ziggurat;
use crate::{ziggurat_tables, Distribution, Open01};
use core::fmt;
use num_traits::Float;
use num_traits::{cast, Float, FloatConst};
use rand::distr::uniform::Uniform;
use rand::Rng;

use spec_math::cephes64::{ndtr, ndtri};

/// The standard Normal distribution `N(0, 1)`.
///
/// This is equivalent to `Normal::new(0.0, 1.0)`, but faster.
Expand Down Expand Up @@ -160,13 +163,16 @@ pub enum Error {
MeanTooSmall,
/// The standard deviation or other dispersion parameter is not finite.
BadVariance,
/// The left bound is greater than or equal to the right bound.
InvalidInterval,
}

impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::MeanTooSmall => "mean < 0 or NaN in log-normal distribution",
Error::BadVariance => "variation parameter is non-finite in (log)normal distribution",
Error::InvalidInterval => "the left bound must be less than right bound",
})
}
}
Expand Down Expand Up @@ -363,6 +369,144 @@ where
}
}

/// The [Truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution)
///
/// # Example
///
/// ```
/// use rand_distr::{TruncatedNormal, Distribution};
///
/// let truncnorm = TruncatedNormal::new(0., 1., -1.0, 2.0).unwrap();
/// let val = truncnorm.sample(&mut rand::thread_rng());
/// println!("{}", val);
/// ```
///
/// # Notes
///
/// This implementation is ported from [`scipy.stats.truncnorm`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html),
/// which is based on [Cephes Mathematical Library](https://www.netlib.org/cephes/).
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TruncatedNormal<F>
where
F: Float + FloatConst,
{
mean: F,
std_dev: F,
a: F,
b: F,
uniform: Uniform<f64>,
}

impl<F> TruncatedNormal<F>
where
F: Float + FloatConst,
{
/// Construct, from mean and standard deviation
///
/// Parameters:
///
/// - mean (`μ`, unrestricted)
/// - standard deviation (`σ`, must be finite)
/// - left bound (`a`)
/// - right bound (`b`)
#[inline]
pub fn new(mean: F, std_dev: F, a: F, b: F) -> Result<TruncatedNormal<F>, Error> {
if !std_dev.is_finite() {
return Err(Error::BadVariance);
}
if a >= b {
return Err(Error::InvalidInterval);
}
Ok(TruncatedNormal {
mean,
std_dev,
a,
b,
uniform: Uniform::new(0., 1.).unwrap(),
})
}

/// Returns the mean (`μ`) of the distribution.
pub fn mean(&self) -> F {
self.mean
}

/// Returns the standard deviation (`σ`) of the distribution.
pub fn std_dev(&self) -> F {
self.std_dev
}

/// Cumulative Distribution Function
pub fn cdf(&self, x: F) -> F {
cast(ndtr(cast(x).unwrap())).unwrap()
}

/// Inverse Cumulative Distribution Function
pub fn icdf(&self, x: F) -> F {
cast(ndtri(cast(x).unwrap())).unwrap()
}

/// Percent Point Function
/// based on `scipy.stats.truncnorm` with modifications
pub fn ppf(&self, q: F) -> F {
// logsumexp trick for log(p + q) with only log(p) and log(q)
let log_sum_exp = |log_p: F, log_q: F| -> F {
let max = log_p.max(log_q);
((log_p - max).exp() + (log_q - max).exp()).ln() + max
};

// Log diff for log(p - q) and insuring that the difference is not negative
let log_diff_exp = |log_p: F, log_q: F| -> F {
let max = log_p.max(log_q);
((log_p - max).exp() - (log_q - max).exp()).abs().ln() + max
};

// Log of Gaussian probability mass within an interval
let log_gauss_mass = |a: F, b: F| -> F {
if b <= F::zero() {
log_diff_exp(self.cdf(b).ln(), self.cdf(a).ln())
} else if a > F::zero() {
// Calculations in right tail are inaccurate, so we'll exploit the
// symmetry and work only in the left tail
log_diff_exp(self.cdf(-b).ln(), self.cdf(-a).ln())
} else {
// Catastrophic cancellation occurs as exp(log_mass) approaches 1.
// Correct for this with an alternative formulation.
// We're not concerned with underflow here: if only one term
// underflows, it was insignificant; if both terms underflow,
// the result can't accurately be represented in logspace anyway
// because log1p(x) ~ x for small x.
(-self.cdf(a) - self.cdf(-b)).ln_1p()
}
};

if self.a < F::zero() {
let log_phi_x = log_sum_exp(
self.cdf(self.a).ln(),
q.ln() + log_gauss_mass(self.a, self.b),
);
self.icdf(log_phi_x.exp())
} else {
let log_phi_x = log_sum_exp(
self.cdf(-self.b).ln(),
(-q).ln_1p() + log_gauss_mass(self.a, self.b),
);
-self.icdf(log_phi_x.exp())
}
}
}

impl<F> Distribution<F> for TruncatedNormal<F>
where
F: Float + FloatConst,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.mean + self.std_dev * self.ppf(cast(self.uniform.sample(rng)).unwrap())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -429,4 +573,50 @@ mod tests {
fn log_normal_distributions_can_be_compared() {
assert_eq!(LogNormal::new(1.0, 2.0), LogNormal::new(1.0, 2.0));
}

#[test]
fn test_truncated_normal_pos() {
let truncnorm = TruncatedNormal::new(0., 1., 1., 2.).unwrap();
let mut rng = crate::test::rng(212);
let mut integral = 0.;
for _ in 0..1000 {
integral += truncnorm.sample(&mut rng);
}
// According to the result from:
// https://www.wolframalpha.com/input?i=integral+e%5E%28-%28x%5E2%29%2F2%29%2F%28sqrt%282+%CF%80%29%29+from+1+to+2
//
// integral e^(-(x^2)/2)/(sqrt(2 π)) from 1 to 2 ≈ 0.135905
//
// The error of the integral result by 1000 samples from TruncatedNormal is below 3%
assert_almost_eq!(integral, 1359.05, 1359.05 * 0.03);
}

#[test]
fn test_truncated_normal_neg() {
let truncnorm = TruncatedNormal::new(0., 1., -2., -1.).unwrap();
let mut rng = crate::test::rng(212);
let mut integral = 0.;
for _ in 0..1000 {
integral += truncnorm.sample(&mut rng);
}
// Mirror case of the `test_truncated_normal_pos`
assert_almost_eq!(integral, -1359.05, 1359.05 * 0.03);
}

#[test]
fn test_truncated_normal_across() {
let truncnorm = TruncatedNormal::new(0., 1., -1., 1.).unwrap();
let mut rng = crate::test::rng(212);
let mut integral = 0.;
for _ in 0..1000 {
integral += truncnorm.sample(&mut rng);
}
// Symmetry case, the sum of result is almost equal to zero.
assert_almost_eq!(integral, 0., 1359.05 * 0.03);
}

#[test]
fn test_truncated_normal_invalid_bounds() {
assert!(TruncatedNormal::new(0., 1., 2., 1.).is_err());
}
}

0 comments on commit d0352bc

Please sign in to comment.