Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeightedIndexTree to rand_distr #1372

Merged
merged 20 commits into from
Feb 8, 2024
53 changes: 35 additions & 18 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,17 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! This module contains an implementation of a tree sttructure for sampling random
//! This module contains an implementation of a tree structure for sampling random
//! indices with probabilities proportional to a collection of weights.

use core::ops::SubAssign;

use super::WeightedError;
use crate::Distribution;
use alloc::vec::Vec;
use rand::{
distributions::{
uniform::{SampleBorrow, SampleUniform},
Weight,
},
Rng,
};
use rand::distributions::uniform::{SampleBorrow, SampleUniform};
use rand::distributions::Weight;
use rand::Rng;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

Expand All @@ -29,7 +25,7 @@ use serde::{Deserialize, Serialize};
/// Sampling a [`WeightedTreeIndex<W>`] distribution returns the index of a randomly
/// selected element from the vector used to create the [`WeightedTreeIndex<W>`].
/// The chance of a given element being picked is proportional to the value of
/// the element. The weights can have any type `W` for which a implementation of
/// the element. The weights can have any type `W` for which an implementation of
/// [`Weight`] exists.
///
/// # Key differences
Expand Down Expand Up @@ -71,15 +67,16 @@ use serde::{Deserialize, Serialize};
/// dist.push(1).unwrap();
/// dist.update(1, 1).unwrap();
/// let mut rng = thread_rng();
/// let mut samples = [0; 3];
/// for _ in 0..100 {
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// let i = dist.sample(&mut rng).unwrap();
/// println!("{}", choices[i]);
/// let i = dist.sample(&mut rng);
/// samples[i] += 1;
/// }
/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::<Vec<_>>());
/// ```
///
/// [`WeightedTreeIndex<W>`]: WeightedTreeIndex
/// [`Uniform<W>::sample`]: Distribution::sample
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(
Expand Down Expand Up @@ -132,7 +129,7 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeInd
/// Returns `true` if we can sample.
///
/// This is the case if the total weight of the tree is greater than zero.
pub fn can_sample(&self) -> bool {
pub fn is_valid(&self) -> bool {
if let Some(weight) = self.subtotals.first() {
*weight > W::ZERO
} else {
Expand Down Expand Up @@ -229,9 +226,13 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeInd
}

impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
Distribution<Result<usize, WeightedError>> for WeightedTreeIndex<W>
WeightedTreeIndex<W>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
/// Samples a randomly selected index from the weighted distribution.
///
/// Returns an error if there are no elements or all weights are zero. This
/// is unlike [`Distribution::sample`], which panics in those cases.
fn safe_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
xmakro marked this conversation as resolved.
Show resolved Hide resolved
if self.subtotals.is_empty() {
return Err(WeightedError::NoItem);
}
Expand Down Expand Up @@ -269,6 +270,19 @@ impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
}
}

impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight> Distribution<usize>
for WeightedTreeIndex<W>
{
/// Samples a randomly selected index from the weighted distribution.
///
/// Caution: This method panics if there are no elements or all weights are zero. However,
/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`]
/// returns `true`.
xmakro marked this conversation as resolved.
Show resolved Hide resolved
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
self.safe_sample(rng).unwrap()
}
}

#[cfg(test)]
mod test {
use super::*;
Expand All @@ -277,7 +291,10 @@ mod test {
fn test_no_item_error() {
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem);
assert_eq!(
tree.safe_sample(&mut rng).unwrap_err(),
WeightedError::NoItem
);
}

#[test]
Expand All @@ -297,7 +314,7 @@ mod test {
let tree = WeightedTreeIndex::<f64>::new(&[0.0, 0.0]).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
assert_eq!(
tree.sample(&mut rng).unwrap_err(),
tree.safe_sample(&mut rng).unwrap_err(),
WeightedError::AllWeightsZero
);
}
Expand Down Expand Up @@ -350,7 +367,7 @@ mod test {
}
let mut counts = alloc::vec![0_usize; end];
for _ in 0..samples {
let i = tree.sample(&mut rng).unwrap();
let i = tree.sample(&mut rng);
counts[i] += 1;
}
for i in 0..start {
Expand Down
Loading