Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/iov/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ fn main() -> Result<()> {
let data = data::read_pmetrics("examples/iov/test.csv").unwrap();
let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap();
algorithm.initialize().unwrap();
while !algorithm.next_cycle().unwrap() {}
let result = algorithm.into_npresult();
let result = algorithm.fit().unwrap();
result.write_outputs().unwrap();

Ok(())
Expand Down
3 changes: 1 addition & 2 deletions examples/meta/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ fn main() {
let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap();
// let result = algorithm.fit().unwrap();
algorithm.initialize().unwrap();
while !algorithm.next_cycle().unwrap() {}
let result = algorithm.into_npresult();
let result = algorithm.fit().unwrap();
result.write_outputs().unwrap();
}
3 changes: 1 addition & 2 deletions examples/new_iov/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ fn main() {
let data = data::read_pmetrics("examples/new_iov/data.csv").unwrap();
let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap();
algorithm.initialize().unwrap();
while !algorithm.next_cycle().unwrap() {}
let result = algorithm.into_npresult();
let result = algorithm.fit().unwrap();
result.write_outputs().unwrap();
}
3 changes: 1 addition & 2 deletions examples/theophylline/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ fn main() {
let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap();
// let result = algorithm.fit().unwrap();
algorithm.initialize().unwrap();
while !algorithm.next_cycle().unwrap() {}
let result = algorithm.into_npresult();
let result = algorithm.fit().unwrap();
result.write_outputs().unwrap();
}
3 changes: 1 addition & 2 deletions examples/vanco_sde/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ fn main() {

let mut algorithm = dispatch_algorithm(settings, sde, data).unwrap();
algorithm.initialize().unwrap();
while !algorithm.next_cycle().unwrap() {}
let result = algorithm.into_npresult();
let result = algorithm.fit().unwrap();
result.write_outputs().unwrap();
}
120 changes: 82 additions & 38 deletions src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
.collect::<Vec<_>>();

if !indices.is_empty() {
let subject: Vec<&Subject> = self.get_data().subjects();
let subject: Vec<&Subject> = self.data().subjects();
let zero_probability_subjects: Vec<&String> =
indices.iter().map(|&i| subject[i].id()).collect();

Expand All @@ -89,7 +89,7 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {
for index in &indices {
tracing::debug!("Subject with zero probability: {}", subject[*index].id());

let error_model = self.get_settings().errormodels().clone();
let error_model = self.settings().errormodels().clone();

// Simulate all support points in parallel
let spp_results: Vec<_> = self
Expand Down Expand Up @@ -207,54 +207,103 @@ pub trait Algorithms<E: Equation + Send + 'static>: Sync + Send + 'static {

Ok(())
}
fn get_settings(&self) -> &Settings;

fn settings(&self) -> &Settings;
/// Get the equation used in the algorithm
fn equation(&self) -> &E;
fn get_data(&self) -> &Data;
/// Get the data used in the algorithm
fn data(&self) -> &Data;
fn get_prior(&self) -> Theta;
fn inc_cycle(&mut self) -> usize;
fn get_cycle(&self) -> usize;
/// Increment the cycle counter and return the new value
fn increment_cycle(&mut self) -> usize;
/// Get the current cycle number
fn cycle(&self) -> usize;
/// Set the current [Theta]
fn set_theta(&mut self, theta: Theta);
/// Get the current [Theta]
fn theta(&self) -> &Theta;
/// Get the current [Psi]
fn psi(&self) -> &Psi;
/// Get the current likelihood
fn likelihood(&self) -> f64;
/// Get the current negative two log-likelihood
fn n2ll(&self) -> f64 {
-2.0 * self.likelihood()
}
/// Get the current [Status] of the algorithm
fn status(&self) -> &Status;
/// Set the current [Status] of the algorithm
fn set_status(&mut self, status: Status);
fn convergence_evaluation(&mut self);
fn converged(&self) -> bool;
/// Evaluate convergence criteria and update status
fn evaluation(&mut self) -> Result<Status>;

/// Create and log a cycle state with the current algorithm state
fn log_cycle_state(&mut self);

/// Initialize the algorithm, setting up initial [Theta] and [Status]
fn initialize(&mut self) -> Result<()> {
// If a stop file exists in the current directory, remove it
if Path::new("stop").exists() {
tracing::info!("Removing existing stop file prior to run");
fs::remove_file("stop").context("Unable to remove previous stop file")?;
}
self.set_status(Status::InProgress);
self.set_status(Status::Continue);
self.set_theta(self.get_prior());
Ok(())
}
fn evaluation(&mut self) -> Result<()>;
fn estimation(&mut self) -> Result<()>;
/// Performs condensation of [Theta] and updates [Psi]
///
/// This step reduces the number of support points in [Theta] based on the current weights,
/// and updates the [Psi] matrix accordingly to reflect the new set of support points.
/// It is typically performed after the estimation step in each cycle of the algorithm.
fn condensation(&mut self) -> Result<()>;

/// Performs optimizations on the current [ErrorModels] and updates [Psi] accordingly
///
/// This step refines the error model parameters to better fit the data,
/// and subsequently updates the [Psi] matrix to reflect these changes.
fn optimizations(&mut self) -> Result<()>;
fn logs(&self);

/// Performs expansion of [Theta]
///
/// This step increases the number of support points in [Theta] based on the current distribution,
/// allowing for exploration of the parameter space.
fn expansion(&mut self) -> Result<()>;
fn next_cycle(&mut self) -> Result<bool> {
if self.inc_cycle() > 1 {

/// Proceed to the next cycle of the algorithm
///
/// This method increments the cycle counter, performs expansion if necessary,
/// and then runs the estimation, condensation, optimization, logging, and evaluation steps
/// in sequence. It returns the current [Status] of the algorithm after completing these steps.
fn next_cycle(&mut self) -> Result<Status> {
let cycle = self.increment_cycle();

if cycle > 1 {
self.expansion()?;
}
let span = tracing::info_span!("", "{}", format!("Cycle {}", self.get_cycle()));

let span = tracing::info_span!("", "{}", format!("Cycle {}", self.cycle()));
let _enter = span.enter();
self.evaluation()?;
self.estimation()?;
self.condensation()?;
self.optimizations()?;
self.logs();
self.convergence_evaluation();
Ok(self.converged())
self.evaluation()
}

/// Fit the model until convergence or stopping criteria are met
///
/// This method runs the full fitting process, starting with initialization,
/// followed by iterative cycles of estimation, condensation, optimization, and evaluation
/// until the algorithm converges or meets a stopping criteria.
fn fit(&mut self) -> Result<NPResult<E>> {
self.initialize().unwrap();
while !self.next_cycle()? {}
loop {
match self.next_cycle()? {
Status::Continue => continue,
Status::Stop(_) => break,
}
}
Ok(self.into_npresult())
}

Expand All @@ -274,32 +323,27 @@ pub fn dispatch_algorithm<E: Equation + Send + 'static>(
}
}

/// Represents the status of the algorithm
/// Represents the status/result of the algorithm
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Status {
/// Algorithm is starting up
Starting,
/// Algorithm has converged to a solution
Converged,
/// Algorithm stopped due to reaching maximum cycles
MaxCycles,
/// Algorithm is currently running
InProgress,
/// Algorithm was manually stopped by user
ManualStop,
/// Other status with custom message
Other(String),
Continue,
Stop(StopReason),
}

impl std::fmt::Display for Status {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Status::Starting => write!(f, "Starting"),
Status::Converged => write!(f, "Converged"),
Status::MaxCycles => write!(f, "Maximum cycles reached"),
Status::InProgress => write!(f, "In progress"),
Status::ManualStop => write!(f, "Manual stop requested"),
Status::Other(msg) => write!(f, "{}", msg),
Status::Continue => write!(f, "Continue"),
Status::Stop(s) => write!(f, "Stop: {:?}", s),
}
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]

pub enum StopReason {
Converged,
MaxCycles,
Stopped,
Completed,
}
Loading