diff --git a/Cargo.toml b/Cargo.toml index b8de530fb..164c98b25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [ ] } faer = "0.23.1" faer-ext = { version = "0.7.1", features = ["nalgebra", "ndarray"] } -pharmsol = "=0.20.0" +pharmsol = "=0.21.0" rand = "0.9.0" anyhow = "1.0.100" rayon = "1.10.0" diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index 556713691..bbd92febe 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -48,7 +48,7 @@ fn main() -> Result<()> { settings.initialize_logs()?; let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv")?; let mut algorithm = dispatch_algorithm(settings, eq, data)?; - let result = algorithm.fit()?; + let mut result = algorithm.fit()?; result.write_outputs()?; Ok(()) diff --git a/examples/drusano/main.rs b/examples/drusano/main.rs index 2475000b0..2262b249a 100644 --- a/examples/drusano/main.rs +++ b/examples/drusano/main.rs @@ -137,7 +137,7 @@ fn main() -> Result<()> { algorithm.initialize().unwrap(); algorithm.fit().unwrap(); // while !algorithm.next_cycle().unwrap() {} - let result = algorithm.into_npresult(); + let mut result = algorithm.into_npresult()?; result.write_outputs().unwrap(); Ok(()) } diff --git a/examples/iov/main.rs b/examples/iov/main.rs index 75ca02255..ffa3f39d2 100644 --- a/examples/iov/main.rs +++ b/examples/iov/main.rs @@ -52,7 +52,7 @@ fn main() -> Result<()> { let data = data::read_pmetrics("examples/iov/test.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap(); algorithm.initialize().unwrap(); - let result = algorithm.fit().unwrap(); + let mut result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); Ok(()) diff --git a/examples/meta/main.rs b/examples/meta/main.rs index b809c9d24..f2784a7c3 100644 --- a/examples/meta/main.rs +++ b/examples/meta/main.rs @@ -64,6 +64,6 @@ fn main() { let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); // let result = algorithm.fit().unwrap(); algorithm.initialize().unwrap(); - let result = algorithm.fit().unwrap(); + let mut result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/examples/neely/main.rs b/examples/neely/main.rs index 8064cab02..47a8894a6 100644 --- a/examples/neely/main.rs +++ b/examples/neely/main.rs @@ -96,6 +96,6 @@ fn main() { settings.initialize_logs().unwrap(); let data = data::read_pmetrics("examples/neely/data.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, ode, data).unwrap(); - let result = algorithm.fit().unwrap(); + let mut result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/examples/new_iov/main.rs b/examples/new_iov/main.rs index d6b309a0f..e23ff50a6 100644 --- a/examples/new_iov/main.rs +++ b/examples/new_iov/main.rs @@ -55,6 +55,6 @@ fn main() { let data = data::read_pmetrics("examples/new_iov/data.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap(); algorithm.initialize().unwrap(); - let result = algorithm.fit().unwrap(); + let mut result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/examples/theophylline/main.rs b/examples/theophylline/main.rs index 9e9e0db6f..67e2b5724 100644 --- a/examples/theophylline/main.rs +++ b/examples/theophylline/main.rs @@ -53,6 +53,6 @@ fn main() { let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); // let result = algorithm.fit().unwrap(); algorithm.initialize().unwrap(); - let result = algorithm.fit().unwrap(); + let mut result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/examples/two_eq_lag/main.rs b/examples/two_eq_lag/main.rs index 74f48aa26..b9226003d 100644 --- a/examples/two_eq_lag/main.rs +++ b/examples/two_eq_lag/main.rs @@ -89,7 +89,7 @@ fn main() { settings.initialize_logs().unwrap(); let data = data::read_pmetrics("examples/two_eq_lag/two_eq_lag.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); - let result = algorithm.fit().unwrap(); + let mut result = algorithm.fit().unwrap(); // algorithm.initialize().unwrap(); // while !algorithm.next_cycle().unwrap() {} // let result = algorithm.into_npresult(); diff --git a/examples/vanco_sde/main.rs b/examples/vanco_sde/main.rs index d1ee4a040..94459d93e 100644 --- a/examples/vanco_sde/main.rs +++ b/examples/vanco_sde/main.rs @@ -78,6 +78,6 @@ fn main() { let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap(); algorithm.initialize().unwrap(); - let result = algorithm.fit().unwrap(); + let mut result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 7d8d48fbd..ab5c16912 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -304,11 +304,11 @@ pub trait Algorithms: Sync + Send + 'static { Status::Stop(_) => break, } } - Ok(self.into_npresult()) + Ok(self.into_npresult()?) } #[allow(clippy::wrong_self_convention)] - fn into_npresult(&self) -> NPResult; + fn into_npresult(&self) -> Result>; } pub fn dispatch_algorithm( diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index c466a0daa..68ed04693 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -77,7 +77,7 @@ impl Algorithms for NPAG { fn equation(&self) -> &E { &self.equation } - fn into_npresult(&self) -> NPResult { + fn into_npresult(&self) -> Result> { NPResult::new( self.equation.clone(), self.data.clone(), @@ -329,17 +329,13 @@ impl Algorithms for NPAG { let (lambda_up, objf_up) = match burke(&psi_up) { Ok((lambda, objf)) => (lambda, objf), Err(err) => { - //todo: write out report - return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err)); + bail!("Error in IPM during optim: {:?}", err); } }; let (lambda_down, objf_down) = match burke(&psi_down) { Ok((lambda, objf)) => (lambda, objf), Err(err) => { - //todo: write out report - //panic!("Error in IPM: {:?}", err); - return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err)); - //(Array1::zeros(1), f64::NEG_INFINITY) + bail!("Error in IPM during optim: {:?}", err); } }; if objf_up > self.objf { diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 9b3d2d74d..ed962d971 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -73,7 +73,7 @@ impl Algorithms for NPOD { data, })) } - fn into_npresult(&self) -> NPResult { + fn into_npresult(&self) -> Result> { NPResult::new( self.equation.clone(), self.data.clone(), @@ -333,17 +333,13 @@ impl Algorithms for NPOD { let (lambda_up, objf_up) = match burke(&psi_up) { Ok((lambda, objf)) => (lambda, objf), Err(err) => { - //todo: write out report - return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err)); + bail!("Error in IPM during optim: {:?}", err); } }; let (lambda_down, objf_down) = match burke(&psi_down) { Ok((lambda, objf)) => (lambda, objf), Err(err) => { - //todo: write out report - //panic!("Error in IPM: {:?}", err); - return Err(anyhow::anyhow!("Error in IPM during optim: {:?}", err)); - //(Array1::zeros(1), f64::NEG_INFINITY) + bail!("Error in IPM during optim: {:?}", err); } }; if objf_up > self.objf { diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index 3d0325d58..496b36e28 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -51,7 +51,7 @@ impl Algorithms for POSTPROB { cyclelog: CycleLog::new(), })) } - fn into_npresult(&self) -> NPResult { + fn into_npresult(&self) -> Result> { NPResult::new( self.equation.clone(), self.data.clone(), diff --git a/src/bestdose/predictions.rs b/src/bestdose/predictions.rs index 307ea9fae..740a06690 100644 --- a/src/bestdose/predictions.rs +++ b/src/bestdose/predictions.rs @@ -319,7 +319,7 @@ pub fn calculate_final_predictions( let concentration_preds = NPPredictions::calculate( &problem.eq, &Data::new(vec![target_with_optimal.clone()]), - problem.theta.clone(), + &problem.theta, weights, &posterior, 0.0, diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 2680f550d..2a9e43d69 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -1,6 +1,7 @@ use crate::algorithms::{Status, StopReason}; use crate::prelude::*; use crate::routines::output::cycles::CycleLog; +use crate::routines::output::posterior::Posterior; use crate::routines::output::predictions::NPPredictions; use crate::routines::settings::Settings; use crate::structs::psi::Psi; @@ -36,15 +37,18 @@ pub struct NPResult { objf: f64, cycles: usize, status: Status, - par_names: Vec, settings: Settings, cyclelog: CycleLog, + predictions: Option, + posterior: Posterior, } #[allow(clippy::too_many_arguments)] impl NPResult { /// Create a new NPResult object - pub fn new( + /// + /// This will also calculate the [Posterior] structure and add it to the NPResult + pub(crate) fn new( equation: E, data: Data, theta: Theta, @@ -55,12 +59,12 @@ impl NPResult { status: Status, settings: Settings, cyclelog: CycleLog, - ) -> Self { - // TODO: Add support for fixed and constant parameters - - let par_names = settings.parameters().names(); + ) -> Result { + // Calculate the posterior probabilities + let posterior = posterior(&psi, &w) + .context("Failed to calculate posterior during initialization of NPResult")?; - Self { + let result = Self { equation, data, theta, @@ -69,10 +73,13 @@ impl NPResult { objf, cycles, status, - par_names, settings, cyclelog, - } + predictions: None, + posterior, + }; + + Ok(result) } pub fn cycles(&self) -> usize { @@ -91,6 +98,18 @@ impl NPResult { &self.theta } + pub fn data(&self) -> &Data { + &self.data + } + + pub fn cycle_log(&self) -> &CycleLog { + &self.cyclelog + } + + pub fn settings(&self) -> &Settings { + &self.settings + } + /// Get the [Psi] structure pub fn psi(&self) -> &Psi { &self.psi @@ -101,7 +120,24 @@ impl NPResult { &self.w } - pub fn write_outputs(&self) -> Result<()> { + /// Calculate and store the [NPPredictions] in the [NPResult] + /// + /// This will overwrite any existing predictions stored in the result! + pub fn calculate_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> { + let predictions = NPPredictions::calculate( + &self.equation, + &self.data, + &self.theta, + &self.w, + &self.posterior, + idelta, + tad, + )?; + self.predictions = Some(predictions); + Ok(()) + } + + pub fn write_outputs(&mut self) -> Result<()> { if self.settings.output().write { tracing::debug!("Writing outputs to {:?}", self.settings.output().path); self.settings.write()?; @@ -288,7 +324,7 @@ impl NPResult { .from_writer(&outputfile.file); // Create the headers - let mut theta_header = self.par_names.clone(); + let mut theta_header = self.settings.parameters().names(); theta_header.push("prob".to_string()); writer.write_record(&theta_header)?; @@ -310,11 +346,9 @@ impl NPResult { pub fn write_posterior(&self) -> Result<()> { tracing::debug!("Writing posterior parameter probabilities..."); let theta = &self.theta; - let w = &self.w; - let psi = &self.psi; // Calculate the posterior probabilities - let posterior = posterior(psi, w)?; + let posterior = self.posterior.clone(); // Create the output folder if it doesn't exist let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") { @@ -372,21 +406,15 @@ impl NPResult { } /// Writes the predictions - pub fn write_predictions(&self, idelta: f64, tad: f64) -> Result<()> { + pub fn write_predictions(&mut self, idelta: f64, tad: f64) -> Result<()> { tracing::debug!("Writing predictions..."); - let posterior = posterior(&self.psi, &self.w)?; + self.calculate_predictions(idelta, tad)?; - // Calculate the predictions - let predictions = NPPredictions::calculate( - &self.equation, - &self.data, - self.theta.clone(), - &self.w, - &posterior, - idelta, - tad, - )?; + let predictions = self + .predictions + .as_ref() + .expect("Predictions should have been calculated, but are of type None."); // Write (full) predictions to pred.csv let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?; diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs index d755e0f11..64f78e41b 100644 --- a/src/routines/output/predictions.rs +++ b/src/routines/output/predictions.rs @@ -7,7 +7,11 @@ use crate::{ structs::{theta::Theta, weights::Weights}, }; -// Structure for the output +/// Container for the multiple model estimated predictions +/// +/// Each row contains the predictions for a single time point for a single subject +/// It includes the population and posterior mean and median predictions +/// These are defined by the mean and median of the prediction for each model, weighted by the population or posterior weights #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NPPredictionRow { /// The subject ID @@ -60,6 +64,10 @@ impl NPPredictionRow { pub fn post_median(&self) -> f64 { self.post_median } + + pub fn censoring(&self) -> Censor { + self.cens + } } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -114,7 +122,7 @@ impl NPPredictions { pub fn calculate( equation: &impl pharmsol::prelude::simulator::Equation, data: &Data, - theta: Theta, + theta: &Theta, w: &Weights, posterior: &Posterior, idelta: f64, diff --git a/src/structs/theta.rs b/src/structs/theta.rs index b03a873b6..37d9f65e7 100644 --- a/src/structs/theta.rs +++ b/src/structs/theta.rs @@ -4,7 +4,7 @@ use anyhow::{bail, Result}; use faer::Mat; use serde::{Deserialize, Serialize}; -use crate::prelude::Parameters; +use crate::{prelude::Parameters, structs::weights::Weights}; /// [Theta] is a structure that holds the support points /// These represent the joint population parameter distribution @@ -150,6 +150,36 @@ impl Theta { } } + /// Write the matrix to a CSV file with weights + pub fn write_with_weights(&self, path: &str, weights: &Weights) -> Result<()> { + if self.nspp() != weights.len() { + bail!( + "Number of support points ({}) does not match number of weights ({})", + self.nspp(), + weights.len() + ); + } + + let mut writer = csv::Writer::from_path(path)?; + + let header: Vec = self + .parameters + .names() + .iter() + .cloned() + .chain(std::iter::once("prob".to_string())) + .collect(); + + writer.write_record(header)?; + + for (row_idx, row) in self.matrix.row_iter().enumerate() { + let mut record: Vec = row.iter().map(|x| x.to_string()).collect(); + record.push(weights[row_idx].to_string()); + writer.write_record(record)?; + } + Ok(()) + } + /// Write the theta matrix to a CSV writer /// Each row represents a support point, each column represents a parameter pub fn to_csv(&self, writer: W) -> Result<()> {