diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index ebcb988f6..2cf7f37cd 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -5,9 +5,10 @@ pub use crate::routines::evaluation::ipm::burke; pub use crate::routines::evaluation::qr; use crate::routines::settings::Settings; -use crate::routines::output::{CycleLog, NPCycle, NPResult}; +use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; use crate::structs::psi::{calculate_psi, Psi}; use crate::structs::theta::Theta; +use crate::structs::weights::Weights; use anyhow::bail; use anyhow::Result; @@ -18,8 +19,6 @@ use pharmsol::prelude::{ use pharmsol::prelude::ErrorModel; -use faer::Col; - use crate::routines::initialization; use crate::routines::expansion::adaptative_grid::adaptative_grid; @@ -35,8 +34,8 @@ pub struct NPAG { ranges: Vec<(f64, f64)>, psi: Psi, theta: Theta, - lambda: Col, - w: Col, + lambda: Weights, + w: Weights, eps: f64, last_objf: f64, objf: f64, @@ -59,8 +58,8 @@ impl Algorithms for NPAG { ranges: settings.parameters().ranges(), psi: Psi::new(), theta: Theta::new(), - lambda: Col::zeros(0), - w: Col::zeros(0), + lambda: Weights::default(), + w: Weights::default(), eps: 0.2, last_objf: -1e30, objf: f64::NEG_INFINITY, @@ -138,7 +137,7 @@ impl Algorithms for NPAG { if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E { self.eps /= 2.; if self.eps <= THETA_E { - let pyl = psi * w; + let pyl = psi * w.weights(); self.f1 = pyl.iter().map(|x| x.ln()).sum(); if (self.f1 - self.f0).abs() <= THETA_F { tracing::info!("The model converged after {} cycles", self.cycle,); @@ -165,15 +164,15 @@ impl Algorithms for NPAG { } // Create state object - let state = NPCycle { - cycle: self.cycle, - objf: -2. * self.objf, - delta_objf: (self.last_objf - self.objf).abs(), - nspp: self.theta.nspp(), - theta: self.theta.clone(), - error_models: self.error_models.clone(), - status: self.status.clone(), - }; + let state = NPCycle::new( + self.cycle, + -2. * self.objf, + self.error_models.clone(), + self.theta.clone(), + self.theta.nspp(), + (self.last_objf - self.objf).abs(), + self.status.clone(), + ); // Write cycle log self.cycle_log.push(state); @@ -199,7 +198,7 @@ impl Algorithms for NPAG { } (self.lambda, _) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), + Ok((lambda, objf)) => (lambda.into(), objf), Err(err) => { bail!("Error in IPM during evaluation: {:?}", err); } @@ -213,11 +212,11 @@ impl Algorithms for NPAG { let max_lambda = self .lambda .iter() - .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc)); + .fold(f64::NEG_INFINITY, |acc, x| x.max(acc)); let mut keep = Vec::::new(); for (index, lam) in self.lambda.iter().enumerate() { - if *lam > max_lambda / 1000_f64 { + if lam > max_lambda / 1000_f64 { keep.push(index); } } @@ -262,7 +261,7 @@ impl Algorithms for NPAG { self.validate_psi()?; (self.lambda, self.objf) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), + Ok((lambda, objf)) => (lambda.into(), objf), Err(err) => { return Err(anyhow::anyhow!( "Error in IPM during condensation: {:?}", @@ -270,7 +269,7 @@ impl Algorithms for NPAG { )); } }; - self.w = self.lambda.clone(); + self.w = self.lambda.clone().into(); Ok(()) } diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index f00886ae4..d31b02e46 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -1,10 +1,11 @@ +use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; +use crate::structs::weights::Weights; use crate::{ algorithms::Status, prelude::{ algorithms::Algorithms, routines::{ evaluation::{ipm::burke, qr}, - output::{CycleLog, NPCycle, NPResult}, settings::Settings, }, }, @@ -13,9 +14,9 @@ use crate::{ theta::Theta, }, }; + use anyhow::bail; use anyhow::Result; -use faer::Col; use faer_ext::IntoNdarray; use pharmsol::{prelude::ErrorModel, ErrorModels}; use pharmsol::{ @@ -37,8 +38,8 @@ pub struct NPOD { equation: E, psi: Psi, theta: Theta, - lambda: Col, - w: Col, + lambda: Weights, + w: Weights, last_objf: f64, objf: f64, cycle: usize, @@ -57,8 +58,8 @@ impl Algorithms for NPOD { equation, psi: Psi::new(), theta: Theta::new(), - lambda: Col::zeros(0), - w: Col::zeros(0), + lambda: Weights::default(), + w: Weights::default(), last_objf: -1e30, objf: f64::NEG_INFINITY, cycle: 0, @@ -157,15 +158,15 @@ impl Algorithms for NPOD { } // Create state object - let state = NPCycle { - cycle: self.cycle, - objf: -2. * self.objf, - delta_objf: (self.last_objf - self.objf).abs(), - nspp: self.theta.nspp(), - theta: self.theta.clone(), - error_models: self.error_models.clone(), - status: self.status.clone(), - }; + let state = NPCycle::new( + self.cycle, + -2. * self.objf, + self.error_models.clone(), + self.theta.clone(), + self.theta.nspp(), + (self.last_objf - self.objf).abs(), + self.status.clone(), + ); // Write cycle log self.cycle_log.push(state); @@ -205,11 +206,11 @@ impl Algorithms for NPOD { let max_lambda = self .lambda .iter() - .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc)); + .fold(f64::NEG_INFINITY, |acc, x| x.max(acc)); let mut keep = Vec::::new(); for (index, lam) in self.lambda.iter().enumerate() { - if *lam > max_lambda / 1000_f64 { + if lam > max_lambda / 1000_f64 { keep.push(index); } } @@ -367,7 +368,7 @@ impl Algorithms for NPOD { fn expansion(&mut self) -> Result<()> { // If no stop signal, add new point to theta based on the optimization of the D function let psi = self.psi().matrix().as_ref().into_ndarray().to_owned(); - let w: Array1 = self.w.clone().iter().cloned().collect(); + let w: Array1 = self.w.clone().iter().collect(); let pyl = psi.dot(&w); // Add new point to theta based on the optimization of the D function diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index e0c19049d..0fcb7a1b0 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -4,10 +4,11 @@ use crate::{ structs::{ psi::{calculate_psi, Psi}, theta::Theta, + weights::Weights, }, }; use anyhow::{Context, Result}; -use faer::Col; + use pharmsol::prelude::{ data::{Data, ErrorModels}, simulator::Equation, @@ -15,7 +16,7 @@ use pharmsol::prelude::{ use crate::routines::evaluation::ipm::burke; use crate::routines::initialization; -use crate::routines::output::{CycleLog, NPResult}; +use crate::routines::output::{cycles::CycleLog, NPResult}; use crate::routines::settings::Settings; /// Posterior probability algorithm @@ -24,7 +25,7 @@ pub struct POSTPROB { equation: E, psi: Psi, theta: Theta, - w: Col, + w: Weights, objf: f64, cycle: usize, status: Status, @@ -40,7 +41,7 @@ impl Algorithms for POSTPROB { equation, psi: Psi::new(), theta: Theta::new(), - w: Col::zeros(0), + w: Weights::default(), objf: f64::INFINITY, cycle: 0, status: Status::Starting, diff --git a/src/routines/evaluation/ipm.rs b/src/routines/evaluation/ipm.rs index d56910a49..cd9b67753 100644 --- a/src/routines/evaluation/ipm.rs +++ b/src/routines/evaluation/ipm.rs @@ -1,4 +1,5 @@ use crate::structs::psi::Psi; +use crate::structs::weights::Weights; use anyhow::bail; use faer::linalg::triangular_solve::solve_lower_triangular_in_place; use faer::linalg::triangular_solve::solve_upper_triangular_in_place; @@ -21,15 +22,15 @@ use rayon::prelude::*; /// /// # Returns /// -/// On success, returns a tuple `(lam, obj)` where: -/// - `lam` is a faer::Col containing the computed probability vector, +/// On success, returns a tuple `(weights, obj)` where: +/// - [Weights] contains the optimized weights (probabilities) for each support point. /// - `obj` is the value of the objective function at the solution. /// /// # Errors /// /// This function returns an error if any step in the optimization (e.g. Cholesky factorization) /// fails. -pub fn burke(psi: &Psi) -> anyhow::Result<(Col, f64)> { +pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> { let mut psi = psi.matrix().to_owned(); // Ensure all entries are finite and make them non-negative. @@ -274,7 +275,7 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Col, f64)> { let lam_sum: f64 = lam.iter().sum(); lam = &lam / lam_sum; - Ok((lam, obj)) + Ok((lam.into(), obj)) } #[cfg(test)] @@ -465,7 +466,11 @@ mod tests { // distribution depends on the optimization algorithm's convergence // Just verify that no single weight dominates excessively (basic sanity check) - let max_weight = lam.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + let max_weight = lam + .weights() + .iter() + .cloned() + .fold(f64::NEG_INFINITY, f64::max); assert!( max_weight < 0.1, "No single weight should dominate in uniform matrix (max weight: {})", diff --git a/src/routines/logger.rs b/src/routines/logger.rs index 91214fd69..2b5d411cc 100644 --- a/src/routines/logger.rs +++ b/src/routines/logger.rs @@ -48,7 +48,7 @@ pub(crate) fn setup_log(settings: &mut Settings) -> Result<()> { let file_layer = match settings.log().write { true => { let layer = fmt::layer() - .with_writer(outputfile.file) + .with_writer(outputfile.file_owned()) .with_ansi(false) .with_timer(timestamper.clone()); diff --git a/src/routines/output/cycles.rs b/src/routines/output/cycles.rs new file mode 100644 index 000000000..536ea892a --- /dev/null +++ b/src/routines/output/cycles.rs @@ -0,0 +1,203 @@ +use anyhow::Result; +use csv::WriterBuilder; +use pharmsol::{ErrorModel, ErrorModels}; + +use crate::{ + algorithms::Status, + prelude::Settings, + routines::output::{median, OutputFile}, + structs::theta::Theta, +}; + +/// An [NPCycle] object contains the summary of a cycle +/// It holds the following information: +/// - `cycle`: The cycle number +/// - `objf`: The objective function value +/// - `gamlam`: The assay noise parameter, either gamma or lambda +/// - `theta`: The support points and their associated probabilities +/// - `nspp`: The number of support points +/// - `delta_objf`: The change in objective function value from last cycle +/// - `converged`: Whether the algorithm has reached convergence +#[derive(Debug, Clone)] +pub struct NPCycle { + cycle: usize, + objf: f64, + error_models: ErrorModels, + theta: Theta, + nspp: usize, + delta_objf: f64, + status: Status, +} + +impl NPCycle { + pub fn new( + cycle: usize, + objf: f64, + error_models: ErrorModels, + theta: Theta, + nspp: usize, + delta_objf: f64, + status: Status, + ) -> Self { + Self { + cycle, + objf, + error_models, + theta, + nspp, + delta_objf, + status, + } + } + + pub fn cycle(&self) -> usize { + self.cycle + } + pub fn objf(&self) -> f64 { + self.objf + } + pub fn error_models(&self) -> &ErrorModels { + &self.error_models + } + pub fn theta(&self) -> &Theta { + &self.theta + } + pub fn nspp(&self) -> usize { + self.nspp + } + pub fn delta_objf(&self) -> f64 { + self.delta_objf + } + pub fn status(&self) -> &Status { + &self.status + } + + pub fn placeholder() -> Self { + Self { + cycle: 0, + objf: 0.0, + error_models: ErrorModels::default(), + theta: Theta::new(), + nspp: 0, + delta_objf: 0.0, + status: Status::Starting, + } + } +} + +/// This holdes a vector of [NPCycle] objects to provide a more detailed log +#[derive(Debug, Clone)] +pub struct CycleLog { + cycles: Vec, +} + +impl CycleLog { + pub fn new() -> Self { + Self { cycles: Vec::new() } + } + + pub fn cycles(&self) -> &[NPCycle] { + &self.cycles + } + + pub fn push(&mut self, cycle: NPCycle) { + self.cycles.push(cycle); + } + + pub fn write(&self, settings: &Settings) -> Result<()> { + tracing::debug!("Writing cycles..."); + let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(false) + .from_writer(&outputfile.file); + + // Write headers + writer.write_field("cycle")?; + writer.write_field("converged")?; + writer.write_field("status")?; + writer.write_field("neg2ll")?; + writer.write_field("nspp")?; + if let Some(first_cycle) = self.cycles.first() { + first_cycle.error_models.iter().try_for_each( + |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> { + match errmod { + ErrorModel::Additive { .. } => { + writer.write_field(format!("gamlam.{}", outeq))?; + } + ErrorModel::Proportional { .. } => { + writer.write_field(format!("gamlam.{}", outeq))?; + } + ErrorModel::None => {} + } + Ok(()) + }, + )?; + } + + let parameter_names = settings.parameters().names(); + for param_name in ¶meter_names { + writer.write_field(format!("{}.mean", param_name))?; + writer.write_field(format!("{}.median", param_name))?; + writer.write_field(format!("{}.sd", param_name))?; + } + + writer.write_record(None::<&[u8]>)?; + + for cycle in &self.cycles { + writer.write_field(format!("{}", cycle.cycle))?; + writer.write_field(format!("{}", cycle.status == Status::Converged))?; + writer.write_field(format!("{}", cycle.status))?; + writer.write_field(format!("{}", cycle.objf))?; + writer + .write_field(format!("{}", cycle.theta.nspp())) + .unwrap(); + + // Write the error models + cycle.error_models.iter().try_for_each( + |(_, errmod): (usize, &ErrorModel)| -> Result<()> { + match errmod { + ErrorModel::Additive { + lambda: _, + poly: _, + lloq: _, + } => { + writer.write_field(format!("{:.5}", errmod.factor()?))?; + } + ErrorModel::Proportional { + gamma: _, + poly: _, + lloq: _, + } => { + writer.write_field(format!("{:.5}", errmod.factor()?))?; + } + ErrorModel::None => {} + } + Ok(()) + }, + )?; + + for param in cycle.theta.matrix().col_iter() { + let param_values: Vec = param.iter().cloned().collect(); + + let mean: f64 = param_values.iter().sum::() / param_values.len() as f64; + let median = median(¶m_values); + let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::() + / (param_values.len() as f64 - 1.0); + + writer.write_field(format!("{}", mean))?; + writer.write_field(format!("{}", median))?; + writer.write_field(format!("{}", std))?; + } + writer.write_record(None::<&[u8]>)?; + } + writer.flush()?; + tracing::debug!("Cycles written to {:?}", &outputfile.relative_path()); + Ok(()) + } +} + +impl Default for CycleLog { + fn default() -> Self { + Self::new() + } +} diff --git a/src/routines/output.rs b/src/routines/output/mod.rs similarity index 59% rename from src/routines/output.rs rename to src/routines/output/mod.rs index 10759383b..d42b63a84 100644 --- a/src/routines/output.rs +++ b/src/routines/output/mod.rs @@ -1,20 +1,28 @@ use crate::algorithms::Status; use crate::prelude::*; +use crate::routines::output::cycles::CycleLog; +use crate::routines::output::predictions::NPPredictions; use crate::routines::settings::Settings; use crate::structs::psi::Psi; use crate::structs::theta::Theta; +use crate::structs::weights::Weights; use anyhow::{bail, Context, Result}; use csv::WriterBuilder; use faer::linalg::zip::IntoView; -use faer::{Col, Mat}; use faer_ext::IntoNdarray; use ndarray::{Array, Array1, Array2, Axis}; use pharmsol::prelude::data::*; -use pharmsol::prelude::simulator::{Equation, Prediction}; +use pharmsol::prelude::simulator::Equation; use serde::Serialize; use std::fs::{create_dir_all, File, OpenOptions}; use std::path::{Path, PathBuf}; +pub mod cycles; +pub mod posterior; +pub mod predictions; + +use posterior::posterior; + /// Defines the result objects from an NPAG run /// An [NPResult] contains the necessary information to generate predictions and summary statistics #[derive(Debug)] @@ -23,7 +31,7 @@ pub struct NPResult { data: Data, theta: Theta, psi: Psi, - w: Col, + w: Weights, objf: f64, cycles: usize, status: Status, @@ -40,7 +48,7 @@ impl NPResult { data: Data, theta: Theta, psi: Psi, - w: Col, + w: Weights, objf: f64, cycles: usize, status: Status, @@ -88,7 +96,7 @@ impl NPResult { } /// Get the weights (probabilities) of the support points - pub fn w(&self) -> &Col { + pub fn weights(&self) -> &Weights { &self.w } @@ -99,13 +107,10 @@ impl NPResult { let idelta: f64 = self.settings.predictions().idelta; let tad = self.settings.predictions().tad; self.cyclelog.write(&self.settings)?; - self.write_obs().context("Failed to write observations")?; self.write_theta().context("Failed to write theta")?; - self.write_obspred() - .context("Failed to write observed-predicted file")?; - self.write_pred(idelta, tad) - .context("Failed to write predictions")?; self.write_covs().context("Failed to write covariates")?; + self.write_predictions(idelta, tad) + .context("Failed to write predictions")?; self.write_posterior() .context("Failed to write posterior")?; } @@ -136,7 +141,14 @@ impl NPResult { .as_mut() .into_ndarray() .to_owned(); - let w: Array1 = self.w.clone().into_view().iter().cloned().collect(); + let w: Array1 = self + .w + .weights() + .clone() + .into_view() + .iter() + .cloned() + .collect(); let psi: Array2 = self.psi.matrix().as_ref().into_ndarray().to_owned(); let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w) @@ -239,7 +251,7 @@ impl NPResult { writer.flush()?; tracing::debug!( "Observations with predictions written to {:?}", - &outputfile.get_relative_path() + &outputfile.relative_path() ); Ok(()) } @@ -250,7 +262,14 @@ impl NPResult { tracing::debug!("Writing population parameter distribution..."); let theta = &self.theta; - let w: Vec = self.w.clone().into_view().iter().cloned().collect(); + let w: Vec = self + .w + .weights() + .clone() + .into_view() + .iter() + .cloned() + .collect(); if w.len() != theta.matrix().nrows() { bail!( @@ -281,7 +300,7 @@ impl NPResult { writer.flush()?; tracing::debug!( "Population parameter distribution written to {:?}", - &outputfile.get_relative_path() + &outputfile.relative_path() ); Ok(()) } @@ -321,227 +340,89 @@ impl NPResult { // Write contents let subjects = self.data.subjects(); - posterior.row_iter().enumerate().for_each(|(i, row)| { - let subject = subjects.get(i).unwrap(); - let id = subject.id(); + posterior + .matrix() + .row_iter() + .enumerate() + .for_each(|(i, row)| { + let subject = subjects.get(i).unwrap(); + let id = subject.id(); - row.iter().enumerate().for_each(|(spp, prob)| { - writer.write_field(id.clone()).unwrap(); - writer.write_field(spp.to_string()).unwrap(); + row.iter().enumerate().for_each(|(spp, prob)| { + writer.write_field(id.clone()).unwrap(); + writer.write_field(spp.to_string()).unwrap(); - theta.matrix().row(spp).iter().for_each(|val| { - writer.write_field(val.to_string()).unwrap(); - }); + theta.matrix().row(spp).iter().for_each(|val| { + writer.write_field(val.to_string()).unwrap(); + }); - writer.write_field(prob.to_string()).unwrap(); - writer.write_record(None::<&[u8]>).unwrap(); + writer.write_field(prob.to_string()).unwrap(); + writer.write_record(None::<&[u8]>).unwrap(); + }); }); - }); writer.flush()?; tracing::debug!( "Posterior parameters written to {:?}", - &outputfile.get_relative_path() + &outputfile.relative_path() ); Ok(()) } - /// Write the observations, which is the reformatted input data - pub fn write_obs(&self) -> Result<()> { - tracing::debug!("Writing observations..."); - let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?; + /// Writes the predictions + pub fn write_predictions(&self, idelta: f64, tad: f64) -> Result<()> { + tracing::debug!("Writing predictions..."); + let posterior = posterior(&self.psi, &self.w)?; + + // Calculate the predictions + let predictions = NPPredictions::calculate( + &self.equation, + &self.data, + self.theta.clone(), + &self.w, + &posterior, + idelta, + tad, + )?; + + // Write (full) predictions to pred.csv + let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?; let mut writer = WriterBuilder::new() .has_headers(true) - .from_writer(&outputfile.file); + .from_writer(&outputfile_pred.file); - #[derive(Serialize)] - struct Row { - id: String, - block: usize, - time: f64, - out: Option, - outeq: usize, + // Write each prediction row + for row in predictions.predictions() { + writer.serialize(row)?; } - for subject in self.data.subjects() { - for occasion in subject.occasions() { - for event in occasion.iter() { - if let Event::Observation(event) = event { - let row = Row { - id: subject.id().clone(), - block: occasion.index(), - time: event.time(), - out: event.value(), - outeq: event.outeq(), - }; - writer.serialize(row)?; - } - } - } - } writer.flush()?; - tracing::debug!( - "Observations written to {:?}", - &outputfile.get_relative_path() + "Predictions written to {:?}", + &outputfile_pred.relative_path() ); - Ok(()) - } - - /// Writes the predictions - pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> { - tracing::debug!("Writing predictions..."); - - // Get necessary data - let theta = self.theta.matrix(); - let w: Vec = self.w.iter().cloned().collect(); - let posterior = posterior(&self.psi, &self.w)?; - - let data = self.data.clone().expand(idelta, tad); - - let subjects = data.subjects(); - - if subjects.len() != posterior.nrows() { - bail!("Number of subjects and number of posterior means do not match"); - }; - // Create the output file and writer for pred.csv - let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?; + // Write observations and predictions to op.csv + let outputfile_op = OutputFile::new(&self.settings.output().path, "op.csv")?; let mut writer = WriterBuilder::new() .has_headers(true) - .from_writer(&outputfile.file); - - // Iterate over each subject and then each support point - for subject in subjects.iter().enumerate() { - let (subject_index, subject) = subject; - - // Get a vector of occasions for this subject, for each predictions - let occasions = subject - .occasions() - .iter() - .flat_map(|o| { - o.events() - .iter() - .filter_map(|e| { - if let Event::Observation(_obs) = e { - Some(o.index()) - } else { - None - } - }) - .collect::>() - }) - .collect::>(); - - // Container for predictions for this subject - // This will hold predictions for each support point - // The outer vector is for each support point - // The inner vector is for the vector of predictions for that support point - let mut predictions: Vec> = Vec::new(); - - // And each support points - for spp in theta.row_iter() { - // Simulate the subject with the current support point - let spp_values = spp.iter().cloned().collect::>(); - let pred = self - .equation - .simulate_subject(subject, &spp_values, None)? - .0 - .get_predictions(); - predictions.push(pred); - } - - if predictions.is_empty() { - continue; // Skip this subject if no predictions are available - } - - // Calculate population mean using - let mut pop_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; - for outer_pred in predictions.iter().enumerate() { - let (i, outer_pred) = outer_pred; - for inner_pred in outer_pred.iter().enumerate() { - let (j, pred) = inner_pred; - pop_mean[j] += pred.prediction() * w[i]; - } - } - - // Calculate population median - let mut pop_median: Vec = Vec::new(); - for j in 0..predictions.first().unwrap().len() { - let mut values: Vec = Vec::new(); - let mut weights: Vec = Vec::new(); - - for (i, outer_pred) in predictions.iter().enumerate() { - values.push(outer_pred[j].prediction()); - weights.push(w[i]); - } - - let median_val = weighted_median(&values, &weights); - pop_median.push(median_val); - } - - // Calculate posterior mean - let mut posterior_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; - for outer_pred in predictions.iter().enumerate() { - let (i, outer_pred) = outer_pred; - for inner_pred in outer_pred.iter().enumerate() { - let (j, pred) = inner_pred; - posterior_mean[j] += pred.prediction() * posterior[(subject_index, i)]; - } - } - - // Calculate posterior median - let mut posterior_median: Vec = Vec::new(); - for j in 0..predictions.first().unwrap().len() { - let mut values: Vec = Vec::new(); - let mut weights: Vec = Vec::new(); - - for (i, outer_pred) in predictions.iter().enumerate() { - values.push(outer_pred[j].prediction()); - weights.push(posterior[(subject_index, i)]); - } - - let median_val = weighted_median(&values, &weights); - posterior_median.push(median_val); - } - - // Structure for the output - #[derive(Debug, Clone, Serialize)] - struct Row { - id: String, - time: f64, - outeq: usize, - block: usize, - pop_mean: f64, - pop_median: f64, - post_mean: f64, - post_median: f64, - } - - for pred in predictions.iter().enumerate() { - let (_, preds) = pred; - for (j, p) in preds.iter().enumerate() { - let row = Row { - id: subject.id().clone(), - time: p.time(), - outeq: p.outeq(), - block: occasions[j], - pop_mean: pop_mean[j], - pop_median: pop_median[j], - post_mean: posterior_mean[j], - post_median: posterior_median[j], - }; - writer.serialize(row)?; - } - } + .from_writer(&outputfile_op.file); + + // Write each prediction row + for row in predictions + .predictions() + .iter() + .filter(|r| r.obs().is_some()) + { + writer.serialize(row)?; } writer.flush()?; tracing::debug!( - "Predictions written to {:?}", - &outputfile.get_relative_path() + "Observed-predicted values written to {:?}", + &outputfile_op.relative_path() ); Ok(()) @@ -611,205 +492,13 @@ impl NPResult { } writer.flush()?; - tracing::debug!( - "Covariates written to {:?}", - &outputfile.get_relative_path() - ); - Ok(()) - } -} - -/// An [NPCycle] object contains the summary of a cycle -/// It holds the following information: -/// - `cycle`: The cycle number -/// - `objf`: The objective function value -/// - `gamlam`: The assay noise parameter, either gamma or lambda -/// - `theta`: The support points and their associated probabilities -/// - `nspp`: The number of support points -/// - `delta_objf`: The change in objective function value from last cycle -/// - `converged`: Whether the algorithm has reached convergence -#[derive(Debug, Clone)] -pub struct NPCycle { - pub cycle: usize, - pub objf: f64, - pub error_models: ErrorModels, - pub theta: Theta, - pub nspp: usize, - pub delta_objf: f64, - pub status: Status, -} - -impl NPCycle { - pub fn new( - cycle: usize, - objf: f64, - error_models: ErrorModels, - theta: Theta, - nspp: usize, - delta_objf: f64, - status: Status, - ) -> Self { - Self { - cycle, - objf, - error_models, - theta, - nspp, - delta_objf, - status, - } - } - - pub fn placeholder() -> Self { - Self { - cycle: 0, - objf: 0.0, - error_models: ErrorModels::default(), - theta: Theta::new(), - nspp: 0, - delta_objf: 0.0, - status: Status::Starting, - } - } -} - -/// This holdes a vector of [NPCycle] objects to provide a more detailed log -#[derive(Debug, Clone)] -pub struct CycleLog { - pub cycles: Vec, -} - -impl CycleLog { - pub fn new() -> Self { - Self { cycles: Vec::new() } - } - - pub fn push(&mut self, cycle: NPCycle) { - self.cycles.push(cycle); - } - - pub fn write(&self, settings: &Settings) -> Result<()> { - tracing::debug!("Writing cycles..."); - let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?; - let mut writer = WriterBuilder::new() - .has_headers(false) - .from_writer(&outputfile.file); - - // Write headers - writer.write_field("cycle")?; - writer.write_field("converged")?; - writer.write_field("status")?; - writer.write_field("neg2ll")?; - writer.write_field("nspp")?; - if let Some(first_cycle) = self.cycles.first() { - first_cycle.error_models.iter().try_for_each( - |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> { - match errmod { - ErrorModel::Additive { .. } => { - writer.write_field(format!("gamlam.{}", outeq))?; - } - ErrorModel::Proportional { .. } => { - writer.write_field(format!("gamlam.{}", outeq))?; - } - ErrorModel::None => {} - } - Ok(()) - }, - )?; - } - - let parameter_names = settings.parameters().names(); - for param_name in ¶meter_names { - writer.write_field(format!("{}.mean", param_name))?; - writer.write_field(format!("{}.median", param_name))?; - writer.write_field(format!("{}.sd", param_name))?; - } - - writer.write_record(None::<&[u8]>)?; - - for cycle in &self.cycles { - writer.write_field(format!("{}", cycle.cycle))?; - writer.write_field(format!("{}", cycle.status == Status::Converged))?; - writer.write_field(format!("{}", cycle.status))?; - writer.write_field(format!("{}", cycle.objf))?; - writer - .write_field(format!("{}", cycle.theta.nspp())) - .unwrap(); - - // Write the error models - cycle.error_models.iter().try_for_each( - |(_, errmod): (usize, &ErrorModel)| -> Result<()> { - match errmod { - ErrorModel::Additive { - lambda: _, - poly: _, - lloq: _, - } => { - writer.write_field(format!("{:.5}", errmod.factor()?))?; - } - ErrorModel::Proportional { - gamma: _, - poly: _, - lloq: _, - } => { - writer.write_field(format!("{:.5}", errmod.factor()?))?; - } - ErrorModel::None => {} - } - Ok(()) - }, - )?; - - for param in cycle.theta.matrix().col_iter() { - let param_values: Vec = param.iter().cloned().collect(); - - let mean: f64 = param_values.iter().sum::() / param_values.len() as f64; - let median = median(param_values.clone()); - let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::() - / (param_values.len() as f64 - 1.0); - - writer.write_field(format!("{}", mean))?; - writer.write_field(format!("{}", median))?; - writer.write_field(format!("{}", std))?; - } - writer.write_record(None::<&[u8]>)?; - } - writer.flush()?; - tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path()); + tracing::debug!("Covariates written to {:?}", &outputfile.relative_path()); Ok(()) } } -impl Default for CycleLog { - fn default() -> Self { - Self::new() - } -} - -/// Calculates the posterior probabilities for each support point given the weights -/// -/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. -pub fn posterior(psi: &Psi, w: &Col) -> Result> { - if psi.matrix().ncols() != w.nrows() { - bail!( - "Number of rows in psi ({}) and number of weights ({}) do not match.", - psi.matrix().nrows(), - w.nrows() - ); - } - - let psi_matrix = psi.matrix(); - let py = psi_matrix * w; - - let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { - psi_matrix.get(i, j) * w.get(j) / py.get(i) - }); - - Ok(posterior) -} - -pub fn median(data: Vec) -> f64 { - let mut data = data.clone(); +pub(crate) fn median(data: &[f64]) -> f64 { + let mut data: Vec = data.to_vec(); data.sort_by(|a, b| a.partial_cmp(b).unwrap()); let size = data.len(); @@ -823,7 +512,7 @@ pub fn median(data: Vec) -> f64 { } } -fn weighted_median(data: &Vec, weights: &Vec) -> f64 { +fn weighted_median(data: &[f64], weights: &[f64]) -> f64 { // Ensure the data and weights arrays have the same length assert_eq!( data.len(), @@ -978,8 +667,8 @@ pub fn posterior_mean_median( /// Contains all the necessary information of an output file #[derive(Debug)] pub struct OutputFile { - pub file: File, - pub relative_path: PathBuf, + file: File, + relative_path: PathBuf, } impl OutputFile { @@ -1004,7 +693,15 @@ impl OutputFile { }) } - pub fn get_relative_path(&self) -> &Path { + pub fn file(&self) -> &File { + &self.file + } + + pub fn file_owned(self) -> File { + self.file + } + + pub fn relative_path(&self) -> &Path { &self.relative_path } } @@ -1016,37 +713,37 @@ mod tests { #[test] fn test_median_odd() { let data = vec![1.0, 3.0, 2.0]; - assert_eq!(median(data), 2.0); + assert_eq!(median(&data), 2.0); } #[test] fn test_median_even() { let data = vec![1.0, 2.0, 3.0, 4.0]; - assert_eq!(median(data), 2.5); + assert_eq!(median(&data), 2.5); } #[test] fn test_median_single() { let data = vec![42.0]; - assert_eq!(median(data), 42.0); + assert_eq!(median(&data), 42.0); } #[test] fn test_median_sorted() { let data = vec![5.0, 10.0, 15.0, 20.0, 25.0]; - assert_eq!(median(data), 15.0); + assert_eq!(median(&data), 15.0); } #[test] fn test_median_unsorted() { let data = vec![10.0, 30.0, 20.0, 50.0, 40.0]; - assert_eq!(median(data), 30.0); + assert_eq!(median(&data), 30.0); } #[test] fn test_median_with_duplicates() { let data = vec![1.0, 2.0, 2.0, 3.0, 4.0]; - assert_eq!(median(data), 2.0); + assert_eq!(median(&data), 2.0); } use super::weighted_median; diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs new file mode 100644 index 000000000..5d49cc3b8 --- /dev/null +++ b/src/routines/output/posterior.rs @@ -0,0 +1,204 @@ +pub use anyhow::{bail, Result}; +use faer::{Col, Mat}; +use serde::{Deserialize, Serialize}; + +use crate::structs::{psi::Psi, weights::Weights}; + +/// Posterior probabilities for each support points +#[derive(Debug, Clone)] +pub struct Posterior { + mat: Mat, +} + +impl Posterior { + /// Create a new Posterior from a matrix + fn new(mat: Mat) -> Self { + Posterior { mat } + } + + /// Calculate the posterior probabilities for each support point given the weights + /// + /// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. + /// /// # Errors + /// Returns an error if the number of rows in `psi` does not match the number of weights in `w`. + /// # Arguments + /// * `psi` - The Psi object containing the matrix of support points. + /// * `w` - The weights for each support point. + /// # Returns + /// A Result containing the Posterior probabilities if successful, or an error if the + /// dimensions do not match. + pub fn calculate(psi: &Psi, w: &Col) -> Result { + if psi.matrix().ncols() != w.nrows() { + bail!( + "Number of rows in psi ({}) and number of weights ({}) do not match.", + psi.matrix().nrows(), + w.nrows() + ); + } + + let psi_matrix = psi.matrix(); + let py = psi_matrix * w; + + let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { + psi_matrix.get(i, j) * w.get(j) / py.get(i) + }); + + Ok(posterior.into()) + } + + /// Get a reference to the underlying matrix + pub fn matrix(&self) -> &Mat { + &self.mat + } + + /// Write the posterior probabilities to a CSV file + /// Each row represents a subject, each column represents a support point + pub fn to_csv(&self, writer: W) -> Result<()> { + let mut csv_writer = csv::Writer::from_writer(writer); + + // Write each row + for i in 0..self.mat.nrows() { + let row: Vec = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect(); + csv_writer.serialize(row)?; + } + + csv_writer.flush()?; + Ok(()) + } + + /// Read posterior probabilities from a CSV file + /// Each row represents a subject, each column represents a support point + pub fn from_csv(reader: R) -> Result { + let mut csv_reader = csv::Reader::from_reader(reader); + let mut rows: Vec> = Vec::new(); + + for result in csv_reader.deserialize() { + let row: Vec = result?; + rows.push(row); + } + + if rows.is_empty() { + bail!("CSV file is empty"); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + bail!("Row {} has {} columns, expected {}", i, row.len(), ncols); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Posterior::new(mat)) + } +} + +/// Convert a matrix to a [Posterior] +impl From> for Posterior { + fn from(mat: Mat) -> Self { + Posterior::new(mat) + } +} + +impl Serialize for Posterior { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(Some(self.mat.nrows()))?; + + // Serialize each row as a vector + for i in 0..self.mat.nrows() { + let row: Vec = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect(); + seq.serialize_element(&row)?; + } + + seq.end() + } +} + +impl<'de> Deserialize<'de> for Posterior { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{SeqAccess, Visitor}; + use std::fmt; + + struct PosteriorVisitor; + + impl<'de> Visitor<'de> for PosteriorVisitor { + type Value = Posterior; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of rows (vectors of f64)") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut rows: Vec> = Vec::new(); + + while let Some(row) = seq.next_element::>()? { + rows.push(row); + } + + if rows.is_empty() { + return Err(serde::de::Error::custom("Empty matrix not allowed")); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + return Err(serde::de::Error::custom(format!( + "Row {} has {} columns, expected {}", + i, + row.len(), + ncols + ))); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Posterior::new(mat)) + } + } + + deserializer.deserialize_seq(PosteriorVisitor) + } +} + +/// Calculates the posterior probabilities for each support point given the weights +/// +/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. +pub fn posterior(psi: &Psi, w: &Weights) -> Result { + if psi.matrix().ncols() != w.len() { + bail!( + "Number of rows in psi ({}) and number of weights ({}) do not match.", + psi.matrix().nrows(), + w.len() + ); + } + + let psi_matrix = psi.matrix(); + let py = psi_matrix * w.weights(); + + let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { + psi_matrix.get(i, j) * w.weights().get(j) / py.get(i) + }); + + Ok(posterior.into()) +} diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs new file mode 100644 index 000000000..d6a961f2d --- /dev/null +++ b/src/routines/output/predictions.rs @@ -0,0 +1,228 @@ +use anyhow::{bail, Result}; +use pharmsol::{prelude::simulator::Prediction, Data, Predictions as PredTrait}; +use serde::{Deserialize, Serialize}; + +use crate::{ + routines::output::{posterior::Posterior, weighted_median}, + structs::{theta::Theta, weights::Weights}, +}; + +// Structure for the output +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NPPredictionRow { + /// The subject ID + id: String, + /// The time of the prediction + time: f64, + /// The output equation number + outeq: usize, + /// The occasion of the prediction + block: usize, + /// The observed value, if any + obs: Option, + /// The population mean prediction + pop_mean: f64, + /// The population median prediction + pop_median: f64, + /// The posterior mean prediction + post_mean: f64, + /// The posterior median prediction + post_median: f64, +} + +impl NPPredictionRow { + pub fn id(&self) -> &str { + &self.id + } + pub fn time(&self) -> f64 { + self.time + } + pub fn outeq(&self) -> usize { + self.outeq + } + pub fn block(&self) -> usize { + self.block + } + pub fn obs(&self) -> Option { + self.obs + } + pub fn pop_mean(&self) -> f64 { + self.pop_mean + } + pub fn pop_median(&self) -> f64 { + self.pop_median + } + pub fn post_mean(&self) -> f64 { + self.post_mean + } + pub fn post_median(&self) -> f64 { + self.post_median + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NPPredictions { + predictions: Vec, +} + +impl IntoIterator for NPPredictions { + type Item = NPPredictionRow; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.predictions.into_iter() + } +} + +impl Default for NPPredictions { + fn default() -> Self { + NPPredictions::new() + } +} + +impl NPPredictions { + pub fn new() -> Self { + NPPredictions { + predictions: Vec::new(), + } + } + + /// Add a [NPPredictionRow] to the predictions + pub fn add(&mut self, row: NPPredictionRow) { + self.predictions.push(row); + } + + /// Get a reference to the predictions + pub fn predictions(&self) -> &[NPPredictionRow] { + &self.predictions + } + + /// Calculate the population and posterior predictions + /// + /// # Arguments + /// * `equation` - The equation to use for simulation + /// * `data` - The data to use for simulation + /// * `theta` - The theta values for the simulation + /// * `w` - The weights for the simulation + /// * `posterior` - The posterior probabilities for the simulation + /// * `idelta` - The delta for the simulation + /// * `tad` - The time after dose for the simulation + /// # Returns + /// A Result containing the NPPredictions or an error + pub fn calculate( + equation: &impl pharmsol::prelude::simulator::Equation, + data: &Data, + theta: Theta, + w: &Weights, + posterior: &Posterior, + idelta: f64, + tad: f64, + ) -> Result { + // Create a new NPPredictions instance + let mut container = NPPredictions::new(); + + // Expand data + let data = data.clone().expand(idelta, tad); + let subjects = data.subjects(); + + if subjects.len() != posterior.matrix().nrows() { + bail!("Number of subjects and number of posterior means do not match"); + }; + + // Iterate over each subject and then each support point + for subject in subjects.iter().enumerate() { + let (subject_index, subject) = subject; + + // Container for predictions for this subject + // This will hold predictions for each support point + // The outer vector is for each support point + // The inner vector is for the vector of predictions for that support point + let mut predictions: Vec> = Vec::new(); + + // And each support points + for spp in theta.matrix().row_iter() { + // Simulate the subject with the current support point + let spp_values = spp.iter().cloned().collect::>(); + let pred = equation + .simulate_subject(subject, &spp_values, None)? + .0 + .get_predictions(); + predictions.push(pred); + } + + if predictions.is_empty() { + continue; // Skip this subject if no predictions are available + } + + // Calculate population mean using + let mut pop_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; + for outer_pred in predictions.iter().enumerate() { + let (i, outer_pred) = outer_pred; + for inner_pred in outer_pred.iter().enumerate() { + let (j, pred) = inner_pred; + pop_mean[j] += pred.prediction() * w[i]; + } + } + + // Calculate population median + let mut pop_median: Vec = Vec::new(); + for j in 0..predictions.first().unwrap().len() { + let mut values: Vec = Vec::new(); + let mut weights: Vec = Vec::new(); + + for (i, outer_pred) in predictions.iter().enumerate() { + values.push(outer_pred[j].prediction()); + weights.push(w[i]); + } + + let median_val = weighted_median(&values, &weights); + pop_median.push(median_val); + } + + // Calculate posterior mean + let mut posterior_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; + for outer_pred in predictions.iter().enumerate() { + let (i, outer_pred) = outer_pred; + for inner_pred in outer_pred.iter().enumerate() { + let (j, pred) = inner_pred; + posterior_mean[j] += pred.prediction() * posterior.matrix()[(subject_index, i)]; + } + } + + // Calculate posterior median + let mut posterior_median: Vec = Vec::new(); + for j in 0..predictions.first().unwrap().len() { + let mut values: Vec = Vec::new(); + let mut weights: Vec = Vec::new(); + + for (i, outer_pred) in predictions.iter().enumerate() { + values.push(outer_pred[j].prediction()); + weights.push(posterior.matrix()[(subject_index, i)]); + } + + let median_val = weighted_median(&values, &weights); + posterior_median.push(median_val); + } + + for pred in predictions.iter().enumerate() { + let (_, preds) = pred; + for (j, p) in preds.iter().enumerate() { + let row = NPPredictionRow { + id: subject.id().clone(), + time: p.time(), + outeq: p.outeq(), + block: p.occasion(), + obs: p.observation(), + pop_mean: pop_mean[j], + pop_median: pop_median[j], + post_mean: posterior_mean[j], + post_median: posterior_median[j], + }; + container.add(row); + } + } + } + + Ok(container) + } +} diff --git a/src/routines/settings.rs b/src/routines/settings.rs index fd84d34e1..e065ee91e 100644 --- a/src/routines/settings.rs +++ b/src/routines/settings.rs @@ -134,7 +134,7 @@ impl Settings { let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?; let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?; - let mut file = outputfile.file; + let mut file = outputfile.file_owned(); std::io::Write::write_all(&mut file, serialized.as_bytes())?; Ok(()) } diff --git a/src/structs/mod.rs b/src/structs/mod.rs index 62886f8f0..eb24bf1e4 100644 --- a/src/structs/mod.rs +++ b/src/structs/mod.rs @@ -1,2 +1,3 @@ pub mod psi; pub mod theta; +pub mod weights; diff --git a/src/structs/psi.rs b/src/structs/psi.rs index 3dfa61a10..c63756cae 100644 --- a/src/structs/psi.rs +++ b/src/structs/psi.rs @@ -1,3 +1,4 @@ +use anyhow::bail; use anyhow::Result; use faer::Mat; use faer_ext::IntoFaer; @@ -7,6 +8,7 @@ use pharmsol::prelude::simulator::psi; use pharmsol::Data; use pharmsol::Equation; use pharmsol::ErrorModels; +use serde::{Deserialize, Serialize}; use super::theta::Theta; @@ -53,6 +55,54 @@ impl Psi { .unwrap(); } } + + /// Write the psi matrix to a CSV writer + /// Each row represents a subject, each column represents a support point + pub fn to_csv(&self, writer: W) -> Result<()> { + let mut csv_writer = csv::Writer::from_writer(writer); + + // Write each row + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + csv_writer.serialize(row)?; + } + + csv_writer.flush()?; + Ok(()) + } + + /// Read psi matrix from a CSV reader + /// Each row represents a subject, each column represents a support point + pub fn from_csv(reader: R) -> Result { + let mut csv_reader = csv::Reader::from_reader(reader); + let mut rows: Vec> = Vec::new(); + + for result in csv_reader.deserialize() { + let row: Vec = result?; + rows.push(row); + } + + if rows.is_empty() { + bail!("CSV file is empty"); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + bail!("Row {} has {} columns, expected {}", i, row.len(), ncols); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Psi { matrix: mat }) + } } impl Default for Psi { @@ -88,6 +138,84 @@ impl From<&Array2> for Psi { } } +impl Serialize for Psi { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?; + + // Serialize each row as a vector + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + seq.serialize_element(&row)?; + } + + seq.end() + } +} + +impl<'de> Deserialize<'de> for Psi { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{SeqAccess, Visitor}; + use std::fmt; + + struct PsiVisitor; + + impl<'de> Visitor<'de> for PsiVisitor { + type Value = Psi; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of rows (vectors of f64)") + } + + fn visit_seq(self, mut seq: A) -> std::result::Result + where + A: SeqAccess<'de>, + { + let mut rows: Vec> = Vec::new(); + + while let Some(row) = seq.next_element::>()? { + rows.push(row); + } + + if rows.is_empty() { + return Err(serde::de::Error::custom("Empty matrix not allowed")); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + return Err(serde::de::Error::custom(format!( + "Row {} has {} columns, expected {}", + i, + row.len(), + ncols + ))); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Psi { matrix: mat }) + } + } + + deserializer.deserialize_seq(PsiVisitor) + } +} + pub(crate) fn calculate_psi( equation: &impl Equation, subjects: &Data, diff --git a/src/structs/theta.rs b/src/structs/theta.rs index 2fbfac64a..efb8c50ff 100644 --- a/src/structs/theta.rs +++ b/src/structs/theta.rs @@ -1,6 +1,8 @@ use std::fmt::Debug; +use anyhow::{bail, Result}; use faer::Mat; +use serde::{Deserialize, Serialize}; use crate::prelude::Parameters; @@ -108,6 +110,58 @@ impl Theta { .unwrap(); } } + + /// Write the theta matrix to a CSV writer + /// Each row represents a support point, each column represents a parameter + pub fn to_csv(&self, writer: W) -> Result<()> { + let mut csv_writer = csv::Writer::from_writer(writer); + + // Write each row + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + csv_writer.serialize(row)?; + } + + csv_writer.flush()?; + Ok(()) + } + + /// Read theta matrix from a CSV reader + /// Each row represents a support point, each column represents a parameter + /// Note: This only reads the matrix values, not the parameter metadata + pub fn from_csv(reader: R) -> Result { + let mut csv_reader = csv::Reader::from_reader(reader); + let mut rows: Vec> = Vec::new(); + + for result in csv_reader.deserialize() { + let row: Vec = result?; + rows.push(row); + } + + if rows.is_empty() { + bail!("CSV file is empty"); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + bail!("Row {} has {} columns, expected {}", i, row.len(), ncols); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + // Create empty parameters - user will need to set these separately + let parameters = Parameters::new(); + + Ok(Theta::from_parts(mat, parameters)) + } } impl Debug for Theta { @@ -132,6 +186,87 @@ impl Debug for Theta { } } +impl Serialize for Theta { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?; + + // Serialize each row as a vector + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + seq.serialize_element(&row)?; + } + + seq.end() + } +} + +impl<'de> Deserialize<'de> for Theta { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{SeqAccess, Visitor}; + use std::fmt; + + struct ThetaVisitor; + + impl<'de> Visitor<'de> for ThetaVisitor { + type Value = Theta; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of rows (vectors of f64)") + } + + fn visit_seq(self, mut seq: A) -> std::result::Result + where + A: SeqAccess<'de>, + { + let mut rows: Vec> = Vec::new(); + + while let Some(row) = seq.next_element::>()? { + rows.push(row); + } + + if rows.is_empty() { + return Err(serde::de::Error::custom("Empty matrix not allowed")); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + return Err(serde::de::Error::custom(format!( + "Row {} has {} columns, expected {}", + i, + row.len(), + ncols + ))); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + // Create empty parameters - user will need to set these separately + let parameters = Parameters::new(); + + Ok(Theta::from_parts(mat, parameters)) + } + } + + deserializer.deserialize_seq(ThetaVisitor) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/structs/weights.rs b/src/structs/weights.rs new file mode 100644 index 000000000..ce3c81082 --- /dev/null +++ b/src/structs/weights.rs @@ -0,0 +1,102 @@ +use faer::Col; +use serde::{Deserialize, Serialize}; +use std::ops::{Index, IndexMut}; + +/// The weight (probabilities) for each support point in the model. +/// +/// This struct is used to hold the weights for each support point in the model. +/// It is a thin wrapper around [faer::Col] to provide additional functionality and context +#[derive(Debug, Clone)] +pub struct Weights { + weights: Col, +} + +impl Default for Weights { + fn default() -> Self { + Self { + weights: Col::from_fn(0, |_| 0.0), + } + } +} + +impl Weights { + pub fn new(weights: Col) -> Self { + Self { weights } + } + + /// Create a new [Weights] instance from a vector of weights. + pub fn from_vec(weights: Vec) -> Self { + Self { + weights: Col::from_fn(weights.len(), |i| weights[i]), + } + } + + /// Get a reference to the weights. + pub fn weights(&self) -> &Col { + &self.weights + } + + /// Get a mutable reference to the weights. + pub fn weights_mut(&mut self) -> &mut Col { + &mut self.weights + } + + /// Get the number of weights. + pub fn len(&self) -> usize { + self.weights.nrows() + } + + /// Get a vector representation of the weights. + pub fn to_vec(&self) -> Vec { + self.weights.iter().cloned().collect() + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.weights.iter().cloned() + } +} + +impl Serialize for Weights { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.to_vec().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Weights { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let weights_vec = Vec::::deserialize(deserializer)?; + Ok(Self::from_vec(weights_vec)) + } +} + +impl From> for Weights { + fn from(weights: Vec) -> Self { + Self::from_vec(weights) + } +} + +impl From> for Weights { + fn from(weights: Col) -> Self { + Self { weights } + } +} + +impl Index for Weights { + type Output = f64; + + fn index(&self, index: usize) -> &Self::Output { + &self.weights[index] + } +} + +impl IndexMut for Weights { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.weights[index] + } +}