diff --git a/.gitignore b/.gitignore index a34b170b..5d7e9858 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ Cargo.lock /paper_files paper.html /tests/browser-e2e/node_modules/ +docs/ diff --git a/Cargo.toml b/Cargo.toml index 35f233d1..18698813 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,3 +79,7 @@ harness = false [[bench]] name = "runtime_matrix" harness = false + +[[bench]] +name = "likelihood_matrix" +harness = false diff --git a/README.md b/README.md index b1620b8e..73932de2 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ A high-performance Rust library for pharmacokinetic/pharmacodynamic (PK/PD) simu ## Installation -Add `pharmsol` to your `Cargo.toml`, either manually or using +Add `pharmsol` to `Cargo.toml`: ```bash cargo add pharmsol @@ -16,65 +16,72 @@ cargo add pharmsol ## Quick Start +Most Rust-first workflows start with one of the equation macros: `analytical!`, +`ode!`, or `sde!`. Here is a simple one-compartment IV infusion model using `analytical!`: + ```rust -use pharmsol::*; +use pharmsol::prelude::*; + +let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, +}; -// Create a subject with an IV infusion and observations let subject = Subject::builder("patient_001") - .infusion(0.0, 500.0, 0, 0.5) // 500 units over 0.5 hours - .observation(0.5, 1.645, 0) - .observation(1.0, 1.216, 0) - .observation(2.0, 0.462, 0) - .observation(4.0, 0.063, 0) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); -// Define parameters: ke (elimination rate), v (volume) -let ke = 1.022; -let v = 194.0; - -// Use the built-in one-compartment analytical solution -let analytical = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; // Concentration = Amount / Volume - }, - (1, 1), // (compartments, outputs) -); - -// Get predictions -let predictions = analytical.estimate_predictions(&subject, &vec![ke, v]).unwrap(); +let predictions = analytical + .estimate_predictions(&subject, &[1.022, 194.0]) + .unwrap(); ``` -## ODE-Based Models +## Modeling Surfaces -For custom or complex models, define your own ODEs: +Here is the same one-compartment IV setup written as an ODE: ```rust -use pharmsol::*; +use pharmsol::prelude::*; -let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, _v); - // One-compartment model with IV infusion support - dx[0] = -ke * x[0] + rateiv[0]; +let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - (1, 1), -); +}; ``` -## Supported Analytical Models +See [examples/analytical_readme.rs](examples/analytical_readme.rs), +[examples/ode_readme.rs](examples/ode_readme.rs), +[examples/sde_readme.rs](examples/sde_readme.rs), +[examples/analytical_vs_ode.rs](examples/analytical_vs_ode.rs), and +[examples/compare_solvers.rs](examples/compare_solvers.rs). For migration-oriented notes, +see [docs/analytical-authoring-migration.md](docs/analytical-authoring-migration.md) and +[docs/ode-authoring-migration.md](docs/ode-authoring-migration.md). + +### Built-In Analytical Kernels - [x] One-compartment with IV infusion - [x] One-compartment with IV infusion and oral absorption @@ -83,6 +90,21 @@ let ode = equation::ODE::new( - [x] Three-compartment with IV infusion - [x] Three-compartment with IV infusion and oral absorption +## DSL and Runtime Targets + +If the model needs to be loaded or compiled at runtime, pharmsol also provides a DSL with +the same broad modeling coverage: ODE, analytical, and SDE authoring. The DSL can target +an in-process JIT runtime, native ahead-of-time artifacts, or WASM bundles depending on +how you want to ship and execute the model. + +- `dsl-jit`: compile DSL source into a runtime model inside the current process. +- `dsl-aot` and `dsl-aot-load`: emit a native artifact and load it later. +- `dsl-wasm`: compile and execute portable WASM model artifacts. + +See [examples/dsl_runtime_jit.rs](examples/dsl_runtime_jit.rs) for the in-repo JIT flow. +The companion `pharmsol-examples` crate includes end-to-end native AOT and WASM runtime +examples. + ## Performance Analytical solutions provide 20-33× speedups compared to equivalent ODE formulations. See [benchmarks](benches/) for details. @@ -96,12 +118,12 @@ use pharmsol::prelude::*; use pharmsol::nca::NCAOptions; let subject = Subject::builder("patient_001") - .bolus(0.0, 100.0, 0) // 100 mg oral dose - .observation(0.5, 5.0, 0) - .observation(1.0, 10.0, 0) - .observation(2.0, 8.0, 0) - .observation(4.0, 4.0, 0) - .observation(8.0, 2.0, 0) + .bolus(0.0, 100.0, "oral") // 100 mg oral dose + .observation(0.5, 5.0, "cp") + .observation(1.0, 10.0, "cp") + .observation(2.0, 8.0, "cp") + .observation(4.0, 4.0, "cp") + .observation(8.0, 2.0, "cp") .build(); let result = subject.nca(&NCAOptions::default()).expect("NCA failed"); diff --git a/benches/likelihood_matrix.rs b/benches/likelihood_matrix.rs new file mode 100644 index 00000000..b8f94d54 --- /dev/null +++ b/benches/likelihood_matrix.rs @@ -0,0 +1,247 @@ +use criterion::{ + criterion_group, criterion_main, BenchmarkId, Criterion, SamplingMode, Throughput, +}; +use ndarray::Array2; +use pharmsol::prelude::simulator::{log_likelihood_batch, log_likelihood_matrix}; +use pharmsol::prelude::*; +use pharmsol::{Cache, ResidualErrorModel, ResidualErrorModels, ODE}; +use std::hint::black_box; +use std::time::Duration; + +fn example_equation() -> ODE { + equation::ODE::new( + |x, p, _t, dx, _b, _rateiv, _cov| { + fetch_params!(p, ka, ke, _tlag, _v); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, tlag, _v); + lag! {0=>tlag} + }, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _tlag, v); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_nout(1) +} + +fn example_data(n_subjects: usize) -> Data { + let subjects = (0..n_subjects) + .map(|index| { + let dose = 100.0 + index as f64; + let offset = index as f64 * 0.02; + Subject::builder(format!("subject_{index}")) + .bolus(0.0, dose, 0) + .observation(0.5, 4.2 + offset, 0) + .observation(1.0, 5.9 + offset, 0) + .observation(2.0, 7.4 + offset, 0) + .observation(4.0, 6.8 + offset, 0) + .observation(8.0, 4.7 + offset, 0) + .observation(12.0, 3.3 + offset, 0) + .build() + }) + .collect(); + + Data::new(subjects) +} + +fn example_subject() -> Subject { + example_data(1) + .get_subject("subject_0") + .expect("example subject exists") + .clone() +} + +fn example_params() -> [f64; 4] { + [0.60, 0.09, 0.05, 20.0] +} + +fn support_points(n_support_points: usize) -> Array2 { + Array2::from_shape_fn((n_support_points, 4), |(row, column)| match column { + 0 => 0.55 + row as f64 * 0.002, + 1 => 0.08 + row as f64 * 0.0004, + 2 => 0.05, + 3 => 18.0 + row as f64 * 0.15, + _ => unreachable!(), + }) +} + +fn batch_parameters(n_subjects: usize) -> Array2 { + Array2::from_shape_fn((n_subjects, 4), |(row, column)| match column { + 0 => 0.60 + row as f64 * 0.001, + 1 => 0.09 + row as f64 * 0.0002, + 2 => 0.05, + 3 => 20.0 + row as f64 * 0.05, + _ => unreachable!(), + }) +} + +fn assay_error_models() -> AssayErrorModels { + AssayErrorModels::new() + .add( + 0, + AssayErrorModel::additive(ErrorPoly::new(0.0, 0.1, 0.0, 0.0), 0.0), + ) + .unwrap() +} + +fn residual_error_models() -> ResidualErrorModels { + ResidualErrorModels::new().add(0, ResidualErrorModel::constant(0.2)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let assay_error_models = assay_error_models(); + let residual_error_models = residual_error_models(); + + let subject = example_subject(); + let params = example_params(); + let equation_cold = example_equation().disable_cache(); + let equation_hot = example_equation(); + let _ = equation_hot + .estimate_predictions(&subject, params.as_slice()) + .unwrap(); + let predictions = equation_cold + .estimate_predictions(&subject, params.as_slice()) + .unwrap(); + + let mut breakdown_group = c.benchmark_group("ode/runtime-breakdown"); + breakdown_group.sample_size(10); + breakdown_group.warm_up_time(Duration::from_millis(250)); + breakdown_group.measurement_time(Duration::from_secs(1)); + breakdown_group.bench_function("predict-cold", |b| { + b.iter(|| { + black_box( + equation_cold + .estimate_predictions(&subject, params.as_slice()) + .unwrap(), + ) + }) + }); + breakdown_group.bench_function("predict-hot", |b| { + b.iter(|| { + black_box( + equation_hot + .estimate_predictions(&subject, params.as_slice()) + .unwrap(), + ) + }) + }); + breakdown_group.bench_function("score-only", |b| { + b.iter(|| black_box(predictions.log_likelihood(&assay_error_models)).unwrap()) + }); + breakdown_group.bench_function("loglik-cold", |b| { + b.iter(|| { + black_box( + equation_cold + .estimate_log_likelihood(&subject, params.as_slice(), &assay_error_models) + .unwrap(), + ) + }) + }); + breakdown_group.bench_function("loglik-hot", |b| { + b.iter(|| { + black_box( + equation_hot + .estimate_log_likelihood(&subject, params.as_slice(), &assay_error_models) + .unwrap(), + ) + }) + }); + breakdown_group.finish(); + + let mut matrix_group = c.benchmark_group("likelihood/matrix"); + matrix_group.sampling_mode(SamplingMode::Flat); + matrix_group.sample_size(10); + matrix_group.warm_up_time(Duration::from_millis(250)); + matrix_group.measurement_time(Duration::from_secs(2)); + + for (n_subjects, n_support_points) in [(16usize, 32usize), (64usize, 128usize)] { + let case = (example_data(n_subjects), support_points(n_support_points)); + let equation_cold = example_equation().disable_cache(); + let equation_hot = example_equation(); + let _ = log_likelihood_matrix(&equation_hot, &case.0, &case.1, &assay_error_models, false) + .unwrap(); + + matrix_group.throughput(Throughput::Elements((n_subjects * n_support_points) as u64)); + matrix_group.bench_with_input( + BenchmarkId::new("cold", format!("{n_subjects}x{n_support_points}")), + &case, + |b, (data, theta)| { + b.iter(|| { + black_box(log_likelihood_matrix( + &equation_cold, + data, + theta, + &assay_error_models, + false, + )) + .unwrap() + }) + }, + ); + matrix_group.bench_with_input( + BenchmarkId::new("hot", format!("{n_subjects}x{n_support_points}")), + &case, + |b, (data, theta)| { + b.iter(|| { + black_box(log_likelihood_matrix( + &equation_hot, + data, + theta, + &assay_error_models, + false, + )) + .unwrap() + }) + }, + ); + } + matrix_group.finish(); + + let mut batch_group = c.benchmark_group("likelihood/batch"); + batch_group.sample_size(10); + batch_group.warm_up_time(Duration::from_millis(250)); + batch_group.measurement_time(Duration::from_secs(1)); + + for n_subjects in [16usize, 64usize, 256usize] { + let data = example_data(n_subjects); + let parameters = batch_parameters(n_subjects); + let equation_cold = example_equation().disable_cache(); + let equation_hot = example_equation(); + let _ = log_likelihood_batch(&equation_hot, &data, ¶meters, &residual_error_models) + .unwrap(); + + batch_group.throughput(Throughput::Elements(n_subjects as u64)); + batch_group.bench_with_input(BenchmarkId::new("cold", n_subjects), &data, |b, data| { + b.iter(|| { + black_box(log_likelihood_batch( + &equation_cold, + data, + ¶meters, + &residual_error_models, + )) + .unwrap() + }) + }); + batch_group.bench_with_input(BenchmarkId::new("hot", n_subjects), &data, |b, data| { + b.iter(|| { + black_box(log_likelihood_batch( + &equation_hot, + data, + ¶meters, + &residual_error_models, + )) + .unwrap() + }) + }); + } + batch_group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs new file mode 100644 index 00000000..8451b478 --- /dev/null +++ b/examples/analytical_readme.rs @@ -0,0 +1,32 @@ +fn main() -> Result<(), pharmsol::PharmsolError> { + use pharmsol::prelude::*; + + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let subject = Subject::builder("analytical_readme") + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); + + let predictions = analytical.estimate_predictions(&subject, &[1.022, 194.0])?; + + println!("times => {:?}", predictions.flat_times()); + println!("predictions => {:?}", predictions.flat_predictions()); + + Ok(()) +} diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 97112f15..fe5ada96 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -4,6 +4,11 @@ //! two-compartment IV, two-compartment oral), this example runs both the //! closed-form analytical solution and the equivalent ODE, then prints //! the predictions side by side so you can verify they match. +//! Both authoring paths use the declaration-first macro surface so the +//! example stays on the preferred public authoring story. +//! Built-in analytical structures are positional: the `params: [...]` +//! declaration becomes metadata, but the runtime kernel still expects values +//! in the structure's native positional order. //! //! cargo run --release --example analytical_vs_ode @@ -11,29 +16,29 @@ use pharmsol::prelude::*; // ── Subjects ─────────────────────────────────────────────────────── -fn subject_iv() -> Subject { +fn subject_iv(input: impl ToString, output: impl ToString) -> Subject { Subject::builder("1") - .infusion(0.0, 500.0, 0, 0.5) - .observation(0.5, 0.0, 0) - .observation(1.0, 0.0, 0) - .observation(2.0, 0.0, 0) - .observation(4.0, 0.0, 0) - .observation(8.0, 0.0, 0) - .observation(12.0, 0.0, 0) - .observation(24.0, 0.0, 0) + .infusion(0.0, 500.0, input, 0.5) + .observation(0.5, 0.0, output.to_string()) + .observation(1.0, 0.0, output.to_string()) + .observation(2.0, 0.0, output.to_string()) + .observation(4.0, 0.0, output.to_string()) + .observation(8.0, 0.0, output.to_string()) + .observation(12.0, 0.0, output.to_string()) + .observation(24.0, 0.0, output) .build() } -fn subject_oral() -> Subject { +fn subject_oral(input: impl ToString, output: impl ToString) -> Subject { Subject::builder("1") - .bolus(0.0, 500.0, 0) - .observation(0.5, 0.0, 0) - .observation(1.0, 0.0, 0) - .observation(2.0, 0.0, 0) - .observation(4.0, 0.0, 0) - .observation(8.0, 0.0, 0) - .observation(12.0, 0.0, 0) - .observation(24.0, 0.0, 0) + .bolus(0.0, 500.0, input) + .observation(0.5, 0.0, output.to_string()) + .observation(1.0, 0.0, output.to_string()) + .observation(2.0, 0.0, output.to_string()) + .observation(4.0, 0.0, output.to_string()) + .observation(8.0, 0.0, output.to_string()) + .observation(12.0, 0.0, output.to_string()) + .observation(24.0, 0.0, output) .build() } @@ -64,168 +69,180 @@ fn print_comparison(label: &str, analytical: &SubjectPredictions, ode: &SubjectP // ── One-compartment IV ───────────────────────────────────────────── -fn one_cmt_iv(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; +fn one_cmt_iv(params: &[f64]) { + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(1) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, _v); - dx[0] = -ke * x[0] + rateiv[0]; + }; + + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(1) - .with_nout(1); + }; - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); + let subject = subject_iv("iv", "cp"); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("One-compartment IV", &pred_a, &pred_o); } // ── One-compartment oral ─────────────────────────────────────────── -fn one_cmt_oral(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - one_compartment_with_absorption, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, v); - y[0] = x[1] / v; +fn one_cmt_oral(params: &[f64]) { + let analytical = analytical! { + name: "one_cmt_oral", + params: [ka, ke, v], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + ], + structure: one_compartment_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { - fetch_params!(p, ka, ke, _v); - dx[0] = -ka * x[0]; - dx[1] = ka * x[0] - ke * x[1]; + }; + + let ode = ode! { + name: "one_cmt_oral", + params: [ka, ke, v], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, v); - y[0] = x[1] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_nout(1); + }; + + let subject = subject_oral("oral", "cp"); - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("One-compartment oral", &pred_a, &pred_o); } // ── Two-compartment IV ───────────────────────────────────────────── -fn two_cmt_iv(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - two_compartments, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _k12, _k21, v); - y[0] = x[0] / v; +fn two_cmt_iv(params: &[f64]) { + let analytical = analytical! { + name: "two_cmt_iv", + params: [ke, k12, k21, v], + states: [central, peripheral], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + structure: two_compartments, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, k12, k21, _v); - dx[0] = -ke * x[0] - k12 * x[0] + k21 * x[1] + rateiv[0]; - dx[1] = k12 * x[0] - k21 * x[1]; + }; + + let ode = ode! { + name: "two_cmt_iv", + params: [ke, k12, k21, v], + states: [central, peripheral], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - k12 * x[central] + k21 * x[peripheral]; + dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _k12, _k21, v); - y[0] = x[0] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_nout(1); + }; + + let subject = subject_iv("iv", "cp"); - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("Two-compartment IV", &pred_a, &pred_o); } // ── Two-compartment oral ─────────────────────────────────────────── -fn two_cmt_oral(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - two_compartments_with_absorption, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _k12, _k21, v); - y[0] = x[1] / v; +fn two_cmt_oral(params: &[f64]) { + let analytical = analytical! { + name: "two_cmt_oral", + params: [ke, ka, k12, k21, v], + states: [gut, central, peripheral], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + ], + structure: two_compartments_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(3) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { - fetch_params!(p, ka, ke, k12, k21, _v); - dx[0] = -ka * x[0]; - dx[1] = ka * x[0] - ke * x[1] - k12 * x[1] + k21 * x[2]; - dx[2] = k12 * x[1] - k21 * x[2]; + }; + + let ode = ode! { + name: "two_cmt_oral", + params: [ka, ke, k12, k21, v], + states: [gut, central, peripheral], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central] - k12 * x[central] + k21 * x[peripheral]; + dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _k12, _k21, v); - y[0] = x[1] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(3) - .with_nout(1); + }; - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); + let subject = subject_oral("oral", "cp"); + + // `two_compartments_with_absorption` is positional and expects + // `ke, ka, k12, k21, v`, while the ODE closure below is authored as + // `ka, ke, k12, k21, v`. + let analytical_params = [params[1], params[0], params[2], params[3], params[4]]; + + let pred_a = analytical + .estimate_predictions(&subject, &analytical_params) + .unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("Two-compartment oral", &pred_a, &pred_o); } // ── Main ─────────────────────────────────────────────────────────── fn main() { - let iv = subject_iv(); - let oral = subject_oral(); - - one_cmt_iv(&iv, &[0.1, 50.0]); // ke, v - one_cmt_oral(&oral, &[1.0, 0.1, 50.0]); // ka, ke, v - two_cmt_iv(&iv, &[0.1, 0.3, 0.2, 50.0]); // ke, k12, k21, v - two_cmt_oral(&oral, &[1.0, 0.1, 0.3, 0.2, 50.0]); // ka, ke, k12, k21, v + one_cmt_iv(&[0.1, 50.0]); // ke, v + one_cmt_oral(&[1.0, 0.1, 50.0]); // ka, ke, v + two_cmt_iv(&[0.1, 0.3, 0.2, 50.0]); // ke, k12, k21, v + two_cmt_oral(&[1.0, 0.1, 0.3, 0.2, 50.0]); // ka, ke, k12, k21, v } diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index 3a34b424..5d8fdbb6 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -1,4 +1,4 @@ -//! Shows how to select different ODE solvers for the same model. +//! Shows how to select different ODE solvers for the same declaration-first model. //! //! pharmsol wraps diffsol's solver families: //! @@ -14,19 +14,26 @@ use std::time::Instant; use pharmsol::prelude::*; // ── Model ────────────────────────────────────────────────────────── -// Two-compartment IV model. The solver is the only thing that changes -// between runs — the ODE, output equation and dimensions stay the same. +// Two-compartment IV model. The solver is the only thing that changes +// between runs; the declaration-first `ode!` surface and the generated +// metadata stay the same. fn two_cpt(solver: OdeSolver) -> equation::ODE { ode! { - diffeq: |x, p, _t, dx, b, rateiv, _cov| { - fetch_params!(p, ke, kcp, kpc, _v); - dx[0] = rateiv[0] + b[0] - ke * x[0] - kcp * x[0] + kpc * x[1]; - dx[1] = kcp * x[0] - kpc * x[1]; + name: "two_cpt", + params: [ke, kcp, kpc, v], + states: [central, peripheral], + outputs: [cp], + routes: [ + bolus(load) -> central, + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; }, - out: |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _kcp, _kpc, v); - y[0] = x[0] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, } .with_solver(solver) @@ -35,30 +42,34 @@ fn two_cpt(solver: OdeSolver) -> equation::ODE { // ── Main ─────────────────────────────────────────────────────────── fn main() { - let subject = Subject::builder("id1") - .bolus(0.0, 100.0, 0) - .infusion(12.0, 200.0, 0, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(4.0, 0) - .missing_observation(8.0, 0) - .missing_observation(12.0, 0) - .missing_observation(12.5, 0) - .missing_observation(13.0, 0) - .missing_observation(14.0, 0) - .missing_observation(16.0, 0) - .missing_observation(24.0, 0) - .build(); - - let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V - // Run each solver and collect predictions let bdf = two_cpt(OdeSolver::Bdf); let tsit45 = two_cpt(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let trbdf2 = two_cpt(OdeSolver::Sdirk(SdirkTableau::TrBdf2)); let esdirk34 = two_cpt(OdeSolver::Sdirk(SdirkTableau::Esdirk34)); + // Both declarations resolve to the same shared input, so subject + // authoring still uses one numeric index for the loading bolus and the + // maintenance infusion. + + let subject = Subject::builder("id1") + .bolus(0.0, 100.0, "load") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") + .build(); + + let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V + let results: Vec<(&str, equation::ODE)> = vec![ ("Bdf", bdf), ("Sdirk(TrBdf2)", trbdf2), diff --git a/examples/covariates.rs b/examples/covariates.rs index f9b97f29..83516ebb 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -1,61 +1,47 @@ fn main() { use pharmsol::prelude::*; - // Create a subject with a bolus dose, observations, and covariates - let subject = Subject::builder("id1") - // Administer a bolus dose of 100 units at time 0 - .bolus(0.0, 100.0, 0) - // Give two additional doses at 2-hour intervals - .repeat(2, 2.0) - .observation(0.5, 0.1, 0) - .observation(1.0, 0.4, 0) - .observation(2.0, 1.0, 0) - .observation(2.5, 1.1, 0) - // Creatinine covariate changes over time, with initial value of 80 at time 0 - .covariate("creatinine", 0.0, 80.0) - // New obseration of creatinine at time 6 hours - // The value will be linearly interpolated between time 0 and time 6 - .covariate("creatinine", 1.0, 40.0) - // For age, the covariate is constant over time, as there are no changes - .covariate("age", 0.0, 25.0) - .missing_observation(8.0, 0) - .build(); - - let ode = equation::ODE::new( - |x, p, t, dx, b, _rateiv, cov| { - // Macro to get the (possibly interpolated) covariate values at time `t` - fetch_cov!(cov, t, creatinine, age); - // Macro to fetch parameter values from `p` - // Note the order must match the order in which parameters are defined later - fetch_params!(p, ka, ke, _tlag, _v); - - let ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); - - //Struct - dx[0] = -ka * x[0] + b[0]; - dx[1] = ka * x[0] - ke * x[1]; + let ode = ode! { + name: "one_cmt_covariates", + params: [ka, ke, tlag, v], + covariates: [creatinine, age], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + ], + diffeq: |x, _t, dx| { + let scaled_ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); + + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - scaled_ke * x[central]; }, // This blocks defines the lag-time of the bolus dose - |p, _t, _cov| { - fetch_params!(p, _ka, _ke, tlag, _v); + lag: |_t| { // Macro used to define the lag-time for the input of the bolus dose - lag! {0=>tlag} + lag! { oral => tlag } }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); - + out: |x, _t, y| { // Define the predicted concentration as the amount in the central compartment divided by volume - y[0] = x[1] / v; + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_nout(1); + }; + + // Create a subject using route and output labels directly. + let subject = Subject::builder("id1") + .bolus(0.0, 100.0, "oral") + .repeat(2, 2.0) + .observation(0.5, 0.1, "cp") + .observation(1.0, 0.4, "cp") + .observation(2.0, 1.0, "cp") + .observation(2.5, 1.1, "cp") + .covariate("creatinine", 0.0, 80.0) + .covariate("creatinine", 1.0, 40.0) + .covariate("age", 0.0, 25.0) + .missing_observation(8.0, "cp") + .build(); // Define parameter values - // Note that the order matters and should correspond to the order in which parameters are fetched in the model - // This is subject to change in future versions let ka = 1.0; // Absorption rate constant let ke = 0.2; // Elimination rate constant let tlag = 0.0; // Lag time diff --git a/examples/dsl_runtime_jit.rs b/examples/dsl_runtime_jit.rs index 655981bc..3f7d1efe 100644 --- a/examples/dsl_runtime_jit.rs +++ b/examples/dsl_runtime_jit.rs @@ -8,7 +8,7 @@ fn main() -> Result<(), Box> { use pharmsol::prelude::*; let model_source = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v @@ -43,24 +43,16 @@ out(cp) = central / v on_compile_event, )?; - // 2. Resolve the route and output indices declared by the model. - let iv = model - .route_index("iv") - .ok_or_else(|| io::Error::other("missing iv route"))?; - let cp = model - .output_index("cp") - .ok_or_else(|| io::Error::other("missing cp output"))?; - // 3. Define the subject data. let subject = Subject::builder("bimodal_ke") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(3.0, cp) - .missing_observation(4.0, cp) - .missing_observation(6.0, cp) - .missing_observation(8.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(3.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(8.0, "cp") .build(); // 4. Estimate predictions for one support point. diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs new file mode 100644 index 00000000..5e8ec9ec --- /dev/null +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -0,0 +1,104 @@ +//! Compares a declaration-first macro ODE with the equivalent handwritten ODE. +//! +//! This is the advanced comparison path for users who want to confirm that the +//! preferred macro surface and the low-level API produce the same metadata and +//! predictions on the same one-compartment IV problem. + +use pharmsol::prelude::*; + +fn macro_model() -> equation::ODE { + ode! { + name: "one_cpt_macro_parity", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_macro_parity") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten one-compartment metadata should validate") +} + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn main() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_model(); + let handwritten_ode = handwritten_model(); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let subject = Subject::builder("macro-vs-handwritten-one-cpt") + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .build(); + + let params = [1.022, 194.0]; + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let macro_flat = macro_predictions.flat_predictions(); + let handwritten_flat = handwritten_predictions.flat_predictions(); + let diff = max_abs_diff(¯o_flat, &handwritten_flat); + + assert!( + diff <= 1e-10, + "macro and handwritten one-compartment predictions diverged: {diff:e}" + ); + + println!("one-compartment parity max abs diff: {diff:e}"); + for ((time, macro_pred), handwritten_pred) in macro_predictions + .flat_times() + .iter() + .zip(macro_flat.iter()) + .zip(handwritten_flat.iter()) + { + println!("t={time:>4.1} macro={macro_pred:>12.8} handwritten={handwritten_pred:>12.8}"); + } + + Ok(()) +} diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs new file mode 100644 index 00000000..d3c10a0f --- /dev/null +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -0,0 +1,127 @@ +//! Compares a declaration-first macro ODE with the equivalent handwritten ODE +//! on a two-compartment IV problem that shares one numeric input across +//! a loading bolus and a maintenance infusion. +//! +//! This keeps the macro story as the default surface while showing the +//! low-level API as an explicit advanced comparison path. + +use pharmsol::prelude::*; + +fn macro_model() -> equation::ODE { + ode! { + name: "two_cpt_shared_input_parity", + params: [ke, kcp, kpc, v], + states: [central, peripheral], + outputs: [cp], + routes: [ + bolus(load) -> central, + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ke, kcp, kpc, _v); + dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; + dx[1] = kcp * x[0] - kpc * x[1]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, _kcp, _kpc, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("two_cpt_shared_input_parity") + .parameters(["ke", "kcp", "kpc", "v"]) + .states(["central", "peripheral"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("load") + .to_state("central") + .inject_input_to_destination(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("handwritten two-compartment metadata should validate") +} + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn main() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_model(); + let handwritten_ode = handwritten_model(); + let macro_metadata = macro_ode.metadata().expect("macro metadata exists"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!( + macro_metadata + .route("load") + .map(|route| route.input_index()), + macro_metadata.route("iv").map(|route| route.input_index()), + "load and iv should share one numeric input" + ); + assert!(macro_metadata.output("cp").is_some()); + + let subject = Subject::builder("macro-vs-handwritten-two-cpt") + .bolus(0.0, 100.0, "load") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") + .build(); + + let params = [0.1, 0.05, 0.03, 50.0]; + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let macro_flat = macro_predictions.flat_predictions(); + let handwritten_flat = handwritten_predictions.flat_predictions(); + let diff = max_abs_diff(¯o_flat, &handwritten_flat); + + assert!( + diff <= 1e-10, + "macro and handwritten two-compartment predictions diverged: {diff:e}" + ); + + println!("two-compartment parity max abs diff: {diff:e}"); + for ((time, macro_pred), handwritten_pred) in macro_predictions + .flat_times() + .iter() + .zip(macro_flat.iter()) + .zip(handwritten_flat.iter()) + { + println!("t={time:>5.1} macro={macro_pred:>12.8} handwritten={handwritten_pred:>12.8}"); + } + + Ok(()) +} diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index 51765af1..2989895f 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -1,43 +1,37 @@ -fn main() { +fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - let subject = Subject::builder("id1") - .bolus(0.0, 100.0, 0) - .repeat(2, 0.5) - .observation(0.5, 0.1, 0) - .observation(1.0, 0.4, 0) - .observation(2.0, 1.0, 0) - .observation(2.5, 1.1, 0) - .covariate("wt", 0.0, 80.0) - .covariate("wt", 1.0, 83.0) - .covariate("age", 0.0, 25.0) - .build(); - println!("{subject}"); - let ode = equation::ODE::new( - |x, p, _t, dx, b, _rateiv, _cov| { - // fetch_cov!(cov, t,); - fetch_params!(p, ka, ke, _tlag, _v); - //Struct - dx[0] = -ka * x[0] + b[0]; - dx[1] = ka * x[0] - ke * x[1]; + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |p, _t, _cov| { - fetch_params!(p, _ka, _ke, tlag, _v); - lag! {0=>tlag} + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(2) - .with_ndrugs(5) - .with_nout(1); + }; + + let subject = Subject::builder("id1") + .infusion(0.0, 100.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); + + let predictions = ode.estimate_predictions(&subject, &[1.022, 194.0])?; + println!( + "state central => {}", + ode.state_index("central").expect("central state exists") + ); + println!("prediction times => {:?}", predictions.flat_times()); + println!("predictions => {:?}", predictions.flat_predictions()); - let op = ode - .estimate_predictions(&subject, &[0.3, 0.5, 0.1, 70.0]) - .unwrap(); - println!("{:#?}", op.flat_predictions()); + Ok(()) } diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index a66112b5..e5813e2a 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -1,67 +1,53 @@ fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - // Create a subject using the builder pattern - let subject = Subject::builder("Nikola Tesla") - // An initial infusion of 500 units over 0.5 time units - .infusion(0., 500.0, 0, 0.5) - // Observations at various time points - .observation(0.5, 1.645, 0) - .observation(1., 1.216, 0) - .observation(2., 0.462, 0) - .observation(3., 0.169, 0) - .observation(4., 0.063, 0) - .observation(6., 0.009, 0) - .observation(8., 0.001, 0) - // A missing observation, to force the simulator to predict to this time point - // For missing observations, predictions are made but no likelihood contribution is computed - .missing_observation(12.0, 0) - // Build the subject - .build(); - - // Define the one-compartment analytical solution function - let an = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - // Calculate the output concentration, here defined as amount over volume - y[0] = x[0] / v; + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(1) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - // Macro to fetch parameters from the parameter vector - // This exposes them as local variables - fetch_params!(p, ke, _v); + }; - // Define the ODE for the one-compartment model - // Note that rateiv is used to include infusion rates - dx[0] = -ke * x[0] + rateiv[0]; + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - // Calculate the output concentration, here defined as amount over volume - y[0] = x[0] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(1) - .with_nout(1); + }; + + // Create a subject using route and output labels directly. + let subject = Subject::builder("Nikola Tesla") + .infusion(0., 500.0, "iv", 0.5) + .observation(0.5, 1.645, "cp") + .observation(1., 1.216, "cp") + .observation(2., 0.462, "cp") + .observation(3., 0.169, "cp") + .observation(4., 0.063, "cp") + .observation(6., 0.009, "cp") + .observation(8., 0.001, "cp") + .missing_observation(12.0, "cp") + .build(); - // Define the error models for the observations - let ems = AssayErrorModels::new(). - // For this example, we use a simple additive error model with 5% error - add( - 0, + // Define the assay error models once by label and reuse them across both + // equations. + let ems = AssayErrorModels::new().add( + "cp", AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), )?; @@ -70,9 +56,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let v = 194.0; // Volume of distribution // Compute likelihoods and predictions for both models - let analytical_likelihoods = an.estimate_log_likelihood(&subject, &[ke, v], &ems)?; + let analytical_likelihoods = analytical.estimate_log_likelihood(&subject, &[ke, v], &ems)?; - let analytical_predictions = an.estimate_predictions(&subject, &[ke, v])?; + let analytical_predictions = analytical.estimate_predictions(&subject, &[ke, v])?; let ode_likelihoods = ode.estimate_log_likelihood(&subject, &[ke, v], &ems)?; diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs new file mode 100644 index 00000000..97b5fed4 --- /dev/null +++ b/examples/sde_readme.rs @@ -0,0 +1,38 @@ +fn main() -> Result<(), pharmsol::PharmsolError> { + use pharmsol::prelude::*; + + let sde = sde! { + name: "one_cmt_sde", + params: [ke, sigma_ke, v], + states: [central], + outputs: [cp], + particles: 16, + routes: [ + infusion(iv) -> central, + ], + drift: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[central] = sigma_ke; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let subject = Subject::builder("sde_readme") + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); + + let predictions = sde.estimate_predictions(&subject, &[1.022, 0.0, 194.0])?; + + println!("first prediction => {}", predictions[[0, 0]].prediction()); + println!("prediction grid shape => {:?}", predictions.dim()); + + Ok(()) +} diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index d81aa1f8..e6f44e32 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -3,6 +3,9 @@ /// This example demonstrates how to implement a two-compartment pharmacokinetic model /// with weight-based covariate scaling using pharmsol. /// +/// It uses the declaration-first `ode!` surface so the route, covariate, +/// state, and output metadata stay aligned with the generated execution path. +/// /// The two-compartment model describes drug distribution between: /// - Central compartment (x[0]): where drug enters and is eliminated /// - Peripheral compartment (x[1]): a tissue compartment in equilibrium with central @@ -18,36 +21,18 @@ fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - // Create a subject using the builder pattern - let subject = Subject::builder("subject_001") - // An infusion of 500 mg over 0.5 hours (1000 mg/hr rate) - .infusion(0.0, 500.0, 0, 0.5) - // Weight covariate at baseline (85 kg reference weight) - .covariate("wt", 0.0, 70.0) - // Observations at various time points (concentration in mg/L) - .observation(0.5, 8.5, 0) - .observation(1.0, 6.2, 0) - .observation(2.0, 4.1, 0) - .observation(4.0, 2.3, 0) - .observation(6.0, 1.5, 0) - .observation(8.0, 1.1, 0) - .observation(12.0, 0.7, 0) - // Missing observation to force prediction at this time point - .missing_observation(24.0, 0) - .build(); - - // Define the two-compartment ODE model - let ode = equation::ODE::new( - // Primary differential equation block - |x, p, t, dx, _b, rateiv, cov| { - // Fetch the (possibly interpolated) weight covariate at time t - fetch_cov!(cov, t, wt); - - // Fetch parameters from the parameter vector + let ode = ode! { + name: "two_cmt_wt", + params: [cl, v, vp, q], + covariates: [wt], + states: [central, peripheral], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _t, dx| { // CL: Clearance (L/hr), V: Central volume (L) // Vp: Peripheral volume (L), Q: Inter-compartmental clearance (L/hr) - fetch_params!(p, cl, v, vp, q); - // Weight-based allometric scaling // Reference weight is 85 kg let wt_ratio = wt / 85.0; @@ -64,36 +49,37 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let kpc = q_scaled / vp_scaled; // Peripheral to central rate constant // Two-compartment model differential equations - // Central compartment: elimination + distribution + infusion input - dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0]; + // Central compartment: elimination + distribution + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; // Peripheral compartment: distribution equilibrium - dx[1] = kcp * x[0] - kpc * x[1]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; }, - // Lag time block (no lag in this model) - |_p, _t, _cov| lag! {}, - // Bioavailability block (100% for IV, so not needed) - |_p, _t, _cov| fa! {}, - // Secondary equations block (not used here) - |_p, _t, _cov, _x| {}, // Output equation block - calculates observed concentration - |x, p, t, cov, y| { - fetch_cov!(cov, t, wt); - fetch_params!(p, _cl, v, _vp, _q); - + out: |x, _t, y| { // Calculate scaled volume for concentration let wt_ratio = wt / 85.0; let v_scaled = v * wt_ratio; // Concentration = Amount / Volume - y[0] = x[0] / v_scaled; + y[cp] = x[central] / v_scaled; }, - // Model dimensions: (number of compartments, number of outputs) - ) - .with_nstates(2) - .with_nout(1); + }; + + // Create a subject using route and output labels directly. + let subject = Subject::builder("subject_001") + .infusion(0.0, 500.0, "iv", 0.5) + .covariate("wt", 0.0, 70.0) + .observation(0.5, 8.5, "cp") + .observation(1.0, 6.2, "cp") + .observation(2.0, 4.1, "cp") + .observation(4.0, 2.3, "cp") + .observation(6.0, 1.5, "cp") + .observation(8.0, 1.1, "cp") + .observation(12.0, 0.7, "cp") + .missing_observation(24.0, "cp") + .build(); // Define parameter values - // Note: order must match the fetch_params! macro order let cl = 5.0; // Clearance (L/hr) let v = 50.0; // Central volume of distribution (L) let vp = 100.0; // Peripheral volume of distribution (L) diff --git a/pharmsol-dsl/README.md b/pharmsol-dsl/README.md index e3a2469c..0280727b 100644 --- a/pharmsol-dsl/README.md +++ b/pharmsol-dsl/README.md @@ -1,50 +1,65 @@ # pharmsol-dsl -`pharmsol-dsl` is the extraction target for the backend-neutral frontend of the pharmsol DSL. +`pharmsol-dsl` is the backend-neutral frontend crate for the pharmsol DSL. -The crate is introduced as an internal workspace member so the codebase can move to a clean engine / DSL split without duplicating workflows or breaking the current `pharmsol::dsl` user-facing API mid-migration. +Use this crate when you need to work with model source as data: -## Current Status +- parse DSL text into syntax nodes +- inspect spans and diagnostics +- analyze names and types into typed IR +- lower validated models into the execution model used by runtime backends -Slices 1 through 7 have moved the shared frontend data modules, parsing frontend, semantic analysis, and execution lowering here, rewired `pharmsol` backend modules to consume that frontend directly, and cleaned up frontend test ownership: +Do not use this crate for JIT compilation, native AoT export or load, WASM runtime loading, or `Subject`-based prediction helpers. Those workflows stay in `pharmsol::dsl` in the main `pharmsol` crate. -- AST types -- diagnostics and spans -- typed IR -- lexer -- parser -- authoring desugaring used by the parser -- semantic analysis and semantic diagnostics -- execution lowering and execution model types -- frontend-only integration tests and authoring fixtures +## Main Pipeline -`pharmsol::dsl` now acts as a deliberate compatibility façade: it re-exports the frontend surface from this crate while keeping runtime compilation, artifact loading, and execution wrappers in `pharmsol`. +The public pipeline is: -## Planned Ownership +1. `parse_model` or `parse_module` +2. `analyze_model` or `analyze_module` +3. `lower_typed_model` or `lower_typed_module` -The crate will own the backend-neutral frontend pipeline: +The main public modules are: -- AST and syntax types -- authoring desugaring -- diagnostics and spans -- lexical analysis -- parse entrypoints -- typed IR -- execution IR and lowering -- parse / analyze / lower entrypoints +- `ast` for syntax-level nodes +- `diagnostic` for spans, codes, and rendered reports +- `ir` for the typed intermediate representation +- `execution` for the lowered execution model shared by JIT, AoT, and WASM backends -The crate will not own runtime-facing APIs such as JIT, AoT, WASM loading, or `Subject`-based prediction wrappers in the initial extraction. +The parser accepts both canonical `model { ... }` source and the authoring shorthand used by the `pharmsol` examples. -## Migration Rule +## Small Example -Until the move is complete, `pharmsol::dsl` remains the compatibility façade. +```rust +use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model}; -That means: +let source = r#" +name = bimodal_ke +kind = ode -- backend code continues to live in `pharmsol` -- frontend modules move here slice by slice -- user-facing import churn is deferred until the architecture is stable +params = ke, v +states = central +outputs = cp -## Transitional Note +infusion(iv) -> central -The temporary lexer bridge from Slice 1 is gone. Frontend-only authoring fixtures now live under `pharmsol-dsl/tests/fixtures/dsl`, while the shared structured-block corpus remains under `tests/fixtures/dsl` because both crates still consume it. +dx(central) = -ke * central +out(cp) = central / v +"#; + +let syntax = parse_model(source).expect("model parses"); +let typed = analyze_model(&syntax).expect("model analyzes"); +let execution = lower_typed_model(&typed).expect("model lowers"); + +assert_eq!(execution.name, "bimodal_ke"); +assert_eq!(execution.metadata.routes.len(), 1); +assert_eq!(execution.metadata.outputs.len(), 1); +``` + +## Boundary With `pharmsol` + +`pharmsol-dsl` owns the frontend pipeline and its data structures. + +`pharmsol::dsl` re-exports that frontend surface and adds the runtime-facing APIs for backend selection, artifact loading, and prediction execution. + +Use `pharmsol-dsl` when you are building tooling, validation, migration, or your own backend. Use `pharmsol::dsl` when you want a complete source-to-runtime workflow. diff --git a/pharmsol-dsl/src/ast.rs b/pharmsol-dsl/src/ast.rs index d43e7404..6cff4483 100644 --- a/pharmsol-dsl/src/ast.rs +++ b/pharmsol-dsl/src/ast.rs @@ -111,10 +111,17 @@ pub struct RoutesBlock { pub span: Span, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum RouteKind { + Bolus, + Infusion, +} + #[derive(Debug, Clone, PartialEq)] pub struct RouteDecl { pub input: Ident, pub destination: Place, + pub kind: Option, pub properties: Vec, pub span: Span, } @@ -141,7 +148,7 @@ pub struct StatementBlock { #[derive(Debug, Clone, PartialEq)] pub struct AnalyticalBlock { - pub kernel: Ident, + pub structure: Ident, pub span: Span, } @@ -491,7 +498,7 @@ fn write_analytical_block( indent(out, indent_level); writeln!(out, "analytical {{")?; indent(out, indent_level + 1); - writeln!(out, "kernel = {}", block.kernel.text)?; + writeln!(out, "structure = {}", block.structure.text)?; indent(out, indent_level); write!(out, "}}") } diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index db32b9c8..04a64290 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -3,6 +3,7 @@ use std::collections::{BTreeMap, BTreeSet}; use super::ast::*; use super::diagnostic::{Applicability, DiagnosticSuggestion, ParseError, Span, TextEdit}; use super::parser::{parse_expr_fragment, parse_place_fragment}; +use crate::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME}; const DEFAULT_MODEL_NAME: &str = "main"; @@ -12,7 +13,7 @@ pub(super) fn parse_module(src: &str) -> Result { struct AuthoringParser<'a> { src: &'a str, - model_name: Option, + name: Option, explicit_kind: Option<(ModelKind, Span)>, parameters: Vec, constants: Vec, @@ -20,10 +21,12 @@ struct AuthoringParser<'a> { states: Vec, declared_derived: BTreeSet, declared_outputs: BTreeSet, + explicit_output_order: Vec, explicit_outputs: BTreeMap, assigned_outputs: BTreeMap, declared_outputs_span: Option, routes: BTreeMap, + route_order: Vec, route_modifiers: BTreeMap>, derive_statements: Vec, derivative_statements: Vec, @@ -68,7 +71,7 @@ impl<'a> AuthoringParser<'a> { fn new(src: &'a str) -> Self { Self { src, - model_name: None, + name: None, explicit_kind: None, parameters: Vec::new(), constants: Vec::new(), @@ -76,10 +79,12 @@ impl<'a> AuthoringParser<'a> { states: Vec::new(), declared_derived: BTreeSet::new(), declared_outputs: BTreeSet::new(), + explicit_output_order: Vec::new(), explicit_outputs: BTreeMap::new(), assigned_outputs: BTreeMap::new(), declared_outputs_span: None, routes: BTreeMap::new(), + route_order: Vec::new(), route_modifiers: BTreeMap::new(), derive_statements: Vec::new(), derivative_statements: Vec::new(), @@ -132,11 +137,15 @@ impl<'a> AuthoringParser<'a> { } let surface_routes = std::mem::take(&mut self.routes); + let route_order = std::mem::take(&mut self.route_order); let mut route_modifiers = std::mem::take(&mut self.route_modifiers); let mut routes = Vec::with_capacity(surface_routes.len()); - for (route_name, route) in &surface_routes { + for route_name in route_order { + let Some(route) = surface_routes.get(&route_name) else { + continue; + }; let mut span = route.span; - let properties = route_modifiers.remove(route_name).unwrap_or_default(); + let properties = route_modifiers.remove(&route_name).unwrap_or_default(); if !properties.is_empty() { span = properties .iter() @@ -145,6 +154,10 @@ impl<'a> AuthoringParser<'a> { routes.push(RouteDecl { input: route.input.clone(), destination: route.destination.clone(), + kind: Some(match route.kind { + SurfaceRouteKind::Bolus => RouteKind::Bolus, + SurfaceRouteKind::Infusion => RouteKind::Infusion, + }), properties, span, }); @@ -160,16 +173,30 @@ impl<'a> AuthoringParser<'a> { let kind = self.determine_kind(module_span)?; if matches!(kind, ModelKind::Analytical) && !self.derivative_statements.is_empty() { return Err(ParseError::new( - "analytical authoring models cannot declare `dx(...)` equations", + "analytical models cannot declare `dx(...)` equations", self.derivative_statements[0].span, )); } + if !self.explicit_output_order.is_empty() { + let output_order = self + .explicit_output_order + .iter() + .enumerate() + .map(|(index, name)| (name.clone(), index)) + .collect::>(); + self.output_statements.sort_by_key(|statement| { + output_statement_name(statement) + .and_then(|name| output_order.get(name).copied()) + .unwrap_or(usize::MAX) + }); + } + let mut derivative_statements = std::mem::take(&mut self.derivative_statements); inject_infusion_rates(&surface_routes, &routes, &mut derivative_statements); let name = self - .model_name + .name .unwrap_or_else(|| Ident::new(DEFAULT_MODEL_NAME, module_span)); let mut items = Vec::new(); @@ -285,7 +312,7 @@ impl<'a> AuthoringParser<'a> { let eq_index = find_top_level_assignment(trimmed).ok_or_else(|| { ParseError::new( - "expected an authoring declaration, equation, or route shorthand", + "expected an declaration, equation, or route shorthand", span, ) })?; @@ -298,9 +325,19 @@ impl<'a> AuthoringParser<'a> { if let Some(rest) = lhs_trimmed.strip_prefix("model") { if !rest.trim().is_empty() { - return Err(ParseError::new("expected `model = `", span)); + return Err(ParseError::new("expected `name = `", span)); + } + return Err(ParseError::new( + "`model = ...` has been renamed to `name = ...`", + span, + )); + } + + if let Some(rest) = lhs_trimmed.strip_prefix("name") { + if !rest.trim().is_empty() { + return Err(ParseError::new("expected `name = `", span)); } - self.model_name = Some(parse_ident_segment(rhs, rhs_abs)?); + self.name = Some(parse_ident_segment(rhs, rhs_abs)?); return Ok(()); } @@ -351,7 +388,8 @@ impl<'a> AuthoringParser<'a> { if lhs_trimmed == "outputs" { self.declared_outputs_span = Some(span); - for ident in parse_ident_list(rhs, rhs_abs)? { + for ident in parse_output_label_list(rhs, rhs_abs)? { + self.explicit_output_order.push(ident.text.clone()); self.declared_outputs.insert(ident.text.clone()); self.explicit_outputs.insert(ident.text, ident.span); } @@ -365,8 +403,15 @@ impl<'a> AuthoringParser<'a> { } if lhs_trimmed == "kernel" { - let kernel = parse_ident_segment(rhs, rhs_abs)?; - self.analytical = Some(AnalyticalBlock { span, kernel }); + return Err(ParseError::new( + "`kernel = ...` has been renamed to `structure = ...`", + span, + )); + } + + if lhs_trimmed == "structure" { + let structure = parse_ident_segment(rhs, rhs_abs)?; + self.analytical = Some(AnalyticalBlock { span, structure }); return Ok(()); } @@ -386,7 +431,20 @@ impl<'a> AuthoringParser<'a> { return self.parse_call_assignment(call, rhs, rhs_abs, span); } - let target = parse_ident_segment(lhs, lhs_abs)?; + let target = match parse_ident_segment(lhs, lhs_abs) { + Ok(target) => target, + Err(error) => { + if self.declared_outputs_span.is_none() { + return Err(error); + } + + let target = parse_output_label_segment(lhs, lhs_abs)?; + if !self.declared_outputs.contains(&target.text) { + return Err(self.undeclared_output_error(&target.text, target.span)); + } + target + } + }; let rhs = parse_surface_rhs(rhs, rhs_abs)?; let stmt = build_assignment_statement( AssignTarget { @@ -427,16 +485,17 @@ impl<'a> AuthoringParser<'a> { } }; - let input = parse_ident_segment(call.argument, call.argument_start)?; + let input = parse_route_label_segment(call.argument, call.argument_start)?; + let route_name = input.text.clone(); let destination = parse_place_at(rhs, line_start + arrow + 2)?; - if self.routes.contains_key(&input.text) { + if self.routes.contains_key(&route_name) { return Err(ParseError::new( format!("duplicate route `{}`", input.text), input.span, )); } self.routes.insert( - input.text.clone(), + route_name.clone(), SurfaceRoute { input, destination, @@ -444,6 +503,7 @@ impl<'a> AuthoringParser<'a> { span, }, ); + self.route_order.push(route_name); Ok(()) } @@ -456,7 +516,7 @@ impl<'a> AuthoringParser<'a> { ) -> Result<(), ParseError> { match call.callee.text.as_str() { "lag" | "fa" => { - let route_name = parse_ident_segment(call.argument, call.argument_start)?; + let route_name = parse_route_label_segment(call.argument, call.argument_start)?; let value = parse_expr_at(rhs, rhs_abs)?; let property_name = match call.callee.text.as_str() { "lag" => "lag", @@ -523,7 +583,7 @@ impl<'a> AuthoringParser<'a> { self.init_statements.push(stmt); } "out" => { - let output = parse_ident_segment(call.argument, call.argument_start)?; + let output = parse_output_label_segment(call.argument, call.argument_start)?; self.validate_output_target(&output)?; self.declared_outputs.insert(output.text.clone()); self.note_output_assignment(&output); @@ -545,7 +605,7 @@ impl<'a> AuthoringParser<'a> { } other => { return Err(ParseError::new( - format!("unsupported authoring equation target `{other}`"), + format!("unsupported equation target `{other}`"), call.callee.span, )) } @@ -569,19 +629,17 @@ impl<'a> AuthoringParser<'a> { .unwrap_or(module_span); if matches!(kind, ModelKind::Analytical) - && (!self.diffusion_statements.is_empty() - || self.particles.is_some() - || !self.init_statements.is_empty()) + && (!self.diffusion_statements.is_empty() || self.particles.is_some()) { return Err(ParseError::new( - "analytical authoring models cannot declare particles, init, or noise equations", + "analytical models cannot declare particles or noise equations", kind_span, )); } if matches!(kind, ModelKind::Ode) && !self.diffusion_statements.is_empty() { return Err(ParseError::new( - "ODE authoring models cannot declare `noise(...)` equations", + "ODE models cannot declare `noise(...)` equations", self.diffusion_statements[0].span, )); } @@ -589,7 +647,7 @@ impl<'a> AuthoringParser<'a> { if matches!(kind, ModelKind::Sde) { if let Some(analytical) = &self.analytical { return Err(ParseError::new( - "SDE authoring models cannot declare an analytical kernel", + "SDE models cannot declare an analytical structure", analytical.span, )); } @@ -743,7 +801,7 @@ fn inject_infusion_rates( let rate_expr = Expr { span: surface_route.span, kind: ExprKind::Call { - callee: Ident::new("rate", surface_route.input.span), + callee: Ident::new(RATE_FUNCTION_NAME, surface_route.input.span), args: vec![Expr { span: surface_route.input.span, kind: ExprKind::Name(surface_route.input.clone()), @@ -783,13 +841,13 @@ fn parse_call_head<'a>(src: &'a str, abs_start: usize) -> Result Result, ParseErro .collect() } +fn parse_output_label_list(src: &str, abs_start: usize) -> Result, ParseError> { + split_top_level(src, ',') + .into_iter() + .map(|(segment, start)| parse_output_label_segment(segment, abs_start + start)) + .collect() +} + fn parse_covariates_list(src: &str, abs_start: usize) -> Result, ParseError> { let mut covariates = Vec::new(); for (segment, start) in split_top_level(src, ',') { @@ -880,6 +945,143 @@ fn parse_ident_segment(src: &str, abs_start: usize) -> Result )) } +fn parse_output_label_segment(src: &str, abs_start: usize) -> Result { + parse_label_segment(src, abs_start, LabelKind::Output) +} + +fn parse_route_label_segment(src: &str, abs_start: usize) -> Result { + parse_label_segment(src, abs_start, LabelKind::Route) +} + +fn parse_label_segment(src: &str, abs_start: usize, kind: LabelKind) -> Result { + let trimmed = src.trim(); + let leading = src.len() - src.trim_start().len(); + let span = Span::new(abs_start + leading, abs_start + leading + trimmed.len()); + if trimmed.is_empty() { + return Err(ParseError::new( + format!("expected {}", kind.expected()), + Span::new(abs_start, abs_start + src.len()), + )); + } + if !is_valid_output_label(trimmed) { + return Err(ParseError::new( + format!("expected {}, found `{trimmed}`", kind.expected()), + span, + )); + } + + if let Some(suffix) = bare_numeric_label(trimmed) { + let replacement = kind.canonical_label(suffix); + return Err(ParseError::new( + format!( + "bare numeric {} labels are not allowed in the DSL; use `{replacement}` instead", + kind.noun() + ), + span, + ) + .with_help(format!( + "numeric {} labels must use the `{}` prefix in authored DSL", + kind.noun(), + kind.prefix_pattern() + )) + .with_suggestion(DiagnosticSuggestion { + message: format!("use `{replacement}`"), + edits: vec![TextEdit { span, replacement }], + applicability: Applicability::Always, + })); + } + + if let Some(suffix) = canonical_numeric_suffix(trimmed, kind.wrong_prefix()) { + let replacement = kind.canonical_label(suffix); + return Err(ParseError::new( + format!( + "`{trimmed}` is {} label and cannot be used as {}; use `{replacement}` here", + kind.wrong_kind_phrase(), + kind.noun_phrase() + ), + span, + ) + .with_help(format!( + "numeric {} labels use the `{}` prefix", + kind.noun(), + kind.prefix_pattern() + )) + .with_suggestion(DiagnosticSuggestion { + message: format!("use `{replacement}`"), + edits: vec![TextEdit { span, replacement }], + applicability: Applicability::Always, + })); + } + + Ok(Ident::new(trimmed, span)) +} + +#[derive(Clone, Copy)] +enum LabelKind { + Route, + Output, +} + +impl LabelKind { + fn expected(self) -> &'static str { + match self { + Self::Route => "route label", + Self::Output => "output label", + } + } + + fn noun(self) -> &'static str { + match self { + Self::Route => "route", + Self::Output => "output", + } + } + + fn noun_phrase(self) -> &'static str { + match self { + Self::Route => "a route", + Self::Output => "an output", + } + } + + fn wrong_kind_phrase(self) -> &'static str { + match self { + Self::Route => "an output", + Self::Output => "a route", + } + } + + fn canonical_label(self, suffix: &str) -> String { + match self { + Self::Route => format!("{NUMERIC_ROUTE_PREFIX}{suffix}"), + Self::Output => format!("{NUMERIC_OUTPUT_PREFIX}{suffix}"), + } + } + + fn wrong_prefix(self) -> &'static str { + match self { + Self::Route => NUMERIC_OUTPUT_PREFIX, + Self::Output => NUMERIC_ROUTE_PREFIX, + } + } + + fn prefix_pattern(self) -> &'static str { + match self { + Self::Route => "input_", + Self::Output => "outeq_", + } + } +} + +fn bare_numeric_label(src: &str) -> Option<&str> { + (!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src) +} + +fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> { + let suffix = src.strip_prefix(prefix)?; + (!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix) +} + fn parse_place_at(src: &str, abs_start: usize) -> Result { let mut place = parse_place_fragment(src).map_err(|error| error.shifted(abs_start))?; shift_place(&mut place, abs_start); @@ -908,14 +1110,14 @@ fn parse_if_rhs(src: &str, abs_start: usize) -> Result { let rest_abs = abs_start + 2 + rest_leading; if !rest.starts_with('(') { return Err(ParseError::new( - "expected `(` after `if` in authoring conditional expression", + "expected `(` after `if` in conditional expression", Span::new(rest_abs, rest_abs + rest.len().min(1)), )); } let close = find_matching_delimiter(rest, '(', ')').ok_or_else(|| { ParseError::new( - "unclosed `(` in authoring conditional expression", + "unclosed `(` in conditional expression", Span::new(rest_abs, rest_abs + rest.len()), ) })?; @@ -925,7 +1127,7 @@ fn parse_if_rhs(src: &str, abs_start: usize) -> Result { let remaining_abs = rest_abs + close + 1; let else_index = find_top_level_keyword(remaining, "else").ok_or_else(|| { ParseError::new( - "expected `else` in authoring conditional expression", + "expected `else` in conditional expression", Span::new(remaining_abs, remaining_abs + remaining.len()), ) })?; @@ -1317,6 +1519,10 @@ fn is_valid_ident(src: &str) -> bool { chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') } +fn is_valid_output_label(src: &str) -> bool { + is_valid_ident(src) || src.chars().all(|ch| ch.is_ascii_digit()) +} + fn is_ident_byte(byte: u8) -> bool { (byte as char).is_ascii_alphanumeric() || byte == b'_' } @@ -1345,6 +1551,16 @@ fn join_covariate_spans(items: &[CovariateDecl]) -> Span { .unwrap_or_else(|| Span::empty(0)) } +fn output_statement_name(statement: &Stmt) -> Option<&str> { + match &statement.kind { + StmtKind::Assign(assign) => match &assign.target.kind { + AssignTargetKind::Name(name) => Some(name.text.as_str()), + _ => None, + }, + _ => None, + } +} + fn join_state_spans(items: &[StateDecl]) -> Span { items .iter() diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 26e0a3cb..2384432a 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use crate::{ AnalyticalKernel, ConstValue, CovariateInterpolation, Diagnostic, DiagnosticPhase, - DiagnosticReport, MathIntrinsic, ModelKind, RoutePropertyKind, Span, Symbol, SymbolId, - SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr, + DiagnosticReport, MathIntrinsic, ModelKind, RouteKind, RoutePropertyKind, Span, Symbol, + SymbolId, SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr, TypedExprKind, TypedModel, TypedModule, TypedRangeExpr, TypedStatePlace, TypedStatementBlock, TypedStmt, TypedStmtKind, TypedUnaryOp, ValueType, DSL_LOWERING_GENERIC, }; @@ -98,7 +98,9 @@ pub struct ExecutionState { pub struct ExecutionRoute { pub symbol: SymbolId, pub name: String, + pub declaration_index: usize, pub index: usize, + pub kind: Option, pub destination: RouteDestination, pub has_lag: bool, pub has_bioavailability: bool, @@ -349,7 +351,7 @@ pub enum ExecutionLoad { State(ExecutionStateRef), Derived(usize), Local(usize), - RouteInput(usize), + RouteInput { route: SymbolId, index: usize }, } #[derive(Debug, Clone, PartialEq)] @@ -531,33 +533,67 @@ impl<'a> ExecutionLowerer<'a> { next_state_offset += len; } + let uses_authoring_route_kinds = + !model.routes.is_empty() && model.routes.iter().all(|route| route.kind.is_some()); let mut route_slots = BTreeMap::new(); - let routes = model - .routes - .iter() - .enumerate() - .map(|(index, route)| { - let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?; - route_slots.insert(route.symbol, index); - let destination = - lower_route_destination(&symbol_map, &state_slots, &route.destination)?; - Ok(ExecutionRoute { - symbol: route.symbol, - name: symbol.name.clone(), - index, - destination, - has_lag: route - .properties - .iter() - .any(|property| property.kind == RoutePropertyKind::Lag), - has_bioavailability: route - .properties - .iter() - .any(|property| property.kind == RoutePropertyKind::Bioavailability), - span: route.span, - }) - }) - .collect::, LoweringError>>()?; + let mut routes = Vec::with_capacity(model.routes.len()); + let mut next_bolus_index = 0usize; + let mut next_infusion_index = 0usize; + for (declaration_index, route) in model.routes.iter().enumerate() { + let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?; + if route.kind == Some(RouteKind::Infusion) { + if let Some(property) = route.properties.first() { + let label = match property.kind { + RoutePropertyKind::Lag => "lag", + RoutePropertyKind::Bioavailability => "bioavailability", + }; + return Err(LoweringError::new( + format!( + "DSL authoring does not allow `{label}` on infusion route `{}`", + symbol.name + ), + property.span, + ) + .with_note("lag and bioavailability are bolus-only route properties")); + } + } + let index = if uses_authoring_route_kinds { + match route.kind.expect("authoring routes must preserve kind") { + RouteKind::Bolus => { + let index = next_bolus_index; + next_bolus_index += 1; + index + } + RouteKind::Infusion => { + let index = next_infusion_index; + next_infusion_index += 1; + index + } + } + } else { + declaration_index + }; + route_slots.insert(route.symbol, index); + let destination = + lower_route_destination(&symbol_map, &state_slots, &route.destination)?; + routes.push(ExecutionRoute { + symbol: route.symbol, + name: symbol.name.clone(), + declaration_index, + index, + kind: route.kind, + destination, + has_lag: route + .properties + .iter() + .any(|property| property.kind == RoutePropertyKind::Lag), + has_bioavailability: route + .properties + .iter() + .any(|property| property.kind == RoutePropertyKind::Bioavailability), + span: route.span, + }); + } let mut derived_slots = BTreeMap::new(); let derived = model @@ -607,7 +643,7 @@ impl<'a> ExecutionLowerer<'a> { analytical: model .analytical .as_ref() - .map(|analytical| analytical.kernel), + .map(|analytical| analytical.structure), }, symbol_map, parameter_slots, @@ -653,7 +689,7 @@ impl<'a> ExecutionLowerer<'a> { kernels.push(ExecutionKernel { role: KernelRole::Analytical, signature: signature_for(KernelRole::Analytical), - implementation: KernelImplementation::AnalyticalBuiltin(analytical.kernel), + implementation: KernelImplementation::AnalyticalBuiltin(analytical.structure), span: analytical.span, }); } @@ -745,7 +781,13 @@ impl<'a> ExecutionLowerer<'a> { }, route_buffer: DenseBufferLayout { kind: BufferKind::Routes, - len: self.metadata.routes.len(), + len: self + .metadata + .routes + .iter() + .map(|route| route.index + 1) + .max() + .unwrap_or(0), slots: self .metadata .routes @@ -858,7 +900,39 @@ impl<'a> ExecutionLowerer<'a> { let mut statements = Vec::with_capacity(self.model.routes.len()); let mut locals = LocalLowering::default(); + let default_value = match property_kind { + RoutePropertyKind::Lag => literal_real(0.0, self.model.span), + RoutePropertyKind::Bioavailability => literal_real(1.0, self.model.span), + }; + let route_len = self + .metadata + .routes + .iter() + .map(|route| route.index + 1) + .max() + .unwrap_or(0); + for route_index in 0..route_len { + let target_kind = match property_kind { + RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index), + RoutePropertyKind::Bioavailability => { + ExecutionTargetKind::RouteBioavailability(route_index) + } + }; + statements.push(ExecutionStmt { + kind: ExecutionStmtKind::Assign(ExecutionAssignStmt { + target: ExecutionTarget { + kind: target_kind, + span: self.model.span, + }, + value: default_value.clone(), + }), + span: self.model.span, + }); + } for route in &self.model.routes { + if route.kind == Some(RouteKind::Infusion) { + continue; + } let route_name = self.symbol_name(route.symbol)?.to_string(); let route_index = *self.route_slots.get(&route.symbol).ok_or_else(|| { LoweringError::new( @@ -872,8 +946,7 @@ impl<'a> ExecutionLowerer<'a> { .find(|property| property.kind == property_kind) { Some(property) => self.lower_expr(&property.value, &mut locals)?, - None if property_kind == RoutePropertyKind::Lag => literal_real(0.0, route.span), - None => literal_real(1.0, route.span), + None => continue, }; let target_kind = match property_kind { RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index), @@ -1098,7 +1171,10 @@ impl<'a> ExecutionLowerer<'a> { expr.span, ) })?; - ExecutionExprKind::Load(ExecutionLoad::RouteInput(route_index)) + ExecutionExprKind::Load(ExecutionLoad::RouteInput { + route: *route, + index: route_index, + }) } }, }; @@ -1439,6 +1515,163 @@ mod tests { ); } + #[test] + fn authoring_routes_share_input_indices_by_kind_local_ordinal() { + let src = r#"name = shared_authoring +kind = ode + +params = ka, ke, v, tlag, f_oral +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers"); + + assert_eq!(lowered.abi.route_buffer.len, 1); + assert_eq!(lowered.metadata.routes.len(), 2); + assert_eq!(lowered.metadata.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(lowered.metadata.routes[1].kind, Some(RouteKind::Infusion)); + assert_eq!(lowered.metadata.routes[0].declaration_index, 0); + assert_eq!(lowered.metadata.routes[1].declaration_index, 1); + assert_eq!(lowered.metadata.routes[0].index, 0); + assert_eq!(lowered.metadata.routes[1].index, 0); + assert!(lowered.metadata.routes[0].has_lag); + assert!(lowered.metadata.routes[0].has_bioavailability); + assert!(!lowered.metadata.routes[1].has_lag); + assert!(!lowered.metadata.routes[1].has_bioavailability); + } + + #[test] + fn canonical_numeric_channel_names_flow_into_execution_metadata_and_abi() { + let src = r#"name = canonical_numeric_channels +kind = ode + +params = ke, v +states = depot, central +outputs = cp, outeq_2 + +bolus(input_10) -> depot +infusion(iv) -> central + +dx(depot) = -ke * depot +dx(central) = rate(input_10) - ke * central + +out(cp) = central / v +out(outeq_2) = depot / v +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers"); + + assert_eq!( + lowered + .metadata + .routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["input_10", "iv"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "outeq_2"] + ); + assert_eq!( + lowered + .abi + .route_buffer + .slots + .iter() + .map(|slot| slot.name.as_str()) + .collect::>(), + vec!["input_10", "iv"] + ); + assert_eq!( + lowered + .abi + .output_buffer + .slots + .iter() + .map(|slot| slot.name.as_str()) + .collect::>(), + vec!["cp", "outeq_2"] + ); + } + + #[test] + fn authoring_routes_reject_infusion_lag_properties() { + let src = r#"name = invalid_infusion_lag +kind = ode + +params = ke, v, tlag +states = central +outputs = cp + +infusion(iv) -> central +lag(iv) = tlag + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let error = crate::lower_typed_model(&typed) + .err() + .expect("infusion lag should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `lag` on infusion route `iv`")); + } + + #[test] + fn authoring_routes_reject_infusion_bioavailability_properties() { + let src = r#"name = invalid_infusion_fa +kind = ode + +params = ke, v, f_iv +states = central +outputs = cp + +infusion(iv) -> central +fa(iv) = f_iv + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let error = crate::lower_typed_model(&typed) + .err() + .expect("infusion bioavailability should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `bioavailability` on infusion route `iv`")); + } + #[test] fn flattens_array_states_and_preserves_loop_structure() { let execution = structured_block_execution(); @@ -1538,8 +1771,8 @@ mod tests { panic!("expected statement bioavailability kernel"); }; - assert_eq!(lag_program.body.statements.len(), 2); - assert_eq!(bio_program.body.statements.len(), 2); + assert_eq!(lag_program.body.statements.len(), 3); + assert_eq!(bio_program.body.statements.len(), 3); assert!(matches!( lag_program.body.statements[1].kind, ExecutionStmtKind::Assign(ExecutionAssignStmt { diff --git a/pharmsol-dsl/src/ir.rs b/pharmsol-dsl/src/ir.rs index d1c54c90..5998431c 100644 --- a/pharmsol-dsl/src/ir.rs +++ b/pharmsol-dsl/src/ir.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::{ModelKind, Span}; +use crate::{ModelKind, RouteKind, Span}; pub type SymbolId = usize; @@ -145,6 +145,7 @@ pub struct TypedState { #[derive(Debug, Clone, PartialEq)] pub struct TypedRoute { pub symbol: SymbolId, + pub kind: Option, pub destination: TypedStatePlace, pub properties: Vec, pub span: Span, @@ -165,7 +166,7 @@ pub enum RoutePropertyKind { #[derive(Debug, Clone, PartialEq)] pub struct TypedAnalytical { - pub kernel: AnalyticalKernel, + pub structure: AnalyticalKernel, pub span: Span, } diff --git a/pharmsol-dsl/src/lib.rs b/pharmsol-dsl/src/lib.rs index 83bd9a7c..261efabf 100644 --- a/pharmsol-dsl/src/lib.rs +++ b/pharmsol-dsl/src/lib.rs @@ -1,8 +1,72 @@ //! Backend-neutral frontend crate for the pharmsol DSL. //! -//! This crate owns parsing, diagnostics, authoring desugaring, semantic -//! analysis, and execution lowering for DSL modules. `pharmsol::dsl` -//! re-exports the stable runtime-facing surface in the main crate. +//! Use this crate when you need the DSL frontend as an engineering API: parse +//! model source, inspect diagnostics, analyze names and types, and lower a +//! validated model into the execution representation that backends consume. +//! +//! Do not use this crate when you already know you want JIT compilation, +//! native AoT artifacts, WASM artifacts, or `Subject`-based prediction +//! helpers. Those runtime-facing workflows stay in the main `pharmsol` crate +//! under `pharmsol::dsl`. +//! +//! Main entrypoints: +//! +//! - [`parse_model`] and [`parse_module`] for turning DSL source text into the +//! syntax tree in [`ast`]. +//! - [`analyze_model`] and [`analyze_module`] for semantic validation and the +//! typed IR in [`ir`]. +//! - [`lower_typed_model`] and [`lower_typed_module`] for lowering typed models +//! into the execution representation in [`execution`]. +//! +//! The frontend pipeline is intentionally simple: +//! +//! 1. Parse source text into syntax. +//! 2. Analyze the syntax into a typed model. +//! 3. Lower the typed model into an [`ExecutionModel`] or [`ExecutionModule`]. +//! +//! This crate accepts both canonical `model { ... }` source and the authoring +//! shorthand used by the `pharmsol` examples. The returned diagnostics carry +//! source spans, rendered messages, and structured data for editor or UI use. +//! +//! Main modules: +//! +//! - [`ast`] for syntax-level nodes. +//! - [`diagnostic`] for spans, diagnostic codes, and rendered reports. +//! - [`ir`] for the typed intermediate representation. +//! - [`execution`] for the lowered execution model shared by JIT, AoT, and +//! WASM backends. +//! +//! Smallest parse-analyze-lower example: +//! +//! ```rust +//! use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model}; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let syntax = parse_model(source).expect("model parses"); +//! let typed = analyze_model(&syntax).expect("model analyzes"); +//! let execution = lower_typed_model(&typed).expect("model lowers"); +//! +//! assert_eq!(execution.name, "bimodal_ke"); +//! assert_eq!(execution.metadata.routes.len(), 1); +//! assert_eq!(execution.metadata.outputs.len(), 1); +//! ``` +//! +//! If you are building an authoring tool, custom compiler, or diagnostics UI, +//! stay in this crate. If you want a complete source-to-runtime workflow, +//! switch to `pharmsol::dsl` in the main crate. pub mod ast; mod authoring; @@ -15,6 +79,12 @@ mod semantic; #[cfg(test)] mod test_fixtures; +/// Canonical prefix for numeric route labels such as `input_1`. +pub const NUMERIC_ROUTE_PREFIX: &str = "input_"; +/// Canonical prefix for numeric output labels such as `outeq_1`. +pub const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; +pub(crate) const RATE_FUNCTION_NAME: &str = "rate"; + pub use ast::*; pub use diagnostic::*; pub use execution::{ diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index 7af6c681..98c6b0a4 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -106,12 +106,18 @@ struct Parser { #[derive(Clone, Copy)] enum LayoutBoundary { ModelItem, - Statement, + Statement(StatementContext), Binding, IdentItem, RouteDecl, } +#[derive(Clone, Copy, PartialEq, Eq)] +enum StatementContext { + Standard, + Outputs, +} + impl Parser { fn new(src: &str) -> Result { Ok(Self::from_tokens(lex(src)?, src.len())) @@ -557,7 +563,7 @@ impl Parser { } fn parse_route_decl(&mut self) -> Result { - let input = self.parse_ident()?; + let input = self.parse_label_name("route label")?; let arrow = self.expect_simple(|kind| matches!(kind, TokenKind::Arrow), "`->`")?; self.ensure_not_layout_boundary( arrow.span, @@ -607,6 +613,7 @@ impl Parser { Ok(RouteDecl { input: input.clone(), destination, + kind: None, properties, span: input.span.join(end_span), }) @@ -616,19 +623,19 @@ impl Parser { let start = self.bump().unwrap().span; let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; - let kernel_name = self.parse_ident()?; - if kernel_name.text != "kernel" { + let structure_name = self.parse_ident()?; + if structure_name.text != "structure" { return Err(ParseError::new( format!( - "expected `kernel = ` inside analytical block, found `{}`", - kernel_name.text + "expected `structure = ` inside analytical block, found `{}`", + structure_name.text ), - kernel_name.span, + structure_name.span, )); } let eq = self.expect_simple(|kind| matches!(kind, TokenKind::Eq), "`=`")?; - let kernel = self.parse_continuation_ident_after(&eq, "kernel identifier")?; + let structure = self.parse_continuation_ident_after(&eq, "structure identifier")?; self.consume_separators(); let end = self.expect_closing( |kind| matches!(kind, TokenKind::RBrace), @@ -637,7 +644,7 @@ impl Parser { "`analytical` block", )?; Ok(AnalyticalBlock { - kernel, + structure, span: start.join(end.span), }) } @@ -654,8 +661,13 @@ impl Parser { fn parse_statement_block(&mut self, name: &str) -> Result { let start = self.bump().unwrap().span; let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = if name == "outputs" { + StatementContext::Outputs + } else { + StatementContext::Standard + }; let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -789,8 +801,9 @@ impl Parser { fn parse_stmt_body(&mut self) -> Result, ParseError> { let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = self.current_statement_context(); let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -853,7 +866,11 @@ impl Parser { } fn parse_assign_target(&mut self) -> Result { - let name = self.parse_ident()?; + let name = if matches!(self.current_statement_context(), StatementContext::Outputs) { + self.parse_output_target_name()? + } else { + self.parse_ident()? + }; let mut span = name.span; let kind = if let Some(open) = self.take_if(|kind| matches!(kind, TokenKind::LParen)) { let args = self.parse_expr_list(&open, TokenKindMatcher::RPAREN)?; @@ -884,6 +901,34 @@ impl Parser { Ok(AssignTarget { kind, span }) } + fn parse_output_target_name(&mut self) -> Result { + self.parse_label_name("output label") + } + + fn parse_label_name(&mut self, expected: &str) -> Result { + let token = self.bump().ok_or_else(|| { + ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)) + })?; + match token.kind { + TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), + TokenKind::Number(value) + if value.is_finite() + && value >= 0.0 + && value.fract() == 0.0 + && value <= usize::MAX as f64 => + { + Ok(Ident::new((value as usize).to_string(), token.span)) + } + other => Err(ParseError::new( + format!( + "expected {expected} identifier or non-negative integer, found {}", + other.describe() + ), + token.span, + )), + } + } + fn parse_ident(&mut self) -> Result { let token = self .bump() @@ -1319,9 +1364,12 @@ impl Parser { | TokenKind::Diffusion | TokenKind::Particles ), - LayoutBoundary::Statement => match &token.kind { + LayoutBoundary::Statement(context) => match &token.kind { TokenKind::If | TokenKind::For | TokenKind::Let => true, TokenKind::Ident(_) => self.line_starts_assignment_target(index), + TokenKind::Number(_) if matches!(context, StatementContext::Outputs) => { + self.line_starts_numeric_output_assignment(index) + } _ => false, }, LayoutBoundary::Binding => self.line_starts_named_assignment(index), @@ -1378,6 +1426,26 @@ impl Parser { } } + fn line_starts_numeric_output_assignment(&self, index: usize) -> bool { + matches!( + self.tokens.get(index).map(|token| &token.kind), + Some(TokenKind::Number(_)) + ) && self + .next_same_line_index(index) + .is_some_and(|next| matches!(self.tokens[next].kind, TokenKind::Eq)) + } + + fn current_statement_context(&self) -> StatementContext { + self.layout_boundaries + .iter() + .rev() + .find_map(|boundary| match boundary { + LayoutBoundary::Statement(context) => Some(*context), + _ => None, + }) + .unwrap_or(StatementContext::Standard) + } + fn next_same_line_index(&self, index: usize) -> Option { let next = index + 1; let token = self.tokens.get(next)?; @@ -1412,7 +1480,7 @@ impl Parser { fn current_boundary_label(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "next model item starts here", - Some(LayoutBoundary::Statement) => "next statement starts here", + Some(LayoutBoundary::Statement(_)) => "next statement starts here", Some(LayoutBoundary::Binding) => "next binding starts here", Some(LayoutBoundary::IdentItem) => "next declaration starts here", Some(LayoutBoundary::RouteDecl) => "next route starts here", @@ -1423,7 +1491,7 @@ impl Parser { fn current_boundary_subject(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "model item", - Some(LayoutBoundary::Statement) => "statement", + Some(LayoutBoundary::Statement(_)) => "statement", Some(LayoutBoundary::Binding) => "binding", Some(LayoutBoundary::IdentItem) => "declaration", Some(LayoutBoundary::RouteDecl) => "route", @@ -1635,14 +1703,14 @@ out(cp) = gut ~ continuous() #[test] fn authoring_output_annotation_is_optional() { let annotated = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 out(cp) = central ~ continuous() "#; let plain = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 @@ -1658,14 +1726,14 @@ out(cp) = central #[test] fn authoring_dx_and_ddt_lower_equivalently() { let dx_src = r#" -model = derivative_alias +name = derivative_alias kind = ode states = central dx(central) = -ke * central out(cp) = central "#; let ddt_src = r#" -model = derivative_alias +name = derivative_alias kind = ode states = central ddt(central) = -ke * central @@ -1681,7 +1749,7 @@ out(cp) = central #[test] fn authoring_rejects_out_target_not_in_declared_outputs() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index 9f46500a..3f454983 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -8,7 +8,7 @@ use crate::diagnostic::{ TextEdit, DSL_SEMANTIC_GENERIC, }; use crate::ir::*; -use crate::ModelKind; +use crate::{ModelKind, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME}; const RESERVED_NAMES: &[&str] = &[ "abs", @@ -29,7 +29,7 @@ const RESERVED_NAMES: &[&str] = &[ "min", "noise", "pow", - "rate", + RATE_FUNCTION_NAME, "round", "sin", "cos", @@ -37,8 +37,6 @@ const RESERVED_NAMES: &[&str] = &[ "sqrt", ]; -const RATE_FUNCTION_NAME: &str = "rate"; - #[derive(Default)] struct SemanticAssist { context_labels: Vec<(Span, String)>, @@ -345,29 +343,30 @@ impl<'a> Analyzer<'a> { }; let analytical = if let Some(block) = sections.analytical { - let kernel = AnalyticalKernel::from_name(&block.kernel.text).ok_or_else(|| { - SemanticError::new( - format!("unknown analytical kernel `{}`", block.kernel.text), - block.kernel.span, - ) - })?; + let structure = + AnalyticalKernel::from_name(&block.structure.text).ok_or_else(|| { + SemanticError::new( + format!("unknown analytical structure `{}`", block.structure.text), + block.structure.span, + ) + })?; let state_components = states .iter() .map(|state| state.size.unwrap_or(1)) .sum::(); - if state_components != kernel.state_count() { + if state_components != structure.state_count() { return Err(SemanticError::new( format!( - "analytical kernel `{}` expects {} state value(s), but model declares {}", - block.kernel.text, - kernel.state_count(), + "analytical structure `{}` expects {} state value(s), but model declares {}", + block.structure.text, + structure.state_count(), state_components ), - block.kernel.span, + block.structure.span, )); } Some(TypedAnalytical { - kernel, + structure, span: block.span, }) } else { @@ -571,6 +570,7 @@ impl<'a> Analyzer<'a> { let mut routes = Vec::new(); if let Some(block) = block { for route in &block.routes { + self.validate_route_label_name(&route.input)?; let id = self.insert_global_symbol( &route.input.text, SymbolKind::Route, @@ -624,6 +624,7 @@ impl<'a> Analyzer<'a> { } routes.push(TypedRoute { symbol: id, + kind: route.kind, destination, properties, span: route.span, @@ -647,6 +648,9 @@ impl<'a> Analyzer<'a> { collect_bare_assignment_names(statements, &mut seen, &mut collected_idents); let mut symbols = Vec::new(); for ident in collected_idents { + if matches!(kind, SymbolKind::Output) { + self.validate_output_label_name(&ident)?; + } let id = self.insert_global_symbol( &ident.text, kind, @@ -1306,7 +1310,7 @@ impl<'a> Analyzer<'a> { span: Span, env: &BlockEnv, ) -> Result { - if callee.text == "rate" { + if callee.text == RATE_FUNCTION_NAME { if args.len() != 1 { return Err(SemanticError::new( format!( @@ -1316,12 +1320,18 @@ impl<'a> Analyzer<'a> { callee.span, )); } + if let syntax::ExprKind::Number(value) = &args[0].kind { + if let Some(suffix) = numeric_label_literal_suffix(*value) { + return Err(self.bare_numeric_route_error(args[0].span, &suffix)); + } + } let syntax::ExprKind::Name(route_name) = &args[0].kind else { return Err(SemanticError::new( "`rate` expects a route identifier argument", args[0].span, )); }; + self.validate_route_label_name(route_name)?; let route = self .globals .routes @@ -1561,7 +1571,7 @@ impl<'a> Analyzer<'a> { }) } syntax::ExprKind::Call { callee, args } => { - if callee.text == "rate" { + if callee.text == RATE_FUNCTION_NAME { return Err(SemanticError::new( "`rate(...)` cannot appear in a compile-time expression", callee.span, @@ -1615,29 +1625,32 @@ impl<'a> Analyzer<'a> { span, ))); } - if let Some(existing) = self.globals.all_names.get(name) { - return Err(SemanticAssist::default() - .context_label( - self.symbol_span(*existing), - self.symbol_declared_here(*existing), - ) - .help(format!( - "rename this declaration to a unique name such as `{}_2`", - name - )) - .replacement_suggestion( - span, - format!("{}_2", name), - format!("rename this declaration to `{}_2`", name), - Applicability::MaybeIncorrect, - ) - .apply(SemanticError::new( - format!( - "symbol name `{name}` collides with existing `{}`", - self.symbol_name(*existing) - ), - span, - ))); + if let Some(existing) = self.globals.all_names.get(name).copied() { + let existing_kind = self.symbols.get(existing).expect("valid symbol id").kind; + if !allows_route_output_name_overlap(existing_kind, kind) { + return Err(SemanticAssist::default() + .context_label( + self.symbol_span(existing), + self.symbol_declared_here(existing), + ) + .help(format!( + "rename this declaration to a unique name such as `{}_2`", + name + )) + .replacement_suggestion( + span, + format!("{}_2", name), + format!("rename this declaration to `{}_2`", name), + Applicability::MaybeIncorrect, + ) + .apply(SemanticError::new( + format!( + "symbol name `{name}` collides with existing `{}`", + self.symbol_name(existing) + ), + span, + ))); + } } let id = self.symbols.len(); self.symbols.push(PendingSymbol { @@ -1647,10 +1660,104 @@ impl<'a> Analyzer<'a> { ty, span, }); - self.globals.all_names.insert(name.to_string(), id); + self.globals.all_names.entry(name.to_string()).or_insert(id); Ok(id) } + fn validate_route_label_name(&self, label: &syntax::Ident) -> Result<(), SemanticError> { + if let Some(suffix) = bare_numeric_label(&label.text) { + return Err(self.bare_numeric_route_error(label.span, suffix)); + } + if let Some(suffix) = canonical_numeric_suffix(&label.text, NUMERIC_OUTPUT_PREFIX) { + return Err(self.wrong_prefix_route_error(label, suffix)); + } + Ok(()) + } + + fn validate_output_label_name(&self, label: &syntax::Ident) -> Result<(), SemanticError> { + if let Some(suffix) = bare_numeric_label(&label.text) { + return Err(self.bare_numeric_output_error(label.span, suffix)); + } + if let Some(suffix) = canonical_numeric_suffix(&label.text, NUMERIC_ROUTE_PREFIX) { + return Err(self.wrong_prefix_output_error(label, suffix)); + } + Ok(()) + } + + fn bare_numeric_route_error(&self, span: Span, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_ROUTE_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric route labels must use the `input_` form in authored DSL") + .replacement_suggestion( + span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "bare numeric route labels are not allowed in the DSL; use `{replacement}` instead" + ), + span, + )) + } + + fn bare_numeric_output_error(&self, span: Span, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_OUTPUT_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric output labels must use the `outeq_` form in authored DSL") + .replacement_suggestion( + span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "bare numeric output labels are not allowed in the DSL; use `{replacement}` instead" + ), + span, + )) + } + + fn wrong_prefix_route_error(&self, label: &syntax::Ident, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_ROUTE_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric route labels use the `input_` prefix") + .replacement_suggestion( + label.span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "`{}` is an output label and cannot be used as a route; use `{replacement}` here", + label.text + ), + label.span, + )) + } + + fn wrong_prefix_output_error(&self, label: &syntax::Ident, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_OUTPUT_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric output labels use the `outeq_` prefix") + .replacement_suggestion( + label.span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "`{}` is a route label and cannot be used as an output target; use `{replacement}` here", + label.text + ), + label.span, + )) + } + fn insert_local_symbol( &mut self, env: &mut BlockEnv, @@ -2130,6 +2237,27 @@ impl<'a> Analyzer<'a> { } } +fn allows_route_output_name_overlap(existing: SymbolKind, new: SymbolKind) -> bool { + matches!( + (existing, new), + (SymbolKind::Route, SymbolKind::Output) | (SymbolKind::Output, SymbolKind::Route) + ) +} + +fn bare_numeric_label(src: &str) -> Option<&str> { + (!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src) +} + +fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> { + let suffix = src.strip_prefix(prefix)?; + (!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix) +} + +fn numeric_label_literal_suffix(value: f64) -> Option { + (value.is_finite() && value >= 0.0 && value.fract() == 0.0 && value <= usize::MAX as f64) + .then(|| (value as usize).to_string()) +} + #[derive(Default)] struct Globals { all_names: BTreeMap, @@ -2651,6 +2779,7 @@ mod tests { use crate::test_fixtures::{ RECOMMENDED_STYLE_AUTHORING, RECOMMENDED_STYLE_CANONICAL, STRUCTURED_BLOCK_CORPUS, }; + use crate::RouteKind; use crate::{parse_model, parse_module}; #[test] @@ -2667,7 +2796,7 @@ mod tests { let analytical = &typed.models[2]; assert!(matches!( - analytical.analytical.as_ref().map(|value| value.kernel), + analytical.analytical.as_ref().map(|value| value.structure), Some(AnalyticalKernel::OneCompartmentWithAbsorption) )); @@ -2691,7 +2820,7 @@ mod tests { } #[test] - fn authoring_fixture_lowers_to_equivalent_typed_ir() { + fn authoring_fixture_preserves_route_kind_while_remaining_equivalent() { let authoring_surface = RECOMMENDED_STYLE_AUTHORING; let canonical = RECOMMENDED_STYLE_CANONICAL; @@ -2705,6 +2834,8 @@ mod tests { typed_model_signature(&authoring_typed), typed_model_signature(&canonical_typed) ); + assert_eq!(authoring_typed.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(canonical_typed.routes[0].kind, None); } #[test] @@ -2977,7 +3108,7 @@ model broken { lines.push(format!("particles:{:?}", model.particles)); lines.push(format!( "analytical:{:?}", - model.analytical.as_ref().map(|value| value.kernel) + model.analytical.as_ref().map(|value| value.structure) )); lines.push(format!( "derive:{}", diff --git a/pharmsol-dsl/src/test_fixtures.rs b/pharmsol-dsl/src/test_fixtures.rs index f26181e4..281a268e 100644 --- a/pharmsol-dsl/src/test_fixtures.rs +++ b/pharmsol-dsl/src/test_fixtures.rs @@ -83,7 +83,7 @@ model one_cmt_abs { oral -> depot } analytical { - kernel = one_compartment_with_absorption + structure = one_compartment_with_absorption } outputs { cp = central / v @@ -132,7 +132,7 @@ model vanco_sde { } "#; -pub(crate) const RECOMMENDED_STYLE_AUTHORING: &str = r#"model = recommended_style +pub(crate) const RECOMMENDED_STYLE_AUTHORING: &str = r#"name = recommended_style kind = ode params = ka, ke, v diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 941ffa77..335a7f86 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -1,16 +1,16 @@ -use pharmsol_dsl::{analyze_model, parse_model, parse_module}; +use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model, parse_module}; #[test] fn output_annotation_is_optional() { let annotated = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 out(cp) = central ~ continuous() "#; let plain = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 @@ -26,7 +26,7 @@ out(cp) = central #[test] fn dx_and_ddt_lower_equivalently() { let dx_src = r#" -model = derivative_alias +name = derivative_alias kind = ode params = ke states = central @@ -34,7 +34,7 @@ dx(central) = -ke * central out(cp) = central "#; let ddt_src = r#" -model = derivative_alias +name = derivative_alias kind = ode params = ke states = central @@ -51,7 +51,7 @@ out(cp) = central #[test] fn rejects_out_target_not_in_declared_outputs() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -95,7 +95,7 @@ out(cp) = central / v ~ continuous() #[test] fn rejects_out_target_not_in_declared_outputs_when_declared_later() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -122,7 +122,7 @@ ddt(central) = -ke * central #[test] fn rejects_declared_output_without_assignment() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -144,7 +144,7 @@ out(cp) = central / v #[test] fn rejects_unknown_output_annotation_name() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode states = central ddt(central) = 0 @@ -161,10 +161,329 @@ out(cp) = central ~ continous() ); } +#[test] +fn mixed_named_and_prefixed_numeric_output_labels_lower_and_round_trip() { + let src = r#" +name = mixed_output_labels +kind = ode +params = ke, v +states = central +outputs = cp, outeq_0, outeq_1 +infusion(iv) -> central +ddt(central) = -ke * central +out(cp) = central / v +out(outeq_0) = 2 * central / v +out(outeq_1) = 3 * central / v +"#; + + let module = parse_module(src).expect("mixed output labels should parse in authoring DSL"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(&model).expect("mixed output labels should analyze"); + let lowered = lower_typed_model(&typed).expect("mixed output labels should lower"); + + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "outeq_0", "outeq_1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.index) + .collect::>(), + vec![0, 1, 2] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered mixed-output model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + +#[test] +fn prefixed_numeric_route_and_output_labels_lower_and_round_trip() { + let src = r#" +name = prefixed_numeric_route_output_labels +kind = ode +params = ke, v +states = central +outputs = outeq_1 +infusion(input_1) -> central +ddt(central) = -ke * central +out(outeq_1) = central / v +"#; + + let module = parse_module(src).expect("prefixed numeric route/output labels should parse"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(model).expect("prefixed numeric route/output labels should analyze"); + let lowered = + lower_typed_model(&typed).expect("prefixed numeric route/output labels should lower"); + + assert_eq!( + lowered + .metadata + .routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["input_1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["outeq_1"] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered shared-label model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + +#[test] +fn rejects_authoring_bare_numeric_output_declarations() { + let src = r#" +name = numeric_outputs +kind = ode +states = central +outputs = 1, 2 +ddt(central) = 0 +out(1) = central +"#; + + let err = parse_model(src).expect_err("bare numeric output declarations must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric output labels are not allowed in the DSL; use `outeq_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `outeq_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_authoring_bare_numeric_route_labels() { + let src = r#" +name = numeric_routes +kind = ode +states = central +outputs = cp +infusion(1) -> central +ddt(central) = 0 +out(cp) = central +"#; + + let err = parse_model(src).expect_err("bare numeric route labels must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric route labels are not allowed in the DSL; use `input_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `input_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_structured_bare_numeric_output_targets() { + let src = r#" +model numeric_output_target { + kind ode + states { central } + outputs { + 1 = central + } +} +"#; + + let model = parse_model(src).expect("structured model parses"); + let err = analyze_model(&model).expect_err("bare numeric output target must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric output labels are not allowed in the DSL; use `outeq_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `outeq_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_structured_bare_numeric_route_labels() { + let src = r#" +model numeric_route_label { + kind ode + states { central } + routes { + 1 -> central + } + outputs { + cp = central + } +} +"#; + + let model = parse_model(src).expect("structured model parses"); + let err = analyze_model(&model).expect_err("bare numeric route label must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric route labels are not allowed in the DSL; use `input_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `input_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_rate_numeric_literals_with_prefixed_guidance() { + let src = r#" +model numeric_rate_arg { + kind ode + states { central } + routes { input_5 -> central } + dynamics { + ddt(central) = rate(5) + } + outputs { + cp = central + } +} +"#; + + let model = parse_model(src).expect("structured model parses"); + let err = analyze_model(&model).expect_err("bare numeric rate argument must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric route labels are not allowed in the DSL; use `input_5` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `input_5`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_wrong_prefix_labels_in_authored_dsl() { + let src = r#" +name = wrong_prefix_route +kind = ode +states = central +outputs = cp +infusion(outeq_1) -> central +ddt(central) = 0 +out(cp) = central +"#; + + let err = parse_model(src).expect_err("wrong-prefix route labels must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "`outeq_1` is an output label and cannot be used as a route; use `input_1` here" + ), + "{}", + rendered + ); + + let src = r#" +name = wrong_prefix_output +kind = ode +states = central +outputs = cp +infusion(iv) -> central +ddt(central) = 0 +out(input_1) = central +"#; + + let err = parse_model(src).expect_err("wrong-prefix output labels must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "`input_1` is a route label and cannot be used as an output; use `outeq_1` here" + ), + "{}", + rendered + ); +} + +#[test] +fn route_labels_still_collide_with_scalar_symbol_names() { + let src = r#" +name = route_state_collision +kind = ode +params = ke +states = central, iv +outputs = cp +infusion(iv) -> central +ddt(central) = -ke * central +ddt(iv) = 0 +out(cp) = central +"#; + + let model = parse_model(src).expect("route/state collision model parses"); + let err = analyze_model(&model).expect_err("route label should still collide with state name"); + let rendered = err.render(src); + + assert!( + rendered.contains("symbol name `iv` collides with existing `iv`"), + "{}", + rendered + ); +} + #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v diff --git a/pharmsol-macros/Cargo.toml b/pharmsol-macros/Cargo.toml index 291b888c..07e3688f 100644 --- a/pharmsol-macros/Cargo.toml +++ b/pharmsol-macros/Cargo.toml @@ -13,4 +13,4 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.106" quote = "1.0.45" -syn = { version = "2.0.117", features = ["full"] } +syn = { version = "2.0.117", features = ["full", "visit", "visit-mut"] } diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 0fa320a3..83007df2 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -4,11 +4,16 @@ //! `pharmsol` crate instead. use proc_macro::TokenStream; -use proc_macro2::TokenTree; +use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; +use std::collections::{HashMap, HashSet}; use syn::{ - parse::{Parse, ParseStream}, - ExprClosure, Ident, Pat, Token, + parse::{Parse, ParseStream, Parser}, + punctuated::Punctuated, + token, + visit::Visit, + visit_mut::VisitMut, + Expr, ExprClosure, Ident, Lit, LitInt, LitStr, Pat, Stmt, Token, }; // --------------------------------------------------------------------------- @@ -16,6 +21,12 @@ use syn::{ // --------------------------------------------------------------------------- struct OdeInput { + name: LitStr, + params: Vec, + covariates: Vec, + states: Vec, + outputs: Vec, + routes: Vec, diffeq: ExprClosure, lag: Option, fa: Option, @@ -23,45 +34,568 @@ struct OdeInput { out: ExprClosure, } +struct AnalyticalInput { + name: LitStr, + params: Vec, + covariates: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + structure: Ident, + sec: Option, + lag: Option, + fa: Option, + init: Option, + out: ExprClosure, +} + +struct SdeInput { + name: LitStr, + params: Vec, + covariates: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + particles: Expr, + drift: ExprClosure, + diffusion: ExprClosure, + lag: Option, + fa: Option, + init: Option, + out: ExprClosure, +} + +struct OdeRouteDecl { + kind: OdeRouteKind, + input: SymbolicIndex, + destination: Ident, +} + +#[derive(Clone, Copy)] +enum OdeRouteKind { + Bolus, + Infusion, +} + +struct AnalyticalKernelSpec { + runtime_path: TokenStream2, + metadata_kernel: TokenStream2, + parameter_arity: usize, + state_count: usize, +} + +struct RoutePropertyEntry { + route: SymbolicIndex, + value: Expr, +} + +#[derive(Clone)] +enum SymbolicIndex { + Ident(Ident), + Int(LitInt), +} + +impl SymbolicIndex { + fn name(&self) -> String { + match self { + Self::Ident(ident) => ident.to_string(), + Self::Int(lit) => lit.base10_digits().to_string(), + } + } + + fn ident(&self) -> Option<&Ident> { + match self { + Self::Ident(ident) => Some(ident), + Self::Int(_) => None, + } + } + + fn numeric_value(&self) -> Option { + match self { + Self::Ident(_) => None, + Self::Int(lit) => Some( + lit.base10_parse::() + .expect("validated numeric label should fit usize"), + ), + } + } + + fn numeric(value: usize) -> Self { + Self::Int(LitInt::new(&value.to_string(), Span::call_site())) + } +} + +impl Parse for SymbolicIndex { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(LitInt) { + let lit: LitInt = input.parse()?; + lit.base10_parse::().map_err(|_| { + syn::Error::new_spanned( + &lit, + "numeric declaration-first labels must be non-negative base-10 integers that fit in usize", + ) + })?; + Ok(Self::Int(lit)) + } else { + Ok(Self::Ident(input.parse()?)) + } + } +} + +impl ToTokens for SymbolicIndex { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + Self::Ident(ident) => ident.to_tokens(tokens), + Self::Int(lit) => lit.to_tokens(tokens), + } + } +} + +impl std::fmt::Display for SymbolicIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.name()) + } +} + +impl Parse for OdeRouteDecl { + fn parse(input: ParseStream) -> syn::Result { + let kind_ident: Ident = input.parse()?; + let kind = match kind_ident.to_string().as_str() { + "bolus" => OdeRouteKind::Bolus, + "infusion" => OdeRouteKind::Infusion, + other => { + return Err(syn::Error::new_spanned( + &kind_ident, + format!("unknown route kind `{other}`, expected `bolus` or `infusion`"), + )); + } + }; + + let content; + syn::parenthesized!(content in input); + let route_input: SymbolicIndex = content.parse()?; + if !content.is_empty() { + return Err(content.error("expected a single route input name inside `(...)`")); + } + + if !input.peek(Token![->]) { + return Err( + input.error("expected `->` followed by a destination state in route declaration") + ); + } + input.parse::]>()?; + let destination: Ident = input.parse()?; + + if input.peek(token::Brace) { + return Err( + input.error("route properties are not supported in declaration-first `ode!` yet") + ); + } + + Ok(Self { + kind, + input: route_input, + destination, + }) + } +} + impl Parse for OdeInput { fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut covariates = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; let mut diffeq = None; let mut lag = None; - let mut fa_val = None; + let mut fa = None; + let mut init = None; + let mut out = None; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + + match key.to_string().as_str() { + "name" => set_once_ode(&mut name, input.parse()?, &key, "name")?, + "params" => set_once_ode(&mut params, parse_ident_list(input)?, &key, "params")?, + "covariates" => set_once_ode( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, + "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?, + "outputs" => set_once_ode( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, + "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?, + "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?, + "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_ode(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_ode(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_ode(&mut out, input.parse()?, &key, "out")?, + other => { + return Err(syn::Error::new_spanned( + &key, + format!( + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, diffeq, lag, fa, init, out" + ), + )); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + let name = name.ok_or_else(|| { + syn::Error::new( + Span::call_site(), + "declaration-first `ode!` requires `name`, `params`, `states`, `outputs`, and `routes`; the old inferred-dimensions form has been removed", + ) + })?; + let params = params.ok_or_else(|| missing_required_ode_field("params"))?; + let covariates = covariates.unwrap_or_default(); + let states = states.ok_or_else(|| missing_required_ode_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_ode_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?; + let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?; + let out = out.ok_or_else(|| missing_required_ode_field("out"))?; + validate_ode_diffeq_uses_automatic_injection(&diffeq, &routes)?; + + validate_unique_idents("parameter", ¶ms, "ode!")?; + validate_unique_idents("covariate", &covariates, "ode!")?; + validate_unique_idents("state", &states, "ode!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "ode!")?; + validate_routes(&routes, &states, "ode!")?; + validate_named_binding_compatibility( + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &output_idents, + routes: &routes, + }, + OdeBindingClosures { + diffeq: &diffeq, + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + }, + )?; + + Ok(Self { + name, + params, + covariates, + states, + outputs, + routes, + diffeq, + lag, + fa, + init, + out, + }) + } +} + +impl Parse for RoutePropertyEntry { + fn parse(input: ParseStream) -> syn::Result { + let route: SymbolicIndex = input.parse()?; + input.parse::]>()?; + let value: Expr = input.parse()?; + Ok(Self { route, value }) + } +} + +impl Parse for AnalyticalInput { + fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut covariates = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; + let mut structure = None; + let mut sec = None; + let mut lag = None; + let mut fa = None; + let mut init = None; + let mut out = None; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + + match key.to_string().as_str() { + "name" => set_once_analytical(&mut name, input.parse()?, &key, "name")?, + "params" => { + set_once_analytical(&mut params, parse_ident_list(input)?, &key, "params")? + } + "covariates" => set_once_analytical( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, + "states" => { + set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")? + } + "outputs" => set_once_analytical( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, + "routes" => { + set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")? + } + "structure" => { + set_once_analytical(&mut structure, input.parse()?, &key, "structure")? + } + "sec" => set_once_analytical(&mut sec, input.parse()?, &key, "sec")?, + "lag" => set_once_analytical(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_analytical(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_analytical(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_analytical(&mut out, input.parse()?, &key, "out")?, + other => { + return Err(syn::Error::new_spanned( + &key, + format!( + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, structure, sec, lag, fa, init, out" + ), + )); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + let name = name.ok_or_else(|| missing_required_analytical_field("name"))?; + let params = params.ok_or_else(|| missing_required_analytical_field("params"))?; + let covariates = covariates.unwrap_or_default(); + let states = states.ok_or_else(|| missing_required_analytical_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_analytical_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_analytical_field("routes"))?; + let structure = structure.ok_or_else(|| missing_required_analytical_field("structure"))?; + let out = out.ok_or_else(|| missing_required_analytical_field("out"))?; + + validate_unique_idents("parameter", ¶ms, "analytical!")?; + validate_unique_idents("covariate", &covariates, "analytical!")?; + validate_unique_idents("state", &states, "analytical!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "analytical!")?; + validate_routes(&routes, &states, "analytical!")?; + + let kernel_spec = resolve_analytical_structure(&structure)?; + if params.len() < kernel_spec.parameter_arity { + return Err(syn::Error::new_spanned( + &structure, + format!( + "analytical structure `{}` requires at least {} parameter value(s), but `params` declares {}", + structure, kernel_spec.parameter_arity, params.len() + ), + )); + } + if states.len() != kernel_spec.state_count { + return Err(syn::Error::new_spanned( + &structure, + format!( + "analytical structure `{}` expects {} state value(s), but `states` declares {}", + structure, + kernel_spec.state_count, + states.len() + ), + )); + } + + validate_analytical_named_binding_compatibility( + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &output_idents, + routes: &routes, + }, + AnalyticalBindingClosures { + sec: sec.as_ref(), + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + }, + )?; + + if let Some(lag) = lag.as_ref() { + let lag_routes = + extract_route_property_routes("built-in `analytical!`", "lag", lag, &routes)?; + validate_route_property_kinds("built-in `analytical!`", "lag", &routes, &lag_routes)?; + } + + if let Some(fa) = fa.as_ref() { + let fa_routes = + extract_route_property_routes("built-in `analytical!`", "fa", fa, &routes)?; + validate_route_property_kinds("built-in `analytical!`", "fa", &routes, &fa_routes)?; + } + + Ok(Self { + name, + params, + covariates, + states, + outputs, + routes, + structure, + sec, + lag, + fa, + init, + out, + }) + } +} + +impl Parse for SdeInput { + fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut covariates = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; + let mut particles = None; + let mut drift = None; + let mut diffusion = None; + let mut lag = None; + let mut fa = None; let mut init = None; let mut out = None; while !input.is_empty() { let key: Ident = input.parse()?; input.parse::()?; - let closure: ExprClosure = input.parse()?; match key.to_string().as_str() { - "diffeq" => diffeq = Some(closure), - "lag" => lag = Some(closure), - "fa" => fa_val = Some(closure), - "init" => init = Some(closure), - "out" => out = Some(closure), + "name" => set_once_sde(&mut name, input.parse()?, &key, "name")?, + "params" => set_once_sde(&mut params, parse_ident_list(input)?, &key, "params")?, + "covariates" => set_once_sde( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, + "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?, + "outputs" => set_once_sde( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, + "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?, + "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?, + "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?, + "diffusion" => set_once_sde(&mut diffusion, input.parse()?, &key, "diffusion")?, + "lag" => set_once_sde(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_sde(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_sde(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_sde(&mut out, input.parse()?, &key, "out")?, other => { return Err(syn::Error::new_spanned( &key, - format!("unknown field `{other}`, expected: diffeq, lag, fa, init, out"), + format!( + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, particles, drift, diffusion, lag, fa, init, out" + ), )); } } - // optional trailing comma if !input.is_empty() { input.parse::()?; } } - Ok(OdeInput { - diffeq: diffeq.ok_or_else(|| input.error("missing required field `diffeq`"))?, + let name = name.ok_or_else(|| missing_required_sde_field("name"))?; + let params = params.ok_or_else(|| missing_required_sde_field("params"))?; + let covariates = covariates.unwrap_or_default(); + let states = states.ok_or_else(|| missing_required_sde_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_sde_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_sde_field("routes"))?; + let particles = particles.ok_or_else(|| missing_required_sde_field("particles"))?; + let drift = drift.ok_or_else(|| missing_required_sde_field("drift"))?; + let diffusion = diffusion.ok_or_else(|| missing_required_sde_field("diffusion"))?; + let out = out.ok_or_else(|| missing_required_sde_field("out"))?; + + validate_unique_idents("parameter", ¶ms, "sde!")?; + validate_unique_idents("covariate", &covariates, "sde!")?; + validate_unique_idents("state", &states, "sde!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "sde!")?; + validate_routes(&routes, &states, "sde!")?; + validate_sde_named_binding_compatibility( + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &output_idents, + routes: &routes, + }, + SdeBindingClosures { + drift: &drift, + diffusion: &diffusion, + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + }, + )?; + + if let Some(lag) = lag.as_ref() { + let lag_routes = + extract_route_property_routes("declaration-first `sde!`", "lag", lag, &routes)?; + validate_route_property_kinds("declaration-first `sde!`", "lag", &routes, &lag_routes)?; + } + + if let Some(fa) = fa.as_ref() { + let fa_routes = + extract_route_property_routes("declaration-first `sde!`", "fa", fa, &routes)?; + validate_route_property_kinds("declaration-first `sde!`", "fa", &routes, &fa_routes)?; + } + + Ok(Self { + name, + params, + covariates, + states, + outputs, + routes, + particles, + drift, + diffusion, lag, - fa: fa_val, + fa, init, - out: out.ok_or_else(|| input.error("missing required field `out`"))?, + out, }) } } @@ -70,7 +604,106 @@ impl Parse for OdeInput { // Helpers // --------------------------------------------------------------------------- -/// Return the identifier string for a closure parameter (empty for wildcards). +fn missing_required_ode_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in declaration-first `ode!`"), + ) +} + +fn missing_required_analytical_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in built-in `analytical!`"), + ) +} + +fn missing_required_sde_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in declaration-first `sde!`"), + ) +} + +fn set_once_ode(slot: &mut Option, value: T, key: &Ident, name: &str) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `ode!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn set_once_analytical( + slot: &mut Option, + value: T, + key: &Ident, + name: &str, +) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `analytical!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn set_once_sde(slot: &mut Option, value: T, key: &Ident, name: &str) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `sde!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn parse_ident_list(input: ParseStream) -> syn::Result> { + let content; + syn::bracketed!(content in input); + Ok(Punctuated::::parse_terminated(&content)? + .into_iter() + .collect()) +} + +fn parse_symbolic_index_list(input: ParseStream) -> syn::Result> { + let content; + syn::bracketed!(content in input); + Ok( + Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(), + ) +} + +fn parse_route_list(input: ParseStream) -> syn::Result> { + if input.peek(token::Brace) { + return Err(input.error("declaration-first macro `routes` must use `[...]`, not `{...}`")); + } + + if !input.peek(token::Bracket) { + return Err( + input.error("expected a bracketed route list like `routes: [infusion(iv) -> central]`") + ); + } + + let content; + syn::bracketed!(content in input); + Ok( + Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(), + ) +} + fn param_name(pat: &Pat) -> String { match pat { Pat::Ident(p) => p.ident.to_string(), @@ -82,208 +715,2704 @@ fn closure_param_names(c: &ExprClosure) -> Vec { c.inputs.iter().map(param_name).collect() } -/// Recursively scan `tokens` for `ident[literal_int]` patterns where the -/// ident matches one of `names`. Returns the maximum literal integer found. -fn max_literal_index(tokens: proc_macro2::TokenStream, names: &[&str]) -> Option { - let tts: Vec = tokens.into_iter().collect(); - let mut best: Option = None; - - for (i, tt) in tts.iter().enumerate() { - match tt { - TokenTree::Ident(ident) => { - let s = ident.to_string(); - if names.contains(&s.as_str()) { - if let Some(TokenTree::Group(g)) = tts.get(i + 1) { - if g.delimiter() == proc_macro2::Delimiter::Bracket { - let inner: Vec = g.stream().into_iter().collect(); - if inner.len() == 1 { - if let TokenTree::Literal(lit) = &inner[0] { - if let Ok(n) = lit.to_string().parse::() { - best = Some(best.map_or(n, |m: usize| m.max(n))); - } - } - } - } - } - } - } - // recurse into brace / paren groups (bracket groups are indexing, handled above) - TokenTree::Group(g) - if matches!( - g.delimiter(), - proc_macro2::Delimiter::Brace | proc_macro2::Delimiter::Parenthesis - ) => +fn closure_param_ident(c: &ExprClosure, index: usize) -> Option { + c.inputs.get(index).and_then(|pat| match pat { + Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()), + _ => None, + }) +} + +fn generated_ident(name: &str) -> Ident { + Ident::new(name, Span::call_site()) +} + +fn symbolic_index_idents(labels: &[SymbolicIndex]) -> Vec { + labels + .iter() + .filter_map(|label| label.ident().cloned()) + .collect() +} + +fn symbolic_index_bindings(labels: &[SymbolicIndex]) -> Vec<(SymbolicIndex, usize)> { + labels + .iter() + .cloned() + .enumerate() + .map(|(index, label)| (label, index)) + .collect() +} + +fn symbolic_numeric_binding_map(bindings: &[(SymbolicIndex, usize)]) -> HashMap { + bindings + .iter() + .filter_map(|(label, index)| label.numeric_value().map(|value| (value, *index))) + .collect() +} + +#[derive(Default)] +struct ClosureBodyUsage { + idents: HashSet, + indexed_idents: HashSet, + assigned_indexed_idents: HashSet, + contains_macro: bool, +} + +impl ClosureBodyUsage { + fn analyze(expr: &Expr) -> Self { + let mut usage = Self::default(); + usage.visit_expr(expr); + usage + } + + fn uses(&self, ident: &Ident) -> bool { + self.contains_macro || self.idents.contains(&ident.to_string()) + } + + fn mentions(&self, ident: &Ident) -> bool { + self.idents.contains(&ident.to_string()) + } + + fn indexes(&self, ident: &Ident) -> bool { + self.indexed_idents.contains(&ident.to_string()) + } + + fn assigns_index(&self, ident: &Ident) -> bool { + self.assigned_indexed_idents.contains(&ident.to_string()) + } +} + +impl<'ast> Visit<'ast> for ClosureBodyUsage { + fn visit_expr_path(&mut self, expr_path: &'ast syn::ExprPath) { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 + { + self.idents + .insert(expr_path.path.segments[0].ident.to_string()); + } + + syn::visit::visit_expr_path(self, expr_path); + } + + fn visit_expr_macro(&mut self, expr_macro: &'ast syn::ExprMacro) { + self.contains_macro = true; + syn::visit::visit_expr_macro(self, expr_macro); + } + + fn visit_stmt_macro(&mut self, stmt_macro: &'ast syn::StmtMacro) { + self.contains_macro = true; + syn::visit::visit_stmt_macro(self, stmt_macro); + } + + fn visit_expr_index(&mut self, expr_index: &'ast syn::ExprIndex) { + if let Expr::Path(expr_path) = expr_index.expr.as_ref() { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 { - if let Some(n) = max_literal_index(g.stream(), names) { - best = Some(best.map_or(n, |m: usize| m.max(n))); + self.indexed_idents + .insert(expr_path.path.segments[0].ident.to_string()); + } + } + + syn::visit::visit_expr_index(self, expr_index); + } + + fn visit_expr_assign(&mut self, expr_assign: &'ast syn::ExprAssign) { + if let Expr::Index(expr_index) = expr_assign.left.as_ref() { + if let Expr::Path(expr_path) = expr_index.expr.as_ref() { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 + { + self.assigned_indexed_idents + .insert(expr_path.path.segments[0].ident.to_string()); } } - _ => {} } + + syn::visit::visit_expr_assign(self, expr_assign); } +} - best +struct IndexRewriteTarget { + container: Ident, + labels: HashMap, } -// --------------------------------------------------------------------------- -// Proc macro -// --------------------------------------------------------------------------- +impl IndexRewriteTarget { + fn new(container: Ident, labels: HashMap) -> Self { + Self { container, labels } + } +} -/// Build an `equation::ODE` while **inferring** `nstates`, `ndrugs` and -/// `nout` from the maximum literal bracket-indices used in the closures. -/// -/// # Fields (any order, comma-separated) -/// -/// | Field | Required | Signature | -/// |----------|----------|-------------------------------------------------| -/// | `diffeq` | **yes** | `\|x, p, t, dx, bolus, rateiv, cov\| { … }` | -/// | `out` | **yes** | `\|x, p, t, cov, y\| { … }` | -/// | `init` | no | `\|p, t, cov, x\| { … }` | -/// | `lag` | no | `\|p, t, cov\| lag! { … }` | -/// | `fa` | no | `\|p, t, cov\| fa! { … }` | -/// -/// # Inference rules -/// -/// * **nstates** = max literal index of the state / derivative vectors + 1 -/// * **ndrugs** = max literal index of bolus / rateiv vectors + 1 -/// * **nout** = max literal index of the output vector + 1 -/// -/// Parameter names are taken from the closure signatures so you can name them -/// however you like. Only **literal** integer indices (e.g. `x[2]`) are -/// detected; computed indices require manual `.with_nstates()` etc. -/// -/// # Example -/// -/// ```ignore -/// use pharmsol::prelude::*; -/// -/// let ode = ode! { -/// diffeq: |x, p, _t, dx, b, rateiv, _cov| { -/// fetch_params!(p, ke, kcp, kpc, _v); -/// dx[0] = rateiv[0] + b[0] - ke * x[0] - kcp * x[0] + kpc * x[1]; -/// dx[1] = kcp * x[0] - kpc * x[1]; -/// }, -/// out: |x, p, _t, _cov, y| { -/// fetch_params!(p, _ke, _kcp, _kpc, v); -/// y[0] = x[0] / v; -/// }, -/// }; -/// // Inferred: nstates=2, ndrugs=1, nout=1 -/// ``` -#[proc_macro] -pub fn ode(input: TokenStream) -> TokenStream { - let input = syn::parse_macro_input!(input as OdeInput); +struct NumericLabelRewriter { + index_targets: Vec, + route_labels: Option>, +} - // ── Validate parameter counts ──────────────────────────────── - let de_params = closure_param_names(&input.diffeq); - if de_params.len() != 7 { - return syn::Error::new_spanned( - &input.diffeq, - "diffeq closure must have 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", - ) - .to_compile_error() - .into(); +impl NumericLabelRewriter { + fn rewrite( + expr: &Expr, + index_targets: Vec, + route_labels: Option>, + ) -> Expr { + let mut rewritten = expr.clone(); + let mut rewriter = Self { + index_targets, + route_labels, + }; + rewriter.visit_expr_mut(&mut rewritten); + rewritten } - let out_params = closure_param_names(&input.out); - if out_params.len() != 5 { - return syn::Error::new_spanned( - &input.out, - "out closure must have 5 parameters: |x, p, t, cov, y|", - ) - .to_compile_error() - .into(); + fn target_labels(&self, path: &syn::ExprPath) -> Option<&HashMap> { + if path.qself.is_some() + || path.path.leading_colon.is_some() + || path.path.segments.len() != 1 + { + return None; + } + + let ident = &path.path.segments[0].ident; + self.index_targets + .iter() + .find(|target| target.container == *ident) + .map(|target| &target.labels) } - // ── Collect names by role ──────────────────────────────────── - // diffeq positions: 0=x 3=dx 4=bolus 5=rateiv - // out positions: 0=x 4=y - // init positions: 3=x - let mut state_names: Vec = vec![ - de_params[0].clone(), - de_params[3].clone(), - out_params[0].clone(), - ]; - if let Some(ref ic) = input.init { - let ip = closure_param_names(ic); - if ip.len() >= 4 { - state_names.push(ip[3].clone()); + fn rewrite_route_macro(&self, mac: &mut syn::Macro) { + let Some(route_labels) = self.route_labels.as_ref() else { + return; + }; + if !(mac.path.is_ident("lag") || mac.path.is_ident("fa")) { + return; } + + let Ok(entries) = Punctuated::::parse_terminated + .parse2(mac.tokens.clone()) + else { + return; + }; + + let entries = entries.into_iter().map(|mut entry| { + if let Some(value) = entry.route.numeric_value() { + if let Some(internal_index) = route_labels.get(&value) { + entry.route = SymbolicIndex::numeric(*internal_index); + } + } + entry + }); + + let tokens = entries.map(|entry| { + let route = entry.route; + let value = entry.value; + quote! { #route => #value } + }); + mac.tokens = quote! { #(#tokens),* }; } - state_names.sort(); - state_names.dedup(); +} - let drug_names = [de_params[4].clone(), de_params[5].clone()]; - let output_names = [out_params[4].clone()]; +impl VisitMut for NumericLabelRewriter { + fn visit_expr_index_mut(&mut self, expr_index: &mut syn::ExprIndex) { + syn::visit_mut::visit_expr_index_mut(self, expr_index); - // filter empties (from wildcard `_` params) - let state_refs: Vec<&str> = state_names - .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - let drug_refs: Vec<&str> = drug_names - .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - let output_refs: Vec<&str> = output_names - .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - - // ── Scan closure bodies ────────────────────────────────────── - let de_tokens = input.diffeq.body.to_token_stream(); - let out_tokens = input.out.body.to_token_stream(); - let init_tokens = input.init.as_ref().map(|c| c.body.to_token_stream()); - - let max_state = [ - max_literal_index(de_tokens.clone(), &state_refs), - max_literal_index(out_tokens.clone(), &state_refs), - init_tokens.and_then(|t| max_literal_index(t, &state_refs)), - ] - .into_iter() - .flatten() - .max(); - - let max_drug = max_literal_index(de_tokens, &drug_refs); - let max_out = max_literal_index(out_tokens, &output_refs); - - let nstates = max_state.map_or(1, |n| n + 1); - let ndrugs = max_drug.map_or(1, |n| n + 1); - let nout = max_out.map_or(1, |n| n + 1); - - // ── Generate output ────────────────────────────────────────── - let diffeq = &input.diffeq; - let out = &input.out; - - let lag = input.lag.as_ref().map_or_else( - || quote! { |_, _, _| ::std::collections::HashMap::new() }, - |c| quote! { #c }, - ); + let Expr::Path(expr_path) = expr_index.expr.as_ref() else { + return; + }; + let Some(labels) = self.target_labels(expr_path) else { + return; + }; + let Expr::Lit(expr_lit) = expr_index.index.as_ref() else { + return; + }; + let Lit::Int(lit) = &expr_lit.lit else { + return; + }; + let Ok(external_index) = lit.base10_parse::() else { + return; + }; + let Some(internal_index) = labels.get(&external_index) else { + return; + }; - let fa = input.fa.as_ref().map_or_else( - || quote! { |_, _, _| ::std::collections::HashMap::new() }, - |c| quote! { #c }, - ); + *expr_index.index = Expr::Lit(syn::ExprLit { + attrs: Vec::new(), + lit: Lit::Int(LitInt::new(&internal_index.to_string(), lit.span())), + }); + } - let init = input - .init - .as_ref() - .map_or_else(|| quote! { |_, _, _, _| {} }, |c| quote! { #c }); + fn visit_expr_macro_mut(&mut self, expr_macro: &mut syn::ExprMacro) { + self.rewrite_route_macro(&mut expr_macro.mac); + syn::visit_mut::visit_expr_macro_mut(self, expr_macro); + } - quote! { - equation::ODE::new( - #diffeq, - #lag, - #fa, - #init, - #out, - ) - .with_nstates(#nstates) - .with_ndrugs(#ndrugs) - .with_nout(#nout) + fn visit_stmt_macro_mut(&mut self, stmt_macro: &mut syn::StmtMacro) { + self.rewrite_route_macro(&mut stmt_macro.mac); + syn::visit_mut::visit_stmt_macro_mut(self, stmt_macro); + } +} + +fn generate_closure_input_aliases( + closure: &ExprClosure, + internal_names: &[Ident], +) -> syn::Result { + if closure.inputs.len() != internal_names.len() { + return Err(syn::Error::new_spanned( + closure, + "internal named binding generation error: closure arity mismatch", + )); + } + + let aliases = + closure + .inputs + .iter() + .zip(internal_names.iter()) + .map(|(pattern, internal_name)| { + quote! { + let #pattern = #internal_name; + } + }); + + Ok(quote! { + #(#aliases)* + }) +} + +fn generate_supported_input_aliases( + closure: &ExprClosure, + supported_internal_names: &[&[Ident]], + error_message: &str, +) -> syn::Result { + for internal_names in supported_internal_names { + if closure.inputs.len() == internal_names.len() { + return generate_closure_input_aliases(closure, internal_names); + } + } + + Err(syn::Error::new_spanned(closure, error_message)) +} + +fn generate_parameter_bindings( + params: &[Ident], + closure: &ExprClosure, + parameter_vector: &Ident, +) -> TokenStream2 { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let bindings = params + .iter() + .enumerate() + .filter(|(_, ident)| usage.uses(ident)) + .map(|(index, ident)| { + quote! { + #[allow(unused_variables)] + let #ident = #parameter_vector[#index]; + } + }); + + quote! { + #(#bindings)* + } +} + +fn generate_mutable_parameter_bindings( + params: &[Ident], + closure: &ExprClosure, + parameter_vector: &Ident, +) -> (TokenStream2, TokenStream2) { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let used_params = params + .iter() + .enumerate() + .filter(|(_, ident)| usage.uses(ident)) + .collect::>(); + + let bindings = used_params.iter().map(|(index, ident)| { + quote! { + #[allow(unused_mut, unused_variables)] + let mut #ident = #parameter_vector[#index]; + } + }); + let writebacks = used_params.iter().map(|(index, ident)| { + quote! { + #parameter_vector[#index] = #ident; + } + }); + + (quote! { #(#bindings)* }, quote! { #(#writebacks)* }) +} + +fn generate_covariate_bindings( + covariates: &[Ident], + closure: &ExprClosure, + covariate_map: &Ident, + time: &Ident, +) -> TokenStream2 { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let used_covariates = covariates + .iter() + .filter(|ident| usage.uses(ident)) + .collect::>(); + + if used_covariates.is_empty() { + quote! {} + } else { + quote! { + ::pharmsol::fetch_cov!(#covariate_map, #time, #(#used_covariates),*); + } + } +} + +fn validate_ode_diffeq_uses_automatic_injection( + diffeq: &ExprClosure, + routes: &[OdeRouteDecl], +) -> syn::Result<()> { + match closure_param_names(diffeq).len() { + 3 => Ok(()), + 5 => { + let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref()); + let route_inputs = route_input_idents(routes); + let fourth_param = closure_param_ident(diffeq, 3); + let fifth_param = closure_param_ident(diffeq, 4); + let mentions_route_inputs = route_inputs.iter().any(|route| usage.mentions(route)); + let indexes_fifth_param = fifth_param.as_ref().is_some_and(|ident| usage.indexes(ident)); + let reads_fourth_param_as_input = fourth_param + .as_ref() + .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident)); + + if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input { + Err(syn::Error::new_spanned( + diffeq, + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx| and remove manual `bolus[...]` / `rateiv[...]` terms", + )) + } else { + Ok(()) + } + } + _ => Err(syn::Error::new_spanned( + diffeq, + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )), + } +} + +fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec { + routes + .iter() + .filter_map(|route| route.input.ident().cloned()) + .collect() +} + +fn route_input_names(routes: &[OdeRouteDecl]) -> Vec { + routes.iter().map(|route| route.input.name()).collect() +} + +fn ode_route_input_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { + let mut next_bolus_index = 0usize; + let mut next_infusion_index = 0usize; + + routes + .iter() + .map(|route| { + let index = match route.kind { + OdeRouteKind::Bolus => { + let index = next_bolus_index; + next_bolus_index += 1; + index + } + OdeRouteKind::Infusion => { + let index = next_infusion_index; + next_infusion_index += 1; + index + } + }; + (route.input.clone(), index) + }) + .collect() +} + +fn dense_index_len(bindings: &[(SymbolicIndex, usize)]) -> usize { + bindings + .iter() + .map(|(_, index)| index + 1) + .max() + .unwrap_or(0) +} + +fn validate_binding_conflicts( + left_label: &str, + left: &[Ident], + right_label: &str, + right: &[Ident], + context: &str, +) -> syn::Result<()> { + let right_names = right.iter().map(Ident::to_string).collect::>(); + + for ident in left { + let name = ident.to_string(); + if right_names.contains(&name) { + return Err(syn::Error::new_spanned( + ident, + format!( + "named {left_label} binding `{name}` conflicts with named {right_label} binding in {context}" + ), + )); + } + } + + Ok(()) +} + +fn validate_closure_param_conflicts( + closure_label: &str, + closure: &ExprClosure, + bindings: &[Ident], + binding_label: &str, +) -> syn::Result<()> { + let parameter_names = closure_param_names(closure) + .into_iter() + .filter(|name| !name.is_empty()) + .collect::>(); + + for ident in bindings { + let name = ident.to_string(); + if parameter_names.contains(&name) { + return Err(syn::Error::new_spanned( + ident, + format!( + "named {binding_label} binding `{name}` conflicts with `{closure_label}` closure parameter `{name}`" + ), + )); + } + } + + Ok(()) +} + +#[derive(Clone, Copy)] +struct NamedBindingSets<'a> { + params: &'a [Ident], + covariates: &'a [Ident], + states: &'a [Ident], + outputs: &'a [Ident], + routes: &'a [OdeRouteDecl], +} + +#[derive(Clone, Copy)] +struct CommonBindingClosures<'a> { + lag: Option<&'a ExprClosure>, + fa: Option<&'a ExprClosure>, + init: Option<&'a ExprClosure>, + out: &'a ExprClosure, +} + +#[derive(Clone, Copy)] +struct AnalyticalBindingClosures<'a> { + sec: Option<&'a ExprClosure>, + common: CommonBindingClosures<'a>, +} + +#[derive(Clone, Copy)] +struct OdeBindingClosures<'a> { + diffeq: &'a ExprClosure, + common: CommonBindingClosures<'a>, +} + +#[derive(Clone, Copy)] +struct SdeBindingClosures<'a> { + drift: &'a ExprClosure, + diffusion: &'a ExprClosure, + common: CommonBindingClosures<'a>, +} + +fn validate_named_binding_compatibility( + bindings: NamedBindingSets<'_>, + closures: OdeBindingClosures<'_>, +) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let OdeBindingClosures { + diffeq, + common: CommonBindingClosures { lag, fa, init, out }, + } = closures; + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "declaration-first `ode!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`diffeq` and `out` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`out` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`out` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "declaration-first `ode!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "declaration-first `ode!` named binding generation", + )?; + + validate_closure_param_conflicts("diffeq", diffeq, params, "parameter")?; + validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?; + validate_closure_param_conflicts("diffeq", diffeq, states, "state")?; + + if let Some(lag) = lag { + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`lag` named binding generation", + )?; + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`fa` named binding generation", + )?; + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + Ok(()) +} + +fn validate_analytical_named_binding_compatibility( + bindings: NamedBindingSets<'_>, + closures: AnalyticalBindingClosures<'_>, +) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let AnalyticalBindingClosures { + sec, + common: CommonBindingClosures { lag, fa, init, out }, + } = closures; + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "output", + outputs, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + + if let Some(sec) = sec { + validate_closure_param_conflicts("sec", sec, params, "parameter")?; + validate_closure_param_conflicts("sec", sec, covariates, "covariate")?; + } + + if let Some(lag) = lag { + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + Ok(()) +} + +fn validate_sde_named_binding_compatibility( + bindings: NamedBindingSets<'_>, + closures: SdeBindingClosures<'_>, +) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let SdeBindingClosures { + drift, + diffusion, + common: CommonBindingClosures { lag, fa, init, out }, + } = closures; + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "output", + outputs, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + + validate_closure_param_conflicts("drift", drift, params, "parameter")?; + validate_closure_param_conflicts("drift", drift, covariates, "covariate")?; + validate_closure_param_conflicts("drift", drift, states, "state")?; + validate_closure_param_conflicts("diffusion", diffusion, params, "parameter")?; + validate_closure_param_conflicts("diffusion", diffusion, states, "state")?; + + if let Some(lag) = lag { + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + Ok(()) +} + +fn generate_index_consts(idents: &[Ident]) -> TokenStream2 { + let bindings = idents.iter().enumerate().map(|(index, ident)| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }); + + quote! { + #(#bindings)* + } +} + +fn generate_mapped_index_consts(bindings: &[(SymbolicIndex, usize)]) -> TokenStream2 { + let bindings = bindings.iter().filter_map(|(label, index)| { + label.ident().map(|ident| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }) + }); + + quote! { + #(#bindings)* + } +} + +fn expand_out( + out: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + outputs: &[SymbolicIndex], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( + out, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_out + }}) +} + +fn route_property_error(macro_name: &str, label: &str, node: T) -> syn::Error { + syn::Error::new_spanned( + node, + format!( + "{macro_name} requires `{label}` to return `{label}! {{ ... }}` so route-property metadata can be synthesized" + ), + ) +} + +fn find_terminal_macro_invocation( + macro_name: &str, + label: &str, + closure: &ExprClosure, +) -> syn::Result { + match closure.body.as_ref() { + Expr::Macro(expr_macro) if expr_macro.mac.path.is_ident(label) => { + Ok(expr_macro.mac.clone()) + } + Expr::Macro(expr_macro) => Err(route_property_error(macro_name, label, expr_macro)), + Expr::Block(expr_block) => { + for stmt in expr_block.block.stmts.iter().rev() { + match stmt { + Stmt::Expr(Expr::Macro(expr_macro), _) + if expr_macro.mac.path.is_ident(label) => + { + return Ok(expr_macro.mac.clone()); + } + Stmt::Expr(Expr::Macro(expr_macro), _) => { + return Err(route_property_error(macro_name, label, expr_macro)); + } + Stmt::Expr(other, _) => { + return Err(route_property_error(macro_name, label, other)); + } + Stmt::Macro(stmt_macro) if stmt_macro.mac.path.is_ident(label) => { + return Ok(stmt_macro.mac.clone()); + } + Stmt::Macro(stmt_macro) => { + return Err(route_property_error(macro_name, label, stmt_macro)); + } + _ => continue, + } + } + + Err(route_property_error(macro_name, label, expr_block)) + } + other => Err(route_property_error(macro_name, label, other)), + } +} + +fn extract_route_property_routes( + macro_name: &str, + label: &str, + closure: &ExprClosure, + routes: &[OdeRouteDecl], +) -> syn::Result> { + let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?; + let entries = Punctuated::::parse_terminated + .parse2(macro_expr.tokens.clone())?; + let known_routes = route_input_names(routes) + .into_iter() + .collect::>(); + let mut seen = HashSet::new(); + + for entry in entries { + let route_name = entry.route.name(); + if !known_routes.contains(&route_name) { + return Err(syn::Error::new_spanned( + &entry.route, + format!( + "route `{route_name}` in `{label}!` is not declared in the `routes` section" + ), + )); + } + if !seen.insert(route_name.clone()) { + return Err(syn::Error::new_spanned( + &entry.route, + format!("duplicate route `{route_name}` in `{label}!`"), + )); + } + let _ = entry.value; + } + + Ok(seen) +} + +fn validate_route_property_kinds( + macro_name: &str, + label: &str, + routes: &[OdeRouteDecl], + property_routes: &HashSet, +) -> syn::Result<()> { + for route in routes { + if property_routes.contains(&route.input.name()) + && matches!(route.kind, OdeRouteKind::Infusion) + { + return Err(syn::Error::new_spanned( + &route.input, + format!( + "{macro_name} does not allow `{label}` on infusion route `{}`", + route.input + ), + )); + } + } + + Ok(()) +} + +fn expand_ode_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + route_bindings: &[(SymbolicIndex, usize)], +) -> syn::Result { + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "declaration-first `ode!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_ode_init( + init: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_route_metadata( + routes: &[OdeRouteDecl], + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.name(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + #lag_flag + #fa_flag + .inject_input_to_destination() + } + }) + .collect() +} + +fn expand_analytical_route_metadata( + routes: &[OdeRouteDecl], + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.name(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + #lag_flag + #fa_flag + } + }) + .collect() +} + +fn expand_sde_route_metadata( + routes: &[OdeRouteDecl], + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.name(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + .inject_input_to_destination() + #lag_flag + #fa_flag + } + }) + .collect() +} + +fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize { + states + .iter() + .position(|state| state == &route.destination) + .expect("validated route destination should exist") +} + +fn expand_injected_ode_route_terms( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(SymbolicIndex, usize)], + dx: &Ident, + bolus: &Ident, + rateiv: &Ident, +) -> TokenStream2 { + let terms = routes + .iter() + .zip(route_bindings.iter()) + .map(|(route, (_, input_index))| { + let destination = route_destination_index(route, states); + match route.kind { + OdeRouteKind::Bolus => quote! { + #dx[#destination] += #bolus[#input_index]; + }, + OdeRouteKind::Infusion => quote! { + #dx[#destination] += #rateiv[#input_index]; + }, + } + }); + + quote! { + #(#terms)* + } +} + +fn expand_injected_sde_rate_terms( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(SymbolicIndex, usize)], + dx: &Ident, + rateiv: &Ident, +) -> TokenStream2 { + let terms = routes + .iter() + .zip(route_bindings.iter()) + .filter_map(|(route, (_, input_index))| match route.kind { + OdeRouteKind::Bolus => None, + OdeRouteKind::Infusion => { + let destination = route_destination_index(route, states); + Some(quote! { + #dx[#destination] += #rateiv[#input_index]; + }) + } + }); + + quote! { + #(#terms)* + } +} + +fn expand_injected_sde_bolus_mappings( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(SymbolicIndex, usize)], +) -> TokenStream2 { + let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; + + for (route, (_, input_index)) in routes.iter().zip(route_bindings.iter()) { + if let OdeRouteKind::Bolus = route.kind { + let destination = route_destination_index(route, states); + destinations[*input_index] = quote! { Some(#destination) }; + } + } + + quote! { + .with_injected_bolus_inputs(&[#(#destinations),*]) + } +} + +fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn::Result<()> { + let mut seen = HashSet::new(); + for ident in idents { + let name = ident.to_string(); + if !seen.insert(name.clone()) { + return Err(syn::Error::new_spanned( + ident, + format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"), + )); + } + } + Ok(()) +} + +fn validate_unique_symbolic_indices( + kind: &str, + labels: &[SymbolicIndex], + macro_name: &str, +) -> syn::Result<()> { + let mut seen = HashSet::new(); + for label in labels { + let name = label.name(); + if !seen.insert(name.clone()) { + return Err(syn::Error::new_spanned( + label, + format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"), + )); + } + } + Ok(()) +} + +fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> { + let known_states = states.iter().map(Ident::to_string).collect::>(); + let mut seen_routes = HashSet::new(); + + for route in routes { + let route_name = route.input.name(); + if !seen_routes.insert(route_name.clone()) { + return Err(syn::Error::new_spanned( + &route.input, + format!("duplicate route `{route_name}` in declaration-first `{macro_name}`"), + )); + } + + if !known_states.contains(&route.destination.to_string()) { + return Err(syn::Error::new_spanned( + &route.destination, + format!( + "route destination `{}` is not declared in the `states` section", + route.destination + ), + )); + } + } + + Ok(()) +} + +fn expand_diffeq( + diffeq: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + routes: &[OdeRouteDecl], + route_bindings: &[(SymbolicIndex, usize)], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( + diffeq, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); + let body = &diffeq.body; + let dx_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) + }; + let route_terms = expand_injected_ode_route_terms( + routes, + states, + route_bindings, + &dx_binding, + &bolus, + &rateiv, + ); + + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + #route_terms + }; + __pharmsol_diffeq + }}) +} + +fn resolve_analytical_structure(structure: &Ident) -> syn::Result { + let structure_name = structure.to_string(); + let (runtime_path, metadata_kernel, parameter_arity, state_count) = match structure_name + .as_str() + { + "one_compartment" => ( + quote! { ::pharmsol::equation::one_compartment }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartment }, + 1, + 1, + ), + "one_compartment_cl" => ( + quote! { ::pharmsol::equation::one_compartment_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentCl }, + 2, + 1, + ), + "one_compartment_cl_with_absorption" => ( + quote! { ::pharmsol::equation::one_compartment_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentClWithAbsorption }, + 3, + 2, + ), + "one_compartment_with_absorption" => ( + quote! { ::pharmsol::equation::one_compartment_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentWithAbsorption }, + 2, + 2, + ), + "two_compartments" => ( + quote! { ::pharmsol::equation::two_compartments }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartments }, + 3, + 2, + ), + "two_compartments_cl" => ( + quote! { ::pharmsol::equation::two_compartments_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsCl }, + 4, + 2, + ), + "two_compartments_cl_with_absorption" => ( + quote! { ::pharmsol::equation::two_compartments_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsClWithAbsorption }, + 5, + 3, + ), + "two_compartments_with_absorption" => ( + quote! { ::pharmsol::equation::two_compartments_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsWithAbsorption }, + 4, + 3, + ), + "three_compartments" => ( + quote! { ::pharmsol::equation::three_compartments }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartments }, + 5, + 3, + ), + "three_compartments_cl" => ( + quote! { ::pharmsol::equation::three_compartments_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsCl }, + 6, + 3, + ), + "three_compartments_cl_with_absorption" => ( + quote! { ::pharmsol::equation::three_compartments_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsClWithAbsorption }, + 7, + 4, + ), + "three_compartments_with_absorption" => ( + quote! { ::pharmsol::equation::three_compartments_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsWithAbsorption }, + 6, + 4, + ), + _ => { + return Err(syn::Error::new_spanned( + structure, + format!("unknown analytical structure `{structure_name}`"), + )); + } + }; + + Ok(AnalyticalKernelSpec { + runtime_path, + metadata_kernel, + parameter_arity, + state_count, + }) +} + +fn expand_analytical_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + route_bindings: &[(SymbolicIndex, usize)], +) -> syn::Result { + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "built-in `analytical!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_analytical_sec( + sec: &ExprClosure, + params: &[Ident], + covariates: &[Ident], +) -> syn::Result { + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + sec, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `sec` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|", + )?; + let parameter_vector = if sec.inputs.len() == full_inputs.len() { + closure_param_ident(sec, 0).unwrap_or_else(|| p.clone()) + } else { + p.clone() + }; + let (parameter_bindings, parameter_writebacks) = + generate_mutable_parameter_bindings(params, sec, ¶meter_vector); + let covariate_bindings = generate_covariate_bindings(covariates, sec, &cov, &t); + let body = &sec.body; + + Ok(quote! {{ + let __pharmsol_sec: fn( + &mut ::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) = |#p: &mut ::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #parameter_bindings + #covariate_bindings + #body + #parameter_writebacks + }; + __pharmsol_sec + }}) +} + +fn expand_analytical_init( + init: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_analytical_out( + out: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + outputs: &[SymbolicIndex], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( + out, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_out + }}) +} + +fn expand_sde_drift( + drift: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + routes: &[OdeRouteDecl], + route_bindings: &[(SymbolicIndex, usize)], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( + drift, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `drift` to have either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )?; + let parameter_bindings = generate_parameter_bindings(params, drift, &p); + let covariate_bindings = generate_covariate_bindings(covariates, drift, &cov, &t); + let body = &drift.body; + let dx_binding = if drift.inputs.len() == full_inputs.len() { + closure_param_ident(drift, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(drift, 2).unwrap_or_else(|| dx.clone()) + }; + let rate_terms = + expand_injected_sde_rate_terms(routes, states, route_bindings, &dx_binding, &rateiv); + + Ok(quote! {{ + let __pharmsol_drift: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + #rate_terms + }; + __pharmsol_drift + }}) +} + +fn expand_sde_diffusion( + diffusion: &ExprClosure, + params: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let sigma = generated_ident("__pharmsol_sigma"); + let full_inputs = [p.clone(), sigma.clone()]; + let reduced_inputs = [sigma.clone()]; + let input_aliases = generate_supported_input_aliases( + diffusion, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `diffusion` to have either 2 parameters: |p, sigma| or 1 parameter: |sigma|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffusion, &p); + let body = &diffusion.body; + + Ok(quote! {{ + let __pharmsol_diffusion: fn( + &::pharmsol::simulator::V, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #sigma: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #body + }; + __pharmsol_diffusion + }}) +} + +fn expand_sde_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + route_bindings: &[(SymbolicIndex, usize)], +) -> syn::Result { + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "declaration-first `sde!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_sde_init( + init: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_sde_out( + out: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + outputs: &[SymbolicIndex], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( + out, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_out + }}) +} + +// --------------------------------------------------------------------------- +// Proc macros +// --------------------------------------------------------------------------- + +#[proc_macro] +pub fn ode(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as OdeInput); + + let route_bindings = ode_route_input_bindings(&input.routes); + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `ode!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `ode!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `ode!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `ode!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let diffeq = match expand_diffeq( + &input.diffeq, + &input.params, + &input.covariates, + &input.states, + &input.routes, + &route_bindings, + ) { + Ok(diffeq) => diffeq, + Err(error) => return error.to_compile_error().into(), + }; + + let out = match expand_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let covariates = &input.covariates; + let states = &input.states; + let outputs = &input.outputs; + let routes = expand_route_metadata(&input.routes, &lag_routes, &fa_routes); + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; + + let lag = match input.lag.as_ref() { + Some(closure) => match expand_ode_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_ode_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => { + match expand_ode_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _, _| {} }, + }; + + quote! {{ + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .parameters([#(stringify!(#params)),*]) + #covariate_metadata + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))*; + + ::pharmsol::equation::ODE::new( + #diffeq, + #lag, + #fa, + #init, + #out, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + .with_metadata(__pharmsol_metadata) + .expect("declaration-first `ode!` generated invalid metadata") + }} + .into() +} + +#[proc_macro] +pub fn analytical(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as AnalyticalInput); + let route_bindings = ode_route_input_bindings(&input.routes); + + let kernel_spec = match resolve_analytical_structure(&input.structure) { + Ok(spec) => spec, + Err(error) => return error.to_compile_error().into(), + }; + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "built-in `analytical!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "built-in `analytical!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "built-in `analytical!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "built-in `analytical!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let sec = match input.sec.as_ref() { + Some(closure) => match expand_analytical_sec(closure, &input.params, &input.covariates) { + Ok(sec) => sec, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| {} }, + }; + + let out = match expand_analytical_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let lag = match input.lag.as_ref() { + Some(closure) => { + match expand_analytical_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_analytical_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => { + match expand_analytical_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _, _| {} }, + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let covariates = &input.covariates; + let states = &input.states; + let outputs = &input.outputs; + let routes = expand_analytical_route_metadata(&input.routes, &lag_routes, &fa_routes); + let runtime_path = kernel_spec.runtime_path; + let metadata_kernel = kernel_spec.metadata_kernel; + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; + + quote! {{ + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .kind(::pharmsol::equation::ModelKind::Analytical) + .parameters([#(stringify!(#params)),*]) + #covariate_metadata + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))* + .analytical_kernel(#metadata_kernel); + + ::pharmsol::equation::Analytical::new( + #runtime_path, + #sec, + #lag, + #fa, + #init, + #out, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + .with_metadata(__pharmsol_metadata) + .expect("built-in `analytical!` generated invalid metadata") + }} + .into() +} + +#[proc_macro] +pub fn sde(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as SdeInput); + let route_bindings = ode_route_input_bindings(&input.routes); + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `sde!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `sde!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `sde!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `sde!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let drift = match expand_sde_drift( + &input.drift, + &input.params, + &input.covariates, + &input.states, + &input.routes, + &route_bindings, + ) { + Ok(drift) => drift, + Err(error) => return error.to_compile_error().into(), + }; + + let diffusion = match expand_sde_diffusion(&input.diffusion, &input.params, &input.states) { + Ok(diffusion) => diffusion, + Err(error) => return error.to_compile_error().into(), + }; + + let lag = match input.lag.as_ref() { + Some(closure) => match expand_sde_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_sde_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => { + match expand_sde_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _, _| {} }, + }; + + let out = match expand_sde_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let covariates = &input.covariates; + let states = &input.states; + let outputs = &input.outputs; + let particles = &input.particles; + let routes = expand_sde_route_metadata(&input.routes, &lag_routes, &fa_routes); + let bolus_mappings = + expand_injected_sde_bolus_mappings(&input.routes, &input.states, &route_bindings); + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; + + quote! {{ + let __pharmsol_particles: usize = #particles; + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .kind(::pharmsol::equation::ModelKind::Sde) + .parameters([#(stringify!(#params)),*]) + #covariate_metadata + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))* + .particles(__pharmsol_particles); + + ::pharmsol::equation::SDE::new( + #drift, + #diffusion, + #lag, + #fa, + #init, + #out, + __pharmsol_particles, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + #bolus_mappings + .with_metadata(__pharmsol_metadata) + .expect("declaration-first `sde!` generated invalid metadata") + }} + .into() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_removed_legacy_form() { + let error = syn::parse_str::( + "diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("legacy macro form must fail"); + + assert!(error + .to_string() + .contains("requires `name`, `params`, `states`, `outputs`, and `routes`")); + assert!(error + .to_string() + .contains("old inferred-dimensions form has been removed")); + } + + #[test] + fn validates_route_destinations() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> peripheral], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown route destination must fail"); + + assert!(error + .to_string() + .contains("route destination `peripheral` is not declared in the `states` section")); + } + + #[test] + fn rejects_named_binding_collisions() { + let error = syn::parse_str::( + "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: [infusion(iv) -> central], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("parameter/state binding collisions must fail"); + + assert!(error + .to_string() + .contains("named parameter binding `central` conflicts with named state binding")); + } + + #[test] + fn ode_route_bindings_share_inputs_by_kind_local_ordinal() { + let input = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: [bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot], diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + ) + .expect("declaration-first ode input should parse"); + + let bindings = ode_route_input_bindings(&input.routes); + + assert_eq!(dense_index_len(&bindings), 2); + assert_eq!(bindings[0].0.name(), "oral"); + assert_eq!(bindings[0].1, 0); + assert_eq!(bindings[1].0.name(), "iv"); + assert_eq!(bindings[1].1, 0); + assert_eq!(bindings[2].0.name(), "sc"); + assert_eq!(bindings[2].1, 1); + } + + #[test] + fn generated_parameter_bindings_only_include_referenced_locals_in_hot_closures() { + let params = vec![generated_ident("ke"), generated_ident("v")]; + let closure = syn::parse_str::( + "|x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }", + ) + .expect("closure should parse"); + + let bindings = + generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p")) + .to_string(); + + assert!( + bindings.contains("let ke = __pharmsol_p [0usize] ;") + || bindings.contains("let ke = __pharmsol_p [ 0 ] ;") + ); + assert!(!bindings.contains("let v =")); + } + + #[test] + fn generated_parameter_bindings_fall_back_to_all_params_for_stmt_macros() { + let params = vec![generated_ident("ka"), generated_ident("tlag")]; + let closure = syn::parse_str::("|_p, _t, _cov| { lag! { oral => tlag } }") + .expect("closure should parse"); + + let bindings = + generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p")) + .to_string(); + + assert!(bindings.contains("let ka =")); + assert!(bindings.contains("let tlag =")); + } + + #[test] + fn analytical_accepts_extra_parameters_beyond_kernel_arity() { + let input = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", + ) + .expect("extra declared parameters should be allowed"); + + assert_eq!(input.params.len(), 5); + assert_eq!(input.covariates.len(), 2); + assert!(input.sec.is_some()); + assert_eq!(input.states.len(), 2); + } + + #[test] + fn analytical_rejects_unknown_structure() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: mystery, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown analytical structure must fail"); + + assert!(error + .to_string() + .contains("unknown analytical structure `mystery`")); + } + + #[test] + fn analytical_rejects_insufficient_kernel_parameters() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("insufficient kernel parameters must fail"); + + assert!(error + .to_string() + .contains("requires at least 2 parameter value(s)")); + } + + #[test] + fn analytical_rejects_unknown_route_property_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown lag route must fail"); + + assert!(error + .to_string() + .contains("route `iv` in `lag!` is not declared in the `routes` section")); + } + + #[test] + fn analytical_rejects_infusion_lag_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("infusion lag must fail"); + + assert!(error + .to_string() + .contains("built-in `analytical!` does not allow `lag` on infusion route `iv`")); + } + + #[test] + fn sde_requires_particles() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: [infusion(iv) -> central], drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("missing particles must fail"); + + assert!(error + .to_string() + .contains("missing required field `particles` in declaration-first `sde!`")); + } + + #[test] + fn sde_rejects_unknown_route_property_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown lag route must fail"); + + assert!(error + .to_string() + .contains("route `oral` in `lag!` is not declared in the `routes` section")); + } + + #[test] + fn sde_rejects_infusion_lag_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("infusion lag must fail"); + + assert!(error + .to_string() + .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`")); + } + + #[test] + fn rejects_braced_route_lists() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("braced route lists must fail"); + + assert!(error + .to_string() + .contains("declaration-first macro `routes` must use `[...]`, not `{...}`")); } - .into() } diff --git a/src/data/builder.rs b/src/data/builder.rs index 18aa17fe..ed0a57a8 100644 --- a/src/data/builder.rs +++ b/src/data/builder.rs @@ -1,6 +1,21 @@ +//! Builder API for constructing [`Subject`] schedules in Rust. +//! +//! Use `Subject::builder(...)` when you want to describe a subject directly in +//! code with a schedule-oriented API. This is the preferred high-level +//! path for hand-written datasets. +//! +//! Builder methods accept public input and output labels. Prefer stable strings +//! such as `"depot"`, `"iv"`, and `"cp"`. Numeric values are accepted, but +//! they remain public labels rather than automatically becoming dense internal +//! indices. + use crate::{data::*, Censor}; -/// Extension trait for creating [Subject] instances using the builder pattern +/// Extension trait that enables `Subject::builder(...)`. +/// +/// Most users do not need to import [`SubjectBuilder`] directly. Import this +/// trait from the crate root or [`crate::prelude`] and then start with +/// `Subject::builder("id")`. pub trait SubjectBuilderExt { /// Create a new SubjectBuilder with the specified ID /// @@ -14,8 +29,8 @@ pub trait SubjectBuilderExt { /// use pharmsol::*; /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) - /// .observation(1.0, 10.5, 0) + /// .bolus(0.0, 100.0, "depot") + /// .observation(1.0, 10.5, "cp") /// .build(); /// ``` fn builder(id: impl Into) -> SubjectBuilder; @@ -34,11 +49,37 @@ impl SubjectBuilderExt for Subject { } } -/// Builder for creating [Subject] instances with a fluent API +/// Builder for creating [`Subject`] values with a fluent API. +/// +/// Use [`SubjectBuilder`] when you want to author common dose and observation +/// schedules directly in Rust without constructing low-level event values by +/// hand. +/// +/// A builder instance accumulates events inside the current [`Occasion`]. +/// [`SubjectBuilder::repeat`] duplicates the most recently added event at later +/// times, and [`SubjectBuilder::reset`] closes the current occasion and starts a +/// new one with fresh occasion-local state. +/// +/// Input and output arguments are public labels. Prefer stable model-facing +/// names such as `"depot"`, `"iv"`, and `"cp"`. +/// +/// # Example +/// +/// ```rust +/// use pharmsol::*; +/// +/// let subject = Subject::builder("patient_001") +/// .bolus(0.0, 100.0, "depot") +/// .repeat(1, 24.0) +/// .observation(1.0, 12.3, "cp") +/// .missing_observation(25.0, "cp") +/// .reset() +/// .bolus(0.0, 80.0, "depot") +/// .observation(1.0, 10.1, "cp") +/// .build(); /// -/// The [SubjectBuilder] allows for constructing complex subject data with a -/// chainable, readable syntax. Events like doses and observations can be -/// added sequentially, and the builder handles organizing them into occasions. +/// assert_eq!(subject.occasions().len(), 2); +/// ``` #[derive(Debug, Clone)] pub struct SubjectBuilder { id: String, @@ -49,52 +90,54 @@ pub struct SubjectBuilder { } impl SubjectBuilder { - /// Add an event to the current occasion + /// Add a fully constructed event to the current occasion. /// - /// # Arguments - /// - /// * `event` - The event to add + /// Use this when you want to mix builder convenience methods with direct + /// [`Event`] values. pub fn event(mut self, event: Event) -> Self { self.last_added_event = Some(event.clone()); self.current_occasion.add_event(event); self } - /// Add a bolus dosing event + /// Add an instantaneous dose. /// /// # Arguments /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose - pub fn bolus(self, time: f64, amount: f64, input: usize) -> Self { + /// * `input` - Public input label receiving the dose + /// + /// Prefer stable route names such as `"depot"` or `"iv"` when the model + /// declares named routes. + pub fn bolus(self, time: f64, amount: f64, input: impl ToString) -> Self { let bolus = Bolus::new(time, amount, input, self.current_occasion.index()); let event = Event::Bolus(bolus); self.event(event) } - /// Add an infusion event + /// Add a continuous dose over a duration. /// /// # Arguments /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - Public input label receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn infusion(self, time: f64, amount: f64, input: usize, duration: f64) -> Self { + pub fn infusion(self, time: f64, amount: f64, input: impl ToString, duration: f64) -> Self { let infusion = Infusion::new(time, amount, input, duration, self.current_occasion.index()); let event = Event::Infusion(infusion); self.event(event) } - /// Add an observation + /// Add an observed value at a given time. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation - pub fn observation(self, time: f64, value: f64, outeq: usize) -> Self { + /// * `outeq` - Public output label for this observation + pub fn observation(self, time: f64, value: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, Some(value), @@ -107,18 +150,19 @@ impl SubjectBuilder { self.event(event) } - /// Add a censored observation + /// Add an observed value with explicit censoring information. + /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this - /// observation + /// * `outeq` - Public output label for this observation + /// * `censoring` - Censoring status for the observation value pub fn censored_observation( self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, censoring: Censor, ) -> Self { let observation = Observation::new( @@ -133,13 +177,16 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation + /// Add a prediction-only observation slot. /// /// # Arguments /// /// * `time` - Time of the observation - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation - pub fn missing_observation(self, time: f64, outeq: usize) -> Self { + /// * `outeq` - Public output label for this observation + /// + /// Use this when you want a prediction at a time point but do not have an + /// observed value. + pub fn missing_observation(self, time: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, None, @@ -152,20 +199,20 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation with a specific error polynomial + /// Add an observed value with an explicit assay error polynomial. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation + /// * `outeq` - Public output label for this observation /// * `errorpoly` - Error polynomial coefficients (c0, c1, c2, c3) - /// * `censored` - Whether the observation is censored + /// * `censored` - Censoring status for the observation value pub fn observation_with_error( self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) -> Self { @@ -181,7 +228,10 @@ impl SubjectBuilder { self.event(event) } - /// Repeat the last event `n` times, separated by some interval `delta` + /// Repeat the last event `n` times, separated by `delta`. + /// + /// The repeated events keep the same label, value, censoring state, and + /// error polynomial as the original event. Only the event time changes. /// /// # Arguments /// @@ -193,9 +243,8 @@ impl SubjectBuilder { /// ```rust /// use pharmsol::*; /// - /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) // First dose at time 0 + /// .bolus(0.0, 100.0, "depot") // First dose at time 0 /// .repeat(3, 24.0) // Repeat the dose at times 24, 48, and 72 /// .build(); /// ``` @@ -255,12 +304,14 @@ impl SubjectBuilder { self } - /// Complete the current occasion and start a new one + /// Complete the current occasion and start a new one. /// /// This finalizes the current occasion, adds it to the subject, /// and creates a new occasion for subsequent events. - /// This is useful if a patient has new observations at some other occasion. - /// Note that all states are reset! + /// Use this when the subject should begin a new occasion with reset state. + /// + /// Covariates collected since the previous reset are attached to the + /// finished occasion. The new occasion starts empty and its state is reset. pub fn reset(mut self) -> Self { let block_index = self.current_occasion.index() + 1; self.current_occasion.sort(); @@ -274,7 +325,7 @@ impl SubjectBuilder { self } - /// Add a covariate value at a specific time + /// Add a covariate value at a specific time. /// /// Multiple calls for the same covariate at different times will create /// linear interpolation between the time points. @@ -300,7 +351,7 @@ impl SubjectBuilder { self } - /// Finalize and build the Subject + /// Finalize and build the [`Subject`]. /// /// This completes the current occasion and returns a new Subject with all /// the accumulated data. diff --git a/src/data/error_model.rs b/src/data/error_model.rs index 609cad9e..0d52932f 100644 --- a/src/data/error_model.rs +++ b/src/data/error_model.rs @@ -1,6 +1,9 @@ -use std::hash::{Hash, Hasher}; +use std::{ + collections::BTreeMap, + hash::{Hash, Hasher}, +}; -use crate::simulator::likelihood::Prediction; +use crate::{data::event::OutputLabel, simulator::likelihood::Prediction}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -120,7 +123,11 @@ impl ErrorPoly { impl From> for AssayErrorModels { fn from(models: Vec) -> Self { - Self { models } + Self { + models, + output_lookup: BTreeMap::new(), + named_models: BTreeMap::new(), + } } } @@ -140,6 +147,8 @@ impl From> for AssayErrorModels { #[derive(Serialize, Debug, Clone, Deserialize)] pub struct AssayErrorModels { models: Vec, + output_lookup: BTreeMap, + named_models: BTreeMap, } /// Deprecated alias for [`AssayErrorModels`]. @@ -159,12 +168,149 @@ impl Default for AssayErrorModels { } impl AssayErrorModels { - /// Create a new instance of [`AssayErrorModels`] + /// Create a new reusable label-first [`AssayErrorModels`] definition. /// - /// # Returns - /// A new instance of [AssayErrorModels]. + /// Output labels are resolved once per equation when the error models are + /// used through simulation or likelihood entrypoints. + /// + /// This lets the same public definition be reused safely across multiple + /// equations while keeping the dense bound representation internal to the + /// runtime path. + /// + /// ```rust + /// # use pharmsol::prelude::*; + /// let error_models = AssayErrorModels::new() + /// .add("cp", AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0))?; + /// # Ok::<(), pharmsol::data::error_model::ErrorModelError>(()) + /// ``` pub fn new() -> Self { - Self { models: vec![] } + Self::empty() + } + + pub(crate) fn assert_compatible_output_names( + &self, + outputs: I, + ) -> Result<(), ErrorModelError> + where + I: IntoIterator, + S: AsRef, + { + if self.output_lookup.is_empty() { + return Ok(()); + } + + let expected = self.bound_output_names(); + let found = outputs + .into_iter() + .map(|output| output.as_ref().to_string()) + .collect::>(); + if expected == found { + return Ok(()); + } + + Err(ErrorModelError::IncompatibleOutputContext { expected, found }) + } + + pub(crate) fn bind_to(&self, context: &impl crate::Equation) -> Result { + self.bind_output_names(context.assay_error_models().bound_output_names()) + } + + pub(crate) fn bind_output_names(&self, outputs: I) -> Result + where + I: IntoIterator, + S: AsRef, + { + let outputs = outputs + .into_iter() + .map(|output| output.as_ref().to_string()) + .collect::>(); + + if !self.output_lookup.is_empty() { + self.assert_compatible_output_names(outputs.iter().map(String::as_str))?; + return Ok(self.clone()); + } + + if self.named_models.is_empty() { + return Ok(self.clone()); + } + + let mut bound = Self::with_output_names(outputs.iter().map(String::as_str)); + bound.models = self.models.clone(); + + for (label, model) in &self.named_models { + bound = bound.add(label.clone(), model.clone())?; + } + + Ok(bound) + } + + /// Create an unbound error-model set for dense-slot callers. + /// + /// This keeps the pre-existing numeric-slot setup path available for low-level + /// tests or workflows that deliberately operate on dense output indices. + pub(crate) fn empty() -> Self { + Self { + models: vec![], + output_lookup: BTreeMap::new(), + named_models: BTreeMap::new(), + } + } + + /// Create an error-model set with output labels resolved up front. + /// + /// This is the label-aware constructor for public workflows. It binds names + /// to dense output slots once during setup so that likelihood evaluation can + /// keep using direct vector indexing with no additional runtime lookup cost. + pub(crate) fn with_output_names(outputs: I) -> Self + where + I: IntoIterator, + S: AsRef, + { + let output_lookup = outputs + .into_iter() + .enumerate() + .map(|(index, output)| (OutputLabel::new(output.as_ref()), index)) + .collect(); + + Self { + models: vec![], + output_lookup, + named_models: BTreeMap::new(), + } + } + + fn bound_output_names(&self) -> Vec { + let mut names = self + .output_lookup + .iter() + .map(|(label, index)| (*index, label.to_string())) + .collect::>(); + names.sort_by_key(|(index, _)| *index); + names.into_iter().map(|(_, label)| label).collect() + } + + fn resolve_output_binding(&self, outeq: impl ToString) -> Result { + let label = OutputLabel::new(outeq); + self.output_lookup + .get(&label) + .copied() + .or_else(|| label.index()) + .ok_or_else(|| ErrorModelError::UnknownOutputLabel(label.to_string())) + } + + fn insert_model_at( + &mut self, + outeq: usize, + model: AssayErrorModel, + ) -> Result<(), ErrorModelError> { + if outeq >= self.models.len() { + self.models.resize(outeq + 1, AssayErrorModel::None); + } + if self.models[outeq] != AssayErrorModel::None { + return Err(ErrorModelError::ExistingOutputEquation(outeq)); + } + self.models[outeq] = model; + Ok(()) } /// Get the error model for a specific output equation @@ -182,22 +328,36 @@ impl AssayErrorModels { Ok(&self.models[outeq]) } - /// Add a new error model for a specific output equation + /// Add a new error model for a specific output equation or declared label. /// # Arguments - /// * `outeq` - The index of the output equation for which to add the error model. + /// * `outeq` - The output slot index or public output label. /// * `model` - The [AssayErrorModel] to add for the specified output equation. /// # Returns /// A new instance of AssayErrorModels with the added model. /// # Errors - /// If the output equation index is invalid or if a model already exists for that output equation, an [ErrorModelError::ExistingOutputEquation] is returned. - pub fn add(mut self, outeq: usize, model: AssayErrorModel) -> Result { - if outeq >= self.models.len() { - self.models.resize(outeq + 1, AssayErrorModel::None); + /// If the output label is unknown or if a model already exists for that output equation, an error is returned. + pub fn add( + mut self, + outeq: impl ToString, + model: AssayErrorModel, + ) -> Result { + let label = OutputLabel::new(outeq); + + if !self.output_lookup.is_empty() { + let outeq = self.resolve_output_binding(label.clone())?; + self.insert_model_at(outeq, model)?; + return Ok(self); } - if self.models[outeq] != AssayErrorModel::None { - return Err(ErrorModelError::ExistingOutputEquation(outeq)); + + if let Some(outeq) = label.index() { + self.insert_model_at(outeq, model)?; + return Ok(self); } - self.models[outeq] = model; + + if self.named_models.contains_key(&label) { + return Err(ErrorModelError::ExistingOutputLabel(label.to_string())); + } + self.named_models.insert(label, model); Ok(self) } /// Returns an iterator over the error models in the collection. @@ -222,6 +382,27 @@ impl AssayErrorModels { pub fn hash(&self) -> u64 { let mut hasher = ahash::AHasher::default(); + for (label, model) in &self.named_models { + 3u8.hash(&mut hasher); + label.hash(&mut hasher); + + match model { + AssayErrorModel::Additive { lambda, .. } => { + 0u8.hash(&mut hasher); + lambda.value().to_bits().hash(&mut hasher); + lambda.is_fixed().hash(&mut hasher); + } + AssayErrorModel::Proportional { gamma, .. } => { + 1u8.hash(&mut hasher); + gamma.value().to_bits().hash(&mut hasher); + gamma.is_fixed().hash(&mut hasher); + } + AssayErrorModel::None => { + 2u8.hash(&mut hasher); + } + } + } + for outeq in 0..self.models.len() { // Find the model with the matching outeq ID @@ -249,12 +430,16 @@ impl AssayErrorModels { } /// Returns the number of error models in the collection. pub fn len(&self) -> usize { + if self.models.is_empty() && !self.named_models.is_empty() && self.output_lookup.is_empty() + { + return self.named_models.len(); + } self.models.len() } /// Returns whether the collection contains no error models. pub fn is_empty(&self) -> bool { - self.models.is_empty() + self.models.is_empty() && self.named_models.is_empty() } /// Returns the error polynomial associated with the specified output equation. @@ -943,8 +1128,19 @@ pub enum ErrorModelError { NonFiniteSigma, #[error("The output equation index {0} is invalid")] InvalidOutputEquation(usize), + #[error("The output label `{0}` is not declared in this error model context")] + UnknownOutputLabel(String), + #[error("The output label `{0}` already exists in this assay error model specification")] + ExistingOutputLabel(String), #[error("The output equation number {0} already exists")] ExistingOutputEquation(usize), + #[error( + "Assay error models were bound for outputs {expected:?} but used with outputs {found:?}" + )] + IncompatibleOutputContext { + expected: Vec, + found: Vec, + }, #[error("An output equation does not have an error model defined")] MissingErrorModel, #[error("The output equation index {0} is of type ErrorModel::None")] @@ -1029,7 +1225,7 @@ mod tests { #[test] fn test_error_models_add_single() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!(models.len(), 1); } @@ -1038,7 +1234,7 @@ mod tests { let model1 = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add(0, model1) .unwrap() .add(1, model2) @@ -1047,12 +1243,101 @@ mod tests { assert_eq!(models.len(), 2); } + #[test] + fn test_error_models_add_label_with_output_names() { + let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); + let models = AssayErrorModels::with_output_names(["cp", "effect"]) + .add("effect", model) + .unwrap(); + + assert_eq!(models.len(), 2); + assert!(models.error_model(1).is_ok()); + } + + #[test] + fn test_error_models_bind_output_names() { + let error_models = AssayErrorModels::new() + .add( + "effect", + AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), + ) + .unwrap(); + + let models = error_models.bind_output_names(["cp", "effect"]).unwrap(); + assert_eq!(models.len(), 2); + assert!(models.error_model(1).is_ok()); + } + + #[test] + fn test_error_models_add_unknown_label_fails() { + let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); + let result = AssayErrorModels::with_output_names(["cp"]).add("effect", model); + + assert!(result.is_err()); + match result { + Err(ErrorModelError::UnknownOutputLabel(label)) => assert_eq!(label, "effect"), + _ => panic!("Expected UnknownOutputLabel error"), + } + } + + #[test] + fn test_error_models_duplicate_label_fails() { + let result = AssayErrorModels::new() + .add( + "cp", + AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), + ) + .unwrap() + .add( + "cp", + AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0), + ); + + match result { + Err(ErrorModelError::ExistingOutputLabel(label)) => assert_eq!(label, "cp"), + _ => panic!("Expected ExistingOutputLabel error"), + } + } + + #[test] + fn test_bound_error_models_reject_mismatched_output_context() { + let error_models = AssayErrorModels::new() + .add( + "cp", + AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), + ) + .unwrap() + .bind_output_names(["cp", "effect"]) + .unwrap(); + + match error_models.assert_compatible_output_names(["effect", "cp"]) { + Err(ErrorModelError::IncompatibleOutputContext { expected, found }) => { + assert_eq!(expected, vec!["cp".to_string(), "effect".to_string()]); + assert_eq!(found, vec!["effect".to_string(), "cp".to_string()]); + } + _ => panic!("Expected IncompatibleOutputContext error"), + } + } + + #[test] + fn test_error_models_sigma_from_label_bound_output() { + let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); + let models = AssayErrorModels::with_output_names(["cp"]) + .add("cp", model) + .unwrap(); + + let observation = Observation::new(0.0, Some(20.0), 0, None, 0, Censor::None); + let prediction = observation.to_prediction(10.0, vec![]); + + assert_eq!(models.sigma(&prediction).unwrap(), (26.0_f64).sqrt()); + } + #[test] fn test_error_models_add_duplicate_outeq_fails() { let model1 = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); - let result = AssayErrorModels::new() + let result = AssayErrorModels::empty() .add(0, model1) .unwrap() .add(0, model2); // Same outeq should fail @@ -1067,7 +1352,7 @@ mod tests { #[test] fn test_error_models_factor() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!(models.factor(0).unwrap(), 5.0); } @@ -1075,7 +1360,7 @@ mod tests { #[test] fn test_error_models_factor_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.factor(1); assert!(result.is_err()); @@ -1088,7 +1373,7 @@ mod tests { #[test] fn test_error_models_set_factor() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!(models.factor(0).unwrap(), 5.0); models.set_factor(0, 10.0).unwrap(); @@ -1098,7 +1383,7 @@ mod tests { #[test] fn test_error_models_set_factor_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.set_factor(1, 10.0); assert!(result.is_err()); @@ -1112,7 +1397,7 @@ mod tests { fn test_error_models_errorpoly() { let poly = ErrorPoly::new(1.0, 2.0, 3.0, 4.0); let model = AssayErrorModel::additive(poly, 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let retrieved_poly = models.errorpoly(0).unwrap(); assert_eq!(retrieved_poly.coefficients(), (1.0, 2.0, 3.0, 4.0)); @@ -1121,7 +1406,7 @@ mod tests { #[test] fn test_error_models_errorpoly_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.errorpoly(1); assert!(result.is_err()); @@ -1136,7 +1421,7 @@ mod tests { let poly1 = ErrorPoly::new(1.0, 2.0, 3.0, 4.0); let poly2 = ErrorPoly::new(5.0, 6.0, 7.0, 8.0); let model = AssayErrorModel::additive(poly1, 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!( models.errorpoly(0).unwrap().coefficients(), @@ -1152,7 +1437,7 @@ mod tests { #[test] fn test_error_models_set_errorpoly_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.set_errorpoly(1, ErrorPoly::new(5.0, 6.0, 7.0, 8.0)); assert!(result.is_err()); @@ -1165,7 +1450,7 @@ mod tests { #[test] fn test_error_models_sigma() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 0, None, 0, Censor::None); let prediction = observation.to_prediction(10.0, vec![]); @@ -1178,7 +1463,7 @@ mod tests { #[test] fn test_error_models_sigma_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 1, None, 0, Censor::None); // outeq=1 not in models let prediction = observation.to_prediction(10.0, vec![]); @@ -1194,7 +1479,7 @@ mod tests { #[test] fn test_error_models_variance() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 0, None, 0, Censor::None); let prediction = observation.to_prediction(10.0, vec![]); @@ -1207,7 +1492,7 @@ mod tests { #[test] fn test_error_models_variance_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 1, None, 0, Censor::None); // outeq=1 not in models let prediction = observation.to_prediction(10.0, vec![]); @@ -1223,7 +1508,7 @@ mod tests { #[test] fn test_error_models_sigma_from_value() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let sigma = models.sigma_from_value(0, 20.0).unwrap(); assert_eq!(sigma, (26.0_f64).sqrt()); @@ -1232,7 +1517,7 @@ mod tests { #[test] fn test_error_models_sigma_from_value_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.sigma_from_value(1, 20.0); assert!(result.is_err()); @@ -1245,7 +1530,7 @@ mod tests { #[test] fn test_error_models_variance_from_value() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let variance = models.variance_from_value(0, 20.0).unwrap(); let expected_sigma = (26.0_f64).sqrt(); @@ -1255,7 +1540,7 @@ mod tests { #[test] fn test_error_models_variance_from_value_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.variance_from_value(1, 20.0); assert!(result.is_err()); @@ -1270,13 +1555,13 @@ mod tests { let model1 = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); - let models1 = AssayErrorModels::new() + let models1 = AssayErrorModels::empty() .add(0, model1.clone()) .unwrap() .add(1, model2.clone()) .unwrap(); - let models2 = AssayErrorModels::new() + let models2 = AssayErrorModels::empty() .add(0, model1) .unwrap() .add(1, model2) @@ -1292,13 +1577,13 @@ mod tests { let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); // Add in different orders - let models1 = AssayErrorModels::new() + let models1 = AssayErrorModels::empty() .add(0, model1.clone()) .unwrap() .add(1, model2.clone()) .unwrap(); - let models2 = AssayErrorModels::new() + let models2 = AssayErrorModels::empty() .add(1, model2) .unwrap() .add(0, model1) @@ -1314,7 +1599,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.1); - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1343,7 +1628,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 2.0); - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1463,7 +1748,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 2.0); - let mut models = AssayErrorModels::new() + let mut models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1523,8 +1808,8 @@ mod tests { let model1_variable = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model1_fixed = AssayErrorModel::additive_fixed(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models1 = AssayErrorModels::new().add(0, model1_variable).unwrap(); - let models2 = AssayErrorModels::new().add(0, model1_fixed).unwrap(); + let models1 = AssayErrorModels::empty().add(0, model1_variable).unwrap(); + let models2 = AssayErrorModels::empty().add(0, model1_fixed).unwrap(); // Different fixed/variable states should produce different hashes assert_ne!(models1.hash(), models2.hash()); @@ -1536,7 +1821,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 2.0); - let mut models = AssayErrorModels::new() + let mut models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1610,7 +1895,7 @@ mod tests { #[test] fn error_model_hash_deterministic() { - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), @@ -1621,13 +1906,13 @@ mod tests { #[test] fn error_model_hash_differs_on_value() { - let a = AssayErrorModels::new() + let a = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), ) .unwrap(); - let b = AssayErrorModels::new() + let b = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 10.0), @@ -1638,13 +1923,13 @@ mod tests { #[test] fn error_model_hash_differs_on_type() { - let a = AssayErrorModels::new() + let a = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), ) .unwrap(); - let b = AssayErrorModels::new() + let b = AssayErrorModels::empty() .add( 0, AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), diff --git a/src/data/event.rs b/src/data/event.rs index ff88e097..d0f96f7e 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -1,3 +1,15 @@ +//! Event types and public label wrappers for subject schedules. +//! +//! These types are the low-level representation behind the higher-level +//! builder and parsing APIs. Most users can start with +//! [`crate::data::builder::SubjectBuilder`], then inspect or transform +//! [`Event`] values after construction. +//! +//! Dose events carry an [`InputLabel`], and observations carry an +//! [`OutputLabel`]. Prefer stable strings such as `"depot"`, `"iv"`, and +//! `"cp"`. Numeric values are accepted, but they remain labels until a +//! downstream workflow explicitly interprets them as indices. + use crate::data::error_model::ErrorPoly; use crate::prelude::simulator::Prediction; use serde::{Deserialize, Serialize}; @@ -7,12 +19,16 @@ use std::fmt; // Shared Analysis Types // ============================================================================ -/// Administration route for a dosing event +/// Administration route classification used by downstream analyses. +/// +/// [`Route`] is a coarse route category, not the original public input label. +/// In the current data-side heuristic: +/// - [`Event::Infusion`] maps to [`Route::IVInfusion`] +/// - [`Event::Bolus`] with input label `0` maps to [`Route::Extravascular`] +/// - [`Event::Bolus`] with any other label maps to [`Route::IVBolus`] /// -/// Determined by the type of dose events and their target compartment: -/// - [`Event::Infusion`] → [`Route::IVInfusion`] -/// - [`Event::Bolus`] with `input >= 1` (central compartment) → [`Route::IVBolus`] -/// - [`Event::Bolus`] with `input == 0` (depot compartment) → [`Route::Extravascular`] +/// If you need the original model-facing label, read [`Bolus::input`] or +/// [`Infusion::input`] instead. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum Route { /// Intravenous bolus @@ -78,12 +94,15 @@ pub enum BLQRule { }, } -/// Represents a pharmacokinetic/pharmacodynamic event +/// One scheduled item in a subject record. +/// +/// Events are the low-level representation for doses and observations: +/// - [`Bolus`] for instantaneous input +/// - [`Infusion`] for input over a duration +/// - [`Observation`] for measured or missing outputs /// -/// Events represent key occurrences in a PK/PD profile, including: -/// - [Bolus] doses (instantaneous drug input) -/// - [Infusion]s (continuous drug input over a duration) -/// - [Observation]s (measured concentrations or other values) +/// Most users create these through `Subject::builder(...)`, row ingestion, or +/// file parsing rather than constructing them all by hand. #[derive(Serialize, Debug, Clone, Deserialize)] pub enum Event { /// A bolus dose (instantaneous drug input) @@ -93,6 +112,172 @@ pub enum Event { /// An observation of drug concentration or other measure Observation(Observation), } + +/// Public label for a dosing input or route. +/// +/// [`Bolus`] and [`Infusion`] store the original user-facing route name in +/// this type. +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct InputLabel(String); + +impl InputLabel { + /// Create a new public label. + /// + /// Prefer stable names when the model declares named routes. + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } + + /// Borrow the stored label as a string. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Try to interpret the label as a numeric index. + /// + /// This is mainly a compatibility helper for lower-level paths that still + /// operate on dense indices after label resolution. + pub fn index(&self) -> Option { + self.0.parse::().ok() + } +} + +impl From for InputLabel { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for InputLabel { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl From for InputLabel { + fn from(value: usize) -> Self { + Self(value.to_string()) + } +} + +impl AsRef for InputLabel { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl fmt::Display for InputLabel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl PartialEq for InputLabel { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } +} + +impl PartialEq for usize { + fn eq(&self, other: &InputLabel) -> bool { + other == self + } +} + +impl PartialEq for &InputLabel { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } +} + +impl PartialEq<&InputLabel> for usize { + fn eq(&self, other: &&InputLabel) -> bool { + other.eq(self) + } +} + +/// Public label for an observation output. +/// +/// [`Observation`] stores the original user-facing output name in this type. +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct OutputLabel(String); + +impl OutputLabel { + /// Create a new public label. + /// + /// Prefer stable names when the model declares named outputs. + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } + + /// Borrow the stored label as a string. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Try to interpret the label as a numeric index. + /// + /// This is mainly a compatibility helper for lower-level paths that still + /// operate on dense indices after label resolution. + pub fn index(&self) -> Option { + self.0.parse::().ok() + } +} + +impl From for OutputLabel { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for OutputLabel { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl From for OutputLabel { + fn from(value: usize) -> Self { + Self(value.to_string()) + } +} + +impl AsRef for OutputLabel { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl fmt::Display for OutputLabel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl PartialEq for OutputLabel { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } +} + +impl PartialEq for usize { + fn eq(&self, other: &OutputLabel) -> bool { + other == self + } +} + +impl PartialEq for &OutputLabel { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } +} + +impl PartialEq<&OutputLabel> for usize { + fn eq(&self, other: &&OutputLabel) -> bool { + other.eq(self) + } +} + impl Event { /// Get the time of the event pub fn time(&self) -> f64 { @@ -145,14 +330,15 @@ impl Event { } } -/// Represents an instantaneous input of drug +/// Instantaneous dose input. /// -/// A [Bolus] is a discrete amount of drug added to a specific compartment at a specific time. +/// A [`Bolus`] records one discrete amount at one time, tagged with the public +/// input label that should be matched against the model. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Bolus { time: f64, amount: f64, - input: usize, + input: InputLabel, occasion: usize, } impl Bolus { @@ -162,12 +348,12 @@ impl Bolus { /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose - pub fn new(time: f64, amount: f64, input: usize, occasion: usize) -> Self { + /// * `input` - The route label receiving the dose + pub fn new(time: f64, amount: f64, input: impl ToString, occasion: usize) -> Self { Bolus { time, amount, - input, + input: InputLabel::new(input), occasion, } } @@ -177,9 +363,16 @@ impl Bolus { self.amount } - /// Get the compartment number that receives the bolus - pub fn input(&self) -> usize { - self.input + /// Get the route label that receives the bolus + pub fn input(&self) -> &InputLabel { + &self.input + } + + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Bolus::input`] when working with the public label itself. + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the time of the bolus administration @@ -192,9 +385,9 @@ impl Bolus { self.amount = amount; } - /// Set the compartment number that receives the bolus - pub fn set_input(&mut self, input: usize) { - self.input = input; + /// Set the route label that receives the bolus + pub fn set_input(&mut self, input: impl ToString) { + self.input = InputLabel::new(input); } /// Set the time of the bolus administration @@ -207,8 +400,8 @@ impl Bolus { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the bolus - pub fn mut_input(&mut self) -> &mut usize { + /// Get a mutable reference to the route label that receives the bolus + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -228,14 +421,15 @@ impl Bolus { } } -/// Represents a continuous dose of drug over time +/// Continuous dose input over a duration. /// -/// An [Infusion] administers drug at a constant rate over a specified duration. +/// An [`Infusion`] records the total amount, start time, duration, and public +/// input label for one infusion event. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Infusion { time: f64, amount: f64, - input: usize, + input: InputLabel, duration: f64, occasion: usize, } @@ -246,13 +440,19 @@ impl Infusion { /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - The route label receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn new(time: f64, amount: f64, input: usize, duration: f64, occasion: usize) -> Self { + pub fn new( + time: f64, + amount: f64, + input: impl ToString, + duration: f64, + occasion: usize, + ) -> Self { Infusion { time, amount, - input, + input: InputLabel::new(input), duration, occasion, } @@ -263,9 +463,16 @@ impl Infusion { self.amount } - /// Get the compartment number that receives the infusion - pub fn input(&self) -> usize { - self.input + /// Get the route label that receives the infusion + pub fn input(&self) -> &InputLabel { + &self.input + } + + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Infusion::input`] when working with the public label itself. + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the duration of the infusion @@ -285,9 +492,9 @@ impl Infusion { self.amount = amount; } - /// Set the compartment number that receives the infusion - pub fn set_input(&mut self, input: usize) { - self.input = input; + /// Set the route label that receives the infusion + pub fn set_input(&mut self, input: impl ToString) { + self.input = InputLabel::new(input); } /// Set the time of the infusion administration @@ -305,8 +512,8 @@ impl Infusion { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the infusion - pub fn mut_input(&mut self) -> &mut usize { + /// Get a mutable reference to the route label that receives the infusion + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -343,12 +550,16 @@ pub enum Censor { ALOQ, } -/// Represents an observation of drug concentration or other measured value +/// Observation of a model output. +/// +/// An [`Observation`] can carry a measured value or `None` for a prediction-only +/// time point. Observations also carry the public output label, optional assay +/// error polynomial, occasion index, and censoring state. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Observation { time: f64, value: Option, - outeq: usize, + outeq: OutputLabel, errorpoly: Option, occasion: usize, censoring: Censor, @@ -360,14 +571,14 @@ impl Observation { /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation + /// * `outeq` - Output label corresponding to this observation /// * `errorpoly` - Optional error polynomial coefficients (c0, c1, c2, c3) /// * `occasion` - Occasion index /// * `censoring` - Censoring type for this observation pub(crate) fn new( time: f64, value: Option, - outeq: usize, + outeq: impl ToString, errorpoly: Option, occasion: usize, censoring: Censor, @@ -375,7 +586,7 @@ impl Observation { Observation { time, value, - outeq, + outeq: OutputLabel::new(outeq), errorpoly, occasion, censoring, @@ -387,14 +598,23 @@ impl Observation { self.time } - /// Get the value of the observation (e.g., drug concentration) + /// Get the value of the observation. + /// + /// `None` means this is a prediction-only or missing-observation slot. pub fn value(&self) -> Option { self.value } - /// Get the output equation number corresponding to this observation - pub fn outeq(&self) -> usize { - self.outeq + /// Get the output label corresponding to this observation + pub fn outeq(&self) -> &OutputLabel { + &self.outeq + } + + /// Try to interpret the output label as a numeric index. + /// + /// Prefer [`Observation::outeq`] when working with the public label itself. + pub fn outeq_index(&self) -> Option { + self.outeq.index() } /// Get the error polynomial coefficients (c0, c1, c2, c3) if available @@ -414,9 +634,9 @@ impl Observation { self.value = value; } - /// Set the output equation number corresponding to this observation - pub fn set_outeq(&mut self, outeq: usize) { - self.outeq = outeq; + /// Set the output label corresponding to this observation + pub fn set_outeq(&mut self, outeq: impl ToString) { + self.outeq = OutputLabel::new(outeq); } /// Set the [ErrorPoly] for this observation @@ -434,8 +654,8 @@ impl Observation { &mut self.value } - /// Get a mutable reference to the output equation number - pub fn mut_outeq(&mut self) -> &mut usize { + /// Get a mutable reference to the output label + pub fn mut_outeq(&mut self) -> &mut OutputLabel { &mut self.outeq } @@ -454,13 +674,19 @@ impl Observation { &mut self.occasion } - /// Create a [Prediction] from this observation + /// Create a [`Prediction`] from this observation. + /// + /// This is a low-level helper for code paths that already operate on a + /// resolved or numeric output index. Named output labels must be resolved by + /// the caller before this conversion happens. pub fn to_prediction(&self, pred: f64, state: Vec) -> Prediction { Prediction { time: self.time(), observation: self.value(), prediction: pred, - outeq: self.outeq(), + outeq: self + .outeq_index() + .expect("prediction requires a resolved or numeric output label"), errorpoly: self.errorpoly(), state, occasion: self.occasion(), @@ -539,6 +765,7 @@ mod tests { assert_eq!(bolus.time(), 2.5); assert_eq!(bolus.amount(), 100.0); assert_eq!(bolus.input(), 1); + assert_eq!(bolus.input().as_str(), "1"); } #[test] @@ -561,6 +788,7 @@ mod tests { assert_eq!(infusion.time(), 1.0); assert_eq!(infusion.amount(), 200.0); assert_eq!(infusion.input(), 1); + assert_eq!(infusion.input().as_str(), "1"); assert_eq!(infusion.duration(), 2.5); } @@ -589,6 +817,7 @@ mod tests { assert_eq!(observation.time(), 5.0); assert_eq!(observation.value(), Some(75.5)); assert_eq!(observation.outeq(), 2); + assert_eq!(observation.outeq().as_str(), "2"); assert_eq!(observation.errorpoly(), error_poly); } diff --git a/src/data/mod.rs b/src/data/mod.rs index 996c791d..28a80b32 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,32 +1,83 @@ -//! Data structures and utilities for pharmacometric modeling +//! Data structures for building pharmacometric input data. //! -//! This module provides types for representing pharmacokinetic/pharmacodynamic data, -//! including subjects, dosing events, observations, and covariates. It also includes -//! utilities for reading and manipulating this data. +//! Use this module when you need to describe what happened to each subject: +//! doses, infusions, observations, covariates, and occasion boundaries. //! -//! # Key Components +//! This module is the input side of `pharmsol`. It is where you assemble +//! subjects and datasets before simulation, estimation, or NCA. It is not where +//! you define model equations or choose a backend. For those workflows, move to +//! [`crate::simulator`], [`crate::nca`], or the feature-gated `pharmsol::dsl` +//! surface. //! -//! - **Events**: Dosing events (bolus, infusion) and observations -//! - **Covariates**: Time-varying subject characteristics -//! - **Subjects**: Collections of events and covariates for a single individual -//! - **Data**: Collections of subjects, representing a complete dataset -//! - **Error Models**: Two types for different algorithm families: -//! - [`ErrorModel`]: Observation-based (assay error) for non-parametric algorithms -//! - [`ResidualErrorModel`]: Prediction-based (residual error) for parametric algorithms +//! # Start Here //! -//! # Examples +//! Most users only need three entrypoints first: //! -//! Creating a subject with the builder pattern: +//! - [`Subject`] for one individual and their full schedule. +//! - [`Data`] for a dataset containing many subjects. +//! - `Subject::builder` for the smallest fluent API to create doses, +//! observations, and covariates in Rust. +//! +//! The main supporting types are: +//! +//! - [`Occasion`] for repeated periods within one subject. +//! - [`Event`], [`Bolus`], [`Infusion`], and [`Observation`] for explicit +//! event-level control. +//! - [`Covariate`] and [`Covariates`] for time-varying subject characteristics. +//! - [`ErrorModel`], [`ResidualErrorModel`], and [`ObservationError`] for the +//! different error surfaces used by downstream workflows. +//! +//! # Choose A Data Input Path +//! +//! - Use `Subject::builder` when you are authoring a schedule directly in Rust. +//! - Use [`row::DataRow`] and [`row::DataRowBuilder`] when your source data is +//! already row-shaped in memory. +//! - Use [`parser::read_pmetrics`] when you are loading a Pmetrics-style file +//! from disk. +//! - Use [`Event`] variants directly when you already have validated event +//! records and need lower-level control than the builder offers. +//! +//! # Label Semantics +//! +//! Dosing inputs and observation outputs use public labels. +//! +//! - The `input` on [`Bolus`] and [`Infusion`] is the route or input label that +//! will be matched against the model. +//! - The `outeq` on [`Observation`] is the output label that identifies which +//! model output the observation belongs to. +//! - Prefer stable names such as `"depot"`, `"central"`, `"iv"`, or `"cp"`. +//! - If you pass a number, it is still treated as a public label string. Use +//! numeric values only when your model intentionally declares numeric labels. +//! +//! [`Occasion`] indices are different: they are integer period markers used to +//! separate repeated dosing blocks within one subject. +//! +//! # Error Surfaces +//! +//! This module exposes three related but different error families: +//! +//! - [`ErrorModel`] for assay or measurement error driven by the observation +//! value, commonly used in non-parametric workflows. +//! - [`ResidualErrorModel`] for residual unexplained variability driven by the +//! prediction value, commonly used in parametric workflows. +//! - [`ObservationError`] for invalid or insufficient observation data during +//! profile construction and related preprocessing. +//! +//! # Example //! //! ```rust //! use pharmsol::*; //! //! let subject = Subject::builder("patient_001") -//! .bolus(0.0, 100.0, 0) -//! .observation(1.0, 10.5, 0) -//! .observation(2.0, 8.2, 0) +//! .bolus(0.0, 100.0, "depot") +//! .observation(1.0, 12.3, "cp") +//! .missing_observation(2.0, "cp") //! .covariate("weight", 0.0, 70.0) //! .build(); +//! +//! let data = Data::new(vec![subject]); +//! +//! assert_eq!(data.subjects().len(), 1); //! ``` pub mod auc; diff --git a/src/data/parser/mod.rs b/src/data/parser/mod.rs index 7bfde3ca..74a50a84 100644 --- a/src/data/parser/mod.rs +++ b/src/data/parser/mod.rs @@ -1,3 +1,15 @@ +//! File-based parsers and parser-facing row utilities. +//! +//! Use this module when your source data starts as files or parser-shaped rows. +//! It re-exports the row ingestion API from [`crate::data::row`] and provides +//! format-specific loaders such as [`read_pmetrics`]. +//! +//! Choose the entrypoint by source shape: +//! - Use [`DataRow`] or [`build_data`] when you already mapped external data into +//! canonical row fields yourself. +//! - Use [`read_pmetrics`] when the source file already follows the Pmetrics CSV +//! convention. + pub mod pmetrics; pub use crate::data::row::{build_data, DataError, DataRow, DataRowBuilder}; diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index c410d689..2c90e2a7 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -1,3 +1,12 @@ +//! Pmetrics CSV parsing and export helpers. +//! +//! This module reads and writes the Pmetrics-style tabular format while keeping +//! pharmsol's public input and output labels intact. +//! +//! `INPUT` and `OUTEQ` values are parsed as labels, not rewritten to dense +//! indices. Named values such as `iv` and `cp` are preserved exactly, and +//! numeric values such as `1` are preserved as numeric-looking labels. + use crate::{data::*, PharmsolError}; use csv::WriterBuilder; use serde::de::{MapAccess, Visitor}; @@ -10,19 +19,27 @@ use crate::data::row::DataRow; use std::fmt; use std::str::FromStr; -/// Read a Pmetrics datafile and convert it to a [Data] object +/// Read a Pmetrics CSV file into [`Data`]. +/// +/// Use [`read_pmetrics`] when the source file already follows the usual +/// Pmetrics column convention instead of mapping the file into [`DataRow`] +/// values yourself. /// -/// This function parses a Pmetrics-formatted CSV file and constructs a [Data] object containing the structured -/// pharmacokinetic/pharmacodynamic data. The function handles various data formats including doses, observations, -/// and covariates. +/// The parser normalizes header names to lowercase, preserves `INPUT` and +/// `OUTEQ` as public labels, expands `ADDL` dosing rows through the shared row +/// ingestion path, and groups rows into occasions using `EVID=4`. +/// +/// All columns not claimed by the core Pmetrics schema are treated as +/// covariates. /// /// # Arguments /// -/// * `path` - The path to the Pmetrics CSV file +/// * `path` - Path to the Pmetrics CSV file /// /// # Returns /// -/// * `Result` - A result containing either the parsed [Data] object or an error +/// A parsed [`Data`] object or a [`DataError`] if the file cannot be read or a +/// required row field is missing. /// /// # Example /// @@ -33,14 +50,25 @@ use std::str::FromStr; /// println!("Number of subjects: {}", data.subjects().len()); /// ``` /// -/// # Format details +/// # Expected columns +/// +/// The canonical columns are `ID`, `TIME`, `EVID`, `DOSE`, `DUR`, `ADDL`, +/// `II`, `INPUT`, `OUT`, `OUTEQ`, `CENS`, and optional `C0..C3` error +/// coefficients. /// -/// The Pmetrics format expects columns like ID, TIME, EVID, DOSE, DUR, etc. The function will: +/// All other numeric columns are treated as covariates. +/// +/// # Parsing behavior +/// +/// The parser will: /// - Convert all headers to lowercase for case-insensitivity /// - Group rows by subject ID /// - Create occasions based on EVID=4 events /// - Parse covariates and create appropriate interpolations /// - Handle additional doses via ADDL and II fields +/// - Preserve raw `INPUT` and `OUTEQ` labels as strings until model resolution +/// - Treat `OUT=-99` as a missing observation value, matching the common +/// Pmetrics convention /// /// For specific column definitions, see the `Row` struct. #[allow(dead_code)] @@ -72,7 +100,7 @@ pub fn read_pmetrics(path: impl Into) -> Result { build_data(data_rows) } -/// A [Row] represents a row in the Pmetrics data format +/// One row from a Pmetrics file after serde deserialization. #[derive(Deserialize, Debug, Serialize, Default, Clone)] #[serde(rename_all = "lowercase")] struct Row { @@ -94,15 +122,15 @@ struct Row { /// Dosing interval #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, - /// Input compartment - #[serde(deserialize_with = "deserialize_option_usize")] - input: Option, + /// Input label from the `INPUT` column + #[serde(deserialize_with = "deserialize_option_route_label")] + input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, - /// Corresponding output equation for the observation - #[serde(deserialize_with = "deserialize_option_usize")] - outeq: Option, + /// Output label from the `OUTEQ` column + #[serde(deserialize_with = "deserialize_option_output_label")] + outeq: Option, /// Censoring output #[serde(default, deserialize_with = "deserialize_option_censor")] cens: Option, @@ -134,12 +162,12 @@ impl Row { dur: self.dur, addl: self.addl.map(|a| a as i64), ii: self.ii, - input: self.input, + input: self.input.clone(), // Treat -99 as missing value (Pmetrics convention) out: self .out .and_then(|v| if v == -99.0 { None } else { Some(v) }), - outeq: self.outeq, + outeq: self.outeq.clone(), cens: self.cens, c0: self.c0, c1: self.c1, @@ -196,11 +224,18 @@ where } } -fn deserialize_option_usize<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_option_route_label<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserialize_option::(deserializer).map(|value| value.map(InputLabel::from)) +} + +fn deserialize_option_output_label<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserialize_option::(deserializer) + deserialize_option::(deserializer).map(|value| value.map(OutputLabel::from)) } fn deserialize_option_isize<'de, D>(deserializer: D) -> Result, D::Error> @@ -257,7 +292,14 @@ where } impl Data { - /// Write the dataset to a file in Pmetrics format + /// Write the dataset to a file in Pmetrics format. + /// + /// `INPUT` and `OUTEQ` are written using their stored public labels. Named + /// labels such as `iv` and `cp` remain named labels, and numeric-looking + /// labels are written back exactly as stored. + /// + /// Missing optional fields are emitted as `.` placeholders to match the + /// usual Pmetrics text convention. /// /// # Arguments /// @@ -496,4 +538,50 @@ mod tests { assert_eq!(second.get(11), Some(".")); assert_eq!(second.get(14), Some(".")); } + + #[test] + fn read_pmetrics_preserves_named_route_and_output_labels() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,cp,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Infusion(infusion) => assert_eq!(infusion.input().as_str(), "iv"), + _ => panic!("expected infusion event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "cp"), + _ => panic!("expected observation event"), + } + } + + #[test] + fn read_pmetrics_preserves_numeric_labels_as_strings() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,.,100,.,.,1,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,1,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Bolus(bolus) => assert_eq!(bolus.input().as_str(), "1"), + _ => panic!("expected bolus event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "1"), + _ => panic!("expected observation event"), + } + } } diff --git a/src/data/row.rs b/src/data/row.rs index b3b38ad8..fcb610ea 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -1,39 +1,56 @@ -//! Row representation of [Data] for flexible parsing +//! Row-shaped data ingestion for [`Data`] and [`Subject`] assembly. +//! +//! Use this module when your source data already looks like rows from a table, +//! CSV file, database export, or ETL pipeline. +//! +//! Choose the ingestion path by source shape: +//! - Use [`crate::data::builder::SubjectBuilder`] when you want to author a +//! schedule directly in Rust. +//! - Use [`DataRow`] and [`build_data`] when your application already has +//! validated row records in memory. +//! - Use [`crate::data::parser::read_pmetrics`] when the source file already +//! follows the Pmetrics column convention. +//! +//! [`DataRow`] keeps public route and output labels as strings. Labels such as +//! `"iv"`, `"depot"`, and `"cp"` are preserved through row parsing and later +//! resolved against model metadata by downstream workflows. //! //! # Example //! //! ```rust //! use pharmsol::data::parser::DataRow; //! -//! // Create a dosing row with ADDL expansion //! let row = DataRow::builder("subject_1", 0.0) //! .evid(1) //! .dose(100.0) -//! .input(1) -//! .addl(3) // 3 additional doses -//! .ii(12.0) // 12 hours apart +//! .input("iv") +//! .addl(3) +//! .ii(12.0) //! .build(); //! //! let events = row.into_events().unwrap(); -//! assert_eq!(events.len(), 4); // Original + 3 additional doses +//! assert_eq!(events.len(), 4); //! ``` -//! use crate::data::*; use std::collections::HashMap; use thiserror::Error; -/// A format-agnostic representation of a single data row +/// A format-agnostic representation of one input row. +/// +/// [`DataRow`] collects the canonical fields needed to turn one external row +/// into one or more [`Event`] values. /// -/// This struct represents the canonical fields needed to create pharmsol Events. -/// Consumers construct this from their source data (regardless of column names), -/// then call [`into_events()`](DataRow::into_events) to get properly parsed -/// Events with full ADDL expansion, EVID handling, censoring, etc. +/// Build this type from your own column mapping or external schema, then call +/// [`DataRow::into_events`] or [`build_data`] to assemble subjects and datasets. +/// +/// A single row can expand into several events when `ADDL` and `II` are both +/// present. /// /// # Fields /// -/// All fields use Pmetrics conventions: -/// - `input` and `outeq` are **1-indexed** (kept as-is, user must size arrays accordingly) +/// All fields use the public labeling conventions: +/// - `input` and `outeq` preserve the route and output labels from the source data /// - `evid`: 0=observation, 1=dose, 4=reset/new occasion /// - `addl`: positive=forward in time, negative=backward in time /// @@ -42,24 +59,22 @@ use thiserror::Error; /// ```rust /// use pharmsol::data::parser::DataRow; /// -/// // Observation row /// let obs = DataRow::builder("pt1", 1.0) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .build(); /// -/// // Dosing row with negative ADDL (doses before time 0) /// let dose = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) -/// .input(1) -/// .addl(-10) // 10 doses BEFORE time 0 +/// .input("iv") +/// .addl(-10) /// .ii(12.0) /// .build(); /// /// let events = dose.into_events().unwrap(); -/// // Events at times: -120, -108, -96, ..., -12, 0 +/// assert_eq!(obs.outeq.as_ref().map(|label| label.as_str()), Some("cp")); /// assert_eq!(events.len(), 11); /// ``` #[derive(Debug, Clone, Default)] @@ -78,12 +93,12 @@ pub struct DataRow { pub addl: Option, /// Interdose interval for ADDL pub ii: Option, - /// Input compartment - pub input: Option, + /// Input route label + pub input: Option, /// Observed value (for EVID=0) pub out: Option, - /// Output equation number - pub outeq: Option, + /// Output label + pub outeq: Option, /// Censoring indicator pub cens: Option, /// Error polynomial coefficients @@ -99,7 +114,7 @@ pub struct DataRow { } impl DataRow { - /// Create a new builder for constructing a DataRow + /// Create a builder for constructing one [`DataRow`]. /// /// # Arguments /// @@ -114,7 +129,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("depot") /// .build(); /// ``` pub fn builder(id: impl Into, time: f64) -> DataRowBuilder { @@ -129,13 +144,14 @@ impl DataRow { } } - /// Convert this row into pharmsol Events + /// Convert this row into one or more [`Event`] values. /// - /// This method contains all the complex parsing logic: + /// This method performs the row-level translation logic: /// - EVID interpretation (0=observation, 1=dose, 4=reset) /// - ADDL/II expansion (both positive and negative directions) /// - Infusion vs bolus detection based on DUR /// - Censoring and error polynomial handling + /// - Preservation of public input and output labels /// /// # ADDL Expansion /// @@ -163,13 +179,13 @@ impl DataRow { /// let row = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("iv") /// .addl(2) /// .ii(24.0) /// .build(); /// /// let events = row.into_events().unwrap(); - /// assert_eq!(events.len(), 3); // doses at 24, 48, and 0 + /// assert_eq!(events.len(), 3); /// /// let times: Vec = events.iter().map(|e| e.time()).collect(); /// assert_eq!(times, vec![24.0, 48.0, 0.0]); @@ -180,14 +196,17 @@ impl DataRow { match self.evid { 0 => { // Observation event - events.push(Event::Observation(Observation::new( - self.time, - self.out, + let outeq = self.outeq + .clone() .ok_or_else(|| DataError::MissingObservationOuteq { id: self.id.clone(), time: self.time, - })?, // Keep 1-indexed as provided by Pmetrics + })?; + events.push(Event::Observation(Observation::new( + self.time, + self.out, + outeq, self.get_errorpoly(), 0, // occasion set later self.cens.unwrap_or(Censor::None), @@ -196,10 +215,13 @@ impl DataRow { 1 | 4 => { // Dosing event (1) or reset with dose (4) - let input = self.input.ok_or_else(|| DataError::MissingBolusInput { - id: self.id.clone(), - time: self.time, - })?; // Keep 1-indexed as provided by Pmetrics + let input = self + .input + .clone() + .ok_or_else(|| DataError::MissingBolusInput { + id: self.id.clone(), + time: self.time, + })?; let event = if self.dur.unwrap_or(0.0) > 0.0 { // Infusion @@ -281,7 +303,11 @@ impl DataRow { } } -/// Builder for constructing DataRow with a fluent API +/// Fluent builder for [`DataRow`]. +/// +/// Use [`DataRowBuilder`] when you have row-shaped data in memory and want to +/// construct rows incrementally before calling [`DataRow::into_events`] or +/// [`build_data`]. /// /// # Example /// @@ -292,7 +318,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 1.5) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .cens(Censor::None) /// .covariate("weight", 70.0) /// .covariate("age", 45.0) @@ -367,12 +393,13 @@ impl DataRowBuilder { self } - /// Set the input compartment (1-indexed) + /// Set the input route label. /// - /// Required for EVID=1 (dosing events). - /// Kept as 1-indexed; user must size state arrays accordingly. - pub fn input(mut self, input: usize) -> Self { - self.row.input = Some(input); + /// Required for EVID=1 dosing rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. + pub fn input(mut self, input: impl ToString) -> Self { + self.row.input = Some(InputLabel::new(input)); self } @@ -384,12 +411,13 @@ impl DataRowBuilder { self } - /// Set the output equation (1-indexed) + /// Set the output label. /// - /// Required for EVID=0 (observation events). - /// Will be converted to 0-indexed internally. - pub fn outeq(mut self, outeq: usize) -> Self { - self.row.outeq = Some(outeq); + /// Required for EVID=0 observation rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. + pub fn outeq(mut self, outeq: impl ToString) -> Self { + self.row.outeq = Some(OutputLabel::new(outeq)); self } @@ -430,13 +458,18 @@ impl DataRowBuilder { } } -/// Build a [Data] object from an iterator of [DataRow]s +/// Build a [`Data`] object from row-shaped input. /// -/// This function handles all the complex assembly logic: +/// This function assembles rows into subjects and occasions: /// - Groups rows by subject ID /// - Splits into occasions at EVID=4 boundaries /// - Converts rows to events via [`DataRow::into_events()`] /// - Builds covariates from row covariate data +/// - Preserves per-subject row order within each occasion block +/// +/// Use this when you already have a collection of [`DataRow`] values in memory. +/// If your source file is a Pmetrics CSV, use [`crate::data::parser::read_pmetrics`] +/// instead. /// /// # Example /// @@ -444,23 +477,21 @@ impl DataRowBuilder { /// use pharmsol::data::parser::{DataRow, build_data}; /// /// let rows = vec![ -/// // Subject 1, Occasion 0 /// DataRow::builder("pt1", 0.0) -/// .evid(1).dose(100.0).input(1).build(), +/// .evid(1).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 1.0) -/// .evid(0).out(50.0).outeq(1).build(), -/// // Subject 1, Occasion 1 (EVID=4 starts new occasion) +/// .evid(0).out(50.0).outeq("cp").build(), /// DataRow::builder("pt1", 24.0) -/// .evid(4).dose(100.0).input(1).build(), +/// .evid(4).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 25.0) -/// .evid(0).out(48.0).outeq(1).build(), -/// // Subject 2 +/// .evid(0).out(48.0).outeq("cp").build(), /// DataRow::builder("pt2", 0.0) -/// .evid(1).dose(50.0).input(1).build(), +/// .evid(1).dose(50.0).input("iv").build(), /// ]; /// /// let data = build_data(rows).unwrap(); /// assert_eq!(data.subjects().len(), 2); +/// assert_eq!(data.subjects()[0].occasions().len(), 2); /// ``` pub fn build_data(rows: impl IntoIterator) -> Result { // Group rows by subject ID @@ -556,14 +587,14 @@ pub enum DataError { /// Required observation value (OUT) is missing #[error("Observation OUT is missing for {id} at time {time}")] MissingObservationOut { id: String, time: f64 }, - /// Required observation output equation (OUTEQ) is missing - #[error("Observation OUTEQ is missing in for {id} at time {time}")] + /// Required observation output label (`OUTEQ`) is missing + #[error("Observation OUTEQ is missing for {id} at time {time}")] MissingObservationOuteq { id: String, time: f64 }, /// Required infusion dose amount is missing #[error("Infusion amount (DOSE) is missing for {id} at time {time}")] MissingInfusionDose { id: String, time: f64 }, - /// Required infusion input compartment is missing - #[error("Infusion compartment (INPUT) is missing for {id} at time {time}")] + /// Required infusion input label (`INPUT`) is missing + #[error("Infusion input label (INPUT) is missing for {id} at time {time}")] MissingInfusionInput { id: String, time: f64 }, /// Required infusion duration is missing #[error("Infusion duration (DUR) is missing for {id} at time {time}")] @@ -571,8 +602,8 @@ pub enum DataError { /// Required bolus dose amount is missing #[error("Bolus amount (DOSE) is missing for {id} at time {time}")] MissingBolusDose { id: String, time: f64 }, - /// Required bolus input compartment is missing - #[error("Bolus compartment (INPUT) is missing for {id} at time {time}")] + /// Required bolus input label (`INPUT`) is missing + #[error("Bolus input label (INPUT) is missing for {id} at time {time}")] MissingBolusInput { id: String, time: f64 }, } diff --git a/src/data/structs.rs b/src/data/structs.rs index 82cd3faf..a116d703 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -59,6 +59,10 @@ impl Data { self.subjects.iter().collect() } + pub(crate) fn subjects_slice(&self) -> &[Subject] { + &self.subjects + } + /// Add a subject to the dataset /// /// # Arguments @@ -180,17 +184,18 @@ impl Data { let old_events = occasion.process_events(None, true); // Create a set of existing (time, outeq) pairs for fast lookup - let existing_obs: std::collections::HashSet<(u64, usize)> = old_events - .iter() - .filter_map(|event| match event { - Event::Observation(obs) => { - // Convert to microseconds for consistent comparison - let time_key = (obs.time() * 1e6).round() as u64; - Some((time_key, obs.outeq())) - } - _ => None, - }) - .collect(); + let existing_obs: std::collections::HashSet<(u64, OutputLabel)> = + old_events + .iter() + .filter_map(|event| match event { + Event::Observation(obs) => { + // Convert to microseconds for consistent comparison + let time_key = (obs.time() * 1e6).round() as u64; + Some((time_key, obs.outeq().clone())) + } + _ => None, + }) + .collect(); // Generate new observation times let mut new_events = Vec::new(); @@ -198,13 +203,13 @@ impl Data { while time < last_time { let time_key = (time * 1e6).round() as u64; - for &outeq in &outeq_values { + for outeq in &outeq_values { // Only add if this (time, outeq) combination doesn't exist - if !existing_obs.contains(&(time_key, outeq)) { + if !existing_obs.contains(&(time_key, outeq.clone())) { let obs = Observation::new( time, None, - outeq, + outeq.clone(), None, occasion.index, Censor::None, @@ -273,15 +278,15 @@ impl Data { self.subjects.is_empty() } - /// Get a vector of all unique output equations (outeq) across all subjects - pub fn get_output_equations(&self) -> Vec { + /// Get a vector of all unique output labels (outeq) across all subjects + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let mut outeq_values: Vec = self + let mut outeq_values: Vec = self .subjects .iter() .flat_map(|subject| subject.get_output_equations()) .collect(); - outeq_values.sort_unstable(); + outeq_values.sort(); outeq_values.dedup(); outeq_values } @@ -396,14 +401,14 @@ impl Subject { self.occasions.iter_mut() } - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let outeq_values: Vec = self + let outeq_values: Vec = self .occasions .iter() .flat_map(|occasion| { occasion.events.iter().filter_map(|event| match event { - Event::Observation(obs) => Some(obs.outeq()), + Event::Observation(obs) => Some(obs.outeq().clone()), _ => None, }) }) @@ -598,8 +603,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let lagtime = fn_lag(&spp.clone().into(), time, covariates); - if let Some(l) = lagtime.get(&bolus.input()) { - *bolus.mut_time() += l; + if let Some(input) = bolus.input_index() { + if let Some(l) = lagtime.get(&input) { + *bolus.mut_time() += l; + } } } } @@ -615,8 +622,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let fa = fn_fa(&spp.clone().into(), time, covariates); - if let Some(f) = fa.get(&bolus.input()) { - bolus.set_amount(bolus.amount() * f); + if let Some(input) = bolus.input_index() { + if let Some(f) = fa.get(&input) { + bolus.set_amount(bolus.amount() * f); + } } } } @@ -703,7 +712,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: Option, censored: Censor, ) { @@ -713,7 +722,7 @@ impl Occasion { } /// Add a missing [Observation] event to the [Occasion] - pub fn add_missing_observation(&mut self, time: f64, outeq: usize) { + pub fn add_missing_observation(&mut self, time: f64, outeq: impl ToString) { let observation = Observation::new(time, None, outeq, None, self.index, Censor::None); self.add_event(Event::Observation(observation)); } @@ -725,7 +734,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) { @@ -741,13 +750,13 @@ impl Occasion { } /// Add a [Bolus] event to the [Occasion] - pub fn add_bolus(&mut self, time: f64, amount: f64, input: usize) { + pub fn add_bolus(&mut self, time: f64, amount: f64, input: impl ToString) { let bolus = Bolus::new(time, amount, input, self.index); self.add_event(Event::Bolus(bolus)); } /// Add an [Infusion] event to the [Occasion] - pub fn add_infusion(&mut self, time: f64, amount: f64, input: usize, duration: f64) { + pub fn add_infusion(&mut self, time: f64, amount: f64, input: impl ToString, duration: f64) { let infusion = Infusion::new(time, amount, input, duration, self.index); self.add_event(Event::Infusion(infusion)); } @@ -775,17 +784,6 @@ impl Occasion { .unwrap_or(0.0) } - pub(crate) fn infusions_ref(&self) -> Vec<&Infusion> { - //TODO this can be pre-computed when the struct is initially created - self.events - .iter() - .filter_map(|event| match event { - Event::Infusion(infusion) => Some(infusion), - _ => None, - }) - .collect() - } - /// Get an iterator over all events /// /// # Returns @@ -967,7 +965,7 @@ impl Occasion { for event in &self.events { if let Event::Observation(obs) = event { - if obs.outeq() == outeq { + if obs.outeq_index() == Some(outeq) { if let Some(value) = obs.value() { times.push(obs.time()); concs.push(value); diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 3749f183..0dd38252 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -37,18 +37,23 @@ use pharmsol_dsl::ModelKind; use pharmsol_dsl::{analyze_module, lower_typed_model, parse_module, ExecutionModel}; use pharmsol_dsl::{Diagnostic, DiagnosticReport, LoweringError, ParseError, SemanticError}; +/// ABI version for native AoT artifacts produced by this crate. pub const AOT_API_VERSION: u32 = 1; #[cfg(feature = "dsl-aot")] +/// Selects the compilation target for a native ahead-of-time artifact. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub enum NativeAotTarget { + /// Compile for the current host toolchain target. #[default] Host, + /// Compile for an explicit Rust target triple. Triple(String), } #[cfg(feature = "dsl-aot")] impl NativeAotTarget { + /// Create a target selector for an explicit Rust target triple. pub fn triple(target: impl Into) -> Self { Self::Triple(target.into()) } @@ -62,15 +67,24 @@ impl NativeAotTarget { } #[cfg(feature = "dsl-aot")] +/// Options that control native ahead-of-time artifact export. +/// +/// AoT export writes a small template crate under [`template_root`](Self::template_root), +/// builds a native shared library, and then copies the resulting artifact to +/// [`output`](Self::output) or a generated default path. #[derive(Debug, Clone, PartialEq, Eq)] pub struct NativeAotCompileOptions { + /// Target triple selection for the emitted artifact. pub target: NativeAotTarget, + /// Optional final artifact location. pub output: Option, + /// Working directory used for the temporary template crate and build output. pub template_root: PathBuf, } #[cfg(feature = "dsl-aot")] impl NativeAotCompileOptions { + /// Create AoT options rooted at a template build directory. pub fn new(template_root: PathBuf) -> Self { Self { target: NativeAotTarget::Host, @@ -79,17 +93,20 @@ impl NativeAotCompileOptions { } } + /// Set the final artifact output path. pub fn with_output(mut self, output: PathBuf) -> Self { self.output = Some(output); self } + /// Set the compilation target triple. pub fn with_target(mut self, target: NativeAotTarget) -> Self { self.target = target; self } } +/// Error produced while exporting, reading, or loading a native AoT artifact. #[derive(Error)] pub enum AotError { #[error(transparent)] @@ -151,6 +168,43 @@ impl fmt::Debug for AotError { } #[cfg(feature = "dsl-aot")] +/// Parse DSL source, lower one selected model, and export a native AoT artifact. +/// +/// Use this when you want a reusable native artifact that can be loaded later +/// with [`load_aot_model`] or [`crate::dsl::load_runtime_artifact`]. +/// +/// This function requires the `dsl-aot` feature. Loading the resulting artifact +/// later requires `dsl-aot-load`. +/// +/// ```rust,no_run +/// use std::path::PathBuf; +/// +/// use pharmsol::dsl::{compile_module_source_to_aot, load_aot_model, NativeAotCompileOptions}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let artifact = compile_module_source_to_aot( +/// source, +/// Some("bimodal_ke"), +/// NativeAotCompileOptions::new(PathBuf::from("target/doc-aot-build")), +/// |_, _| {}, +/// )?; +/// let loaded = load_aot_model(&artifact)?; +/// # let _ = loaded; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_module_source_to_aot( source: &str, model_name: Option<&str>, @@ -184,6 +238,10 @@ pub fn compile_module_source_to_aot( } #[cfg(feature = "dsl-aot")] +/// Export a lowered execution model as a native AoT artifact. +/// +/// Use this lower-level entrypoint when you already own the frontend pipeline +/// and only need artifact generation. pub fn export_execution_model_to_aot( model: &ExecutionModel, options: NativeAotCompileOptions, @@ -240,6 +298,10 @@ pub fn export_execution_model_to_aot( } #[cfg(feature = "dsl-aot-load")] +/// Read only the metadata from a native AoT artifact. +/// +/// This is useful when you need to inspect model identity, routes, outputs, or +/// buffer sizes without loading the executable kernels. pub fn read_aot_model_info(path: impl AsRef) -> Result { let library = unsafe { Library::new(path.as_ref()) } .map_err(|error| AotError::Load(error.to_string()))?; @@ -248,6 +310,7 @@ pub fn read_aot_model_info(path: impl AsRef) -> Result) -> Result { let path = path.as_ref(); let library = @@ -534,23 +597,62 @@ mod tests { other => panic!("expected ode model, got {other:?}"), }; - let oral = jit.route_index("oral").expect("jit oral route"); - let iv = jit.route_index("iv").expect("jit iv route"); - let cp = jit.output_index("cp").expect("jit cp output"); - assert_eq!(aot.route_index("oral"), Some(oral)); - assert_eq!(aot.route_index("iv"), Some(iv)); - assert_eq!(aot.output_index("cp"), Some(cp)); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("jit oral route"); + let iv = jit + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index) + .expect("jit iv route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("jit cp output"); + assert_eq!( + aot.info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index), + Some(oral) + ); + assert_eq!( + aot.info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index), + Some(iv) + ); + assert_eq!( + aot.info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index), + Some(cp) + ); let subject = crate::Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; diff --git a/src/dsl/compiled_backend_abi.rs b/src/dsl/compiled_backend_abi.rs index 26a2f825..8717c416 100644 --- a/src/dsl/compiled_backend_abi.rs +++ b/src/dsl/compiled_backend_abi.rs @@ -324,7 +324,9 @@ mod tests { }], routes: vec![NativeRouteInfo { name: "iv".to_string(), + declaration_index: 0, index: 0, + kind: None, destination_offset: 1, inject_input_to_destination: true, }], diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index ed516387..c41d4ed4 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -83,6 +83,11 @@ pub type JitAnalyticalModel = NativeAnalyticalModel; pub type JitSdeModel = NativeSdeModel; pub type CompiledJitModel = CompiledNativeModel; +/// Error reported while lowering an execution model into native in-process JIT +/// code. +/// +/// The error retains the backend diagnostic so callers can render the message +/// against the original DSL source when available. #[derive(Clone, PartialEq, Eq)] pub struct JitCompileError { diagnostic: Box, @@ -214,6 +219,10 @@ struct LoweredValue { ty: ValueType, } +/// Compile one lowered execution model into a reusable JIT kernel artifact. +/// +/// This builds the raw Cranelift-compiled kernel bundle for all roles present in +/// the model. Most callers should use [`compile_execution_model_to_jit`] instead. pub fn compile_execution_artifact( model: &ExecutionModel, ) -> Result { @@ -731,7 +740,7 @@ fn lower_load( ExecutionLoad::Parameter(index) => load_fixed(builder, env.args.params, *index, ty), ExecutionLoad::Covariate(index) => load_fixed(builder, env.args.covariates, *index, ty), ExecutionLoad::Derived(index) => load_fixed(builder, env.args.derived, *index, ty), - ExecutionLoad::RouteInput(index) => load_fixed(builder, env.args.routes, *index, ty), + ExecutionLoad::RouteInput { index, .. } => load_fixed(builder, env.args.routes, *index, ty), ExecutionLoad::Local(index) => { let binding = env.locals.get(index).ok_or_else(|| { JitCompileError::new(format!("unknown local slot {index}"), Some(span)) @@ -1217,6 +1226,41 @@ fn state_address( Ok(builder.ins().iadd(base, byte_offset)) } +/// Compile an [`ExecutionModel`](pharmsol_dsl::ExecutionModel) to the native +/// in-process JIT backend. +/// +/// Use this low-level entrypoint when you already own the parse, analyze, and +/// lower steps and want the JIT backend directly instead of the higher-level +/// runtime facade. +/// +/// This function requires the `dsl-jit` feature. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{ +/// analyze_model, compile_execution_model_to_jit, lower_typed_model, parse_model, +/// }; +/// +/// let parsed = parse_model( +/// r#" +/// model implicit_route_injection { +/// kind ode +/// states { central } +/// routes { iv -> central } +/// dynamics { +/// ddt(central) = 0 +/// } +/// outputs { +/// cp = central +/// } +/// } +/// "#, +/// )?; +/// let typed = analyze_model(&parsed)?; +/// let execution = lower_typed_model(&typed)?; +/// let compiled = compile_execution_model_to_jit(&execution)?; +/// # let _ = compiled; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_execution_model_to_jit( model: &ExecutionModel, ) -> Result { @@ -1229,6 +1273,7 @@ pub fn compile_execution_model_to_jit( } } +/// Compile an ODE execution model to the native in-process JIT backend. pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Ode { return Err(JitCompileError::new( @@ -1245,6 +1290,7 @@ pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result Result { @@ -1263,6 +1309,7 @@ pub fn compile_analytical_model_to_jit( )) } +/// Compile an SDE execution model to the native in-process JIT backend. pub fn compile_sde_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Sde { return Err(JitCompileError::new( @@ -1330,6 +1377,119 @@ mod tests { assert!(debugged.contains("error[DSL4000]"), "{}", debugged); } + #[test] + fn authoring_runtime_shares_input_between_bolus_and_infusion_routes() { + let source = r#" +name = shared_authoring +kind = ode + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + let parsed = pharmsol_dsl::parse_model(source).expect("authoring model parses"); + let typed = pharmsol_dsl::analyze_model(&parsed).expect("authoring model analyzes"); + let model = pharmsol_dsl::lower_typed_model(&typed).expect("authoring model lowers"); + let jit = compile_ode_model_to_jit(&model) + .expect("compile jit ode model") + .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); + + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let iv = jit + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index) + .expect("iv route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(iv, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("ode") + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .observation(0.5, 0.0, "cp") + .observation(1.0, 0.0, "cp") + .observation(2.0, 0.0, "cp") + .observation(6.0, 0.0, "cp") + .observation(7.0, 0.0, "cp") + .observation(9.0, 0.0, "cp") + .build(); + + let reference_subject = Subject::builder("ode") + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 0, 2.0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(7.0, 0.0, 0) + .observation(9.0, 0.0, 0) + .build(); + + let support = vec![1.2, 0.15, 40.0]; + let jit_predictions = jit + .estimate_predictions(&jit_subject, &support) + .expect("jit predictions"); + + let reference = ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + let ka = p[0]; + let ke = p[1]; + dx[0] = -ka * x[0] + bolus[0]; + dx[1] = ka * x[0] - ke * x[1] + rateiv[0]; + }, + |_p, _t, _cov| std::collections::HashMap::new(), + |_p, _t, _cov| std::collections::HashMap::new(), + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + y[0] = x[1] / p[2]; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); + + let reference_predictions = reference + .estimate_predictions(&reference_subject, &support) + .expect("reference ode predictions"); + + for (jit_pred, reference_pred) in jit_predictions + .predictions() + .iter() + .zip(reference_predictions.predictions()) + { + assert_relative_eq!( + jit_pred.prediction(), + reference_pred.prediction(), + max_relative = 1e-4 + ); + } + } + fn slot_index(layout: &DenseBufferLayout, name: &str) -> usize { layout .slots @@ -1403,27 +1563,58 @@ mod tests { .expect("compile jit ode model") .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); - let oral = jit.route_index("oral").expect("oral route"); - let iv = jit.route_index("iv").expect("iv route"); - let cp = jit.output_index("cp").expect("cp output"); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let iv = jit + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index) + .expect("iv route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 1); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("ode") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") + .build(); - let subject = Subject::builder("ode") + let reference_subject = Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 1, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(7.0, 0) + .missing_observation(9.0, 0) .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit predictions"); let reference = ODE::new( @@ -1468,7 +1659,7 @@ mod tests { .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference ode predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1489,20 +1680,42 @@ mod tests { let model = load_corpus_model("one_cmt_abs"); let jit = compile_analytical_model_to_jit(&model).expect("compile jit analytical model"); - let oral = jit.route_index("oral").expect("oral route"); - let cp = jit.output_index("cp").expect("cp output"); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("analytical") - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + let reference_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.0, 0.15, 25.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit analytical predictions"); let reference = equation::Analytical::new( @@ -1520,7 +1733,7 @@ mod tests { .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference analytical predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1543,21 +1756,44 @@ mod tests { .expect("compile jit sde model") .with_particles(64); - let oral = jit.route_index("oral").expect("oral route"); - let cp = jit.output_index("cp").expect("cp output"); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("sde") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("sde") + let reference_subject = Subject::builder("sde") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.1, 0.2, 0.12, 0.08, 15.0, 0.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit sde predictions"); let reference = SDE::new( @@ -1594,7 +1830,7 @@ mod tests { .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference sde predictions"); for (jit_pred, reference_pred) in jit_predictions diff --git a/src/dsl/mod.rs b/src/dsl/mod.rs index 563e4cf6..f536c377 100644 --- a/src/dsl/mod.rs +++ b/src/dsl/mod.rs @@ -1,9 +1,94 @@ //! Public DSL facade for pharmsol. //! -//! The backend-neutral frontend is being extracted into `pharmsol-dsl`. -//! Frontend syntax, diagnostics, semantic analysis, and lowering now come -//! from `pharmsol-dsl`, while runtime and backend compilation entrypoints -//! remain owned by `pharmsol`. +//! Use this module when you want to work with pharmsol models as source text +//! and stay inside the main crate for the full workflow: parse DSL source, +//! inspect diagnostics, lower to the execution model, compile to a runtime +//! backend, load saved artifacts, and run predictions. +//! +//! Use the `pharmsol-dsl` crate directly only when you need the backend-neutral +//! frontend as an engineering API. That crate owns parsing, diagnostics, +//! semantic analysis, and lowering. This module re-exports that stable +//! frontend surface and adds the backend-specific entrypoints that stay owned +//! by `pharmsol`. +//! +//! Main entrypoints: +//! +//! - [`parse_model`], [`parse_module`], [`analyze_model`], and +//! [`analyze_module`] for frontend-only validation and inspection. +//! - [`lower_typed_model`] and [`lower_typed_module`] for lowering typed models +//! into the execution representation used by the runtime backends. +//! - [`compile_module_source_to_runtime`] and [`compile_execution_model_to_runtime`] +//! for the one-stop compile-and-run path. +//! - [`load_runtime_artifact`], [`load_aot_model`], and +//! [`load_runtime_wasm_bytes`] for loading saved artifacts back into a model +//! you can execute. +//! +//! Common workflow choices: +//! +//! - Frontend only: parse, analyze, and lower when you need diagnostics, +//! authoring tools, or your own backend. +//! - In-process execution: compile straight to [`RuntimeCompilationTarget`] and +//! keep everything inside the current process. +//! - Native artifact shipping: export a native AoT artifact, then load it later +//! on a compatible host. +//! - WASM artifact shipping: emit `.wasm` bytes or a bundled module for browser +//! or portable runtime use. +//! +//! Feature map: +//! +//! - `dsl-core`: enables this facade and the frontend re-exports from +//! `pharmsol-dsl`. +//! - `dsl-jit`: enables in-process JIT compilation through +//! [`compile_module_source_to_runtime`] with +//! [`RuntimeCompilationTarget::Jit`], plus the lower-level JIT compile +//! entrypoints. +//! - `dsl-aot`: enables native ahead-of-time artifact export through +//! [`compile_module_source_to_aot`] and [`export_execution_model_to_aot`]. +//! - `dsl-aot-load`: enables native AoT artifact loading through +//! [`load_aot_model`] and [`read_aot_model_info`]. +//! - `dsl-wasm-compile`: enables WASM artifact emission through +//! [`compile_module_source_to_wasm_bytes`], +//! [`compile_module_source_to_wasm_module`], and the browser loader helpers. +//! - `dsl-wasm`: enables host-side WASM loading and runtime execution on +//! non-browser native hosts. This includes +//! [`compile_module_source_to_runtime_wasm`], [`load_runtime_wasm_bytes`], +//! [`read_wasm_model_info`], and [`read_wasm_model_info_bytes`]. +//! +//! Smallest compile-to-runtime example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! # let _ = model; +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` +//! +//! For a lower-level frontend pipeline without backend selection, use +//! `pharmsol-dsl`. For a complete runtime path inside the main crate, stay in +//! [`pharmsol::dsl`](self). #[cfg(any(feature = "dsl-aot", feature = "dsl-aot-load"))] mod aot; diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 8e48a022..27f2416a 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -1,48 +1,87 @@ +use std::collections::BTreeMap; + use serde::{Deserialize, Serialize}; use pharmsol_dsl::execution::{ ExecutionExpr, ExecutionExprKind, ExecutionLoad, ExecutionModel, ExecutionStmt, ExecutionStmtKind, KernelImplementation, KernelRole, }; -use pharmsol_dsl::{AnalyticalKernel, ModelKind}; +use pharmsol_dsl::{AnalyticalKernel, ModelKind, RouteKind}; +/// Public metadata extracted from a compiled backend model. +/// +/// This is the shared inspection surface returned by the native AoT, WASM, and +/// runtime loaders. It keeps public labels and buffer sizes available without +/// exposing backend-specific kernel details. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeModelInfo { + /// Public model name. pub name: String, + /// High-level model family. pub kind: ModelKind, + /// Parameter names in support-point order. pub parameters: Vec, + /// Declared covariates and their dense runtime indices. pub covariates: Vec, + /// Declared routes together with declaration-order and dense runtime indices. pub routes: Vec, + /// Declared outputs and their dense runtime indices. pub outputs: Vec, + /// Length of the state buffer used during execution. pub state_len: usize, + /// Length of the derived-value buffer used during execution. pub derived_len: usize, + /// Length of the output buffer used during execution. pub output_len: usize, + /// Length of the dense route-input buffer used during execution. pub route_len: usize, + /// Analytical kernel metadata when the compiled model is analytical. pub analytical: Option, + /// Particle count when the compiled model is stochastic. pub particles: Option, } +/// Metadata for one compiled covariate. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeCovariateInfo { + /// Public covariate name. pub name: String, + /// Dense runtime covariate index. pub index: usize, } +/// Metadata for one compiled route. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeRouteInfo { + /// Public route label. pub name: String, + /// Route position in declaration order. + #[serde(default)] + pub declaration_index: usize, + /// Dense runtime route-input index. pub index: usize, + /// Coarse route kind when declared in metadata. + #[serde(default)] + pub kind: Option, + /// Dense destination state offset used by compiled kernels. pub destination_offset: usize, + /// Whether the compiled backend injects the route input into the destination + /// state automatically when the model does not read the route input + /// explicitly. pub inject_input_to_destination: bool, } +/// Metadata for one compiled output. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeOutputInfo { + /// Public output label. pub name: String, + /// Dense runtime output index. pub index: usize, } impl NativeModelInfo { + /// Build public compiled-model metadata from a lowered execution model. pub fn from_execution_model(model: &ExecutionModel) -> Self { let explicit_route_input_usage = explicit_route_input_usage(model); Self { @@ -69,10 +108,12 @@ impl NativeModelInfo { .iter() .map(|route| NativeRouteInfo { name: route.name.clone(), + declaration_index: route.declaration_index, index: route.index, + kind: route.kind, destination_offset: route.destination.state_offset, inject_input_to_destination: !explicit_route_input_usage - .get(route.index) + .get(route.declaration_index) .copied() .unwrap_or(false), }) @@ -97,6 +138,12 @@ impl NativeModelInfo { } fn explicit_route_input_usage(model: &ExecutionModel) -> Vec { + let declaration_slots = model + .metadata + .routes + .iter() + .map(|route| (route.symbol, route.declaration_index)) + .collect::>(); let Some(kernel) = (match model.kind { ModelKind::Ode => model.kernel(KernelRole::Dynamics), ModelKind::Sde => model.kernel(KernelRole::Drift), @@ -107,54 +154,193 @@ fn explicit_route_input_usage(model: &ExecutionModel) -> Vec { let mut usage = vec![false; model.metadata.routes.len()]; if let KernelImplementation::Statements(program) = &kernel.implementation { - mark_route_inputs_in_statements(&program.body.statements, &mut usage); + mark_route_inputs_in_statements(&program.body.statements, &declaration_slots, &mut usage); } usage } -fn mark_route_inputs_in_statements(statements: &[ExecutionStmt], usage: &mut [bool]) { +fn mark_route_inputs_in_statements( + statements: &[ExecutionStmt], + declaration_slots: &BTreeMap, + usage: &mut [bool], +) { for statement in statements { match &statement.kind { ExecutionStmtKind::Let(let_stmt) => { - mark_route_inputs_in_expr(&let_stmt.value, usage); + mark_route_inputs_in_expr(&let_stmt.value, declaration_slots, usage); } ExecutionStmtKind::Assign(assign_stmt) => { - mark_route_inputs_in_expr(&assign_stmt.value, usage); + mark_route_inputs_in_expr(&assign_stmt.value, declaration_slots, usage); } ExecutionStmtKind::If(if_stmt) => { - mark_route_inputs_in_expr(&if_stmt.condition, usage); - mark_route_inputs_in_statements(&if_stmt.then_branch, usage); + mark_route_inputs_in_expr(&if_stmt.condition, declaration_slots, usage); + mark_route_inputs_in_statements(&if_stmt.then_branch, declaration_slots, usage); if let Some(else_branch) = &if_stmt.else_branch { - mark_route_inputs_in_statements(else_branch, usage); + mark_route_inputs_in_statements(else_branch, declaration_slots, usage); } } ExecutionStmtKind::For(for_stmt) => { - mark_route_inputs_in_expr(&for_stmt.range.start, usage); - mark_route_inputs_in_expr(&for_stmt.range.end, usage); - mark_route_inputs_in_statements(&for_stmt.body, usage); + mark_route_inputs_in_expr(&for_stmt.range.start, declaration_slots, usage); + mark_route_inputs_in_expr(&for_stmt.range.end, declaration_slots, usage); + mark_route_inputs_in_statements(&for_stmt.body, declaration_slots, usage); } } } } -fn mark_route_inputs_in_expr(expr: &ExecutionExpr, usage: &mut [bool]) { +fn mark_route_inputs_in_expr( + expr: &ExecutionExpr, + declaration_slots: &BTreeMap, + usage: &mut [bool], +) { match &expr.kind { ExecutionExprKind::Literal(_) => {} - ExecutionExprKind::Load(ExecutionLoad::RouteInput(index)) => { - if let Some(slot) = usage.get_mut(*index) { + ExecutionExprKind::Load(ExecutionLoad::RouteInput { route, .. }) => { + if let Some(slot) = declaration_slots + .get(route) + .and_then(|index| usage.get_mut(*index)) + { *slot = true; } } ExecutionExprKind::Load(_) => {} - ExecutionExprKind::Unary { expr, .. } => mark_route_inputs_in_expr(expr, usage), + ExecutionExprKind::Unary { expr, .. } => { + mark_route_inputs_in_expr(expr, declaration_slots, usage) + } ExecutionExprKind::Binary { lhs, rhs, .. } => { - mark_route_inputs_in_expr(lhs, usage); - mark_route_inputs_in_expr(rhs, usage); + mark_route_inputs_in_expr(lhs, declaration_slots, usage); + mark_route_inputs_in_expr(rhs, declaration_slots, usage); } ExecutionExprKind::Call { args, .. } => { for arg in args { - mark_route_inputs_in_expr(arg, usage); + mark_route_inputs_in_expr(arg, declaration_slots, usage); } } } } + +#[cfg(test)] +mod tests { + use super::*; + use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model}; + + fn load_model_info(src: &str) -> NativeModelInfo { + let model = parse_model(src).expect("model parses"); + let typed = analyze_model(&model).expect("model analyzes"); + let lowered = lower_typed_model(&typed).expect("model lowers"); + NativeModelInfo::from_execution_model(&lowered) + } + + #[test] + fn declaration_first_routes_inject_by_default() { + let info = load_model_info( + r#" +model implicit_route_injection { + kind ode + states { central } + routes { iv -> central } + dynamics { + ddt(central) = 0 + } + outputs { + cp = central + } +} +"#, + ); + + assert_eq!(info.routes.len(), 1); + assert!(info.routes[0].inject_input_to_destination); + } + + #[test] + fn explicit_rate_usage_disables_automatic_injection() { + let info = load_model_info( + r#" +model explicit_route_usage { + kind ode + states { central } + routes { iv -> central } + dynamics { + ddt(central) = rate(iv) + } + outputs { + cp = central + } +} +"#, + ); + + assert_eq!(info.routes.len(), 1); + assert!(!info.routes[0].inject_input_to_destination); + } + + #[test] + fn authoring_shared_input_routes_keep_declaration_specific_injection() { + let info = load_model_info( + r#" +name = shared_authoring +kind = ode + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#, + ); + + assert_eq!(info.route_len, 1); + assert_eq!(info.routes.len(), 2); + assert_eq!(info.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(info.routes[1].kind, Some(RouteKind::Infusion)); + assert_eq!(info.routes[0].index, 0); + assert_eq!(info.routes[1].index, 0); + assert!(info.routes[0].inject_input_to_destination); + assert!(!info.routes[1].inject_input_to_destination); + } + + #[test] + fn native_model_info_preserves_canonical_numeric_channel_names() { + let info = load_model_info( + r#" +name = canonical_numeric_channels +kind = ode + +params = ke, v +states = depot, central +outputs = cp, outeq_2 + +bolus(input_10) -> depot +infusion(iv) -> central + +dx(depot) = -ke * depot +dx(central) = rate(input_10) - ke * central + +out(cp) = central / v +out(outeq_2) = depot / v +"#, + ); + + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["input_10", "iv"] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "outeq_2"] + ); + } +} diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 4a94715f..7197084e 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::HashMap; use std::sync::Arc; use diffsol::{ @@ -14,22 +15,25 @@ use cranelift_jit::JITModule; #[cfg(feature = "dsl-aot-load")] use libloading::Library; use pharmsol_dsl::execution::KernelRole; -use pharmsol_dsl::AnalyticalKernel; +use pharmsol_dsl::{AnalyticalKernel, RouteKind, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ - data::{Covariates, Infusion}, + data::error_model::AssayErrorModels, + data::{Covariates, Infusion, InputLabel, OutputLabel}, simulator::{ + cache::{PredictionCache, DEFAULT_CACHE_SIZE}, equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, sde::simulate_sde_event_with, + EqnKind, Equation, EquationPriv, EquationTypes, }, likelihood::{Prediction, SubjectPredictions}, - M, V, + Fa, Lag, M, T, V, }, - Event, Observation, PharmsolError, Subject, + Event, Observation, Occasion, PharmsolError, Subject, }; pub type DenseKernelFn = unsafe extern "C" fn( @@ -264,12 +268,74 @@ impl RuntimeArtifact for NativeExecutionArtifact { #[derive(Clone, Debug)] struct SharedNativeModel { info: Arc, + route_semantics: Arc, artifact: Arc, } +#[derive(Clone, Debug)] +struct RouteInputSemantics { + bolus_destinations: Vec>, + infusion_inputs: Vec, + injected_infusion_destinations: Vec>, +} + +impl RouteInputSemantics { + fn from_model_info(info: &NativeModelInfo) -> Self { + let mut bolus_destinations = vec![None; info.route_len]; + let mut infusion_inputs = vec![false; info.route_len]; + let mut injected_infusion_destinations = vec![None; info.route_len]; + + for route in &info.routes { + match route.kind { + Some(RouteKind::Bolus) => { + bolus_destinations[route.index] = Some(route.destination_offset); + } + Some(RouteKind::Infusion) => { + infusion_inputs[route.index] = true; + if route.inject_input_to_destination { + injected_infusion_destinations[route.index] = + Some(route.destination_offset); + } + } + None => { + bolus_destinations[route.index] = Some(route.destination_offset); + infusion_inputs[route.index] = true; + if route.inject_input_to_destination { + injected_infusion_destinations[route.index] = + Some(route.destination_offset); + } + } + } + } + + Self { + bolus_destinations, + infusion_inputs, + injected_infusion_destinations, + } + } + + fn supports_input(&self, input: usize, kind: RouteKind) -> bool { + match kind { + RouteKind::Bolus => self + .bolus_destinations + .get(input) + .copied() + .flatten() + .is_some(), + RouteKind::Infusion => self.infusion_inputs.get(input).copied().unwrap_or(false), + } + } + + fn bolus_destination(&self, input: usize) -> Option { + self.bolus_destinations.get(input).copied().flatten() + } +} + impl SharedNativeModel { fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { Self { + route_semantics: Arc::new(RouteInputSemantics::from_model_info(&info)), info: Arc::new(info), artifact: Arc::new(artifact), } @@ -291,6 +357,20 @@ impl SharedNativeModel { .map(|output| output.index) } + fn metadata_route_index_for_label(&self, label: &str) -> Option { + self.route_index(label).or_else(|| { + canonical_numeric_alias(label, NUMERIC_ROUTE_PREFIX) + .and_then(|alias| self.route_index(alias.as_str())) + }) + } + + fn metadata_output_index_for_label(&self, label: &str) -> Option { + self.output_index(label).or_else(|| { + canonical_numeric_alias(label, NUMERIC_OUTPUT_PREFIX) + .and_then(|alias| self.output_index(alias.as_str())) + }) + } + fn validate_support_point(&self, support_point: &[f64]) -> Result<(), PharmsolError> { if support_point.len() != self.info.parameters.len() { return Err(PharmsolError::OtherError(format!( @@ -313,6 +393,69 @@ impl SharedNativeModel { Ok(()) } + fn validate_output(&self, outeq: usize) -> Result<(), PharmsolError> { + if outeq >= self.info.output_len { + return Err(PharmsolError::OuteqOutOfRange { + outeq, + nout: self.info.output_len, + }); + } + Ok(()) + } + + fn validate_input_for_kind(&self, input: usize, kind: RouteKind) -> Result<(), PharmsolError> { + self.validate_input(input)?; + if self.route_semantics.supports_input(input, kind) { + return Ok(()); + } + + Err(PharmsolError::UnsupportedInputRouteKind { input, kind }) + } + + fn resolve_input_label( + &self, + label: &InputLabel, + kind: RouteKind, + ) -> Result { + let input = self + .metadata_route_index_for_label(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + self.validate_input_for_kind(input, kind)?; + Ok(input) + } + + fn resolve_output_label(&self, label: &OutputLabel) -> Result { + self.metadata_output_index_for_label(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }) + } + + fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { + let mut events = occasion.process_events(None, true); + + for event in events.iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(events) + } + fn fill_cov_buffer(&self, covariates: &Covariates, time: f64, buf: &mut [f64]) { for covariate in &self.info.covariates { buf[covariate.index] = match covariates.get_covariate(&covariate.name) { @@ -323,9 +466,14 @@ impl SharedNativeModel { } fn apply_route_inputs_to_rates(&self, rates: &mut [f64], route_inputs: &[f64]) { - for route in &self.info.routes { - if route.inject_input_to_destination { - rates[route.destination_offset] += route_inputs[route.index]; + for (input, destination) in self + .route_semantics + .injected_infusion_destinations + .iter() + .enumerate() + { + if let Some(destination) = destination { + rates[*destination] += route_inputs[input]; } } } @@ -451,7 +599,13 @@ impl SharedNativeModel { for event in events.iter_mut() { if let Event::Bolus(bolus) = event { - self.validate_input(bolus.input())?; + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + self.validate_input_for_kind(input, RouteKind::Bolus)?; if self.artifact.has_kernel(KernelRole::RouteLag) { lag_values.fill(0.0); @@ -477,7 +631,7 @@ impl SharedNativeModel { lag_values.as_mut_ptr(), )?; } - let lag = lag_values[bolus.input()]; + let lag = lag_values[input]; if lag != 0.0 { *bolus.mut_time() += lag; } @@ -507,7 +661,7 @@ impl SharedNativeModel { fa_values.as_mut_ptr(), )?; } - let factor = fa_values[bolus.input()]; + let factor = fa_values[input]; if factor != 1.0 { bolus.set_amount(bolus.amount() * factor); } @@ -525,9 +679,14 @@ impl SharedNativeModel { input: usize, amount: f64, ) -> Result<(), PharmsolError> { - self.validate_input(input)?; - let destination = &self.info.routes[input]; - state[destination.destination_offset] += amount; + self.validate_input_for_kind(input, RouteKind::Bolus)?; + let destination = self.route_semantics.bolus_destination(input).ok_or( + PharmsolError::UnsupportedInputRouteKind { + input, + kind: RouteKind::Bolus, + }, + )?; + state[destination] += amount; Ok(()) } @@ -564,13 +723,13 @@ impl SharedNativeModel { &cov_buf, &mut outputs, )?; - if observation.outeq() >= outputs.len() { - return Err(PharmsolError::OuteqOutOfRange { - outeq: observation.outeq(), - nout: outputs.len(), - }); - } - Ok(observation.to_prediction(outputs[observation.outeq()], state.to_vec())) + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + self.validate_output(outeq)?; + Ok(observation.to_prediction(outputs[outeq], state.to_vec())) } } @@ -580,6 +739,7 @@ pub struct NativeOdeModel { solver: OdeSolver, rtol: f64, atol: f64, + cache: Option, } #[derive(Clone, Debug)] @@ -607,6 +767,7 @@ impl NativeOdeModel { solver: OdeSolver::default(), rtol: DEFAULT_ODE_RTOL, atol: DEFAULT_ODE_ATOL, + cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -621,14 +782,6 @@ impl NativeOdeModel { self } - pub fn route_index(&self, name: &str) -> Option { - self.shared.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.shared.output_index(name) - } - pub fn info(&self) -> &NativeModelInfo { self.shared.info.as_ref() } @@ -647,17 +800,14 @@ impl NativeOdeModel { let support_vector: V = DVector::from_vec(support_point.to_vec()).into(); for occasion in subject.occasions() { - let infusion_refs = occasion.infusions_ref(); - let infusions = infusion_refs + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared.validate_input(infusion.input())?; - } - - let mut events = occasion.process_events(None, true); let session = RefCell::new(self.shared.artifact.start_session()?); let mut route_session = session.borrow_mut(); self.shared.apply_route_properties( @@ -742,20 +892,20 @@ impl NativeOdeModel { }, NalgebraContext, ); + let support_point_vec = support_point.to_vec(); let problem = OdeBuilder::::new() .atol(vec![self.atol]) .rtol(self.rtol) .t0(occasion.initial_time()) .h0(1e-3) - .p(support_point.to_vec()) + .p(support_point_vec.clone()) .build_from_eqn(PMProblem::with_params_v( diffeq, self.shared.info.state_len, self.shared.info.route_len, - support_point.to_vec(), support_vector.clone(), occasion.covariates(), - infusion_refs.as_slice(), + infusions.iter(), initial_state, )?)?; @@ -813,9 +963,15 @@ impl NativeOdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; self.shared.apply_bolus( solver.state_mut().y.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -880,19 +1036,195 @@ impl NativeOdeModel { } } -impl NativeAnalyticalModel { - pub(crate) fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { - Self { - shared: Arc::new(SharedNativeModel::new(info, artifact)), +fn runtime_no_lag(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +fn runtime_no_fa(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +#[inline(always)] +fn runtime_ode_predictions( + model: &NativeOdeModel, + subject: &Subject, + support_point: &[f64], +) -> Result { + if let Some(cache) = &model.cache { + let key = ( + subject.hash(), + crate::simulator::equation::spphash(support_point), + ); + if let Some(cached) = cache.get(&key) { + return Ok(cached); + } + + let result = model.estimate_predictions(subject, support_point)?; + cache.insert(key, result.clone()); + Ok(result) + } else { + model.estimate_predictions(subject, support_point) + } +} + +impl crate::simulator::equation::Cache for NativeOdeModel { + fn with_cache_capacity(mut self, size: u64) -> Self { + self.cache = Some(PredictionCache::new(size)); + self + } + + fn enable_cache(mut self) -> Self { + self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); + self + } + + fn clear_cache(&self) { + if let Some(cache) = &self.cache { + cache.invalidate_all(); } } - pub fn route_index(&self, name: &str) -> Option { - self.shared.route_index(name) + fn disable_cache(mut self) -> Self { + self.cache = None; + self } +} - pub fn output_index(&self, name: &str) -> Option { - self.shared.output_index(name) +impl EquationTypes for NativeOdeModel { + type S = V; + type P = SubjectPredictions; +} + +impl EquationPriv for NativeOdeModel { + fn lag(&self) -> &Lag { + &(runtime_no_lag as Lag) + } + + fn fa(&self) -> &Fa { + &(runtime_no_fa as Fa) + } + + fn get_nstates(&self) -> usize { + self.shared.info.state_len + } + + fn get_ndrugs(&self) -> usize { + self.shared.info.route_len + } + + fn get_nouteqs(&self) -> usize { + self.shared.info.output_len + } + + fn metadata(&self) -> Option<&crate::ValidatedModelMetadata> { + None + } + + fn solve( + &self, + _state: &mut Self::S, + _support_point: &[f64], + _covariates: &Covariates, + _infusions: &[Infusion], + _start_time: f64, + _end_time: f64, + ) -> Result<(), PharmsolError> { + unimplemented!("solve is not used for runtime ODE models") + } + + fn process_observation( + &self, + _support_point: &[f64], + _observation: &Observation, + _error_models: Option<&AssayErrorModels>, + _time: f64, + _covariates: &Covariates, + _x: &mut Self::S, + _likelihood: &mut Vec, + _output: &mut Self::P, + ) -> Result<(), PharmsolError> { + unimplemented!("process_observation is not used for runtime ODE models") + } + + fn initial_state( + &self, + _support_point: &[f64], + _covariates: &Covariates, + _occasion_index: usize, + ) -> Self::S { + V::zeros(self.shared.info.state_len, NalgebraContext) + } +} + +impl Equation for NativeOdeModel { + fn estimate_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + Ok(self + .estimate_log_likelihood(subject, support_point, error_models)? + .exp()) + } + + fn estimate_log_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + let bound_error_models = self.bind_error_models(error_models)?; + let predictions = runtime_ode_predictions(self, subject, support_point)?; + predictions.log_likelihood(&bound_error_models) + } + + fn kind() -> EqnKind { + EqnKind::ODE + } + + fn assay_error_models(&self) -> AssayErrorModels { + AssayErrorModels::with_output_names( + self.info() + .outputs + .iter() + .map(|output| output.name.as_str()), + ) + } + + fn estimate_predictions( + &self, + subject: &Subject, + support_point: &[f64], + ) -> Result { + runtime_ode_predictions(self, subject, support_point) + } + + fn simulate_subject( + &self, + subject: &Subject, + support_point: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::P, Option), PharmsolError> { + let bound_error_models = match error_models { + Some(error_models) => Some(self.bind_error_models(error_models)?), + None => None, + }; + + let predictions = runtime_ode_predictions(self, subject, support_point)?; + let likelihood = match bound_error_models.as_ref() { + Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), + None => None, + }; + Ok((predictions, likelihood)) + } +} + +impl NativeAnalyticalModel { + pub(crate) fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { + Self { + shared: Arc::new(SharedNativeModel::new(info, artifact)), + } } pub fn info(&self) -> &NativeModelInfo { @@ -912,17 +1244,14 @@ impl NativeAnalyticalModel { let mut output = SubjectPredictions::default(); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared.validate_input(infusion.input())?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -941,8 +1270,12 @@ impl NativeAnalyticalModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { - self.shared - .apply_bolus(&mut state, bolus.input(), bolus.amount())? + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; + self.shared.apply_bolus(&mut state, input, bolus.amount())? } Event::Infusion(_) => {} Event::Observation(observation) => { @@ -959,6 +1292,7 @@ impl NativeAnalyticalModel { if let Some(next_event) = events.get(index + 1) { self.solve_interval( + &mut *session, &mut state, support_point, occasion.covariates(), @@ -973,8 +1307,10 @@ impl NativeAnalyticalModel { Ok(output) } + #[allow(clippy::too_many_arguments)] fn solve_interval( &self, + session: &mut dyn KernelSession, state: &mut [f64], support_point: &[f64], covariates: &Covariates, @@ -1001,11 +1337,25 @@ impl NativeAnalyticalModel { breakpoints.dedup_by(|lhs, rhs| (*lhs - *rhs).abs() < 1e-12); let mut current = breakpoints[0]; - let projected = project_analytical_parameters(&self.shared.info, support_point)?; + let mut cov_buf = vec![0.0; self.shared.info.covariates.len()]; + let mut derived = vec![0.0; self.shared.info.derived_len]; for next in breakpoints.iter().copied().skip(1) { let dt = next - current; - let route_inputs = active_route_inputs(infusions, current, self.shared.info.route_len); + let route_inputs = + interval_route_inputs(infusions, current, next, self.shared.info.route_len); + self.shared.refresh_derived( + session, + next, + state, + support_point, + covariates, + &route_inputs, + &mut derived, + &mut cov_buf, + )?; + let projected = + project_analytical_parameters(&self.shared.info, support_point, &derived)?; let next_state = apply_analytical_kernel( self.shared.info.analytical.ok_or_else(|| { PharmsolError::OtherError(format!( @@ -1041,14 +1391,6 @@ impl NativeSdeModel { self } - pub fn route_index(&self, name: &str) -> Option { - self.shared.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.shared.output_index(name) - } - pub fn info(&self) -> &NativeModelInfo { self.shared.info.as_ref() } @@ -1066,17 +1408,14 @@ impl NativeSdeModel { let mut output = Array2::from_shape_fn((self.nparticles, 0), |_| Prediction::default()); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared.validate_input(infusion.input())?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -1098,10 +1437,15 @@ impl NativeSdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; for particle in &mut particles { self.shared.apply_bolus( particle.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -1292,11 +1636,33 @@ impl NativeSdeModel { fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> Vec { let mut values = vec![0.0; route_len]; for infusion in infusions { - if infusion.input() < route_len + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && time >= infusion.time() && time <= infusion.time() + infusion.duration() { - values[infusion.input()] += infusion.amount() / infusion.duration(); + values[input] += infusion.amount() / infusion.duration(); + } + } + values +} + +fn interval_route_inputs( + infusions: &[Infusion], + start_time: f64, + end_time: f64, + route_len: usize, +) -> Vec { + let mut values = vec![0.0; route_len]; + for infusion in infusions { + let finish = infusion.time() + infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && start_time >= infusion.time() && end_time <= finish { + values[input] += infusion.amount() / infusion.duration(); } } values @@ -1320,9 +1686,17 @@ fn sort_events(events: &mut [Event]) { }); } +fn canonical_numeric_alias(label: &str, prefix: &str) -> Option { + if label.is_empty() || !label.chars().all(|ch| ch.is_ascii_digit()) { + return None; + } + Some(format!("{prefix}{label}")) +} + fn project_analytical_parameters( info: &NativeModelInfo, support_point: &[f64], + derived: &[f64], ) -> Result { let kernel = info.analytical.ok_or_else(|| { PharmsolError::OtherError(format!( @@ -1339,6 +1713,13 @@ fn project_analytical_parameters( support_point.len() ))); } + + // Analytical authoring models can project kernel arguments through a derive + // kernel by declaring exactly the built-in kernel arity in `derived`. + if derived.len() == arity { + return Ok(V::from_vec(derived.to_vec(), NalgebraContext)); + } + Ok(V::from_vec( support_point[..arity].to_vec(), NalgebraContext, @@ -1483,3 +1864,104 @@ fn apply_analytical_kernel( } } } + +#[cfg(test)] +mod tests { + use super::{ + canonical_numeric_alias, KernelSession, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, + RuntimeArtifact, RuntimeBackend, SharedNativeModel, NUMERIC_OUTPUT_PREFIX, + NUMERIC_ROUTE_PREFIX, + }; + use crate::PharmsolError; + use pharmsol_dsl::execution::KernelRole; + use pharmsol_dsl::{ModelKind, RouteKind}; + + #[derive(Debug)] + struct DummyArtifact; + + impl RuntimeArtifact for DummyArtifact { + fn backend(&self) -> RuntimeBackend { + panic!("dummy artifact backend should not be used in tests") + } + + fn has_kernel(&self, _role: KernelRole) -> bool { + false + } + + fn start_session(&self) -> Result, PharmsolError> { + panic!("dummy artifact sessions should not be used in tests") + } + } + + fn bolus_only_shared_model() -> SharedNativeModel { + SharedNativeModel::new( + NativeModelInfo { + name: "bolus_only".to_string(), + kind: ModelKind::Ode, + parameters: Vec::new(), + covariates: Vec::new(), + routes: vec![NativeRouteInfo { + name: "oral".to_string(), + declaration_index: 0, + index: 0, + kind: Some(RouteKind::Bolus), + destination_offset: 0, + inject_input_to_destination: false, + }], + outputs: vec![NativeOutputInfo { + name: "cp".to_string(), + index: 0, + }], + state_len: 1, + derived_len: 0, + output_len: 1, + route_len: 1, + analytical: None, + particles: None, + }, + DummyArtifact, + ) + } + + #[test] + fn canonical_numeric_alias_maps_bare_numeric_labels_to_contextual_prefixes() { + assert_eq!( + canonical_numeric_alias("1", NUMERIC_ROUTE_PREFIX), + Some("input_1".to_string()) + ); + assert_eq!( + canonical_numeric_alias("10", NUMERIC_OUTPUT_PREFIX), + Some("outeq_10".to_string()) + ); + } + + #[test] + fn canonical_numeric_alias_ignores_symbolic_and_prefixed_labels() { + assert_eq!(canonical_numeric_alias("iv", NUMERIC_ROUTE_PREFIX), None); + assert_eq!( + canonical_numeric_alias("input_1", NUMERIC_ROUTE_PREFIX), + None + ); + assert_eq!( + canonical_numeric_alias("outeq_2", NUMERIC_OUTPUT_PREFIX), + None + ); + } + + #[test] + fn validate_input_for_kind_reports_structured_route_kind_error() { + let model = bolus_only_shared_model(); + + let error = model + .validate_input_for_kind(0, RouteKind::Infusion) + .expect_err("bolus-only route should reject infusion usage"); + + assert!(matches!( + error, + PharmsolError::UnsupportedInputRouteKind { + input: 0, + kind: RouteKind::Infusion, + } + )); + } +} diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1d49d82a..7ccb132d 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -1,3 +1,82 @@ +//! Unified runtime entrypoints for DSL-backed models. +//! +//! Use this module when you already know you want an executable model and need +//! one backend-neutral surface for compile, load, and prediction workflows. +//! It normalizes the backend-specific JIT, native AoT, and WASM entrypoints so +//! callers can choose a deployment target without rewriting the downstream +//! prediction code. +//! +//! Use the backend modules directly only when you need a backend-specific +//! artifact or compile control: +//! +//! - [`super::jit`] for direct in-process JIT compilation. +//! - [`compile_module_source_to_aot`][crate::dsl::compile_module_source_to_aot] for native artifact export and reload. +//! - [`compile_module_source_to_wasm_bytes`][crate::dsl::compile_module_source_to_wasm_bytes] and [`load_runtime_wasm_bytes`] for portable WASM bytes, +//! browser-loader assets, and host-side WASM loading. +//! +//! Main entrypoints: +//! +//! - [`compile_module_source_to_runtime`] for the one-stop source-to-runtime +//! path. +//! - [`compile_execution_model_to_runtime`] when you already have an +//! [`ExecutionModel`](pharmsol_dsl::ExecutionModel). +//! - [`load_runtime_artifact`] and [`load_runtime_wasm_bytes`] when the model +//! has already been compiled and stored elsewhere. +//! - [`CompiledRuntimeModel::estimate_predictions`] for backend-neutral +//! execution against a [`Subject`](crate::Subject). +//! +//! Backend choice guide: +//! +//! - [`RuntimeCompilationTarget::Jit`] keeps compilation and execution inside +//! the current process. Use it for native interactive workflows and tests. +//! - [`RuntimeCompilationTarget::NativeAot`] emits a native artifact and reloads +//! it into the same runtime model shape. Use it when you want reusable native +//! artifacts and can control the target platform. +//! - [`RuntimeCompilationTarget::Wasm`] emits portable WASM bytes and reloads +//! them into the host-side runtime adapter. Choose this target when you need a portable +//! artifact or browser-aligned deployment story. +//! +//! Smallest compile-and-run example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! use pharmsol::prelude::*; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .missing_observation(2.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.2, 50.0])?; +//! assert!(predictions.as_subject().is_some()); +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` + use std::fmt; use std::path::Path; @@ -39,24 +118,39 @@ pub type RuntimeOdeModel = NativeOdeModel; pub type RuntimeAnalyticalModel = NativeAnalyticalModel; pub type RuntimeSdeModel = NativeSdeModel; +/// Selects which backend should produce the executable runtime model. +/// +/// This enum is the main backend-switching point for +/// [`compile_module_source_to_runtime`] and +/// [`compile_execution_model_to_runtime`]. #[derive(Debug, Clone, PartialEq, Eq)] pub enum RuntimeCompilationTarget { + /// Compile and execute the model inside the current native process. #[cfg(feature = "dsl-jit")] Jit, + /// Export a native artifact and reload it as a runtime model. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot(NativeAotCompileOptions), + /// Emit WASM bytes and reload them through the host-side WASM runtime. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Identifies the on-disk artifact format for [`load_runtime_artifact`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RuntimeArtifactFormat { + /// A native ahead-of-time artifact produced by the AoT compiler. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot, + /// A WASM artifact produced by the WASM compiler. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Backend-neutral prediction output from a compiled runtime model. +/// +/// ODE and analytical models return subject predictions. SDE models return the +/// particle matrix used by the stochastic workflow. #[derive(Clone, Debug)] pub enum RuntimePredictions { Subject(SubjectPredictions), @@ -93,6 +187,10 @@ impl RuntimePredictions { } } +/// Executable runtime model returned by the backend-neutral runtime surface. +/// +/// This type hides the concrete backend and keeps the prediction entrypoint the +/// same across JIT, native AoT, and WASM-based flows. #[derive(Clone, Debug)] pub enum CompiledRuntimeModel { Ode(RuntimeOdeModel), @@ -131,22 +229,6 @@ impl CompiledRuntimeModel { self.info().kind } - pub fn route_index(&self, name: &str) -> Option { - match self { - Self::Ode(model) => model.route_index(name), - Self::Analytical(model) => model.route_index(name), - Self::Sde(model) => model.route_index(name), - } - } - - pub fn output_index(&self, name: &str) -> Option { - match self { - Self::Ode(model) => model.output_index(name), - Self::Analytical(model) => model.output_index(name), - Self::Sde(model) => model.output_index(name), - } - } - pub fn estimate_predictions( &self, subject: &Subject, @@ -166,6 +248,8 @@ impl CompiledRuntimeModel { } } +/// Errors produced while parsing, lowering, compiling, loading, or executing a +/// runtime DSL model. #[derive(Error)] pub enum RuntimeError { #[error("failed to parse DSL source: {0}")] @@ -231,6 +315,10 @@ impl fmt::Debug for RuntimeError { } } +/// Parse, analyze, lower, compile, and return a runtime model in one step. +/// +/// Use this when your input is DSL source text and you want the shortest path +/// from source to predictions. pub fn compile_module_source_to_runtime( source: &str, model_name: Option<&str>, @@ -269,6 +357,10 @@ pub fn compile_module_source_to_runtime( }) } +/// Compile a lowered execution model to a selected runtime backend. +/// +/// Use this when you already own the frontend pipeline and only need the final +/// backend step. pub fn compile_execution_model_to_runtime( model: &ExecutionModel, target: RuntimeCompilationTarget, @@ -309,6 +401,7 @@ pub fn compile_execution_model_to_runtime( } } +/// Load a previously compiled native AoT or WASM artifact from disk. pub fn load_runtime_artifact( path: impl AsRef, format: RuntimeArtifactFormat, @@ -330,6 +423,7 @@ pub fn load_runtime_artifact( } #[cfg(feature = "dsl-wasm")] +/// Compile DSL source straight to a host-side runtime model via the WASM path. pub fn compile_module_source_to_runtime_wasm( source: &str, model_name: Option<&str>, @@ -339,6 +433,8 @@ pub fn compile_module_source_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Compile a lowered execution model straight to a host-side runtime model via +/// the WASM path. pub fn compile_execution_model_to_runtime_wasm( model: &ExecutionModel, ) -> Result { @@ -347,6 +443,7 @@ pub fn compile_execution_model_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Load a runtime model from in-memory WASM bytes. pub fn load_runtime_wasm_bytes(bytes: &[u8]) -> Result { let (info, artifact) = load_wasm_artifact_bytes(bytes)?; Ok(runtime_model_from_parts(info, artifact)) @@ -377,11 +474,110 @@ mod tests { use super::*; use crate::dsl::compile_sde_model_to_jit; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; + use crate::PharmsolError; use crate::SubjectBuilderExt; use approx::assert_relative_eq; - use pharmsol_dsl::{DiagnosticPhase, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; + use pharmsol_dsl::{DiagnosticPhase, RouteKind, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; use tempfile::tempdir; + const MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL: &str = r#" +name = multi_digit_output_runtime +kind = ode + +params = ke, v +states = central +outputs = outeq_2, outeq_10, outeq_11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(outeq_10) = central / v ~ continuous() +out(outeq_2) = central / v ~ continuous() +out(outeq_11) = central / v ~ continuous() +"#; + + const NUMERIC_ROUTE_LABELS_RUNTIME_DSL: &str = r#" +name = prefixed_numeric_route_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(input_10) -> central +bolus(input_11) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + const SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = prefixed_numeric_route_output_runtime +kind = ode + +params = ke, v +states = central +outputs = outeq_1 + +infusion(input_1) -> central + +dx(central) = -ke * central + +out(outeq_1) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + fn corpus_source() -> &'static str { STRUCTURED_BLOCK_CORPUS } @@ -397,17 +593,17 @@ mod tests { pharmsol_dsl::lower_typed_model(model).expect("lower corpus model") } - fn ode_subject(output: usize, oral: usize, iv: usize) -> Subject { + fn ode_subject() -> Subject { Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, output) - .missing_observation(1.0, output) - .missing_observation(2.0, output) - .missing_observation(6.0, output) - .missing_observation(7.0, output) - .missing_observation(9.0, output) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build() } @@ -421,6 +617,155 @@ mod tests { .collect() } + fn compile_runtime_backend_matrix( + source: &str, + model_name: &str, + work_dir: &std::path::Path, + ) -> ( + CompiledRuntimeModel, + CompiledRuntimeModel, + CompiledRuntimeModel, + ) { + let jit = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Jit, + |_, _| {}, + ) + .expect("compile jit runtime model"); + let aot = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::NativeAot( + NativeAotCompileOptions::new(work_dir.join(format!("{model_name}-aot-build"))) + .with_output(work_dir.join(format!("{model_name}.pkm"))), + ), + |_, _| {}, + ) + .expect("compile aot runtime model"); + let wasm = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Wasm, + |_, _| {}, + ) + .expect("compile wasm runtime model"); + + (jit, aot, wasm) + } + + fn compiled_route_input_index(model: &CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .routes + .iter() + .find(|route| route.name == name) + .map(|route| route.index) + } + + fn compiled_output_slot_index(model: &CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .outputs + .iter() + .find(|output| output.name == name) + .map(|output| output.index) + } + + fn numeric_route_subject() -> Subject { + Subject::builder("numeric-route-runtime") + .bolus(0.0, 120.0, "input_10") + .bolus(1.0, 80.0, "input_11") + .missing_observation(0.5, "cp") + .missing_observation(1.5, "cp") + .build() + } + + fn numeric_route_alias_subject() -> Subject { + Subject::builder("numeric-route-runtime-alias") + .bolus(0.0, 120.0, "10") + .bolus(1.0, 80.0, "11") + .missing_observation(0.5, "cp") + .missing_observation(1.5, "cp") + .build() + } + + fn shared_numeric_route_output_subject() -> Subject { + Subject::builder("prefixed-numeric-route-output-runtime") + .infusion(0.0, 120.0, "input_1", 1.0) + .missing_observation(0.5, "outeq_1") + .missing_observation(1.5, "outeq_1") + .build() + } + + fn shared_numeric_route_output_alias_subject() -> Subject { + Subject::builder("raw-numeric-route-output-runtime") + .infusion(0.0, 120.0, "1", 1.0) + .missing_observation(0.5, "1") + .missing_observation(1.5, "1") + .build() + } + + fn mismatched_route_kind_subject() -> Subject { + Subject::builder("mismatched-route-kind-runtime") + .infusion(0.0, 120.0, "10", 1.0) + .missing_observation(0.5, "cp") + .build() + } + + fn assert_unknown_output_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == expected_label + )); + } + + fn assert_unknown_input_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == expected_label + )); + } + + fn assert_unsupported_input_route_kind( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_input: usize, + expected_kind: RouteKind, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("mismatched route kind should fail"); + + match error { + RuntimeError::Runtime(PharmsolError::UnsupportedInputRouteKind { input, kind }) + if input == expected_input && kind == expected_kind => {} + other => panic!( + "expected UnsupportedInputRouteKind {{ input: {expected_input}, kind: {:?} }}, got {other:?}", + expected_kind + ), + } + } + #[test] fn runtime_backend_matrix_matches_ode_predictions() { let work_dir = tempdir().expect("tempdir"); @@ -460,10 +805,177 @@ mod tests { vec!["ka", "cl", "v", "tlag", "f_oral"] ); - let oral = jit.route_index("oral").expect("oral route"); - let iv = jit.route_index("iv").expect("iv route"); - let cp = jit.output_index("cp").expect("cp output"); - let subject = ode_subject(cp, oral, iv); + assert!(compiled_route_input_index(&jit, "oral").is_some()); + assert!(compiled_route_input_index(&jit, "iv").is_some()); + assert_eq!(compiled_output_slot_index(&jit, "cp"), Some(0)); + let subject = ode_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_reports_route_kind_mismatch() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let subject = mismatched_route_kind_subject(); + + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "prefixed_numeric_route_runtime", + work_dir.path(), + ); + let expected_input = + compiled_route_input_index(&jit, "input_10").expect("input_10 route index"); + + for model in [&jit, &aot, &wasm] { + assert_unsupported_input_route_kind( + model, + &subject, + &support, + expected_input, + RouteKind::Infusion, + ); + } + } + + #[test] + fn runtime_backend_matrix_preserves_multi_digit_output_label_order() { + let work_dir = tempdir().expect("tempdir"); + let (jit, aot, wasm) = compile_runtime_backend_matrix( + MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL, + "multi_digit_output_runtime", + work_dir.path(), + ); + + assert_eq!(compiled_output_slot_index(&jit, "outeq_2"), Some(0)); + assert_eq!(compiled_output_slot_index(&jit, "outeq_10"), Some(1)); + assert_eq!(compiled_output_slot_index(&jit, "outeq_11"), Some(2)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_2"), Some(0)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_10"), Some(1)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_11"), Some(2)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_2"), Some(0)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_10"), Some(1)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_11"), Some(2)); + } + + #[test] + fn runtime_backend_matrix_supports_prefixed_multi_digit_numeric_route_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "prefixed_numeric_route_runtime", + work_dir.path(), + ); + + assert_eq!(compiled_route_input_index(&jit, "input_10"), Some(0)); + assert_eq!(compiled_route_input_index(&jit, "input_11"), Some(1)); + assert_eq!(compiled_route_input_index(&aot, "input_10"), Some(0)); + assert_eq!(compiled_route_input_index(&aot, "input_11"), Some(1)); + assert_eq!(compiled_route_input_index(&wasm, "input_10"), Some(0)); + assert_eq!(compiled_route_input_index(&wasm, "input_11"), Some(1)); + + let subject = numeric_route_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_resolves_raw_numeric_route_labels_against_prefixed_metadata() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "prefixed_numeric_route_runtime", + work_dir.path(), + ); + + let subject = numeric_route_alias_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_supports_prefixed_numeric_route_and_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL, + "prefixed_numeric_route_output_runtime", + work_dir.path(), + ); + + assert_eq!(compiled_route_input_index(&jit, "input_1"), Some(0)); + assert_eq!(compiled_output_slot_index(&jit, "outeq_1"), Some(0)); + assert_eq!(compiled_route_input_index(&aot, "input_1"), Some(0)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_1"), Some(0)); + assert_eq!(compiled_route_input_index(&wasm, "input_1"), Some(0)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_1"), Some(0)); + + let subject = shared_numeric_route_output_subject(); let jit_values = subject_values( &jit.estimate_predictions(&subject, &support) @@ -489,6 +1001,80 @@ mod tests { } } + #[test] + fn runtime_backend_matrix_resolves_shared_raw_numeric_route_and_output_aliases() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL, + "prefixed_numeric_route_output_runtime", + work_dir.path(), + ); + + let subject = shared_numeric_route_output_alias_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_output_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + + assert_unknown_output_label(&jit, &subject, &support, "10"); + assert_unknown_output_label(&aot, &subject, &support, "10"); + assert_unknown_output_label(&wasm, &subject, &support, "10"); + } + + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_input_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_input_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + + assert_unknown_input_label(&jit, &subject, &support, "10"); + assert_unknown_input_label(&aot, &subject, &support, "10"); + assert_unknown_input_label(&wasm, &subject, &support, "10"); + } + #[test] fn runtime_compile_preserves_parse_diagnostic_structure() { let source = "model broken { kind ode outputs { cp = 1 + } }"; diff --git a/src/dsl/rust_backend.rs b/src/dsl/rust_backend.rs index 19b3e5cc..850e13b7 100644 --- a/src/dsl/rust_backend.rs +++ b/src/dsl/rust_backend.rs @@ -264,7 +264,7 @@ fn emit_load(load: &ExecutionLoad, ty: ValueType) -> Result { ExecutionLoad::Covariate(index) => format!("load_f64(covariates, {index})"), ExecutionLoad::Derived(index) => format!("load_f64(derived, {index})"), ExecutionLoad::Local(index) => return Ok(format!("local_{index}")), - ExecutionLoad::RouteInput(index) => format!("load_f64(routes, {index})"), + ExecutionLoad::RouteInput { index, .. } => format!("load_f64(routes, {index})"), ExecutionLoad::State(state) => { let index = emit_state_ref_index(state)?; format!("load_f64(states, {index})") diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index 16884952..e95b799a 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -406,11 +406,39 @@ impl RuntimeArtifact for WasmExecutionArtifact { } } +/// Read only the metadata from a compiled WASM artifact on disk. +/// +/// Use this when you need model identity, route labels, output labels, or +/// buffer sizes without loading the executable runtime wrapper. pub fn read_wasm_model_info(path: impl AsRef) -> Result { let (info, _) = load_wasm_artifact(path)?; Ok(info) } +/// Read only the metadata from in-memory compiled WASM bytes. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{compile_module_source_to_wasm_bytes, read_wasm_model_info_bytes}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let bytes = compile_module_source_to_wasm_bytes(source, Some("bimodal_ke"))?; +/// let info = read_wasm_model_info_bytes(&bytes)?; +/// assert_eq!(info.name, "bimodal_ke"); +/// # Ok::<(), Box>(()) +/// ``` pub fn read_wasm_model_info_bytes(bytes: &[u8]) -> Result { let (info, _) = load_wasm_artifact_bytes(bytes)?; Ok(info) @@ -778,7 +806,9 @@ mod tests { covariates: Vec::new(), routes: vec![NativeRouteInfo { name: "oral".to_string(), + declaration_index: 0, index: 0, + kind: None, destination_offset: 0, inject_input_to_destination: true, }], diff --git a/src/dsl/wasm_compile.rs b/src/dsl/wasm_compile.rs index 66995e8a..cda4727d 100644 --- a/src/dsl/wasm_compile.rs +++ b/src/dsl/wasm_compile.rs @@ -19,15 +19,24 @@ use pharmsol_dsl::{ LoweringError, ParseError, SemanticError, }; +/// ABI version for compiled WASM artifacts produced by this crate. pub const WASM_API_VERSION: u32 = 1; +/// Default entry capacity for [`WasmCompileCache`]. pub const DEFAULT_WASM_COMPILE_CACHE_CAPACITY: usize = 32; static BROWSER_LOADER_SOURCE: OnceLock = OnceLock::new(); +/// Portable WASM artifact bundle produced by the WASM compiler path. +/// +/// The bundle includes the raw WASM bytes, model metadata, and a browser loader +/// source string that can instantiate the model in JavaScript. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct CompiledWasmModule { + /// Raw compiled WASM bytes. pub wasm_bytes: Vec, + /// Serialized model metadata and kernel availability. pub metadata: CompiledModelInfoEnvelope, + /// JavaScript loader source for browser-side instantiation. pub browser_loader_source: String, } @@ -52,6 +61,7 @@ struct WasmCompileCacheState { lru: VecDeque, } +/// In-memory LRU cache for repeated WASM compilation from the same DSL source. #[derive(Debug)] pub struct WasmCompileCache { capacity: usize, @@ -65,6 +75,7 @@ impl Default for WasmCompileCache { } impl WasmCompileCache { + /// Create a new compile cache with at least one entry of capacity. pub fn new(capacity: usize) -> Self { Self { capacity: capacity.max(1), @@ -72,10 +83,12 @@ impl WasmCompileCache { } } + /// Return the configured cache capacity. pub fn capacity(&self) -> usize { self.capacity } + /// Return the number of cached compiled modules. pub fn entry_count(&self) -> usize { self.state .lock() @@ -84,6 +97,7 @@ impl WasmCompileCache { .len() } + /// Remove all cached compiled modules. pub fn clear(&self) { let mut state = self .state @@ -93,6 +107,8 @@ impl WasmCompileCache { state.lru.clear(); } + /// Compile DSL source to a full WASM module bundle, reusing the cache when + /// possible. pub fn compile_module_source_to_wasm_module( &self, source: &str, @@ -108,6 +124,7 @@ impl WasmCompileCache { Ok(compiled) } + /// Compile DSL source to raw WASM bytes, reusing the cache when possible. pub fn compile_module_source_to_wasm_bytes( &self, source: &str, @@ -145,6 +162,8 @@ impl WasmCompileCache { } } +/// Error produced while compiling, inspecting, or loading a DSL-backed WASM +/// artifact. #[derive(Error)] pub enum WasmError { #[error(transparent)] @@ -224,10 +243,12 @@ impl fmt::Debug for WasmError { } } +/// Compile a lowered execution model to raw WASM bytes. pub fn compile_execution_model_to_wasm_bytes(model: &ExecutionModel) -> Result, WasmError> { emit_execution_model_to_wasm_bytes(model, WASM_API_VERSION) } +/// Compile a lowered execution model to a portable WASM bundle. pub fn compile_execution_model_to_wasm_module( model: &ExecutionModel, ) -> Result { @@ -238,6 +259,7 @@ pub fn compile_execution_model_to_wasm_module( }) } +/// Parse DSL source, lower one selected model, and return raw WASM bytes. pub fn compile_module_source_to_wasm_bytes( source: &str, model_name: Option<&str>, @@ -245,6 +267,35 @@ pub fn compile_module_source_to_wasm_bytes( Ok(compile_module_source_to_wasm_module(source, model_name)?.wasm_bytes) } +/// Parse DSL source, lower one selected model, and return the full WASM bundle. +/// +/// Use this when you want a portable artifact for browser or host-side loading +/// together with the browser loader source. +/// +/// This function requires `dsl-wasm-compile`. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{browser_loader_source, compile_module_source_to_wasm_module}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let compiled = compile_module_source_to_wasm_module(source, Some("bimodal_ke"))?; +/// let loader = browser_loader_source(); +/// # let _ = (compiled, loader); +/// # Ok::<(), pharmsol::dsl::WasmError>(()) +/// ``` pub fn compile_module_source_to_wasm_module( source: &str, model_name: Option<&str>, @@ -282,6 +333,10 @@ fn compile_module_source_to_wasm_module_uncached( compile_execution_model_to_wasm_module(&execution) } +/// Return the JavaScript loader source for browser-side WASM model execution. +/// +/// This helper is useful when you want to ship compiled WASM bytes together +/// with the minimal browser glue code that understands the pharmsol ABI. pub fn browser_loader_source() -> String { BROWSER_LOADER_SOURCE .get_or_init(build_browser_loader_source) @@ -848,7 +903,7 @@ mod tests { }; const SIMPLE_SOURCE: &str = r#" -model = example_ode +name = example_ode kind = ode params = ke, v @@ -901,7 +956,7 @@ out(cp) = central / v ~ continuous() cache .compile_module_source_to_wasm_module( r#" -model = second_ode +name = second_ode kind = ode params = ke, v @@ -949,7 +1004,7 @@ out(cp) = central / v ~ continuous() #[test] fn compile_module_source_to_wasm_module_preserves_semantic_diagnostic_structure() { let source = r#" -model = broken +name = broken kind = ode states = central @@ -995,7 +1050,7 @@ out(cp) = central ~ continuous() #[test] fn compile_module_source_to_wasm_module_preserves_lowering_diagnostic_structure() { let source = r#" -model = broken +name = broken kind = ode states = transit[4], central diff --git a/src/dsl/wasm_direct_emitter.rs b/src/dsl/wasm_direct_emitter.rs index 2d92ad1d..857f2ac7 100644 --- a/src/dsl/wasm_direct_emitter.rs +++ b/src/dsl/wasm_direct_emitter.rs @@ -922,7 +922,7 @@ fn emit_load( function.instruction(&Instruction::LocalGet(local.wasm_local)); emit_cast_stack(local.ty, target_ty, function, state.model_name) } - ExecutionLoad::RouteInput(index) => emit_dense_load( + ExecutionLoad::RouteInput { index, .. } => emit_dense_load( function, KERNEL_PARAM_ROUTES, *index, diff --git a/src/error/mod.rs b/src/error/mod.rs index 1316b8a4..1e97aee0 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -1,5 +1,7 @@ use thiserror::Error; +use pharmsol_dsl::RouteKind; + #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] use crate::data::error_model::ErrorModelError; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] @@ -37,7 +39,13 @@ pub enum PharmsolError { ZeroLikelihood, #[error("Missing observation in prediction")] MissingObservation, - #[error("Input channel {input} is out of range (ndrugs = {ndrugs})")] + #[error("Input label `{label}` could not be resolved to a route input")] + UnknownInputLabel { label: String }, + #[error("Output label `{label}` could not be resolved to an output")] + UnknownOutputLabel { label: String }, + #[error("Input index {input} does not support route kind {kind:?}")] + UnsupportedInputRouteKind { input: usize, kind: RouteKind }, + #[error("Input index {input} is out of range (ndrugs = {ndrugs})")] InputOutOfRange { input: usize, ndrugs: usize }, #[error("Output equation {outeq} is out of range (nout = {nout})")] OuteqOutOfRange { outeq: usize, nout: usize }, diff --git a/src/lib.rs b/src/lib.rs index f2691579..9c9e40b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,105 @@ +//! `pharmsol` is a Rust library for pharmacometric work. +//! +//! You can use it to: +//! +//! - build PK/PD datasets from dose and observation events +//! - simulate analytical, ODE, and SDE models +//! - run non-compartmental analysis (NCA) +//! - compile and run models from the pharmsol DSL when the DSL features are enabled +//! +//! Most users start in one of these places: +//! +//! - [`prelude`] for the common types, traits, and macros +//! - [`data`] to build subjects, occasions, events, and covariates +//! - [`simulator`] to define models and generate predictions +//! - [`nca`] to calculate NCA metrics from the same data structures +//! - [`optimize`] for optimizer-oriented workflows +//! +//! The DSL runtime surface is feature-gated. When you enable `dsl-core`, the +//! `pharmsol::dsl` module adds parsing, analysis, lowering, compile, and runtime +//! entrypoints for models written as DSL source text. +//! +//! ## Quick Start +//! +//! This example shows the smallest full workflow: define a model, build a +//! subject, and generate predictions. +//! +//! ```rust +//! use pharmsol::prelude::*; +//! +//! let model = analytical! { +//! name: "one_cmt_iv", +//! params: [ke, v], +//! states: [central], +//! outputs: [cp], +//! routes: [ +//! infusion(iv) -> central, +//! ], +//! structure: one_compartment, +//! out: |x, _p, _t, _cov, y| { +//! y[cp] = x[central] / v; +//! }, +//! }; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.022, 194.0])?; +//! assert_eq!(predictions.flat_predictions().len(), 2); +//! # Ok::<(), pharmsol::PharmsolError>(()) +//! ``` +//! +//! ## Choose A Workflow +//! +//! Use this guide when you are deciding where to start. +//! +//! | Task | Start Here | Notes | +//! | --- | --- | --- | +//! | Build subject data | [`data`] or [`prelude`] | Best when you already know dose times, labels, and observations. | +//! | Simulate a model written in Rust | [`simulator`] or [`prelude`] | Supports analytical, ODE, and SDE models. | +//! | Run NCA | [`nca`] or [`prelude`] | Reuses the same `Subject`, `Occasion`, and `Data` types. | +//! | Use optimization helpers | [`optimize`] | Intended for advanced workflows. | +//! | Parse or compile DSL source | `pharmsol::dsl` | Requires one or more DSL features. | +//! +//! ## Feature Guide +//! +//! Core simulation and NCA APIs do not need extra crate features on native +//! targets. +//! +//! DSL work is feature-gated: +//! +//! - `dsl-core`: exposes the `pharmsol::dsl` facade and frontend types +//! - `dsl-jit`: adds in-process JIT compilation +//! - `dsl-aot`: adds native ahead-of-time artifact compilation +//! - `dsl-aot-load`: adds native artifact loading +//! - `dsl-wasm-compile`: adds WASM artifact generation +//! - `dsl-wasm`: adds WASM runtime loading and execution +//! +//! ## Labels And Indices +//! +//! Public data APIs use route labels and output labels such as `"iv"`, +//! `"oral"`, and `"cp"`. +//! +//! Use labels in builders and parsed data unless you are deliberately working +//! with dense internal indices from a lower-level API. +//! +//! ## Platform Notes +//! +//! The main `data`, `simulator`, `nca`, and `optimize` modules are documented +//! for native targets. Some surfaces are not built on `wasm32-unknown-unknown`. +//! The DSL runtime also has feature-specific platform limits. +//! +//! ## Next Stops +//! +//! - Start with [`prelude`] if you want one import for the common workflow. +//! - Open [`data`] if you need to construct subjects or parse input files. +//! - Open [`simulator`] if you need predictions from analytical, ODE, or SDE models. +//! - Open [`nca`] if you need exposure and terminal metrics. +//! - Use `pharmsol::dsl` if the model comes from source text instead of Rust code. + #[cfg(feature = "dsl-aot")] mod build_support; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] @@ -28,37 +130,52 @@ pub use crate::data::Interpolation::*; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::data::*; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::equation::*; -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::optimize::effect::get_e2; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::optimize::spp::SppOptimizer; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::simulator::equation::analytical::*; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::simulator::equation::metadata; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::simulator::equation::{ self, ode::{ExplicitRkTableau, OdeSolver, SdirkTableau}, - ODE, + Analytical, AnalyticalKernel, Cache, Equation, ModelKind, ModelMetadata, ModelMetadataError, + NameDomain, Predictions, RouteInputPolicy, RouteKind, State, ValidatedModelMetadata, ODE, SDE, }; pub use error::PharmsolError; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use nalgebra::dmatrix; -pub use pharmsol_macros::ode; +pub use pharmsol_macros::{analytical, ode, sde}; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use std::collections::HashMap; -/// Prelude module that re-exports all commonly used types and traits. +/// Common imports for the main pharmsol workflow. +/// +/// Use the prelude when you want one import that covers the common public API: /// -/// Use `use pharmsol::prelude::*;` to import everything needed for basic -/// pharmacometric modeling. +/// - subject and dataset types +/// - subject builders and events +/// - simulation types and prediction results +/// - NCA traits and option types +/// - declaration-first macros such as [`crate::ode`] and [`crate::analytical`] +/// +/// This is the fastest way to get started with examples, scripts, and small +/// applications. +/// +/// If you need a narrower import surface, use the modules directly instead. /// /// # Example /// ```rust /// use pharmsol::prelude::*; /// /// let subject = Subject::builder("patient_001") -/// .bolus(0.0, 100.0, 0) -/// .observation(1.0, 10.5, 0) +/// .infusion(0.0, 100.0, "iv", 1.0) +/// .missing_observation(1.0, "cp") /// .build(); +/// +/// assert_eq!(subject.id(), "patient_001"); /// ``` #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod prelude { @@ -92,7 +209,7 @@ pub mod prelude { pub use crate::data::auc::{auc, auc_interval, aumc, interpolate_linear}; #[allow(deprecated)] - // Simulator submodule for internal use and advanced users + // Simulator submodule for organized access to simulation types. pub mod simulator { pub use crate::simulator::{ cache::{self, PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE}, @@ -136,6 +253,8 @@ pub mod prelude { // Re-export macros (they are exported at crate root via #[macro_export]) #[doc(inline)] + pub use crate::analytical; + #[doc(inline)] pub use crate::fa; #[doc(inline)] pub use crate::fetch_cov; @@ -145,6 +264,8 @@ pub mod prelude { pub use crate::lag; #[doc(inline)] pub use crate::ode; + #[doc(inline)] + pub use crate::sde; } #[macro_export] diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 0aff0936..1dcc3a8f 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -8,6 +8,8 @@ pub mod two_compartment_models; use diffsol::{NalgebraContext, Vector, VectorHost}; pub use one_compartment_cl_models::*; pub use one_compartment_models::*; +use pharmsol_dsl::ModelKind; +use thiserror::Error; pub use three_compartment_cl_models::*; pub use three_compartment_models::*; pub use two_compartment_cl_models::*; @@ -15,12 +17,26 @@ pub use two_compartment_models::*; use super::spphash; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, + ValidatedModelMetadata, +}; use crate::data::error_model::AssayErrorModels; use crate::simulator::cache::{PredictionCache, DEFAULT_CACHE_SIZE}; use crate::PharmsolError; -use crate::{ - data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject, -}; +use crate::{data::Covariates, simulator::*, Observation, Subject}; + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum AnalyticalMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("analytical model declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error("analytical model declares {declared} route metadata entries but model has {expected} inputs")] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("analytical model declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} /// Model equation using analytical solutions. /// @@ -35,6 +51,7 @@ pub struct Analytical { init: Init, out: Out, neqs: Neqs, + metadata: Option, cache: Option, } @@ -88,6 +105,7 @@ impl Analytical { init, out, neqs: Neqs::default(), + metadata: None, cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -95,20 +113,86 @@ impl Analytical { /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } + + /// Attach validated handwritten-model metadata to this analytical model. + pub fn with_metadata( + mut self, + metadata: ModelMetadata, + ) -> Result { + let metadata = metadata.validate_for(ModelKind::Analytical)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + /// Access the validated metadata attached to this analytical model, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), AnalyticalMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(AnalyticalMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_input_count(); + if declared_routes != neqs.ndrugs { + return Err(AnalyticalMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(AnalyticalMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for Analytical { @@ -184,6 +268,11 @@ impl EquationPriv for Analytical { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -227,13 +316,19 @@ impl EquationPriv for Analytical { let s = inf.time(); let e = s + inf.duration(); if current_t >= s && next_t <= e { - if inf.input() >= self.get_ndrugs() { + let input = + inf.input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: inf.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: inf.input(), + input, ndrugs: self.get_ndrugs(), }); } - rateiv[inf.input()] += inf.amount() / inf.duration(); + rateiv[input] += inf.amount() / inf.duration(); } } @@ -271,7 +366,12 @@ impl EquationPriv for Analytical { covariates, &mut y, ); - let pred = y[observation.outeq()]; + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + let pred = y[outeq]; let pred = observation.to_prediction(pred, x.as_slice().to_vec()); if let Some(error_models) = error_models { likelihood.push(pred.log_likelihood(error_models)?.exp()); @@ -302,6 +402,7 @@ pub(crate) mod tests { use crate::SubjectBuilderExt; use approx::assert_relative_eq; use diffsol::Vector; + use pharmsol_dsl::AnalyticalKernel; use std::collections::HashMap; pub(crate) enum SubjectInfo { @@ -423,6 +524,212 @@ pub(crate) mod tests { assert_eq!(predictions.predictions()[0].prediction(), 4.0); } + fn simple_analytical() -> Analytical { + let eq = |x: &V, _p: &V, _dt: f64, _rateiv: &V, _cov: &Covariates| x.clone(); + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[0]; + }; + + Analytical::new(eq, seq_eq, lag, fa, init, out) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + #[test] + fn handwritten_analytical_metadata_exposes_name_lookup() { + let analytical = simple_analytical() + .with_metadata( + super::super::metadata::new("one_cmt_analytical") + .parameters(["ke", "v"]) + .covariates([super::super::Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate"); + let metadata = analytical.metadata().expect("metadata exists"); + + assert_eq!(analytical.parameter_index("ke"), Some(0)); + assert_eq!(analytical.parameter_index("v"), Some(1)); + assert_eq!(analytical.covariate_index("wt"), Some(0)); + assert_eq!(analytical.state_index("central"), Some(0)); + assert!(metadata.route("iv").is_some()); + assert!(metadata.output("cp").is_some()); + assert_eq!(metadata.kind(), ModelKind::Analytical); + } + + #[test] + fn handwritten_analytical_metadata_resolves_raw_numeric_aliases_against_canonical_labels() { + let eq = |x: &V, _p: &V, dt: f64, rateiv: &V, _cov: &Covariates| { + let mut next = x.clone(); + next[0] += rateiv[0] * dt; + next + }; + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[0]; + }; + + let analytical = Analytical::new(eq, seq_eq, lag, fa, init, out) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("numeric_alias_analytical") + .states(["central"]) + .outputs(["outeq_1"]) + .route(super::super::Route::infusion("input_1").to_state("central")), + ) + .expect("metadata attachment should validate"); + + let canonical = Subject::builder("canonical") + .infusion(0.0, 100.0, "input_1", 1.0) + .observation(1.0, 0.0, "outeq_1") + .build(); + let aliased = Subject::builder("aliased") + .infusion(0.0, 100.0, "1", 1.0) + .observation(1.0, 0.0, "1") + .build(); + + let canonical_predictions = analytical + .estimate_predictions(&canonical, &[]) + .expect("canonical labels should simulate"); + let aliased_predictions = analytical + .estimate_predictions(&aliased, &[]) + .expect("raw numeric aliases should simulate"); + + assert_relative_eq!( + canonical_predictions.predictions()[0].prediction(), + aliased_predictions.predictions()[0].prediction(), + epsilon = 1e-10 + ); + } + + #[test] + fn handwritten_analytical_without_metadata_keeps_raw_path() { + let analytical = simple_analytical(); + + assert!(analytical.metadata().is_none()); + assert_eq!(analytical.state_index("central"), None); + } + + #[test] + fn handwritten_analytical_rejects_dimension_mismatches() { + let error = simple_analytical() + .with_metadata( + super::super::metadata::new("wrong_outputs") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp", "auc"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect_err("output-count mismatches must fail"); + + assert_eq!( + error, + AnalyticalMetadataError::OutputCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_analytical_rejects_particles_metadata() { + let error = simple_analytical() + .with_metadata( + super::super::metadata::new("invalid_particles") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")) + .particles(64), + ) + .expect_err("analytical metadata must reject particles"); + + assert_eq!( + error, + AnalyticalMetadataError::Validation(ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Analytical, + }) + ); + } + + #[test] + fn built_in_analytical_models_can_advertise_kernel_identity() { + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + let analytical = + Analytical::new(one_compartment_with_absorption, seq_eq, lag, fa, init, out) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("one_cmt_abs") + .parameters(["ka", "ke", "v"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral").to_state("gut"), + super::super::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("built-in analytical metadata should validate"); + + assert_eq!( + analytical + .metadata() + .expect("metadata exists") + .analytical_kernel(), + Some(AnalyticalKernel::OneCompartmentWithAbsorption) + ); + let metadata = analytical.metadata().expect("metadata exists"); + assert_eq!( + metadata.route("oral").map(|route| route.input_index()), + Some(0) + ); + assert_eq!( + metadata.route("iv").map(|route| route.input_index()), + Some(0) + ); + } + + #[test] + fn changing_dimensions_after_metadata_clears_analytical_metadata() { + let analytical = simple_analytical() + .with_metadata( + super::super::metadata::new("one_cmt_analytical") + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate") + .with_ndrugs(2); + + assert!(analytical.metadata().is_none()); + } + fn assert_pm_wrapper_matches_native( native: AnalyticalEq, wrapper: AnalyticalEq, @@ -563,12 +870,13 @@ impl Equation for Analytical { support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = self.bind_error_models(error_models)?; let ypred = _subject_predictions(self, subject, support_point)?; - ypred.log_likelihood(error_models) + ypred.log_likelihood(&bound_error_models) } - fn kind() -> crate::EqnKind { - crate::EqnKind::Analytical + fn kind() -> EqnKind { + EqnKind::Analytical } } @@ -598,6 +906,7 @@ fn _estimate_likelihood( support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = ode.bind_error_models(error_models)?; let ypred = _subject_predictions(ode, subject, support_point)?; - Ok(ypred.log_likelihood(error_models)?.exp()) + Ok(ypred.log_likelihood(&bound_error_models)?.exp()) } diff --git a/src/simulator/equation/meta.rs b/src/simulator/equation/meta.rs deleted file mode 100644 index 1b38ae35..00000000 --- a/src/simulator/equation/meta.rs +++ /dev/null @@ -1,64 +0,0 @@ -#[repr(C)] -#[derive(Debug, Clone)] -/// Model metadata container. -/// -/// This structure holds the metadata associated with a pharmacometric model, -/// including parameter names and other model-specific information that needs -/// to be preserved across simulation and estimation activities. -/// -/// # Examples -/// -/// ``` -/// use pharmsol::simulator::equation::Meta; -/// -/// let model_metadata = Meta::new(vec!["CL", "V", "KA"]); -/// assert_eq!(model_metadata.get_params().len(), 3); -/// ``` -pub struct Meta { - params: Vec, -} - -impl Meta { - /// Creates a new metadata container with the specified parameter names. - /// - /// # Arguments - /// - /// * `params` - A vector of string slices representing parameter names - /// - /// # Returns - /// - /// A new `Meta` instance containing the converted parameter names - /// - /// # Examples - /// - /// ``` - /// use pharmsol::simulator::equation::Meta; - /// - /// let metadata = Meta::new(vec!["CL", "V", "KA"]); - /// ``` - pub fn new(params: Vec<&str>) -> Self { - let params = params.iter().map(|x| x.to_string()).collect(); - Meta { params } - } - - /// Retrieves the parameter names stored in this metadata container. - /// - /// # Returns - /// - /// A reference to the vector of parameter names - /// - /// # Examples - /// - /// ``` - /// use pharmsol::simulator::equation::Meta; - /// - /// let metadata = Meta::new(vec!["CL", "V", "KA"]); - /// let params = metadata.get_params(); - /// assert_eq!(params[0], "CL"); - /// assert_eq!(params[1], "V"); - /// assert_eq!(params[2], "KA"); - /// ``` - pub fn get_params(&self) -> &Vec { - &self.params - } -} diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs new file mode 100644 index 00000000..a512381e --- /dev/null +++ b/src/simulator/equation/metadata.rs @@ -0,0 +1,1281 @@ +//! Metadata builders and validated metadata views for handwritten models. +//! +//! Use this module when a handwritten [`crate::ODE`], [`crate::Analytical`], or +//! [`crate::SDE`] model should expose the same public names that appear in data +//! rows, subject builders, or parsed files. +//! +//! Metadata gives names to parameters, covariates, states, routes, and outputs. +//! After validation, the execution layer can resolve public labels such as +//! `"iv"` and `"cp"` against those declarations before simulation. +//! +//! Without metadata, handwritten models fall back to numeric labels. With +//! metadata, labels are matched by name. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.name(), "one_cmt"); +//! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); +//! assert!(metadata.output("cp").is_some()); +//! ``` + +use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind}; +use std::fmt; +use thiserror::Error; + +/// Shorthand for [`ModelMetadata::new`]. +pub fn new(name: impl Into) -> ModelMetadata { + ModelMetadata::new(name) +} + +/// Validation errors for handwritten model metadata. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum ModelMetadataError { + #[error("model kind is required for metadata validation")] + MissingModelKind, + #[error("metadata declares kind `{declared:?}` but validation requested `{requested:?}`")] + ModelKindConflict { + declared: ModelKind, + requested: ModelKind, + }, + #[error("duplicate {domain} name `{name}`")] + DuplicateName { domain: NameDomain, name: String }, + #[error("route `{route}` must declare a destination state")] + MissingRouteDestination { route: String }, + #[error("route `{route}` targets unknown state `{destination}`")] + UnknownRouteDestination { route: String, destination: String }, + #[error("infusion route `{route}` cannot declare lag")] + InfusionLagNotAllowed { route: String }, + #[error("infusion route `{route}` cannot declare bioavailability")] + InfusionBioavailabilityNotAllowed { route: String }, + #[error("{kind:?} metadata cannot declare particles")] + ParticlesNotAllowed { kind: ModelKind }, + #[error("Sde metadata requires particles")] + MissingParticles, + #[error( + "metadata declares {declared} particle(s) but validation provided {fallback} fallback particle(s)" + )] + ParticleCountConflict { declared: usize, fallback: usize }, + #[error("{kind:?} metadata cannot declare an analytical kernel")] + AnalyticalKernelNotAllowed { kind: ModelKind }, +} + +/// Name domain used in duplicate-name validation messages. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NameDomain { + Parameter, + Covariate, + State, + Route, + Output, +} + +impl fmt::Display for NameDomain { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let domain = match self { + Self::Parameter => "parameter", + Self::Covariate => "covariate", + Self::State => "state", + Self::Route => "route", + Self::Output => "output", + }; + f.write_str(domain) + } +} + +/// Validated metadata view used by the execution layer. +/// +/// This type is what handwritten equation builders store after metadata has +/// passed validation. It provides stable lookup helpers from public names to the +/// dense indices used during execution. +/// +/// Route lookups expose two different indices: +/// - [`ValidatedModelMetadata::route_declaration_index`] is the route position in +/// declaration order. +/// - [`ValidatedModelMetadata::route_index`] is the dense execution input index +/// for that route kind. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedModelMetadata { + name: String, + kind: ModelKind, + parameters: Vec, + covariates: Vec, + states: Vec, + routes: Vec, + route_input_count: usize, + outputs: Vec, + particles: Option, + analytical: Option, +} + +impl ValidatedModelMetadata { + /// Get the public model name. + pub fn name(&self) -> &str { + &self.name + } + + /// Get the validated model family. + pub fn kind(&self) -> ModelKind { + self.kind + } + + pub fn parameters(&self) -> &[Parameter] { + &self.parameters + } + + pub fn covariates(&self) -> &[Covariate] { + &self.covariates + } + + pub fn states(&self) -> &[State] { + &self.states + } + + pub fn routes(&self) -> &[ValidatedRoute] { + &self.routes + } + + /// Get the number of dense execution input slots needed for routes. + /// + /// This is the maximum of the bolus-route count and infusion-route count. + pub fn route_input_count(&self) -> usize { + self.route_input_count + } + + pub fn outputs(&self) -> &[Output] { + &self.outputs + } + + pub fn particles(&self) -> Option { + self.particles + } + + pub fn analytical_kernel(&self) -> Option { + self.analytical + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.parameters + .iter() + .position(|parameter| parameter.name() == name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.covariates + .iter() + .position(|covariate| covariate.name() == name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.states.iter().position(|state| state.name() == name) + } + + /// Look up a route by public name and return its declaration-order index. + pub fn route_declaration_index(&self, name: &str) -> Option { + self.routes.iter().position(|route| route.name() == name) + } + + /// Look up an output by public name and return its dense output index. + pub(crate) fn output_index(&self, name: &str) -> Option { + self.outputs.iter().position(|output| output.name() == name) + } + + pub fn parameter(&self, name: &str) -> Option<&Parameter> { + self.parameter_index(name) + .map(|index| &self.parameters[index]) + } + + pub fn covariate(&self, name: &str) -> Option<&Covariate> { + self.covariate_index(name) + .map(|index| &self.covariates[index]) + } + + pub fn state(&self, name: &str) -> Option<&State> { + self.state_index(name).map(|index| &self.states[index]) + } + + pub fn route(&self, name: &str) -> Option<&ValidatedRoute> { + self.route_declaration_index(name) + .map(|index| &self.routes[index]) + } + + pub fn output(&self, name: &str) -> Option<&Output> { + self.output_index(name).map(|index| &self.outputs[index]) + } +} + +/// One validated route declaration with resolved execution details. +/// +/// A validated route keeps both the declaration-order index and the dense input +/// index used during execution. Those values can differ from each other when a +/// model mixes bolus and infusion routes. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedRoute { + name: String, + kind: RouteKind, + declaration_index: usize, + input_index: usize, + destination: String, + destination_index: usize, + has_lag: bool, + has_bioavailability: bool, + input_policy: Option, +} + +impl ValidatedRoute { + /// Get the public route name used for label matching. + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> RouteKind { + self.kind + } + + /// Get the declaration-order index for this route. + pub fn declaration_index(&self) -> usize { + self.declaration_index + } + + /// Get the dense execution input index for this route kind. + pub fn input_index(&self) -> usize { + self.input_index + } + + /// Get the destination state name. + pub fn destination(&self) -> &str { + &self.destination + } + + /// Get the destination state index in model order. + pub fn destination_index(&self) -> usize { + self.destination_index + } + + pub fn has_lag(&self) -> bool { + self.has_lag + } + + pub fn has_bioavailability(&self) -> bool { + self.has_bioavailability + } + + pub fn input_policy(&self) -> Option { + self.input_policy + } +} + +/// Builder for handwritten model metadata. +/// +/// Use [`ModelMetadata`] to declare the public names that should be attached to +/// a handwritten equation. After validation, the resulting metadata can be +/// attached to handwritten [`crate::ODE`], [`crate::Analytical`], and +/// [`crate::SDE`] models through their `with_metadata(...)` methods. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelMetadata { + name: String, + kind: Option, + parameters: Vec, + covariates: Vec, + states: Vec, + routes: Vec, + outputs: Vec, + particles: Option, + analytical: Option, +} + +impl ModelMetadata { + /// Create a new metadata builder with a model name. + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + kind: None, + parameters: Vec::new(), + covariates: Vec::new(), + states: Vec::new(), + routes: Vec::new(), + outputs: Vec::new(), + particles: None, + analytical: None, + } + } + + /// Set the model kind explicitly. + pub fn kind(mut self, kind: ModelKind) -> Self { + self.kind = Some(kind); + self + } + + /// Replace the ordered parameter list. + pub fn parameters(mut self, parameters: I) -> Self + where + I: IntoIterator, + Parameter: From, + { + self.parameters = parameters.into_iter().map(Parameter::from).collect(); + self + } + + /// Replace the ordered covariate list. + pub fn covariates(mut self, covariates: I) -> Self + where + I: IntoIterator, + { + self.covariates = covariates.into_iter().collect(); + self + } + + /// Replace the ordered state list. + pub fn states(mut self, states: I) -> Self + where + I: IntoIterator, + State: From, + { + self.states = states.into_iter().map(State::from).collect(); + self + } + + /// Add one route declaration. + pub fn route(mut self, route: Route) -> Self { + self.routes.push(route); + self + } + + /// Extend with multiple route declarations. + pub fn routes(mut self, routes: I) -> Self + where + I: IntoIterator, + { + self.routes.extend(routes); + self + } + + /// Replace the ordered output list. + pub fn outputs(mut self, outputs: I) -> Self + where + I: IntoIterator, + Output: From, + { + self.outputs = outputs.into_iter().map(Output::from).collect(); + self + } + + /// Set the particle count for stochastic models. + pub fn particles(mut self, particles: usize) -> Self { + self.particles = Some(particles); + self + } + + /// Set the analytical kernel identity for built-in analytical models. + pub fn analytical_kernel(mut self, analytical: AnalyticalKernel) -> Self { + self.analytical = Some(analytical); + self + } + + /// Get the model name. + pub fn name(&self) -> &str { + &self.name + } + + /// Get the explicit model kind, if already declared. + pub fn kind_decl(&self) -> Option { + self.kind + } + + /// Get the ordered parameter metadata. + pub fn parameters_decl(&self) -> &[Parameter] { + &self.parameters + } + + /// Get the ordered covariate metadata. + pub fn covariates_decl(&self) -> &[Covariate] { + &self.covariates + } + + /// Get the ordered state metadata. + pub fn states_decl(&self) -> &[State] { + &self.states + } + + /// Get the ordered route metadata. + pub fn routes_decl(&self) -> &[Route] { + &self.routes + } + + /// Get the ordered output metadata. + pub fn outputs_decl(&self) -> &[Output] { + &self.outputs + } + + /// Get the declared particle count. + pub fn particles_decl(&self) -> Option { + self.particles + } + + /// Get the declared analytical kernel identity. + pub fn analytical_kernel_decl(&self) -> Option { + self.analytical + } + + /// Validate this metadata using its declared kind. + /// + /// Use this when the metadata itself already declares whether the model is + /// ODE, analytical, or SDE. + pub fn validate(self) -> Result { + self.validate_internal(None, None) + } + + /// Validate this metadata for a specific model kind. + /// + /// Use this when the equation type determines the model family and you want + /// validation to enforce that family explicitly. + pub fn validate_for( + self, + kind: ModelKind, + ) -> Result { + self.validate_internal(Some(kind), None) + } + + /// Validate this metadata for a specific model kind, using a fallback + /// particle count when the metadata itself does not declare one. + pub fn validate_for_with_particles( + self, + kind: ModelKind, + fallback_particles: usize, + ) -> Result { + self.validate_internal(Some(kind), Some(fallback_particles)) + } + + fn validate_internal( + self, + requested_kind: Option, + fallback_particles: Option, + ) -> Result { + let kind = resolve_kind(self.kind, requested_kind)?; + validate_unique_names(&self.parameters, NameDomain::Parameter, Parameter::name)?; + validate_unique_names(&self.covariates, NameDomain::Covariate, Covariate::name)?; + validate_unique_names(&self.states, NameDomain::State, State::name)?; + validate_unique_names(&self.routes, NameDomain::Route, Route::name)?; + validate_unique_names(&self.outputs, NameDomain::Output, Output::name)?; + + let particles = resolve_particles(kind, self.particles, fallback_particles)?; + validate_kind_specific_fields(kind, self.analytical, particles)?; + + let (routes, route_input_count) = validate_routes(self.routes, &self.states)?; + + Ok(ValidatedModelMetadata { + name: self.name, + kind, + parameters: self.parameters, + covariates: self.covariates, + states: self.states, + routes, + route_input_count, + outputs: self.outputs, + particles, + analytical: self.analytical, + }) + } +} + +/// One named parameter in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Parameter { + name: String, +} + +impl Parameter { + /// Create a named parameter declaration. + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for Parameter +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// One named covariate plus interpolation semantics. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Covariate { + name: String, + interpolation: Option, +} + +impl Covariate { + /// Create a named covariate without an explicit interpolation policy. + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + interpolation: None, + } + } + + /// Create a continuous covariate that uses linear interpolation. + pub fn continuous(name: impl Into) -> Self { + Self::new(name).with_interpolation(CovariateInterpolation::Linear) + } + + /// Create a covariate that uses last-observation-carried-forward semantics. + pub fn locf(name: impl Into) -> Self { + Self::new(name).with_interpolation(CovariateInterpolation::Locf) + } + + /// Set the interpolation policy explicitly. + pub fn with_interpolation(mut self, interpolation: CovariateInterpolation) -> Self { + self.interpolation = Some(interpolation); + self + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn interpolation(&self) -> Option { + self.interpolation + } +} + +/// One named state in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct State { + name: String, +} + +impl State { + /// Create a named state declaration. + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for State +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// One named output in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Output { + name: String, +} + +impl Output { + /// Create a named output declaration. + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for Output +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// Route declaration kind. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteKind { + /// Instantaneous dose input. + Bolus, + /// Dose input over a duration. + Infusion, +} + +/// How route inputs should be interpreted by the execution layer. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteInputPolicy { + /// Inject the resolved input directly into the declared destination state. + InjectToDestination, + /// Expect the low-level execution path to provide an explicit input vector. + ExplicitInputVector, +} + +/// One named route declaration. +/// +/// Route names are the public labels matched against dose events such as `iv` +/// or `oral`. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Route { + name: String, + kind: RouteKind, + destination: Option, + has_lag: bool, + has_bioavailability: bool, + input_policy: Option, +} + +impl Route { + /// Create a named bolus route declaration. + pub fn bolus(name: impl Into) -> Self { + Self::new(name, RouteKind::Bolus) + } + + /// Create a named infusion route declaration. + pub fn infusion(name: impl Into) -> Self { + Self::new(name, RouteKind::Infusion) + } + + /// Create a route declaration with an explicit kind. + pub fn new(name: impl Into, kind: RouteKind) -> Self { + Self { + name: name.into(), + kind, + destination: None, + has_lag: false, + has_bioavailability: false, + input_policy: None, + } + } + + /// Declare which state this route targets. + pub fn to_state(mut self, destination: impl Into) -> Self { + self.destination = Some(destination.into()); + self + } + + /// Mark this route as supporting lag handling. + pub fn with_lag(mut self) -> Self { + self.has_lag = true; + self + } + + /// Mark this route as supporting bioavailability handling. + pub fn with_bioavailability(mut self) -> Self { + self.has_bioavailability = true; + self + } + + /// Request direct injection into the destination state at execution time. + pub fn inject_input_to_destination(mut self) -> Self { + self.input_policy = Some(RouteInputPolicy::InjectToDestination); + self + } + + /// Request an explicit low-level input vector at execution time. + pub fn expect_explicit_input(mut self) -> Self { + self.input_policy = Some(RouteInputPolicy::ExplicitInputVector); + self + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> RouteKind { + self.kind + } + + pub fn destination(&self) -> Option<&str> { + self.destination.as_deref() + } + + pub fn has_lag(&self) -> bool { + self.has_lag + } + + pub fn has_bioavailability(&self) -> bool { + self.has_bioavailability + } + + pub fn input_policy(&self) -> Option { + self.input_policy + } +} + +fn resolve_kind( + declared_kind: Option, + requested_kind: Option, +) -> Result { + match (declared_kind, requested_kind) { + (Some(declared), Some(requested)) if declared != requested => { + Err(ModelMetadataError::ModelKindConflict { + declared, + requested, + }) + } + (Some(declared), _) => Ok(declared), + (None, Some(requested)) => Ok(requested), + (None, None) => Err(ModelMetadataError::MissingModelKind), + } +} + +fn resolve_particles( + kind: ModelKind, + declared_particles: Option, + fallback_particles: Option, +) -> Result, ModelMetadataError> { + let particles = match (declared_particles, fallback_particles) { + (Some(declared), Some(fallback)) if declared != fallback => { + return Err(ModelMetadataError::ParticleCountConflict { declared, fallback }); + } + (Some(declared), _) => Some(declared), + (None, Some(fallback)) => Some(fallback), + (None, None) => None, + }; + + match kind { + ModelKind::Ode | ModelKind::Analytical if particles.is_some() => { + Err(ModelMetadataError::ParticlesNotAllowed { kind }) + } + ModelKind::Sde if particles.is_none() => Err(ModelMetadataError::MissingParticles), + _ => Ok(particles), + } +} + +fn validate_kind_specific_fields( + kind: ModelKind, + analytical: Option, + particles: Option, +) -> Result<(), ModelMetadataError> { + match kind { + ModelKind::Ode => { + if analytical.is_some() { + return Err(ModelMetadataError::AnalyticalKernelNotAllowed { kind }); + } + if particles.is_some() { + return Err(ModelMetadataError::ParticlesNotAllowed { kind }); + } + } + ModelKind::Analytical => { + if particles.is_some() { + return Err(ModelMetadataError::ParticlesNotAllowed { kind }); + } + } + ModelKind::Sde => { + if analytical.is_some() { + return Err(ModelMetadataError::AnalyticalKernelNotAllowed { kind }); + } + } + } + Ok(()) +} + +fn validate_unique_names( + values: &[T], + domain: NameDomain, + name_of: impl Fn(&T) -> &str, +) -> Result<(), ModelMetadataError> { + let mut names = std::collections::HashSet::with_capacity(values.len()); + for value in values { + let name = name_of(value); + if !names.insert(name) { + return Err(ModelMetadataError::DuplicateName { + domain, + name: name.to_string(), + }); + } + } + Ok(()) +} + +fn validate_routes( + routes: Vec, + states: &[State], +) -> Result<(Vec, usize), ModelMetadataError> { + let mut bolus_inputs = 0; + let mut infusion_inputs = 0; + let mut validated_routes = Vec::with_capacity(routes.len()); + + for (declaration_index, route) in routes.into_iter().enumerate() { + let input_index = match route.kind { + RouteKind::Bolus => { + let index = bolus_inputs; + bolus_inputs += 1; + index + } + RouteKind::Infusion => { + let index = infusion_inputs; + infusion_inputs += 1; + index + } + }; + + validated_routes.push(validate_route( + route, + declaration_index, + input_index, + states, + )?); + } + + Ok((validated_routes, bolus_inputs.max(infusion_inputs))) +} + +fn validate_route( + route: Route, + declaration_index: usize, + input_index: usize, + states: &[State], +) -> Result { + if route.kind == RouteKind::Infusion && route.has_lag { + return Err(ModelMetadataError::InfusionLagNotAllowed { + route: route.name.clone(), + }); + } + + if route.kind == RouteKind::Infusion && route.has_bioavailability { + return Err(ModelMetadataError::InfusionBioavailabilityNotAllowed { + route: route.name.clone(), + }); + } + + let destination = + route + .destination + .clone() + .ok_or_else(|| ModelMetadataError::MissingRouteDestination { + route: route.name.clone(), + })?; + let destination_index = states + .iter() + .position(|state| state.name() == destination) + .ok_or_else(|| ModelMetadataError::UnknownRouteDestination { + route: route.name.clone(), + destination: destination.clone(), + })?; + + Ok(ValidatedRoute { + name: route.name, + kind: route.kind, + declaration_index, + input_index, + destination, + destination_index, + has_lag: route.has_lag, + has_bioavailability: route.has_bioavailability, + input_policy: route.input_policy, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builds_ode_metadata_shape() { + let metadata = new("bimodal_ke") + .kind(ModelKind::Ode) + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")); + + assert_eq!(metadata.name(), "bimodal_ke"); + assert_eq!(metadata.kind_decl(), Some(ModelKind::Ode)); + assert_eq!(metadata.parameters_decl()[0].name(), "ke"); + assert_eq!(metadata.parameters_decl()[1].name(), "v"); + assert_eq!(metadata.states_decl()[0].name(), "central"); + assert_eq!(metadata.outputs_decl()[0].name(), "cp"); + assert_eq!(metadata.routes_decl()[0].name(), "iv"); + assert_eq!(metadata.routes_decl()[0].kind(), RouteKind::Infusion); + assert_eq!(metadata.routes_decl()[0].destination(), Some("central")); + } + + #[test] + fn builds_analytical_metadata_shape() { + let metadata = new("one_cmt_abs") + .kind(ModelKind::Analytical) + .parameters(["ka", "ke", "v"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route(Route::bolus("oral").to_state("gut").with_bioavailability()) + .route(Route::infusion("iv").to_state("central")) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption); + + assert_eq!(metadata.kind_decl(), Some(ModelKind::Analytical)); + assert_eq!(metadata.states_decl()[0].name(), "gut"); + assert_eq!(metadata.states_decl()[1].name(), "central"); + assert_eq!(metadata.routes_decl()[0].kind(), RouteKind::Bolus); + assert!(metadata.routes_decl()[0].has_bioavailability()); + assert_eq!( + metadata.analytical_kernel_decl(), + Some(AnalyticalKernel::OneCompartmentWithAbsorption) + ); + } + + #[test] + fn builds_sde_metadata_shape() { + let metadata = new("one_cmt_sde") + .kind(ModelKind::Sde) + .parameters(["ke", "sigma", "v"]) + .covariates([Covariate::continuous("wt"), Covariate::locf("age")]) + .states(["central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(128); + + assert_eq!(metadata.kind_decl(), Some(ModelKind::Sde)); + assert_eq!(metadata.covariates_decl()[0].name(), "wt"); + assert_eq!( + metadata.covariates_decl()[0].interpolation(), + Some(CovariateInterpolation::Linear) + ); + assert_eq!(metadata.covariates_decl()[1].name(), "age"); + assert_eq!( + metadata.covariates_decl()[1].interpolation(), + Some(CovariateInterpolation::Locf) + ); + assert_eq!(metadata.particles_decl(), Some(128)); + assert_eq!( + metadata.routes_decl()[0].input_policy(), + Some(RouteInputPolicy::InjectToDestination) + ); + } + + #[test] + fn validates_metadata_and_exposes_lookup_helpers() { + let metadata = new("bimodal_ke") + .kind(ModelKind::Ode) + .parameters(["ke", "v"]) + .covariates([Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect("metadata should validate"); + + assert_eq!(metadata.parameter_index("ke"), Some(0)); + assert_eq!(metadata.parameter_index("v"), Some(1)); + assert_eq!(metadata.covariate_index("wt"), Some(0)); + assert_eq!(metadata.state_index("central"), Some(0)); + assert!(metadata.route("iv").is_some()); + assert_eq!(metadata.route_declaration_index("iv"), Some(0)); + assert_eq!(metadata.route_input_count(), 1); + assert_eq!(metadata.output_index("cp"), Some(0)); + assert_eq!( + metadata.route("iv").expect("route exists").destination(), + "central" + ); + assert_eq!( + metadata + .route("iv") + .expect("route exists") + .declaration_index(), + 0 + ); + assert_eq!(metadata.route("iv").expect("route exists").input_index(), 0); + assert_eq!( + metadata + .route("iv") + .expect("route exists") + .destination_index(), + 0 + ); + } + + #[test] + fn duplicate_names_fail_validation() { + let error = new("dup_params") + .kind(ModelKind::Ode) + .parameters(["ke", "ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("duplicate parameters must fail"); + + assert_eq!( + error, + ModelMetadataError::DuplicateName { + domain: NameDomain::Parameter, + name: "ke".to_string(), + } + ); + } + + #[test] + fn missing_route_destination_fails_validation() { + let error = new("missing_route_destination") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv")) + .validate() + .expect_err("route destination is required"); + + assert_eq!( + error, + ModelMetadataError::MissingRouteDestination { + route: "iv".to_string(), + } + ); + } + + #[test] + fn unknown_route_destination_fails_validation() { + let error = new("unknown_route_destination") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("peripheral")) + .validate() + .expect_err("unknown destinations must fail"); + + assert_eq!( + error, + ModelMetadataError::UnknownRouteDestination { + route: "iv".to_string(), + destination: "peripheral".to_string(), + } + ); + } + + #[test] + fn shared_input_routes_preserve_declaration_and_input_identity() { + let metadata = new("shared_input") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + Route::bolus("oral").to_state("gut"), + Route::infusion("iv").to_state("central"), + ]) + .validate() + .expect("shared-input metadata should validate"); + + assert_eq!(metadata.routes().len(), 2); + assert_eq!(metadata.route_input_count(), 1); + assert_eq!(metadata.route_declaration_index("oral"), Some(0)); + assert_eq!(metadata.route_declaration_index("iv"), Some(1)); + assert_eq!(metadata.route("oral").expect("oral route").input_index(), 0); + assert_eq!(metadata.route("iv").expect("iv route").input_index(), 0); + assert_eq!( + metadata + .route("oral") + .expect("oral route") + .declaration_index(), + 0 + ); + assert_eq!( + metadata.route("iv").expect("iv route").declaration_index(), + 1 + ); + } + + #[test] + fn infusion_routes_reject_lag_and_bioavailability() { + let lag_error = new("infusion_lag") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central").with_lag()) + .validate() + .expect_err("infusion lag must fail"); + + assert_eq!( + lag_error, + ModelMetadataError::InfusionLagNotAllowed { + route: "iv".to_string(), + } + ); + + let fa_error = new("infusion_fa") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .with_bioavailability(), + ) + .validate() + .expect_err("infusion bioavailability must fail"); + + assert_eq!( + fa_error, + ModelMetadataError::InfusionBioavailabilityNotAllowed { + route: "iv".to_string(), + } + ); + } + + #[test] + fn validate_requires_or_accepts_a_kind() { + let error = new("kind_required") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("kindless metadata needs explicit validation kind"); + + assert_eq!(error, ModelMetadataError::MissingModelKind); + + let validated = new("kind_override") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for(ModelKind::Ode) + .expect("caller-provided kind should validate"); + + assert_eq!(validated.kind(), ModelKind::Ode); + } + + #[test] + fn conflicting_kinds_fail_validation() { + let error = new("kind_conflict") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for(ModelKind::Sde) + .expect_err("conflicting kinds must fail"); + + assert_eq!( + error, + ModelMetadataError::ModelKindConflict { + declared: ModelKind::Ode, + requested: ModelKind::Sde, + } + ); + } + + #[test] + fn particles_are_rejected_for_ode_and_analytical() { + let ode_error = new("ode_particles") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate() + .expect_err("ODE metadata cannot declare particles"); + + assert_eq!( + ode_error, + ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Ode, + } + ); + + let analytical_error = new("analytical_particles") + .kind(ModelKind::Analytical) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate() + .expect_err("Analytical metadata cannot declare particles"); + + assert_eq!( + analytical_error, + ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Analytical, + } + ); + } + + #[test] + fn analytical_kernel_is_limited_to_analytical_models() { + let error = new("ode_kernel") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .analytical_kernel(AnalyticalKernel::OneCompartment) + .validate() + .expect_err("ODE metadata cannot declare an analytical kernel"); + + assert_eq!( + error, + ModelMetadataError::AnalyticalKernelNotAllowed { + kind: ModelKind::Ode, + } + ); + } + + #[test] + fn sde_requires_particles_or_a_fallback_count() { + let error = new("sde_missing_particles") + .kind(ModelKind::Sde) + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("SDE metadata requires particles"); + + assert_eq!(error, ModelMetadataError::MissingParticles); + + let validated = new("sde_fallback_particles") + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for_with_particles(ModelKind::Sde, 128) + .expect("fallback particle count should satisfy SDE validation"); + + assert_eq!(validated.kind(), ModelKind::Sde); + assert_eq!(validated.particles(), Some(128)); + } + + #[test] + fn conflicting_particle_counts_fail_validation() { + let error = new("sde_particle_conflict") + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate_for_with_particles(ModelKind::Sde, 128) + .expect_err("mismatched particle counts must fail"); + + assert_eq!( + error, + ModelMetadataError::ParticleCountConflict { + declared: 64, + fallback: 128, + } + ); + } +} diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 39cd741f..94ca5ccf 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -1,27 +1,78 @@ +//! Handwritten equation families and their shared simulation interfaces. +//! +//! This module is the public home for handwritten [`ODE`], [`Analytical`], and +//! [`SDE`] models, plus the shared [`Equation`] trait and the metadata types +//! that attach public names to parameters, states, routes, and outputs. +//! +//! Use this module when you want to: +//! - choose between deterministic ODE, analytical, and stochastic SDE models +//! - attach metadata so dataset labels such as `"iv"` and `"cp"` resolve by +//! name instead of by dense numeric index +//! - work with prediction or likelihood APIs across equation families +//! +//! # Equation Families +//! +//! - [`ODE`] for deterministic models that must be numerically integrated. +//! - [`Analytical`] for supported closed-form models. +//! - [`SDE`] for stochastic models that use particles. +//! +//! # Labels And Metadata +//! +//! Input and output labels arrive from public data APIs as strings. +//! +//! - Without metadata, handwritten models fall back to numeric labels such as +//! `0` or `1`. +//! - With [`metadata::ModelMetadata`] attached, route and output labels are +//! resolved by name against the declared routes and outputs before +//! simulation. +//! +//! That label-first path is the preferred public workflow for current authoring. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); +//! assert!(metadata.output("cp").is_some()); +//! ``` + use std::fmt::Debug; pub mod analytical; -pub mod meta; +pub mod metadata; pub mod ode; pub mod sde; pub use analytical::*; -pub use meta::*; +pub use metadata::*; pub use ode::*; +pub use pharmsol_dsl::{AnalyticalKernel, ModelKind}; +use pharmsol_dsl::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; pub use sde::*; use crate::{ error_model::AssayErrorModels, simulator::{Fa, Lag}, - Covariates, Event, Infusion, Observation, PharmsolError, Subject, + Covariates, Event, Infusion, InputLabel, Observation, Occasion, OutputLabel, PharmsolError, + Subject, }; use super::likelihood::Prediction; /// Trait for state vectors that can receive bolus doses. pub trait State { - /// Add a bolus dose to the state at the specified input compartment. + /// Add a bolus dose to the state at the specified resolved input index. /// /// # Parameters - /// - `input`: The compartment index + /// - `input`: The resolved dense input index used by the execution layer /// - `amount`: The bolus amount fn add_bolus(&mut self, input: usize, amount: f64); } @@ -112,7 +163,7 @@ pub trait Cache: Sized { fn disable_cache(self) -> Self; } -/// Trait defining the associated types for equations. +/// Associated state and prediction container types for an equation family. pub trait EquationTypes { /// The state vector type type S: State + Debug; @@ -128,6 +179,7 @@ pub(crate) trait EquationPriv: EquationTypes { fn get_nstates(&self) -> usize; fn get_ndrugs(&self) -> usize; fn get_nouteqs(&self) -> usize; + fn metadata(&self) -> Option<&ValidatedModelMetadata>; fn solve( &self, state: &mut Self::S, @@ -140,6 +192,93 @@ pub(crate) trait EquationPriv: EquationTypes { fn nparticles(&self) -> usize { 1 } + + fn resolve_input_label( + &self, + label: &InputLabel, + expected_kind: RouteKind, + ) -> Result { + if let Some(metadata) = self.metadata() { + let route = metadata + .route(label.as_str()) + .or_else(|| { + canonical_numeric_alias(label.as_str(), NUMERIC_ROUTE_PREFIX) + .and_then(|alias| metadata.route(alias.as_str())) + }) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + + if route.kind() != expected_kind { + return Err(PharmsolError::UnsupportedInputRouteKind { + input: route.input_index(), + kind: match expected_kind { + RouteKind::Bolus => pharmsol_dsl::RouteKind::Bolus, + RouteKind::Infusion => pharmsol_dsl::RouteKind::Infusion, + }, + }); + } + + return Ok(route.input_index()); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + }) + } + + fn resolve_output_label(&self, label: &OutputLabel) -> Result { + if let Some(metadata) = self.metadata() { + return metadata + .output_index(label.as_str()) + .or_else(|| { + canonical_numeric_alias(label.as_str(), NUMERIC_OUTPUT_PREFIX) + .and_then(|alias| metadata.output_index(alias.as_str())) + }) + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }) + } + + fn resolve_occasion_events( + &self, + occasion: &Occasion, + support_point: &[f64], + covariates: &Covariates, + ) -> Result, PharmsolError> { + let mut resolved = occasion.clone(); + + for event in resolved.events_iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(resolved.process_events( + Some((self.fa(), self.lag(), support_point, covariates)), + true, + )) + } #[allow(dead_code)] fn is_sde(&self) -> bool { false @@ -180,13 +319,20 @@ pub(crate) trait EquationPriv: EquationTypes { ) -> Result<(), PharmsolError> { match event { Event::Bolus(bolus) => { - if bolus.input() >= self.get_ndrugs() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: self.get_ndrugs(), }); } - x.add_bolus(bolus.input(), bolus.amount()); + x.add_bolus(input, bolus.amount()); } Event::Infusion(infusion) => { infusions.push(infusion.clone()); @@ -219,11 +365,22 @@ pub(crate) trait EquationPriv: EquationTypes { } } -/// Trait for model equations that can be simulated. +fn canonical_numeric_alias(label: &str, prefix: &str) -> Option { + if label.is_empty() || !label.chars().all(|ch| ch.is_ascii_digit()) { + return None; + } + Some(format!("{prefix}{label}")) +} + +/// Trait for handwritten model equations that can be simulated. +/// +/// [`Equation`] is the shared interface implemented by handwritten [`ODE`], +/// [`Analytical`], and [`SDE`] models. /// -/// This trait defines the interface for different types of model equations -/// (ODE, SDE, analytical) that can be simulated to generate predictions -/// and estimate parameters. +/// Subject data enters this layer through public labels on dose and observation +/// events. If metadata is attached to the equation, those labels are resolved by +/// name before simulation. Otherwise, the execution layer expects numeric labels +/// that can be interpreted as dense indices. /// /// # Likelihood Calculation /// @@ -232,6 +389,14 @@ pub(crate) trait EquationPriv: EquationTypes { /// is provided for backward compatibility. #[allow(private_bounds)] pub trait Equation: EquationPriv + 'static + Clone + Sync { + #[doc(hidden)] + fn bind_error_models( + &self, + error_models: &AssayErrorModels, + ) -> Result { + Ok(error_models.bind_to(self)?) + } + /// Estimate the likelihood of the subject given the support point and error model. /// /// **Deprecated**: Use [`estimate_log_likelihood`](Self::estimate_log_likelihood) instead @@ -309,6 +474,22 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { self.get_nstates() } + /// Build a label-aware [`AssayErrorModels`] set for this equation. + /// + /// Handwritten equations resolve output labels from attached metadata. + /// Equations without metadata fall back to an explicit unbound set so dense + /// output-slot workflows remain available without adding runtime lookup cost. + #[doc(hidden)] + fn assay_error_models(&self) -> AssayErrorModels { + self.metadata() + .map(|metadata| { + AssayErrorModels::with_output_names( + metadata.outputs().iter().map(|output| output.name()), + ) + }) + .unwrap_or_else(AssayErrorModels::empty) + } + /// Simulate a subject with given parameters and optionally calculate likelihood. /// /// # Parameters @@ -324,6 +505,11 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { support_point: &[f64], error_models: Option<&AssayErrorModels>, ) -> Result<(Self::P, Option), PharmsolError> { + let bound_error_models = match error_models { + Some(error_models) => Some(self.bind_error_models(error_models)?), + None => None, + }; + let mut output = Self::P::new(self.nparticles()); let mut likelihood = Vec::new(); for occasion in subject.occasions() { @@ -331,16 +517,13 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { let mut x = self.initial_state(support_point, covariates, occasion.index()); let mut infusions = Vec::new(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; for (index, event) in events.iter().enumerate() { self.simulate_event( support_point, event, events.get(index + 1), - error_models, + bound_error_models.as_ref(), covariates, &mut x, &mut infusions, @@ -349,11 +532,14 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { )?; } } - let ll = error_models.map(|_| likelihood.iter().product::()); + let ll = bound_error_models + .as_ref() + .map(|_| likelihood.iter().product::()); Ok((output, ll)) } } +/// Runtime family tag for handwritten equations. #[repr(C)] #[derive(Clone, Debug)] pub enum EqnKind { diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index eed65e7a..9f5ab3c1 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -1,9 +1,8 @@ use crate::{Covariates, Infusion, PharmsolError}; use diffsol::{ ConstantOp, LinearOp, MatrixCommon, NalgebraContext, NalgebraMat, NonLinearOp, - NonLinearOpJacobian, OdeEquations, OdeEquationsRef, Op, UnitCallable, Vector, VectorCommon, + NonLinearOpJacobian, OdeEquations, OdeEquationsRef, Op, UnitCallable, Vector, }; -use nalgebra::DVector; use std::{cell::RefCell, cmp::Ordering}; type M = NalgebraMat; type V = ::V; @@ -11,13 +10,13 @@ type C = ::C; type T = ::T; #[derive(Debug, Clone)] -struct InfusionChannel { +struct InfusionTrack { input: usize, event_times: Vec, cumulative_rates: Vec, } -impl InfusionChannel { +impl InfusionTrack { fn new(input: usize, mut events: Vec<(f64, f64)>) -> Self { events.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal)); @@ -63,24 +62,31 @@ impl InfusionChannel { #[derive(Debug, Clone, Default)] struct InfusionSchedule { - channels: Vec, + tracks: Vec, } impl InfusionSchedule { - fn new(ndrugs: usize, infusions: &[&Infusion]) -> Result { - if ndrugs == 0 || infusions.is_empty() { - return Ok(Self { - channels: Vec::new(), - }); + fn new<'a, I>(ndrugs: usize, infusions: I) -> Result + where + I: IntoIterator, + { + if ndrugs == 0 { + return Ok(Self { tracks: Vec::new() }); } let mut per_input: Vec> = vec![Vec::new(); ndrugs]; + let mut saw_infusion = false; for infusion in infusions { + saw_infusion = true; if infusion.duration() <= 0.0 { continue; } - let input = infusion.input(); + let input = infusion + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: infusion.input().to_string(), + })?; if input >= ndrugs { return Err(PharmsolError::InputOutOfRange { input, ndrugs }); } @@ -90,27 +96,31 @@ impl InfusionSchedule { per_input[input].push((infusion.time() + infusion.duration(), -rate)); } - let channels = per_input + if !saw_infusion { + return Ok(Self { tracks: Vec::new() }); + } + + let tracks = per_input .into_iter() .enumerate() .filter_map(|(input, events)| { if events.is_empty() { None } else { - Some(InfusionChannel::new(input, events)) + Some(InfusionTrack::new(input, events)) } }) .collect(); - Ok(Self { channels }) + Ok(Self { tracks }) } fn fill_rate_vector(&self, time: f64, rateiv: &mut V) { rateiv.fill(0.0); - for channel in &self.channels { - let rate = channel.rate_at(time); + for track in &self.tracks { + let rate = track.rate_at(time); if rate != 0.0 { - rateiv[channel.input] = rate; + rateiv[track.input] = rate; } } } @@ -317,7 +327,6 @@ where nstates: usize, nparams: usize, init: V, - p: Vec, p_as_v: V, zero_bolus: V, covariates: &'a Covariates, @@ -332,17 +341,19 @@ where /// Creates a new PMProblem with a pre-converted parameter vector. /// This avoids an allocation when the caller already has a V representation. #[allow(clippy::too_many_arguments)] - pub fn with_params_v( + pub fn with_params_v<'b, I>( func: F, nstates: usize, ndrugs: usize, - p: Vec, p_as_v: V, covariates: &'a Covariates, - infusions: &[&'a Infusion], + infusions: I, init: V, - ) -> Result { - let nparams = p.len(); + ) -> Result + where + I: IntoIterator, + { + let nparams = p_as_v.len(); let rateiv_buffer = RefCell::new(V::zeros(ndrugs, NalgebraContext)); let infusion_schedule = InfusionSchedule::new(ndrugs, infusions)?; // Pre-allocate zero bolus vector @@ -353,7 +364,6 @@ where nstates, nparams, init, - p, p_as_v, zero_bolus, covariates, @@ -430,13 +440,10 @@ where } fn get_params(&self, p: &mut V) { - // Avoid unnecessary cloning by directly copying values from self.p - if p.len() == self.p.len() { - for i in 0..self.p.len() { - p[i] = self.p[i]; - } + if p.len() == self.p_as_v.len() { + p.copy_from(&self.p_as_v); } else { - p.copy_from(&DVector::from_vec(self.p.clone()).into()); + *p = self.p_as_v.clone(); } } @@ -453,13 +460,9 @@ where } fn set_params(&mut self, p: &V) { - if self.p.len() == p.len() { - for i in 0..p.len() { - self.p[i] = p[i]; - self.p_as_v[i] = p[i]; - } + if self.p_as_v.len() == p.len() { + self.p_as_v.copy_from(p); } else { - self.p = p.inner().iter().cloned().collect(); self.p_as_v = p.clone(); } } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 17b04235..13f0c2f3 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -27,8 +27,13 @@ use diffsol::{ OdeSolverStopReason, Vector, VectorHost, }; use nalgebra::DVector; +use pharmsol_dsl::ModelKind; +use thiserror::Error; -use super::{Equation, EquationPriv, EquationTypes, State}; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, State, + ValidatedModelMetadata, +}; const RTOL: f64 = 1e-4; const ATOL: f64 = 1e-4; @@ -76,6 +81,18 @@ pub enum ExplicitRkTableau { Tsit45, } +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum OdeMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("ODE declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error("ODE declares {declared} route metadata entries but model has {expected} inputs")] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("ODE declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} + #[derive(Clone, Debug)] pub struct ODE { diffeq: DiffEq, @@ -87,6 +104,7 @@ pub struct ODE { solver: OdeSolver, rtol: f64, atol: f64, + metadata: Option, cache: Option, } @@ -102,6 +120,7 @@ impl ODE { solver: OdeSolver::default(), rtol: RTOL, atol: ATOL, + metadata: None, cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -109,18 +128,21 @@ impl ODE { /// Set the number of state variables (ODE compartments). pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } @@ -136,6 +158,66 @@ impl ODE { self.atol = atol; self } + + /// Attach validated handwritten-model metadata to this ODE. + pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { + let metadata = metadata.validate_for(ModelKind::Ode)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + /// Access the validated metadata attached to this ODE, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), OdeMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(OdeMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_input_count(); + if declared_routes != neqs.ndrugs { + return Err(OdeMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(OdeMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for ODE { @@ -174,8 +256,9 @@ fn _estimate_likelihood( support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = ode.bind_error_models(error_models)?; let ypred = _subject_predictions(ode, subject, support_point)?; - Ok(ypred.log_likelihood(error_models)?.exp()) + Ok(ypred.log_likelihood(&bound_error_models)?.exp()) } #[inline(always)] @@ -238,6 +321,11 @@ impl EquationPriv for ODE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -280,7 +368,7 @@ impl EquationPriv for ODE { impl ODE { /// Generic event-loop runner, parameterized over the concrete solver type. #[allow(clippy::too_many_arguments)] - fn run_events<'a, S: OdeSolverMethod<'a, PMProblem<'a, DiffEq>>>( + fn run_events<'a, F, S>( &self, solver: &mut S, events: &[Event], @@ -295,20 +383,31 @@ impl ODE { y_out: &mut V, likelihood: &mut Vec, output: &mut SubjectPredictions, - ) -> Result<(), PharmsolError> { + ) -> Result<(), PharmsolError> + where + F: Fn(&V, &V, f64, &mut V, &V, &V, &Covariates) + 'a, + S: OdeSolverMethod<'a, PMProblem<'a, F>>, + { for (index, event) in events.iter().enumerate() { let next_event = events.get(index + 1); match event { Event::Bolus(bolus) => { - if bolus.input() >= bolus_v.len() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= bolus_v.len() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: bolus_v.len(), }); } bolus_v.fill(0.0); - bolus_v[bolus.input()] = bolus.amount(); + bolus_v[input] = bolus.amount(); state_with_bolus.fill(0.0); state_without_bolus.fill(0.0); @@ -348,7 +447,12 @@ impl ODE { covariates, y_out, ); - let pred = y_out[observation.outeq()]; + let outeq = observation.outeq_index().ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + } + })?; + let pred = y_out[outeq]; let pred = observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); if let Some(error_models) = error_models { @@ -410,18 +514,27 @@ impl Equation for ODE { _estimate_likelihood(self, subject, support_point, error_models) } + fn estimate_predictions( + &self, + subject: &Subject, + support_point: &[f64], + ) -> Result { + _subject_predictions(self, subject, support_point) + } + fn estimate_log_likelihood( &self, subject: &Subject, support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = self.bind_error_models(error_models)?; let ypred = _subject_predictions(self, subject, support_point)?; - ypred.log_likelihood(error_models) + ypred.log_likelihood(&bound_error_models) } - fn kind() -> crate::EqnKind { - crate::EqnKind::ODE + fn kind() -> EqnKind { + EqnKind::ODE } fn simulate_subject( @@ -430,6 +543,11 @@ impl Equation for ODE { support_point: &[f64], error_models: Option<&AssayErrorModels>, ) -> Result<(Self::P, Option), PharmsolError> { + let bound_error_models = match error_models { + Some(error_models) => Some(self.bind_error_models(error_models)?), + None => None, + }; + let mut output = Self::P::new(self.nparticles()); // Preallocate likelihood vector @@ -446,7 +564,8 @@ impl Equation for ODE { let zero_bolus = V::zeros(ndrugs, NalgebraContext); let zero_rateiv = V::zeros(ndrugs, NalgebraContext); let mut bolus_v = V::zeros(ndrugs, NalgebraContext); - let spp_v: V = DVector::from_vec(support_point.to_vec()).into(); + let support_point_vec = support_point.to_vec(); + let spp_v: V = DVector::from_vec(support_point_vec.clone()).into(); // Pre-allocate output vector for observations let mut y_out = V::zeros(self.get_nouteqs(), NalgebraContext); @@ -454,26 +573,26 @@ impl Equation for ODE { // Iterate over occasions for occasion in subject.occasions() { let covariates = occasion.covariates(); - let infusions = occasion.infusions_ref(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; let problem = OdeBuilder::::new() .atol(vec![self.atol]) .rtol(self.rtol) .t0(occasion.initial_time()) .h0(1e-3) - .p(support_point.to_vec()) + .p(support_point_vec.clone()) .build_from_eqn(PMProblem::with_params_v( - self.diffeq, + move |x, p, t, dx, bolus, rateiv, cov| { + (self.diffeq)(x, p, t, dx, bolus, rateiv, cov); + }, nstates, ndrugs, - support_point.to_vec(), spp_v.clone(), covariates, - infusions.as_slice(), + events.iter().filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion), + _ => None, + }), self.initial_state(support_point, covariates, occasion.index()), )?)?; @@ -486,7 +605,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -505,7 +624,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -524,7 +643,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -543,7 +662,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -556,7 +675,332 @@ impl Equation for ODE { } } } - let ll = error_models.map(|_| likelihood.iter().product::()); + let ll = bound_error_models + .as_ref() + .map(|_| likelihood.iter().product::()); Ok((output, ll)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{fa, lag, Subject, SubjectBuilderExt}; + use approx::assert_relative_eq; + use std::sync::atomic::{AtomicUsize, Ordering}; + + static PREDICTION_CACHE_DIFFEQ_CALLS: AtomicUsize = AtomicUsize::new(0); + + fn simple_ode() -> ODE { + ODE::new( + |_x, _p, _t, _dx, _b, _rateiv, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, _y| {}, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + fn route_policy_subject() -> Subject { + Subject::builder("route_policy") + .bolus(0.0, 100.0, "oral") + .infusion(0.0, 100.0, "iv", 1.0) + .observation(1.0, 0.0, "cp") + .build() + } + + fn explicit_route_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + b: &V, + rateiv: &V, + _cov: &Covariates, + ) { + dx[0] = b[0] + rateiv[0]; + } + + fn injected_route_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + _b: &V, + _rateiv: &V, + _cov: &Covariates, + ) { + dx[0] = 0.0; + } + + fn zero_lag(_p: &V, _t: f64, _cov: &Covariates) -> std::collections::HashMap { + std::collections::HashMap::new() + } + + fn unit_fa(_p: &V, _t: f64, _cov: &Covariates) -> std::collections::HashMap { + std::collections::HashMap::new() + } + + fn zero_init(_p: &V, _t: f64, _cov: &Covariates, _x: &mut V) {} + + fn state_output(x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V) { + y[0] = x[0]; + } + + fn counting_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + _b: &V, + _rateiv: &V, + _cov: &Covariates, + ) { + PREDICTION_CACHE_DIFFEQ_CALLS.fetch_add(1, Ordering::SeqCst); + dx[0] = 0.0; + } + + #[test] + fn handwritten_ode_metadata_exposes_name_lookup() { + let ode = simple_ode() + .with_metadata( + super::super::metadata::new("bimodal_ke") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate"); + let metadata = ode.metadata().expect("metadata exists"); + + assert_eq!(ode.parameter_index("ke"), Some(0)); + assert_eq!(ode.parameter_index("v"), Some(1)); + assert_eq!(ode.state_index("central"), Some(0)); + assert!(metadata.route("iv").is_some()); + assert!(metadata.output("cp").is_some()); + assert_eq!(metadata.kind(), ModelKind::Ode); + } + + #[test] + fn handwritten_ode_without_metadata_keeps_raw_path() { + let ode = simple_ode(); + + assert!(ode.metadata().is_none()); + assert_eq!(ode.state_index("central"), None); + } + + #[test] + fn handwritten_ode_rejects_dimension_mismatches() { + let error = simple_ode() + .with_metadata( + super::super::metadata::new("wrong_outputs") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp", "auc"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect_err("output-count mismatches must fail"); + + assert_eq!( + error, + OdeMetadataError::OutputCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_ode_rejects_invalid_metadata() { + let error = simple_ode() + .with_metadata( + super::super::metadata::new("missing_destination") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv")), + ) + .expect_err("invalid metadata must fail during attachment"); + + assert_eq!( + error, + OdeMetadataError::Validation(ModelMetadataError::MissingRouteDestination { + route: "iv".to_string(), + }) + ); + } + + #[test] + fn handwritten_ode_defaults_to_explicit_route_vectors() { + let ode = ODE::new( + explicit_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("explicit_routes") + .states(["central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral").to_state("central"), + super::super::Route::infusion("iv").to_state("central"), + ]), + ) + .expect("metadata attachment should validate"); + + let predictions = ode + .simulate_subject(&route_policy_subject(), &[], None) + .expect("simulation should succeed") + .0; + let metadata = ode.metadata().expect("metadata exists"); + + assert_eq!( + metadata.route("oral").map(|route| route.input_index()), + Some(0) + ); + assert_eq!( + metadata.route("iv").map(|route| route.input_index()), + Some(0) + ); + assert_relative_eq!( + predictions.predictions()[0].prediction(), + 200.0, + epsilon = 1e-6 + ); + } + + #[test] + fn handwritten_ode_metadata_input_policy_is_descriptive_only() { + let ode = ODE::new( + injected_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("injected_routes") + .states(["central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + super::super::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("metadata attachment should validate"); + + let predictions = ode + .simulate_subject(&route_policy_subject(), &[], None) + .expect("simulation should succeed") + .0; + + assert_relative_eq!( + predictions.predictions()[0].prediction(), + 0.0, + epsilon = 1e-6 + ); + } + + #[test] + fn handwritten_ode_metadata_resolves_raw_numeric_aliases_against_canonical_labels() { + let ode = ODE::new( + explicit_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("numeric_alias_ode") + .states(["central"]) + .outputs(["outeq_1"]) + .route(super::super::Route::infusion("input_1").to_state("central")), + ) + .expect("metadata attachment should validate"); + + let canonical = Subject::builder("canonical") + .infusion(0.0, 100.0, "input_1", 1.0) + .observation(1.0, 0.0, "outeq_1") + .build(); + let aliased = Subject::builder("aliased") + .infusion(0.0, 100.0, "1", 1.0) + .observation(1.0, 0.0, "1") + .build(); + + let canonical_predictions = ode + .simulate_subject(&canonical, &[], None) + .expect("canonical labels should simulate") + .0; + let aliased_predictions = ode + .simulate_subject(&aliased, &[], None) + .expect("raw numeric aliases should simulate") + .0; + + assert_relative_eq!( + canonical_predictions.predictions()[0].prediction(), + aliased_predictions.predictions()[0].prediction(), + epsilon = 1e-6 + ); + } + + #[test] + fn changing_dimensions_after_metadata_clears_route_metadata() { + let ode = simple_ode() + .with_metadata( + super::super::metadata::new("bimodal_ke") + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate") + .with_ndrugs(2); + + assert!(ode.metadata().is_none()); + } + + #[test] + fn handwritten_ode_estimate_predictions_uses_prediction_cache() { + PREDICTION_CACHE_DIFFEQ_CALLS.store(0, Ordering::SeqCst); + + let ode = ODE::new(counting_kernel, zero_lag, unit_fa, zero_init, state_output) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1); + let subject = Subject::builder("cached_predictions") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .build(); + + let first = ode + .estimate_predictions(&subject, &[]) + .expect("first prediction run should succeed"); + let first_calls = PREDICTION_CACHE_DIFFEQ_CALLS.load(Ordering::SeqCst); + assert!(first_calls > 0); + + let second = ode + .estimate_predictions(&subject, &[]) + .expect("second prediction run should succeed"); + let second_calls = PREDICTION_CACHE_DIFFEQ_CALLS.load(Ordering::SeqCst); + + assert_eq!(first.predictions().len(), second.predictions().len()); + assert_eq!(first_calls, second_calls); + } +} diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index af8ea246..c5b01435 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -3,8 +3,10 @@ mod em; use diffsol::{NalgebraContext, Vector}; use nalgebra::DVector; use ndarray::{concatenate, Array2, Axis}; +use pharmsol_dsl::ModelKind; use rand::{rng, RngExt}; use rayon::prelude::*; +use thiserror::Error; use crate::{ data::{Covariates, Infusion}, @@ -21,7 +23,57 @@ use diffsol::VectorCommon; use crate::PharmsolError; -use super::{Equation, EquationPriv, EquationTypes, Predictions, State}; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, Predictions, + State, ValidatedModelMetadata, +}; + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum SdeMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("SDE declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error("SDE declares {declared} route metadata entries but model has {expected} inputs")] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("SDE declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} + +#[derive(Clone, Debug, Default)] +struct InjectedBolusMappings { + destinations: Vec>, +} + +impl InjectedBolusMappings { + fn explicit(ndrugs: usize) -> Self { + Self { + destinations: vec![None; ndrugs], + } + } + + fn from_destinations(ndrugs: usize, destinations: &[Option]) -> Self { + let mut mappings = Self::explicit(ndrugs); + for (input, destination) in destinations.iter().copied().take(ndrugs).enumerate() { + mappings.destinations[input] = destination; + } + mappings + } + + fn invalidate_for_ndrugs(&mut self, ndrugs: usize) { + *self = Self::explicit(ndrugs); + } + + fn apply(&self, state: &mut [DVector], input: usize, amount: f64) -> bool { + let Some(destination) = self.destinations.get(input).copied().flatten() else { + return false; + }; + state.par_iter_mut().for_each(|particle| { + particle[destination] += amount; + }); + true + } +} /// Simulate a stochastic differential equation (SDE) event. /// @@ -44,7 +96,7 @@ use super::{Equation, EquationPriv, EquationTypes, Predictions, State}; /// The state vector at time `tf` after simulation. #[inline(always)] #[allow(clippy::too_many_arguments)] -pub(crate) fn simulate_sde_event( +fn simulate_sde_event( drift: &Drift, difussion: &Diffusion, x: V, @@ -70,7 +122,10 @@ pub(crate) fn simulate_sde_event( let mut rateiv = V::zeros(ndrugs, NalgebraContext); for infusion in &infusion_events { if time >= infusion.time() && time <= infusion.duration() + infusion.time() { - rateiv[infusion.input()] += infusion.amount() / infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + rateiv[input] += infusion.amount() / infusion.duration(); } } @@ -133,6 +188,8 @@ pub struct SDE { out: Out, neqs: Neqs, nparticles: usize, + metadata: Option, + injected_bolus_mappings: InjectedBolusMappings, cache: Option, } @@ -164,6 +221,8 @@ impl SDE { out, neqs: Neqs::default(), nparticles, + metadata: None, + injected_bolus_mappings: InjectedBolusMappings::default(), cache: Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)), } } @@ -171,20 +230,92 @@ impl SDE { /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } + + /// Attach validated handwritten-model metadata to this SDE model. + pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { + let metadata = metadata.validate_for_with_particles(ModelKind::Sde, self.nparticles)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + #[doc(hidden)] + pub fn with_injected_bolus_inputs(mut self, destinations: &[Option]) -> Self { + self.injected_bolus_mappings = + InjectedBolusMappings::from_destinations(self.neqs.ndrugs, destinations); + self + } + + /// Access the validated metadata attached to this SDE model, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + self.injected_bolus_mappings + .invalidate_for_ndrugs(self.neqs.ndrugs); + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), SdeMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(SdeMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_input_count(); + if declared_routes != neqs.ndrugs { + return Err(SdeMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(SdeMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for SDE { @@ -328,6 +459,11 @@ impl EquationPriv for SDE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -386,7 +522,10 @@ impl EquationPriv for SDE { covariates, &mut y, ); - *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec()); + let outeq = observation + .outeq_index() + .expect("resolved observations should use numeric output labels"); + *p = observation.to_prediction(y[outeq], x[i].as_slice().to_vec()); }); let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?; *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap(); @@ -435,6 +574,67 @@ impl EquationPriv for SDE { } x } + + fn simulate_event( + &self, + support_point: &[f64], + event: &crate::Event, + next_event: Option<&crate::Event>, + error_models: Option<&AssayErrorModels>, + covariates: &Covariates, + x: &mut Self::S, + infusions: &mut Vec, + likelihood: &mut Vec, + output: &mut Self::P, + ) -> Result<(), PharmsolError> { + match event { + crate::Event::Bolus(bolus) => { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { + return Err(PharmsolError::InputOutOfRange { + input, + ndrugs: self.get_ndrugs(), + }); + } + if !self.injected_bolus_mappings.apply(x, input, bolus.amount()) { + x.add_bolus(input, bolus.amount()); + } + } + crate::Event::Infusion(infusion) => { + infusions.push(infusion.clone()); + } + crate::Event::Observation(observation) => { + self.process_observation( + support_point, + observation, + error_models, + event.time(), + covariates, + x, + likelihood, + output, + )?; + } + } + + if let Some(next_event) = next_event { + self.solve( + x, + support_point, + covariates, + infusions, + event.time(), + next_event.time(), + )?; + } + Ok(()) + } } impl Equation for SDE { @@ -475,8 +675,8 @@ impl Equation for SDE { } } - fn kind() -> crate::EqnKind { - crate::EqnKind::SDE + fn kind() -> EqnKind { + EqnKind::SDE } } @@ -533,3 +733,326 @@ fn sysresample(q: &[f64]) -> Vec { } i } + +#[cfg(test)] +mod tests { + use super::*; + use crate::simulator::equation::{self, Covariate, Route}; + use crate::SubjectBuilderExt; + use crate::{fa, fetch_params, lag}; + + fn simple_sde() -> SDE { + let drift = |x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx[0] = rateiv[0] - x[0]; + }; + let diffusion = |_p: &V, g: &mut V| { + g[0] = 1.0; + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x[0] = 0.0; + }; + let out = |x: &V, p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }; + + SDE::new(drift, diffusion, lag, fa, init, out, 128) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + fn route_policy_sde(drift: Drift) -> SDE { + let diffusion = |_p: &V, sigma: &mut V| { + sigma.fill(0.0); + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + SDE::new(drift, diffusion, lag, fa, init, out, 16) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + } + + #[test] + fn handwritten_sde_metadata_exposes_name_lookup_and_particles() { + let sde = simple_sde() + .with_metadata( + equation::metadata::new("one_cmt_sde") + .parameters(["ke", "v"]) + .covariates([Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect("SDE metadata attachment should validate"); + + let metadata = sde.metadata().expect("metadata exists"); + assert_eq!(metadata.kind(), ModelKind::Sde); + assert_eq!(metadata.particles(), Some(128)); + assert_eq!(sde.parameter_index("ke"), Some(0)); + assert_eq!(sde.parameter_index("v"), Some(1)); + assert_eq!(sde.covariate_index("wt"), Some(0)); + assert_eq!(sde.state_index("central"), Some(0)); + assert!(metadata.route("iv").is_some()); + assert!(metadata.output("cp").is_some()); + } + + #[test] + fn handwritten_sde_metadata_resolves_raw_numeric_aliases_against_canonical_labels() { + let drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + dx[1] = rateiv[0]; + }; + let diffusion = |_p: &V, sigma: &mut V| { + sigma.fill(0.0); + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + let sde = SDE::new(drift, diffusion, lag, fa, init, out, 16) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_alias_sde") + .states(["depot", "central"]) + .outputs(["outeq_1"]) + .route(Route::infusion("input_1").to_state("central")) + .particles(16), + ) + .expect("SDE metadata attachment should validate"); + + let canonical = Subject::builder("canonical") + .infusion(0.0, 100.0, "input_1", 1.0) + .observation(1.0, 0.0, "outeq_1") + .build(); + let aliased = Subject::builder("aliased") + .infusion(0.0, 100.0, "1", 1.0) + .observation(1.0, 0.0, "1") + .build(); + + let canonical_predictions = sde + .estimate_predictions(&canonical, &[]) + .expect("canonical labels should simulate"); + let aliased_predictions = sde + .estimate_predictions(&aliased, &[]) + .expect("raw numeric aliases should simulate"); + + assert!( + (canonical_predictions[[0, 0]].prediction() - aliased_predictions[[0, 0]].prediction()) + .abs() + < 1e-10 + ); + } + + #[test] + fn handwritten_sde_without_metadata_keeps_raw_path() { + let sde = simple_sde(); + + assert!(sde.metadata().is_none()); + assert_eq!(sde.parameter_index("ke"), None); + } + + #[test] + fn handwritten_sde_rejects_dimension_mismatches() { + let error = simple_sde() + .with_metadata( + equation::metadata::new("bad_sde") + .parameters(["ke", "v"]) + .states(["central", "peripheral"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect_err("mismatched state metadata must fail"); + + assert_eq!( + error, + SdeMetadataError::StateCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_sde_rejects_particle_mismatch() { + let error = simple_sde() + .with_metadata( + equation::metadata::new("particle_conflict") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64), + ) + .expect_err("mismatched SDE particles must fail"); + + assert_eq!( + error, + SdeMetadataError::Validation(ModelMetadataError::ParticleCountConflict { + declared: 64, + fallback: 128, + }) + ); + } + + #[test] + fn changing_dimensions_after_metadata_clears_sde_metadata() { + let sde = simple_sde() + .with_metadata( + equation::metadata::new("one_cmt_sde") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect("metadata attachment should validate") + .with_nout(2); + + assert!(sde.metadata().is_none()); + } + + #[test] + fn sde_metadata_input_policy_is_descriptive_only_for_bolus_routes() { + let zero_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, _rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + }; + + let explicit = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("explicit_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(Route::bolus("oral").to_state("central")) + .particles(16), + ) + .expect("explicit metadata should validate"); + + let injected = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("injected_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate"); + + let subject = Subject::builder("bolus_route") + .bolus(0.0, 100.0, "oral") + .missing_observation(0.1, "cp") + .build(); + + let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); + let injected_predictions = injected.estimate_predictions(&subject, &[0.0]).unwrap(); + + assert_eq!(explicit_predictions[[0, 0]].prediction(), 0.0); + assert_eq!(injected_predictions[[0, 0]].prediction(), 0.0); + } + + #[test] + fn sde_metadata_input_policy_does_not_change_explicit_rateiv_behavior() { + let rateiv_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + dx[1] = rateiv[0]; + }; + + let explicit = route_policy_sde(rateiv_drift) + .with_metadata( + equation::metadata::new("explicit_infusion") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(16), + ) + .expect("explicit metadata should validate"); + + let injected = route_policy_sde(rateiv_drift) + .with_metadata( + equation::metadata::new("injected_infusion") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate"); + + let subject = Subject::builder("infusion_route") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(1.0, "cp") + .build(); + + let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); + let injected_predictions = injected.estimate_predictions(&subject, &[0.0]).unwrap(); + + let explicit_prediction = explicit_predictions[[0, 0]].prediction(); + let injected_prediction = injected_predictions[[0, 0]].prediction(); + + assert!(explicit_prediction > 0.0); + assert!((injected_prediction - explicit_prediction).abs() < 1e-8); + } + + #[test] + fn clearing_sde_metadata_preserves_raw_bolus_behavior() { + let zero_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, _rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + }; + + let sde = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("injected_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate") + .with_nout(1); + + let subject = Subject::builder("bolus_route") + .bolus(0.0, 100.0, 0) + .missing_observation(0.1, 0) + .build(); + + let predictions = sde.estimate_predictions(&subject, &[0.0]).unwrap(); + + assert!(sde.metadata().is_none()); + assert_eq!(predictions[[0, 0]].prediction(), 0.0); + } +} diff --git a/src/simulator/likelihood/matrix.rs b/src/simulator/likelihood/matrix.rs index dae2cac4..9512a191 100644 --- a/src/simulator/likelihood/matrix.rs +++ b/src/simulator/likelihood/matrix.rs @@ -47,16 +47,20 @@ pub fn log_likelihood_matrix( error_models: &AssayErrorModels, progress: bool, ) -> Result, PharmsolError> { - let mut log_psi: Array2 = Array2::default((subjects.len(), support_points.nrows()).f()); - - let subjects_vec = subjects.subjects(); + let n_support_points = support_points.nrows(); + let mut log_psi: Array2 = Array2::default((subjects.len(), n_support_points).f()); + let subject_slice = subjects.subjects_slice(); + let support_point_rows = support_points + .axis_iter(Axis(0)) + .map(|row| row.to_vec()) + .collect::>(); let progress_tracker = if progress { - let total = subjects_vec.len() * support_points.nrows(); + let total = subject_slice.len() * n_support_points; println!( "Computing log-likelihood matrix: {} subjects × {} support points...", - subjects_vec.len(), - support_points.nrows() + subject_slice.len(), + n_support_points ); Some(ProgressTracker::new(total)) } else { @@ -68,26 +72,20 @@ pub fn log_likelihood_matrix( .into_par_iter() .enumerate() .try_for_each(|(i, mut row)| { - row.axis_iter_mut(Axis(0)) - .into_par_iter() - .enumerate() - .try_for_each(|(j, mut element)| { - let subject = subjects_vec.get(i).unwrap(); - match equation.estimate_log_likelihood( - subject, - &support_points.row(j).to_vec(), - error_models, - ) { - Ok(log_likelihood) => { - element.fill(log_likelihood); - if let Some(ref tracker) = progress_tracker { - tracker.inc(); - } - } - Err(e) => return Err(e), - }; - Ok(()) - }) + let subject = &subject_slice[i]; + + for (element, support_point) in row.iter_mut().zip(support_point_rows.iter()) { + *element = equation.estimate_log_likelihood( + subject, + support_point.as_slice(), + error_models, + )?; + if let Some(ref tracker) = progress_tracker { + tracker.inc(); + } + } + + Ok(()) }); if let Some(tracker) = progress_tracker { diff --git a/src/simulator/likelihood/mod.rs b/src/simulator/likelihood/mod.rs index c703dee7..17d07a19 100644 --- a/src/simulator/likelihood/mod.rs +++ b/src/simulator/likelihood/mod.rs @@ -111,8 +111,8 @@ pub fn log_likelihood_batch( parameters: &Array2, residual_error_models: &crate::ResidualErrorModels, ) -> Result, PharmsolError> { - let subjects_vec = subjects.subjects(); - let n_subjects = subjects_vec.len(); + let subject_slice = subjects.subjects_slice(); + let n_subjects = subject_slice.len(); if parameters.nrows() != n_subjects { return Err(PharmsolError::OtherError(format!( @@ -123,10 +123,10 @@ pub fn log_likelihood_batch( } // Parallel computation across subjects - let results: Vec = (0..n_subjects) - .into_par_iter() - .map(|i| { - let subject = &subjects_vec[i]; + let results: Vec = subject_slice + .par_iter() + .enumerate() + .map(|(i, subject)| { let params = parameters.row(i).to_vec(); // Simulate to get predictions @@ -223,7 +223,7 @@ mod tests { }; // Create error model with additive error - let error_models = crate::AssayErrorModels::new() + let error_models = crate::AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 1.0, 0.0, 0.0), 0.0), @@ -270,7 +270,7 @@ mod tests { ]; let subject_predictions = SubjectPredictions::from(predictions); - let error_models = crate::AssayErrorModels::new() + let error_models = crate::AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 1.0, 0.0, 0.0), 0.0), @@ -294,7 +294,7 @@ mod tests { #[test] fn test_empty_predictions_have_neutral_log_likelihood() { let preds = SubjectPredictions::default(); - let errors = crate::AssayErrorModels::new(); + let errors = crate::AssayErrorModels::empty(); assert_eq!(preds.log_likelihood(&errors).unwrap(), 0.0); // log(1) = 0 } @@ -305,7 +305,9 @@ mod tests { preds.add_prediction(obs.to_prediction(1.0, vec![])); let error_model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 0.0); - let errors = crate::AssayErrorModels::new().add(0, error_model).unwrap(); + let errors = crate::AssayErrorModels::empty() + .add(0, error_model) + .unwrap(); let log_lik = preds.log_likelihood(&errors).unwrap(); assert!(log_lik.is_finite()); diff --git a/src/simulator/likelihood/prediction.rs b/src/simulator/likelihood/prediction.rs index fbd60230..b1610f56 100644 --- a/src/simulator/likelihood/prediction.rs +++ b/src/simulator/likelihood/prediction.rs @@ -233,7 +233,7 @@ mod tests { } fn create_error_models() -> AssayErrorModels { - AssayErrorModels::new() + AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 0.0), diff --git a/src/simulator/likelihood/subject.rs b/src/simulator/likelihood/subject.rs index 77d8963b..5c3346fe 100644 --- a/src/simulator/likelihood/subject.rs +++ b/src/simulator/likelihood/subject.rs @@ -169,7 +169,7 @@ mod tests { use crate::Censor; fn create_error_models() -> AssayErrorModels { - AssayErrorModels::new() + AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 1.0, 0.0, 0.0), 0.0), @@ -199,7 +199,7 @@ mod tests { preds.add_prediction(obs.to_prediction(1.0, vec![])); let error_model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 0.0); - let errors = AssayErrorModels::new().add(0, error_model).unwrap(); + let errors = AssayErrorModels::empty().add(0, error_model).unwrap(); let log_lik = preds.log_likelihood(&errors).unwrap(); assert!(log_lik.is_finite()); diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 5cea84fe..058ca125 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -200,7 +200,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; /// /// # Fields /// - `nstates`: Number of state variables (ODE compartments) -/// - `ndrugs`: Number of drug input channels (size of bolus[] and rateiv[]) +/// - `ndrugs`: Number of drug inputs (size of bolus[] and rateiv[]) /// - `nout`: Number of output equations /// /// # Defaults @@ -218,7 +218,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; pub struct Neqs { /// Number of state variables pub nstates: usize, - /// Number of drug input channels (bolus/rateiv size) + /// Number of drug inputs (bolus/rateiv size) pub ndrugs: usize, /// Number of output equations pub nout: usize, diff --git a/src/test_fixtures.rs b/src/test_fixtures.rs index 7fb1610e..91d21e90 100644 --- a/src/test_fixtures.rs +++ b/src/test_fixtures.rs @@ -83,7 +83,7 @@ model one_cmt_abs { oral -> depot } analytical { - kernel = one_compartment_with_absorption + structure = one_compartment_with_absorption } outputs { cp = central / v diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs new file mode 100644 index 00000000..44075cff --- /dev/null +++ b/tests/analytical_macro_lowering.rs @@ -0,0 +1,492 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; + +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("analytical-macro-iv") + .infusion(0.0, 120.0, input, 1.0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("analytical-macro-oral") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + +fn shared_input_subject() -> Subject { + Subject::builder("analytical-macro-shared") + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .build() +} + +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { + Subject::builder("analytical-macro-covariates") + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) + .missing_observation(8.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +fn macro_one_compartment() -> equation::Analytical { + analytical! { + name: "one_cpt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + structure: one_compartment, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_one_compartment() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_iv") + .kind(equation::ModelKind::Analytical) + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(equation::Route::infusion("iv").to_state("central")) + .analytical_kernel(equation::AnalyticalKernel::OneCompartment), + ) + .expect("handwritten analytical metadata should validate") +} + +fn macro_one_compartment_with_absorption() -> equation::Analytical { + analytical! { + name: "one_cmt_abs", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + ], + structure: one_compartment_with_absorption, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + init: |_t, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_one_compartment_with_absorption() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + ) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten absorption metadata should validate") +} + +fn macro_shared_input_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_shared", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + infusion(iv) -> central, + ], + structure: one_compartment_with_absorption, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_shared_input_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_shared") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten shared-input analytical metadata should validate") +} + +fn macro_covariate_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_covariates", + params: [ka, ke, v, tlag, f_oral, base_gut, base_central, tvke], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + infusion(iv) -> central, + ], + structure: one_compartment_with_absorption, + sec: |_t| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + ke = tvke * wt_scale * renal_scale; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_covariate_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_covariates") + .kind(equation::ModelKind::Analytical) + .parameters([ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten covariate analytical metadata should validate") +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +#[test] +fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_one_compartment(); + let handwritten_model = handwritten_one_compartment(); + let subject = infusion_subject("iv", "cp"); + let support_point = [0.2, 10.0]; + let macro_metadata = macro_model + .metadata() + .expect("macro analytical metadata exists"); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_model.state_index("central"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { + let macro_model = macro_one_compartment_with_absorption(); + let handwritten_model = handwritten_one_compartment_with_absorption(); + let subject = oral_subject("oral", "cp"); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro metadata exists"); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!( + macro_metadata.analytical_kernel(), + Some(equation::AnalyticalKernel::OneCompartmentWithAbsorption) + ); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro absorption model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten absorption model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_analytical(); + let handwritten_model = handwritten_shared_input_analytical(); + let subject = shared_input_subject(); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro metadata exists"); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(1)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro shared-input analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-input analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_covariates_lower_to_handwritten_behavior() { + let macro_model = macro_covariate_analytical(); + let handwritten_model = handwritten_covariate_analytical(); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + + let subject = covariate_subject("oral", "iv", "cp"); + let support_point = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro covariate analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs new file mode 100644 index 00000000..38fd4c9d --- /dev/null +++ b/tests/authoring_parity_corpus.rs @@ -0,0 +1,1752 @@ +#[cfg(feature = "dsl-jit")] +use approx::assert_relative_eq; +#[cfg(feature = "dsl-jit")] +use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; +#[cfg(feature = "dsl-jit")] +use pharmsol::equation::RouteInputPolicy; +use pharmsol::equation::{ + self, AnalyticalKernel, RouteKind as HandwrittenRouteKind, ValidatedModelMetadata, +}; +use pharmsol::prelude::*; +#[cfg(feature = "dsl-jit")] +use pharmsol::Predictions; +use pharmsol_dsl::{ + analyze_model, lower_typed_model, parse_model, CovariateInterpolation, ExecutionModel, + ModelKind, RouteKind as DslRouteKind, +}; + +const ODE_DSL: &str = r#" +name = one_cmt_oral_iv +kind = ode + +params = ka, cl, v, tlag, f_oral +covariates = wt @linear +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - (cl / v) * central + +out(cp) = central / (v * (wt / 70.0)) ~ continuous() +"#; + +const ODE_MACRO_DSL: &str = r#" +name = one_cmt_oral_covariate_parity +kind = ode + +params = ka, cl, v, tlag, f_oral +covariates = wt @linear +states = depot, central +outputs = cp + +bolus(oral) -> depot +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - (cl / v) * central + +out(cp) = central / (v * (wt / 70.0)) ~ continuous() +"#; + +const ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL: &str = r#" +name = multi_digit_output_order +kind = ode + +params = ke, v +states = central +outputs = outeq_2, outeq_10, outeq_11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(outeq_10) = central / v ~ continuous() +out(outeq_2) = central / v ~ continuous() +out(outeq_11) = central / v ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL: &str = r#" +name = authoring_prefixed_numeric_routes +kind = ode + +states = first, second +outputs = cp + +bolus(input_10) -> first +bolus(input_11) -> second + +dx(first) = 0 +dx(second) = 0 + +out(cp) = first + second ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL: &str = r#"model structured_numeric_routes { + kind ode + states { + first, + second, + } + routes { + input_10 -> first + input_11 -> second + } + dynamics { + ddt(first) = 0 + ddt(second) = 0 + } + outputs { + cp = first + second + } +} +"#; + +const ODE_INVALID_INFUSION_LAG_DSL: &str = r#" +name = invalid_infusion_lag_parity +kind = ode + +params = ke, v, tlag +states = central +outputs = cp + +infusion(iv) -> central +lag(iv) = tlag + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_SHARED_INPUT_DSL: &str = r#" +name = shared_input_one_cpt +kind = ode + +params = ka, ke, v, tlag, f_oral +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL: &str = r#" +name = mixed_output_labels_runtime +kind = ode + +params = ke, v +states = central +outputs = cp, outeq_0, outeq_1 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +out(outeq_0) = 2 * central / v ~ continuous() +out(outeq_1) = 3 * central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + +const ANALYTICAL_DSL: &str = r#" +name = one_cmt_abs_parity +kind = analytical + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +structure = one_compartment_with_absorption + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ANALYTICAL_RUNTIME_SHARED_INPUT_DSL: &str = r#" +name = one_cmt_abs_shared +kind = analytical + +params = ka, ke, v, tlag, f_oral +states = gut, central +outputs = cp + +bolus(oral) -> gut +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral +structure = one_compartment_with_absorption + +out(cp) = central / v ~ continuous() +"#; + +const SDE_DSL: &str = r#" +name = one_cmt_sde_parity +kind = sde + +params = ka, ke, v, sigma +covariates = wt @locf +states = depot, central +outputs = cp + +bolus(oral) -> depot +particles = 256 + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central +noise(central) = sigma + +out(cp) = central / (v * wt) ~ continuous() +"#; + +const SDE_MACRO_DSL: &str = r#" +name = one_cmt_sde_macro_parity +kind = sde + +params = ka, ke, v, sigma +states = depot, central +outputs = cp + +bolus(oral) -> depot +particles = 256 + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central +noise(central) = sigma + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const SDE_RUNTIME_SHARED_INPUT_DSL: &str = r#" +name = one_cmt_shared_sde +kind = sde + +params = ka, ke, sigma_ke, v, tlag, f_oral +states = gut, central +outputs = cp +particles = 8 + +bolus(oral) -> gut +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(gut) = -ka * gut +dx(central) = ka * gut - ke * central +noise(central) = sigma_ke + +out(cp) = central / v ~ continuous() +"#; + +#[derive(Clone, Debug, PartialEq, Eq)] +struct MetadataParityView { + name: String, + kind: ModelKind, + parameters: Vec, + covariates: Vec, + states: Vec, + route_input_count: usize, + routes: Vec, + outputs: Vec, + analytical_kernel: Option, + particles: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct NamedIndex { + name: String, + index: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct CovariateParity { + name: String, + index: usize, + interpolation: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct RouteParity { + name: String, + kind: Option, + declaration_index: usize, + input_index: usize, + destination_name: String, + destination_index: usize, + has_lag: bool, + has_bioavailability: bool, +} + +#[cfg(feature = "dsl-jit")] +#[derive(Clone, Debug, PartialEq, Eq)] +struct RouteInputPolicyParity { + name: String, + declaration_index: usize, + input_index: usize, + input_policy: RouteInputPolicy, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum RouteKindParity { + Bolus, + Infusion, +} + +impl RouteKindParity { + fn from_dsl(kind: DslRouteKind) -> Self { + match kind { + DslRouteKind::Bolus => Self::Bolus, + DslRouteKind::Infusion => Self::Infusion, + } + } + + fn from_handwritten(kind: HandwrittenRouteKind) -> Self { + match kind { + HandwrittenRouteKind::Bolus => Self::Bolus, + HandwrittenRouteKind::Infusion => Self::Infusion, + } + } +} + +fn load_execution_model(src: &str) -> ExecutionModel { + let parsed = parse_model(src).expect("DSL model should parse"); + let typed = analyze_model(&parsed).expect("DSL model should analyze"); + lower_typed_model(&typed).expect("DSL model should lower") +} + +#[cfg(feature = "dsl-jit")] +fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntimeModel { + dsl::compile_module_source_to_runtime( + src, + Some(model_name), + RuntimeCompilationTarget::Jit, + |_, _| {}, + ) + .expect("DSL runtime model should compile") +} + +#[cfg(feature = "dsl-jit")] +fn compiled_route_input_index(model: &dsl::CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .routes + .iter() + .find(|route| route.name == name) + .map(|route| route.index) +} + +#[cfg(feature = "dsl-jit")] +fn compiled_output_slot_index(model: &dsl::CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .outputs + .iter() + .find(|output| output.name == name) + .map(|output| output.index) +} + +#[cfg(feature = "dsl-jit")] +fn shared_input_prediction_subject() -> Subject { + Subject::builder("authoring-parity-shared-input") + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .build() +} + +fn dsl_metadata_view(src: &str) -> MetadataParityView { + let model = load_execution_model(src); + + let parameters = model + .metadata + .parameters + .iter() + .map(|parameter| NamedIndex { + name: parameter.name.clone(), + index: parameter.index, + }) + .collect(); + let covariates = model + .metadata + .covariates + .iter() + .map(|covariate| CovariateParity { + name: covariate.name.clone(), + index: covariate.index, + interpolation: covariate.interpolation, + }) + .collect(); + let states = model + .metadata + .states + .iter() + .map(|state| NamedIndex { + name: state.name.clone(), + index: state.offset, + }) + .collect(); + let outputs = model + .metadata + .outputs + .iter() + .map(|output| NamedIndex { + name: output.name.clone(), + index: output.index, + }) + .collect(); + let routes = model + .metadata + .routes + .iter() + .map(|route| RouteParity { + name: route.name.clone(), + kind: route.kind.map(RouteKindParity::from_dsl), + declaration_index: route.declaration_index, + input_index: route.index, + destination_name: route.destination.state_name.clone(), + destination_index: route.destination.state_offset, + has_lag: route.has_lag, + has_bioavailability: route.has_bioavailability, + }) + .collect(); + + MetadataParityView { + name: model.name, + kind: model.kind, + parameters, + covariates, + states, + route_input_count: model.abi.route_buffer.len, + routes, + outputs, + analytical_kernel: model.metadata.analytical, + particles: model.metadata.particles, + } +} + +#[cfg(feature = "dsl-jit")] +fn dsl_route_input_policy_view(src: &str) -> Vec { + let model = load_execution_model(src); + let info = dsl::NativeModelInfo::from_execution_model(&model); + + info.routes + .into_iter() + .map(|route| RouteInputPolicyParity { + name: route.name, + declaration_index: route.declaration_index, + input_index: route.index, + input_policy: if route.inject_input_to_destination { + RouteInputPolicy::InjectToDestination + } else { + RouteInputPolicy::ExplicitInputVector + }, + }) + .collect() +} + +fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityView { + MetadataParityView { + name: metadata.name().to_string(), + kind: metadata.kind(), + parameters: metadata + .parameters() + .iter() + .enumerate() + .map(|(index, parameter)| NamedIndex { + name: parameter.name().to_string(), + index, + }) + .collect(), + covariates: metadata + .covariates() + .iter() + .enumerate() + .map(|(index, covariate)| CovariateParity { + name: covariate.name().to_string(), + index, + interpolation: covariate.interpolation(), + }) + .collect(), + states: metadata + .states() + .iter() + .enumerate() + .map(|(index, state)| NamedIndex { + name: state.name().to_string(), + index, + }) + .collect(), + route_input_count: metadata.route_input_count(), + routes: metadata + .routes() + .iter() + .map(|route| RouteParity { + name: route.name().to_string(), + kind: Some(RouteKindParity::from_handwritten(route.kind())), + declaration_index: route.declaration_index(), + input_index: route.input_index(), + destination_name: route.destination().to_string(), + destination_index: route.destination_index(), + has_lag: route.has_lag(), + has_bioavailability: route.has_bioavailability(), + }) + .collect(), + outputs: metadata + .outputs() + .iter() + .enumerate() + .map(|(index, output)| NamedIndex { + name: output.name().to_string(), + index, + }) + .collect(), + analytical_kernel: metadata.analytical_kernel(), + particles: metadata.particles(), + } +} + +#[cfg(feature = "dsl-jit")] +fn handwritten_route_input_policy_view( + metadata: &ValidatedModelMetadata, +) -> Vec { + metadata + .routes() + .iter() + .map(|route| RouteInputPolicyParity { + name: route.name().to_string(), + declaration_index: route.declaration_index(), + input_index: route.input_index(), + input_policy: route + .input_policy() + .expect("route input policy should be explicit in this handwritten fixture"), + }) + .collect() +} + +fn macro_ode_model() -> equation::ODE { + ode! { + name: "one_cmt_oral_covariate_parity", + params: [ka, cl, v, tlag, f_oral], + covariates: [wt], + states: [depot, central], + outputs: [cp], + routes: [ + bolus(oral) -> depot, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - (cl / v) * x[central]; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, t, cov, y| { + fetch_cov!(cov, t, wt); + y[cp] = x[central] / (v * (wt / 70.0)); + }, + } +} + +fn handwritten_ode_macro_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_covariate_parity") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + ), + ) + .expect("handwritten macro-shape ODE metadata should validate") +} + +fn handwritten_ode_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_iv") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_input_macro_ode() -> equation::ODE { + ode! { + name: "shared_input_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [cp], + routes: [ + bolus(oral) -> depot, + infusion(iv) -> central, + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_input_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_input_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .inject_input_to_destination(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("handwritten shared-input ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_mismatched_shared_input_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_input_one_cpt_mismatched") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("mismatched shared-input ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_input_macro_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_shared", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + infusion(iv) -> central, + ], + structure: one_compartment_with_absorption, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_input_handwritten_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_shared") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten shared-input analytical metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_input_macro_sde() -> equation::SDE { + sde! { + name: "one_cmt_shared_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: [ + bolus(oral) -> gut, + infusion(iv) -> central, + ], + drift: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[gut] = 0.0; + sigma[central] = 0.0 * sigma_ke; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + init: |_p, _t, _cov, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_input_handwritten_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_shared_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten shared-input SDE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn assert_prediction_vectors_close(left: &[f64], right: &[f64], tolerance: f64) { + assert_eq!(left.len(), right.len()); + for (left_value, right_value) in left.iter().zip(right.iter()) { + let diff = (left_value - right_value).abs(); + assert!( + diff <= tolerance, + "prediction mismatch: left={left_value:.12}, right={right_value:.12}, diff={diff:.12}, tolerance={tolerance:.12}" + ); + } +} + +#[cfg(feature = "dsl-jit")] +fn assert_prediction_vectors_diverge(left: &[f64], right: &[f64], tolerance: f64) { + assert_eq!(left.len(), right.len()); + assert!( + left.iter() + .zip(right.iter()) + .any(|(left_value, right_value)| (left_value - right_value).abs() > tolerance), + "expected prediction vectors to diverge beyond tolerance {tolerance:.12}" + ); +} + +#[cfg(feature = "dsl-jit")] +fn particle_prediction_means(predictions: &ndarray::Array2) -> Vec { + predictions + .get_predictions() + .into_iter() + .map(|prediction| prediction.prediction()) + .collect() +} + +fn macro_analytical_model() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_parity", + params: [ka, ke, v], + states: [depot, central], + outputs: [cp], + routes: [ + bolus(oral) -> depot, + ], + structure: one_compartment_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_analytical_model() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_parity") + .kind(ModelKind::Analytical) + .parameters(["ka", "ke", "v"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(equation::Route::bolus("oral").to_state("depot")) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten analytical metadata should validate") +} + +fn macro_sde_model() -> equation::SDE { + sde! { + name: "one_cmt_sde_macro_parity", + params: [ka, ke, v, sigma], + states: [depot, central], + outputs: [cp], + particles: 256, + routes: [ + bolus(oral) -> depot, + ], + drift: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + diffusion: |_p, sigma_values| { + sigma_values[depot] = 0.0; + sigma_values[central] = sigma; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_sde_model() -> equation::SDE { + equation::SDE::new( + |_x, _p, _t, dx, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + 256, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_parity") + .kind(ModelKind::Sde) + .parameters(["ka", "ke", "v", "sigma"]) + .covariates([equation::Covariate::locf("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination(), + ) + .particles(256), + ) + .expect("handwritten SDE metadata should validate") +} + +fn handwritten_sde_macro_model() -> equation::SDE { + equation::SDE::new( + |_x, _p, _t, dx, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, sigma_values| { + sigma_values[0] = 0.0; + sigma_values[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + 256, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_macro_parity") + .kind(ModelKind::Sde) + .parameters(["ka", "ke", "v", "sigma"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination(), + ) + .particles(256), + ) + .expect("handwritten macro-shape SDE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn mismatched_ode_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_iv") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .expect_explicit_input() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("mismatched ODE metadata should validate") +} + +#[test] +fn ode_dsl_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_ode_model(); + let dsl_view = dsl_metadata_view(ODE_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[test] +fn ode_dsl_declared_output_order_controls_dense_indices_for_multi_digit_labels() { + let dsl_view = dsl_metadata_view(ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL); + + assert_eq!( + dsl_view.outputs, + vec![ + NamedIndex { + name: "outeq_2".to_string(), + index: 0, + }, + NamedIndex { + name: "outeq_10".to_string(), + index: 1, + }, + NamedIndex { + name: "outeq_11".to_string(), + index: 2, + }, + ] + ); +} + +#[test] +fn ode_authoring_dsl_supports_prefixed_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "input_10".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "input_11".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + +#[test] +fn ode_structured_dsl_supports_prefixed_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "input_10".to_string(), + kind: None, + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "input_11".to_string(), + kind: None, + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + +#[test] +fn ode_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { + let handwritten = handwritten_ode_macro_model(); + let macro_model = macro_ode_model(); + let dsl_view = dsl_metadata_view(ODE_MACRO_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten macro-shape ODE metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[test] +fn analytical_dsl_macro_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_analytical_model(); + let macro_model = macro_analytical_model(); + let dsl_view = dsl_metadata_view(ANALYTICAL_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten analytical metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro analytical metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[test] +fn sde_dsl_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_sde_model(); + let dsl_view = dsl_metadata_view(SDE_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[test] +fn sde_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { + let handwritten = handwritten_sde_macro_model(); + let macro_model = macro_sde_model(); + let dsl_view = dsl_metadata_view(SDE_MACRO_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten macro-shape SDE metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_route_input_policies_agree_with_handwritten_metadata() { + let dsl_view = dsl_route_input_policy_view(ODE_DSL); + let handwritten = handwritten_ode_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("handwritten ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn sde_route_input_policies_agree_with_handwritten_metadata() { + let dsl_view = dsl_route_input_policy_view(SDE_DSL); + let handwritten = handwritten_sde_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("handwritten SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn route_input_policy_mismatches_are_detected_explicitly() { + let dsl_view = dsl_route_input_policy_view(ODE_DSL); + let handwritten = mismatched_ode_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("mismatched handwritten metadata should exist"), + ); + + assert_ne!(handwritten_view, dsl_view); + assert_eq!(dsl_view[0].name, "oral"); + assert_eq!( + dsl_view[0].input_policy, + RouteInputPolicy::InjectToDestination + ); + assert_eq!( + handwritten_view[0].input_policy, + RouteInputPolicy::ExplicitInputVector + ); +} + +#[test] +fn invalid_dsl_infusion_route_properties_fail_explicitly() { + let model = + parse_model(ODE_INVALID_INFUSION_LAG_DSL).expect("invalid DSL fixture should parse"); + let typed = analyze_model(&model).expect("invalid DSL fixture should analyze"); + let error = lower_typed_model(&typed).expect_err("infusion lag should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `lag` on infusion route `iv`")); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { + let runtime_model = + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let macro_model = runtime_shared_input_macro_ode(); + let handwritten_model = runtime_shared_input_handwritten_ode(); + let macro_metadata = macro_model.metadata().expect("macro ODE metadata exists"); + let handwritten_metadata = handwritten_model + .metadata() + .expect("handwritten ODE metadata exists"); + + let oral = compiled_route_input_index(&runtime_model, "oral") + .expect("runtime oral route should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); + let subject = shared_input_prediction_subject(); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(cp, 0); + assert_eq!( + macro_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + macro_metadata.route("iv").map(|route| route.input_index()), + Some(iv) + ); + assert_eq!( + handwritten_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + handwritten_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime ODE model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro ODE model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten ODE model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-4); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-4); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { + let runtime_model = + compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_INPUT_DSL, "one_cmt_abs_shared"); + let macro_model = runtime_shared_input_macro_analytical(); + let handwritten_model = runtime_shared_input_handwritten_analytical(); + let macro_metadata = macro_model + .metadata() + .expect("macro analytical metadata exists"); + let handwritten_metadata = handwritten_model + .metadata() + .expect("handwritten analytical metadata exists"); + + let oral = compiled_route_input_index(&runtime_model, "oral") + .expect("runtime oral route should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); + let subject = shared_input_prediction_subject(); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(cp, 0); + assert_eq!( + macro_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + macro_metadata.route("iv").map(|route| route.input_index()), + Some(iv) + ); + assert_eq!( + handwritten_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + handwritten_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime analytical model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => { + panic!("analytical runtime should return subject predictions") + } + }; + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-8); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-8); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { + let runtime_model = + compile_runtime_jit_model(SDE_RUNTIME_SHARED_INPUT_DSL, "one_cmt_shared_sde"); + let macro_model = runtime_shared_input_macro_sde(); + let handwritten_model = runtime_shared_input_handwritten_sde(); + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); + let handwritten_metadata = handwritten_model + .metadata() + .expect("handwritten SDE metadata exists"); + + let oral = compiled_route_input_index(&runtime_model, "oral") + .expect("runtime oral route should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); + let subject = shared_input_prediction_subject(); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(cp, 0); + assert_eq!( + macro_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + macro_metadata.route("iv").map(|route| route.input_index()), + Some(iv) + ); + assert_eq!( + handwritten_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + handwritten_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime SDE model should simulate") + { + RuntimePredictions::Particles(predictions) => particle_prediction_means(&predictions), + RuntimePredictions::Subject(_) => panic!("SDE runtime should return particle predictions"), + }; + let macro_predictions = particle_prediction_means( + ¯o_model + .estimate_predictions(&subject, &support_point) + .expect("macro SDE model should simulate"), + ); + let handwritten_predictions = particle_prediction_means( + &handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten SDE model should simulate"), + ); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-4); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-4); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn route_input_policy_runtime_mismatches_are_detected_explicitly() { + let runtime_model = + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let mismatched_model = runtime_mismatched_shared_input_ode(); + let mismatched_metadata = mismatched_model + .metadata() + .expect("mismatched handwritten metadata exists"); + + let oral = compiled_route_input_index(&runtime_model, "oral") + .expect("runtime oral route should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); + let subject = shared_input_prediction_subject(); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(cp, 0); + assert_eq!( + mismatched_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + mismatched_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime ODE model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + let mismatched_predictions = mismatched_model + .estimate_predictions(&subject, &support_point) + .expect("mismatched handwritten ODE should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_diverge(&runtime_predictions, &mismatched_predictions, 1e-4); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_preserves_mixed_output_labels() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL, + "mixed_output_labels_runtime", + ); + let subject = Subject::builder("runtime-mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(0.5, "outeq_0") + .missing_observation(0.5, "outeq_1") + .build(); + let support_point = [0.2, 10.0]; + + assert_eq!(compiled_output_slot_index(&runtime_model, "cp"), Some(0)); + assert_eq!( + compiled_output_slot_index(&runtime_model, "outeq_0"), + Some(1) + ); + assert_eq!( + compiled_output_slot_index(&runtime_model, "outeq_1"), + Some(2) + ); + + let predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime mixed-output model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + + assert_eq!(predictions.len(), 3); + assert_relative_eq!(predictions[1], 2.0 * predictions[0], epsilon = 1e-6); + assert_relative_eq!(predictions[2], 3.0 * predictions[0], epsilon = 1e-6); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_output_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL, + "undeclared_numeric_output_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == "10" + )); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_input_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL, + "undeclared_numeric_input_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == "10" + )); +} diff --git a/tests/browser-e2e/site/app.mjs b/tests/browser-e2e/site/app.mjs index 87b68eda..c939c670 100644 --- a/tests/browser-e2e/site/app.mjs +++ b/tests/browser-e2e/site/app.mjs @@ -4,7 +4,7 @@ const precompiledInputs = Object.freeze({ }); const compileFlowSource = ` -model = example_ode +name = example_ode kind = ode params = ke, v @@ -25,7 +25,7 @@ const compileFlowInputs = Object.freeze({ }); const invalidCompileSource = ` -model = broken +name = broken kind = ode states = central diff --git a/tests/full_feature_dsl_backend_parity.rs b/tests/full_feature_dsl_backend_parity.rs new file mode 100644 index 00000000..b83157dd --- /dev/null +++ b/tests/full_feature_dsl_backend_parity.rs @@ -0,0 +1,208 @@ +#[path = "support/runtime_corpus.rs"] +mod runtime_corpus; + +#[cfg(all(feature = "dsl-jit", feature = "dsl-wasm"))] +mod tests { + use super::runtime_corpus::{self as corpus, CorpusCase}; + use pharmsol::dsl::{CompiledRuntimeModel, RuntimeBackend}; + + fn owned_names(names: &[&str]) -> Vec { + names.iter().map(|name| (*name).to_owned()).collect() + } + + fn assert_info_matches( + left_label: &str, + left: &CompiledRuntimeModel, + right_label: &str, + right: &CompiledRuntimeModel, + ) { + assert_eq!( + left.info(), + right.info(), + "{left_label} model info diverged from {right_label}" + ); + } + + fn assert_ode_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "ode_full_feature_parity"); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ]) + ); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.index) + .collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + } + + fn assert_analytical_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "analytical_full_feature_parity"); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + ); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.index) + .collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + } + + fn assert_full_backend_parity( + case: CorpusCase, + assert_public_shape: fn(&CompiledRuntimeModel), + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(case)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + assert_public_shape(&jit); + corpus::assert_runtime_model_matches_reference(case, "runtime-jit", &jit)?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(case, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + assert_public_shape(&aot); + corpus::assert_runtime_model_matches_reference(case, "runtime-native-aot", &aot)?; + } + + let wasm = corpus::compile_runtime_wasm_model(case)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + assert_public_shape(&wasm); + corpus::assert_runtime_model_matches_reference(case, "runtime-wasm", &wasm)?; + + assert_info_matches("runtime-jit", &jit, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_info_matches("runtime-jit", &jit, "runtime-native-aot", &aot); + assert_info_matches("runtime-native-aot", &aot, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-native-aot", + &aot, + )?; + corpus::assert_runtime_models_match_each_other( + case, + "runtime-native-aot", + &aot, + "runtime-wasm", + &wasm, + )?; + } + + Ok(()) + } + + #[test] + fn ode_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity(CorpusCase::OdeFull, assert_ode_full_public_shape) + } + + #[test] + fn analytical_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity( + CorpusCase::AnalyticalFull, + assert_analytical_full_public_shape, + ) + } +} diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs new file mode 100644 index 00000000..4fe7b442 --- /dev/null +++ b/tests/full_feature_macro_parity.rs @@ -0,0 +1,480 @@ +use pharmsol::prelude::*; + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn macro_ode_model() -> equation::ODE { + ode! { + name: "ode_full_feature_parity", + params: [ka, ke, kcp, kpc, v, tlag, f_oral, base_depot, base_central, base_peripheral], + covariates: [wt, renal], + states: [depot, central, peripheral], + outputs: [cp], + routes: [ + bolus(oral) -> depot, + bolus(load) -> central, + infusion(iv) -> central, + ], + diffeq: |x, _t, dx| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] + - (adjusted_ke + adjusted_kcp) * x[central] + + kpc * x[peripheral]; + dx[peripheral] = adjusted_kcp * x[central] - kpc * x[peripheral]; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[depot] = base_depot + 0.05 * wt; + x[central] = base_central + 0.1 * renal; + x[peripheral] = base_peripheral + 0.02 * wt; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_ode_model() -> equation::ODE { + equation::ODE::new( + |x, p, t, dx, bolus, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + kcp, + kpc, + _v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[0] = bolus[0] - ka * x[0]; + dx[1] = + bolus[1] + ka * x[0] + rateiv[0] - (adjusted_ke + adjusted_kcp) * x[1] + kpc * x[2]; + dx[2] = adjusted_kcp * x[1] - kpc * x[2]; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + _f_oral, + base_depot, + base_central, + base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_depot + 0.05 * wt; + x[1] = base_central + 0.1 * renal; + x[2] = base_peripheral + 0.02 * wt; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(3) + .with_ndrugs(2) + .with_nout(1) + .with_metadata( + equation::metadata::new("ode_full_feature_parity") + .parameters([ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["depot", "central", "peripheral"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .inject_input_to_destination(), + equation::Route::bolus("load") + .to_state("central") + .inject_input_to_destination(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("handwritten ODE metadata should validate") +} + +fn build_ode_subject() -> Subject { + Subject::builder("macro-vs-handwritten-ode-full-features") + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +fn macro_analytical_model() -> equation::Analytical { + analytical! { + name: "analytical_full_feature_parity", + params: [ka, ke, v, tlag, f_oral, base_gut, base_central, tvke], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + bolus(load) -> central, + infusion(iv) -> central, + ], + structure: one_compartment_with_absorption, + sec: |_t| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + ke = tvke * wt_scale * renal_scale; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_analytical_model() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(2) + .with_nout(1) + .with_metadata( + equation::metadata::new("analytical_full_feature_parity") + .kind(equation::ModelKind::Analytical) + .parameters([ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::bolus("load").to_state("central"), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten analytical metadata should validate") +} + +fn build_analytical_subject() -> Subject { + Subject::builder("macro-vs-handwritten-analytical-full-features") + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +#[test] +fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_ode_model(); + let handwritten_ode = handwritten_ode_model(); + let macro_metadata = macro_ode.metadata().expect("macro ODE metadata exists"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let oral = macro_metadata + .route("oral") + .map(|route| route.input_index()) + .expect("oral route exists"); + let load = macro_metadata + .route("load") + .map(|route| route.input_index()) + .expect("load route exists"); + let iv = macro_metadata + .route("iv") + .map(|route| route.input_index()) + .expect("iv route exists"); + + assert_eq!(oral, iv); + assert_eq!(load, 1); + assert!(macro_metadata.output("cp").is_some()); + + let subject = build_ode_subject(); + let params = [1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0]; + + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let diff = max_abs_diff( + ¯o_predictions.flat_predictions(), + &handwritten_predictions.flat_predictions(), + ); + assert!( + diff <= 1e-10, + "macro and handwritten ODE predictions diverged: {diff:e}" + ); + + Ok(()) +} + +#[test] +fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { + let macro_analytical = macro_analytical_model(); + let handwritten_analytical = handwritten_analytical_model(); + let macro_metadata = macro_analytical + .metadata() + .expect("macro analytical metadata exists"); + + assert_eq!( + macro_analytical.metadata(), + handwritten_analytical.metadata() + ); + + let oral = macro_metadata + .route("oral") + .map(|route| route.input_index()) + .expect("oral route exists"); + let load = macro_metadata + .route("load") + .map(|route| route.input_index()) + .expect("load route exists"); + let iv = macro_metadata + .route("iv") + .map(|route| route.input_index()) + .expect("iv route exists"); + + assert_eq!(oral, iv); + assert_eq!(load, 1); + assert!(macro_metadata.output("cp").is_some()); + + let subject = build_analytical_subject(); + let params = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; + + let macro_predictions = macro_analytical.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_analytical.estimate_predictions(&subject, ¶ms)?; + + let diff = max_abs_diff( + ¯o_predictions.flat_predictions(), + &handwritten_predictions.flat_predictions(), + ); + assert!( + diff <= 1e-10, + "macro and handwritten analytical predictions diverged: {diff:e}" + ); + + Ok(()) +} diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs new file mode 100644 index 00000000..8e5540ee --- /dev/null +++ b/tests/ode_macro_lowering.rs @@ -0,0 +1,633 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::data::read_pmetrics; +use pharmsol::prelude::*; +use tempfile::NamedTempFile; + +fn write_pmetrics_fixture(contents: &str) -> NamedTempFile { + let file = NamedTempFile::new().expect("create temporary Pmetrics fixture"); + std::fs::write(file.path(), contents).expect("write temporary Pmetrics fixture"); + file +} + +fn subject_for_route(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("macro-lowering") + .infusion(0.0, 100.0, input, 1.0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + +fn subject_for_shared_input() -> Subject { + Subject::builder("macro-shared-input") + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .build() +} + +fn subject_for_covariates(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("macro-covariates") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .covariate("wt", 0.0, 70.0) + .build() +} + +fn subject_for_numeric_bolus_route(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("numeric-bolus-route") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + +fn injected_macro_ode() -> equation::ODE { + ode! { + name: "injected_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn injected_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("injected_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten injected metadata should validate") +} + +fn numeric_label_macro_ode() -> equation::ODE { + ode! { + name: "numeric_label_one_cpt", + params: [ke, v], + states: [central], + outputs: [outeq_1], + routes: [ + infusion(input_1) -> central, + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; + }, + out: |x, _t, y| { + y[outeq_1] = x[central] / v; + }, + } +} + +fn numeric_label_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_label_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["outeq_1"]) + .route( + equation::Route::infusion("input_1") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten numeric-label metadata should validate") +} + +fn shared_input_macro_ode() -> equation::ODE { + ode! { + name: "shared_input_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [cp], + routes: [ + bolus(oral) -> depot, + infusion(iv) -> central, + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn shared_input_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_input_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .inject_input_to_destination(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("handwritten shared-input metadata should validate") +} + +fn numeric_route_property_macro_ode() -> equation::ODE { + ode! { + name: "numeric_route_property_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [outeq_1], + routes: [ + bolus(input_1) -> depot, + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + lag: |_t| { + lag! { input_1 => tlag } + }, + fa: |_t| { + fa! { input_1 => f_oral } + }, + out: |x, _t, y| { + y[outeq_1] = x[central] / v; + }, + } +} + +fn numeric_route_property_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_route_property_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["outeq_1"]) + .route( + equation::Route::bolus("input_1") + .to_state("depot") + .with_lag() + .with_bioavailability() + .inject_input_to_destination(), + ), + ) + .expect("handwritten numeric route-property metadata should validate") +} + +fn mixed_output_labels_macro_ode() -> equation::ODE { + ode! { + name: "mixed_output_labels_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp, outeq_0, outeq_1], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + y[outeq_0] = 2.0 * x[central] / v; + y[outeq_1] = 3.0 * x[central] / v; + }, + } +} + +fn mixed_output_labels_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + y[1] = 2.0 * x[0] / v; + y[2] = 3.0 * x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(3) + .with_metadata( + equation::metadata::new("mixed_output_labels_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp", "outeq_0", "outeq_1"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten mixed-output metadata should validate") +} + +fn covariate_macro_ode() -> equation::ODE { + ode! { + name: "covariate_one_cpt", + params: [ka, ke, v], + covariates: [wt], + states: [gut, central], + outputs: [cp], + routes: [ + bolus(oral) -> gut, + ], + diffeq: |x, _t, dx| { + let scaled_ke = ke * (wt / 70.0); + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - scaled_ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn covariate_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, t, dx, bolus, _rateiv, cov| { + fetch_cov!(cov, t, wt); + fetch_params!(p, ka, ke, _v); + let scaled_ke = ke * (wt / 70.0); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] - scaled_ke * x[1]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("covariate_one_cpt") + .parameters(["ka", "ke", "v"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination(), + ), + ) + .expect("handwritten covariate metadata should validate") +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +#[test] +fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = injected_macro_ode(); + let handwritten_ode = injected_handwritten_ode(); + let subject = subject_for_route("iv", "cp"); + let support_point = [0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro injected model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro injected model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten injected model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_numeric_labels_lower_to_dense_slots() { + let macro_ode = numeric_label_macro_ode(); + let handwritten_ode = numeric_label_handwritten_ode(); + let subject = subject_for_route("1", "1"); + let support_point = [0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro numeric-label model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert!(macro_metadata.route("input_1").is_some()); + assert!(macro_metadata.output("outeq_1").is_some()); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro numeric-label model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten numeric-label model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = shared_input_macro_ode(); + let handwritten_ode = shared_input_handwritten_ode(); + let subject = subject_for_shared_input(); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_ode + .metadata() + .expect("macro shared-input model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_ode.state_index("depot"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro shared-input model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-input model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_mixed_output_labels_lower_to_dense_slots() { + let macro_ode = mixed_output_labels_macro_ode(); + let handwritten_ode = mixed_output_labels_handwritten_ode(); + let subject = Subject::builder("mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "0") + .missing_observation(2.0, "1") + .build(); + let support_point = [0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro mixed-output model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert!(macro_metadata.output("cp").is_some()); + assert!(macro_metadata.output("outeq_0").is_some()); + assert!(macro_metadata.output("outeq_1").is_some()); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro mixed-output model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten mixed-output model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_numeric_route_properties_lower_to_dense_slots() { + let macro_ode = numeric_route_property_macro_ode(); + let handwritten_ode = numeric_route_property_handwritten_ode(); + let subject = subject_for_numeric_bolus_route("1", "1"); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_ode + .metadata() + .expect("macro numeric route-property model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert!(macro_metadata.route("input_1").is_some()); + assert!(macro_metadata.output("outeq_1").is_some()); + assert_eq!(macro_ode.state_index("depot"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro numeric route-property model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten numeric route-property model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_named_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,cp,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read named-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = injected_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro named-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = injected_macro_ode() + .estimate_predictions(&subject_for_route("iv", "cp"), &support_point) + .expect("macro internal-index model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + +#[test] +fn macro_numeric_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,1,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,1,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read numeric-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = numeric_label_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro numeric-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = numeric_label_macro_ode() + .estimate_predictions(&subject_for_route("1", "1"), &support_point) + .expect("macro internal-index numeric-label model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + +#[test] +fn macro_covariate_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = covariate_macro_ode(); + let handwritten_ode = covariate_handwritten_ode(); + let subject = subject_for_covariates("oral", "cp"); + let support_point = [1.0, 0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro covariate model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_metadata.covariates().len(), 1); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_ode.state_index("gut"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro covariate model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} diff --git a/tests/ode_optimizations.rs b/tests/ode_optimizations.rs index 488ae0f3..e50ecf10 100644 --- a/tests/ode_optimizations.rs +++ b/tests/ode_optimizations.rs @@ -900,7 +900,7 @@ fn likelihood_calculation_matches_analytical() { .with_nstates(1) .with_nout(1); - let error_models = AssayErrorModels::new() + let error_models = AssayErrorModels::default() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 0.1, 0.0, 0.0), 0.0), diff --git a/tests/runtime_backend_matrix.rs b/tests/runtime_backend_matrix.rs index 6a207398..fdabc94d 100644 --- a/tests/runtime_backend_matrix.rs +++ b/tests/runtime_backend_matrix.rs @@ -84,6 +84,84 @@ mod tests { Ok(()) } + #[test] + fn analytical_full_runtime_backend_matrix_matches_reference_predictions( + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(CorpusCase::AnalyticalFull)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-jit", + &jit, + )?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(CorpusCase::AnalyticalFull, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-native-aot", + &aot, + )?; + + let wasm = corpus::compile_runtime_wasm_model(CorpusCase::AnalyticalFull)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-wasm", + &wasm, + )?; + corpus::assert_runtime_models_match_each_other( + CorpusCase::AnalyticalFull, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + Ok(()) + } + + #[test] + fn ode_full_runtime_backend_matrix_matches_reference_predictions( + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(CorpusCase::OdeFull)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + corpus::assert_runtime_model_matches_reference(CorpusCase::OdeFull, "runtime-jit", &jit)?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(CorpusCase::OdeFull, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + corpus::assert_runtime_model_matches_reference( + CorpusCase::OdeFull, + "runtime-native-aot", + &aot, + )?; + + let wasm = corpus::compile_runtime_wasm_model(CorpusCase::OdeFull)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + corpus::assert_runtime_model_matches_reference(CorpusCase::OdeFull, "runtime-wasm", &wasm)?; + corpus::assert_runtime_models_match_each_other( + CorpusCase::OdeFull, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + Ok(()) + } + #[test] fn sde_runtime_backend_matrix_matches_reference_predictions( ) -> Result<(), Box> { diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs new file mode 100644 index 00000000..7980ccd4 --- /dev/null +++ b/tests/sde_macro_lowering.rs @@ -0,0 +1,591 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; +use pharmsol::Predictions; + +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("sde-macro-iv") + .infusion(0.0, 120.0, input, 1.0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("sde-macro-oral") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + +fn shared_input_subject() -> Subject { + Subject::builder("sde-macro-shared") + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .build() +} + +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { + Subject::builder("sde-macro-covariates") + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) + .missing_observation(8.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +fn prediction_means(predictions: &ndarray::Array2) -> Vec { + predictions + .get_predictions() + .into_iter() + .map(|prediction| prediction.prediction()) + .collect() +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +fn macro_infusion_sde() -> equation::SDE { + sde! { + name: "one_cpt_sde", + params: [ke, sigma_ke, v], + states: [central], + outputs: [cp], + particles: 16, + routes: [ + infusion(iv) -> central, + ], + drift: |x, _t, dx| { + dx[central] = -ke * x[central]; + }, + diffusion: |sigma| { + sigma[central] = sigma_ke; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_infusion_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ke, _sigma_ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |p, sigma| { + fetch_params!(p, _ke, sigma_ke, _v); + sigma[0] = sigma_ke; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, _sigma_ke, v); + y[0] = x[0] / v; + }, + 16, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ke", "sigma_ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("handwritten SDE metadata should validate") +} + +fn macro_absorption_sde() -> equation::SDE { + sde! { + name: "one_cmt_abs_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: [ + bolus(oral) -> gut, + ], + drift: |x, _t, dx| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |sigma| { + sigma[gut] = 0.0 * sigma_ke; + sigma[central] = sigma_ke; + }, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + init: |_t, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_absorption_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, _rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, sigma| { + fetch_params!(p, _ka, _ke, sigma_ke, _v, _tlag, _f_oral); + sigma[0] = 0.0 * sigma_ke; + sigma[1] = sigma_ke; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + ) + .particles(8), + ) + .expect("handwritten absorption SDE metadata should validate") +} + +fn macro_shared_input_sde() -> equation::SDE { + sde! { + name: "one_cmt_shared_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: [ + bolus(oral) -> gut, + infusion(iv) -> central, + ], + drift: |x, _t, dx| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |sigma| { + sigma[gut] = 0.0; + sigma[central] = 0.0; + }, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + init: |_t, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_shared_input_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_shared_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten shared-input SDE metadata should validate") +} + +fn macro_covariate_sde() -> equation::SDE { + sde! { + name: "one_cmt_sde_covariates", + params: [ka, ke, sigma_ke, v, tlag, f_oral, base_gut, base_central], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: [ + bolus(oral) -> gut, + infusion(iv) -> central, + ], + drift: |x, _t, dx| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - adjusted_ke * x[central]; + }, + diffusion: |sigma| { + sigma[gut] = 0.0 * sigma_ke; + sigma[central] = 0.0 * sigma_ke; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_covariate_sde() -> equation::SDE { + equation::SDE::new( + |x, p, t, dx, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + _sigma_ke, + _v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - adjusted_ke * x[1]; + }, + |p, sigma| { + fetch_params!( + p, + _ka, + _ke, + sigma_ke, + _v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); + sigma[0] = 0.0 * sigma_ke; + sigma[1] = 0.0 * sigma_ke; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_covariates") + .kind(equation::ModelKind::Sde) + .parameters([ + "ka", + "ke", + "sigma_ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten covariate SDE metadata should validate") +} + +#[test] +fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_infusion_sde(); + let handwritten_model = handwritten_infusion_sde(); + let subject = infusion_subject("iv", "cp"); + let support_point = [0.2, 0.0, 10.0]; + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_model.state_index("central"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro SDE model should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten SDE model should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { + let macro_model = macro_absorption_sde(); + let handwritten_model = handwritten_absorption_sde(); + let subject = oral_subject("oral", "cp"); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_model.state_index("gut"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro absorption SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten absorption SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_sde(); + let handwritten_model = handwritten_shared_input_sde(); + let subject = shared_input_subject(); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(1)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro shared-input SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-input SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_covariates_lower_to_handwritten_behavior() { + let macro_model = macro_covariate_sde(); + let handwritten_model = handwritten_covariate_sde(); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + + let subject = covariate_subject("oral", "iv", "cp"); + let support_point = [1.0, 0.16, 0.0, 32.0, 0.5, 0.8, 3.0, 14.0]; + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro covariate SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index fe363aa7..97b958c2 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -12,7 +12,7 @@ pub const OBSERVATION_TIMES: [f64; 7] = [0.5, 1.0, 2.0, 3.0, 4.0, 6.0, 8.0]; pub const SUPPORT_POINT: [f64; 2] = [1.2, 50.0]; pub const AUTHORING_DSL: &str = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v @@ -55,6 +55,14 @@ fn subject_for_indices(route_index: usize, output_index: usize) -> Subject { builder.build() } +fn subject_for_labels(route_label: &str, output_label: &str) -> Subject { + let mut builder = Subject::builder(MODEL_NAME).infusion(0.0, 500.0, route_label, 0.5); + for time in OBSERVATION_TIMES { + builder = builder.missing_observation(time, output_label); + } + builder.build() +} + pub fn subject() -> Subject { subject_for_indices(0, 0) } @@ -65,12 +73,27 @@ pub fn subject() -> Subject { feature = "dsl-wasm" ))] pub fn subject_for_runtime_model(model: &pharmsol::dsl::CompiledRuntimeModel) -> Subject { - let route_index = model - .route_index("iv") - .or_else(|| model.route_index("input_0")) - .expect("bimodal_ke route is available"); - let output_index = model.output_index("cp").expect("cp output is available"); - subject_for_indices(route_index, output_index) + let route_label = if model.info().routes.iter().any(|route| route.name == "iv") { + "iv" + } else if model + .info() + .routes + .iter() + .any(|route| route.name == "input_0") + { + "input_0" + } else { + panic!("bimodal_ke route is available"); + }; + assert!( + model + .info() + .outputs + .iter() + .any(|output| output.name == "cp"), + "cp output is available" + ); + subject_for_labels(route_label, "cp") } pub fn reference_values() -> Result, Box> { diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 0a9917a0..0f32fada 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -19,7 +19,7 @@ use pharmsol::{equation, fa, fetch_cov, fetch_params, lag, Subject, SubjectBuild use tempfile::{tempdir, TempDir}; const ODE_SOURCE: &str = r#" -model = one_cmt_oral_iv +name = one_cmt_oral_iv kind = ode params = ka, cl, v, tlag, f_oral @@ -43,8 +43,40 @@ dx(central) = ka * depot - ke * central out(cp) = central / v ~ continuous() "#; +const ODE_FULL_SOURCE: &str = r#" +name = ode_full_feature_parity +kind = ode + +params = ka, ke, kcp, kpc, v, tlag, f_oral, base_depot, base_central, base_peripheral +covariates = wt@linear, renal@linear +derived = adjusted_ke, adjusted_kcp, adjusted_v +states = depot, central, peripheral +outputs = cp + +bolus(oral) -> depot +bolus(load) -> central +infusion(iv) -> central + +lag(oral) = tlag * sqrt(wt / 70.0) * pow(90.0 / renal, 0.1) +fa(oral) = min(max(f_oral * pow(renal / 90.0, 0.1), 0.0), 1.0) + +adjusted_ke = ke * pow(wt / 70.0, 0.75) * pow(renal / 90.0, 0.25) +adjusted_kcp = kcp * pow(wt / 70.0, 0.25) +adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)) + +dx(depot) = -ka * depot +dx(central) = ka * depot - (adjusted_ke + adjusted_kcp) * central + kpc * peripheral +dx(peripheral) = adjusted_kcp * central - kpc * peripheral + +init(depot) = base_depot + 0.05 * wt +init(central) = base_central + 0.1 * renal +init(peripheral) = base_peripheral + 0.02 * wt + +out(cp) = central / adjusted_v ~ continuous() +"#; + const ANALYTICAL_SOURCE: &str = r#" -model = one_cmt_abs +name = one_cmt_abs kind = analytical params = ka, ke, v, tlag, f_oral @@ -56,13 +88,41 @@ bolus(oral) -> depot lag(oral) = tlag fa(oral) = f_oral -kernel = one_compartment_with_absorption +structure = one_compartment_with_absorption out(cp) = central / v ~ continuous() "#; +const ANALYTICAL_FULL_SOURCE: &str = r#" +name = analytical_full_feature_parity +kind = analytical + +params = ka, ke, v, tlag, f_oral, base_gut, base_central, tvke +covariates = wt@linear, renal@linear +derived = ka_proj, ke_proj +states = gut, central +outputs = cp + +bolus(oral) -> gut +bolus(load) -> central +infusion(iv) -> central + +lag(oral) = tlag * sqrt(wt / 70.0) * pow(90.0 / renal, 0.1) +fa(oral) = min(max(f_oral * pow(renal / 90.0, 0.1), 0.0), 1.0) + +ka_proj = ka +ke_proj = tvke * pow(wt / 70.0, 0.75) * pow(renal / 90.0, 0.25) + +structure = one_compartment_with_absorption + +init(gut) = base_gut + 0.03 * wt +init(central) = base_central + 0.08 * renal + +out(cp) = central / (v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0))) ~ continuous() +"#; + const SDE_SOURCE: &str = r#" -model = vanco_sde +name = vanco_sde kind = sde params = ka, ke0, kcp, kpc, vol, ske @@ -90,7 +150,9 @@ pub const SDE_PARTICLE_COUNT: usize = 16; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CorpusCase { Ode, + OdeFull, Analytical, + AnalyticalFull, Sde, } @@ -98,7 +160,9 @@ impl CorpusCase { pub fn label(self) -> &'static str { match self { Self::Ode => "dsl-ode-one_cmt_oral_iv", + Self::OdeFull => "dsl-ode-full-feature-parity", Self::Analytical => "dsl-analytical-one_cmt_abs", + Self::AnalyticalFull => "dsl-analytical-full-feature-parity", Self::Sde => "dsl-sde-vanco_sde", } } @@ -106,7 +170,9 @@ impl CorpusCase { pub fn model_name(self) -> &'static str { match self { Self::Ode => "one_cmt_oral_iv", + Self::OdeFull => "ode_full_feature_parity", Self::Analytical => "one_cmt_abs", + Self::AnalyticalFull => "analytical_full_feature_parity", Self::Sde => "vanco_sde", } } @@ -114,7 +180,9 @@ impl CorpusCase { fn source(self) -> &'static str { match self { Self::Ode => ODE_SOURCE, + Self::OdeFull => ODE_FULL_SOURCE, Self::Analytical => ANALYTICAL_SOURCE, + Self::AnalyticalFull => ANALYTICAL_FULL_SOURCE, Self::Sde => SDE_SOURCE, } } @@ -122,7 +190,9 @@ impl CorpusCase { pub fn tolerance(self) -> f64 { match self { Self::Ode => 1e-4, + Self::OdeFull => 1e-4, Self::Analytical => 1e-8, + Self::AnalyticalFull => 1e-8, Self::Sde => 1e-4, } } @@ -130,59 +200,170 @@ impl CorpusCase { pub fn support_point(self) -> &'static [f64] { match self { Self::Ode => &[1.2, 5.0, 40.0, 0.5, 0.8], + Self::OdeFull => &[1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0], Self::Analytical => &[1.0, 0.15, 25.0, 0.5, 0.8], + Self::AnalyticalFull => &[1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16], Self::Sde => &[1.1, 0.2, 0.12, 0.08, 15.0, 0.0], } } fn runtime_subject(self, model: &CompiledRuntimeModel) -> Result> { - let cp = model - .output_index("cp") + model + .info() + .outputs + .iter() + .find(|output| output.name == "cp") .ok_or_else(|| io::Error::other(format!("{}: missing cp output", self.label())))?; let subject = match self { Self::Ode => { - let oral = model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; - let iv = model.route_index("iv").ok_or_else(|| { - io::Error::other(format!("{}: missing iv route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") + .build() + } + Self::OdeFull => { + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "load") + .ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; + Subject::builder(self.label()) + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) .build() } Self::Analytical => { - let oral = model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + Subject::builder(self.label()) + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build() + } + Self::AnalyticalFull => { + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "load") + .ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; Subject::builder(self.label()) - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) .build() } Self::Sde => { - let oral = model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build() } }; @@ -193,9 +374,15 @@ impl CorpusCase { fn reference_predictions(self) -> Result> { match self { Self::Ode => Ok(ExpectedPredictions::Subject(reference_ode_predictions()?)), + Self::OdeFull => Ok(ExpectedPredictions::Subject( + reference_ode_full_predictions()?, + )), Self::Analytical => Ok(ExpectedPredictions::Subject( reference_analytical_predictions()?, )), + Self::AnalyticalFull => Ok(ExpectedPredictions::Subject( + reference_analytical_full_predictions()?, + )), Self::Sde => Ok(ExpectedPredictions::Particles(reference_sde_predictions()?)), } } @@ -605,6 +792,137 @@ fn reference_ode_predictions() -> Result> { )?) } +fn reference_ode_full_predictions() -> Result> { + Ok(equation::ODE::new( + |x, p, t, dx, bolus, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + kcp, + kpc, + _v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[0] = bolus[0] - ka * x[0]; + dx[1] = + bolus[1] + ka * x[0] + rateiv[0] - (adjusted_ke + adjusted_kcp) * x[1] + kpc * x[2]; + dx[2] = adjusted_kcp * x[1] - kpc * x[2]; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + _f_oral, + base_depot, + base_central, + base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_depot + 0.05 * wt; + x[1] = base_central + 0.1 * renal; + x[2] = base_peripheral + 0.02 * wt; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(3) + .with_ndrugs(2) + .with_nout(1) + .estimate_predictions( + &Subject::builder(CorpusCase::OdeFull.label()) + .bolus(0.0, 80.0, 1) + .bolus(1.0, 120.0, 0) + .infusion(6.0, 150.0, 0, 2.5) + .missing_observation(0.25, 0) + .missing_observation(0.75, 0) + .missing_observation(1.5, 0) + .missing_observation(3.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .missing_observation(12.0, 0) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build(), + CorpusCase::OdeFull.support_point(), + )?) +} + fn reference_analytical_predictions() -> Result> { Ok(equation::Analytical::new( one_compartment_with_absorption, @@ -638,6 +956,110 @@ fn reference_analytical_predictions() -> Result Result> { + Ok(equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(2) + .with_nout(1) + .estimate_predictions( + &Subject::builder(CorpusCase::AnalyticalFull.label()) + .bolus(0.0, 60.0, 1) + .bolus(1.0, 100.0, 0) + .infusion(6.0, 140.0, 0, 2.0) + .missing_observation(0.25, 0) + .missing_observation(0.75, 0) + .missing_observation(1.5, 0) + .missing_observation(3.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .missing_observation(12.0, 0) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build(), + CorpusCase::AnalyticalFull.support_point(), + )?) +} + fn reference_sde_predictions() -> Result, Box> { Ok(SDE::new( |x, p, _t, dx, _rateiv, _cov| { diff --git a/tests/test_pf.rs b/tests/test_pf.rs index 1f9482f2..3a195ff0 100644 --- a/tests/test_pf.rs +++ b/tests/test_pf.rs @@ -34,7 +34,7 @@ fn test_particle_filter_likelihood() { .with_nstates(2) .with_nout(1); - let ems = AssayErrorModels::new() + let ems = AssayErrorModels::default() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.5, 0.0, 0.0, 0.0), 0.0),