Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to std::simd, expand SIMD & docs #1239

Merged
merged 5 commits into from
Aug 10, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ alloc = ["rand_core/alloc"]
# Option: use getrandom package for seeding
getrandom = ["rand_core/getrandom"]

# Option (requires nightly): experimental SIMD support
simd_support = ["packed_simd"]
# Option (requires nightly Rust): experimental SIMD support
simd_support = []

# Option (enabled by default): enable StdRng
std_rng = ["rand_chacha"]
Expand All @@ -68,13 +68,6 @@ log = { version = "0.4.4", optional = true }
serde = { version = "1.0.103", features = ["derive"], optional = true }
rand_chacha = { path = "rand_chacha", version = "0.3.0", default-features = false, optional = true }

[dependencies.packed_simd]
# NOTE: so far no version works reliably due to dependence on unstable features
package = "packed_simd_2"
version = "0.3.7"
optional = true
features = ["into_bits"]

[target.'cfg(unix)'.dependencies]
# Used for fork protection (reseeding.rs)
libc = { version = "0.2.22", optional = true, default-features = false }
Expand Down
5 changes: 3 additions & 2 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use core::{fmt, u64};

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// The Bernoulli distribution.
///
/// This is a special case of the Binomial distribution where `n = 1`.
Expand Down Expand Up @@ -147,10 +148,10 @@ mod test {
use crate::Rng;

#[test]
#[cfg(feature="serde1")]
#[cfg(feature = "serde1")]
fn test_serializing_deserializing_bernoulli() {
let coin_flip = Bernoulli::new(0.5).unwrap();
let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();

assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
}
Expand Down
74 changes: 39 additions & 35 deletions src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

//! Basic floating-point number distributions

use crate::distributions::utils::FloatSIMDUtils;
use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
use crate::distributions::{Distribution, Standard};
use crate::Rng;
use core::mem;
#[cfg(feature = "simd_support")] use packed_simd::*;
#[cfg(feature = "simd_support")] use core::simd::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
Expand Down Expand Up @@ -99,7 +99,7 @@ macro_rules! float_impls {
// The exponent is encoded using an offset-binary representation
let exponent_bits: $u_scalar =
(($exponent_bias + exponent) as $u_scalar) << $fraction_bits;
$ty::from_bits(self | exponent_bits)
$ty::from_bits(self | $uty::splat(exponent_bits))
}
}

Expand All @@ -108,13 +108,13 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; [0, 1) interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
scale * $ty::cast_from_int(value)
let value = value >> $uty::splat(float_size - precision);
$ty::splat(scale) * $ty::cast_from_int(value)
}
}

Expand All @@ -123,14 +123,14 @@ macro_rules! float_impls {
// Multiply-based method; 24/53 random bits; (0, 1] interval.
// We use the most significant bits because for simple RNGs
// those are usually more random.
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
let precision = $fraction_bits + 1;
let scale = 1.0 / ((1 as $u_scalar << precision) as $f_scalar);

let value: $uty = rng.gen();
let value = value >> (float_size - precision);
let value = value >> $uty::splat(float_size - precision);
// Add 1 to shift up; will not overflow because of right-shift:
scale * $ty::cast_from_int(value + 1)
$ty::splat(scale) * $ty::cast_from_int(value + $uty::splat(1))
}
}

Expand All @@ -140,11 +140,11 @@ macro_rules! float_impls {
// We use the most significant bits because for simple RNGs
// those are usually more random.
use core::$f_scalar::EPSILON;
let float_size = mem::size_of::<$f_scalar>() as u32 * 8;
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;

let value: $uty = rng.gen();
let fraction = value >> (float_size - $fraction_bits);
fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0)
let fraction = value >> $uty::splat(float_size - $fraction_bits);
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0)
}
}
}
Expand All @@ -169,10 +169,10 @@ float_impls! { f64x4, u64x4, f64, u64, 52, 1023 }
#[cfg(feature = "simd_support")]
float_impls! { f64x8, u64x8, f64, u64, 52, 1023 }


#[cfg(test)]
mod tests {
use super::*;
use crate::distributions::utils::FloatAsSIMD;
use crate::rngs::mock::StepRng;

const EPSILON32: f32 = ::core::f32::EPSILON;
Expand All @@ -182,29 +182,31 @@ mod tests {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
let two = $ty::splat(2.0);

// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);

// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 8 | 1 << (8 + 32), 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));

// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
}
};
}
Expand All @@ -222,29 +224,31 @@ mod tests {
($fnn:ident, $ty:ident, $ZERO:expr, $EPSILON:expr) => {
#[test]
fn $fnn() {
let two = $ty::splat(2.0);

// Standard
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.gen::<$ty>(), $ZERO);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.gen::<$ty>(), $EPSILON / 2.0);
assert_eq!(one.gen::<$ty>(), $EPSILON / two);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.gen::<$ty>(), 1.0 - $EPSILON / 2.0);
assert_eq!(max.gen::<$ty>(), $ty::splat(1.0) - $EPSILON / two);

// OpenClosed01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(OpenClosed01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 11, 0);
assert_eq!(one.sample::<$ty, _>(OpenClosed01), $EPSILON);
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + 1.0);
assert_eq!(max.sample::<$ty, _>(OpenClosed01), $ZERO + $ty::splat(1.0));

// Open01
let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), 0.0 + $EPSILON / 2.0);
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 12, 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / 2.0 * 3.0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), 1.0 - $EPSILON / 2.0);
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
}
};
}
Expand Down Expand Up @@ -296,16 +300,16 @@ mod tests {
// non-SIMD types; we assume this pattern continues across all
// SIMD types.

test_samples(&Standard, f32x2::new(0.0, 0.0), &[
f32x2::new(0.0035963655, 0.7346052),
f32x2::new(0.09778172, 0.20298547),
f32x2::new(0.34296435, 0.81664366),
test_samples(&Standard, f32x2::from([0.0, 0.0]), &[
f32x2::from([0.0035963655, 0.7346052]),
f32x2::from([0.09778172, 0.20298547]),
f32x2::from([0.34296435, 0.81664366]),
]);

test_samples(&Standard, f64x2::new(0.0, 0.0), &[
f64x2::new(0.7346051961657583, 0.20298547462974248),
f64x2::new(0.8166436635290655, 0.7423708925400552),
f64x2::new(0.16387782224016323, 0.9087068770169618),
test_samples(&Standard, f64x2::from([0.0, 0.0]), &[
f64x2::from([0.7346051961657583, 0.20298547462974248]),
f64x2::from([0.8166436635290655, 0.7423708925400552]),
f64x2::from([0.16387782224016323, 0.9087068770169618]),
]);
}
}
Expand Down
Loading