From dd744a20d54ffbfb40c26848b49ee215809d7e82 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Fri, 16 May 2025 12:11:40 +0200 Subject: [PATCH 1/8] WIP: Needs new pharmsol --- src/lib.rs | 5 ++ src/mmopt/mod.rs | 210 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) create mode 100644 src/mmopt/mod.rs diff --git a/src/lib.rs b/src/lib.rs index 1416e5ec1..990262988 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,9 @@ pub mod routines; // Structures pub mod structs; +// MMopt +pub mod mmopt; + // Re-export commonly used items pub use anyhow::Result; pub use std::collections::HashMap; @@ -41,6 +44,8 @@ pub mod prelude { pub use crate::routines::settings::*; pub use crate::structs::*; + pub use crate::mmopt::*; + //Alma re-exports pub mod simulator { pub use pharmsol::prelude::simulator::*; diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs new file mode 100644 index 000000000..5a77def9e --- /dev/null +++ b/src/mmopt/mod.rs @@ -0,0 +1,210 @@ +use anyhow::Result; +use faer::Mat; +use pharmsol::{ + prelude::simulator::SubjectPredictions, Data, Equation, ErrorModel, Predictions, Subject, +}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use serde_json::error; +use std::fmt::Error; + +use crate::structs::theta::Theta; + +pub struct PredictionsContainer { + pub matrix: Mat, + pub times: Vec, + pub probs: Vec, +} + +impl PredictionsContainer { + fn matrix(&self) -> &Mat { + &self.matrix + } + + fn nsub(&self) -> usize { + self.matrix.ncols() + } + fn nout(&self) -> usize { + self.matrix.nrows() + } +} + +struct CostMatrix { + matrix: Option>, + auc: f64, + cmax: f64, + cmin: f64, +} + +impl CostMatrix { + pub fn new(auc: f64, cmax: f64, cmin: f64) -> Self { + !unimplemented!() + } +} + +/// The results of a multiple-model optimization +/// +/// +#[derive(Debug)] +pub struct MmoptResult { + // Optimal sample times + pub times: Vec, + // Bayes risk + pub risk: f64, +} + +pub fn mmopt( + theta: &Theta, + subject: &Subject, + equation: impl Equation, + errormodel: ErrorModel, + nsamp: usize, +) -> Result { + // Check that subject contains only one Occasion + if subject.occasions().len() != 1 { + return Err(anyhow::anyhow!("Subject must contain only one Occasion")); + } + + // Generate predictions + let predictions = theta + .matrix() + .row_iter() + .map(|theta_row| { + let support_point: Vec = theta_row.iter().cloned().collect(); + let predictions = equation + .estimate_predictions(&subject, &support_point) + .get_predictions(); + predictions + }) + .collect::>(); + + // Times vector + let times = predictions[0].iter().map(|p| p.time()).collect::>(); + + // Generate prediction matrix + let pred_matrix = Mat::from_fn(predictions[0].len(), theta.nspp(), |i, j| { + predictions[j][i].prediction().to_owned() + }); + + // Generate sample candidate indices + let candidate_indices = generate_combinations(times.len(), nsamp); + + // em + let e = errormodel.; + + let (best_combo, min_risk) = candidate_indices + .par_iter() + .map(|combo| { + let mut risk = 0.0; + // Compare the i-th and the j-th subject predictions + for i in 0..theta.nspp() { + for j in 0..theta.nspp() { + if i != j { + let i_obs: Vec = pred_matrix + .col(i) + .iter() + .enumerate() + .filter_map(|(k, &x)| if combo.contains(&k) { Some(x) } else { None }) + .collect(); + + let j_obs: Vec = pred_matrix + .col(j) + .iter() + .enumerate() + .filter_map(|(k, &x)| if combo.contains(&k) { Some(x) } else { None }) + .collect(); + + let i_var: Vec = + i_obs.iter().map(|&x| errormodel.(x)).collect(); + let j_var: Vec = + j_obs.iter().map(|&x| errorpoly.variance(x)).collect(); + + let sum_k_ijn: f64 = i_obs + .iter() + .zip(j_obs.iter()) + .zip(i_var.iter()) + .zip(j_var.iter()) + .map(|(((y_i, y_j), i_var), j_var)| { + let denominator = i_var + j_var; + let term1 = (y_i - y_j).powi(2) / (4.0 * denominator); + let term2 = 0.5 * ((i_var + j_var) / 2.0).ln(); + let term3 = -0.25 * (i_var * j_var).ln(); + term1 + term2 + term3 + }) + .collect::>() + .iter() + .sum::(); + + let prob_i = predictions.probs[i]; + let prob_j = predictions.probs[j]; + let cost = cost_matrix.matrix[(i, j)]; + let risk_component = prob_i * prob_j * (-sum_k_ijn).exp() * cost; + risk += risk_component; + } + } + } + + (combo.clone(), risk) + }) + .min_by(|(_, risk_a), (_, risk_b)| risk_a.partial_cmp(risk_b).unwrap()) + .unwrap(); + + let res = MmoptResult { + best_combo_indices: best_combo.clone(), + best_combo_times: best_combo + .iter() + .map(|&index| predictions.times[index]) + .collect(), + min_risk, + }; + + Ok(res) +} + +fn generate_combinations(m: usize, n: usize) -> Vec> { + fn backtrack( + m: usize, + n: usize, + start: usize, + current: &mut Vec, + results: &mut Vec>, + ) { + if current.len() == n { + results.push(current.clone()); + return; + } + + for i in start..m { + current.push(i); + backtrack(m, n, i + 1, current, results); + current.pop(); + } + } + + let mut results = Vec::new(); + let mut current = Vec::new(); + backtrack(m, n, 0, &mut current, &mut results); + results +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_combinations() { + let m = 5; + let n = 3; + let combinations = generate_combinations(m, n); + assert_eq!(combinations.len(), 10); + assert_eq!(combinations[0], vec![0, 1, 2]); + assert_eq!(combinations[1], vec![0, 1, 3]); + assert_eq!(combinations[2], vec![0, 1, 4]); + assert_eq!(combinations[3], vec![0, 2, 3]); + assert_eq!(combinations[4], vec![0, 2, 4]); + assert_eq!(combinations[5], vec![0, 3, 4]); + assert_eq!(combinations[6], vec![1, 2, 3]); + assert_eq!(combinations[7], vec![1, 2, 4]); + assert_eq!(combinations[8], vec![1, 3, 4]); + assert_eq!(combinations[9], vec![2, 3, 4]); + } +} From a6beba4b0537adee4eb0c988929c41b6cbaebaea Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Thu, 22 May 2025 16:13:20 +0200 Subject: [PATCH 2/8] WIP --- src/mmopt/mod.rs | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs index 5a77def9e..2b070479d 100644 --- a/src/mmopt/mod.rs +++ b/src/mmopt/mod.rs @@ -88,9 +88,6 @@ pub fn mmopt( // Generate sample candidate indices let candidate_indices = generate_combinations(times.len(), nsamp); - // em - let e = errormodel.; - let (best_combo, min_risk) = candidate_indices .par_iter() .map(|combo| { @@ -114,7 +111,7 @@ pub fn mmopt( .collect(); let i_var: Vec = - i_obs.iter().map(|&x| errormodel.(x)).collect(); + i_obs.iter().map(|&x| errormodel.variance(x)).collect(); let j_var: Vec = j_obs.iter().map(|&x| errorpoly.variance(x)).collect(); @@ -148,13 +145,10 @@ pub fn mmopt( .min_by(|(_, risk_a), (_, risk_b)| risk_a.partial_cmp(risk_b).unwrap()) .unwrap(); + let times = best_combo.iter().map(|&i| times[i]).collect::>(); let res = MmoptResult { - best_combo_indices: best_combo.clone(), - best_combo_times: best_combo - .iter() - .map(|&index| predictions.times[index]) - .collect(), - min_risk, + times: times, + risk: min_risk, }; Ok(res) From e915e7d113e1d145f4b7fdfb85e9872ded48c4ae Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 13 Jul 2025 15:46:09 +0200 Subject: [PATCH 3/8] Refactor --- src/mmopt/mod.rs | 134 +++++++++++++++++++---------------------------- 1 file changed, 55 insertions(+), 79 deletions(-) diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs index 2b070479d..0b772508a 100644 --- a/src/mmopt/mod.rs +++ b/src/mmopt/mod.rs @@ -1,33 +1,10 @@ -use anyhow::Result; +use anyhow::{Ok, Result}; use faer::Mat; -use pharmsol::{ - prelude::simulator::SubjectPredictions, Data, Equation, ErrorModel, Predictions, Subject, -}; +use pharmsol::{Equation, ErrorModel, Predictions, Subject}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; -use serde_json::error; -use std::fmt::Error; use crate::structs::theta::Theta; -pub struct PredictionsContainer { - pub matrix: Mat, - pub times: Vec, - pub probs: Vec, -} - -impl PredictionsContainer { - fn matrix(&self) -> &Mat { - &self.matrix - } - - fn nsub(&self) -> usize { - self.matrix.ncols() - } - fn nout(&self) -> usize { - self.matrix.nrows() - } -} - struct CostMatrix { matrix: Option>, auc: f64, @@ -52,6 +29,7 @@ pub struct MmoptResult { pub risk: f64, } +/// Perform multiple-model optimization to determine optimal sample times pub fn mmopt( theta: &Theta, subject: &Subject, @@ -72,6 +50,7 @@ pub fn mmopt( let support_point: Vec = theta_row.iter().cloned().collect(); let predictions = equation .estimate_predictions(&subject, &support_point) + .unwrap() .get_predictions(); predictions }) @@ -91,67 +70,64 @@ pub fn mmopt( let (best_combo, min_risk) = candidate_indices .par_iter() .map(|combo| { - let mut risk = 0.0; - // Compare the i-th and the j-th subject predictions - for i in 0..theta.nspp() { - for j in 0..theta.nspp() { - if i != j { - let i_obs: Vec = pred_matrix - .col(i) - .iter() - .enumerate() - .filter_map(|(k, &x)| if combo.contains(&k) { Some(x) } else { None }) - .collect(); - - let j_obs: Vec = pred_matrix - .col(j) - .iter() - .enumerate() - .filter_map(|(k, &x)| if combo.contains(&k) { Some(x) } else { None }) - .collect(); - - let i_var: Vec = - i_obs.iter().map(|&x| errormodel.variance(x)).collect(); - let j_var: Vec = - j_obs.iter().map(|&x| errorpoly.variance(x)).collect(); - - let sum_k_ijn: f64 = i_obs - .iter() - .zip(j_obs.iter()) - .zip(i_var.iter()) - .zip(j_var.iter()) - .map(|(((y_i, y_j), i_var), j_var)| { - let denominator = i_var + j_var; - let term1 = (y_i - y_j).powi(2) / (4.0 * denominator); - let term2 = 0.5 * ((i_var + j_var) / 2.0).ln(); - let term3 = -0.25 * (i_var * j_var).ln(); - term1 + term2 + term3 - }) - .collect::>() - .iter() - .sum::(); - - let prob_i = predictions.probs[i]; - let prob_j = predictions.probs[j]; - let cost = cost_matrix.matrix[(i, j)]; - let risk_component = prob_i * prob_j * (-sum_k_ijn).exp() * cost; - risk += risk_component; - } - } - } - + let risk = calculate_risk(combo, &pred_matrix, theta, &errormodel).unwrap(); (combo.clone(), risk) }) .min_by(|(_, risk_a), (_, risk_b)| risk_a.partial_cmp(risk_b).unwrap()) .unwrap(); - let times = best_combo.iter().map(|&i| times[i]).collect::>(); - let res = MmoptResult { - times: times, + let optimal_times = best_combo.iter().map(|&i| times[i]).collect(); + Ok(MmoptResult { + times: optimal_times, risk: min_risk, - }; + }) +} + +/// Calculate the risk for a specific combination of sample times +fn calculate_risk( + combo: &[usize], + pred_matrix: &Mat, + theta: &Theta, + errormodel: &ErrorModel, +) -> Result { + let nspp = theta.nspp(); + let prob_uniform = 1.0 / nspp as f64; // Uniform probability for each support point + + let risk = (0..nspp) + .flat_map(|i| (0..nspp).map(move |j| (i, j))) + .filter(|(i, j)| i != j) + .map(|(i, j)| { + // Extract observations for the selected time points + let i_obs: Vec = combo.iter().map(|&k| pred_matrix[(k, i)]).collect(); + + let j_obs: Vec = combo.iter().map(|&k| pred_matrix[(k, j)]).collect(); + + // Calculate the sum of log-likelihood differences + let sum_k_ijn: f64 = i_obs + .iter() + .zip(j_obs.iter()) + .map(|(&y_i, &y_j)| { + let i_var = errormodel.variance_from_value(y_i).unwrap(); + let j_var = errormodel.variance_from_value(y_j).unwrap(); + let denominator = i_var + j_var; + + let term1 = (y_i - y_j).powi(2) / (4.0 * denominator); + let term2 = 0.5 * (denominator / 2.0).ln(); + let term3 = -0.25 * (i_var * j_var).ln(); + + term1 + term2 + term3 + }) + .sum(); + + // For now, assume unit cost matrix (cost = 1.0 for all pairs) + // This can be parameterized later if needed + let cost = 1.0; + + prob_uniform * prob_uniform * (-sum_k_ijn).exp() * cost + }) + .sum(); - Ok(res) + Ok(risk) } fn generate_combinations(m: usize, n: usize) -> Vec> { From a4ce2b7fe292e4d2d6079497372fbe10595312c7 Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 13 Jul 2025 15:48:59 +0200 Subject: [PATCH 4/8] Update mod.rs --- src/mmopt/mod.rs | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs index 0b772508a..1580681cb 100644 --- a/src/mmopt/mod.rs +++ b/src/mmopt/mod.rs @@ -5,19 +5,6 @@ use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use crate::structs::theta::Theta; -struct CostMatrix { - matrix: Option>, - auc: f64, - cmax: f64, - cmin: f64, -} - -impl CostMatrix { - pub fn new(auc: f64, cmax: f64, cmin: f64) -> Self { - !unimplemented!() - } -} - /// The results of a multiple-model optimization /// /// @@ -36,6 +23,7 @@ pub fn mmopt( equation: impl Equation, errormodel: ErrorModel, nsamp: usize, + weights: Vec, ) -> Result { // Check that subject contains only one Occasion if subject.occasions().len() != 1 { @@ -70,7 +58,8 @@ pub fn mmopt( let (best_combo, min_risk) = candidate_indices .par_iter() .map(|combo| { - let risk = calculate_risk(combo, &pred_matrix, theta, &errormodel).unwrap(); + let risk = + calculate_risk(combo, &pred_matrix, theta, &errormodel, weights.clone()).unwrap(); (combo.clone(), risk) }) .min_by(|(_, risk_a), (_, risk_b)| risk_a.partial_cmp(risk_b).unwrap()) @@ -89,9 +78,9 @@ fn calculate_risk( pred_matrix: &Mat, theta: &Theta, errormodel: &ErrorModel, + weights: Vec, ) -> Result { let nspp = theta.nspp(); - let prob_uniform = 1.0 / nspp as f64; // Uniform probability for each support point let risk = (0..nspp) .flat_map(|i| (0..nspp).map(move |j| (i, j))) @@ -123,7 +112,7 @@ fn calculate_risk( // This can be parameterized later if needed let cost = 1.0; - prob_uniform * prob_uniform * (-sum_k_ijn).exp() * cost + weights[i] * weights[j] * (-sum_k_ijn).exp() * cost }) .sum(); From aef09c471790b776c0b228a93f89b356e4a2856b Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 13 Jul 2025 15:49:54 +0200 Subject: [PATCH 5/8] Update mod.rs --- src/mmopt/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs index 1580681cb..09e94e36a 100644 --- a/src/mmopt/mod.rs +++ b/src/mmopt/mod.rs @@ -108,9 +108,8 @@ fn calculate_risk( }) .sum(); - // For now, assume unit cost matrix (cost = 1.0 for all pairs) - // This can be parameterized later if needed - let cost = 1.0; + // No cost for getting it right + let cost = if i == j { 0.0 } else { 1.0 }; weights[i] * weights[j] * (-sum_k_ijn).exp() * cost }) From 3889f83e859aca9bfa84fab2ca702c4736f1cea5 Mon Sep 17 00:00:00 2001 From: Markus Date: Mon, 23 Mar 2026 22:04:32 +0100 Subject: [PATCH 6/8] WIP --- examples/mmopt.rs | 79 +++++++ src/mmopt/mod.rs | 317 +++++++++++++++++++++++++--- tests/mmopt_tests.rs | 491 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 854 insertions(+), 33 deletions(-) create mode 100644 examples/mmopt.rs create mode 100644 tests/mmopt_tests.rs diff --git a/examples/mmopt.rs b/examples/mmopt.rs new file mode 100644 index 000000000..6e7841d8e --- /dev/null +++ b/examples/mmopt.rs @@ -0,0 +1,79 @@ +use anyhow::Result; +use pmcore::mmopt::mmopt; +use pmcore::prelude::*; +use pmcore::structs::theta::Theta; + +fn main() -> Result<()> { + // Define a one-compartment PK model + let eq = equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _, _| lag! {}, + |_p, _, _| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ); + + // Population support points representing two distinct PK sub-populations + let params = Parameters::new().add("ke", 0.1, 1.0).add("v", 30.0, 100.0); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, // ke: slow eliminator + (0, 1) => 50.0, // v + (1, 0) => 0.5, // ke: fast eliminator + (1, 1) => 60.0, // v + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + // Error model: additive with SD = 20% of observation + let errormodel = ErrorModel::additive(ErrorPoly::new(0.0, 0.20, 0.0, 0.0), 0.0); + + // Create a subject with a dose and candidate observation times + // The observations values are irrelevant — only their times matter + let subject = Subject::builder("candidate") + .bolus(0.0, 100.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) + .missing_observation(6.0, 0) + .missing_observation(8.0, 0) + .missing_observation(12.0, 0) + .missing_observation(24.0, 0) + .build(); + + // Equal prior weights for both sub-populations + let weights = vec![0.5, 0.5]; + + // Find the optimal 2 sample times (out of 8 candidates) + println!("Finding optimal 2 sample times from 8 candidates...\n"); + let result = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 2, + weights.clone(), + )?; + println!(" {}", result); + + // Compare with 3 samples + println!("\nFinding optimal 3 sample times...\n"); + let result_3 = mmopt(&theta, &subject, eq, errormodel, 0, 3, weights)?; + println!(" {}", result_3); + + println!( + "\nRisk reduction from 2 → 3 samples: {:.2}%", + (1.0 - result_3.risk / result.risk) * 100.0 + ); + + Ok(()) +} diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs index 09e94e36a..e30b0b7a3 100644 --- a/src/mmopt/mod.rs +++ b/src/mmopt/mod.rs @@ -7,59 +7,129 @@ use crate::structs::theta::Theta; /// The results of a multiple-model optimization /// -/// -#[derive(Debug)] +/// Contains the optimal sample times and the associated Bayes risk. +#[derive(Debug, Clone)] pub struct MmoptResult { - // Optimal sample times + /// Optimal sample times pub times: Vec, - // Bayes risk + /// Bayes risk at the optimal sample times pub risk: f64, } -/// Perform multiple-model optimization to determine optimal sample times +impl std::fmt::Display for MmoptResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Optimal times: {:?}, Bayes risk: {:.6}", + self.times, self.risk + ) + } +} + +/// Perform multiple-model optimization to determine optimal sample times. +/// +/// This function evaluates all possible combinations of `nsamp` sample times +/// from the candidate times defined as observations in the `subject`, and returns +/// the combination that minimizes the Bayes risk of misclassification between +/// support points. +/// +/// # Arguments +/// * `theta` - Support points (population parameter distribution) +/// * `subject` - Subject with candidate observation times (must have exactly one occasion) +/// * `equation` - The pharmacometric model equation +/// * `errormodel` - Error model for computing observation variance +/// * `outeq` - Output equation index to optimize for +/// * `nsamp` - Number of samples to select +/// * `weights` - Probability weights for each support point (must sum to ~1.0) +/// +/// # Errors +/// Returns an error if: +/// - The subject has more than one occasion +/// - The number of support points is less than 2 +/// - The weights length doesn't match the number of support points +/// - `nsamp` is 0 or exceeds the number of candidate times pub fn mmopt( theta: &Theta, subject: &Subject, equation: impl Equation, errormodel: ErrorModel, + outeq: usize, nsamp: usize, weights: Vec, ) -> Result { - // Check that subject contains only one Occasion + // Validate inputs if subject.occasions().len() != 1 { - return Err(anyhow::anyhow!("Subject must contain only one Occasion")); + return Err(anyhow::anyhow!("Subject must contain exactly one Occasion")); + } + + if theta.nspp() < 2 { + return Err(anyhow::anyhow!( + "At least 2 support points are required, got {}", + theta.nspp() + )); } - // Generate predictions + if weights.len() != theta.nspp() { + return Err(anyhow::anyhow!( + "Weights length ({}) must match number of support points ({})", + weights.len(), + theta.nspp() + )); + } + + if nsamp == 0 { + return Err(anyhow::anyhow!("Number of samples must be at least 1")); + } + + // Generate predictions for each support point let predictions = theta .matrix() .row_iter() .map(|theta_row| { let support_point: Vec = theta_row.iter().cloned().collect(); - let predictions = equation - .estimate_predictions(&subject, &support_point) + let all_preds = equation + .estimate_predictions(subject, &support_point) .unwrap() .get_predictions(); - predictions + // Filter predictions by output equation + all_preds + .into_iter() + .filter(|p| p.outeq() == outeq) + .collect::>() }) .collect::>(); - // Times vector + if predictions[0].is_empty() { + return Err(anyhow::anyhow!( + "No predictions found for output equation {}", + outeq + )); + } + + // Times vector from the first support point's predictions let times = predictions[0].iter().map(|p| p.time()).collect::>(); - // Generate prediction matrix + if nsamp > times.len() { + return Err(anyhow::anyhow!( + "Number of samples ({}) exceeds number of candidate times ({})", + nsamp, + times.len() + )); + } + + // Generate prediction matrix: rows = time points, cols = support points let pred_matrix = Mat::from_fn(predictions[0].len(), theta.nspp(), |i, j| { - predictions[j][i].prediction().to_owned() + predictions[j][i].prediction() }); - // Generate sample candidate indices + // Generate all C(m, n) sample candidate index combinations let candidate_indices = generate_combinations(times.len(), nsamp); + // Evaluate risk in parallel for all combinations and select minimum let (best_combo, min_risk) = candidate_indices .par_iter() .map(|combo| { - let risk = - calculate_risk(combo, &pred_matrix, theta, &errormodel, weights.clone()).unwrap(); + let risk = calculate_risk(combo, &pred_matrix, &errormodel, &weights); (combo.clone(), risk) }) .min_by(|(_, risk_a), (_, risk_b)| risk_a.partial_cmp(risk_b).unwrap()) @@ -72,26 +142,28 @@ pub fn mmopt( }) } -/// Calculate the risk for a specific combination of sample times +/// Calculate the Bayes risk for a specific combination of sample time indices. +/// +/// The risk quantifies the expected misclassification probability between support +/// points, weighted by their probabilities. Lower risk means the selected sample +/// times provide better discrimination between support points. fn calculate_risk( combo: &[usize], pred_matrix: &Mat, - theta: &Theta, errormodel: &ErrorModel, - weights: Vec, -) -> Result { - let nspp = theta.nspp(); + weights: &[f64], +) -> f64 { + let nspp = pred_matrix.ncols(); - let risk = (0..nspp) + (0..nspp) .flat_map(|i| (0..nspp).map(move |j| (i, j))) .filter(|(i, j)| i != j) .map(|(i, j)| { - // Extract observations for the selected time points + // Extract predictions for the selected time points let i_obs: Vec = combo.iter().map(|&k| pred_matrix[(k, i)]).collect(); - let j_obs: Vec = combo.iter().map(|&k| pred_matrix[(k, j)]).collect(); - // Calculate the sum of log-likelihood differences + // Calculate the sum of log-likelihood discrimination terms let sum_k_ijn: f64 = i_obs .iter() .zip(j_obs.iter()) @@ -108,14 +180,9 @@ fn calculate_risk( }) .sum(); - // No cost for getting it right - let cost = if i == j { 0.0 } else { 1.0 }; - - weights[i] * weights[j] * (-sum_k_ijn).exp() * cost + weights[i] * weights[j] * (-sum_k_ijn).exp() }) - .sum(); - - Ok(risk) + .sum() } fn generate_combinations(m: usize, n: usize) -> Vec> { @@ -147,6 +214,8 @@ fn generate_combinations(m: usize, n: usize) -> Vec> { #[cfg(test)] mod tests { use super::*; + use faer::Mat; + use pharmsol::ErrorPoly; #[test] fn test_combinations() { @@ -165,4 +234,186 @@ mod tests { assert_eq!(combinations[8], vec![1, 3, 4]); assert_eq!(combinations[9], vec![2, 3, 4]); } + + #[test] + fn test_combinations_edge_cases() { + // Select all elements + let combinations = generate_combinations(3, 3); + assert_eq!(combinations.len(), 1); + assert_eq!(combinations[0], vec![0, 1, 2]); + + // Select 1 element + let combinations = generate_combinations(4, 1); + assert_eq!(combinations.len(), 4); + assert_eq!(combinations[0], vec![0]); + assert_eq!(combinations[3], vec![3]); + + // C(6, 2) = 15 + let combinations = generate_combinations(6, 2); + assert_eq!(combinations.len(), 15); + } + + fn make_error_model() -> ErrorModel { + ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 0.0) + } + + #[test] + fn test_calculate_risk_identical_predictions() { + // When two support points have identical predictions, misclassification risk is maximal + let errormodel = make_error_model(); + let weights = vec![0.5, 0.5]; + + // pred_matrix: 3 time points, 2 support points with identical predictions + let pred_matrix = Mat::from_fn(3, 2, |i, _j| match i { + 0 => 10.0, + 1 => 5.0, + 2 => 2.0, + _ => 0.0, + }); + + let combo = vec![0, 1, 2]; + let risk_identical = calculate_risk(&combo, &pred_matrix, &errormodel, &weights); + + // Now different predictions + let pred_matrix_diff = Mat::from_fn(3, 2, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 50.0, + (1, 0) => 5.0, + (1, 1) => 25.0, + (2, 0) => 2.0, + (2, 1) => 12.0, + _ => 0.0, + }); + + let risk_different = calculate_risk(&combo, &pred_matrix_diff, &errormodel, &weights); + + // Identical predictions should have higher risk (harder to discriminate) + assert!( + risk_identical > risk_different, + "Identical predictions should have higher risk: {} vs {}", + risk_identical, + risk_different + ); + } + + #[test] + fn test_calculate_risk_symmetric_weights() { + // Risk should be symmetric when weights are equal + let errormodel = make_error_model(); + let weights = vec![0.5, 0.5]; + + let pred_matrix = Mat::from_fn(3, 2, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 20.0, + (1, 0) => 5.0, + (1, 1) => 10.0, + (2, 0) => 2.0, + (2, 1) => 4.0, + _ => 0.0, + }); + + let combo = vec![0, 1, 2]; + let risk = calculate_risk(&combo, &pred_matrix, &errormodel, &weights); + assert!(risk > 0.0, "Risk should be positive"); + assert!(risk.is_finite(), "Risk should be finite"); + } + + #[test] + fn test_calculate_risk_more_samples_lower_risk() { + // Using more sample times should generally not increase (and usually decrease) risk + let errormodel = make_error_model(); + let weights = vec![0.5, 0.5]; + + // 4 time points with very different predictions + let pred_matrix = Mat::from_fn(4, 2, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 20.0, + (1, 0) => 5.0, + (1, 1) => 15.0, + (2, 0) => 2.0, + (2, 1) => 8.0, + (3, 0) => 1.0, + (3, 1) => 6.0, + _ => 0.0, + }); + + // Best 2-sample combo risk + let combos_2 = generate_combinations(4, 2); + let min_risk_2 = combos_2 + .iter() + .map(|combo| calculate_risk(combo, &pred_matrix, &errormodel, &weights)) + .fold(f64::INFINITY, f64::min); + + // Best 3-sample combo risk + let combos_3 = generate_combinations(4, 3); + let min_risk_3 = combos_3 + .iter() + .map(|combo| calculate_risk(combo, &pred_matrix, &errormodel, &weights)) + .fold(f64::INFINITY, f64::min); + + assert!( + min_risk_3 <= min_risk_2, + "More samples should yield equal or lower risk: {} vs {}", + min_risk_3, + min_risk_2 + ); + } + + #[test] + fn test_calculate_risk_zero_weight() { + // Setting a weight to zero should eliminate that support point's contribution + let errormodel = make_error_model(); + + let pred_matrix = Mat::from_fn(2, 3, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 20.0, + (0, 2) => 30.0, + (1, 0) => 5.0, + (1, 1) => 10.0, + (1, 2) => 15.0, + _ => 0.0, + }); + + let combo = vec![0, 1]; + + // With all weights + let weights_all = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; + let risk_all = calculate_risk(&combo, &pred_matrix, &errormodel, &weights_all); + + // Zero out one weight + let weights_zero = vec![0.5, 0.5, 0.0]; + let risk_zero = calculate_risk(&combo, &pred_matrix, &errormodel, &weights_zero); + + // Risks should differ when a weight is zeroed + assert!( + (risk_all - risk_zero).abs() > 1e-10, + "Zeroing a weight should change the risk" + ); + } + + #[test] + fn test_risk_positive_and_finite() { + let errormodel = make_error_model(); + let weights = vec![0.25, 0.25, 0.25, 0.25]; + + let pred_matrix = Mat::from_fn(5, 4, |i, j| (i as f64 + 1.0) * (j as f64 + 1.0) * 2.0); + + let combo = vec![0, 2, 4]; + let risk = calculate_risk(&combo, &pred_matrix, &errormodel, &weights); + assert!(risk >= 0.0, "Risk must be non-negative"); + assert!(risk.is_finite(), "Risk must be finite"); + } + + #[test] + fn test_mmopt_result_display() { + let result = MmoptResult { + times: vec![1.0, 4.0, 8.0], + risk: 0.123456, + }; + let display = format!("{}", result); + assert!(display.contains("1.0")); + assert!(display.contains("4.0")); + assert!(display.contains("8.0")); + assert!(display.contains("0.123456")); + } } diff --git a/tests/mmopt_tests.rs b/tests/mmopt_tests.rs new file mode 100644 index 000000000..e6caf20d8 --- /dev/null +++ b/tests/mmopt_tests.rs @@ -0,0 +1,491 @@ +use anyhow::Result; +use pmcore::mmopt::{mmopt, MmoptResult}; +use pmcore::prelude::*; +use pmcore::structs::theta::Theta; + +/// Helper to create a simple one-compartment model +fn one_comp_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _, _| lag! {}, + |_p, _, _| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ) +} + +/// Helper to create a simple error model +fn additive_error_model() -> ErrorModel { + ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 0.0) +} + +/// Helper to create parameters for the one-compartment model +fn one_comp_params() -> Parameters { + Parameters::new().add("ke", 0.1, 1.0).add("v", 30.0, 100.0) +} + +/// Test basic mmopt functionality with a simple one-compartment model +#[test] +fn test_mmopt_basic() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + // Two support points with different PK parameters + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, // ke for spp1 + (0, 1) => 50.0, // v for spp1 + (1, 0) => 0.8, // ke for spp2 + (1, 1) => 80.0, // v for spp2 + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + // Subject with candidate observation times after a bolus dose + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + let weights = vec![0.5, 0.5]; + + let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, weights)?; + + assert_eq!( + result.times.len(), + 2, + "Should select exactly 2 sample times" + ); + assert!(result.risk >= 0.0, "Risk must be non-negative"); + assert!(result.risk.is_finite(), "Risk must be finite"); + + // All selected times should be from the candidate set + let candidate_times = vec![1.0, 2.0, 4.0, 6.0, 8.0, 12.0]; + for t in &result.times { + assert!( + candidate_times.contains(t), + "Selected time {} is not in the candidate set", + t + ); + } + + Ok(()) +} + +/// Test that selecting more samples results in equal or lower risk +#[test] +fn test_mmopt_more_samples_lower_risk() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 50.0, + (1, 0) => 0.7, + (1, 1) => 70.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + + let result_2 = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 2, + vec![0.5, 0.5], + )?; + let result_3 = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 3, + vec![0.5, 0.5], + )?; + + assert!( + result_3.risk <= result_2.risk + 1e-10, + "More samples should yield lower or equal risk: {} vs {}", + result_3.risk, + result_2.risk + ); + + Ok(()) +} + +/// Test mmopt with three support points +#[test] +fn test_mmopt_three_support_points() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(3, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 40.0, + (1, 0) => 0.5, + (1, 1) => 60.0, + (2, 0) => 0.9, + (2, 1) => 90.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(3.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; + + let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, weights)?; + + assert_eq!(result.times.len(), 2); + assert!(result.risk >= 0.0); + assert!(result.risk.is_finite()); + + Ok(()) +} + +/// Test that mmopt with all candidate times produces the lowest possible risk +#[test] +fn test_mmopt_all_samples() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + let weights = vec![0.5, 0.5]; + + // Select all 3 samples (only one combination) + let result = mmopt(&theta, &subject, eq, errormodel, 0, 3, weights)?; + + assert_eq!(result.times.len(), 3); + assert_eq!(result.times, vec![1.0, 4.0, 8.0]); + + Ok(()) +} + +/// Test validation: subject with multiple occasions should fail +#[test] +fn test_mmopt_multiple_occasions_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + // Subject with two occasions + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .repeat(1, 24.0) + .bolus(24.0, 100.0, 0) + .observation(25.0, 0.0, 0) + .build(); + + // Only proceed if the subject actually has multiple occasions + if subject.occasions().len() > 1 { + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + vec![0.5, 0.5], + ); + assert!(result.is_err()); + } +} + +/// Test validation: fewer than 2 support points should fail +#[test] +fn test_mmopt_single_support_point_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(1, 2, |_r, c| match c { + 0 => 0.3, + 1 => 50.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + vec![1.0], + ); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("At least 2 support points")); +} + +/// Test validation: weights length mismatch should fail +#[test] +fn test_mmopt_weights_mismatch_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // 3 weights for 2 support points + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + vec![0.33, 0.33, 0.34], + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Weights length")); +} + +/// Test validation: nsamp = 0 should fail +#[test] +fn test_mmopt_zero_samples_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .build(); + + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 0, + vec![0.5, 0.5], + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("at least 1")); +} + +/// Test validation: nsamp exceeds candidate times should fail +#[test] +fn test_mmopt_too_many_samples_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // Request 5 samples but only 2 candidate times + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 5, + vec![0.5, 0.5], + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("exceeds")); +} + +/// Test that unequal weights influence the optimal sampling design +#[test] +fn test_mmopt_unequal_weights() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 50.0, + (1, 0) => 0.8, + (1, 1) => 80.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + + let result_equal = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 2, + vec![0.5, 0.5], + )?; + + let result_skewed = mmopt(&theta, &subject, eq, errormodel, 0, 2, vec![0.9, 0.1])?; + + // Different weights should generally produce different risks + // (or at least both should be valid) + assert!(result_equal.risk.is_finite()); + assert!(result_skewed.risk.is_finite()); + assert!(result_equal.risk >= 0.0); + assert!(result_skewed.risk >= 0.0); + + Ok(()) +} + +/// Test MmoptResult Display implementation +#[test] +fn test_mmopt_result_display() { + let result = MmoptResult { + times: vec![2.0, 6.0, 12.0], + risk: 0.042, + }; + let display = format!("{}", result); + assert!(display.contains("2.0")); + assert!(display.contains("6.0")); + assert!(display.contains("12.0")); + assert!(display.contains("0.042")); +} + +/// Test with a single sample selection +#[test] +fn test_mmopt_single_sample() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 40.0, + (1, 0) => 0.9, + (1, 1) => 90.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + vec![0.5, 0.5], + )?; + + assert_eq!(result.times.len(), 1); + assert!(result.risk.is_finite()); + assert!(result.risk >= 0.0); + + Ok(()) +} From 77002e11307304ae04a74a36dda8524ead4323f0 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:02:36 +0100 Subject: [PATCH 7/8] Use paper for example --- examples/mmopt.rs | 186 ++++++++++++++++++++++++++++++++++++---------- src/lib.rs | 1 - src/mmopt/mod.rs | 55 +++++++++++--- 3 files changed, 189 insertions(+), 53 deletions(-) diff --git a/examples/mmopt.rs b/examples/mmopt.rs index 6e7841d8e..1e1cc940d 100644 --- a/examples/mmopt.rs +++ b/examples/mmopt.rs @@ -1,14 +1,18 @@ +//! Replication of the experiments in Bayard & Neely (2017) +//! "Experiment Design for Nonparametric Models Based On Minimizing Bayes Risk" +//! J Pharmacokinet Pharmacodyn. 2017;44(2):95-111. PMCID: PMC5376526 + use anyhow::Result; use pmcore::mmopt::mmopt; use pmcore::prelude::*; use pmcore::structs::theta::Theta; -fn main() -> Result<()> { - // Define a one-compartment PK model - let eq = equation::ODE::new( - |x, p, _t, dx, b, _rateiv, _cov| { +/// One-compartment model: dx/dt = -K*x + input, y = x/V +fn one_comp_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, rateiv, _cov| { fetch_params!(p, ke, _v); - dx[0] = -ke * x[0] + b[0]; + dx[0] = -ke * x[0] + b[0] + rateiv[0]; }, |_p, _, _| lag! {}, |_p, _, _| fa! {}, @@ -18,43 +22,135 @@ fn main() -> Result<()> { y[0] = x[0] / v; }, (1, 1), - ); + ) +} - // Population support points representing two distinct PK sub-populations - let params = Parameters::new().add("ke", 0.1, 1.0).add("v", 30.0, 100.0); +fn main() -> Result<()> { + section4()?; + println!(); + section6()?; + Ok(()) +} + +/// Paper Section 4: Two-support-point exponential decay example +/// +/// Model: μ(t,a) = e^{-at} (implemented as 1-compartment with D=V=1) +/// Support points: a1 = 1.5 (fast), a2 = 0.25 (slow) +/// Uniform priors: p1 = p2 = 0.5 +/// Error: σ = 0.3 (constant additive) +/// Candidate times: 0.1 to 5.0 hours at 0.1-hour intervals +/// +/// Analytical optimum: t* = ln(6)/1.25 ≈ 1.4334 hours +fn section4() -> Result<()> { + println!("=== Section 4: Two-support-point example ===\n"); + + let eq = one_comp_model(); + let params = Parameters::new().add("ke", 0.1, 5.0).add("v", 0.5, 2.0); let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { - (0, 0) => 0.3, // ke: slow eliminator - (0, 1) => 50.0, // v - (1, 0) => 0.5, // ke: fast eliminator - (1, 1) => 60.0, // v + (0, 0) => 1.5, // a1 (fast) + (0, 1) => 1.0, // V = 1 + (1, 0) => 0.25, // a2 (slow) + (1, 1) => 1.0, // V = 1 _ => 0.0, }); let theta = Theta::from_parts(mat, params)?; - // Error model: additive with SD = 20% of observation - let errormodel = ErrorModel::additive(ErrorPoly::new(0.0, 0.20, 0.0, 0.0), 0.0); - - // Create a subject with a dose and candidate observation times - // The observations values are irrelevant — only their times matter - let subject = Subject::builder("candidate") - .bolus(0.0, 100.0, 0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(4.0, 0) - .missing_observation(6.0, 0) - .missing_observation(8.0, 0) - .missing_observation(12.0, 0) - .missing_observation(24.0, 0) - .build(); - - // Equal prior weights for both sub-populations + let errormodel = ErrorModel::additive(ErrorPoly::new(0.3, 0.0, 0.0, 0.0), 0.0); + + // Candidate times: 0.1 to 5.0 at 0.1h steps + let mut builder = Subject::builder("section4"); + builder = builder.bolus(0.0, 1.0, 0); + for i in 1..=50 { + builder = builder.missing_observation(i as f64 * 0.1, 0); + } + let subject = builder.build(); + let weights = vec![0.5, 0.5]; + let analytical = (6.0_f64).ln() / 1.25; + + let result = mmopt(&theta, &subject, eq, errormodel, 0, 1, weights)?; + + println!( + " Analytical optimum: t* = ln(6)/1.25 = {:.4} h", + analytical + ); + println!(" MMopt optimal time: t = {:.4} h", result.times[0]); + println!(" Bayes risk (overbound): {:.6}", result.risk); + + Ok(()) +} + +/// Paper Section 6: PK example with 10 support points +/// +/// Model: one-compartment, dx/dt = d(t) - K*x, y = x/V +/// Dose: 300 units infused over 1 hour (rate = 300/hr) +/// Error: σ = 0.1 (constant additive) +/// 10 support points from Table 6.1 with equal priors (p_i = 0.1) +/// Candidate times: 0.25 to 24.0 hours at 0.25-hour intervals +/// +/// Paper results (Table 6.2): +/// n=1: t* = {4.25}, Bayes Risk = 0.5474 +/// n=2: t* = {1.0, 9.5}, Bayes Risk = 0.2947 +/// n=3: t* = {1.0, 1.0, 10.5}, Bayes Risk = 0.2325 +fn section6() -> Result<()> { + println!("=== Section 6: PK example (10 support points, Table 6.1) ===\n"); + + let eq = one_comp_model(); + let params = Parameters::new().add("ke", 0.01, 0.2).add("v", 80.0, 120.0); + + // Table 6.1 support points [K, V] + let support_points: [(f64, f64); 10] = [ + (0.090088, 113.7451), + (0.111611, 93.4326), + (0.066074, 90.2832), + (0.108604, 89.2334), + (0.103047, 112.1093), + (0.033965, 94.3847), + (0.100859, 109.8633), + (0.023174, 111.7920), + (0.087041, 108.6670), + (0.095996, 100.3418), + ]; - // Find the optimal 2 sample times (out of 8 candidates) - println!("Finding optimal 2 sample times from 8 candidates...\n"); - let result = mmopt( + let mat = faer::Mat::from_fn(10, 2, |r, c| match c { + 0 => support_points[r].0, + 1 => support_points[r].1, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let errormodel = ErrorModel::additive(ErrorPoly::new(0.1, 0.0, 0.0, 0.0), 0.0); + + // 1-hour infusion of 300 units; candidate times 0.25 to 24h at 0.25h steps + let mut builder = Subject::builder("section6"); + builder = builder.infusion(0.0, 300.0, 0, 1.0); + for i in 1..=96 { + builder = builder.missing_observation(i as f64 * 0.25, 0); + } + let subject = builder.build(); + + let weights = vec![0.1; 10]; + + // --- 1-sample design --- + let r1 = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 1, + weights.clone(), + )?; + println!(" 1-sample design:"); + println!(" Paper: t* = {{4.25}}, Bayes Risk = 0.5474"); + println!( + " MMopt: t* = {{{:.2}}}, Bayes risk = {:.6}", + r1.times[0], r1.risk + ); + + // --- 2-sample design --- + let r2 = mmopt( &theta, &subject, eq.clone(), @@ -63,16 +159,26 @@ fn main() -> Result<()> { 2, weights.clone(), )?; - println!(" {}", result); + println!("\n 2-sample design:"); + println!(" Paper: t* = {{1.0, 9.5}}, Bayes Risk = 0.2947"); + println!( + " MMopt: t* = {{{:.2}, {:.2}}}, Bayes risk = {:.6}", + r2.times[0], r2.times[1], r2.risk + ); - // Compare with 3 samples - println!("\nFinding optimal 3 sample times...\n"); - let result_3 = mmopt(&theta, &subject, eq, errormodel, 0, 3, weights)?; - println!(" {}", result_3); + // --- 3-sample design --- + let r3 = mmopt(&theta, &subject, eq, errormodel, 0, 3, weights)?; + println!("\n 3-sample design:"); + println!(" Paper: t* = {{1.0, 1.0, 10.5}}, Bayes Risk = 0.2325"); + println!( + " MMopt: t* = {{{:.2}, {:.2}, {:.2}}}, Bayes risk = {:.6}", + r3.times[0], r3.times[1], r3.times[2], r3.risk + ); println!( - "\nRisk reduction from 2 → 3 samples: {:.2}%", - (1.0 - result_3.risk / result.risk) * 100.0 + "\n Risk reduction: 1→2 samples: {:.1}%, 2→3 samples: {:.1}%", + (1.0 - r2.risk / r1.risk) * 100.0, + (1.0 - r3.risk / r2.risk) * 100.0, ); Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 4b5358483..49f0f4d3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,6 @@ pub mod prelude { pub use crate::routines::settings::*; pub use crate::structs::*; - pub use crate::mmopt::*; pub mod simulator { diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs index e30b0b7a3..57d8c8a11 100644 --- a/src/mmopt/mod.rs +++ b/src/mmopt/mod.rs @@ -29,7 +29,7 @@ impl std::fmt::Display for MmoptResult { /// Perform multiple-model optimization to determine optimal sample times. /// /// This function evaluates all possible combinations of `nsamp` sample times -/// from the candidate times defined as observations in the `subject`, and returns +/// from the candidate times defined as [pharmsol::data::Observation]s in the [Subject], and returns /// the combination that minimizes the Bayes risk of misclassification between /// support points. /// @@ -85,19 +85,26 @@ pub fn mmopt( let predictions = theta .matrix() .row_iter() - .map(|theta_row| { + .enumerate() + .map(|(idx, theta_row)| { let support_point: Vec = theta_row.iter().cloned().collect(); let all_preds = equation .estimate_predictions(subject, &support_point) - .unwrap() + .map_err(|e| { + anyhow::anyhow!( + "Failed to generate predictions for support point {}: {}", + idx, + e + ) + })? .get_predictions(); // Filter predictions by output equation - all_preds + Ok(all_preds .into_iter() .filter(|p| p.outeq() == outeq) - .collect::>() + .collect::>()) }) - .collect::>(); + .collect::>>()?; if predictions[0].is_empty() { return Err(anyhow::anyhow!( @@ -117,6 +124,20 @@ pub fn mmopt( )); } + // Guard against combinatorial explosion + let n_combinations = n_choose_k(times.len(), nsamp); + const MAX_COMBINATIONS: u128 = 1_000_000; + if n_combinations > MAX_COMBINATIONS { + return Err(anyhow::anyhow!( + "C({}, {}) = {} exceeds the maximum allowed combinations ({}). \ + Reduce the number of candidate times or increase nsamp.", + times.len(), + nsamp, + n_combinations, + MAX_COMBINATIONS + )); + } + // Generate prediction matrix: rows = time points, cols = support points let pred_matrix = Mat::from_fn(predictions[0].len(), theta.nspp(), |i, j| { predictions[j][i].prediction() @@ -132,8 +153,8 @@ pub fn mmopt( let risk = calculate_risk(combo, &pred_matrix, &errormodel, &weights); (combo.clone(), risk) }) - .min_by(|(_, risk_a), (_, risk_b)| risk_a.partial_cmp(risk_b).unwrap()) - .unwrap(); + .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Greater)) + .ok_or_else(|| anyhow::anyhow!("No candidate combinations to evaluate"))?; let optimal_times = best_combo.iter().map(|&i| times[i]).collect(); Ok(MmoptResult { @@ -156,8 +177,7 @@ fn calculate_risk( let nspp = pred_matrix.ncols(); (0..nspp) - .flat_map(|i| (0..nspp).map(move |j| (i, j))) - .filter(|(i, j)| i != j) + .flat_map(|i| ((i + 1)..nspp).map(move |j| (i, j))) .map(|(i, j)| { // Extract predictions for the selected time points let i_obs: Vec = combo.iter().map(|&k| pred_matrix[(k, i)]).collect(); @@ -168,8 +188,8 @@ fn calculate_risk( .iter() .zip(j_obs.iter()) .map(|(&y_i, &y_j)| { - let i_var = errormodel.variance_from_value(y_i).unwrap(); - let j_var = errormodel.variance_from_value(y_j).unwrap(); + let i_var = errormodel.variance_from_value(y_i).unwrap_or(f64::EPSILON); + let j_var = errormodel.variance_from_value(y_j).unwrap_or(f64::EPSILON); let denominator = i_var + j_var; let term1 = (y_i - y_j).powi(2) / (4.0 * denominator); @@ -185,6 +205,17 @@ fn calculate_risk( .sum() } +/// Compute C(m, n) without overflow risk by using u128 arithmetic. +fn n_choose_k(m: usize, n: usize) -> u128 { + if n > m { + return 0; + } + // Use the smaller of n and m-n for efficiency + let k = n.min(m - n) as u128; + let m = m as u128; + (0..k).fold(1u128, |acc, i| acc * (m - i) / (i + 1)) +} + fn generate_combinations(m: usize, n: usize) -> Vec> { fn backtrack( m: usize, From f9f793ce7611421f0514b8a149b7724d6a65c4df Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:17:04 +0100 Subject: [PATCH 8/8] Add example of Bayes risk --- examples/bayes_risk.rs | 146 +++++++++++++++++++++++++++++++++++++++ examples/mmopt.rs | 16 +++-- src/mmopt/mod.rs | 150 +++++++++++++++++++++++++++-------------- tests/mmopt_tests.rs | 41 ++++++----- 4 files changed, 281 insertions(+), 72 deletions(-) create mode 100644 examples/bayes_risk.rs diff --git a/examples/bayes_risk.rs b/examples/bayes_risk.rs new file mode 100644 index 000000000..f919a62e1 --- /dev/null +++ b/examples/bayes_risk.rs @@ -0,0 +1,146 @@ +//! Example: Compute Bayes risk for a given sampling design +//! +//! Uses the same PK model and support points as Section 6 of Bayard & Neely (2017). +//! Instead of optimizing sample times, this calculates the Bayes risk for +//! user-specified observation times. + +use anyhow::Result; +use pmcore::mmopt::bayes_risk; +use pmcore::prelude::*; +use pmcore::structs::theta::Theta; +use pmcore::structs::weights::Weights; + +/// One-compartment model: dx/dt = -K*x + input, y = x/V +fn one_comp_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0] + rateiv[0]; + }, + |_p, _, _| lag! {}, + |_p, _, _| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ) +} + +fn main() -> Result<()> { + let eq = one_comp_model(); + let params = Parameters::new().add("ke", 0.01, 0.2).add("v", 80.0, 120.0); + + // Table 6.1 support points [K, V] + let support_points: [(f64, f64); 10] = [ + (0.090088, 113.7451), + (0.111611, 93.4326), + (0.066074, 90.2832), + (0.108604, 89.2334), + (0.103047, 112.1093), + (0.033965, 94.3847), + (0.100859, 109.8633), + (0.023174, 111.7920), + (0.087041, 108.6670), + (0.095996, 100.3418), + ]; + + let mat = faer::Mat::from_fn(10, 2, |r, c| match c { + 0 => support_points[r].0, + 1 => support_points[r].1, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let errormodel = ErrorModel::additive(ErrorPoly::new(0.1, 0.0, 0.0, 0.0), 0.0); + let weights = Weights::uniform(10); + + // --- Design A: Two observations at the MMopt-optimal times --- + let subject_a = Subject::builder("design_a") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(1.0, 0) + .missing_observation(9.5, 0) + .build(); + + let risk_a = bayes_risk( + &theta, + &subject_a, + eq.clone(), + errormodel.clone(), + 0, + &weights, + )?; + println!("Design A t = {{1.0, 9.5}} Bayes risk = {:.6}", risk_a); + + // --- Design B: Two observations at sub-optimal times --- + let subject_b = Subject::builder("design_b") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .build(); + + let risk_b = bayes_risk( + &theta, + &subject_b, + eq.clone(), + errormodel.clone(), + 0, + &weights, + )?; + println!( + "Design B t = {{2.0, 6.0}} Bayes risk = {:.6}", + risk_b + ); + + // --- Design C: B + one more sample --- + let subject_c = Subject::builder("design_c") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(12.0, 0) + .build(); + + let risk_c = bayes_risk( + &theta, + &subject_c, + eq.clone(), + errormodel.clone(), + 0, + &weights, + )?; + println!( + "Design C t = {{2.0, 6.0, 12.0}} Bayes risk = {:.6}", + risk_c + ); + + // --- Design D: C + one more sample --- + let subject_d = Subject::builder("design_d") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(12.0, 0) + .missing_observation(18.0, 0) + .build(); + + let risk_d = bayes_risk(&theta, &subject_d, eq, errormodel, 0, &weights)?; + println!( + "Design D t = {{2.0, 6.0, 12.0, 18.0}} Bayes risk = {:.6}", + risk_d + ); + + println!( + "\nDesign A vs B: {:.1}% lower risk with optimal times", + (1.0 - risk_a / risk_b) * 100.0 + ); + println!( + "B → C (add 1 sample): {:.1}% risk reduction", + (1.0 - risk_c / risk_b) * 100.0 + ); + println!( + "C → D (add 1 sample): {:.1}% risk reduction", + (1.0 - risk_d / risk_c) * 100.0 + ); + + Ok(()) +} diff --git a/examples/mmopt.rs b/examples/mmopt.rs index 1e1cc940d..0a4532173 100644 --- a/examples/mmopt.rs +++ b/examples/mmopt.rs @@ -1,11 +1,13 @@ //! Replication of the experiments in Bayard & Neely (2017) //! "Experiment Design for Nonparametric Models Based On Minimizing Bayes Risk" -//! J Pharmacokinet Pharmacodyn. 2017;44(2):95-111. PMCID: PMC5376526 +//! Journal of Pharmacokinetics and Pharmacodynamics. +//! https://doi.org/10.1007/s10928-016-9498-5 use anyhow::Result; use pmcore::mmopt::mmopt; use pmcore::prelude::*; use pmcore::structs::theta::Theta; +use pmcore::structs::weights::Weights; /// One-compartment model: dx/dt = -K*x + input, y = x/V fn one_comp_model() -> equation::ODE { @@ -66,10 +68,10 @@ fn section4() -> Result<()> { } let subject = builder.build(); - let weights = vec![0.5, 0.5]; + let weights = Weights::from_vec(vec![0.5, 0.5]); let analytical = (6.0_f64).ln() / 1.25; - let result = mmopt(&theta, &subject, eq, errormodel, 0, 1, weights)?; + let result = mmopt(&theta, &subject, eq, errormodel, 0, 1, &weights)?; println!( " Analytical optimum: t* = ln(6)/1.25 = {:.4} h", @@ -130,7 +132,7 @@ fn section6() -> Result<()> { } let subject = builder.build(); - let weights = vec![0.1; 10]; + let weights = Weights::uniform(10); // --- 1-sample design --- let r1 = mmopt( @@ -140,7 +142,7 @@ fn section6() -> Result<()> { errormodel.clone(), 0, 1, - weights.clone(), + &weights, )?; println!(" 1-sample design:"); println!(" Paper: t* = {{4.25}}, Bayes Risk = 0.5474"); @@ -157,7 +159,7 @@ fn section6() -> Result<()> { errormodel.clone(), 0, 2, - weights.clone(), + &weights, )?; println!("\n 2-sample design:"); println!(" Paper: t* = {{1.0, 9.5}}, Bayes Risk = 0.2947"); @@ -167,7 +169,7 @@ fn section6() -> Result<()> { ); // --- 3-sample design --- - let r3 = mmopt(&theta, &subject, eq, errormodel, 0, 3, weights)?; + let r3 = mmopt(&theta, &subject, eq, errormodel, 0, 3, &weights)?; println!("\n 3-sample design:"); println!(" Paper: t* = {{1.0, 1.0, 10.5}}, Bayes Risk = 0.2325"); println!( diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs index 57d8c8a11..b2e924966 100644 --- a/src/mmopt/mod.rs +++ b/src/mmopt/mod.rs @@ -4,6 +4,7 @@ use pharmsol::{Equation, ErrorModel, Predictions, Subject}; use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use crate::structs::theta::Theta; +use crate::structs::weights::Weights; /// The results of a multiple-model optimization /// @@ -26,6 +27,40 @@ impl std::fmt::Display for MmoptResult { } } +/// Compute the Bayes risk overbound for the observation times in the given subject. +/// +/// This evaluates the Bhattacharyya upper bound on the Bayes risk of misclassification +/// between support points, using all observation times defined in the subject. Unlike +/// [`mmopt`], this does not search over combinations — it simply scores the design +/// represented by the subject's current observations. +/// +/// # Arguments +/// * `theta` - Support points (population parameter distribution) +/// * `subject` - Subject whose observation times define the sampling design +/// * `equation` - The pharmacometric model equation +/// * `errormodel` - Error model for computing observation variance +/// * `outeq` - Output equation index to evaluate +/// * `weights` - Probability weights for each support point (must sum to ~1.0) +pub fn bayes_risk( + theta: &Theta, + subject: &Subject, + equation: impl Equation, + errormodel: ErrorModel, + outeq: usize, + weights: &Weights, +) -> Result { + let (pred_matrix, weights_vec, _) = + build_prediction_matrix(theta, subject, equation, outeq, weights)?; + + let all_indices: Vec = (0..pred_matrix.nrows()).collect(); + Ok(calculate_risk( + &all_indices, + &pred_matrix, + &errormodel, + &weights_vec, + )) +} + /// Perform multiple-model optimization to determine optimal sample times. /// /// This function evaluates all possible combinations of `nsamp` sample times @@ -55,9 +90,70 @@ pub fn mmopt( errormodel: ErrorModel, outeq: usize, nsamp: usize, - weights: Vec, + weights: &Weights, ) -> Result { - // Validate inputs + if nsamp == 0 { + return Err(anyhow::anyhow!("Number of samples must be at least 1")); + } + + let (pred_matrix, weights_vec, times) = + build_prediction_matrix(theta, subject, equation, outeq, weights)?; + + if nsamp > times.len() { + return Err(anyhow::anyhow!( + "Number of samples ({}) exceeds number of candidate times ({})", + nsamp, + times.len() + )); + } + + // Guard against combinatorial explosion + let n_combinations = n_choose_k(times.len(), nsamp); + const MAX_COMBINATIONS: u128 = 1_000_000; + if n_combinations > MAX_COMBINATIONS { + return Err(anyhow::anyhow!( + "C({}, {}) = {} exceeds the maximum allowed combinations ({}). \ + Reduce the number of candidate times or increase nsamp.", + times.len(), + nsamp, + n_combinations, + MAX_COMBINATIONS + )); + } + + // Generate all C(m, n) sample candidate index combinations + let candidate_indices = generate_combinations(times.len(), nsamp); + + // Evaluate risk in parallel for all combinations and select minimum + let (best_combo, min_risk) = candidate_indices + .par_iter() + .map(|combo| { + let risk = calculate_risk(combo, &pred_matrix, &errormodel, &weights_vec); + (combo.clone(), risk) + }) + .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Greater)) + .ok_or_else(|| anyhow::anyhow!("No candidate combinations to evaluate"))?; + + let optimal_times = best_combo.iter().map(|&i| times[i]).collect(); + Ok(MmoptResult { + times: optimal_times, + risk: min_risk, + }) +} + +/// Shared helper: validate inputs, generate predictions, and build the prediction matrix. +/// +/// Returns `(pred_matrix, weights_vec, times)` where: +/// - `pred_matrix` has rows = time points, cols = support points +/// - `weights_vec` is the weight vector extracted from `Weights` +/// - `times` is the vector of prediction times +fn build_prediction_matrix( + theta: &Theta, + subject: &Subject, + equation: impl Equation, + outeq: usize, + weights: &Weights, +) -> Result<(Mat, Vec, Vec)> { if subject.occasions().len() != 1 { return Err(anyhow::anyhow!("Subject must contain exactly one Occasion")); } @@ -77,9 +173,7 @@ pub fn mmopt( )); } - if nsamp == 0 { - return Err(anyhow::anyhow!("Number of samples must be at least 1")); - } + let weights_vec = weights.to_vec(); // Generate predictions for each support point let predictions = theta @@ -98,7 +192,6 @@ pub fn mmopt( ) })? .get_predictions(); - // Filter predictions by output equation Ok(all_preds .into_iter() .filter(|p| p.outeq() == outeq) @@ -113,54 +206,13 @@ pub fn mmopt( )); } - // Times vector from the first support point's predictions - let times = predictions[0].iter().map(|p| p.time()).collect::>(); + let times: Vec = predictions[0].iter().map(|p| p.time()).collect(); - if nsamp > times.len() { - return Err(anyhow::anyhow!( - "Number of samples ({}) exceeds number of candidate times ({})", - nsamp, - times.len() - )); - } - - // Guard against combinatorial explosion - let n_combinations = n_choose_k(times.len(), nsamp); - const MAX_COMBINATIONS: u128 = 1_000_000; - if n_combinations > MAX_COMBINATIONS { - return Err(anyhow::anyhow!( - "C({}, {}) = {} exceeds the maximum allowed combinations ({}). \ - Reduce the number of candidate times or increase nsamp.", - times.len(), - nsamp, - n_combinations, - MAX_COMBINATIONS - )); - } - - // Generate prediction matrix: rows = time points, cols = support points let pred_matrix = Mat::from_fn(predictions[0].len(), theta.nspp(), |i, j| { predictions[j][i].prediction() }); - // Generate all C(m, n) sample candidate index combinations - let candidate_indices = generate_combinations(times.len(), nsamp); - - // Evaluate risk in parallel for all combinations and select minimum - let (best_combo, min_risk) = candidate_indices - .par_iter() - .map(|combo| { - let risk = calculate_risk(combo, &pred_matrix, &errormodel, &weights); - (combo.clone(), risk) - }) - .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Greater)) - .ok_or_else(|| anyhow::anyhow!("No candidate combinations to evaluate"))?; - - let optimal_times = best_combo.iter().map(|&i| times[i]).collect(); - Ok(MmoptResult { - times: optimal_times, - risk: min_risk, - }) + Ok((pred_matrix, weights_vec, times)) } /// Calculate the Bayes risk for a specific combination of sample time indices. diff --git a/tests/mmopt_tests.rs b/tests/mmopt_tests.rs index e6caf20d8..3d6c090e0 100644 --- a/tests/mmopt_tests.rs +++ b/tests/mmopt_tests.rs @@ -2,6 +2,7 @@ use anyhow::Result; use pmcore::mmopt::{mmopt, MmoptResult}; use pmcore::prelude::*; use pmcore::structs::theta::Theta; +use pmcore::structs::weights::Weights; /// Helper to create a simple one-compartment model fn one_comp_model() -> equation::ODE { @@ -59,9 +60,9 @@ fn test_mmopt_basic() -> Result<()> { .build(); let errormodel = additive_error_model(); - let weights = vec![0.5, 0.5]; + let weights = Weights::from_vec(vec![0.5, 0.5]); - let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, weights)?; + let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, &weights)?; assert_eq!( result.times.len(), @@ -117,7 +118,7 @@ fn test_mmopt_more_samples_lower_risk() -> Result<()> { errormodel.clone(), 0, 2, - vec![0.5, 0.5], + &Weights::from_vec(vec![0.5, 0.5]), )?; let result_3 = mmopt( &theta, @@ -126,7 +127,7 @@ fn test_mmopt_more_samples_lower_risk() -> Result<()> { errormodel.clone(), 0, 3, - vec![0.5, 0.5], + &Weights::from_vec(vec![0.5, 0.5]), )?; assert!( @@ -165,9 +166,9 @@ fn test_mmopt_three_support_points() -> Result<()> { .build(); let errormodel = additive_error_model(); - let weights = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; + let weights = Weights::from_vec(vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); - let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, weights)?; + let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, &weights)?; assert_eq!(result.times.len(), 2); assert!(result.risk >= 0.0); @@ -199,10 +200,10 @@ fn test_mmopt_all_samples() -> Result<()> { .build(); let errormodel = additive_error_model(); - let weights = vec![0.5, 0.5]; + let weights = Weights::from_vec(vec![0.5, 0.5]); // Select all 3 samples (only one combination) - let result = mmopt(&theta, &subject, eq, errormodel, 0, 3, weights)?; + let result = mmopt(&theta, &subject, eq, errormodel, 0, 3, &weights)?; assert_eq!(result.times.len(), 3); assert_eq!(result.times, vec![1.0, 4.0, 8.0]); @@ -243,7 +244,7 @@ fn test_mmopt_multiple_occasions_error() { additive_error_model(), 0, 1, - vec![0.5, 0.5], + &Weights::from_vec(vec![0.5, 0.5]), ); assert!(result.is_err()); } @@ -275,7 +276,7 @@ fn test_mmopt_single_support_point_error() { additive_error_model(), 0, 1, - vec![1.0], + &Weights::from_vec(vec![1.0]), ); assert!(result.is_err()); assert!(result @@ -313,7 +314,7 @@ fn test_mmopt_weights_mismatch_error() { additive_error_model(), 0, 1, - vec![0.33, 0.33, 0.34], + &Weights::from_vec(vec![0.33, 0.33, 0.34]), ); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("Weights length")); @@ -346,7 +347,7 @@ fn test_mmopt_zero_samples_error() { additive_error_model(), 0, 0, - vec![0.5, 0.5], + &Weights::from_vec(vec![0.5, 0.5]), ); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("at least 1")); @@ -381,7 +382,7 @@ fn test_mmopt_too_many_samples_error() { additive_error_model(), 0, 5, - vec![0.5, 0.5], + &Weights::from_vec(vec![0.5, 0.5]), ); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("exceeds")); @@ -420,10 +421,18 @@ fn test_mmopt_unequal_weights() -> Result<()> { errormodel.clone(), 0, 2, - vec![0.5, 0.5], + &Weights::from_vec(vec![0.5, 0.5]), )?; - let result_skewed = mmopt(&theta, &subject, eq, errormodel, 0, 2, vec![0.9, 0.1])?; + let result_skewed = mmopt( + &theta, + &subject, + eq, + errormodel, + 0, + 2, + &Weights::from_vec(vec![0.9, 0.1]), + )?; // Different weights should generally produce different risks // (or at least both should be valid) @@ -480,7 +489,7 @@ fn test_mmopt_single_sample() -> Result<()> { additive_error_model(), 0, 1, - vec![0.5, 0.5], + &Weights::from_vec(vec![0.5, 0.5]), )?; assert_eq!(result.times.len(), 1);