-
-
Notifications
You must be signed in to change notification settings - Fork 490
Full update of weighted index by assigning weights #1194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ use crate::distributions::Distribution; | |
| use crate::Rng; | ||
| use core::cmp::PartialOrd; | ||
| use core::fmt; | ||
| use core::iter::ExactSizeIterator; | ||
|
|
||
| // Note that this whole module is only imported if feature="alloc" is enabled. | ||
| use alloc::vec::Vec; | ||
|
|
@@ -130,6 +131,119 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> { | |
| }) | ||
| } | ||
|
|
||
| /// Reuses the weighted index by assigning a new set of weights without changing the number of | ||
| /// weights. | ||
| /// | ||
| /// Returns an error if: | ||
| /// | ||
| /// + the number of items in the iterator does not match the number of items used to create the | ||
| /// distribution; | ||
| /// + the iterator yields invalid values (such as `f64::NAN`); | ||
| /// + the weights yielded by the iterator sum up to zero. | ||
| /// | ||
| /// NOTE: If this method fails the distribution should no longer be used for sampling, because | ||
| /// results of sampling from it are undefined. | ||
| pub fn assign_new_weights<I>(&mut self, weights: I) -> Result<(), WeightedError > | ||
vks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| where | ||
| I: IntoIterator, | ||
| I::IntoIter: ExactSizeIterator, | ||
| I::Item: SampleBorrow<X>, | ||
| X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, | ||
| { | ||
| let mut iter = weights.into_iter(); | ||
|
|
||
| if iter.len() != self.cumulative_weights.len() + 1 { | ||
| return Err(WeightedError::LenMismatch); | ||
| } | ||
|
|
||
| let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); | ||
| let zero = <X as Default>::default(); | ||
|
|
||
| if !(total_weight >= zero) { | ||
| return Err(WeightedError::InvalidWeight); | ||
| } | ||
|
|
||
| for (w, c) in iter.zip(self.cumulative_weights.iter_mut()) { | ||
| if !(w.borrow() >= &zero) { | ||
| return Err(WeightedError::InvalidWeight); | ||
| } | ||
| *c = total_weight.clone(); | ||
| total_weight += w.borrow(); | ||
| } | ||
|
|
||
| if total_weight == zero { | ||
dhardy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return Err(WeightedError::AllWeightsZero); | ||
| }; | ||
|
|
||
| self.weight_distribution = X::Sampler::new(zero, total_weight.clone()); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's still a problem: this panics if
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, but here |
||
| self.total_weight = total_weight; | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// Create a `WeightedIndex` from a vector of cumulative weights without verification. | ||
| pub fn from_cumulative_weights_unchecked(ws: Vec<X>) -> Self | ||
| where | ||
| X: Clone + Default, | ||
| { | ||
| let mut cumulative_weights = ws; | ||
| let total_weight = cumulative_weights.pop().unwrap(); | ||
| let zero = <X as Default>::default(); | ||
| let weight_distribution = X::Sampler::new(zero, total_weight.clone()); | ||
| Self { | ||
| cumulative_weights, | ||
| total_weight, | ||
| weight_distribution, | ||
| } | ||
| } | ||
|
|
||
| /// Create a `WeightedIndex` from a vector of cumulative weights without verification. | ||
| pub fn from_weights(ws: Vec<X>) -> Result<Self, WeightedError> | ||
| where | ||
| X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, | ||
| for<'a> &'a X: SampleBorrow<X>, | ||
| { | ||
| let mut cumulative_weights = ws; | ||
| let mut iter = cumulative_weights.iter_mut(); | ||
|
|
||
| let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); | ||
| let zero = <X as Default>::default(); | ||
|
|
||
| if !(total_weight >= zero) { | ||
| return Err(WeightedError::InvalidWeight); | ||
| } | ||
|
|
||
| for w in iter { | ||
| if !(w.borrow() >= &zero) { | ||
| return Err(WeightedError::InvalidWeight); | ||
| } | ||
| total_weight += w.borrow(); | ||
| *w = total_weight.clone(); | ||
| } | ||
|
|
||
| if total_weight == zero { | ||
| return Err(WeightedError::AllWeightsZero); | ||
| }; | ||
|
|
||
| let weight_distribution = X::Sampler::new(zero, total_weight.clone()); | ||
| cumulative_weights.pop().unwrap(); | ||
| Ok(Self { | ||
| cumulative_weights, | ||
| total_weight, | ||
| weight_distribution, | ||
| }) | ||
| } | ||
|
|
||
| /// Remove the inner vector containing the cumulative weights. | ||
| pub fn into_cumulative_weights(self) -> Vec<X> { | ||
| let Self { | ||
| mut cumulative_weights, | ||
| total_weight, | ||
| .. | ||
| } = self; | ||
| cumulative_weights.push(total_weight); | ||
| cumulative_weights | ||
| } | ||
|
|
||
| /// Update a subset of weights, without changing the number of weights. | ||
| /// | ||
| /// `new_weights` must be sorted by the index. | ||
|
|
@@ -389,6 +503,55 @@ mod test { | |
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_assign_new_weights() { | ||
| let data = [ | ||
| ( | ||
| &[10u32, 2, 3, 4][..], | ||
| &[10, 100, 4, 4][..], | ||
| ), | ||
| ( | ||
| &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], | ||
| &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], | ||
| ), | ||
| ]; | ||
|
|
||
| for &(weights, new_weights) in data.iter() { | ||
| let total_weight = weights.iter().sum::<u32>(); | ||
| let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); | ||
| assert_eq!(distr.total_weight, total_weight); | ||
|
|
||
| distr.assign_new_weights(weights).unwrap(); | ||
| assert_eq!(distr.total_weight, total_weight); | ||
|
|
||
| let new_total_weight = new_weights.iter().sum::<u32>(); | ||
| let new_distr = WeightedIndex::new(new_weights.to_vec()).unwrap(); | ||
| distr.assign_new_weights(new_weights).unwrap(); | ||
| assert_eq!(new_total_weight, new_distr.total_weight); | ||
| assert_eq!(new_total_weight, distr.total_weight); | ||
| assert_eq!(new_distr.cumulative_weights, distr.cumulative_weights); | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn assigning_error_states() { | ||
| { | ||
| let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap(); | ||
| let res = distr.assign_new_weights(&[1.0f64, 2.0, 3.0][..]); | ||
| assert_eq!(res, Err(WeightedError::LenMismatch)); | ||
| } | ||
| { | ||
| let mut distr = WeightedIndex::new(&[1.0f64, 2.0, 3.0, 0.0][..]).unwrap(); | ||
| let res = distr.assign_new_weights(&[1.0f64, 2.0, ::core::f64::NAN, 0.0][..]); | ||
| assert_eq!(res, Err(WeightedError::InvalidWeight)); | ||
| } | ||
| { | ||
| let mut distr = WeightedIndex::new(&[1u32, 2, 3, 0][..]).unwrap(); | ||
| let res = distr.assign_new_weights(&[0u32, 0, 0, 0][..]); | ||
| assert_eq!(res, Err(WeightedError::AllWeightsZero)); | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn value_stability() { | ||
| fn test_samples<X: SampleUniform + PartialOrd, I>( | ||
|
|
@@ -436,6 +599,10 @@ pub enum WeightedError { | |
|
|
||
| /// Too many weights are provided (length greater than `u32::MAX`) | ||
| TooMany, | ||
|
|
||
| /// Have to provide exactly as many weights when assigning as were present when constructing | ||
| /// the weighted index. | ||
| LenMismatch, | ||
SuperFluffy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| #[cfg(feature = "std")] | ||
|
|
@@ -448,6 +615,7 @@ impl fmt::Display for WeightedError { | |
| WeightedError::InvalidWeight => "A weight is invalid in distribution", | ||
| WeightedError::AllWeightsZero => "All weights are zero in distribution", | ||
| WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", | ||
| WeightedError::LenMismatch => "Length mismatch between previous and provided weights", | ||
| }) | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.