diff --git a/examples/iov/main.rs b/examples/iov/main.rs index 292572f32..75ca02255 100644 --- a/examples/iov/main.rs +++ b/examples/iov/main.rs @@ -52,8 +52,7 @@ fn main() -> Result<()> { let data = data::read_pmetrics("examples/iov/test.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap(); algorithm.initialize().unwrap(); - while !algorithm.next_cycle().unwrap() {} - let result = algorithm.into_npresult(); + let result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); Ok(()) diff --git a/examples/meta/main.rs b/examples/meta/main.rs index e30d54292..b809c9d24 100644 --- a/examples/meta/main.rs +++ b/examples/meta/main.rs @@ -64,7 +64,6 @@ fn main() { let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); // let result = algorithm.fit().unwrap(); algorithm.initialize().unwrap(); - while !algorithm.next_cycle().unwrap() {} - let result = algorithm.into_npresult(); + let result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/examples/new_iov/main.rs b/examples/new_iov/main.rs index f61be08cc..d6b309a0f 100644 --- a/examples/new_iov/main.rs +++ b/examples/new_iov/main.rs @@ -55,7 +55,6 @@ fn main() { let data = data::read_pmetrics("examples/new_iov/data.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap(); algorithm.initialize().unwrap(); - while !algorithm.next_cycle().unwrap() {} - let result = algorithm.into_npresult(); + let result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/examples/theophylline/main.rs b/examples/theophylline/main.rs index 455452d65..9e9e0db6f 100644 --- a/examples/theophylline/main.rs +++ b/examples/theophylline/main.rs @@ -53,7 +53,6 @@ fn main() { let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); // let result = algorithm.fit().unwrap(); algorithm.initialize().unwrap(); - while !algorithm.next_cycle().unwrap() {} - let result = algorithm.into_npresult(); + let result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/examples/vanco_sde/main.rs b/examples/vanco_sde/main.rs index 37981897b..d1ee4a040 100644 --- a/examples/vanco_sde/main.rs +++ b/examples/vanco_sde/main.rs @@ -78,7 +78,6 @@ fn main() { let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap(); algorithm.initialize().unwrap(); - while !algorithm.next_cycle().unwrap() {} - let result = algorithm.into_npresult(); + let result = algorithm.fit().unwrap(); result.write_outputs().unwrap(); } diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index effa1397a..7d8d48fbd 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -75,7 +75,7 @@ pub trait Algorithms: Sync + Send + 'static { .collect::>(); if !indices.is_empty() { - let subject: Vec<&Subject> = self.get_data().subjects(); + let subject: Vec<&Subject> = self.data().subjects(); let zero_probability_subjects: Vec<&String> = indices.iter().map(|&i| subject[i].id()).collect(); @@ -89,7 +89,7 @@ pub trait Algorithms: Sync + Send + 'static { for index in &indices { tracing::debug!("Subject with zero probability: {}", subject[*index].id()); - let error_model = self.get_settings().errormodels().clone(); + let error_model = self.settings().errormodels().clone(); // Simulate all support points in parallel let spp_results: Vec<_> = self @@ -207,54 +207,103 @@ pub trait Algorithms: Sync + Send + 'static { Ok(()) } - fn get_settings(&self) -> &Settings; + + fn settings(&self) -> &Settings; + /// Get the equation used in the algorithm fn equation(&self) -> &E; - fn get_data(&self) -> &Data; + /// Get the data used in the algorithm + fn data(&self) -> &Data; fn get_prior(&self) -> Theta; - fn inc_cycle(&mut self) -> usize; - fn get_cycle(&self) -> usize; + /// Increment the cycle counter and return the new value + fn increment_cycle(&mut self) -> usize; + /// Get the current cycle number + fn cycle(&self) -> usize; + /// Set the current [Theta] fn set_theta(&mut self, theta: Theta); + /// Get the current [Theta] fn theta(&self) -> Θ + /// Get the current [Psi] fn psi(&self) -> Ψ + /// Get the current likelihood fn likelihood(&self) -> f64; + /// Get the current negative two log-likelihood fn n2ll(&self) -> f64 { -2.0 * self.likelihood() } + /// Get the current [Status] of the algorithm fn status(&self) -> &Status; + /// Set the current [Status] of the algorithm fn set_status(&mut self, status: Status); - fn convergence_evaluation(&mut self); - fn converged(&self) -> bool; + /// Evaluate convergence criteria and update status + fn evaluation(&mut self) -> Result; + + /// Create and log a cycle state with the current algorithm state + fn log_cycle_state(&mut self); + + /// Initialize the algorithm, setting up initial [Theta] and [Status] fn initialize(&mut self) -> Result<()> { // If a stop file exists in the current directory, remove it if Path::new("stop").exists() { tracing::info!("Removing existing stop file prior to run"); fs::remove_file("stop").context("Unable to remove previous stop file")?; } - self.set_status(Status::InProgress); + self.set_status(Status::Continue); self.set_theta(self.get_prior()); Ok(()) } - fn evaluation(&mut self) -> Result<()>; + fn estimation(&mut self) -> Result<()>; + /// Performs condensation of [Theta] and updates [Psi] + /// + /// This step reduces the number of support points in [Theta] based on the current weights, + /// and updates the [Psi] matrix accordingly to reflect the new set of support points. + /// It is typically performed after the estimation step in each cycle of the algorithm. fn condensation(&mut self) -> Result<()>; + + /// Performs optimizations on the current [ErrorModels] and updates [Psi] accordingly + /// + /// This step refines the error model parameters to better fit the data, + /// and subsequently updates the [Psi] matrix to reflect these changes. fn optimizations(&mut self) -> Result<()>; - fn logs(&self); + + /// Performs expansion of [Theta] + /// + /// This step increases the number of support points in [Theta] based on the current distribution, + /// allowing for exploration of the parameter space. fn expansion(&mut self) -> Result<()>; - fn next_cycle(&mut self) -> Result { - if self.inc_cycle() > 1 { + + /// Proceed to the next cycle of the algorithm + /// + /// This method increments the cycle counter, performs expansion if necessary, + /// and then runs the estimation, condensation, optimization, logging, and evaluation steps + /// in sequence. It returns the current [Status] of the algorithm after completing these steps. + fn next_cycle(&mut self) -> Result { + let cycle = self.increment_cycle(); + + if cycle > 1 { self.expansion()?; } - let span = tracing::info_span!("", "{}", format!("Cycle {}", self.get_cycle())); + + let span = tracing::info_span!("", "{}", format!("Cycle {}", self.cycle())); let _enter = span.enter(); - self.evaluation()?; + self.estimation()?; self.condensation()?; self.optimizations()?; - self.logs(); - self.convergence_evaluation(); - Ok(self.converged()) + self.evaluation() } + + /// Fit the model until convergence or stopping criteria are met + /// + /// This method runs the full fitting process, starting with initialization, + /// followed by iterative cycles of estimation, condensation, optimization, and evaluation + /// until the algorithm converges or meets a stopping criteria. fn fit(&mut self) -> Result> { self.initialize().unwrap(); - while !self.next_cycle()? {} + loop { + match self.next_cycle()? { + Status::Continue => continue, + Status::Stop(_) => break, + } + } Ok(self.into_npresult()) } @@ -274,32 +323,27 @@ pub fn dispatch_algorithm( } } -/// Represents the status of the algorithm +/// Represents the status/result of the algorithm #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum Status { - /// Algorithm is starting up - Starting, - /// Algorithm has converged to a solution - Converged, - /// Algorithm stopped due to reaching maximum cycles - MaxCycles, - /// Algorithm is currently running - InProgress, - /// Algorithm was manually stopped by user - ManualStop, - /// Other status with custom message - Other(String), + Continue, + Stop(StopReason), } impl std::fmt::Display for Status { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Status::Starting => write!(f, "Starting"), - Status::Converged => write!(f, "Converged"), - Status::MaxCycles => write!(f, "Maximum cycles reached"), - Status::InProgress => write!(f, "In progress"), - Status::ManualStop => write!(f, "Manual stop requested"), - Status::Other(msg) => write!(f, "{}", msg), + Status::Continue => write!(f, "Continue"), + Status::Stop(s) => write!(f, "Stop: {:?}", s), } } } + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] + +pub enum StopReason { + Converged, + MaxCycles, + Stopped, + Completed, +} diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index ef3979a2f..48afa8e9b 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -1,8 +1,8 @@ -use crate::algorithms::Status; +use crate::algorithms::{Status, StopReason}; use crate::prelude::algorithms::Algorithms; -pub use crate::routines::evaluation::ipm::burke; -pub use crate::routines::evaluation::qr; +pub use crate::routines::estimation::ipm::burke; +pub use crate::routines::estimation::qr; use crate::routines::settings::Settings; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; @@ -44,7 +44,6 @@ pub struct NPAG { cycle: usize, gamma_delta: Vec, error_models: ErrorModels, - converged: bool, status: Status, cycle_log: CycleLog, data: Data, @@ -68,8 +67,7 @@ impl Algorithms for NPAG { cycle: 0, gamma_delta: vec![0.1; settings.errormodels().len()], error_models: settings.errormodels().clone(), - converged: false, - status: Status::Starting, + status: Status::Continue, cycle_log: CycleLog::new(), settings, data, @@ -94,11 +92,11 @@ impl Algorithms for NPAG { ) } - fn get_settings(&self) -> &Settings { + fn settings(&self) -> &Settings { &self.settings } - fn get_data(&self) -> &Data { + fn data(&self) -> &Data { &self.data } @@ -110,12 +108,12 @@ impl Algorithms for NPAG { self.objf } - fn inc_cycle(&mut self) -> usize { + fn increment_cycle(&mut self) -> usize { self.cycle += 1; self.cycle } - fn get_cycle(&self) -> usize { + fn cycle(&self) -> usize { self.cycle } @@ -131,7 +129,32 @@ impl Algorithms for NPAG { &self.psi } - fn convergence_evaluation(&mut self) { + fn evaluation(&mut self) -> Result { + tracing::info!("Objective function = {:.4}", -2.0 * self.objf); + tracing::debug!("Support points: {}", self.theta.nspp()); + + self.error_models.iter().for_each(|(outeq, em)| { + if ErrorModel::None == *em { + return; + } + tracing::debug!( + "Error model for outeq {}: {:.2}", + outeq, + em.factor().unwrap_or_default() + ); + }); + + tracing::debug!("EPS = {:.4}", self.eps); + // Increasing objf signals instability or model misspecification. + if self.last_objf > self.objf + 1e-4 { + tracing::warn!( + "Objective function decreased from {:.4} to {:.4} (delta = {})", + -2.0 * self.last_objf, + -2.0 * self.objf, + -2.0 * self.last_objf - -2.0 * self.objf + ); + } + let psi = self.psi.matrix(); let w = &self.w; if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E { @@ -141,8 +164,9 @@ impl Algorithms for NPAG { self.f1 = pyl.iter().map(|x| x.ln()).sum(); if (self.f1 - self.f0).abs() <= THETA_F { tracing::info!("The model converged after {} cycles", self.cycle,); - self.converged = true; - self.status = Status::Converged; + self.set_status(Status::Stop(StopReason::Converged)); + self.log_cycle_state(); + return Ok(self.status().clone()); } else { self.f0 = self.f1; self.eps = 0.2; @@ -153,37 +177,26 @@ impl Algorithms for NPAG { // Stop if we have reached maximum number of cycles if self.cycle >= self.settings.config().cycles { tracing::warn!("Maximum number of cycles reached"); - self.converged = true; - self.status = Status::MaxCycles; + self.set_status(Status::Stop(StopReason::MaxCycles)); + self.log_cycle_state(); + return Ok(self.status().clone()); } // Stop if stopfile exists if std::path::Path::new("stop").exists() { tracing::warn!("Stopfile detected - breaking"); - self.status = Status::ManualStop; + self.set_status(Status::Stop(StopReason::Stopped)); + self.log_cycle_state(); + return Ok(self.status().clone()); } - // Create state object - let state = NPCycle::new( - self.cycle, - -2. * self.objf, - self.error_models.clone(), - self.theta.clone(), - self.theta.nspp(), - (self.last_objf - self.objf).abs(), - self.status.clone(), - ); - - // Write cycle log - self.cycle_log.push(state); - self.last_objf = self.objf; - } - - fn converged(&self) -> bool { - self.converged + // Continue with normal operation + self.set_status(Status::Continue); + self.log_cycle_state(); + Ok(self.status().clone()) } - fn evaluation(&mut self) -> Result<()> { + fn estimation(&mut self) -> Result<()> { self.psi = calculate_psi( &self.equation, &self.data, @@ -200,7 +213,7 @@ impl Algorithms for NPAG { (self.lambda, _) = match burke(&self.psi) { Ok((lambda, objf)) => (lambda.into(), objf), Err(err) => { - bail!("Error in IPM during evaluation: {:?}", err); + bail!("Error in IPM during estimation: {:?}", err); } }; Ok(()) @@ -353,33 +366,6 @@ impl Algorithms for NPAG { Ok(()) } - fn logs(&self) { - tracing::info!("Objective function = {:.4}", -2.0 * self.objf); - tracing::debug!("Support points: {}", self.theta.nspp()); - - self.error_models.iter().for_each(|(outeq, em)| { - if ErrorModel::None == *em { - return; - } - tracing::debug!( - "Error model for outeq {}: {:.16}", - outeq, - em.factor().unwrap_or_default() - ); - }); - - tracing::debug!("EPS = {:.4}", self.eps); - // Increasing objf signals instability or model misspecification. - if self.last_objf > self.objf + 1e-4 { - tracing::warn!( - "Objective function decreased from {:.4} to {:.4} (delta = {})", - -2.0 * self.last_objf, - -2.0 * self.objf, - -2.0 * self.last_objf - -2.0 * self.objf - ); - } - } - fn expansion(&mut self) -> Result<()> { adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D)?; Ok(()) @@ -392,4 +378,18 @@ impl Algorithms for NPAG { fn status(&self) -> &Status { &self.status } + + fn log_cycle_state(&mut self) { + let state = NPCycle::new( + self.cycle, + -2. * self.objf, + self.error_models.clone(), + self.theta.clone(), + self.theta.nspp(), + (self.last_objf - self.objf).abs(), + self.status.clone(), + ); + self.cycle_log.push(state); + self.last_objf = self.objf; + } } diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index f25563799..9b3d2d74d 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -1,3 +1,4 @@ +use crate::algorithms::StopReason; use crate::routines::initialization::sample_space; use crate::routines::output::{cycles::CycleLog, cycles::NPCycle, NPResult}; use crate::structs::weights::Weights; @@ -6,7 +7,7 @@ use crate::{ prelude::{ algorithms::Algorithms, routines::{ - evaluation::{ipm::burke, qr}, + estimation::{ipm::burke, qr}, settings::Settings, }, }, @@ -66,7 +67,7 @@ impl Algorithms for NPOD { gamma_delta: vec![0.1; settings.errormodels().len()], error_models: settings.errormodels().clone(), converged: false, - status: Status::Starting, + status: Status::Continue, cycle_log: CycleLog::new(), settings, data, @@ -91,11 +92,11 @@ impl Algorithms for NPOD { &self.equation } - fn get_settings(&self) -> &Settings { + fn settings(&self) -> &Settings { &self.settings } - fn get_data(&self) -> &Data { + fn data(&self) -> &Data { &self.data } @@ -103,12 +104,12 @@ impl Algorithms for NPOD { sample_space(&self.settings).unwrap() } - fn inc_cycle(&mut self) -> usize { + fn increment_cycle(&mut self) -> usize { self.cycle += 1; self.cycle } - fn get_cycle(&self) -> usize { + fn cycle(&self) -> usize { self.cycle } @@ -136,48 +137,76 @@ impl Algorithms for NPOD { &self.status } - fn convergence_evaluation(&mut self) { + fn log_cycle_state(&mut self) { + let state = NPCycle::new( + self.cycle, + -2. * self.objf, + self.error_models.clone(), + self.theta.clone(), + self.theta.nspp(), + (self.last_objf - self.objf).abs(), + self.status.clone(), + ); + self.cycle_log.push(state); + self.last_objf = self.objf; + } + + fn evaluation(&mut self) -> Result { + tracing::info!("Objective function = {:.4}", -2.0 * self.objf); + tracing::debug!("Support points: {}", self.theta.nspp()); + self.error_models.iter().for_each(|(outeq, em)| { + if ErrorModel::None == *em { + return; + } + tracing::debug!( + "Error model for outeq {}: {:.16}", + outeq, + em.factor().unwrap_or_default() + ); + }); + // Increasing objf signals instability or model misspecification. + if self.last_objf > self.objf + 1e-4 { + tracing::warn!( + "Objective function decreased from {:.4} to {:.4} (delta = {})", + -2.0 * self.last_objf, + -2.0 * self.objf, + -2.0 * self.last_objf - -2.0 * self.objf + ); + } + if (self.last_objf - self.objf).abs() <= THETA_F { tracing::info!("Objective function convergence reached"); self.converged = true; - self.status = Status::Converged; + self.set_status(Status::Stop(StopReason::Converged)); + self.log_cycle_state(); + return Ok(self.status.clone()); } // Stop if we have reached maximum number of cycles if self.cycle >= self.settings.config().cycles { tracing::warn!("Maximum number of cycles reached"); self.converged = true; - self.status = Status::MaxCycles; + self.set_status(Status::Stop(StopReason::MaxCycles)); + self.log_cycle_state(); + return Ok(self.status.clone()); } // Stop if stopfile exists if std::path::Path::new("stop").exists() { tracing::warn!("Stopfile detected - breaking"); self.converged = true; - self.status = Status::ManualStop; + self.set_status(Status::Stop(StopReason::Stopped)); + self.log_cycle_state(); + return Ok(self.status.clone()); } - // Create state object - let state = NPCycle::new( - self.cycle, - -2. * self.objf, - self.error_models.clone(), - self.theta.clone(), - self.theta.nspp(), - (self.last_objf - self.objf).abs(), - self.status.clone(), - ); - - // Write cycle log - self.cycle_log.push(state); - self.last_objf = self.objf; + // Continue with normal operation + self.status = Status::Continue; + self.log_cycle_state(); + Ok(self.status.clone()) } - fn converged(&self) -> bool { - self.converged - } - - fn evaluation(&mut self) -> Result<()> { + fn estimation(&mut self) -> Result<()> { let error_model: ErrorModels = self.error_models.clone(); self.psi = calculate_psi( @@ -341,30 +370,6 @@ impl Algorithms for NPOD { Ok(()) } - fn logs(&self) { - tracing::info!("Objective function = {:.4}", -2.0 * self.objf); - tracing::debug!("Support points: {}", self.theta.nspp()); - self.error_models.iter().for_each(|(outeq, em)| { - if ErrorModel::None == *em { - return; - } - tracing::debug!( - "Error model for outeq {}: {:.16}", - outeq, - em.factor().unwrap_or_default() - ); - }); - // Increasing objf signals instability or model misspecification. - if self.last_objf > self.objf + 1e-4 { - tracing::warn!( - "Objective function decreased from {:.4} to {:.4} (delta = {})", - -2.0 * self.last_objf, - -2.0 * self.objf, - -2.0 * self.last_objf - -2.0 * self.objf - ); - } - } - fn expansion(&mut self) -> Result<()> { // If no stop signal, add new point to theta based on the optimization of the D function let psi = self.psi().matrix().as_ref().into_ndarray().to_owned(); diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index 7df1e75c3..3d0325d58 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -1,5 +1,5 @@ use crate::{ - algorithms::Status, + algorithms::{Status, StopReason}, prelude::algorithms::Algorithms, structs::{ psi::{calculate_psi, Psi}, @@ -14,7 +14,7 @@ use pharmsol::prelude::{ simulator::Equation, }; -use crate::routines::evaluation::ipm::burke; +use crate::routines::estimation::ipm::burke; use crate::routines::initialization; use crate::routines::output::{cycles::CycleLog, NPResult}; use crate::routines::settings::Settings; @@ -44,7 +44,7 @@ impl Algorithms for POSTPROB { w: Weights::default(), objf: f64::INFINITY, cycle: 0, - status: Status::Starting, + status: Status::Continue, error_models: settings.errormodels().clone(), settings, data, @@ -65,7 +65,7 @@ impl Algorithms for POSTPROB { self.cyclelog.clone(), ) } - fn get_settings(&self) -> &Settings { + fn settings(&self) -> &Settings { &self.settings } @@ -73,7 +73,7 @@ impl Algorithms for POSTPROB { &self.equation } - fn get_data(&self) -> &Data { + fn data(&self) -> &Data { &self.data } @@ -85,11 +85,11 @@ impl Algorithms for POSTPROB { self.objf } - fn inc_cycle(&mut self) -> usize { + fn increment_cycle(&mut self) -> usize { 0 } - fn get_cycle(&self) -> usize { + fn cycle(&self) -> usize { 0 } @@ -113,16 +113,12 @@ impl Algorithms for POSTPROB { &self.status } - fn convergence_evaluation(&mut self) { - // POSTPROB algorithm converges after a single evaluation - self.status = Status::MaxCycles; + fn evaluation(&mut self) -> Result { + self.status = Status::Stop(StopReason::Converged); + Ok(self.status.clone()) } - fn converged(&self) -> bool { - true - } - - fn evaluation(&mut self) -> Result<()> { + fn estimation(&mut self) -> Result<()> { self.psi = calculate_psi( &self.equation, &self.data, @@ -142,9 +138,21 @@ impl Algorithms for POSTPROB { Ok(()) } - fn logs(&self) {} - fn expansion(&mut self) -> Result<()> { Ok(()) } + + fn log_cycle_state(&mut self) { + // Postprob doesn't track last_objf, so we use 0.0 as the delta + let state = crate::routines::output::cycles::NPCycle::new( + self.cycle, + self.objf, + self.error_models.clone(), + self.theta.clone(), + self.theta.nspp(), + 0.0, + self.status.clone(), + ); + self.cyclelog.push(state); + } } diff --git a/src/routines/evaluation/ipm.rs b/src/routines/estimation/ipm.rs similarity index 100% rename from src/routines/evaluation/ipm.rs rename to src/routines/estimation/ipm.rs diff --git a/src/routines/evaluation/mod.rs b/src/routines/estimation/mod.rs similarity index 100% rename from src/routines/evaluation/mod.rs rename to src/routines/estimation/mod.rs diff --git a/src/routines/evaluation/qr.rs b/src/routines/estimation/qr.rs similarity index 100% rename from src/routines/evaluation/qr.rs rename to src/routines/estimation/qr.rs diff --git a/src/routines/mod.rs b/src/routines/mod.rs index a11a1d3c4..af25d67e6 100644 --- a/src/routines/mod.rs +++ b/src/routines/mod.rs @@ -1,7 +1,7 @@ // Routines for condensation pub mod condensation; -// Routines for evaluation -pub mod evaluation; +// Routines for estimation +pub mod estimation; // Routines for expansion pub mod expansion; // Routines for initialization diff --git a/src/routines/output/cycles.rs b/src/routines/output/cycles.rs index e3a30a58f..f65720aa9 100644 --- a/src/routines/output/cycles.rs +++ b/src/routines/output/cycles.rs @@ -4,7 +4,7 @@ use pharmsol::{ErrorModel, ErrorModels}; use serde::Serialize; use crate::{ - algorithms::Status, + algorithms::{Status, StopReason}, prelude::Settings, routines::output::{median, OutputFile}, structs::theta::Theta, @@ -81,7 +81,7 @@ impl NPCycle { theta: Theta::new(), nspp: 0, delta_objf: 0.0, - status: Status::Starting, + status: Status::Continue, } } } @@ -146,7 +146,10 @@ impl CycleLog { for cycle in &self.cycles { writer.write_field(format!("{}", cycle.cycle))?; - writer.write_field(format!("{}", cycle.status == Status::Converged))?; + writer.write_field(format!( + "{}", + cycle.status == Status::Stop(StopReason::Converged) + ))?; writer.write_field(format!("{}", cycle.status))?; writer.write_field(format!("{}", cycle.objf))?; writer diff --git a/src/routines/output/mod.rs b/src/routines/output/mod.rs index e57d3986b..2680f550d 100644 --- a/src/routines/output/mod.rs +++ b/src/routines/output/mod.rs @@ -1,4 +1,4 @@ -use crate::algorithms::Status; +use crate::algorithms::{Status, StopReason}; use crate::prelude::*; use crate::routines::output::cycles::CycleLog; use crate::routines::output::predictions::NPPredictions; @@ -84,7 +84,7 @@ impl NPResult { } pub fn converged(&self) -> bool { - self.status == Status::Converged + self.status == Status::Stop(StopReason::Converged) } pub fn get_theta(&self) -> &Theta { diff --git a/tests/cycles_tests.rs b/tests/cycles_tests.rs deleted file mode 100644 index 2d2ebf6b0..000000000 --- a/tests/cycles_tests.rs +++ /dev/null @@ -1,129 +0,0 @@ -use anyhow::Result; -use pmcore::algorithms::Status; -use pmcore::prelude::*; -use pmcore::routines::output::cycles::{CycleLog, NPCycle}; -use pmcore::structs::theta::Theta; - -/// Test NPCycle creation and accessors -#[test] -fn test_npcycle_creation() -> Result<()> { - let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); - let ems = ErrorModels::new().add(0, em)?; - - let theta = Theta::new(); - - let cycle = NPCycle::new( - 1, // cycle - 100.5, // objf - ems.clone(), // error_models - theta.clone(), // theta - 10, // nspp - -5.2, // delta_objf - Status::Converged, // status - ); - - // Test accessors - assert_eq!(cycle.cycle(), 1); - assert_eq!(cycle.objf(), 100.5); - assert_eq!(cycle.nspp(), 10); - assert_eq!(cycle.delta_objf(), -5.2); - - Ok(()) -} - -/// Test NPCycle placeholder -#[test] -fn test_npcycle_placeholder() { - let cycle = NPCycle::placeholder(); - - // Placeholder should have default values - assert_eq!(cycle.cycle(), 0); - assert_eq!(cycle.objf(), 0.0); - assert_eq!(cycle.nspp(), 0); - assert_eq!(cycle.delta_objf(), 0.0); -} - -/// Test CycleLog creation and operations -#[test] -fn test_cycle_log() -> Result<()> { - let mut log = CycleLog::new(); - - let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); - let ems = ErrorModels::new().add(0, em)?; - let theta = Theta::new(); - - // Add a few cycles - for i in 1..=5 { - let cycle = NPCycle::new( - i, - 100.0 - (i as f64) * 2.0, - ems.clone(), - theta.clone(), - 10 + i, - -2.0, - if i == 5 { - Status::Converged - } else { - Status::InProgress - }, - ); - log.push(cycle); - } - - // Check that cycles were added - assert_eq!(log.cycles().len(), 5); - - // Check individual cycles - let cycles = log.cycles(); - assert_eq!(cycles[0].cycle(), 1); - assert_eq!(cycles[4].cycle(), 5); - - Ok(()) -} - -/// Test CycleLog with different statuses -#[test] -fn test_cycle_log_statuses() -> Result<()> { - let mut log = CycleLog::new(); - - let em = ErrorModel::additive(ErrorPoly::new(0.0, 0.10, 0.0, 0.0), 2.0); - let ems = ErrorModels::new().add(0, em)?; - let theta = Theta::new(); - - let statuses = vec![Status::Starting, Status::InProgress, Status::Converged]; - - for (i, status) in statuses.iter().enumerate() { - let cycle = NPCycle::new( - i + 1, - 100.0, - ems.clone(), - theta.clone(), - 10, - 0.0, - status.clone(), - ); - log.push(cycle); - } - - assert_eq!(log.cycles().len(), 3); - - Ok(()) -} - -/// Test Status enum display -#[test] -fn test_status_display() { - let status_starting = Status::Starting; - let status_progress = Status::InProgress; - let status_converged = Status::Converged; - let status_max = Status::MaxCycles; - - // These should be displayable - let _ = format!("{:?}", status_starting); - let _ = format!("{:?}", status_progress); - let _ = format!("{:?}", status_converged); - let _ = format!("{:?}", status_max); - - // Test Display trait - assert!(format!("{}", status_converged).contains("Converged")); -} diff --git a/tests/ipm_tests.rs b/tests/ipm_tests.rs index f0059f3dd..efcd3e0e2 100644 --- a/tests/ipm_tests.rs +++ b/tests/ipm_tests.rs @@ -19,7 +19,7 @@ fn test_burke_ipm_simple() -> Result<()> { let psi = Psi::from(mat); // Run Burke's IPM - let result = pmcore::routines::evaluation::ipm::burke(&psi); + let result = pmcore::routines::estimation::ipm::burke(&psi); // Should succeed assert!(result.is_ok()); @@ -58,7 +58,7 @@ fn test_burke_ipm_larger() -> Result<()> { let psi = Psi::from(mat); // Run Burke's IPM - let result = pmcore::routines::evaluation::ipm::burke(&psi); + let result = pmcore::routines::estimation::ipm::burke(&psi); assert!(result.is_ok()); @@ -92,7 +92,7 @@ fn test_burke_ipm_uniform() -> Result<()> { let psi = Psi::from(mat); // Run Burke's IPM - let result = pmcore::routines::evaluation::ipm::burke(&psi); + let result = pmcore::routines::estimation::ipm::burke(&psi); assert!(result.is_ok()); @@ -131,7 +131,7 @@ fn test_burke_ipm_with_negatives() -> Result<()> { let psi = Psi::from(mat); // Run Burke's IPM - should handle negatives by taking absolute value - let result = pmcore::routines::evaluation::ipm::burke(&psi); + let result = pmcore::routines::estimation::ipm::burke(&psi); assert!(result.is_ok()); @@ -160,7 +160,7 @@ fn test_burke_ipm_with_infinites() { let psi = Psi::from(mat); // Run Burke's IPM - should fail with infinite values - let result = pmcore::routines::evaluation::ipm::burke(&psi); + let result = pmcore::routines::estimation::ipm::burke(&psi); assert!(result.is_err(), "Should fail with infinite values"); } @@ -177,7 +177,7 @@ fn test_burke_ipm_with_nan() { let psi = Psi::from(mat); // Run Burke's IPM - should fail with NaN values - let result = pmcore::routines::evaluation::ipm::burke(&psi); + let result = pmcore::routines::estimation::ipm::burke(&psi); assert!(result.is_err(), "Should fail with NaN values"); } @@ -195,7 +195,7 @@ fn test_burke_ipm_high_dimensional() -> Result<()> { let psi = Psi::from(mat); // Run Burke's IPM - let result = pmcore::routines::evaluation::ipm::burke(&psi); + let result = pmcore::routines::estimation::ipm::burke(&psi); assert!(result.is_ok());