diff --git a/examples/bestdose.rs b/examples/bestdose.rs index 35e9deb06..b3be5a096 100644 --- a/examples/bestdose.rs +++ b/examples/bestdose.rs @@ -114,35 +114,22 @@ fn main() -> Result<()> { // Print results for (bias_weight, optimal) in &results { - let opt_doses = optimal - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Bolus(bolus) => Some(bolus.amount()), - Event::Infusion(infusion) => Some(infusion.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let opt_doses = optimal.doses(); println!( "Bias weight: {:.2}\t\t Optimal dose: {:?}\t\tCost: {:.6}\t\tln Cost: {:.4}\t\tMethod: {}", bias_weight, opt_doses, - optimal.objf, - optimal.objf.ln(), - optimal.optimization_method + optimal.objf(), + optimal.objf().ln(), + optimal.optimization_method() ); } // Print concentration-time predictions for the optimal dose let optimal = &results.last().unwrap().1; println!("\nConcentration-time predictions for optimal dose:"); - for pred in optimal.preds.predictions().into_iter() { + for pred in optimal.predictions().predictions().into_iter() { println!( "Time: {:.2} h, Observed: {:.2}, (Pop Mean: {:.4}, Pop Median: {:.4}, Post Mean: {:.4}, Post Median: {:.4})", pred.time(), pred.obs().unwrap_or(0.0), pred.pop_mean(), pred.pop_median(), pred.post_mean(), pred.post_median() diff --git a/examples/bestdose_auc.rs b/examples/bestdose_auc.rs index 0d9a46cfe..17738da1d 100644 --- a/examples/bestdose_auc.rs +++ b/examples/bestdose_auc.rs @@ -83,33 +83,20 @@ fn main() -> Result<()> { println!("Optimizing dose...\n"); let optimal = problem.optimize()?; - let opt_doses = optimal - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Bolus(bolus) => Some(bolus.amount()), - Event::Infusion(infusion) => Some(infusion.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let opt_doses = optimal.doses(); println!("=== RESULTS ==="); println!("Optimal dose: {:.1} mg", opt_doses[0]); - println!("Cost function: {:.6}", optimal.objf); + println!("Cost function: {:.6}", optimal.objf()); - if let Some(auc_preds) = &optimal.auc_predictions { + if let Some(auc_preds) = &optimal.auc_predictions() { println!("\nAUC Predictions:"); let mut total_error = 0.0; for (time, auc) in auc_preds { // Find the target AUC for this time - let target = if (*time - 6.0).abs() < 0.1 { + let target = if (time - 6.0).abs() < 0.1 { 50.0 - } else if (*time - 12.0).abs() < 0.1 { + } else if (time - 12.0).abs() < 0.1 { 80.0 } else { 0.0 @@ -127,7 +114,7 @@ fn main() -> Result<()> { ); } else { println!("\nConcentration Predictions:"); - for pred in optimal.preds.predictions() { + for pred in optimal.predictions().predictions() { println!( " Time: {:5.1}h | Target: {:6.1} | Predicted: {:6.2}", pred.time(), @@ -172,30 +159,13 @@ fn main() -> Result<()> { println!("Optimizing maintenance dose...\n"); let optimal_interval = problem_interval.optimize()?; - let doses: Vec = optimal_interval - .optimal_subject - .iter() - .map(|occ| { - occ.iter() - .filter(|event| match event { - Event::Bolus(_) => true, - Event::Infusion(_) => true, - _ => false, - }) - .map(|event| match event { - Event::Bolus(bolus) => bolus.amount(), - Event::Infusion(infusion) => infusion.amount(), - _ => 0.0, - }) - }) - .flatten() - .collect(); + let doses: Vec = optimal_interval.doses(); println!("=== INTERVAL AUC RESULTS ==="); println!("Optimal maintenance dose (at t=12h): {:.1} mg", doses[0]); - println!("Cost function: {:.6}", optimal_interval.objf); + println!("Cost function: {:.6}", optimal_interval.objf()); - if let Some(auc_preds) = &optimal_interval.auc_predictions { + if let Some(auc_preds) = &optimal_interval.auc_predictions() { println!("\nInterval AUC Predictions:"); for (time, auc) in auc_preds { let target = 60.0; diff --git a/examples/bestdose_bounds.rs b/examples/bestdose_bounds.rs index 4fab2862c..2c3bdd6dd 100644 --- a/examples/bestdose_bounds.rs +++ b/examples/bestdose_bounds.rs @@ -89,7 +89,7 @@ fn main() -> Result<()> { let result = problem.optimize()?; let doses: Vec = result - .optimal_subject + .optimal_subject() .iter() .map(|occ| { occ.iter() @@ -118,7 +118,10 @@ fn main() -> Result<()> { println!( "{:<30} | {:>10.1} mg | {:>10.6}{}", - description, doses[0], result.objf, at_bound + description, + doses[0], + result.objf(), + at_bound ); } diff --git a/src/bestdose/optimization.rs b/src/bestdose/optimization.rs index 4fb6c4cca..bd4056ca2 100644 --- a/src/bestdose/optimization.rs +++ b/src/bestdose/optimization.rs @@ -45,7 +45,7 @@ use argmin::solver::neldermead::NelderMead; use crate::bestdose::cost::calculate_cost; use crate::bestdose::predictions::calculate_final_predictions; -use crate::bestdose::types::{BestDoseProblem, BestDoseResult}; +use crate::bestdose::types::{BestDoseProblem, BestDoseResult, BestDoseStatus, OptimalMethod}; use crate::structs::weights::Weights; use pharmsol::prelude::*; @@ -244,10 +244,15 @@ pub fn dual_optimization(problem: &BestDoseProblem) -> Result { let (final_doses, final_cost, method, final_weights) = if cost1 <= cost2 { tracing::info!(" → Winner: Posterior (lower cost) ✓"); - (doses1, cost1, "posterior", problem.posterior.clone()) + ( + doses1, + cost1, + OptimalMethod::Posterior, + problem.posterior.clone(), + ) } else { tracing::info!(" → Winner: Uniform (lower cost) ✓"); - (doses2, cost2, "uniform", uniform_weights) + (doses2, cost2, OptimalMethod::Uniform, uniform_weights) }; // ═════════════════════════════════════════════════════════════ @@ -290,9 +295,9 @@ pub fn dual_optimization(problem: &BestDoseProblem) -> Result { Ok(BestDoseResult { optimal_subject, objf: final_cost, - status: "Converged".to_string(), + status: BestDoseStatus::Converged, preds, auc_predictions, - optimization_method: method.to_string(), + optimization_method: method, }) } diff --git a/src/bestdose/types.rs b/src/bestdose/types.rs index 61872d3dc..e422cd0be 100644 --- a/src/bestdose/types.rs +++ b/src/bestdose/types.rs @@ -6,12 +6,15 @@ //! - [`Target`]: Enum specifying concentration or AUC targets //! - [`DoseRange`]: Dose constraint specification +use std::fmt::Display; + use crate::prelude::*; use crate::routines::output::predictions::NPPredictions; use crate::routines::settings::Settings; use crate::structs::theta::Theta; use crate::structs::weights::Weights; use pharmsol::prelude::*; +use serde::{Deserialize, Serialize}; /// Target type for dose optimization /// @@ -49,7 +52,7 @@ use pharmsol::prelude::*; /// - Automatically finds the most recent bolus/infusion before each observation /// /// Both methods use trapezoidal rule on a dense time grid controlled by `settings.predictions().idelta`. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub enum Target { /// Target concentrations at observation times /// @@ -133,10 +136,10 @@ pub enum Target { /// ```rust,ignore /// use pmcore::bestdose::DoseRange; /// -/// // Standard range: 0-1000 mg +/// // Large range: 0-1000 mg /// let range = DoseRange::new(0.0, 1000.0); /// -/// // Narrow therapeutic window +/// // Narrow range: 50-150 mg /// let range = DoseRange::new(50.0, 150.0); /// /// // Access bounds @@ -357,31 +360,31 @@ pub struct BestDoseProblem { /// # Ok(()) /// # } /// ``` -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct BestDoseResult { - /// Optimal dose amount(s) + /// Subject with optimal doses /// - /// Vector contains one element per dose in the target subject. - /// Order matches the dose events in the target subject. - pub optimal_subject: Subject, + /// The [Subject] contains the same events as the target subject, + /// but with the dose amounts updated to the optimal values. + pub(crate) optimal_subject: Subject, /// Final cost function value /// /// Lower is better. Represents the weighted combination of variance /// (patient-specific error) and bias (deviation from population). - pub objf: f64, + pub(crate) objf: f64, /// Optimization status message /// /// Examples: "converged", "maximum iterations reached", etc. - pub status: String, + pub(crate) status: BestDoseStatus, /// Concentration-time predictions for optimal doses /// /// Contains predicted concentrations at observation times using the /// optimal doses. Predictions use the weights from the winning optimization /// method (posterior or uniform). - pub preds: NPPredictions, + pub(crate) preds: NPPredictions, /// AUC values at observation times /// @@ -389,7 +392,7 @@ pub struct BestDoseResult { /// Each tuple contains `(time, cumulative_auc)`. /// /// For [`Target::Concentration`], this field is `None`. - pub auc_predictions: Option>, + pub(crate) auc_predictions: Option>, /// Which optimization method produced the best result /// @@ -397,5 +400,84 @@ pub struct BestDoseResult { /// - `"uniform"`: Population-based optimization (uses uniform weights) /// /// The algorithm runs both optimizations and selects the one with lower cost. - pub optimization_method: String, + pub(crate) optimization_method: OptimalMethod, +} + +impl BestDoseResult { + /// Get the optimized subject + pub fn optimal_subject(&self) -> &Subject { + &self.optimal_subject + } + + /// Get the dose amounts of the optimized subject + /// + /// This includes all doses (bolus and infusion) in the order they appear + /// in the optimal subject, and returns their amounts as a vector of f64. + pub fn doses(&self) -> Vec { + self.optimal_subject() + .iter() + .flat_map(|occ| { + occ.events() + .iter() + .filter_map(|event| match event { + Event::Bolus(bolus) => Some(bolus.amount()), + Event::Infusion(infusion) => Some(infusion.amount()), + _ => None, + }) + .collect::>() + }) + .collect::>() + } + + /// Get the objective cost function value + pub fn objf(&self) -> f64 { + self.objf + } + + /// Get the optimization status + pub fn status(&self) -> &BestDoseStatus { + &self.status + } + + /// Get the concentration-time predictions + pub fn predictions(&self) -> &NPPredictions { + &self.preds + } + + /// Get the AUC predictions, if available + pub fn auc_predictions(&self) -> Option> { + self.auc_predictions.clone() + } + + /// Get the optimization method used + pub fn optimization_method(&self) -> OptimalMethod { + self.optimization_method + } +} + +/// Optimization method used in BestDose +/// +/// This returns the type of optimization method that produced the best result: +/// - `Posterior`: Patient-specific optimization using posterior weights +/// - `Uniform`: Population-based optimization using uniform weights +#[derive(Debug, Clone, Serialize, Deserialize, Copy, PartialEq, Eq)] +pub enum OptimalMethod { + Posterior, + Uniform, +} + +impl Display for OptimalMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OptimalMethod::Posterior => write!(f, "Posterior"), + OptimalMethod::Uniform => write!(f, "Uniform"), + } + } +} + +/// Status of the BestDose optimization +#[derive(Debug, Clone, Serialize, Deserialize, Copy, PartialEq, Eq)] +pub enum BestDoseStatus { + Converged, + MaxIterations, } diff --git a/tests/bestdose_tests.rs b/tests/bestdose_tests.rs index 220c88991..e99d74e69 100644 --- a/tests/bestdose_tests.rs +++ b/tests/bestdose_tests.rs @@ -105,36 +105,12 @@ fn test_infusion_mask_inclusion() -> Result<()> { // We should get back 1 optimized dose (the infusion placeholder) assert_eq!( - result - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Infusion(inf) => Some(inf.amount()), - _ => None, - }) - .collect::>() - }) - .count(), + result.doses().len(), 1, "Should have 1 optimized dose (the infusion)" ); - let optinf = result - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Infusion(inf) => Some(inf.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let optinf = result.doses(); // The optimized dose should be reasonable (not NaN, not infinite) assert!( @@ -222,22 +198,15 @@ fn test_fixed_infusion_preservation() -> Result<()> { let result = problem.optimize()?; // Should only optimize the future bolus, not the past infusion - let doses = result - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Infusion(inf) if inf.amount() != 200.0 => Some(inf.amount()), - Event::Bolus(bol) => Some(bol.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let doses = result.doses(); eprintln!("Optimized doses: {:?}", doses); - assert_eq!(doses.len(), 1, "Should have 1 optimized dose"); + assert_eq!( + doses.len(), + 2, + "Should have 2 doses (past infusion + future bolus)" + ); + assert_eq!(doses[0], 200.0, "Past infusion dose should be preserved"); + assert!(doses[1] > 0.0, "Future bolus dose should be optimized"); Ok(()) } @@ -465,25 +434,12 @@ fn test_basic_auc_mode() -> Result<()> { ); let result = result?; - let doses = result - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Infusion(inf) => Some(inf.amount()), - Event::Bolus(bol) => Some(bol.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let doses = result.doses(); assert_eq!(doses.len(), 1); - assert!(result.auc_predictions.is_some()); + assert!(result.auc_predictions().is_some()); - let auc_preds = result.auc_predictions.unwrap(); + let auc_preds = result.auc_predictions().unwrap(); eprintln!("Basic AUC test - AUC predictions: {:?}", auc_preds); assert_eq!(auc_preds.len(), 1); @@ -574,20 +530,7 @@ fn test_infusion_auc_mode() -> Result<()> { ); let result = result?; - let doses = result - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Infusion(inf) => Some(inf.amount()), - Event::Bolus(bol) => Some(bol.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let doses = result.doses(); eprintln!("Optimized dose: {:?}", doses); @@ -596,11 +539,11 @@ fn test_infusion_auc_mode() -> Result<()> { // Should have AUC predictions assert!( - result.auc_predictions.is_some(), + result.auc_predictions().is_some(), "Should have AUC predictions" ); - let auc_preds = result.auc_predictions.unwrap(); + let auc_preds = result.auc_predictions().unwrap(); eprintln!("AUC predictions: {:?}", auc_preds); assert_eq!(auc_preds.len(), 2, "Should have 2 AUC predictions"); @@ -767,27 +710,14 @@ fn test_multi_outeq_auc_optimization() -> Result<()> { let best_dose_result = result?; - let doses = best_dose_result - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Infusion(inf) => Some(inf.amount()), - Event::Bolus(bol) => Some(bol.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let doses = best_dose_result.doses(); assert_eq!(doses.len(), 1); assert!(doses[0] > 0.0); - assert!(best_dose_result.objf.is_finite()); + assert!(best_dose_result.objf().is_finite()); - assert!(best_dose_result.auc_predictions.is_some()); - let auc_preds = best_dose_result.auc_predictions.unwrap(); + assert!(best_dose_result.auc_predictions().is_some()); + let auc_preds = best_dose_result.auc_predictions().unwrap(); assert_eq!( auc_preds.len(), 2, @@ -868,33 +798,16 @@ fn test_auc_from_zero_single_dose() -> Result<()> { let result = problem.optimize()?; - let doses: Vec = result - .optimal_subject - .iter() - .map(|occ| { - occ.iter() - .filter(|event| match event { - Event::Bolus(_) => true, - Event::Infusion(_) => true, - _ => false, - }) - .map(|event| match event { - Event::Bolus(bolus) => bolus.amount(), - Event::Infusion(infusion) => infusion.amount(), - _ => 0.0, - }) - }) - .flatten() - .collect(); + let doses: Vec = result.doses(); // Verify we got a result assert_eq!(doses.len(), 1); assert!(doses[0] > 0.0); - assert!(result.objf.is_finite()); + assert!(result.objf().is_finite()); // Verify we have AUC predictions - assert!(result.auc_predictions.is_some()); - let auc_preds = result.auc_predictions.unwrap(); + assert!(result.auc_predictions().is_some()); + let auc_preds = result.auc_predictions().unwrap(); assert_eq!(auc_preds.len(), 1); let (time, auc) = auc_preds[0]; @@ -977,31 +890,18 @@ fn test_auc_from_last_dose_maintenance() -> Result<()> { )?; let result = problem.optimize()?; - let doses = result - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Infusion(inf) => Some(inf.amount()), - Event::Bolus(bol) => Some(bol.amount()), - _ => None, - }) - .collect::>() - }) - .collect::>(); + let doses = result.doses(); // Verify we got a result assert_eq!(doses.len(), 2, "Should be 2 doses (loading + maintenance)"); // Very first one is fixed loading dose, second is optimized maintenance dose assert_eq!(doses[0], 300.0); assert!(doses[0] > 0.0); - assert!(result.objf.is_finite()); + assert!(result.objf().is_finite()); // Verify we have AUC predictions - assert!(result.auc_predictions.is_some()); - let auc_preds = result.auc_predictions.unwrap(); + assert!(result.auc_predictions().is_some()); + let auc_preds = result.auc_predictions().unwrap(); assert_eq!(auc_preds.len(), 1); let (time, auc) = auc_preds[0]; @@ -1088,21 +988,7 @@ fn test_auc_modes_comparison() -> Result<()> { let result_zero = problem_zero.optimize()?; // Extract only the second dose (the optimized one at t=12) - let dose_zero = result_zero - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Bolus(bol) if bol.time() == 12.0 => Some(bol.amount()), - Event::Infusion(inf) if inf.time() == 12.0 => Some(inf.amount()), - _ => None, - }) - .collect::>() - }) - .next() - .unwrap(); + let dose_zero = result_zero.doses()[1]; // Mode 2: AUCFromLastDose - target is interval AUC from t=12 to t=24 let target_last = Subject::builder("patient_last") @@ -1127,21 +1013,7 @@ fn test_auc_modes_comparison() -> Result<()> { let result_last = problem_last.optimize()?; // Extract only the second dose (the optimized one at t=12) - let dose_last = result_last - .optimal_subject - .iter() - .flat_map(|occ| { - occ.events() - .iter() - .filter_map(|event| match event { - Event::Bolus(bol) if bol.time() == 12.0 => Some(bol.amount()), - Event::Infusion(inf) if inf.time() == 12.0 => Some(inf.amount()), - _ => None, - }) - .collect::>() - }) - .next() - .unwrap(); + let dose_last = result_last.doses()[1]; // The two modes should recommend DIFFERENT doses for the same target value // because they're measuring different things @@ -1153,14 +1025,14 @@ fn test_auc_modes_comparison() -> Result<()> { eprintln!(" Optimal 2nd dose: {:.1} mg", dose_zero); eprintln!( " AUC prediction: {:.2}", - result_zero.auc_predictions.as_ref().unwrap()[0].1 + result_zero.auc_predictions().as_ref().unwrap()[0].1 ); eprintln!(" "); eprintln!(" AUCFromLastDose (interval 12→24h):"); eprintln!(" Optimal 2nd dose: {:.1} mg", dose_last); eprintln!( " AUC prediction: {:.2}", - result_last.auc_predictions.as_ref().unwrap()[0].1 + result_last.auc_predictions().as_ref().unwrap()[0].1 ); // Verify both modes work @@ -1248,24 +1120,7 @@ fn test_auc_from_last_dose_multiple_observations() -> Result<()> { )?; let result = problem.optimize()?; - let doses: Vec = result - .optimal_subject - .iter() - .map(|occ| { - occ.iter() - .filter(|event| match event { - Event::Bolus(_) => true, - Event::Infusion(_) => true, - _ => false, - }) - .map(|event| match event { - Event::Bolus(bolus) => bolus.amount(), - Event::Infusion(infusion) => infusion.amount(), - _ => 0.0, - }) - }) - .flatten() - .collect(); + let doses: Vec = result.doses(); // Should optimize 2 doses assert_eq!(doses.len(), 2); @@ -1273,8 +1128,8 @@ fn test_auc_from_last_dose_multiple_observations() -> Result<()> { assert!(doses[1] > 0.0); // Should have 2 AUC predictions - assert!(result.auc_predictions.is_some()); - let auc_preds = result.auc_predictions.unwrap(); + assert!(result.auc_predictions().is_some()); + let auc_preds = result.auc_predictions().unwrap(); assert_eq!(auc_preds.len(), 2); // First observation measures AUC from t=0 (first dose) to t=12 @@ -1364,30 +1219,13 @@ fn test_auc_from_last_dose_no_prior_dose() -> Result<()> { )?; let result = problem.optimize()?; - let doses: Vec = result - .optimal_subject - .iter() - .map(|occ| { - occ.iter() - .filter(|event| match event { - Event::Bolus(_) => true, - Event::Infusion(_) => true, - _ => false, - }) - .map(|event| match event { - Event::Bolus(bolus) => bolus.amount(), - Event::Infusion(infusion) => infusion.amount(), - _ => 0.0, - }) - }) - .flatten() - .collect(); + let doses: Vec = result.doses(); assert_eq!(doses.len(), 1); assert!(doses[0] > 0.0); - assert!(result.auc_predictions.is_some()); - let auc_preds = result.auc_predictions.unwrap(); + assert!(result.auc_predictions().is_some()); + let auc_preds = result.auc_predictions().unwrap(); assert_eq!(auc_preds.len(), 1); let (_time, auc) = auc_preds[0]; @@ -1479,24 +1317,7 @@ fn test_dose_range_bounds_respected() -> Result<()> { )?; let result = problem.optimize()?; - let doses: Vec = result - .optimal_subject - .iter() - .map(|occ| { - occ.iter() - .filter(|event| match event { - Event::Bolus(_) => true, - Event::Infusion(_) => true, - _ => false, - }) - .map(|event| match event { - Event::Bolus(bolus) => bolus.amount(), - Event::Infusion(infusion) => infusion.amount(), - _ => 0.0, - }) - }) - .flatten() - .collect(); + let doses: Vec = result.doses(); println!("Optimal dose: {:.1} mg", doses[0]); println!("Dose range: 50-200 mg");