diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index c8fd298171..1e28aaaa79 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -76,8 +76,9 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution -//! - Alternative implementation for weighted index sampling +//! - Alternative implementations for weighted index sampling //! - [`WeightedAliasIndex`] distribution +//! - [`WeightedTreeIndex`] distribution //! - Misc. distributions //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution @@ -133,6 +134,9 @@ pub use rand::distributions::{WeightedError, WeightedIndex}; #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub use weighted_alias::WeightedAliasIndex; +#[cfg(feature = "alloc")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub use weighted_tree::WeightedTreeIndex; pub use num_traits; @@ -186,6 +190,9 @@ mod test { #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod weighted_alias; +#[cfg(feature = "alloc")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub mod weighted_tree; mod binomial; mod cauchy; diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs new file mode 100644 index 0000000000..b308cdb2c0 --- /dev/null +++ b/rand_distr/src/weighted_tree.rs @@ -0,0 +1,384 @@ +// Copyright 2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! This module contains an implementation of a tree structure for sampling random +//! indices with probabilities proportional to a collection of weights. + +use core::ops::SubAssign; + +use super::WeightedError; +use crate::Distribution; +use alloc::vec::Vec; +use rand::distributions::uniform::{SampleBorrow, SampleUniform}; +use rand::distributions::Weight; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedTreeIndex`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which an implementation of +/// [`Weight`] exists. +/// +/// # Key differences +/// +/// The main distinction between [`WeightedTreeIndex`] and [`rand::distributions::WeightedIndex`] +/// lies in the internal representation of weights. In [`WeightedTreeIndex`], +/// weights are structured as a tree, which is optimized for frequent updates of the weights. +/// +/// # Caution: Floating point types +/// +/// When utilizing [`WeightedTreeIndex`] with floating point types (such as f32 or f64), +/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types +/// are susceptible to numerical rounding errors. Since operations on floating point weights are +/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable +/// deviations from the expected behavior. +/// +/// Ideally, use fixed point or integer types whenever possible. +/// +/// # Performance +/// +/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. +/// +/// Time complexity for the operations of a [`WeightedTreeIndex`] are: +/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time. +/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. +/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. +/// +/// # Example +/// +/// ``` +/// use rand_distr::WeightedTreeIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 0]; +/// let mut dist = WeightedTreeIndex::new(&weights).unwrap(); +/// dist.push(1).unwrap(); +/// dist.update(1, 1).unwrap(); +/// let mut rng = thread_rng(); +/// let mut samples = [0; 3]; +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// let i = dist.sample(&mut rng); +/// samples[i] += 1; +/// } +/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::>()); +/// ``` +/// +/// [`WeightedTreeIndex`]: WeightedTreeIndex +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde1", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde1 ", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] +#[derive(Clone, Default, Debug, PartialEq)] +pub struct WeightedTreeIndex< + W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign + Weight, +> { + subtotals: Vec, +} + +impl + Weight> + WeightedTreeIndex +{ + /// Creates a new [`WeightedTreeIndex`] from a slice of weights. + pub fn new(weights: I) -> Result + where + I: IntoIterator, + I::Item: SampleBorrow, + { + let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); + for weight in subtotals.iter() { + if *weight < W::ZERO { + return Err(WeightedError::InvalidWeight); + } + } + let n = subtotals.len(); + for i in (1..n).rev() { + let w = subtotals[i].clone(); + let parent = (i - 1) / 2; + subtotals[parent] + .checked_add_assign(&w) + .map_err(|()| WeightedError::Overflow)?; + } + Ok(Self { subtotals }) + } + + /// Returns `true` if the tree contains no weights. + pub fn is_empty(&self) -> bool { + self.subtotals.is_empty() + } + + /// Returns the number of weights. + pub fn len(&self) -> usize { + self.subtotals.len() + } + + /// Returns `true` if we can sample. + /// + /// This is the case if the total weight of the tree is greater than zero. + pub fn is_valid(&self) -> bool { + if let Some(weight) = self.subtotals.first() { + *weight > W::ZERO + } else { + false + } + } + + /// Gets the weight at an index. + pub fn get(&self, index: usize) -> W { + let left_index = 2 * index + 1; + let right_index = 2 * index + 2; + let mut w = self.subtotals[index].clone(); + w -= self.subtotal(left_index); + w -= self.subtotal(right_index); + w + } + + /// Removes the last weight and returns it, or [`None`] if it is empty. + pub fn pop(&mut self) -> Option { + self.subtotals.pop().map(|weight| { + let mut index = self.len(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= weight.clone(); + } + weight + }) + } + + /// Appends a new weight at the end. + pub fn push(&mut self, weight: W) -> Result<(), WeightedError> { + if weight < W::ZERO { + return Err(WeightedError::InvalidWeight); + } + if let Some(total) = self.subtotals.first() { + let mut total = total.clone(); + if total.checked_add_assign(&weight).is_err() { + return Err(WeightedError::Overflow); + } + } + let mut index = self.len(); + self.subtotals.push(weight.clone()); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index].checked_add_assign(&weight).unwrap(); + } + Ok(()) + } + + /// Updates the weight at an index. + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> { + if weight < W::ZERO { + return Err(WeightedError::InvalidWeight); + } + let old_weight = self.get(index); + if weight > old_weight { + let mut difference = weight; + difference -= old_weight; + if let Some(total) = self.subtotals.first() { + let mut total = total.clone(); + if total.checked_add_assign(&difference).is_err() { + return Err(WeightedError::Overflow); + } + } + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); + } + } else if weight < old_weight { + let mut difference = old_weight; + difference -= weight; + self.subtotals[index] -= difference.clone(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= difference.clone(); + } + } + Ok(()) + } + + fn subtotal(&self, index: usize) -> W { + if index < self.subtotals.len() { + self.subtotals[index].clone() + } else { + W::ZERO + } + } +} + +impl + Weight> + WeightedTreeIndex +{ + /// Samples a randomly selected index from the weighted distribution. + /// + /// Returns an error if there are no elements or all weights are zero. This + /// is unlike [`Distribution::sample`], which panics in those cases. + fn try_sample(&self, rng: &mut R) -> Result { + if self.subtotals.is_empty() { + return Err(WeightedError::NoItem); + } + let total_weight = self.subtotals[0].clone(); + if total_weight == W::ZERO { + return Err(WeightedError::AllWeightsZero); + } + let mut target_weight = rng.gen_range(W::ZERO..total_weight); + let mut index = 0; + loop { + // Maybe descend into the left sub tree. + let left_index = 2 * index + 1; + let left_subtotal = self.subtotal(left_index); + if target_weight < left_subtotal { + index = left_index; + continue; + } + target_weight -= left_subtotal; + + // Maybe descend into the right sub tree. + let right_index = 2 * index + 2; + let right_subtotal = self.subtotal(right_index); + if target_weight < right_subtotal { + index = right_index; + continue; + } + target_weight -= right_subtotal; + + // Otherwise we found the index with the target weight. + break; + } + assert!(target_weight >= W::ZERO); + assert!(target_weight < self.get(index)); + Ok(index) + } +} + +/// Samples a randomly selected index from the weighted distribution. +/// +/// Caution: This method panics if there are no elements or all weights are zero. However, +/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] +/// returns `true`. +impl + Weight> Distribution + for WeightedTreeIndex +{ + fn sample(&self, rng: &mut R) -> usize { + self.try_sample(rng).unwrap() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_no_item_error() { + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!( + tree.try_sample(&mut rng).unwrap_err(), + WeightedError::NoItem + ); + } + + #[test] + fn test_overflow_error() { + assert_eq!( + WeightedTreeIndex::new(&[i32::MAX, 2]), + Err(WeightedError::Overflow) + ); + let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap(); + assert_eq!(tree.push(3), Err(WeightedError::Overflow)); + assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow)); + tree.update(1, 2).unwrap(); + } + + #[test] + fn test_all_weights_zero_error() { + let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + assert_eq!( + tree.try_sample(&mut rng).unwrap_err(), + WeightedError::AllWeightsZero + ); + } + + #[test] + fn test_invalid_weight_error() { + assert_eq!( + WeightedTreeIndex::::new(&[1, -1]).unwrap_err(), + WeightedError::InvalidWeight + ); + let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight); + tree.push(1).unwrap(); + assert_eq!( + tree.update(0, -1).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[test] + fn test_tree_modifications() { + let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap(); + tree.push(3).unwrap(); + tree.push(5).unwrap(); + tree.update(0, 0).unwrap(); + assert_eq!(tree.pop(), Some(5)); + let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap(); + assert_eq!(tree, expected); + } + + #[test] + fn test_sample_counts_match_probabilities() { + let start = 1; + let end = 3; + let samples = 20; + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect(); + let mut tree = WeightedTreeIndex::new(&weights).unwrap(); + let mut total_weight = 0.0; + let mut weights = alloc::vec![0.0; end]; + for i in 0..end { + tree.update(i, i as f64).unwrap(); + weights[i] = i as f64; + total_weight += i as f64; + } + for i in 0..start { + tree.update(i, 0.0).unwrap(); + weights[i] = 0.0; + total_weight -= i as f64; + } + let mut counts = alloc::vec![0_usize; end]; + for _ in 0..samples { + let i = tree.sample(&mut rng); + counts[i] += 1; + } + for i in 0..start { + assert_eq!(counts[i], 0); + } + for i in start..end { + let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; + assert!(diff.abs() < 0.05); + } + } +} diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index de3628b5ea..0b1b4da947 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -18,7 +18,7 @@ use core::fmt; use alloc::vec::Vec; #[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; /// A distribution using weighted sampling of discrete items /// @@ -33,9 +33,12 @@ use serde::{Serialize, Deserialize}; /// # Performance /// /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. As an alternative, -/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) +/// `N` is the number of weights. There are two alternative implementations with +/// different runtimes characteristics: +/// * [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) /// supports `O(1)` sampling, but with much higher initialisation cost. +/// * [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) +/// keeps the weights in a tree structure where sampling and updating is `O(log N)`. /// /// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its /// size is the sum of the size of those objects, possibly plus some alignment. @@ -144,15 +147,21 @@ impl WeightedIndex { /// allocation internally. /// /// In case of error, `self` is not modified. - /// + /// + /// Updates take `O(N)` time. If you need to frequently update weights, consider + /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) + /// as an alternative where an update is `O(log N)`. + /// /// Note: Updating floating-point weights may cause slight inaccuracies in the total weight. /// This method may not return `WeightedError::AllWeightsZero` when all weights - /// are zero if using floating-point weights. + /// are zero if using floating-point weights. pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where X: for<'a> ::core::ops::AddAssign<&'a X> + where + X: for<'a> ::core::ops::AddAssign<&'a X> + for<'a> ::core::ops::SubAssign<&'a X> + Clone - + Default { + + Default, + { if new_weights.is_empty() { return Ok(()); } @@ -230,12 +239,14 @@ impl WeightedIndex { } impl Distribution for WeightedIndex -where X: SampleUniform + PartialOrd +where + X: SampleUniform + PartialOrd, { fn sample(&self, rng: &mut R) -> usize { let chosen_weight = self.weight_distribution.sample(rng); // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights.partition_point(|w| w <= &chosen_weight) + self.cumulative_weights + .partition_point(|w| w <= &chosen_weight) } } @@ -288,7 +299,7 @@ macro_rules! impl_weight_float { Ok(()) } } - } + }; } impl_weight_float!(f32); impl_weight_float!(f64); @@ -314,7 +325,7 @@ mod test { } #[test] - fn test_accepting_nan(){ + fn test_accepting_nan() { assert_eq!( WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), WeightedError::InvalidWeight, @@ -337,7 +348,6 @@ mod test { ) } - #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_weightedindex() { @@ -461,15 +471,21 @@ mod test { } let mut buf = [0; 10]; - test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, - ]); - test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, - ]); - test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, - ]); + test_samples( + &[1i32, 1, 1, 1, 1, 1, 1, 1, 1], + &mut buf, + &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], + ); + test_samples( + &[0.7f32, 0.1, 0.1, 0.1], + &mut buf, + &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], + ); + test_samples( + &[1.0f64, 0.999, 0.998, 0.997], + &mut buf, + &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], + ); } #[test] @@ -479,7 +495,10 @@ mod test { #[test] fn overflow() { - assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow)); + assert_eq!( + WeightedIndex::new([2, usize::MAX]), + Err(WeightedError::Overflow) + ); } }