diff --git a/examples/bayes_risk.rs b/examples/bayes_risk.rs new file mode 100644 index 000000000..f919a62e1 --- /dev/null +++ b/examples/bayes_risk.rs @@ -0,0 +1,146 @@ +//! Example: Compute Bayes risk for a given sampling design +//! +//! Uses the same PK model and support points as Section 6 of Bayard & Neely (2017). +//! Instead of optimizing sample times, this calculates the Bayes risk for +//! user-specified observation times. + +use anyhow::Result; +use pmcore::mmopt::bayes_risk; +use pmcore::prelude::*; +use pmcore::structs::theta::Theta; +use pmcore::structs::weights::Weights; + +/// One-compartment model: dx/dt = -K*x + input, y = x/V +fn one_comp_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0] + rateiv[0]; + }, + |_p, _, _| lag! {}, + |_p, _, _| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ) +} + +fn main() -> Result<()> { + let eq = one_comp_model(); + let params = Parameters::new().add("ke", 0.01, 0.2).add("v", 80.0, 120.0); + + // Table 6.1 support points [K, V] + let support_points: [(f64, f64); 10] = [ + (0.090088, 113.7451), + (0.111611, 93.4326), + (0.066074, 90.2832), + (0.108604, 89.2334), + (0.103047, 112.1093), + (0.033965, 94.3847), + (0.100859, 109.8633), + (0.023174, 111.7920), + (0.087041, 108.6670), + (0.095996, 100.3418), + ]; + + let mat = faer::Mat::from_fn(10, 2, |r, c| match c { + 0 => support_points[r].0, + 1 => support_points[r].1, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let errormodel = ErrorModel::additive(ErrorPoly::new(0.1, 0.0, 0.0, 0.0), 0.0); + let weights = Weights::uniform(10); + + // --- Design A: Two observations at the MMopt-optimal times --- + let subject_a = Subject::builder("design_a") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(1.0, 0) + .missing_observation(9.5, 0) + .build(); + + let risk_a = bayes_risk( + &theta, + &subject_a, + eq.clone(), + errormodel.clone(), + 0, + &weights, + )?; + println!("Design A t = {{1.0, 9.5}} Bayes risk = {:.6}", risk_a); + + // --- Design B: Two observations at sub-optimal times --- + let subject_b = Subject::builder("design_b") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .build(); + + let risk_b = bayes_risk( + &theta, + &subject_b, + eq.clone(), + errormodel.clone(), + 0, + &weights, + )?; + println!( + "Design B t = {{2.0, 6.0}} Bayes risk = {:.6}", + risk_b + ); + + // --- Design C: B + one more sample --- + let subject_c = Subject::builder("design_c") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(12.0, 0) + .build(); + + let risk_c = bayes_risk( + &theta, + &subject_c, + eq.clone(), + errormodel.clone(), + 0, + &weights, + )?; + println!( + "Design C t = {{2.0, 6.0, 12.0}} Bayes risk = {:.6}", + risk_c + ); + + // --- Design D: C + one more sample --- + let subject_d = Subject::builder("design_d") + .infusion(0.0, 300.0, 0, 1.0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(12.0, 0) + .missing_observation(18.0, 0) + .build(); + + let risk_d = bayes_risk(&theta, &subject_d, eq, errormodel, 0, &weights)?; + println!( + "Design D t = {{2.0, 6.0, 12.0, 18.0}} Bayes risk = {:.6}", + risk_d + ); + + println!( + "\nDesign A vs B: {:.1}% lower risk with optimal times", + (1.0 - risk_a / risk_b) * 100.0 + ); + println!( + "B → C (add 1 sample): {:.1}% risk reduction", + (1.0 - risk_c / risk_b) * 100.0 + ); + println!( + "C → D (add 1 sample): {:.1}% risk reduction", + (1.0 - risk_d / risk_c) * 100.0 + ); + + Ok(()) +} diff --git a/examples/mmopt.rs b/examples/mmopt.rs new file mode 100644 index 000000000..0a4532173 --- /dev/null +++ b/examples/mmopt.rs @@ -0,0 +1,187 @@ +//! Replication of the experiments in Bayard & Neely (2017) +//! "Experiment Design for Nonparametric Models Based On Minimizing Bayes Risk" +//! Journal of Pharmacokinetics and Pharmacodynamics. +//! https://doi.org/10.1007/s10928-016-9498-5 + +use anyhow::Result; +use pmcore::mmopt::mmopt; +use pmcore::prelude::*; +use pmcore::structs::theta::Theta; +use pmcore::structs::weights::Weights; + +/// One-compartment model: dx/dt = -K*x + input, y = x/V +fn one_comp_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0] + rateiv[0]; + }, + |_p, _, _| lag! {}, + |_p, _, _| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ) +} + +fn main() -> Result<()> { + section4()?; + println!(); + section6()?; + Ok(()) +} + +/// Paper Section 4: Two-support-point exponential decay example +/// +/// Model: μ(t,a) = e^{-at} (implemented as 1-compartment with D=V=1) +/// Support points: a1 = 1.5 (fast), a2 = 0.25 (slow) +/// Uniform priors: p1 = p2 = 0.5 +/// Error: σ = 0.3 (constant additive) +/// Candidate times: 0.1 to 5.0 hours at 0.1-hour intervals +/// +/// Analytical optimum: t* = ln(6)/1.25 ≈ 1.4334 hours +fn section4() -> Result<()> { + println!("=== Section 4: Two-support-point example ===\n"); + + let eq = one_comp_model(); + let params = Parameters::new().add("ke", 0.1, 5.0).add("v", 0.5, 2.0); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 1.5, // a1 (fast) + (0, 1) => 1.0, // V = 1 + (1, 0) => 0.25, // a2 (slow) + (1, 1) => 1.0, // V = 1 + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let errormodel = ErrorModel::additive(ErrorPoly::new(0.3, 0.0, 0.0, 0.0), 0.0); + + // Candidate times: 0.1 to 5.0 at 0.1h steps + let mut builder = Subject::builder("section4"); + builder = builder.bolus(0.0, 1.0, 0); + for i in 1..=50 { + builder = builder.missing_observation(i as f64 * 0.1, 0); + } + let subject = builder.build(); + + let weights = Weights::from_vec(vec![0.5, 0.5]); + let analytical = (6.0_f64).ln() / 1.25; + + let result = mmopt(&theta, &subject, eq, errormodel, 0, 1, &weights)?; + + println!( + " Analytical optimum: t* = ln(6)/1.25 = {:.4} h", + analytical + ); + println!(" MMopt optimal time: t = {:.4} h", result.times[0]); + println!(" Bayes risk (overbound): {:.6}", result.risk); + + Ok(()) +} + +/// Paper Section 6: PK example with 10 support points +/// +/// Model: one-compartment, dx/dt = d(t) - K*x, y = x/V +/// Dose: 300 units infused over 1 hour (rate = 300/hr) +/// Error: σ = 0.1 (constant additive) +/// 10 support points from Table 6.1 with equal priors (p_i = 0.1) +/// Candidate times: 0.25 to 24.0 hours at 0.25-hour intervals +/// +/// Paper results (Table 6.2): +/// n=1: t* = {4.25}, Bayes Risk = 0.5474 +/// n=2: t* = {1.0, 9.5}, Bayes Risk = 0.2947 +/// n=3: t* = {1.0, 1.0, 10.5}, Bayes Risk = 0.2325 +fn section6() -> Result<()> { + println!("=== Section 6: PK example (10 support points, Table 6.1) ===\n"); + + let eq = one_comp_model(); + let params = Parameters::new().add("ke", 0.01, 0.2).add("v", 80.0, 120.0); + + // Table 6.1 support points [K, V] + let support_points: [(f64, f64); 10] = [ + (0.090088, 113.7451), + (0.111611, 93.4326), + (0.066074, 90.2832), + (0.108604, 89.2334), + (0.103047, 112.1093), + (0.033965, 94.3847), + (0.100859, 109.8633), + (0.023174, 111.7920), + (0.087041, 108.6670), + (0.095996, 100.3418), + ]; + + let mat = faer::Mat::from_fn(10, 2, |r, c| match c { + 0 => support_points[r].0, + 1 => support_points[r].1, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let errormodel = ErrorModel::additive(ErrorPoly::new(0.1, 0.0, 0.0, 0.0), 0.0); + + // 1-hour infusion of 300 units; candidate times 0.25 to 24h at 0.25h steps + let mut builder = Subject::builder("section6"); + builder = builder.infusion(0.0, 300.0, 0, 1.0); + for i in 1..=96 { + builder = builder.missing_observation(i as f64 * 0.25, 0); + } + let subject = builder.build(); + + let weights = Weights::uniform(10); + + // --- 1-sample design --- + let r1 = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 1, + &weights, + )?; + println!(" 1-sample design:"); + println!(" Paper: t* = {{4.25}}, Bayes Risk = 0.5474"); + println!( + " MMopt: t* = {{{:.2}}}, Bayes risk = {:.6}", + r1.times[0], r1.risk + ); + + // --- 2-sample design --- + let r2 = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 2, + &weights, + )?; + println!("\n 2-sample design:"); + println!(" Paper: t* = {{1.0, 9.5}}, Bayes Risk = 0.2947"); + println!( + " MMopt: t* = {{{:.2}, {:.2}}}, Bayes risk = {:.6}", + r2.times[0], r2.times[1], r2.risk + ); + + // --- 3-sample design --- + let r3 = mmopt(&theta, &subject, eq, errormodel, 0, 3, &weights)?; + println!("\n 3-sample design:"); + println!(" Paper: t* = {{1.0, 1.0, 10.5}}, Bayes Risk = 0.2325"); + println!( + " MMopt: t* = {{{:.2}, {:.2}, {:.2}}}, Bayes risk = {:.6}", + r3.times[0], r3.times[1], r3.times[2], r3.risk + ); + + println!( + "\n Risk reduction: 1→2 samples: {:.1}%, 2→3 samples: {:.1}%", + (1.0 - r2.risk / r1.risk) * 100.0, + (1.0 - r3.risk / r2.risk) * 100.0, + ); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 41cbc9af8..49f0f4d3d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,9 @@ pub mod routines; // Structures pub mod structs; +// MMopt +pub mod mmopt; + // Re-export commonly used items pub use anyhow::Result; pub use std::collections::HashMap; @@ -46,6 +49,8 @@ pub mod prelude { pub use crate::routines::settings::*; pub use crate::structs::*; + pub use crate::mmopt::*; + pub mod simulator { pub use pharmsol::prelude::simulator::*; } diff --git a/src/mmopt/mod.rs b/src/mmopt/mod.rs new file mode 100644 index 000000000..b2e924966 --- /dev/null +++ b/src/mmopt/mod.rs @@ -0,0 +1,502 @@ +use anyhow::{Ok, Result}; +use faer::Mat; +use pharmsol::{Equation, ErrorModel, Predictions, Subject}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +use crate::structs::theta::Theta; +use crate::structs::weights::Weights; + +/// The results of a multiple-model optimization +/// +/// Contains the optimal sample times and the associated Bayes risk. +#[derive(Debug, Clone)] +pub struct MmoptResult { + /// Optimal sample times + pub times: Vec, + /// Bayes risk at the optimal sample times + pub risk: f64, +} + +impl std::fmt::Display for MmoptResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Optimal times: {:?}, Bayes risk: {:.6}", + self.times, self.risk + ) + } +} + +/// Compute the Bayes risk overbound for the observation times in the given subject. +/// +/// This evaluates the Bhattacharyya upper bound on the Bayes risk of misclassification +/// between support points, using all observation times defined in the subject. Unlike +/// [`mmopt`], this does not search over combinations — it simply scores the design +/// represented by the subject's current observations. +/// +/// # Arguments +/// * `theta` - Support points (population parameter distribution) +/// * `subject` - Subject whose observation times define the sampling design +/// * `equation` - The pharmacometric model equation +/// * `errormodel` - Error model for computing observation variance +/// * `outeq` - Output equation index to evaluate +/// * `weights` - Probability weights for each support point (must sum to ~1.0) +pub fn bayes_risk( + theta: &Theta, + subject: &Subject, + equation: impl Equation, + errormodel: ErrorModel, + outeq: usize, + weights: &Weights, +) -> Result { + let (pred_matrix, weights_vec, _) = + build_prediction_matrix(theta, subject, equation, outeq, weights)?; + + let all_indices: Vec = (0..pred_matrix.nrows()).collect(); + Ok(calculate_risk( + &all_indices, + &pred_matrix, + &errormodel, + &weights_vec, + )) +} + +/// Perform multiple-model optimization to determine optimal sample times. +/// +/// This function evaluates all possible combinations of `nsamp` sample times +/// from the candidate times defined as [pharmsol::data::Observation]s in the [Subject], and returns +/// the combination that minimizes the Bayes risk of misclassification between +/// support points. +/// +/// # Arguments +/// * `theta` - Support points (population parameter distribution) +/// * `subject` - Subject with candidate observation times (must have exactly one occasion) +/// * `equation` - The pharmacometric model equation +/// * `errormodel` - Error model for computing observation variance +/// * `outeq` - Output equation index to optimize for +/// * `nsamp` - Number of samples to select +/// * `weights` - Probability weights for each support point (must sum to ~1.0) +/// +/// # Errors +/// Returns an error if: +/// - The subject has more than one occasion +/// - The number of support points is less than 2 +/// - The weights length doesn't match the number of support points +/// - `nsamp` is 0 or exceeds the number of candidate times +pub fn mmopt( + theta: &Theta, + subject: &Subject, + equation: impl Equation, + errormodel: ErrorModel, + outeq: usize, + nsamp: usize, + weights: &Weights, +) -> Result { + if nsamp == 0 { + return Err(anyhow::anyhow!("Number of samples must be at least 1")); + } + + let (pred_matrix, weights_vec, times) = + build_prediction_matrix(theta, subject, equation, outeq, weights)?; + + if nsamp > times.len() { + return Err(anyhow::anyhow!( + "Number of samples ({}) exceeds number of candidate times ({})", + nsamp, + times.len() + )); + } + + // Guard against combinatorial explosion + let n_combinations = n_choose_k(times.len(), nsamp); + const MAX_COMBINATIONS: u128 = 1_000_000; + if n_combinations > MAX_COMBINATIONS { + return Err(anyhow::anyhow!( + "C({}, {}) = {} exceeds the maximum allowed combinations ({}). \ + Reduce the number of candidate times or increase nsamp.", + times.len(), + nsamp, + n_combinations, + MAX_COMBINATIONS + )); + } + + // Generate all C(m, n) sample candidate index combinations + let candidate_indices = generate_combinations(times.len(), nsamp); + + // Evaluate risk in parallel for all combinations and select minimum + let (best_combo, min_risk) = candidate_indices + .par_iter() + .map(|combo| { + let risk = calculate_risk(combo, &pred_matrix, &errormodel, &weights_vec); + (combo.clone(), risk) + }) + .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Greater)) + .ok_or_else(|| anyhow::anyhow!("No candidate combinations to evaluate"))?; + + let optimal_times = best_combo.iter().map(|&i| times[i]).collect(); + Ok(MmoptResult { + times: optimal_times, + risk: min_risk, + }) +} + +/// Shared helper: validate inputs, generate predictions, and build the prediction matrix. +/// +/// Returns `(pred_matrix, weights_vec, times)` where: +/// - `pred_matrix` has rows = time points, cols = support points +/// - `weights_vec` is the weight vector extracted from `Weights` +/// - `times` is the vector of prediction times +fn build_prediction_matrix( + theta: &Theta, + subject: &Subject, + equation: impl Equation, + outeq: usize, + weights: &Weights, +) -> Result<(Mat, Vec, Vec)> { + if subject.occasions().len() != 1 { + return Err(anyhow::anyhow!("Subject must contain exactly one Occasion")); + } + + if theta.nspp() < 2 { + return Err(anyhow::anyhow!( + "At least 2 support points are required, got {}", + theta.nspp() + )); + } + + if weights.len() != theta.nspp() { + return Err(anyhow::anyhow!( + "Weights length ({}) must match number of support points ({})", + weights.len(), + theta.nspp() + )); + } + + let weights_vec = weights.to_vec(); + + // Generate predictions for each support point + let predictions = theta + .matrix() + .row_iter() + .enumerate() + .map(|(idx, theta_row)| { + let support_point: Vec = theta_row.iter().cloned().collect(); + let all_preds = equation + .estimate_predictions(subject, &support_point) + .map_err(|e| { + anyhow::anyhow!( + "Failed to generate predictions for support point {}: {}", + idx, + e + ) + })? + .get_predictions(); + Ok(all_preds + .into_iter() + .filter(|p| p.outeq() == outeq) + .collect::>()) + }) + .collect::>>()?; + + if predictions[0].is_empty() { + return Err(anyhow::anyhow!( + "No predictions found for output equation {}", + outeq + )); + } + + let times: Vec = predictions[0].iter().map(|p| p.time()).collect(); + + let pred_matrix = Mat::from_fn(predictions[0].len(), theta.nspp(), |i, j| { + predictions[j][i].prediction() + }); + + Ok((pred_matrix, weights_vec, times)) +} + +/// Calculate the Bayes risk for a specific combination of sample time indices. +/// +/// The risk quantifies the expected misclassification probability between support +/// points, weighted by their probabilities. Lower risk means the selected sample +/// times provide better discrimination between support points. +fn calculate_risk( + combo: &[usize], + pred_matrix: &Mat, + errormodel: &ErrorModel, + weights: &[f64], +) -> f64 { + let nspp = pred_matrix.ncols(); + + (0..nspp) + .flat_map(|i| ((i + 1)..nspp).map(move |j| (i, j))) + .map(|(i, j)| { + // Extract predictions for the selected time points + let i_obs: Vec = combo.iter().map(|&k| pred_matrix[(k, i)]).collect(); + let j_obs: Vec = combo.iter().map(|&k| pred_matrix[(k, j)]).collect(); + + // Calculate the sum of log-likelihood discrimination terms + let sum_k_ijn: f64 = i_obs + .iter() + .zip(j_obs.iter()) + .map(|(&y_i, &y_j)| { + let i_var = errormodel.variance_from_value(y_i).unwrap_or(f64::EPSILON); + let j_var = errormodel.variance_from_value(y_j).unwrap_or(f64::EPSILON); + let denominator = i_var + j_var; + + let term1 = (y_i - y_j).powi(2) / (4.0 * denominator); + let term2 = 0.5 * (denominator / 2.0).ln(); + let term3 = -0.25 * (i_var * j_var).ln(); + + term1 + term2 + term3 + }) + .sum(); + + weights[i] * weights[j] * (-sum_k_ijn).exp() + }) + .sum() +} + +/// Compute C(m, n) without overflow risk by using u128 arithmetic. +fn n_choose_k(m: usize, n: usize) -> u128 { + if n > m { + return 0; + } + // Use the smaller of n and m-n for efficiency + let k = n.min(m - n) as u128; + let m = m as u128; + (0..k).fold(1u128, |acc, i| acc * (m - i) / (i + 1)) +} + +fn generate_combinations(m: usize, n: usize) -> Vec> { + fn backtrack( + m: usize, + n: usize, + start: usize, + current: &mut Vec, + results: &mut Vec>, + ) { + if current.len() == n { + results.push(current.clone()); + return; + } + + for i in start..m { + current.push(i); + backtrack(m, n, i + 1, current, results); + current.pop(); + } + } + + let mut results = Vec::new(); + let mut current = Vec::new(); + backtrack(m, n, 0, &mut current, &mut results); + results +} + +#[cfg(test)] +mod tests { + use super::*; + use faer::Mat; + use pharmsol::ErrorPoly; + + #[test] + fn test_combinations() { + let m = 5; + let n = 3; + let combinations = generate_combinations(m, n); + assert_eq!(combinations.len(), 10); + assert_eq!(combinations[0], vec![0, 1, 2]); + assert_eq!(combinations[1], vec![0, 1, 3]); + assert_eq!(combinations[2], vec![0, 1, 4]); + assert_eq!(combinations[3], vec![0, 2, 3]); + assert_eq!(combinations[4], vec![0, 2, 4]); + assert_eq!(combinations[5], vec![0, 3, 4]); + assert_eq!(combinations[6], vec![1, 2, 3]); + assert_eq!(combinations[7], vec![1, 2, 4]); + assert_eq!(combinations[8], vec![1, 3, 4]); + assert_eq!(combinations[9], vec![2, 3, 4]); + } + + #[test] + fn test_combinations_edge_cases() { + // Select all elements + let combinations = generate_combinations(3, 3); + assert_eq!(combinations.len(), 1); + assert_eq!(combinations[0], vec![0, 1, 2]); + + // Select 1 element + let combinations = generate_combinations(4, 1); + assert_eq!(combinations.len(), 4); + assert_eq!(combinations[0], vec![0]); + assert_eq!(combinations[3], vec![3]); + + // C(6, 2) = 15 + let combinations = generate_combinations(6, 2); + assert_eq!(combinations.len(), 15); + } + + fn make_error_model() -> ErrorModel { + ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 0.0) + } + + #[test] + fn test_calculate_risk_identical_predictions() { + // When two support points have identical predictions, misclassification risk is maximal + let errormodel = make_error_model(); + let weights = vec![0.5, 0.5]; + + // pred_matrix: 3 time points, 2 support points with identical predictions + let pred_matrix = Mat::from_fn(3, 2, |i, _j| match i { + 0 => 10.0, + 1 => 5.0, + 2 => 2.0, + _ => 0.0, + }); + + let combo = vec![0, 1, 2]; + let risk_identical = calculate_risk(&combo, &pred_matrix, &errormodel, &weights); + + // Now different predictions + let pred_matrix_diff = Mat::from_fn(3, 2, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 50.0, + (1, 0) => 5.0, + (1, 1) => 25.0, + (2, 0) => 2.0, + (2, 1) => 12.0, + _ => 0.0, + }); + + let risk_different = calculate_risk(&combo, &pred_matrix_diff, &errormodel, &weights); + + // Identical predictions should have higher risk (harder to discriminate) + assert!( + risk_identical > risk_different, + "Identical predictions should have higher risk: {} vs {}", + risk_identical, + risk_different + ); + } + + #[test] + fn test_calculate_risk_symmetric_weights() { + // Risk should be symmetric when weights are equal + let errormodel = make_error_model(); + let weights = vec![0.5, 0.5]; + + let pred_matrix = Mat::from_fn(3, 2, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 20.0, + (1, 0) => 5.0, + (1, 1) => 10.0, + (2, 0) => 2.0, + (2, 1) => 4.0, + _ => 0.0, + }); + + let combo = vec![0, 1, 2]; + let risk = calculate_risk(&combo, &pred_matrix, &errormodel, &weights); + assert!(risk > 0.0, "Risk should be positive"); + assert!(risk.is_finite(), "Risk should be finite"); + } + + #[test] + fn test_calculate_risk_more_samples_lower_risk() { + // Using more sample times should generally not increase (and usually decrease) risk + let errormodel = make_error_model(); + let weights = vec![0.5, 0.5]; + + // 4 time points with very different predictions + let pred_matrix = Mat::from_fn(4, 2, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 20.0, + (1, 0) => 5.0, + (1, 1) => 15.0, + (2, 0) => 2.0, + (2, 1) => 8.0, + (3, 0) => 1.0, + (3, 1) => 6.0, + _ => 0.0, + }); + + // Best 2-sample combo risk + let combos_2 = generate_combinations(4, 2); + let min_risk_2 = combos_2 + .iter() + .map(|combo| calculate_risk(combo, &pred_matrix, &errormodel, &weights)) + .fold(f64::INFINITY, f64::min); + + // Best 3-sample combo risk + let combos_3 = generate_combinations(4, 3); + let min_risk_3 = combos_3 + .iter() + .map(|combo| calculate_risk(combo, &pred_matrix, &errormodel, &weights)) + .fold(f64::INFINITY, f64::min); + + assert!( + min_risk_3 <= min_risk_2, + "More samples should yield equal or lower risk: {} vs {}", + min_risk_3, + min_risk_2 + ); + } + + #[test] + fn test_calculate_risk_zero_weight() { + // Setting a weight to zero should eliminate that support point's contribution + let errormodel = make_error_model(); + + let pred_matrix = Mat::from_fn(2, 3, |i, j| match (i, j) { + (0, 0) => 10.0, + (0, 1) => 20.0, + (0, 2) => 30.0, + (1, 0) => 5.0, + (1, 1) => 10.0, + (1, 2) => 15.0, + _ => 0.0, + }); + + let combo = vec![0, 1]; + + // With all weights + let weights_all = vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]; + let risk_all = calculate_risk(&combo, &pred_matrix, &errormodel, &weights_all); + + // Zero out one weight + let weights_zero = vec![0.5, 0.5, 0.0]; + let risk_zero = calculate_risk(&combo, &pred_matrix, &errormodel, &weights_zero); + + // Risks should differ when a weight is zeroed + assert!( + (risk_all - risk_zero).abs() > 1e-10, + "Zeroing a weight should change the risk" + ); + } + + #[test] + fn test_risk_positive_and_finite() { + let errormodel = make_error_model(); + let weights = vec![0.25, 0.25, 0.25, 0.25]; + + let pred_matrix = Mat::from_fn(5, 4, |i, j| (i as f64 + 1.0) * (j as f64 + 1.0) * 2.0); + + let combo = vec![0, 2, 4]; + let risk = calculate_risk(&combo, &pred_matrix, &errormodel, &weights); + assert!(risk >= 0.0, "Risk must be non-negative"); + assert!(risk.is_finite(), "Risk must be finite"); + } + + #[test] + fn test_mmopt_result_display() { + let result = MmoptResult { + times: vec![1.0, 4.0, 8.0], + risk: 0.123456, + }; + let display = format!("{}", result); + assert!(display.contains("1.0")); + assert!(display.contains("4.0")); + assert!(display.contains("8.0")); + assert!(display.contains("0.123456")); + } +} diff --git a/tests/mmopt_tests.rs b/tests/mmopt_tests.rs new file mode 100644 index 000000000..3d6c090e0 --- /dev/null +++ b/tests/mmopt_tests.rs @@ -0,0 +1,500 @@ +use anyhow::Result; +use pmcore::mmopt::{mmopt, MmoptResult}; +use pmcore::prelude::*; +use pmcore::structs::theta::Theta; +use pmcore::structs::weights::Weights; + +/// Helper to create a simple one-compartment model +fn one_comp_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, b, _rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = -ke * x[0] + b[0]; + }, + |_p, _, _| lag! {}, + |_p, _, _| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + (1, 1), + ) +} + +/// Helper to create a simple error model +fn additive_error_model() -> ErrorModel { + ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 0.0) +} + +/// Helper to create parameters for the one-compartment model +fn one_comp_params() -> Parameters { + Parameters::new().add("ke", 0.1, 1.0).add("v", 30.0, 100.0) +} + +/// Test basic mmopt functionality with a simple one-compartment model +#[test] +fn test_mmopt_basic() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + // Two support points with different PK parameters + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, // ke for spp1 + (0, 1) => 50.0, // v for spp1 + (1, 0) => 0.8, // ke for spp2 + (1, 1) => 80.0, // v for spp2 + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + // Subject with candidate observation times after a bolus dose + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + let weights = Weights::from_vec(vec![0.5, 0.5]); + + let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, &weights)?; + + assert_eq!( + result.times.len(), + 2, + "Should select exactly 2 sample times" + ); + assert!(result.risk >= 0.0, "Risk must be non-negative"); + assert!(result.risk.is_finite(), "Risk must be finite"); + + // All selected times should be from the candidate set + let candidate_times = vec![1.0, 2.0, 4.0, 6.0, 8.0, 12.0]; + for t in &result.times { + assert!( + candidate_times.contains(t), + "Selected time {} is not in the candidate set", + t + ); + } + + Ok(()) +} + +/// Test that selecting more samples results in equal or lower risk +#[test] +fn test_mmopt_more_samples_lower_risk() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 50.0, + (1, 0) => 0.7, + (1, 1) => 70.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + + let result_2 = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 2, + &Weights::from_vec(vec![0.5, 0.5]), + )?; + let result_3 = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 3, + &Weights::from_vec(vec![0.5, 0.5]), + )?; + + assert!( + result_3.risk <= result_2.risk + 1e-10, + "More samples should yield lower or equal risk: {} vs {}", + result_3.risk, + result_2.risk + ); + + Ok(()) +} + +/// Test mmopt with three support points +#[test] +fn test_mmopt_three_support_points() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(3, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 40.0, + (1, 0) => 0.5, + (1, 1) => 60.0, + (2, 0) => 0.9, + (2, 1) => 90.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(3.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + let weights = Weights::from_vec(vec![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0]); + + let result = mmopt(&theta, &subject, eq, errormodel, 0, 2, &weights)?; + + assert_eq!(result.times.len(), 2); + assert!(result.risk >= 0.0); + assert!(result.risk.is_finite()); + + Ok(()) +} + +/// Test that mmopt with all candidate times produces the lowest possible risk +#[test] +fn test_mmopt_all_samples() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + let weights = Weights::from_vec(vec![0.5, 0.5]); + + // Select all 3 samples (only one combination) + let result = mmopt(&theta, &subject, eq, errormodel, 0, 3, &weights)?; + + assert_eq!(result.times.len(), 3); + assert_eq!(result.times, vec![1.0, 4.0, 8.0]); + + Ok(()) +} + +/// Test validation: subject with multiple occasions should fail +#[test] +fn test_mmopt_multiple_occasions_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + // Subject with two occasions + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .repeat(1, 24.0) + .bolus(24.0, 100.0, 0) + .observation(25.0, 0.0, 0) + .build(); + + // Only proceed if the subject actually has multiple occasions + if subject.occasions().len() > 1 { + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + &Weights::from_vec(vec![0.5, 0.5]), + ); + assert!(result.is_err()); + } +} + +/// Test validation: fewer than 2 support points should fail +#[test] +fn test_mmopt_single_support_point_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(1, 2, |_r, c| match c { + 0 => 0.3, + 1 => 50.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + &Weights::from_vec(vec![1.0]), + ); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("At least 2 support points")); +} + +/// Test validation: weights length mismatch should fail +#[test] +fn test_mmopt_weights_mismatch_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // 3 weights for 2 support points + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + &Weights::from_vec(vec![0.33, 0.33, 0.34]), + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("Weights length")); +} + +/// Test validation: nsamp = 0 should fail +#[test] +fn test_mmopt_zero_samples_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .build(); + + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 0, + &Weights::from_vec(vec![0.5, 0.5]), + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("at least 1")); +} + +/// Test validation: nsamp exceeds candidate times should fail +#[test] +fn test_mmopt_too_many_samples_error() { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.3, + (0, 1) => 50.0, + (1, 0) => 0.6, + (1, 1) => 75.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params).unwrap(); + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .build(); + + // Request 5 samples but only 2 candidate times + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 5, + &Weights::from_vec(vec![0.5, 0.5]), + ); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("exceeds")); +} + +/// Test that unequal weights influence the optimal sampling design +#[test] +fn test_mmopt_unequal_weights() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 50.0, + (1, 0) => 0.8, + (1, 1) => 80.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .observation(12.0, 0.0, 0) + .build(); + + let errormodel = additive_error_model(); + + let result_equal = mmopt( + &theta, + &subject, + eq.clone(), + errormodel.clone(), + 0, + 2, + &Weights::from_vec(vec![0.5, 0.5]), + )?; + + let result_skewed = mmopt( + &theta, + &subject, + eq, + errormodel, + 0, + 2, + &Weights::from_vec(vec![0.9, 0.1]), + )?; + + // Different weights should generally produce different risks + // (or at least both should be valid) + assert!(result_equal.risk.is_finite()); + assert!(result_skewed.risk.is_finite()); + assert!(result_equal.risk >= 0.0); + assert!(result_skewed.risk >= 0.0); + + Ok(()) +} + +/// Test MmoptResult Display implementation +#[test] +fn test_mmopt_result_display() { + let result = MmoptResult { + times: vec![2.0, 6.0, 12.0], + risk: 0.042, + }; + let display = format!("{}", result); + assert!(display.contains("2.0")); + assert!(display.contains("6.0")); + assert!(display.contains("12.0")); + assert!(display.contains("0.042")); +} + +/// Test with a single sample selection +#[test] +fn test_mmopt_single_sample() -> Result<()> { + let eq = one_comp_model(); + let params = one_comp_params(); + + let mat = faer::Mat::from_fn(2, 2, |r, c| match (r, c) { + (0, 0) => 0.2, + (0, 1) => 40.0, + (1, 0) => 0.9, + (1, 1) => 90.0, + _ => 0.0, + }); + let theta = Theta::from_parts(mat, params)?; + + let subject = Subject::builder("test") + .bolus(0.0, 100.0, 0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(4.0, 0.0, 0) + .observation(8.0, 0.0, 0) + .build(); + + let result = mmopt( + &theta, + &subject, + eq, + additive_error_model(), + 0, + 1, + &Weights::from_vec(vec![0.5, 0.5]), + )?; + + assert_eq!(result.times.len(), 1); + assert!(result.risk.is_finite()); + assert!(result.risk >= 0.0); + + Ok(()) +}