Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeightedIndexTree to rand_distr #1372

Merged
merged 20 commits into from
Feb 8, 2024
9 changes: 8 additions & 1 deletion rand_distr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@
//! - [`UnitBall`] distribution
//! - [`UnitCircle`] distribution
//! - [`UnitDisc`] distribution
//! - Alternative implementation for weighted index sampling
//! - Alternative implementations for weighted index sampling
//! - [`WeightedAliasIndex`] distribution
//! - [`WeightedTreeIndex`] distribution
//! - Misc. distributions
//! - [`InverseGaussian`] distribution
//! - [`NormalInverseGaussian`] distribution
Expand Down Expand Up @@ -133,6 +134,9 @@ pub use rand::distributions::{WeightedError, WeightedIndex};
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use weighted_alias::WeightedAliasIndex;
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use weighted_tree::WeightedTreeIndex;

pub use num_traits;

Expand Down Expand Up @@ -186,6 +190,9 @@ mod test {
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub mod weighted_alias;
#[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub mod weighted_tree;

mod binomial;
mod cauchy;
Expand Down
364 changes: 364 additions & 0 deletions rand_distr/src/weighted_tree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
// Copyright 2024 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

//! This module contains an implementation of a tree sttructure for sampling random
xmakro marked this conversation as resolved.
Show resolved Hide resolved
//! indices with probabilities proportional to a collection of weights.

use core::ops::SubAssign;

use super::WeightedError;
use crate::Distribution;
use alloc::vec::Vec;
use rand::{
distributions::{
uniform::{SampleBorrow, SampleUniform},
Weight,
},
Rng,
};
xmakro marked this conversation as resolved.
Show resolved Hide resolved
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

/// A distribution using weighted sampling to pick a discretely selected item.
///
/// Sampling a [`WeightedTreeIndex<W>`] distribution returns the index of a randomly
/// selected element from the vector used to create the [`WeightedTreeIndex<W>`].
/// The chance of a given element being picked is proportional to the value of
/// the element. The weights can have any type `W` for which a implementation of
xmakro marked this conversation as resolved.
Show resolved Hide resolved
/// [`Weight`] exists.
///
/// # Key differences
///
/// The main distinction between [`WeightedTreeIndex<W>`] and [`rand::distributions::WeightedIndex<W>`]
/// lies in the internal representation of weights. In [`WeightedTreeIndex<W>`],
/// weights are structured as a tree, which is optimized for frequent updates of the weights.
///
/// # Caution: Floating point types
///
/// When utilizing [`WeightedTreeIndex<W>`] with floating point types (such as f32 or f64),
/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types
/// are susceptible to numerical rounding errors. Since operations on floating point weights are
/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable
/// deviations from the expected behavior.
///
/// Ideally, use fixed point or integer types whenever possible.
///
/// # Performance
///
/// A [`WeightedTreeIndex<W>`] with `n` elements requires `O(n)` memory.
///
/// Time complexity for the operations of a [`WeightedTreeIndex<W>`] are:
/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time.
/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time.
/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time.
/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time.
/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time.
///
/// # Example
///
/// ```
/// use rand_distr::WeightedTreeIndex;
/// use rand::prelude::*;
///
/// let choices = vec!['a', 'b', 'c'];
/// let weights = vec![2, 0];
/// let mut dist = WeightedTreeIndex::new(&weights).unwrap();
/// dist.push(1).unwrap();
/// dist.update(1, 1).unwrap();
/// let mut rng = thread_rng();
/// for _ in 0..100 {
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// let i = dist.sample(&mut rng).unwrap();
/// println!("{}", choices[i]);
/// }
xmakro marked this conversation as resolved.
Show resolved Hide resolved
/// ```
///
/// [`WeightedTreeIndex<W>`]: WeightedTreeIndex
/// [`Uniform<W>::sample`]: Distribution::sample
xmakro marked this conversation as resolved.
Show resolved Hide resolved
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(
feature = "serde1",
serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
)]
#[cfg_attr(
feature = "serde1 ",
serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
)]
#[derive(Clone, Default, Debug, PartialEq)]
pub struct WeightedTreeIndex<W: SampleUniform> {
xmakro marked this conversation as resolved.
Show resolved Hide resolved
subtotals: Vec<W>,
}

impl<W: Clone + PartialEq + PartialOrd + SampleUniform + Weight> WeightedTreeIndex<W> {
/// Creates a new [`WeightedTreeIndex`] from a slice of weights.
pub fn new<I>(weights: I) -> Result<Self, WeightedError>
where
I: IntoIterator,
I::Item: SampleBorrow<W>,
{
let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
for weight in subtotals.iter() {
if *weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
}
}
let n = subtotals.len();
for i in (1..n).rev() {
let w = subtotals[i].clone();
let parent = (i - 1) / 2;
subtotals[parent]
.checked_add_assign(&w)
.map_err(|()| WeightedError::Overflow)?;
}
Ok(Self { subtotals })
}

/// Returns `true` if the tree contains no weights.
pub fn is_empty(&self) -> bool {
self.subtotals.is_empty()
}

/// Returns the number of weights.
pub fn len(&self) -> usize {
self.subtotals.len()
}

/// Returns `true` if we can sample.
///
/// This is the case if the total weight of the tree is greater than zero.
pub fn can_sample(&self) -> bool {
xmakro marked this conversation as resolved.
Show resolved Hide resolved
if let Some(weight) = self.subtotals.first() {
*weight > W::ZERO
} else {
false
}
}

/// Gets the weight at an index.
pub fn get(&self, index: usize) -> W
where
W: for<'a> SubAssign<&'a W>,
{
let left_index = 2 * index + 1;
let right_index = 2 * index + 2;
let mut w = self.subtotals[index].clone();
w -= &self.subtotal(left_index);
w -= &self.subtotal(right_index);
w
}

/// Removes the last weight and returns it, or [`None`] if it is empty.
pub fn pop(&mut self) -> Option<W>
where
W: for<'a> SubAssign<&'a W>,
{
self.subtotals.pop().map(|weight| {
let mut index = self.len();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index] -= &weight;
}
weight
})
}

/// Appends a new weight at the end.
pub fn push(&mut self, weight: W) -> Result<(), WeightedError> {
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
}
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&weight).is_err() {
return Err(WeightedError::Overflow);
}
}
let mut index = self.len();
self.subtotals.push(weight.clone());
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index].checked_add_assign(&weight).unwrap();
}
Ok(())
}

/// Updates the weight at an index.
pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError>
where
W: for<'a> SubAssign<&'a W>,
{
if weight < W::ZERO {
return Err(WeightedError::InvalidWeight);
}
let mut difference = weight;
difference -= &self.get(index);
dhardy marked this conversation as resolved.
Show resolved Hide resolved
if difference == W::ZERO {
return Ok(());
}
if let Some(total) = self.subtotals.first() {
let mut total = total.clone();
if total.checked_add_assign(&difference).is_err() {
return Err(WeightedError::Overflow);
}
}
self.subtotals[index]
.checked_add_assign(&difference)
.unwrap();
while index != 0 {
index = (index - 1) / 2;
self.subtotals[index]
.checked_add_assign(&difference)
.unwrap();
}
Ok(())
}

fn subtotal(&self, index: usize) -> W {
if index < self.subtotals.len() {
self.subtotals[index].clone()
} else {
W::ZERO
}
}
}

impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
Distribution<Result<usize, WeightedError>> for WeightedTreeIndex<W>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't usually implement Distribution<Result<..>>. Do we need to?

I think I would prefer to panic on error, but guarantee no panic if self.is_valid() (self.can_sample()).

@vks?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the Result<>, so sample now panics instead. I added the info about is_valid to the doc string and introduced safe_sample as an alternative that does not panic.

{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, WeightedError> {
if self.subtotals.is_empty() {
return Err(WeightedError::NoItem);
}
let total_weight = self.subtotals[0].clone();
if total_weight == W::ZERO {
return Err(WeightedError::AllWeightsZero);
}
let mut target_weight = rng.gen_range(W::ZERO..total_weight);
let mut index = 0;
loop {
// Maybe descend into the left sub tree.
let left_index = 2 * index + 1;
let left_subtotal = self.subtotal(left_index);
if target_weight < left_subtotal {
index = left_index;
continue;
}
target_weight -= left_subtotal;

// Maybe descend into the right sub tree.
let right_index = 2 * index + 2;
let right_subtotal = self.subtotal(right_index);
if target_weight < right_subtotal {
index = right_index;
continue;
}
target_weight -= right_subtotal;

// Otherwise we found the index with the target weight.
break;
}
assert!(target_weight >= W::ZERO);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this must be strictly > 0?

Copy link
Contributor Author

@xmakro xmakro Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an intermediate state, this can be zero if all weights are zero or there are no elements. is_valid would then return false. Allowing this is useful, so that the user does not have to apply the updates in the right order to avoid intermediate zero states.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case the function will already have returned an error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I mixed up what this comment was referring to.

I think target_weight can be 0 in this line, for example: if we have a tree with only one node and gen_range samples 0, then the loop block is a no-op and after we hit this line with target_weight = 0.

However, I realized the line right after was incorrect, it should be: assert!(target_weight < self.get(index)). Should be correct now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree

assert!(target_weight < self.subtotal(index));
Ok(index)
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_no_item_error() {
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem);
}

#[test]
fn test_overflow_error() {
assert_eq!(
WeightedTreeIndex::new(&[i32::MAX, 2]),
Err(WeightedError::Overflow)
);
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap();
assert_eq!(tree.push(3), Err(WeightedError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow));
tree.update(1, 2).unwrap();
}

#[test]
fn test_all_weights_zero_error() {
let tree = WeightedTreeIndex::<f64>::new(&[0.0, 0.0]).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
assert_eq!(
tree.sample(&mut rng).unwrap_err(),
WeightedError::AllWeightsZero
);
}

#[test]
fn test_invalid_weight_error() {
assert_eq!(
WeightedTreeIndex::<i32>::new(&[1, -1]).unwrap_err(),
WeightedError::InvalidWeight
);
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight);
tree.push(1).unwrap();
assert_eq!(
tree.update(0, -1).unwrap_err(),
WeightedError::InvalidWeight
);
}

#[test]
fn test_tree_modifications() {
let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap();
tree.push(3).unwrap();
tree.push(5).unwrap();
tree.update(0, 0).unwrap();
assert_eq!(tree.pop(), Some(5));
let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap();
assert_eq!(tree, expected);
}

#[test]
fn test_sample_counts_match_probabilities() {
let start = 1;
let end = 3;
let samples = 20;
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect();
let mut tree = WeightedTreeIndex::new(&weights).unwrap();
let mut total_weight = 0.0;
let mut weights = alloc::vec![0.0; end];
for i in 0..end {
tree.update(i, i as f64).unwrap();
weights[i] = i as f64;
total_weight += i as f64;
}
for i in 0..start {
tree.update(i, 0.0).unwrap();
weights[i] = 0.0;
total_weight -= i as f64;
}
let mut counts = alloc::vec![0_usize; end];
for _ in 0..samples {
let i = tree.sample(&mut rng).unwrap();
counts[i] += 1;
}
for i in 0..start {
assert_eq!(counts[i], 0);
}
for i in start..end {
let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight;
assert!(diff.abs() < 0.05);
}
}
}
Loading
Loading