From 83e5a863241697cb3d8b8f35967bdd45835cb23b Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 10:48:11 +0200 Subject: [PATCH 01/15] Serde for Posterior, Theta, Psi Also moves outputs from a monolithic file to a module --- src/routines/{output.rs => output/mod.rs} | 62 +++----- src/routines/output/posterior.rs | 173 ++++++++++++++++++++++ src/structs/psi.rs | 128 ++++++++++++++++ src/structs/theta.rs | 135 +++++++++++++++++ 4 files changed, 460 insertions(+), 38 deletions(-) rename src/routines/{output.rs => output/mod.rs} (96%) create mode 100644 src/routines/output/posterior.rs diff --git a/src/routines/output.rs b/src/routines/output/mod.rs similarity index 96% rename from src/routines/output.rs rename to src/routines/output/mod.rs index 10759383b..921bc3983 100644 --- a/src/routines/output.rs +++ b/src/routines/output/mod.rs @@ -6,7 +6,7 @@ use crate::structs::theta::Theta; use anyhow::{bail, Context, Result}; use csv::WriterBuilder; use faer::linalg::zip::IntoView; -use faer::{Col, Mat}; +use faer::Col; use faer_ext::IntoNdarray; use ndarray::{Array, Array1, Array2, Axis}; use pharmsol::prelude::data::*; @@ -15,6 +15,10 @@ use serde::Serialize; use std::fs::{create_dir_all, File, OpenOptions}; use std::path::{Path, PathBuf}; +pub mod posterior; + +use posterior::posterior; + /// Defines the result objects from an NPAG run /// An [NPResult] contains the necessary information to generate predictions and summary statistics #[derive(Debug)] @@ -321,22 +325,26 @@ impl NPResult { // Write contents let subjects = self.data.subjects(); - posterior.row_iter().enumerate().for_each(|(i, row)| { - let subject = subjects.get(i).unwrap(); - let id = subject.id(); + posterior + .matrix() + .row_iter() + .enumerate() + .for_each(|(i, row)| { + let subject = subjects.get(i).unwrap(); + let id = subject.id(); - row.iter().enumerate().for_each(|(spp, prob)| { - writer.write_field(id.clone()).unwrap(); - writer.write_field(spp.to_string()).unwrap(); + row.iter().enumerate().for_each(|(spp, prob)| { + writer.write_field(id.clone()).unwrap(); + writer.write_field(spp.to_string()).unwrap(); - theta.matrix().row(spp).iter().for_each(|val| { - writer.write_field(val.to_string()).unwrap(); - }); + theta.matrix().row(spp).iter().for_each(|val| { + writer.write_field(val.to_string()).unwrap(); + }); - writer.write_field(prob.to_string()).unwrap(); - writer.write_record(None::<&[u8]>).unwrap(); + writer.write_field(prob.to_string()).unwrap(); + writer.write_record(None::<&[u8]>).unwrap(); + }); }); - }); writer.flush()?; tracing::debug!( @@ -403,7 +411,7 @@ impl NPResult { let subjects = data.subjects(); - if subjects.len() != posterior.nrows() { + if subjects.len() != posterior.matrix().nrows() { bail!("Number of subjects and number of posterior means do not match"); }; @@ -488,7 +496,7 @@ impl NPResult { let (i, outer_pred) = outer_pred; for inner_pred in outer_pred.iter().enumerate() { let (j, pred) = inner_pred; - posterior_mean[j] += pred.prediction() * posterior[(subject_index, i)]; + posterior_mean[j] += pred.prediction() * posterior.matrix()[(subject_index, i)]; } } @@ -500,7 +508,7 @@ impl NPResult { for (i, outer_pred) in predictions.iter().enumerate() { values.push(outer_pred[j].prediction()); - weights.push(posterior[(subject_index, i)]); + weights.push(posterior.matrix()[(subject_index, i)]); } let median_val = weighted_median(&values, &weights); @@ -786,28 +794,6 @@ impl Default for CycleLog { } } -/// Calculates the posterior probabilities for each support point given the weights -/// -/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. -pub fn posterior(psi: &Psi, w: &Col) -> Result> { - if psi.matrix().ncols() != w.nrows() { - bail!( - "Number of rows in psi ({}) and number of weights ({}) do not match.", - psi.matrix().nrows(), - w.nrows() - ); - } - - let psi_matrix = psi.matrix(); - let py = psi_matrix * w; - - let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { - psi_matrix.get(i, j) * w.get(j) / py.get(i) - }); - - Ok(posterior) -} - pub fn median(data: Vec) -> f64 { let mut data = data.clone(); data.sort_by(|a, b| a.partial_cmp(b).unwrap()); diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs new file mode 100644 index 000000000..f8f08b775 --- /dev/null +++ b/src/routines/output/posterior.rs @@ -0,0 +1,173 @@ +pub use anyhow::{bail, Result}; +use faer::{Col, Mat}; +use serde::{Deserialize, Serialize}; + +use crate::structs::psi::Psi; + +/// Posterior probabilities for each support points +#[derive(Debug, Clone)] +pub struct Posterior { + mat: Mat, +} + +impl Posterior { + /// Create a new Posterior from a matrix + pub fn new(mat: Mat) -> Self { + Posterior { mat } + } + + /// Get the underlying matrix + pub fn matrix(&self) -> &Mat { + &self.mat + } + + /// Write the posterior probabilities to a CSV file + /// Each row represents a subject, each column represents a support point + pub fn to_csv(&self, writer: W) -> Result<()> { + let mut csv_writer = csv::Writer::from_writer(writer); + + // Write each row + for i in 0..self.mat.nrows() { + let row: Vec = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect(); + csv_writer.serialize(row)?; + } + + csv_writer.flush()?; + Ok(()) + } + + /// Read posterior probabilities from a CSV file + /// Each row represents a subject, each column represents a support point + pub fn from_csv(reader: R) -> Result { + let mut csv_reader = csv::Reader::from_reader(reader); + let mut rows: Vec> = Vec::new(); + + for result in csv_reader.deserialize() { + let row: Vec = result?; + rows.push(row); + } + + if rows.is_empty() { + bail!("CSV file is empty"); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + bail!("Row {} has {} columns, expected {}", i, row.len(), ncols); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Posterior::new(mat)) + } +} + +impl From> for Posterior { + fn from(mat: Mat) -> Self { + Posterior::new(mat) + } +} + +impl Serialize for Posterior { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(Some(self.mat.nrows()))?; + + // Serialize each row as a vector + for i in 0..self.mat.nrows() { + let row: Vec = (0..self.mat.ncols()).map(|j| *self.mat.get(i, j)).collect(); + seq.serialize_element(&row)?; + } + + seq.end() + } +} + +impl<'de> Deserialize<'de> for Posterior { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{SeqAccess, Visitor}; + use std::fmt; + + struct PosteriorVisitor; + + impl<'de> Visitor<'de> for PosteriorVisitor { + type Value = Posterior; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of rows (vectors of f64)") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut rows: Vec> = Vec::new(); + + while let Some(row) = seq.next_element::>()? { + rows.push(row); + } + + if rows.is_empty() { + return Err(serde::de::Error::custom("Empty matrix not allowed")); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + return Err(serde::de::Error::custom(format!( + "Row {} has {} columns, expected {}", + i, + row.len(), + ncols + ))); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Posterior::new(mat)) + } + } + + deserializer.deserialize_seq(PosteriorVisitor) + } +} + +/// Calculates the posterior probabilities for each support point given the weights +/// +/// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. +pub fn posterior(psi: &Psi, w: &Col) -> Result { + if psi.matrix().ncols() != w.nrows() { + bail!( + "Number of rows in psi ({}) and number of weights ({}) do not match.", + psi.matrix().nrows(), + w.nrows() + ); + } + + let psi_matrix = psi.matrix(); + let py = psi_matrix * w; + + let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { + psi_matrix.get(i, j) * w.get(j) / py.get(i) + }); + + Ok(posterior.into()) +} diff --git a/src/structs/psi.rs b/src/structs/psi.rs index 3dfa61a10..c63756cae 100644 --- a/src/structs/psi.rs +++ b/src/structs/psi.rs @@ -1,3 +1,4 @@ +use anyhow::bail; use anyhow::Result; use faer::Mat; use faer_ext::IntoFaer; @@ -7,6 +8,7 @@ use pharmsol::prelude::simulator::psi; use pharmsol::Data; use pharmsol::Equation; use pharmsol::ErrorModels; +use serde::{Deserialize, Serialize}; use super::theta::Theta; @@ -53,6 +55,54 @@ impl Psi { .unwrap(); } } + + /// Write the psi matrix to a CSV writer + /// Each row represents a subject, each column represents a support point + pub fn to_csv(&self, writer: W) -> Result<()> { + let mut csv_writer = csv::Writer::from_writer(writer); + + // Write each row + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + csv_writer.serialize(row)?; + } + + csv_writer.flush()?; + Ok(()) + } + + /// Read psi matrix from a CSV reader + /// Each row represents a subject, each column represents a support point + pub fn from_csv(reader: R) -> Result { + let mut csv_reader = csv::Reader::from_reader(reader); + let mut rows: Vec> = Vec::new(); + + for result in csv_reader.deserialize() { + let row: Vec = result?; + rows.push(row); + } + + if rows.is_empty() { + bail!("CSV file is empty"); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + bail!("Row {} has {} columns, expected {}", i, row.len(), ncols); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Psi { matrix: mat }) + } } impl Default for Psi { @@ -88,6 +138,84 @@ impl From<&Array2> for Psi { } } +impl Serialize for Psi { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?; + + // Serialize each row as a vector + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + seq.serialize_element(&row)?; + } + + seq.end() + } +} + +impl<'de> Deserialize<'de> for Psi { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{SeqAccess, Visitor}; + use std::fmt; + + struct PsiVisitor; + + impl<'de> Visitor<'de> for PsiVisitor { + type Value = Psi; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of rows (vectors of f64)") + } + + fn visit_seq(self, mut seq: A) -> std::result::Result + where + A: SeqAccess<'de>, + { + let mut rows: Vec> = Vec::new(); + + while let Some(row) = seq.next_element::>()? { + rows.push(row); + } + + if rows.is_empty() { + return Err(serde::de::Error::custom("Empty matrix not allowed")); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + return Err(serde::de::Error::custom(format!( + "Row {} has {} columns, expected {}", + i, + row.len(), + ncols + ))); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + Ok(Psi { matrix: mat }) + } + } + + deserializer.deserialize_seq(PsiVisitor) + } +} + pub(crate) fn calculate_psi( equation: &impl Equation, subjects: &Data, diff --git a/src/structs/theta.rs b/src/structs/theta.rs index 2fbfac64a..efb8c50ff 100644 --- a/src/structs/theta.rs +++ b/src/structs/theta.rs @@ -1,6 +1,8 @@ use std::fmt::Debug; +use anyhow::{bail, Result}; use faer::Mat; +use serde::{Deserialize, Serialize}; use crate::prelude::Parameters; @@ -108,6 +110,58 @@ impl Theta { .unwrap(); } } + + /// Write the theta matrix to a CSV writer + /// Each row represents a support point, each column represents a parameter + pub fn to_csv(&self, writer: W) -> Result<()> { + let mut csv_writer = csv::Writer::from_writer(writer); + + // Write each row + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + csv_writer.serialize(row)?; + } + + csv_writer.flush()?; + Ok(()) + } + + /// Read theta matrix from a CSV reader + /// Each row represents a support point, each column represents a parameter + /// Note: This only reads the matrix values, not the parameter metadata + pub fn from_csv(reader: R) -> Result { + let mut csv_reader = csv::Reader::from_reader(reader); + let mut rows: Vec> = Vec::new(); + + for result in csv_reader.deserialize() { + let row: Vec = result?; + rows.push(row); + } + + if rows.is_empty() { + bail!("CSV file is empty"); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + bail!("Row {} has {} columns, expected {}", i, row.len(), ncols); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + // Create empty parameters - user will need to set these separately + let parameters = Parameters::new(); + + Ok(Theta::from_parts(mat, parameters)) + } } impl Debug for Theta { @@ -132,6 +186,87 @@ impl Debug for Theta { } } +impl Serialize for Theta { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(Some(self.matrix.nrows()))?; + + // Serialize each row as a vector + for i in 0..self.matrix.nrows() { + let row: Vec = (0..self.matrix.ncols()) + .map(|j| *self.matrix.get(i, j)) + .collect(); + seq.serialize_element(&row)?; + } + + seq.end() + } +} + +impl<'de> Deserialize<'de> for Theta { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + use serde::de::{SeqAccess, Visitor}; + use std::fmt; + + struct ThetaVisitor; + + impl<'de> Visitor<'de> for ThetaVisitor { + type Value = Theta; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a sequence of rows (vectors of f64)") + } + + fn visit_seq(self, mut seq: A) -> std::result::Result + where + A: SeqAccess<'de>, + { + let mut rows: Vec> = Vec::new(); + + while let Some(row) = seq.next_element::>()? { + rows.push(row); + } + + if rows.is_empty() { + return Err(serde::de::Error::custom("Empty matrix not allowed")); + } + + let nrows = rows.len(); + let ncols = rows[0].len(); + + // Verify all rows have the same length + for (i, row) in rows.iter().enumerate() { + if row.len() != ncols { + return Err(serde::de::Error::custom(format!( + "Row {} has {} columns, expected {}", + i, + row.len(), + ncols + ))); + } + } + + // Create matrix from rows + let mat = Mat::from_fn(nrows, ncols, |i, j| rows[i][j]); + + // Create empty parameters - user will need to set these separately + let parameters = Parameters::new(); + + Ok(Theta::from_parts(mat, parameters)) + } + } + + deserializer.deserialize_seq(ThetaVisitor) + } +} + #[cfg(test)] mod tests { use super::*; From 610177c404f664a43657850d7ec1c66dda43d090 Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 11:58:16 +0200 Subject: [PATCH 02/15] Move NPCycle and CycleLog --- src/algorithms/npag.rs | 20 ++-- src/algorithms/npod.rs | 21 ++-- src/algorithms/postprob.rs | 2 +- src/routines/output/cycles.rs | 203 ++++++++++++++++++++++++++++++++++ src/routines/output/mod.rs | 169 +--------------------------- 5 files changed, 227 insertions(+), 188 deletions(-) create mode 100644 src/routines/output/cycles.rs diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index ebcb988f6..9e889b713 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -5,7 +5,7 @@ pub use crate::routines::evaluation::ipm::burke; pub use crate::routines::evaluation::qr; use crate::routines::settings::Settings; -use crate::routines::output::{CycleLog, NPCycle, NPResult}; +use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; use crate::structs::psi::{calculate_psi, Psi}; use crate::structs::theta::Theta; @@ -165,15 +165,15 @@ impl Algorithms for NPAG { } // Create state object - let state = NPCycle { - cycle: self.cycle, - objf: -2. * self.objf, - delta_objf: (self.last_objf - self.objf).abs(), - nspp: self.theta.nspp(), - theta: self.theta.clone(), - error_models: self.error_models.clone(), - status: self.status.clone(), - }; + let state = NPCycle::new( + self.cycle, + -2. * self.objf, + self.error_models.clone(), + self.theta.clone(), + self.theta.nspp(), + (self.last_objf - self.objf).abs(), + self.status.clone(), + ); // Write cycle log self.cycle_log.push(state); diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index f00886ae4..bb8361816 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -1,10 +1,10 @@ +use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; use crate::{ algorithms::Status, prelude::{ algorithms::Algorithms, routines::{ evaluation::{ipm::burke, qr}, - output::{CycleLog, NPCycle, NPResult}, settings::Settings, }, }, @@ -13,6 +13,7 @@ use crate::{ theta::Theta, }, }; + use anyhow::bail; use anyhow::Result; use faer::Col; @@ -157,15 +158,15 @@ impl Algorithms for NPOD { } // Create state object - let state = NPCycle { - cycle: self.cycle, - objf: -2. * self.objf, - delta_objf: (self.last_objf - self.objf).abs(), - nspp: self.theta.nspp(), - theta: self.theta.clone(), - error_models: self.error_models.clone(), - status: self.status.clone(), - }; + let state = NPCycle::new( + self.cycle, + -2. * self.objf, + self.error_models.clone(), + self.theta.clone(), + self.theta.nspp(), + (self.last_objf - self.objf).abs(), + self.status.clone(), + ); // Write cycle log self.cycle_log.push(state); diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index e0c19049d..234c558f0 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -15,7 +15,7 @@ use pharmsol::prelude::{ use crate::routines::evaluation::ipm::burke; use crate::routines::initialization; -use crate::routines::output::{CycleLog, NPResult}; +use crate::routines::output::{cycles::CycleLog, NPResult}; use crate::routines::settings::Settings; /// Posterior probability algorithm diff --git a/src/routines/output/cycles.rs b/src/routines/output/cycles.rs new file mode 100644 index 000000000..7997cf72a --- /dev/null +++ b/src/routines/output/cycles.rs @@ -0,0 +1,203 @@ +use anyhow::Result; +use csv::WriterBuilder; +use pharmsol::{ErrorModel, ErrorModels}; + +use crate::{ + algorithms::Status, + prelude::Settings, + routines::output::{median, OutputFile}, + structs::theta::Theta, +}; + +/// An [NPCycle] object contains the summary of a cycle +/// It holds the following information: +/// - `cycle`: The cycle number +/// - `objf`: The objective function value +/// - `gamlam`: The assay noise parameter, either gamma or lambda +/// - `theta`: The support points and their associated probabilities +/// - `nspp`: The number of support points +/// - `delta_objf`: The change in objective function value from last cycle +/// - `converged`: Whether the algorithm has reached convergence +#[derive(Debug, Clone)] +pub struct NPCycle { + cycle: usize, + objf: f64, + error_models: ErrorModels, + theta: Theta, + nspp: usize, + delta_objf: f64, + status: Status, +} + +impl NPCycle { + pub fn new( + cycle: usize, + objf: f64, + error_models: ErrorModels, + theta: Theta, + nspp: usize, + delta_objf: f64, + status: Status, + ) -> Self { + Self { + cycle, + objf, + error_models, + theta, + nspp, + delta_objf, + status, + } + } + + pub fn cycle(&self) -> usize { + self.cycle + } + pub fn objf(&self) -> f64 { + self.objf + } + pub fn error_models(&self) -> &ErrorModels { + &self.error_models + } + pub fn theta(&self) -> &Theta { + &self.theta + } + pub fn nspp(&self) -> usize { + self.nspp + } + pub fn delta_objf(&self) -> f64 { + self.delta_objf + } + pub fn status(&self) -> &Status { + &self.status + } + + pub fn placeholder() -> Self { + Self { + cycle: 0, + objf: 0.0, + error_models: ErrorModels::default(), + theta: Theta::new(), + nspp: 0, + delta_objf: 0.0, + status: Status::Starting, + } + } +} + +/// This holdes a vector of [NPCycle] objects to provide a more detailed log +#[derive(Debug, Clone)] +pub struct CycleLog { + cycles: Vec, +} + +impl CycleLog { + pub fn new() -> Self { + Self { cycles: Vec::new() } + } + + pub fn cycles(&self) -> &[NPCycle] { + &self.cycles + } + + pub fn push(&mut self, cycle: NPCycle) { + self.cycles.push(cycle); + } + + pub fn write(&self, settings: &Settings) -> Result<()> { + tracing::debug!("Writing cycles..."); + let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(false) + .from_writer(&outputfile.file); + + // Write headers + writer.write_field("cycle")?; + writer.write_field("converged")?; + writer.write_field("status")?; + writer.write_field("neg2ll")?; + writer.write_field("nspp")?; + if let Some(first_cycle) = self.cycles.first() { + first_cycle.error_models.iter().try_for_each( + |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> { + match errmod { + ErrorModel::Additive { .. } => { + writer.write_field(format!("gamlam.{}", outeq))?; + } + ErrorModel::Proportional { .. } => { + writer.write_field(format!("gamlam.{}", outeq))?; + } + ErrorModel::None => {} + } + Ok(()) + }, + )?; + } + + let parameter_names = settings.parameters().names(); + for param_name in ¶meter_names { + writer.write_field(format!("{}.mean", param_name))?; + writer.write_field(format!("{}.median", param_name))?; + writer.write_field(format!("{}.sd", param_name))?; + } + + writer.write_record(None::<&[u8]>)?; + + for cycle in &self.cycles { + writer.write_field(format!("{}", cycle.cycle))?; + writer.write_field(format!("{}", cycle.status == Status::Converged))?; + writer.write_field(format!("{}", cycle.status))?; + writer.write_field(format!("{}", cycle.objf))?; + writer + .write_field(format!("{}", cycle.theta.nspp())) + .unwrap(); + + // Write the error models + cycle.error_models.iter().try_for_each( + |(_, errmod): (usize, &ErrorModel)| -> Result<()> { + match errmod { + ErrorModel::Additive { + lambda: _, + poly: _, + lloq: _, + } => { + writer.write_field(format!("{:.5}", errmod.factor()?))?; + } + ErrorModel::Proportional { + gamma: _, + poly: _, + lloq: _, + } => { + writer.write_field(format!("{:.5}", errmod.factor()?))?; + } + ErrorModel::None => {} + } + Ok(()) + }, + )?; + + for param in cycle.theta.matrix().col_iter() { + let param_values: Vec = param.iter().cloned().collect(); + + let mean: f64 = param_values.iter().sum::() / param_values.len() as f64; + let median = median(param_values.clone()); + let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::() + / (param_values.len() as f64 - 1.0); + + writer.write_field(format!("{}", mean))?; + writer.write_field(format!("{}", median))?; + writer.write_field(format!("{}", std))?; + } + writer.write_record(None::<&[u8]>)?; + } + writer.flush()?; + tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path()); + Ok(()) + } +} + +impl Default for CycleLog { + fn default() -> Self { + Self::new() + } +} diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 921bc3983..55f8d8e8f 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -1,5 +1,6 @@ use crate::algorithms::Status; use crate::prelude::*; +use crate::routines::output::cycles::CycleLog; use crate::routines::settings::Settings; use crate::structs::psi::Psi; use crate::structs::theta::Theta; @@ -15,6 +16,7 @@ use serde::Serialize; use std::fs::{create_dir_all, File, OpenOptions}; use std::path::{Path, PathBuf}; +pub mod cycles; pub mod posterior; use posterior::posterior; @@ -627,173 +629,6 @@ impl NPResult { } } -/// An [NPCycle] object contains the summary of a cycle -/// It holds the following information: -/// - `cycle`: The cycle number -/// - `objf`: The objective function value -/// - `gamlam`: The assay noise parameter, either gamma or lambda -/// - `theta`: The support points and their associated probabilities -/// - `nspp`: The number of support points -/// - `delta_objf`: The change in objective function value from last cycle -/// - `converged`: Whether the algorithm has reached convergence -#[derive(Debug, Clone)] -pub struct NPCycle { - pub cycle: usize, - pub objf: f64, - pub error_models: ErrorModels, - pub theta: Theta, - pub nspp: usize, - pub delta_objf: f64, - pub status: Status, -} - -impl NPCycle { - pub fn new( - cycle: usize, - objf: f64, - error_models: ErrorModels, - theta: Theta, - nspp: usize, - delta_objf: f64, - status: Status, - ) -> Self { - Self { - cycle, - objf, - error_models, - theta, - nspp, - delta_objf, - status, - } - } - - pub fn placeholder() -> Self { - Self { - cycle: 0, - objf: 0.0, - error_models: ErrorModels::default(), - theta: Theta::new(), - nspp: 0, - delta_objf: 0.0, - status: Status::Starting, - } - } -} - -/// This holdes a vector of [NPCycle] objects to provide a more detailed log -#[derive(Debug, Clone)] -pub struct CycleLog { - pub cycles: Vec, -} - -impl CycleLog { - pub fn new() -> Self { - Self { cycles: Vec::new() } - } - - pub fn push(&mut self, cycle: NPCycle) { - self.cycles.push(cycle); - } - - pub fn write(&self, settings: &Settings) -> Result<()> { - tracing::debug!("Writing cycles..."); - let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?; - let mut writer = WriterBuilder::new() - .has_headers(false) - .from_writer(&outputfile.file); - - // Write headers - writer.write_field("cycle")?; - writer.write_field("converged")?; - writer.write_field("status")?; - writer.write_field("neg2ll")?; - writer.write_field("nspp")?; - if let Some(first_cycle) = self.cycles.first() { - first_cycle.error_models.iter().try_for_each( - |(outeq, errmod): (usize, &ErrorModel)| -> Result<(), csv::Error> { - match errmod { - ErrorModel::Additive { .. } => { - writer.write_field(format!("gamlam.{}", outeq))?; - } - ErrorModel::Proportional { .. } => { - writer.write_field(format!("gamlam.{}", outeq))?; - } - ErrorModel::None => {} - } - Ok(()) - }, - )?; - } - - let parameter_names = settings.parameters().names(); - for param_name in ¶meter_names { - writer.write_field(format!("{}.mean", param_name))?; - writer.write_field(format!("{}.median", param_name))?; - writer.write_field(format!("{}.sd", param_name))?; - } - - writer.write_record(None::<&[u8]>)?; - - for cycle in &self.cycles { - writer.write_field(format!("{}", cycle.cycle))?; - writer.write_field(format!("{}", cycle.status == Status::Converged))?; - writer.write_field(format!("{}", cycle.status))?; - writer.write_field(format!("{}", cycle.objf))?; - writer - .write_field(format!("{}", cycle.theta.nspp())) - .unwrap(); - - // Write the error models - cycle.error_models.iter().try_for_each( - |(_, errmod): (usize, &ErrorModel)| -> Result<()> { - match errmod { - ErrorModel::Additive { - lambda: _, - poly: _, - lloq: _, - } => { - writer.write_field(format!("{:.5}", errmod.factor()?))?; - } - ErrorModel::Proportional { - gamma: _, - poly: _, - lloq: _, - } => { - writer.write_field(format!("{:.5}", errmod.factor()?))?; - } - ErrorModel::None => {} - } - Ok(()) - }, - )?; - - for param in cycle.theta.matrix().col_iter() { - let param_values: Vec = param.iter().cloned().collect(); - - let mean: f64 = param_values.iter().sum::() / param_values.len() as f64; - let median = median(param_values.clone()); - let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::() - / (param_values.len() as f64 - 1.0); - - writer.write_field(format!("{}", mean))?; - writer.write_field(format!("{}", median))?; - writer.write_field(format!("{}", std))?; - } - writer.write_record(None::<&[u8]>)?; - } - writer.flush()?; - tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path()); - Ok(()) - } -} - -impl Default for CycleLog { - fn default() -> Self { - Self::new() - } -} - pub fn median(data: Vec) -> f64 { let mut data = data.clone(); data.sort_by(|a, b| a.partial_cmp(b).unwrap()); From 030408b0a0bad7fc08d8be1140c0a428dcae043a Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 12:05:01 +0200 Subject: [PATCH 03/15] The median helper function is no longer public, and takes a reference The helper function does not need to own the data, and as such it can take a reference to the data instead. --- src/routines/output/cycles.rs | 2 +- src/routines/output/mod.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/routines/output/cycles.rs b/src/routines/output/cycles.rs index 7997cf72a..4cf02bdd4 100644 --- a/src/routines/output/cycles.rs +++ b/src/routines/output/cycles.rs @@ -180,7 +180,7 @@ impl CycleLog { let param_values: Vec = param.iter().cloned().collect(); let mean: f64 = param_values.iter().sum::() / param_values.len() as f64; - let median = median(param_values.clone()); + let median = median(¶m_values); let std = param_values.iter().map(|x| (x - mean).powi(2)).sum::() / (param_values.len() as f64 - 1.0); diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 55f8d8e8f..8df822b7c 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -629,7 +629,7 @@ impl NPResult { } } -pub fn median(data: Vec) -> f64 { +pub(crate) fn median(data: &Vec) -> f64 { let mut data = data.clone(); data.sort_by(|a, b| a.partial_cmp(b).unwrap()); @@ -837,37 +837,37 @@ mod tests { #[test] fn test_median_odd() { let data = vec![1.0, 3.0, 2.0]; - assert_eq!(median(data), 2.0); + assert_eq!(median(&data), 2.0); } #[test] fn test_median_even() { let data = vec![1.0, 2.0, 3.0, 4.0]; - assert_eq!(median(data), 2.5); + assert_eq!(median(&data), 2.5); } #[test] fn test_median_single() { let data = vec![42.0]; - assert_eq!(median(data), 42.0); + assert_eq!(median(&data), 42.0); } #[test] fn test_median_sorted() { let data = vec![5.0, 10.0, 15.0, 20.0, 25.0]; - assert_eq!(median(data), 15.0); + assert_eq!(median(&data), 15.0); } #[test] fn test_median_unsorted() { let data = vec![10.0, 30.0, 20.0, 50.0, 40.0]; - assert_eq!(median(data), 30.0); + assert_eq!(median(&data), 30.0); } #[test] fn test_median_with_duplicates() { let data = vec![1.0, 2.0, 2.0, 3.0, 4.0]; - assert_eq!(median(data), 2.0); + assert_eq!(median(&data), 2.0); } use super::weighted_median; From ec378b8717bdc5697a4037a0be6cffd8e774e93b Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 15:30:42 +0200 Subject: [PATCH 04/15] OutputFile --- src/routines/logger.rs | 2 +- src/routines/output/cycles.rs | 2 +- src/routines/output/mod.rs | 35 +++++++++++++++++------------------ src/routines/settings.rs | 2 +- 4 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/routines/logger.rs b/src/routines/logger.rs index 91214fd69..2b5d411cc 100644 --- a/src/routines/logger.rs +++ b/src/routines/logger.rs @@ -48,7 +48,7 @@ pub(crate) fn setup_log(settings: &mut Settings) -> Result<()> { let file_layer = match settings.log().write { true => { let layer = fmt::layer() - .with_writer(outputfile.file) + .with_writer(outputfile.file_owned()) .with_ansi(false) .with_timer(timestamper.clone()); diff --git a/src/routines/output/cycles.rs b/src/routines/output/cycles.rs index 4cf02bdd4..536ea892a 100644 --- a/src/routines/output/cycles.rs +++ b/src/routines/output/cycles.rs @@ -191,7 +191,7 @@ impl CycleLog { writer.write_record(None::<&[u8]>)?; } writer.flush()?; - tracing::debug!("Cycles written to {:?}", &outputfile.get_relative_path()); + tracing::debug!("Cycles written to {:?}", &outputfile.relative_path()); Ok(()) } } diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 8df822b7c..ce9ed145e 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -245,7 +245,7 @@ impl NPResult { writer.flush()?; tracing::debug!( "Observations with predictions written to {:?}", - &outputfile.get_relative_path() + &outputfile.relative_path() ); Ok(()) } @@ -287,7 +287,7 @@ impl NPResult { writer.flush()?; tracing::debug!( "Population parameter distribution written to {:?}", - &outputfile.get_relative_path() + &outputfile.relative_path() ); Ok(()) } @@ -351,7 +351,7 @@ impl NPResult { writer.flush()?; tracing::debug!( "Posterior parameters written to {:?}", - &outputfile.get_relative_path() + &outputfile.relative_path() ); Ok(()) @@ -393,10 +393,7 @@ impl NPResult { } writer.flush()?; - tracing::debug!( - "Observations written to {:?}", - &outputfile.get_relative_path() - ); + tracing::debug!("Observations written to {:?}", &outputfile.relative_path()); Ok(()) } @@ -549,10 +546,7 @@ impl NPResult { } writer.flush()?; - tracing::debug!( - "Predictions written to {:?}", - &outputfile.get_relative_path() - ); + tracing::debug!("Predictions written to {:?}", &outputfile.relative_path()); Ok(()) } @@ -621,10 +615,7 @@ impl NPResult { } writer.flush()?; - tracing::debug!( - "Covariates written to {:?}", - &outputfile.get_relative_path() - ); + tracing::debug!("Covariates written to {:?}", &outputfile.relative_path()); Ok(()) } } @@ -799,8 +790,8 @@ pub fn posterior_mean_median( /// Contains all the necessary information of an output file #[derive(Debug)] pub struct OutputFile { - pub file: File, - pub relative_path: PathBuf, + file: File, + relative_path: PathBuf, } impl OutputFile { @@ -825,7 +816,15 @@ impl OutputFile { }) } - pub fn get_relative_path(&self) -> &Path { + pub fn file(&self) -> &File { + &self.file + } + + pub fn file_owned(self) -> File { + self.file + } + + pub fn relative_path(&self) -> &Path { &self.relative_path } } diff --git a/src/routines/settings.rs b/src/routines/settings.rs index fd84d34e1..e065ee91e 100644 --- a/src/routines/settings.rs +++ b/src/routines/settings.rs @@ -134,7 +134,7 @@ impl Settings { let serialized = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?; let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?; - let mut file = outputfile.file; + let mut file = outputfile.file_owned(); std::io::Write::write_all(&mut file, serialized.as_bytes())?; Ok(()) } From 9a046353d445efbc11339cfcd33f234dbbfc2ffe Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 16:10:34 +0200 Subject: [PATCH 05/15] Update posterior.rs --- src/routines/output/posterior.rs | 35 ++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs index f8f08b775..37190cac0 100644 --- a/src/routines/output/posterior.rs +++ b/src/routines/output/posterior.rs @@ -12,11 +12,41 @@ pub struct Posterior { impl Posterior { /// Create a new Posterior from a matrix - pub fn new(mat: Mat) -> Self { + fn new(mat: Mat) -> Self { Posterior { mat } } - /// Get the underlying matrix + /// Calculate the posterior probabilities for each support point given the weights + /// + /// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. + /// /// # Errors + /// Returns an error if the number of rows in `psi` does not match the number of weights in `w`. + /// # Arguments + /// * `psi` - The Psi object containing the matrix of support points. + /// * `w` - The weights for each support point. + /// # Returns + /// A Result containing the Posterior probabilities if successful, or an error if the + /// dimensions do not match. + pub fn calculate(psi: &Psi, w: &Col) -> Result { + if psi.matrix().ncols() != w.nrows() { + bail!( + "Number of rows in psi ({}) and number of weights ({}) do not match.", + psi.matrix().nrows(), + w.nrows() + ); + } + + let psi_matrix = psi.matrix(); + let py = psi_matrix * w; + + let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { + psi_matrix.get(i, j) * w.get(j) / py.get(i) + }); + + Ok(posterior.into()) + } + + /// Get a reference to the underlying matrix pub fn matrix(&self) -> &Mat { &self.mat } @@ -68,6 +98,7 @@ impl Posterior { } } +/// Convert a matrix to a [Posterior] impl From> for Posterior { fn from(mat: Mat) -> Self { Posterior::new(mat) From 42c3d0c651840e439fbfe02b68ceeedd3040c39c Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 16:13:31 +0200 Subject: [PATCH 06/15] Update posterior.rs --- src/routines/output/posterior.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs index 37190cac0..bc8a19e2b 100644 --- a/src/routines/output/posterior.rs +++ b/src/routines/output/posterior.rs @@ -27,7 +27,7 @@ impl Posterior { /// # Returns /// A Result containing the Posterior probabilities if successful, or an error if the /// dimensions do not match. - pub fn calculate(psi: &Psi, w: &Col) -> Result { + pub fn calculate(psi: &Psi, w: &Col) -> Result { if psi.matrix().ncols() != w.nrows() { bail!( "Number of rows in psi ({}) and number of weights ({}) do not match.", From 6e37524f2fece4bee480612b289a01ffe8b83245 Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 16:16:18 +0200 Subject: [PATCH 07/15] Clippy --- src/routines/output/mod.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index ce9ed145e..25dd2b02b 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -620,8 +620,8 @@ impl NPResult { } } -pub(crate) fn median(data: &Vec) -> f64 { - let mut data = data.clone(); +pub(crate) fn median(data: &[f64]) -> f64 { + let mut data: Vec = data.to_vec(); data.sort_by(|a, b| a.partial_cmp(b).unwrap()); let size = data.len(); @@ -635,7 +635,7 @@ pub(crate) fn median(data: &Vec) -> f64 { } } -fn weighted_median(data: &Vec, weights: &Vec) -> f64 { +fn weighted_median(data: &[f64], weights: &Vec) -> f64 { // Ensure the data and weights arrays have the same length assert_eq!( data.len(), From b3884d98a41e118a6d8ecc6736beca71d0cce891 Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 17:59:22 +0200 Subject: [PATCH 08/15] Move prediction logic --- src/routines/output/mod.rs | 1 + src/routines/output/predictions.rs | 225 +++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+) create mode 100644 src/routines/output/predictions.rs diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 25dd2b02b..62fc910ce 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -18,6 +18,7 @@ use std::path::{Path, PathBuf}; pub mod cycles; pub mod posterior; +pub mod predictions; use posterior::posterior; diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs new file mode 100644 index 000000000..8a51b7047 --- /dev/null +++ b/src/routines/output/predictions.rs @@ -0,0 +1,225 @@ +use anyhow::{bail, Result}; +use faer::Col; +use pharmsol::{prelude::simulator::Prediction, Data, Event, Predictions as PredTrait}; +use serde::Serialize; + +use crate::{ + routines::output::{posterior::Posterior, weighted_median}, + structs::theta::Theta, +}; + +// Structure for the output +#[derive(Debug, Clone, Serialize)] +pub struct NPPredictionRow { + id: String, + time: f64, + outeq: usize, + block: usize, + obs: Option, + pop_mean: f64, + pop_median: f64, + post_mean: f64, + post_median: f64, +} + +impl NPPredictionRow { + pub fn id(&self) -> &str { + &self.id + } + pub fn time(&self) -> f64 { + self.time + } + pub fn outeq(&self) -> usize { + self.outeq + } + pub fn block(&self) -> usize { + self.block + } + pub fn obs(&self) -> Option { + self.obs + } + pub fn pop_mean(&self) -> f64 { + self.pop_mean + } + pub fn pop_median(&self) -> f64 { + self.pop_median + } + pub fn post_mean(&self) -> f64 { + self.post_mean + } + pub fn post_median(&self) -> f64 { + self.post_median + } +} + +pub struct NPPredictions { + predictions: Vec, +} + +impl IntoIterator for NPPredictions { + type Item = NPPredictionRow; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.predictions.into_iter() + } +} + +impl Default for NPPredictions { + fn default() -> Self { + NPPredictions::new() + } +} + +impl NPPredictions { + pub fn new() -> Self { + NPPredictions { + predictions: Vec::new(), + } + } + + /// Add a [NPPredictionRow] to the predictions + pub fn add(&mut self, row: NPPredictionRow) { + self.predictions.push(row); + } + + /// Get a reference to the predictions + pub fn predictions(&self) -> &[NPPredictionRow] { + &self.predictions + } + + pub fn calculate( + equation: &impl pharmsol::prelude::simulator::Equation, + data: &Data, + theta: Theta, + w: &Col, + posterior: &Posterior, + idelta: f64, + tad: f64, + ) -> Result { + // Create a new NPPredictions instance + let mut container = NPPredictions::new(); + + // Expand data + let data = data.clone().expand(idelta, tad); + let subjects = data.subjects(); + + if subjects.len() != posterior.matrix().nrows() { + bail!("Number of subjects and number of posterior means do not match"); + }; + + // Iterate over each subject and then each support point + for subject in subjects.iter().enumerate() { + let (subject_index, subject) = subject; + + // Get a vector of occasions for this subject, for each predictions + let occasions = subject + .occasions() + .iter() + .flat_map(|o| { + o.events() + .iter() + .filter_map(|e| { + if let Event::Observation(_obs) = e { + Some(o.index()) + } else { + None + } + }) + .collect::>() + }) + .collect::>(); + + // Container for predictions for this subject + // This will hold predictions for each support point + // The outer vector is for each support point + // The inner vector is for the vector of predictions for that support point + let mut predictions: Vec> = Vec::new(); + + // And each support points + for spp in theta.matrix().row_iter() { + // Simulate the subject with the current support point + let spp_values = spp.iter().cloned().collect::>(); + let pred = equation + .simulate_subject(subject, &spp_values, None)? + .0 + .get_predictions(); + predictions.push(pred); + } + + if predictions.is_empty() { + continue; // Skip this subject if no predictions are available + } + + // Calculate population mean using + let mut pop_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; + for outer_pred in predictions.iter().enumerate() { + let (i, outer_pred) = outer_pred; + for inner_pred in outer_pred.iter().enumerate() { + let (j, pred) = inner_pred; + pop_mean[j] += pred.prediction() * w[i]; + } + } + + // Calculate population median + let mut pop_median: Vec = Vec::new(); + for j in 0..predictions.first().unwrap().len() { + let mut values: Vec = Vec::new(); + let mut weights: Vec = Vec::new(); + + for (i, outer_pred) in predictions.iter().enumerate() { + values.push(outer_pred[j].prediction()); + weights.push(w[i]); + } + + let median_val = weighted_median(&values, &weights); + pop_median.push(median_val); + } + + // Calculate posterior mean + let mut posterior_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; + for outer_pred in predictions.iter().enumerate() { + let (i, outer_pred) = outer_pred; + for inner_pred in outer_pred.iter().enumerate() { + let (j, pred) = inner_pred; + posterior_mean[j] += pred.prediction() * posterior.matrix()[(subject_index, i)]; + } + } + + // Calculate posterior median + let mut posterior_median: Vec = Vec::new(); + for j in 0..predictions.first().unwrap().len() { + let mut values: Vec = Vec::new(); + let mut weights: Vec = Vec::new(); + + for (i, outer_pred) in predictions.iter().enumerate() { + values.push(outer_pred[j].prediction()); + weights.push(posterior.matrix()[(subject_index, i)]); + } + + let median_val = weighted_median(&values, &weights); + posterior_median.push(median_val); + } + + for pred in predictions.iter().enumerate() { + let (_, preds) = pred; + for (j, p) in preds.iter().enumerate() { + let row = NPPredictionRow { + id: subject.id().clone(), + time: p.time(), + outeq: p.outeq(), + block: occasions[j], + obs: p.observation(), + pop_mean: pop_mean[j], + pop_median: pop_median[j], + post_mean: posterior_mean[j], + post_median: posterior_median[j], + }; + container.add(row); + } + } + } + + Ok(container) + } +} From 65d765ab123d7a5e1707e69f3b120f9154bda9af Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 17 Aug 2025 18:20:43 +0200 Subject: [PATCH 09/15] Move prediction logic from monolith to module --- src/routines/output/mod.rs | 148 +++-------------------------- src/routines/output/predictions.rs | 4 +- 2 files changed, 16 insertions(+), 136 deletions(-) diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 62fc910ce..4ddac556d 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -1,6 +1,7 @@ use crate::algorithms::Status; use crate::prelude::*; use crate::routines::output::cycles::CycleLog; +use crate::routines::output::predictions::NPPredictions; use crate::routines::settings::Settings; use crate::structs::psi::Psi; use crate::structs::theta::Theta; @@ -11,7 +12,7 @@ use faer::Col; use faer_ext::IntoNdarray; use ndarray::{Array, Array1, Array2, Axis}; use pharmsol::prelude::data::*; -use pharmsol::prelude::simulator::{Equation, Prediction}; +use pharmsol::prelude::simulator::Equation; use serde::Serialize; use std::fs::{create_dir_all, File, OpenOptions}; use std::path::{Path, PathBuf}; @@ -402,18 +403,17 @@ impl NPResult { pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> { tracing::debug!("Writing predictions..."); - // Get necessary data - let theta = self.theta.matrix(); - let w: Vec = self.w.iter().cloned().collect(); let posterior = posterior(&self.psi, &self.w)?; - let data = self.data.clone().expand(idelta, tad); - - let subjects = data.subjects(); - - if subjects.len() != posterior.matrix().nrows() { - bail!("Number of subjects and number of posterior means do not match"); - }; + let predictions = NPPredictions::calculate( + &self.equation, + &self.data, + self.theta.clone(), + &self.w, + &posterior, + idelta, + tad, + )?; // Create the output file and writer for pred.csv let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?; @@ -421,129 +421,9 @@ impl NPResult { .has_headers(true) .from_writer(&outputfile.file); - // Iterate over each subject and then each support point - for subject in subjects.iter().enumerate() { - let (subject_index, subject) = subject; - - // Get a vector of occasions for this subject, for each predictions - let occasions = subject - .occasions() - .iter() - .flat_map(|o| { - o.events() - .iter() - .filter_map(|e| { - if let Event::Observation(_obs) = e { - Some(o.index()) - } else { - None - } - }) - .collect::>() - }) - .collect::>(); - - // Container for predictions for this subject - // This will hold predictions for each support point - // The outer vector is for each support point - // The inner vector is for the vector of predictions for that support point - let mut predictions: Vec> = Vec::new(); - - // And each support points - for spp in theta.row_iter() { - // Simulate the subject with the current support point - let spp_values = spp.iter().cloned().collect::>(); - let pred = self - .equation - .simulate_subject(subject, &spp_values, None)? - .0 - .get_predictions(); - predictions.push(pred); - } - - if predictions.is_empty() { - continue; // Skip this subject if no predictions are available - } - - // Calculate population mean using - let mut pop_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; - for outer_pred in predictions.iter().enumerate() { - let (i, outer_pred) = outer_pred; - for inner_pred in outer_pred.iter().enumerate() { - let (j, pred) = inner_pred; - pop_mean[j] += pred.prediction() * w[i]; - } - } - - // Calculate population median - let mut pop_median: Vec = Vec::new(); - for j in 0..predictions.first().unwrap().len() { - let mut values: Vec = Vec::new(); - let mut weights: Vec = Vec::new(); - - for (i, outer_pred) in predictions.iter().enumerate() { - values.push(outer_pred[j].prediction()); - weights.push(w[i]); - } - - let median_val = weighted_median(&values, &weights); - pop_median.push(median_val); - } - - // Calculate posterior mean - let mut posterior_mean: Vec = vec![0.0; predictions.first().unwrap().len()]; - for outer_pred in predictions.iter().enumerate() { - let (i, outer_pred) = outer_pred; - for inner_pred in outer_pred.iter().enumerate() { - let (j, pred) = inner_pred; - posterior_mean[j] += pred.prediction() * posterior.matrix()[(subject_index, i)]; - } - } - - // Calculate posterior median - let mut posterior_median: Vec = Vec::new(); - for j in 0..predictions.first().unwrap().len() { - let mut values: Vec = Vec::new(); - let mut weights: Vec = Vec::new(); - - for (i, outer_pred) in predictions.iter().enumerate() { - values.push(outer_pred[j].prediction()); - weights.push(posterior.matrix()[(subject_index, i)]); - } - - let median_val = weighted_median(&values, &weights); - posterior_median.push(median_val); - } - - // Structure for the output - #[derive(Debug, Clone, Serialize)] - struct Row { - id: String, - time: f64, - outeq: usize, - block: usize, - pop_mean: f64, - pop_median: f64, - post_mean: f64, - post_median: f64, - } - - for pred in predictions.iter().enumerate() { - let (_, preds) = pred; - for (j, p) in preds.iter().enumerate() { - let row = Row { - id: subject.id().clone(), - time: p.time(), - outeq: p.outeq(), - block: occasions[j], - pop_mean: pop_mean[j], - pop_median: pop_median[j], - post_mean: posterior_mean[j], - post_median: posterior_median[j], - }; - writer.serialize(row)?; - } - } + // Write each prediction row + for row in predictions.predictions() { + writer.serialize(row)?; } writer.flush()?; diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs index 8a51b7047..a52c0f77f 100644 --- a/src/routines/output/predictions.rs +++ b/src/routines/output/predictions.rs @@ -1,7 +1,7 @@ use anyhow::{bail, Result}; use faer::Col; use pharmsol::{prelude::simulator::Prediction, Data, Event, Predictions as PredTrait}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use crate::{ routines::output::{posterior::Posterior, weighted_median}, @@ -9,7 +9,7 @@ use crate::{ }; // Structure for the output -#[derive(Debug, Clone, Serialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct NPPredictionRow { id: String, time: f64, From 10745c7696e6e36dbbe1bfed2a6b5df04366143a Mon Sep 17 00:00:00 2001 From: Markus Date: Mon, 18 Aug 2025 18:00:07 +0200 Subject: [PATCH 10/15] Update predictions.rs --- src/routines/output/predictions.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs index a52c0f77f..bbb89ea4e 100644 --- a/src/routines/output/predictions.rs +++ b/src/routines/output/predictions.rs @@ -88,6 +88,18 @@ impl NPPredictions { &self.predictions } + /// Calculate the populatuion and posterior predictions + /// + /// # Arguments + /// * `equation` - The equation to use for simulation + /// * `data` - The data to use for simulation + /// * `theta` - The theta values for the simulation + /// * `w` - The weights for the simulation + /// * `posterior` - The posterior values for the simulation + /// * `idelta` - The delta for the simulation + /// * `tad` - The time after dose for the simulation + /// # Returns + /// A Result containing the NPPredictions or an error pub fn calculate( equation: &impl pharmsol::prelude::simulator::Equation, data: &Data, From c957cb41d7516247ee4ba375f922c7c72062d9e7 Mon Sep 17 00:00:00 2001 From: Markus Date: Mon, 18 Aug 2025 18:14:07 +0200 Subject: [PATCH 11/15] Weights are now a dedicated structure --- src/structs/mod.rs | 1 + src/structs/weights.rs | 97 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 src/structs/weights.rs diff --git a/src/structs/mod.rs b/src/structs/mod.rs index 62886f8f0..eb24bf1e4 100644 --- a/src/structs/mod.rs +++ b/src/structs/mod.rs @@ -1,2 +1,3 @@ pub mod psi; pub mod theta; +pub mod weights; diff --git a/src/structs/weights.rs b/src/structs/weights.rs new file mode 100644 index 000000000..21d4742ab --- /dev/null +++ b/src/structs/weights.rs @@ -0,0 +1,97 @@ +use faer::Col; +use serde::{Deserialize, Serialize}; +use std::ops::{Index, IndexMut}; + +/// The weight (probabilities) for each support point in the model. +/// +/// This struct is used to hold the weights for each support point in the model. +#[derive(Debug, Clone)] +pub struct Weights { + weights: Col, +} + +impl Default for Weights { + fn default() -> Self { + Self { + weights: Col::from_fn(0, |_| 0.0), + } + } +} + +impl Weights { + pub fn new(weights: Col) -> Self { + Self { weights } + } + + /// Create a new [Weights] instance from a vector of weights. + pub fn from_vec(weights: Vec) -> Self { + Self { + weights: Col::from_fn(weights.len(), |i| weights[i]), + } + } + + /// Get a reference to the weights. + pub fn weights(&self) -> &Col { + &self.weights + } + + /// Get a mutable reference to the weights. + pub fn weights_mut(&mut self) -> &mut Col { + &mut self.weights + } + + /// Get the number of weights. + pub fn len(&self) -> usize { + self.weights.nrows() + } + + /// Get a vector representation of the weights. + pub fn to_vec(&self) -> Vec { + self.weights.iter().cloned().collect() + } +} + +impl Serialize for Weights { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + self.to_vec().serialize(serializer) + } +} + +impl<'de> Deserialize<'de> for Weights { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let weights_vec = Vec::::deserialize(deserializer)?; + Ok(Self::from_vec(weights_vec)) + } +} + +impl From> for Weights { + fn from(weights: Vec) -> Self { + Self::from_vec(weights) + } +} + +impl From> for Weights { + fn from(weights: Col) -> Self { + Self { weights } + } +} + +impl Index for Weights { + type Output = f64; + + fn index(&self, index: usize) -> &Self::Output { + &self.weights[index] + } +} + +impl IndexMut for Weights { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.weights[index] + } +} From 7f59af451959d7b681e70ffe3d174f089edd04fc Mon Sep 17 00:00:00 2001 From: Markus Date: Mon, 18 Aug 2025 18:29:23 +0200 Subject: [PATCH 12/15] Use weights instead of Col where applicable --- src/algorithms/npag.rs | 23 +++++++++++------------ src/algorithms/npod.rs | 16 ++++++++-------- src/algorithms/postprob.rs | 7 ++++--- src/routines/evaluation/ipm.rs | 15 ++++++++++----- src/routines/output/mod.rs | 26 ++++++++++++++++++++------ src/routines/output/posterior.rs | 12 ++++++------ src/routines/output/predictions.rs | 5 ++--- src/structs/weights.rs | 5 +++++ 8 files changed, 66 insertions(+), 43 deletions(-) diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 9e889b713..2cf7f37cd 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -8,6 +8,7 @@ use crate::routines::settings::Settings; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; use crate::structs::psi::{calculate_psi, Psi}; use crate::structs::theta::Theta; +use crate::structs::weights::Weights; use anyhow::bail; use anyhow::Result; @@ -18,8 +19,6 @@ use pharmsol::prelude::{ use pharmsol::prelude::ErrorModel; -use faer::Col; - use crate::routines::initialization; use crate::routines::expansion::adaptative_grid::adaptative_grid; @@ -35,8 +34,8 @@ pub struct NPAG { ranges: Vec<(f64, f64)>, psi: Psi, theta: Theta, - lambda: Col, - w: Col, + lambda: Weights, + w: Weights, eps: f64, last_objf: f64, objf: f64, @@ -59,8 +58,8 @@ impl Algorithms for NPAG { ranges: settings.parameters().ranges(), psi: Psi::new(), theta: Theta::new(), - lambda: Col::zeros(0), - w: Col::zeros(0), + lambda: Weights::default(), + w: Weights::default(), eps: 0.2, last_objf: -1e30, objf: f64::NEG_INFINITY, @@ -138,7 +137,7 @@ impl Algorithms for NPAG { if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E { self.eps /= 2.; if self.eps <= THETA_E { - let pyl = psi * w; + let pyl = psi * w.weights(); self.f1 = pyl.iter().map(|x| x.ln()).sum(); if (self.f1 - self.f0).abs() <= THETA_F { tracing::info!("The model converged after {} cycles", self.cycle,); @@ -199,7 +198,7 @@ impl Algorithms for NPAG { } (self.lambda, _) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), + Ok((lambda, objf)) => (lambda.into(), objf), Err(err) => { bail!("Error in IPM during evaluation: {:?}", err); } @@ -213,11 +212,11 @@ impl Algorithms for NPAG { let max_lambda = self .lambda .iter() - .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc)); + .fold(f64::NEG_INFINITY, |acc, x| x.max(acc)); let mut keep = Vec::::new(); for (index, lam) in self.lambda.iter().enumerate() { - if *lam > max_lambda / 1000_f64 { + if lam > max_lambda / 1000_f64 { keep.push(index); } } @@ -262,7 +261,7 @@ impl Algorithms for NPAG { self.validate_psi()?; (self.lambda, self.objf) = match burke(&self.psi) { - Ok((lambda, objf)) => (lambda, objf), + Ok((lambda, objf)) => (lambda.into(), objf), Err(err) => { return Err(anyhow::anyhow!( "Error in IPM during condensation: {:?}", @@ -270,7 +269,7 @@ impl Algorithms for NPAG { )); } }; - self.w = self.lambda.clone(); + self.w = self.lambda.clone().into(); Ok(()) } diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index bb8361816..d31b02e46 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -1,4 +1,5 @@ use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; +use crate::structs::weights::Weights; use crate::{ algorithms::Status, prelude::{ @@ -16,7 +17,6 @@ use crate::{ use anyhow::bail; use anyhow::Result; -use faer::Col; use faer_ext::IntoNdarray; use pharmsol::{prelude::ErrorModel, ErrorModels}; use pharmsol::{ @@ -38,8 +38,8 @@ pub struct NPOD { equation: E, psi: Psi, theta: Theta, - lambda: Col, - w: Col, + lambda: Weights, + w: Weights, last_objf: f64, objf: f64, cycle: usize, @@ -58,8 +58,8 @@ impl Algorithms for NPOD { equation, psi: Psi::new(), theta: Theta::new(), - lambda: Col::zeros(0), - w: Col::zeros(0), + lambda: Weights::default(), + w: Weights::default(), last_objf: -1e30, objf: f64::NEG_INFINITY, cycle: 0, @@ -206,11 +206,11 @@ impl Algorithms for NPOD { let max_lambda = self .lambda .iter() - .fold(f64::NEG_INFINITY, |acc, &x| x.max(acc)); + .fold(f64::NEG_INFINITY, |acc, x| x.max(acc)); let mut keep = Vec::::new(); for (index, lam) in self.lambda.iter().enumerate() { - if *lam > max_lambda / 1000_f64 { + if lam > max_lambda / 1000_f64 { keep.push(index); } } @@ -368,7 +368,7 @@ impl Algorithms for NPOD { fn expansion(&mut self) -> Result<()> { // If no stop signal, add new point to theta based on the optimization of the D function let psi = self.psi().matrix().as_ref().into_ndarray().to_owned(); - let w: Array1 = self.w.clone().iter().cloned().collect(); + let w: Array1 = self.w.clone().iter().collect(); let pyl = psi.dot(&w); // Add new point to theta based on the optimization of the D function diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index 234c558f0..0fcb7a1b0 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -4,10 +4,11 @@ use crate::{ structs::{ psi::{calculate_psi, Psi}, theta::Theta, + weights::Weights, }, }; use anyhow::{Context, Result}; -use faer::Col; + use pharmsol::prelude::{ data::{Data, ErrorModels}, simulator::Equation, @@ -24,7 +25,7 @@ pub struct POSTPROB { equation: E, psi: Psi, theta: Theta, - w: Col, + w: Weights, objf: f64, cycle: usize, status: Status, @@ -40,7 +41,7 @@ impl Algorithms for POSTPROB { equation, psi: Psi::new(), theta: Theta::new(), - w: Col::zeros(0), + w: Weights::default(), objf: f64::INFINITY, cycle: 0, status: Status::Starting, diff --git a/src/routines/evaluation/ipm.rs b/src/routines/evaluation/ipm.rs index d56910a49..cd9b67753 100644 --- a/src/routines/evaluation/ipm.rs +++ b/src/routines/evaluation/ipm.rs @@ -1,4 +1,5 @@ use crate::structs::psi::Psi; +use crate::structs::weights::Weights; use anyhow::bail; use faer::linalg::triangular_solve::solve_lower_triangular_in_place; use faer::linalg::triangular_solve::solve_upper_triangular_in_place; @@ -21,15 +22,15 @@ use rayon::prelude::*; /// /// # Returns /// -/// On success, returns a tuple `(lam, obj)` where: -/// - `lam` is a faer::Col containing the computed probability vector, +/// On success, returns a tuple `(weights, obj)` where: +/// - [Weights] contains the optimized weights (probabilities) for each support point. /// - `obj` is the value of the objective function at the solution. /// /// # Errors /// /// This function returns an error if any step in the optimization (e.g. Cholesky factorization) /// fails. -pub fn burke(psi: &Psi) -> anyhow::Result<(Col, f64)> { +pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> { let mut psi = psi.matrix().to_owned(); // Ensure all entries are finite and make them non-negative. @@ -274,7 +275,7 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Col, f64)> { let lam_sum: f64 = lam.iter().sum(); lam = &lam / lam_sum; - Ok((lam, obj)) + Ok((lam.into(), obj)) } #[cfg(test)] @@ -465,7 +466,11 @@ mod tests { // distribution depends on the optimization algorithm's convergence // Just verify that no single weight dominates excessively (basic sanity check) - let max_weight = lam.iter().cloned().fold(f64::NEG_INFINITY, f64::max); + let max_weight = lam + .weights() + .iter() + .cloned() + .fold(f64::NEG_INFINITY, f64::max); assert!( max_weight < 0.1, "No single weight should dominate in uniform matrix (max weight: {})", diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 4ddac556d..63d6d3bdb 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -5,10 +5,10 @@ use crate::routines::output::predictions::NPPredictions; use crate::routines::settings::Settings; use crate::structs::psi::Psi; use crate::structs::theta::Theta; +use crate::structs::weights::Weights; use anyhow::{bail, Context, Result}; use csv::WriterBuilder; use faer::linalg::zip::IntoView; -use faer::Col; use faer_ext::IntoNdarray; use ndarray::{Array, Array1, Array2, Axis}; use pharmsol::prelude::data::*; @@ -31,7 +31,7 @@ pub struct NPResult { data: Data, theta: Theta, psi: Psi, - w: Col, + w: Weights, objf: f64, cycles: usize, status: Status, @@ -48,7 +48,7 @@ impl NPResult { data: Data, theta: Theta, psi: Psi, - w: Col, + w: Weights, objf: f64, cycles: usize, status: Status, @@ -96,7 +96,7 @@ impl NPResult { } /// Get the weights (probabilities) of the support points - pub fn w(&self) -> &Col { + pub fn weights(&self) -> &Weights { &self.w } @@ -144,7 +144,14 @@ impl NPResult { .as_mut() .into_ndarray() .to_owned(); - let w: Array1 = self.w.clone().into_view().iter().cloned().collect(); + let w: Array1 = self + .w + .weights() + .clone() + .into_view() + .iter() + .cloned() + .collect(); let psi: Array2 = self.psi.matrix().as_ref().into_ndarray().to_owned(); let (post_mean, post_median) = posterior_mean_median(&theta, &psi, &w) @@ -258,7 +265,14 @@ impl NPResult { tracing::debug!("Writing population parameter distribution..."); let theta = &self.theta; - let w: Vec = self.w.clone().into_view().iter().cloned().collect(); + let w: Vec = self + .w + .weights() + .clone() + .into_view() + .iter() + .cloned() + .collect(); if w.len() != theta.matrix().nrows() { bail!( diff --git a/src/routines/output/posterior.rs b/src/routines/output/posterior.rs index bc8a19e2b..5d49cc3b8 100644 --- a/src/routines/output/posterior.rs +++ b/src/routines/output/posterior.rs @@ -2,7 +2,7 @@ pub use anyhow::{bail, Result}; use faer::{Col, Mat}; use serde::{Deserialize, Serialize}; -use crate::structs::psi::Psi; +use crate::structs::{psi::Psi, weights::Weights}; /// Posterior probabilities for each support points #[derive(Debug, Clone)] @@ -184,20 +184,20 @@ impl<'de> Deserialize<'de> for Posterior { /// Calculates the posterior probabilities for each support point given the weights /// /// The shape is the same as [Psi], and thus subjects are the rows and support points are the columns. -pub fn posterior(psi: &Psi, w: &Col) -> Result { - if psi.matrix().ncols() != w.nrows() { +pub fn posterior(psi: &Psi, w: &Weights) -> Result { + if psi.matrix().ncols() != w.len() { bail!( "Number of rows in psi ({}) and number of weights ({}) do not match.", psi.matrix().nrows(), - w.nrows() + w.len() ); } let psi_matrix = psi.matrix(); - let py = psi_matrix * w; + let py = psi_matrix * w.weights(); let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { - psi_matrix.get(i, j) * w.get(j) / py.get(i) + psi_matrix.get(i, j) * w.weights().get(j) / py.get(i) }); Ok(posterior.into()) diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs index bbb89ea4e..411104ca9 100644 --- a/src/routines/output/predictions.rs +++ b/src/routines/output/predictions.rs @@ -1,11 +1,10 @@ use anyhow::{bail, Result}; -use faer::Col; use pharmsol::{prelude::simulator::Prediction, Data, Event, Predictions as PredTrait}; use serde::{Deserialize, Serialize}; use crate::{ routines::output::{posterior::Posterior, weighted_median}, - structs::theta::Theta, + structs::{theta::Theta, weights::Weights}, }; // Structure for the output @@ -104,7 +103,7 @@ impl NPPredictions { equation: &impl pharmsol::prelude::simulator::Equation, data: &Data, theta: Theta, - w: &Col, + w: &Weights, posterior: &Posterior, idelta: f64, tad: f64, diff --git a/src/structs/weights.rs b/src/structs/weights.rs index 21d4742ab..ce3c81082 100644 --- a/src/structs/weights.rs +++ b/src/structs/weights.rs @@ -5,6 +5,7 @@ use std::ops::{Index, IndexMut}; /// The weight (probabilities) for each support point in the model. /// /// This struct is used to hold the weights for each support point in the model. +/// It is a thin wrapper around [faer::Col] to provide additional functionality and context #[derive(Debug, Clone)] pub struct Weights { weights: Col, @@ -49,6 +50,10 @@ impl Weights { pub fn to_vec(&self) -> Vec { self.weights.iter().cloned().collect() } + + pub fn iter(&self) -> impl Iterator + '_ { + self.weights.iter().cloned() + } } impl Serialize for Weights { From 114625d39db4492e32604649a6b3870e173a9344 Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 24 Aug 2025 12:21:43 +0200 Subject: [PATCH 13/15] Improve calculation of op and pred.csv Pred.csv now contains observations (where available), and op.csv is the filtered version of pred.csv that only includes predictions where there exists an associated observation --- src/routines/output/mod.rs | 82 ++++++++++++------------------ src/routines/output/predictions.rs | 4 +- 2 files changed, 34 insertions(+), 52 deletions(-) diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 63d6d3bdb..9ef769653 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -107,13 +107,10 @@ impl NPResult { let idelta: f64 = self.settings.predictions().idelta; let tad = self.settings.predictions().tad; self.cyclelog.write(&self.settings)?; - self.write_obs().context("Failed to write observations")?; self.write_theta().context("Failed to write theta")?; - self.write_obspred() - .context("Failed to write observed-predicted file")?; - self.write_pred(idelta, tad) - .context("Failed to write predictions")?; self.write_covs().context("Failed to write covariates")?; + self.write_predictions(idelta, tad) + .context("Failed to write predictions")?; self.write_posterior() .context("Failed to write posterior")?; } @@ -373,52 +370,13 @@ impl NPResult { Ok(()) } - /// Write the observations, which is the reformatted input data - pub fn write_obs(&self) -> Result<()> { - tracing::debug!("Writing observations..."); - let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?; - - let mut writer = WriterBuilder::new() - .has_headers(true) - .from_writer(&outputfile.file); - - #[derive(Serialize)] - struct Row { - id: String, - block: usize, - time: f64, - out: Option, - outeq: usize, - } - - for subject in self.data.subjects() { - for occasion in subject.occasions() { - for event in occasion.iter() { - if let Event::Observation(event) = event { - let row = Row { - id: subject.id().clone(), - block: occasion.index(), - time: event.time(), - out: event.value(), - outeq: event.outeq(), - }; - writer.serialize(row)?; - } - } - } - } - writer.flush()?; - - tracing::debug!("Observations written to {:?}", &outputfile.relative_path()); - Ok(()) - } - /// Writes the predictions - pub fn write_pred(&self, idelta: f64, tad: f64) -> Result<()> { + pub fn write_predictions(&self, idelta: f64, tad: f64) -> Result<()> { tracing::debug!("Writing predictions..."); let posterior = posterior(&self.psi, &self.w)?; + // Calculate the predictions let predictions = NPPredictions::calculate( &self.equation, &self.data, @@ -429,11 +387,11 @@ impl NPResult { tad, )?; - // Create the output file and writer for pred.csv - let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?; + // Write (full) predictions to pred.csv + let outputfile_pred = OutputFile::new(&self.settings.output().path, "pred.csv")?; let mut writer = WriterBuilder::new() .has_headers(true) - .from_writer(&outputfile.file); + .from_writer(&outputfile_pred.file); // Write each prediction row for row in predictions.predictions() { @@ -441,7 +399,31 @@ impl NPResult { } writer.flush()?; - tracing::debug!("Predictions written to {:?}", &outputfile.relative_path()); + tracing::debug!( + "Predictions written to {:?}", + &outputfile_pred.relative_path() + ); + + // Write observations and predictions to op.csv + let outputfile_op = OutputFile::new(&self.settings.output().path, "op.csv")?; + let mut writer = WriterBuilder::new() + .has_headers(true) + .from_writer(&outputfile_op.file); + + // Write each prediction row + for row in predictions + .predictions() + .iter() + .filter(|r| r.obs().is_some()) + { + writer.serialize(row)?; + } + + writer.flush()?; + tracing::debug!( + "Observed-predicted values written to {:?}", + &outputfile_op.relative_path() + ); Ok(()) } diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs index 411104ca9..7f0733045 100644 --- a/src/routines/output/predictions.rs +++ b/src/routines/output/predictions.rs @@ -87,14 +87,14 @@ impl NPPredictions { &self.predictions } - /// Calculate the populatuion and posterior predictions + /// Calculate the population and posterior predictions /// /// # Arguments /// * `equation` - The equation to use for simulation /// * `data` - The data to use for simulation /// * `theta` - The theta values for the simulation /// * `w` - The weights for the simulation - /// * `posterior` - The posterior values for the simulation + /// * `posterior` - The posterior probabilities for the simulation /// * `idelta` - The delta for the simulation /// * `tad` - The time after dose for the simulation /// # Returns From 1394b89cd490e86229c3e53e95b7f39c967367f4 Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 7 Sep 2025 14:01:34 +0200 Subject: [PATCH 14/15] Documentation Also bumps pharmsol --- Cargo.toml | 2 +- src/routines/output/predictions.rs | 32 +++++++++++------------------- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 98914a3f6..f35b9dd03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [ ] } faer = "0.22.4" faer-ext = { version = "0.6.0", features = ["nalgebra", "ndarray"] } -pharmsol = "=0.17.0" +pharmsol = "=0.17.1" rand = "0.9.0" anyhow = "1.0.97" rayon = "1.10.0" diff --git a/src/routines/output/predictions.rs b/src/routines/output/predictions.rs index 7f0733045..d6a961f2d 100644 --- a/src/routines/output/predictions.rs +++ b/src/routines/output/predictions.rs @@ -1,5 +1,5 @@ use anyhow::{bail, Result}; -use pharmsol::{prelude::simulator::Prediction, Data, Event, Predictions as PredTrait}; +use pharmsol::{prelude::simulator::Prediction, Data, Predictions as PredTrait}; use serde::{Deserialize, Serialize}; use crate::{ @@ -10,14 +10,23 @@ use crate::{ // Structure for the output #[derive(Debug, Clone, Serialize, Deserialize)] pub struct NPPredictionRow { + /// The subject ID id: String, + /// The time of the prediction time: f64, + /// The output equation number outeq: usize, + /// The occasion of the prediction block: usize, + /// The observed value, if any obs: Option, + /// The population mean prediction pop_mean: f64, + /// The population median prediction pop_median: f64, + /// The posterior mean prediction post_mean: f64, + /// The posterior median prediction post_median: f64, } @@ -51,6 +60,7 @@ impl NPPredictionRow { } } +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct NPPredictions { predictions: Vec, } @@ -123,24 +133,6 @@ impl NPPredictions { for subject in subjects.iter().enumerate() { let (subject_index, subject) = subject; - // Get a vector of occasions for this subject, for each predictions - let occasions = subject - .occasions() - .iter() - .flat_map(|o| { - o.events() - .iter() - .filter_map(|e| { - if let Event::Observation(_obs) = e { - Some(o.index()) - } else { - None - } - }) - .collect::>() - }) - .collect::>(); - // Container for predictions for this subject // This will hold predictions for each support point // The outer vector is for each support point @@ -219,7 +211,7 @@ impl NPPredictions { id: subject.id().clone(), time: p.time(), outeq: p.outeq(), - block: occasions[j], + block: p.occasion(), obs: p.observation(), pop_mean: pop_mean[j], pop_median: pop_median[j], From 77eb47215857e409d9fe9754e574f5e394e619ae Mon Sep 17 00:00:00 2001 From: Markus Hovd Date: Fri, 12 Sep 2025 10:04:13 +0200 Subject: [PATCH 15/15] Update src/routines/output/mod.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/routines/output/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index 9ef769653..d42b63a84 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -512,7 +512,7 @@ pub(crate) fn median(data: &[f64]) -> f64 { } } -fn weighted_median(data: &[f64], weights: &Vec) -> f64 { +fn weighted_median(data: &[f64], weights: &[f64]) -> f64 { // Ensure the data and weights arrays have the same length assert_eq!( data.len(),