From c353134e25a0054392307e97498925eaa7610205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 08:16:33 +0100 Subject: [PATCH 01/22] Initial pass --- .gitignore | 1 + README.md | 115 +- examples/analytical_readme.rs | 35 + examples/analytical_vs_ode.rs | 337 +- examples/compare_solvers.rs | 75 +- examples/covariates.rs | 79 +- examples/dsl_runtime_jit.rs | 2 +- examples/macro_vs_handwritten_one_cpt.rs | 110 + examples/macro_vs_handwritten_two_cpt.rs | 130 + examples/ode_readme.rs | 69 +- examples/one_compartment.rs | 99 +- examples/sde_readme.rs | 41 + examples/two_compartment.rs | 79 +- pharmsol-dsl/src/ast.rs | 11 +- pharmsol-dsl/src/authoring.rs | 53 +- pharmsol-dsl/src/execution.rs | 244 +- pharmsol-dsl/src/ir.rs | 5 +- pharmsol-dsl/src/parser.rs | 25 +- pharmsol-dsl/src/semantic.rs | 28 +- pharmsol-dsl/src/test_fixtures.rs | 4 +- .../tests/dsl_authoring_edge_cases.rs | 18 +- pharmsol-macros/Cargo.toml | 2 +- pharmsol-macros/src/lib.rs | 2732 +++++++++++++++-- src/dsl/compiled_backend_abi.rs | 2 + src/dsl/jit.rs | 87 +- src/dsl/model_info.rs | 150 +- src/dsl/native.rs | 107 +- src/dsl/rust_backend.rs | 2 +- src/dsl/wasm.rs | 8 +- src/dsl/wasm_compile.rs | 8 +- src/dsl/wasm_direct_emitter.rs | 2 +- src/lib.rs | 21 +- src/simulator/equation/analytical/mod.rs | 257 +- src/simulator/equation/meta.rs | 64 - src/simulator/equation/metadata.rs | 1211 ++++++++ src/simulator/equation/mod.rs | 5 +- src/simulator/equation/ode/mod.rs | 342 ++- src/simulator/equation/sde/mod.rs | 476 ++- src/test_fixtures.rs | 2 +- tests/analytical_macro_lowering.rs | 299 ++ tests/authoring_parity_corpus.rs | 1365 ++++++++ tests/browser-e2e/site/app.mjs | 4 +- tests/ode_macro_lowering.rs | 376 +++ tests/sde_macro_lowering.rs | 359 +++ tests/support/bimodal_ke.rs | 2 +- tests/support/runtime_corpus.rs | 8 +- 46 files changed, 8616 insertions(+), 835 deletions(-) create mode 100644 examples/analytical_readme.rs create mode 100644 examples/macro_vs_handwritten_one_cpt.rs create mode 100644 examples/macro_vs_handwritten_two_cpt.rs create mode 100644 examples/sde_readme.rs delete mode 100644 src/simulator/equation/meta.rs create mode 100644 src/simulator/equation/metadata.rs create mode 100644 tests/analytical_macro_lowering.rs create mode 100644 tests/authoring_parity_corpus.rs create mode 100644 tests/ode_macro_lowering.rs create mode 100644 tests/sde_macro_lowering.rs diff --git a/.gitignore b/.gitignore index a34b170b..5d7e9858 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ Cargo.lock /paper_files paper.html /tests/browser-e2e/node_modules/ +docs/ diff --git a/README.md b/README.md index b1620b8e..d24aea87 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ A high-performance Rust library for pharmacokinetic/pharmacodynamic (PK/PD) simu ## Installation -Add `pharmsol` to your `Cargo.toml`, either manually or using +Add `pharmsol` to `Cargo.toml`: ```bash cargo add pharmsol @@ -16,65 +16,75 @@ cargo add pharmsol ## Quick Start +Most Rust-first workflows start with one of the equation macros: `analytical!`, +`ode!`, or `sde!`. Here is a simple one-compartment IV infusion model using `analytical!`: + ```rust -use pharmsol::*; +use pharmsol::prelude::*; + +let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, +}; + +let iv = analytical.route_index("iv").unwrap(); +let cp = analytical.output_index("cp").unwrap(); -// Create a subject with an IV infusion and observations let subject = Subject::builder("patient_001") - .infusion(0.0, 500.0, 0, 0.5) // 500 units over 0.5 hours - .observation(0.5, 1.645, 0) - .observation(1.0, 1.216, 0) - .observation(2.0, 0.462, 0) - .observation(4.0, 0.063, 0) + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) .build(); -// Define parameters: ke (elimination rate), v (volume) -let ke = 1.022; -let v = 194.0; - -// Use the built-in one-compartment analytical solution -let analytical = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; // Concentration = Amount / Volume - }, - (1, 1), // (compartments, outputs) -); - -// Get predictions -let predictions = analytical.estimate_predictions(&subject, &vec![ke, v]).unwrap(); +let predictions = analytical + .estimate_predictions(&subject, &[1.022, 194.0]) + .unwrap(); ``` -## ODE-Based Models +## Modeling Surfaces -For custom or complex models, define your own ODEs: +Here is the same one-compartment IV setup written as an ODE: ```rust -use pharmsol::*; +use pharmsol::prelude::*; -let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, _v); - // One-compartment model with IV infusion support - dx[0] = -ke * x[0] + rateiv[0]; +let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - (1, 1), -); +}; ``` -## Supported Analytical Models +See [examples/analytical_readme.rs](examples/analytical_readme.rs), +[examples/ode_readme.rs](examples/ode_readme.rs), +[examples/sde_readme.rs](examples/sde_readme.rs), +[examples/analytical_vs_ode.rs](examples/analytical_vs_ode.rs), and +[examples/compare_solvers.rs](examples/compare_solvers.rs). For migration-oriented notes, +see [docs/analytical-authoring-migration.md](docs/analytical-authoring-migration.md) and +[docs/ode-authoring-migration.md](docs/ode-authoring-migration.md). + +### Built-In Analytical Kernels - [x] One-compartment with IV infusion - [x] One-compartment with IV infusion and oral absorption @@ -83,6 +93,21 @@ let ode = equation::ODE::new( - [x] Three-compartment with IV infusion - [x] Three-compartment with IV infusion and oral absorption +## DSL and Runtime Targets + +If the model needs to be loaded or compiled at runtime, pharmsol also provides a DSL with +the same broad modeling coverage: ODE, analytical, and SDE authoring. The DSL can target +an in-process JIT runtime, native ahead-of-time artifacts, or WASM bundles depending on +how you want to ship and execute the model. + +- `dsl-jit`: compile DSL source into a runtime model inside the current process. +- `dsl-aot` and `dsl-aot-load`: emit a native artifact and load it later. +- `dsl-wasm`: compile and execute portable WASM model artifacts. + +See [examples/dsl_runtime_jit.rs](examples/dsl_runtime_jit.rs) for the in-repo JIT flow. +The companion `pharmsol-examples` crate includes end-to-end native AOT and WASM runtime +examples. + ## Performance Analytical solutions provide 20-33× speedups compared to equivalent ODE formulations. See [benchmarks](benches/) for details. diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs new file mode 100644 index 00000000..fdcd5b1a --- /dev/null +++ b/examples/analytical_readme.rs @@ -0,0 +1,35 @@ +fn main() -> Result<(), pharmsol::PharmsolError> { + use pharmsol::prelude::*; + + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + + let subject = Subject::builder("analytical_readme") + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .build(); + + let predictions = analytical.estimate_predictions(&subject, &[1.022, 194.0])?; + + println!("times => {:?}", predictions.flat_times()); + println!("predictions => {:?}", predictions.flat_predictions()); + + Ok(()) +} \ No newline at end of file diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 97112f15..290d6632 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -4,6 +4,8 @@ //! two-compartment IV, two-compartment oral), this example runs both the //! closed-form analytical solution and the equivalent ODE, then prints //! the predictions side by side so you can verify they match. +//! Both authoring paths use the declaration-first macro surface so the +//! example stays on the preferred public authoring story. //! //! cargo run --release --example analytical_vs_ode @@ -11,29 +13,29 @@ use pharmsol::prelude::*; // ── Subjects ─────────────────────────────────────────────────────── -fn subject_iv() -> Subject { +fn subject_iv(input: usize, output: usize) -> Subject { Subject::builder("1") - .infusion(0.0, 500.0, 0, 0.5) - .observation(0.5, 0.0, 0) - .observation(1.0, 0.0, 0) - .observation(2.0, 0.0, 0) - .observation(4.0, 0.0, 0) - .observation(8.0, 0.0, 0) - .observation(12.0, 0.0, 0) - .observation(24.0, 0.0, 0) + .infusion(0.0, 500.0, input, 0.5) + .observation(0.5, 0.0, output) + .observation(1.0, 0.0, output) + .observation(2.0, 0.0, output) + .observation(4.0, 0.0, output) + .observation(8.0, 0.0, output) + .observation(12.0, 0.0, output) + .observation(24.0, 0.0, output) .build() } -fn subject_oral() -> Subject { +fn subject_oral(input: usize, output: usize) -> Subject { Subject::builder("1") - .bolus(0.0, 500.0, 0) - .observation(0.5, 0.0, 0) - .observation(1.0, 0.0, 0) - .observation(2.0, 0.0, 0) - .observation(4.0, 0.0, 0) - .observation(8.0, 0.0, 0) - .observation(12.0, 0.0, 0) - .observation(24.0, 0.0, 0) + .bolus(0.0, 500.0, input) + .observation(0.5, 0.0, output) + .observation(1.0, 0.0, output) + .observation(2.0, 0.0, output) + .observation(4.0, 0.0, output) + .observation(8.0, 0.0, output) + .observation(12.0, 0.0, output) + .observation(24.0, 0.0, output) .build() } @@ -64,168 +66,181 @@ fn print_comparison(label: &str, analytical: &SubjectPredictions, ode: &SubjectP // ── One-compartment IV ───────────────────────────────────────────── -fn one_cmt_iv(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(1) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, _v); - dx[0] = -ke * x[0] + rateiv[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(1) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn one_cmt_iv(params: &[f64]) { + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_iv(iv, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("One-compartment IV", &pred_a, &pred_o); } // ── One-compartment oral ─────────────────────────────────────────── -fn one_cmt_oral(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - one_compartment_with_absorption, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { - fetch_params!(p, ka, ke, _v); - dx[0] = -ka * x[0]; - dx[1] = ka * x[0] - ke * x[1]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn one_cmt_oral(params: &[f64]) { + let analytical = analytical! { + name: "one_cmt_oral", + params: [ka, ke, v], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + structure: one_compartment_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "one_cmt_oral", + params: [ka, ke, v], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let oral = analytical.route_index("oral").expect("oral route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_oral(oral, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("One-compartment oral", &pred_a, &pred_o); } // ── Two-compartment IV ───────────────────────────────────────────── -fn two_cmt_iv(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - two_compartments, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _k12, _k21, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, k12, k21, _v); - dx[0] = -ke * x[0] - k12 * x[0] + k21 * x[1] + rateiv[0]; - dx[1] = k12 * x[0] - k21 * x[1]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _k12, _k21, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn two_cmt_iv(params: &[f64]) { + let analytical = analytical! { + name: "two_cmt_iv", + params: [ke, k12, k21, v], + states: [central, peripheral], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: two_compartments, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "two_cmt_iv", + params: [ke, k12, k21, v], + states: [central, peripheral], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - k12 * x[central] + k21 * x[peripheral]; + dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_iv(iv, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("Two-compartment IV", &pred_a, &pred_o); } // ── Two-compartment oral ─────────────────────────────────────────── -fn two_cmt_oral(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - two_compartments_with_absorption, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _k12, _k21, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(3) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { - fetch_params!(p, ka, ke, k12, k21, _v); - dx[0] = -ka * x[0]; - dx[1] = ka * x[0] - ke * x[1] - k12 * x[1] + k21 * x[2]; - dx[2] = k12 * x[1] - k21 * x[2]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _k12, _k21, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(3) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn two_cmt_oral(params: &[f64]) { + let analytical = analytical! { + name: "two_cmt_oral", + params: [ka, ke, k12, k21, v], + states: [gut, central, peripheral], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + structure: two_compartments_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "two_cmt_oral", + params: [ka, ke, k12, k21, v], + states: [gut, central, peripheral], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central] - k12 * x[central] + k21 * x[peripheral]; + dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let oral = analytical.route_index("oral").expect("oral route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_oral(oral, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("Two-compartment oral", &pred_a, &pred_o); } // ── Main ─────────────────────────────────────────────────────────── fn main() { - let iv = subject_iv(); - let oral = subject_oral(); - - one_cmt_iv(&iv, &[0.1, 50.0]); // ke, v - one_cmt_oral(&oral, &[1.0, 0.1, 50.0]); // ka, ke, v - two_cmt_iv(&iv, &[0.1, 0.3, 0.2, 50.0]); // ke, k12, k21, v - two_cmt_oral(&oral, &[1.0, 0.1, 0.3, 0.2, 50.0]); // ka, ke, k12, k21, v + one_cmt_iv(&[0.1, 50.0]); // ke, v + one_cmt_oral(&[1.0, 0.1, 50.0]); // ka, ke, v + two_cmt_iv(&[0.1, 0.3, 0.2, 50.0]); // ke, k12, k21, v + two_cmt_oral(&[1.0, 0.1, 0.3, 0.2, 50.0]); // ka, ke, k12, k21, v } diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index 3a34b424..ebec4caa 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -1,4 +1,4 @@ -//! Shows how to select different ODE solvers for the same model. +//! Shows how to select different ODE solvers for the same declaration-first model. //! //! pharmsol wraps diffsol's solver families: //! @@ -14,19 +14,26 @@ use std::time::Instant; use pharmsol::prelude::*; // ── Model ────────────────────────────────────────────────────────── -// Two-compartment IV model. The solver is the only thing that changes -// between runs — the ODE, output equation and dimensions stay the same. +// Two-compartment IV model. The solver is the only thing that changes +// between runs; the declaration-first `ode!` surface and the generated +// metadata stay the same. fn two_cpt(solver: OdeSolver) -> equation::ODE { ode! { - diffeq: |x, p, _t, dx, b, rateiv, _cov| { - fetch_params!(p, ke, kcp, kpc, _v); - dx[0] = rateiv[0] + b[0] - ke * x[0] - kcp * x[0] + kpc * x[1]; - dx[1] = kcp * x[0] - kpc * x[1]; + name: "two_cpt", + params: [ke, kcp, kpc, v], + states: [central, peripheral], + outputs: [cp], + routes: { + bolus(load) -> central, + infusion(iv) -> central, }, - out: |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _kcp, _kpc, v); - y[0] = x[0] / v; + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, } .with_solver(solver) @@ -35,30 +42,42 @@ fn two_cpt(solver: OdeSolver) -> equation::ODE { // ── Main ─────────────────────────────────────────────────────────── fn main() { - let subject = Subject::builder("id1") - .bolus(0.0, 100.0, 0) - .infusion(12.0, 200.0, 0, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(4.0, 0) - .missing_observation(8.0, 0) - .missing_observation(12.0, 0) - .missing_observation(12.5, 0) - .missing_observation(13.0, 0) - .missing_observation(14.0, 0) - .missing_observation(16.0, 0) - .missing_observation(24.0, 0) - .build(); - - let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V - // Run each solver and collect predictions let bdf = two_cpt(OdeSolver::Bdf); let tsit45 = two_cpt(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let trbdf2 = two_cpt(OdeSolver::Sdirk(SdirkTableau::TrBdf2)); let esdirk34 = two_cpt(OdeSolver::Sdirk(SdirkTableau::Esdirk34)); + // Both declarations resolve to the same shared input channel, so subject + // authoring still uses one numeric index for the loading bolus and the + // maintenance infusion. + let load = bdf.route_index("load").expect("load route exists"); + let iv = bdf.route_index("iv").expect("iv route exists"); + let cp = bdf.output_index("cp").expect("cp output exists"); + + assert_eq!( + load, iv, + "mixed IV declarations should share one numeric channel" + ); + + let subject = Subject::builder("id1") + .bolus(0.0, 100.0, iv) + .infusion(12.0, 200.0, iv, 2.0) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .missing_observation(12.5, cp) + .missing_observation(13.0, cp) + .missing_observation(14.0, cp) + .missing_observation(16.0, cp) + .missing_observation(24.0, cp) + .build(); + + let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V + let results: Vec<(&str, equation::ODE)> = vec![ ("Bdf", bdf), ("Sdirk(TrBdf2)", trbdf2), diff --git a/examples/covariates.rs b/examples/covariates.rs index f9b97f29..0e85b9bf 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -1,61 +1,54 @@ fn main() { use pharmsol::prelude::*; - // Create a subject with a bolus dose, observations, and covariates - let subject = Subject::builder("id1") - // Administer a bolus dose of 100 units at time 0 - .bolus(0.0, 100.0, 0) - // Give two additional doses at 2-hour intervals - .repeat(2, 2.0) - .observation(0.5, 0.1, 0) - .observation(1.0, 0.4, 0) - .observation(2.0, 1.0, 0) - .observation(2.5, 1.1, 0) - // Creatinine covariate changes over time, with initial value of 80 at time 0 - .covariate("creatinine", 0.0, 80.0) - // New obseration of creatinine at time 6 hours - // The value will be linearly interpolated between time 0 and time 6 - .covariate("creatinine", 1.0, 40.0) - // For age, the covariate is constant over time, as there are no changes - .covariate("age", 0.0, 25.0) - .missing_observation(8.0, 0) - .build(); - - let ode = equation::ODE::new( - |x, p, t, dx, b, _rateiv, cov| { + let ode = ode! { + name: "one_cmt_covariates", + params: [ka, ke, tlag, v], + covariates: [creatinine, age], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _p, t, dx, cov| { // Macro to get the (possibly interpolated) covariate values at time `t` fetch_cov!(cov, t, creatinine, age); - // Macro to fetch parameter values from `p` - // Note the order must match the order in which parameters are defined later - fetch_params!(p, ka, ke, _tlag, _v); - let ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); + let scaled_ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); - //Struct - dx[0] = -ka * x[0] + b[0]; - dx[1] = ka * x[0] - ke * x[1]; + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - scaled_ke * x[central]; }, // This blocks defines the lag-time of the bolus dose - |p, _t, _cov| { - fetch_params!(p, _ka, _ke, tlag, _v); + lag: |_p, _t, _cov| { // Macro used to define the lag-time for the input of the bolus dose - lag! {0=>tlag} + lag! { oral => tlag } }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); - + out: |x, _p, _t, _cov, y| { // Define the predicted concentration as the amount in the central compartment divided by volume - y[0] = x[1] / v; + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_nout(1); + }; + + let oral = ode.route_index("oral").expect("oral route exists"); + let cp = ode.output_index("cp").expect("cp output exists"); + + // Create a subject with metadata-backed route and output names instead of + // hard-coded numeric indices. + let subject = Subject::builder("id1") + .bolus(0.0, 100.0, oral) + .repeat(2, 2.0) + .observation(0.5, 0.1, cp) + .observation(1.0, 0.4, cp) + .observation(2.0, 1.0, cp) + .observation(2.5, 1.1, cp) + .covariate("creatinine", 0.0, 80.0) + .covariate("creatinine", 1.0, 40.0) + .covariate("age", 0.0, 25.0) + .missing_observation(8.0, cp) + .build(); // Define parameter values - // Note that the order matters and should correspond to the order in which parameters are fetched in the model - // This is subject to change in future versions let ka = 1.0; // Absorption rate constant let ke = 0.2; // Elimination rate constant let tlag = 0.0; // Lag time diff --git a/examples/dsl_runtime_jit.rs b/examples/dsl_runtime_jit.rs index 655981bc..932acaae 100644 --- a/examples/dsl_runtime_jit.rs +++ b/examples/dsl_runtime_jit.rs @@ -8,7 +8,7 @@ fn main() -> Result<(), Box> { use pharmsol::prelude::*; let model_source = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs new file mode 100644 index 00000000..4d8f74d0 --- /dev/null +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -0,0 +1,110 @@ +//! Compares a declaration-first macro ODE with the equivalent handwritten ODE. +//! +//! This is the advanced comparison path for users who want to confirm that the +//! preferred macro surface and the low-level API produce the same metadata and +//! predictions on the same one-compartment IV problem. + +use pharmsol::prelude::*; + +fn macro_model() -> equation::ODE { + ode! { + name: "one_cpt_macro_parity", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_macro_parity") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten one-compartment metadata should validate") +} + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn main() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_model(); + let handwritten_ode = handwritten_model(); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let iv = macro_ode.route_index("iv").expect("iv route exists"); + let cp = macro_ode.output_index("cp").expect("cp output exists"); + + assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); + assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + + let subject = Subject::builder("macro-vs-handwritten-one-cpt") + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .missing_observation(8.0, cp) + .build(); + + let params = [1.022, 194.0]; + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let macro_flat = macro_predictions.flat_predictions(); + let handwritten_flat = handwritten_predictions.flat_predictions(); + let diff = max_abs_diff(¯o_flat, &handwritten_flat); + + assert!( + diff <= 1e-10, + "macro and handwritten one-compartment predictions diverged: {diff:e}" + ); + + println!("one-compartment parity max abs diff: {diff:e}"); + for ((time, macro_pred), handwritten_pred) in macro_predictions + .flat_times() + .iter() + .zip(macro_flat.iter()) + .zip(handwritten_flat.iter()) + { + println!("t={time:>4.1} macro={macro_pred:>12.8} handwritten={handwritten_pred:>12.8}"); + } + + Ok(()) +} diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs new file mode 100644 index 00000000..9ab1a675 --- /dev/null +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -0,0 +1,130 @@ +//! Compares a declaration-first macro ODE with the equivalent handwritten ODE +//! on a two-compartment IV problem that shares one numeric input channel across +//! a loading bolus and a maintenance infusion. +//! +//! This keeps the macro story as the default surface while showing the +//! low-level API as an explicit advanced comparison path. + +use pharmsol::prelude::*; + +fn macro_model() -> equation::ODE { + ode! { + name: "two_cpt_shared_channel_parity", + params: [ke, kcp, kpc, v], + states: [central, peripheral], + outputs: [cp], + routes: { + bolus(load) -> central, + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ke, kcp, kpc, _v); + dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; + dx[1] = kcp * x[0] - kpc * x[1]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, _kcp, _kpc, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("two_cpt_shared_channel_parity") + .parameters(["ke", "kcp", "kpc", "v"]) + .states(["central", "peripheral"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("load") + .to_state("central") + .inject_input_to_destination(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("handwritten two-compartment metadata should validate") +} + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn main() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_model(); + let handwritten_ode = handwritten_model(); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let load = macro_ode.route_index("load").expect("load route exists"); + let iv = macro_ode.route_index("iv").expect("iv route exists"); + let cp = macro_ode.output_index("cp").expect("cp output exists"); + + assert_eq!( + load, iv, + "load and iv should share one numeric input channel" + ); + assert_eq!(handwritten_ode.route_index("load"), Some(load)); + assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); + assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + + let subject = Subject::builder("macro-vs-handwritten-two-cpt") + .bolus(0.0, 100.0, load) + .infusion(12.0, 200.0, iv, 2.0) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .missing_observation(12.5, cp) + .missing_observation(13.0, cp) + .missing_observation(14.0, cp) + .missing_observation(16.0, cp) + .missing_observation(24.0, cp) + .build(); + + let params = [0.1, 0.05, 0.03, 50.0]; + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let macro_flat = macro_predictions.flat_predictions(); + let handwritten_flat = handwritten_predictions.flat_predictions(); + let diff = max_abs_diff(¯o_flat, &handwritten_flat); + + assert!( + diff <= 1e-10, + "macro and handwritten two-compartment predictions diverged: {diff:e}" + ); + + println!("two-compartment parity max abs diff: {diff:e}"); + for ((time, macro_pred), handwritten_pred) in macro_predictions + .flat_times() + .iter() + .zip(macro_flat.iter()) + .zip(handwritten_flat.iter()) + { + println!("t={time:>5.1} macro={macro_pred:>12.8} handwritten={handwritten_pred:>12.8}"); + } + + Ok(()) +} diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index 51765af1..a0174801 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -1,43 +1,40 @@ -fn main() { +fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - let subject = Subject::builder("id1") - .bolus(0.0, 100.0, 0) - .repeat(2, 0.5) - .observation(0.5, 0.1, 0) - .observation(1.0, 0.4, 0) - .observation(2.0, 1.0, 0) - .observation(2.5, 1.1, 0) - .covariate("wt", 0.0, 80.0) - .covariate("wt", 1.0, 83.0) - .covariate("age", 0.0, 25.0) - .build(); - println!("{subject}"); - let ode = equation::ODE::new( - |x, p, _t, dx, b, _rateiv, _cov| { - // fetch_cov!(cov, t,); - fetch_params!(p, ka, ke, _tlag, _v); - //Struct - dx[0] = -ka * x[0] + b[0]; - dx[1] = ka * x[0] - ke * x[1]; + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, }, - |p, _t, _cov| { - fetch_params!(p, _ka, _ke, tlag, _v); - lag! {0=>tlag} + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); - y[0] = x[1] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_ndrugs(5) - .with_nout(1); + }; + + let iv = ode.route_index("iv").expect("iv route exists"); + let cp = ode.output_index("cp").expect("cp output exists"); + + let subject = Subject::builder("id1") + .infusion(0.0, 100.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .build(); + + let predictions = ode.estimate_predictions(&subject, &[1.022, 194.0])?; + println!( + "state central => {}", + ode.state_index("central").expect("central state exists") + ); + println!("prediction times => {:?}", predictions.flat_times()); + println!("predictions => {:?}", predictions.flat_predictions()); - let op = ode - .estimate_predictions(&subject, &[0.3, 0.5, 0.1, 70.0]) - .unwrap(); - println!("{:#?}", op.flat_predictions()); + Ok(()) } diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index a66112b5..aafdf2b2 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -1,67 +1,58 @@ fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - // Create a subject using the builder pattern - let subject = Subject::builder("Nikola Tesla") - // An initial infusion of 500 units over 0.5 time units - .infusion(0., 500.0, 0, 0.5) - // Observations at various time points - .observation(0.5, 1.645, 0) - .observation(1., 1.216, 0) - .observation(2., 0.462, 0) - .observation(3., 0.169, 0) - .observation(4., 0.063, 0) - .observation(6., 0.009, 0) - .observation(8., 0.001, 0) - // A missing observation, to force the simulator to predict to this time point - // For missing observations, predictions are made but no likelihood contribution is computed - .missing_observation(12.0, 0) - // Build the subject - .build(); - - // Define the one-compartment analytical solution function - let an = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - // Calculate the output concentration, here defined as amount over volume - y[0] = x[0] / v; + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, }, - ) - .with_nstates(1) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - // Macro to fetch parameters from the parameter vector - // This exposes them as local variables - fetch_params!(p, ke, _v); + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; - // Define the ODE for the one-compartment model - // Note that rateiv is used to include infusion rates - dx[0] = -ke * x[0] + rateiv[0]; + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - // Calculate the output concentration, here defined as amount over volume - y[0] = x[0] / v; + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - ) - .with_nstates(1) - .with_nout(1); + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + + // Create a subject using metadata-backed route and output names instead of + // hard-coded indices. + let subject = Subject::builder("Nikola Tesla") + .infusion(0., 500.0, iv, 0.5) + .observation(0.5, 1.645, cp) + .observation(1., 1.216, cp) + .observation(2., 0.462, cp) + .observation(3., 0.169, cp) + .observation(4., 0.063, cp) + .observation(6., 0.009, cp) + .observation(8., 0.001, cp) + .missing_observation(12.0, cp) + .build(); // Define the error models for the observations let ems = AssayErrorModels::new(). // For this example, we use a simple additive error model with 5% error add( - 0, + cp, AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), )?; @@ -70,9 +61,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let v = 194.0; // Volume of distribution // Compute likelihoods and predictions for both models - let analytical_likelihoods = an.estimate_log_likelihood(&subject, &[ke, v], &ems)?; + let analytical_likelihoods = analytical.estimate_log_likelihood(&subject, &[ke, v], &ems)?; - let analytical_predictions = an.estimate_predictions(&subject, &[ke, v])?; + let analytical_predictions = analytical.estimate_predictions(&subject, &[ke, v])?; let ode_likelihoods = ode.estimate_log_likelihood(&subject, &[ke, v], &ems)?; diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs new file mode 100644 index 00000000..b3385bed --- /dev/null +++ b/examples/sde_readme.rs @@ -0,0 +1,41 @@ +fn main() -> Result<(), pharmsol::PharmsolError> { + use pharmsol::prelude::*; + + let sde = sde! { + name: "one_cmt_sde", + params: [ke, sigma_ke, v], + states: [central], + outputs: [cp], + particles: 16, + routes: { + infusion(iv) -> central, + }, + drift: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[central] = sigma_ke; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = sde.route_index("iv").expect("iv route exists"); + let cp = sde.output_index("cp").expect("cp output exists"); + + let subject = Subject::builder("sde_readme") + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .build(); + + let predictions = sde.estimate_predictions(&subject, &[1.022, 0.0, 194.0])?; + + println!("first prediction => {}", predictions[[0, 0]].prediction()); + println!("prediction grid shape => {:?}", predictions.dim()); + + Ok(()) +} \ No newline at end of file diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index d81aa1f8..5634704c 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -3,6 +3,9 @@ /// This example demonstrates how to implement a two-compartment pharmacokinetic model /// with weight-based covariate scaling using pharmsol. /// +/// It uses the declaration-first `ode!` surface so the route, covariate, +/// state, and output metadata stay aligned with the generated execution path. +/// /// The two-compartment model describes drug distribution between: /// - Central compartment (x[0]): where drug enters and is eliminated /// - Peripheral compartment (x[1]): a tissue compartment in equilibrium with central @@ -18,36 +21,21 @@ fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - // Create a subject using the builder pattern - let subject = Subject::builder("subject_001") - // An infusion of 500 mg over 0.5 hours (1000 mg/hr rate) - .infusion(0.0, 500.0, 0, 0.5) - // Weight covariate at baseline (85 kg reference weight) - .covariate("wt", 0.0, 70.0) - // Observations at various time points (concentration in mg/L) - .observation(0.5, 8.5, 0) - .observation(1.0, 6.2, 0) - .observation(2.0, 4.1, 0) - .observation(4.0, 2.3, 0) - .observation(6.0, 1.5, 0) - .observation(8.0, 1.1, 0) - .observation(12.0, 0.7, 0) - // Missing observation to force prediction at this time point - .missing_observation(24.0, 0) - .build(); - - // Define the two-compartment ODE model - let ode = equation::ODE::new( - // Primary differential equation block - |x, p, t, dx, _b, rateiv, cov| { + let ode = ode! { + name: "two_cmt_wt", + params: [cl, v, vp, q], + covariates: [wt], + states: [central, peripheral], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, t, dx, cov| { // Fetch the (possibly interpolated) weight covariate at time t fetch_cov!(cov, t, wt); - // Fetch parameters from the parameter vector // CL: Clearance (L/hr), V: Central volume (L) // Vp: Peripheral volume (L), Q: Inter-compartmental clearance (L/hr) - fetch_params!(p, cl, v, vp, q); - // Weight-based allometric scaling // Reference weight is 85 kg let wt_ratio = wt / 85.0; @@ -64,36 +52,43 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let kpc = q_scaled / vp_scaled; // Peripheral to central rate constant // Two-compartment model differential equations - // Central compartment: elimination + distribution + infusion input - dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0]; + // Central compartment: elimination + distribution + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; // Peripheral compartment: distribution equilibrium - dx[1] = kcp * x[0] - kpc * x[1]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; }, - // Lag time block (no lag in this model) - |_p, _t, _cov| lag! {}, - // Bioavailability block (100% for IV, so not needed) - |_p, _t, _cov| fa! {}, - // Secondary equations block (not used here) - |_p, _t, _cov, _x| {}, // Output equation block - calculates observed concentration - |x, p, t, cov, y| { + out: |x, _p, t, cov, y| { fetch_cov!(cov, t, wt); - fetch_params!(p, _cl, v, _vp, _q); // Calculate scaled volume for concentration let wt_ratio = wt / 85.0; let v_scaled = v * wt_ratio; // Concentration = Amount / Volume - y[0] = x[0] / v_scaled; + y[cp] = x[central] / v_scaled; }, - // Model dimensions: (number of compartments, number of outputs) - ) - .with_nstates(2) - .with_nout(1); + }; + + let iv = ode.route_index("iv").expect("iv route exists"); + let cp = ode.output_index("cp").expect("cp output exists"); + + // Create a subject using metadata-backed route and output names instead of + // hard-coded numeric indices. + let subject = Subject::builder("subject_001") + .infusion(0.0, 500.0, iv, 0.5) + .covariate("wt", 0.0, 70.0) + .observation(0.5, 8.5, cp) + .observation(1.0, 6.2, cp) + .observation(2.0, 4.1, cp) + .observation(4.0, 2.3, cp) + .observation(6.0, 1.5, cp) + .observation(8.0, 1.1, cp) + .observation(12.0, 0.7, cp) + .missing_observation(24.0, cp) + .build(); // Define parameter values - // Note: order must match the fetch_params! macro order let cl = 5.0; // Clearance (L/hr) let v = 50.0; // Central volume of distribution (L) let vp = 100.0; // Peripheral volume of distribution (L) diff --git a/pharmsol-dsl/src/ast.rs b/pharmsol-dsl/src/ast.rs index d43e7404..6cff4483 100644 --- a/pharmsol-dsl/src/ast.rs +++ b/pharmsol-dsl/src/ast.rs @@ -111,10 +111,17 @@ pub struct RoutesBlock { pub span: Span, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum RouteKind { + Bolus, + Infusion, +} + #[derive(Debug, Clone, PartialEq)] pub struct RouteDecl { pub input: Ident, pub destination: Place, + pub kind: Option, pub properties: Vec, pub span: Span, } @@ -141,7 +148,7 @@ pub struct StatementBlock { #[derive(Debug, Clone, PartialEq)] pub struct AnalyticalBlock { - pub kernel: Ident, + pub structure: Ident, pub span: Span, } @@ -491,7 +498,7 @@ fn write_analytical_block( indent(out, indent_level); writeln!(out, "analytical {{")?; indent(out, indent_level + 1); - writeln!(out, "kernel = {}", block.kernel.text)?; + writeln!(out, "structure = {}", block.structure.text)?; indent(out, indent_level); write!(out, "}}") } diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index db32b9c8..09bb309a 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -12,7 +12,7 @@ pub(super) fn parse_module(src: &str) -> Result { struct AuthoringParser<'a> { src: &'a str, - model_name: Option, + name: Option, explicit_kind: Option<(ModelKind, Span)>, parameters: Vec, constants: Vec, @@ -24,6 +24,7 @@ struct AuthoringParser<'a> { assigned_outputs: BTreeMap, declared_outputs_span: Option, routes: BTreeMap, + route_order: Vec, route_modifiers: BTreeMap>, derive_statements: Vec, derivative_statements: Vec, @@ -68,7 +69,7 @@ impl<'a> AuthoringParser<'a> { fn new(src: &'a str) -> Self { Self { src, - model_name: None, + name: None, explicit_kind: None, parameters: Vec::new(), constants: Vec::new(), @@ -80,6 +81,7 @@ impl<'a> AuthoringParser<'a> { assigned_outputs: BTreeMap::new(), declared_outputs_span: None, routes: BTreeMap::new(), + route_order: Vec::new(), route_modifiers: BTreeMap::new(), derive_statements: Vec::new(), derivative_statements: Vec::new(), @@ -132,11 +134,15 @@ impl<'a> AuthoringParser<'a> { } let surface_routes = std::mem::take(&mut self.routes); + let route_order = std::mem::take(&mut self.route_order); let mut route_modifiers = std::mem::take(&mut self.route_modifiers); let mut routes = Vec::with_capacity(surface_routes.len()); - for (route_name, route) in &surface_routes { + for route_name in route_order { + let Some(route) = surface_routes.get(&route_name) else { + continue; + }; let mut span = route.span; - let properties = route_modifiers.remove(route_name).unwrap_or_default(); + let properties = route_modifiers.remove(&route_name).unwrap_or_default(); if !properties.is_empty() { span = properties .iter() @@ -145,6 +151,10 @@ impl<'a> AuthoringParser<'a> { routes.push(RouteDecl { input: route.input.clone(), destination: route.destination.clone(), + kind: Some(match route.kind { + SurfaceRouteKind::Bolus => RouteKind::Bolus, + SurfaceRouteKind::Infusion => RouteKind::Infusion, + }), properties, span, }); @@ -169,7 +179,7 @@ impl<'a> AuthoringParser<'a> { inject_infusion_rates(&surface_routes, &routes, &mut derivative_statements); let name = self - .model_name + .name .unwrap_or_else(|| Ident::new(DEFAULT_MODEL_NAME, module_span)); let mut items = Vec::new(); @@ -298,9 +308,19 @@ impl<'a> AuthoringParser<'a> { if let Some(rest) = lhs_trimmed.strip_prefix("model") { if !rest.trim().is_empty() { - return Err(ParseError::new("expected `model = `", span)); + return Err(ParseError::new("expected `name = `", span)); } - self.model_name = Some(parse_ident_segment(rhs, rhs_abs)?); + return Err(ParseError::new( + "`model = ...` has been renamed to `name = ...`", + span, + )); + } + + if let Some(rest) = lhs_trimmed.strip_prefix("name") { + if !rest.trim().is_empty() { + return Err(ParseError::new("expected `name = `", span)); + } + self.name = Some(parse_ident_segment(rhs, rhs_abs)?); return Ok(()); } @@ -365,8 +385,15 @@ impl<'a> AuthoringParser<'a> { } if lhs_trimmed == "kernel" { - let kernel = parse_ident_segment(rhs, rhs_abs)?; - self.analytical = Some(AnalyticalBlock { span, kernel }); + return Err(ParseError::new( + "`kernel = ...` has been renamed to `structure = ...`", + span, + )); + } + + if lhs_trimmed == "structure" { + let structure = parse_ident_segment(rhs, rhs_abs)?; + self.analytical = Some(AnalyticalBlock { span, structure }); return Ok(()); } @@ -428,15 +455,16 @@ impl<'a> AuthoringParser<'a> { }; let input = parse_ident_segment(call.argument, call.argument_start)?; + let route_name = input.text.clone(); let destination = parse_place_at(rhs, line_start + arrow + 2)?; - if self.routes.contains_key(&input.text) { + if self.routes.contains_key(&route_name) { return Err(ParseError::new( format!("duplicate route `{}`", input.text), input.span, )); } self.routes.insert( - input.text.clone(), + route_name.clone(), SurfaceRoute { input, destination, @@ -444,6 +472,7 @@ impl<'a> AuthoringParser<'a> { span, }, ); + self.route_order.push(route_name); Ok(()) } @@ -589,7 +618,7 @@ impl<'a> AuthoringParser<'a> { if matches!(kind, ModelKind::Sde) { if let Some(analytical) = &self.analytical { return Err(ParseError::new( - "SDE authoring models cannot declare an analytical kernel", + "SDE authoring models cannot declare an analytical structure", analytical.span, )); } diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 26e0a3cb..0904bec2 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use crate::{ AnalyticalKernel, ConstValue, CovariateInterpolation, Diagnostic, DiagnosticPhase, - DiagnosticReport, MathIntrinsic, ModelKind, RoutePropertyKind, Span, Symbol, SymbolId, - SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr, + DiagnosticReport, MathIntrinsic, ModelKind, RouteKind, RoutePropertyKind, Span, Symbol, + SymbolId, SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr, TypedExprKind, TypedModel, TypedModule, TypedRangeExpr, TypedStatePlace, TypedStatementBlock, TypedStmt, TypedStmtKind, TypedUnaryOp, ValueType, DSL_LOWERING_GENERIC, }; @@ -98,7 +98,9 @@ pub struct ExecutionState { pub struct ExecutionRoute { pub symbol: SymbolId, pub name: String, + pub declaration_index: usize, pub index: usize, + pub kind: Option, pub destination: RouteDestination, pub has_lag: bool, pub has_bioavailability: bool, @@ -349,7 +351,7 @@ pub enum ExecutionLoad { State(ExecutionStateRef), Derived(usize), Local(usize), - RouteInput(usize), + RouteInput { route: SymbolId, index: usize }, } #[derive(Debug, Clone, PartialEq)] @@ -531,33 +533,67 @@ impl<'a> ExecutionLowerer<'a> { next_state_offset += len; } + let uses_authoring_route_kinds = + !model.routes.is_empty() && model.routes.iter().all(|route| route.kind.is_some()); let mut route_slots = BTreeMap::new(); - let routes = model - .routes - .iter() - .enumerate() - .map(|(index, route)| { - let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?; - route_slots.insert(route.symbol, index); - let destination = - lower_route_destination(&symbol_map, &state_slots, &route.destination)?; - Ok(ExecutionRoute { - symbol: route.symbol, - name: symbol.name.clone(), - index, - destination, - has_lag: route - .properties - .iter() - .any(|property| property.kind == RoutePropertyKind::Lag), - has_bioavailability: route - .properties - .iter() - .any(|property| property.kind == RoutePropertyKind::Bioavailability), - span: route.span, - }) - }) - .collect::, LoweringError>>()?; + let mut routes = Vec::with_capacity(model.routes.len()); + let mut next_bolus_index = 0usize; + let mut next_infusion_index = 0usize; + for (declaration_index, route) in model.routes.iter().enumerate() { + let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?; + if route.kind == Some(RouteKind::Infusion) { + for property in &route.properties { + let label = match property.kind { + RoutePropertyKind::Lag => "lag", + RoutePropertyKind::Bioavailability => "bioavailability", + }; + return Err(LoweringError::new( + format!( + "DSL authoring does not allow `{label}` on infusion route `{}`", + symbol.name + ), + property.span, + ) + .with_note("lag and bioavailability are bolus-only route properties")); + } + } + let index = if uses_authoring_route_kinds { + match route.kind.expect("authoring routes must preserve kind") { + RouteKind::Bolus => { + let index = next_bolus_index; + next_bolus_index += 1; + index + } + RouteKind::Infusion => { + let index = next_infusion_index; + next_infusion_index += 1; + index + } + } + } else { + declaration_index + }; + route_slots.insert(route.symbol, index); + let destination = + lower_route_destination(&symbol_map, &state_slots, &route.destination)?; + routes.push(ExecutionRoute { + symbol: route.symbol, + name: symbol.name.clone(), + declaration_index, + index, + kind: route.kind, + destination, + has_lag: route + .properties + .iter() + .any(|property| property.kind == RoutePropertyKind::Lag), + has_bioavailability: route + .properties + .iter() + .any(|property| property.kind == RoutePropertyKind::Bioavailability), + span: route.span, + }); + } let mut derived_slots = BTreeMap::new(); let derived = model @@ -607,7 +643,7 @@ impl<'a> ExecutionLowerer<'a> { analytical: model .analytical .as_ref() - .map(|analytical| analytical.kernel), + .map(|analytical| analytical.structure), }, symbol_map, parameter_slots, @@ -653,7 +689,7 @@ impl<'a> ExecutionLowerer<'a> { kernels.push(ExecutionKernel { role: KernelRole::Analytical, signature: signature_for(KernelRole::Analytical), - implementation: KernelImplementation::AnalyticalBuiltin(analytical.kernel), + implementation: KernelImplementation::AnalyticalBuiltin(analytical.structure), span: analytical.span, }); } @@ -745,7 +781,13 @@ impl<'a> ExecutionLowerer<'a> { }, route_buffer: DenseBufferLayout { kind: BufferKind::Routes, - len: self.metadata.routes.len(), + len: self + .metadata + .routes + .iter() + .map(|route| route.index + 1) + .max() + .unwrap_or(0), slots: self .metadata .routes @@ -858,7 +900,39 @@ impl<'a> ExecutionLowerer<'a> { let mut statements = Vec::with_capacity(self.model.routes.len()); let mut locals = LocalLowering::default(); + let default_value = match property_kind { + RoutePropertyKind::Lag => literal_real(0.0, self.model.span), + RoutePropertyKind::Bioavailability => literal_real(1.0, self.model.span), + }; + let route_len = self + .metadata + .routes + .iter() + .map(|route| route.index + 1) + .max() + .unwrap_or(0); + for route_index in 0..route_len { + let target_kind = match property_kind { + RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index), + RoutePropertyKind::Bioavailability => { + ExecutionTargetKind::RouteBioavailability(route_index) + } + }; + statements.push(ExecutionStmt { + kind: ExecutionStmtKind::Assign(ExecutionAssignStmt { + target: ExecutionTarget { + kind: target_kind, + span: self.model.span, + }, + value: default_value.clone(), + }), + span: self.model.span, + }); + } for route in &self.model.routes { + if route.kind == Some(RouteKind::Infusion) { + continue; + } let route_name = self.symbol_name(route.symbol)?.to_string(); let route_index = *self.route_slots.get(&route.symbol).ok_or_else(|| { LoweringError::new( @@ -872,8 +946,7 @@ impl<'a> ExecutionLowerer<'a> { .find(|property| property.kind == property_kind) { Some(property) => self.lower_expr(&property.value, &mut locals)?, - None if property_kind == RoutePropertyKind::Lag => literal_real(0.0, route.span), - None => literal_real(1.0, route.span), + None => continue, }; let target_kind = match property_kind { RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index), @@ -1098,7 +1171,10 @@ impl<'a> ExecutionLowerer<'a> { expr.span, ) })?; - ExecutionExprKind::Load(ExecutionLoad::RouteInput(route_index)) + ExecutionExprKind::Load(ExecutionLoad::RouteInput { + route: *route, + index: route_index, + }) } }, }; @@ -1439,6 +1515,100 @@ mod tests { ); } + #[test] + fn authoring_routes_share_channel_indices_by_kind_local_ordinal() { + let src = r#"name = shared_authoring +kind = ode + +params = ka, ke, v, tlag, f_oral +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers"); + + assert_eq!(lowered.abi.route_buffer.len, 1); + assert_eq!(lowered.metadata.routes.len(), 2); + assert_eq!(lowered.metadata.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(lowered.metadata.routes[1].kind, Some(RouteKind::Infusion)); + assert_eq!(lowered.metadata.routes[0].declaration_index, 0); + assert_eq!(lowered.metadata.routes[1].declaration_index, 1); + assert_eq!(lowered.metadata.routes[0].index, 0); + assert_eq!(lowered.metadata.routes[1].index, 0); + assert!(lowered.metadata.routes[0].has_lag); + assert!(lowered.metadata.routes[0].has_bioavailability); + assert!(!lowered.metadata.routes[1].has_lag); + assert!(!lowered.metadata.routes[1].has_bioavailability); + } + + #[test] + fn authoring_routes_reject_infusion_lag_properties() { + let src = r#"name = invalid_infusion_lag +kind = ode + +params = ke, v, tlag +states = central +outputs = cp + +infusion(iv) -> central +lag(iv) = tlag + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let error = crate::lower_typed_model(&typed) + .err() + .expect("infusion lag should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `lag` on infusion route `iv`")); + } + + #[test] + fn authoring_routes_reject_infusion_bioavailability_properties() { + let src = r#"name = invalid_infusion_fa +kind = ode + +params = ke, v, f_iv +states = central +outputs = cp + +infusion(iv) -> central +fa(iv) = f_iv + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let error = crate::lower_typed_model(&typed) + .err() + .expect("infusion bioavailability should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `bioavailability` on infusion route `iv`")); + } + #[test] fn flattens_array_states_and_preserves_loop_structure() { let execution = structured_block_execution(); @@ -1538,8 +1708,8 @@ mod tests { panic!("expected statement bioavailability kernel"); }; - assert_eq!(lag_program.body.statements.len(), 2); - assert_eq!(bio_program.body.statements.len(), 2); + assert_eq!(lag_program.body.statements.len(), 3); + assert_eq!(bio_program.body.statements.len(), 3); assert!(matches!( lag_program.body.statements[1].kind, ExecutionStmtKind::Assign(ExecutionAssignStmt { diff --git a/pharmsol-dsl/src/ir.rs b/pharmsol-dsl/src/ir.rs index d1c54c90..5998431c 100644 --- a/pharmsol-dsl/src/ir.rs +++ b/pharmsol-dsl/src/ir.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::{ModelKind, Span}; +use crate::{ModelKind, RouteKind, Span}; pub type SymbolId = usize; @@ -145,6 +145,7 @@ pub struct TypedState { #[derive(Debug, Clone, PartialEq)] pub struct TypedRoute { pub symbol: SymbolId, + pub kind: Option, pub destination: TypedStatePlace, pub properties: Vec, pub span: Span, @@ -165,7 +166,7 @@ pub enum RoutePropertyKind { #[derive(Debug, Clone, PartialEq)] pub struct TypedAnalytical { - pub kernel: AnalyticalKernel, + pub structure: AnalyticalKernel, pub span: Span, } diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index 7af6c681..f07fbd50 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -607,6 +607,7 @@ impl Parser { Ok(RouteDecl { input: input.clone(), destination, + kind: None, properties, span: input.span.join(end_span), }) @@ -616,19 +617,19 @@ impl Parser { let start = self.bump().unwrap().span; let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; - let kernel_name = self.parse_ident()?; - if kernel_name.text != "kernel" { + let structure_name = self.parse_ident()?; + if structure_name.text != "structure" { return Err(ParseError::new( format!( - "expected `kernel = ` inside analytical block, found `{}`", - kernel_name.text + "expected `structure = ` inside analytical block, found `{}`", + structure_name.text ), - kernel_name.span, + structure_name.span, )); } let eq = self.expect_simple(|kind| matches!(kind, TokenKind::Eq), "`=`")?; - let kernel = self.parse_continuation_ident_after(&eq, "kernel identifier")?; + let structure = self.parse_continuation_ident_after(&eq, "structure identifier")?; self.consume_separators(); let end = self.expect_closing( |kind| matches!(kind, TokenKind::RBrace), @@ -637,7 +638,7 @@ impl Parser { "`analytical` block", )?; Ok(AnalyticalBlock { - kernel, + structure, span: start.join(end.span), }) } @@ -1635,14 +1636,14 @@ out(cp) = gut ~ continuous() #[test] fn authoring_output_annotation_is_optional() { let annotated = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 out(cp) = central ~ continuous() "#; let plain = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 @@ -1658,14 +1659,14 @@ out(cp) = central #[test] fn authoring_dx_and_ddt_lower_equivalently() { let dx_src = r#" -model = derivative_alias +name = derivative_alias kind = ode states = central dx(central) = -ke * central out(cp) = central "#; let ddt_src = r#" -model = derivative_alias +name = derivative_alias kind = ode states = central ddt(central) = -ke * central @@ -1681,7 +1682,7 @@ out(cp) = central #[test] fn authoring_rejects_out_target_not_in_declared_outputs() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index 9f46500a..b328934f 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -345,29 +345,29 @@ impl<'a> Analyzer<'a> { }; let analytical = if let Some(block) = sections.analytical { - let kernel = AnalyticalKernel::from_name(&block.kernel.text).ok_or_else(|| { + let structure = AnalyticalKernel::from_name(&block.structure.text).ok_or_else(|| { SemanticError::new( - format!("unknown analytical kernel `{}`", block.kernel.text), - block.kernel.span, + format!("unknown analytical structure `{}`", block.structure.text), + block.structure.span, ) })?; let state_components = states .iter() .map(|state| state.size.unwrap_or(1)) .sum::(); - if state_components != kernel.state_count() { + if state_components != structure.state_count() { return Err(SemanticError::new( format!( - "analytical kernel `{}` expects {} state value(s), but model declares {}", - block.kernel.text, - kernel.state_count(), + "analytical structure `{}` expects {} state value(s), but model declares {}", + block.structure.text, + structure.state_count(), state_components ), - block.kernel.span, + block.structure.span, )); } Some(TypedAnalytical { - kernel, + structure, span: block.span, }) } else { @@ -624,6 +624,7 @@ impl<'a> Analyzer<'a> { } routes.push(TypedRoute { symbol: id, + kind: route.kind, destination, properties, span: route.span, @@ -2652,6 +2653,7 @@ mod tests { RECOMMENDED_STYLE_AUTHORING, RECOMMENDED_STYLE_CANONICAL, STRUCTURED_BLOCK_CORPUS, }; use crate::{parse_model, parse_module}; + use crate::RouteKind; #[test] fn analyzes_structured_block_corpus() { @@ -2667,7 +2669,7 @@ mod tests { let analytical = &typed.models[2]; assert!(matches!( - analytical.analytical.as_ref().map(|value| value.kernel), + analytical.analytical.as_ref().map(|value| value.structure), Some(AnalyticalKernel::OneCompartmentWithAbsorption) )); @@ -2691,7 +2693,7 @@ mod tests { } #[test] - fn authoring_fixture_lowers_to_equivalent_typed_ir() { + fn authoring_fixture_preserves_route_kind_while_remaining_equivalent() { let authoring_surface = RECOMMENDED_STYLE_AUTHORING; let canonical = RECOMMENDED_STYLE_CANONICAL; @@ -2705,6 +2707,8 @@ mod tests { typed_model_signature(&authoring_typed), typed_model_signature(&canonical_typed) ); + assert_eq!(authoring_typed.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(canonical_typed.routes[0].kind, None); } #[test] @@ -2977,7 +2981,7 @@ model broken { lines.push(format!("particles:{:?}", model.particles)); lines.push(format!( "analytical:{:?}", - model.analytical.as_ref().map(|value| value.kernel) + model.analytical.as_ref().map(|value| value.structure) )); lines.push(format!( "derive:{}", diff --git a/pharmsol-dsl/src/test_fixtures.rs b/pharmsol-dsl/src/test_fixtures.rs index f26181e4..281a268e 100644 --- a/pharmsol-dsl/src/test_fixtures.rs +++ b/pharmsol-dsl/src/test_fixtures.rs @@ -83,7 +83,7 @@ model one_cmt_abs { oral -> depot } analytical { - kernel = one_compartment_with_absorption + structure = one_compartment_with_absorption } outputs { cp = central / v @@ -132,7 +132,7 @@ model vanco_sde { } "#; -pub(crate) const RECOMMENDED_STYLE_AUTHORING: &str = r#"model = recommended_style +pub(crate) const RECOMMENDED_STYLE_AUTHORING: &str = r#"name = recommended_style kind = ode params = ka, ke, v diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 941ffa77..797be3e9 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -3,14 +3,14 @@ use pharmsol_dsl::{analyze_model, parse_model, parse_module}; #[test] fn output_annotation_is_optional() { let annotated = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 out(cp) = central ~ continuous() "#; let plain = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 @@ -26,7 +26,7 @@ out(cp) = central #[test] fn dx_and_ddt_lower_equivalently() { let dx_src = r#" -model = derivative_alias +name = derivative_alias kind = ode params = ke states = central @@ -34,7 +34,7 @@ dx(central) = -ke * central out(cp) = central "#; let ddt_src = r#" -model = derivative_alias +name = derivative_alias kind = ode params = ke states = central @@ -51,7 +51,7 @@ out(cp) = central #[test] fn rejects_out_target_not_in_declared_outputs() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -95,7 +95,7 @@ out(cp) = central / v ~ continuous() #[test] fn rejects_out_target_not_in_declared_outputs_when_declared_later() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -122,7 +122,7 @@ ddt(central) = -ke * central #[test] fn rejects_declared_output_without_assignment() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -144,7 +144,7 @@ out(cp) = central / v #[test] fn rejects_unknown_output_annotation_name() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode states = central ddt(central) = 0 @@ -164,7 +164,7 @@ out(cp) = central ~ continous() #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v diff --git a/pharmsol-macros/Cargo.toml b/pharmsol-macros/Cargo.toml index 291b888c..d9fe58ec 100644 --- a/pharmsol-macros/Cargo.toml +++ b/pharmsol-macros/Cargo.toml @@ -13,4 +13,4 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.106" quote = "1.0.45" -syn = { version = "2.0.117", features = ["full"] } +syn = { version = "2.0.117", features = ["full", "visit"] } diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 0fa320a3..7607871a 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -4,11 +4,15 @@ //! `pharmsol` crate instead. use proc_macro::TokenStream; -use proc_macro2::TokenTree; +use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; +use std::collections::HashSet; use syn::{ - parse::{Parse, ParseStream}, - ExprClosure, Ident, Pat, Token, + parse::{Parse, ParseStream, Parser}, + punctuated::Punctuated, + token, + visit::Visit, + Expr, ExprClosure, Ident, LitStr, Pat, Stmt, Token, }; // --------------------------------------------------------------------------- @@ -16,6 +20,13 @@ use syn::{ // --------------------------------------------------------------------------- struct OdeInput { + name: LitStr, + params: Vec, + covariates: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + diffeq_mode: OdeDiffeqMode, diffeq: ExprClosure, lag: Option, fa: Option, @@ -23,45 +34,438 @@ struct OdeInput { out: ExprClosure, } +struct AnalyticalInput { + name: LitStr, + params: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + structure: Ident, + lag: Option, + fa: Option, + init: Option, + out: ExprClosure, +} + +struct SdeInput { + name: LitStr, + params: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + particles: Expr, + drift: ExprClosure, + diffusion: ExprClosure, + lag: Option, + fa: Option, + init: Option, + out: ExprClosure, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum OdeDiffeqMode { + InjectedRouteInputs, + ExplicitRouteVectors, +} + +struct OdeRouteDecl { + kind: OdeRouteKind, + input: Ident, + destination: Ident, +} + +#[derive(Clone, Copy)] +enum OdeRouteKind { + Bolus, + Infusion, +} + +struct AnalyticalKernelSpec { + runtime_path: TokenStream2, + metadata_kernel: TokenStream2, + parameter_arity: usize, + state_count: usize, +} + +struct RoutePropertyEntry { + route: Ident, + value: Expr, +} + +impl Parse for OdeRouteDecl { + fn parse(input: ParseStream) -> syn::Result { + let kind_ident: Ident = input.parse()?; + let kind = match kind_ident.to_string().as_str() { + "bolus" => OdeRouteKind::Bolus, + "infusion" => OdeRouteKind::Infusion, + other => { + return Err(syn::Error::new_spanned( + &kind_ident, + format!("unknown route kind `{other}`, expected `bolus` or `infusion`"), + )); + } + }; + + let content; + syn::parenthesized!(content in input); + let route_input: Ident = content.parse()?; + if !content.is_empty() { + return Err(content.error("expected a single route input name inside `(...)`")); + } + + if !input.peek(Token![->]) { + return Err( + input.error("expected `->` followed by a destination state in route declaration") + ); + } + input.parse::]>()?; + let destination: Ident = input.parse()?; + + if input.peek(token::Brace) { + return Err( + input.error("route properties are not supported in declaration-first `ode!` yet") + ); + } + + Ok(Self { + kind, + input: route_input, + destination, + }) + } +} + impl Parse for OdeInput { fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut covariates = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; let mut diffeq = None; let mut lag = None; - let mut fa_val = None; + let mut fa = None; + let mut init = None; + let mut out = None; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + + match key.to_string().as_str() { + "name" => set_once_ode(&mut name, input.parse()?, &key, "name")?, + "params" => set_once_ode(&mut params, parse_ident_list(input)?, &key, "params")?, + "covariates" => set_once_ode( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, + "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?, + "outputs" => set_once_ode(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?, + "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?, + "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_ode(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_ode(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_ode(&mut out, input.parse()?, &key, "out")?, + other => { + return Err(syn::Error::new_spanned( + &key, + format!( + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, diffeq, lag, fa, init, out" + ), + )); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + let name = name.ok_or_else(|| { + syn::Error::new( + Span::call_site(), + "declaration-first `ode!` requires `name`, `params`, `states`, `outputs`, and `routes`; the old inferred-dimensions form has been removed", + ) + })?; + let params = params.ok_or_else(|| missing_required_ode_field("params"))?; + let covariates = covariates.unwrap_or_default(); + let states = states.ok_or_else(|| missing_required_ode_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_ode_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?; + let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?; + let out = out.ok_or_else(|| missing_required_ode_field("out"))?; + let diffeq_mode = classify_diffeq_mode(&diffeq)?; + + validate_unique_idents("parameter", ¶ms, "ode!")?; + validate_unique_idents("covariate", &covariates, "ode!")?; + validate_unique_idents("state", &states, "ode!")?; + validate_unique_idents("output", &outputs, "ode!")?; + validate_routes(&routes, &states, "ode!")?; + validate_named_binding_compatibility( + ¶ms, + &states, + &outputs, + &routes, + &diffeq, + &out, + diffeq_mode, + )?; + + Ok(Self { + name, + params, + covariates, + states, + outputs, + routes, + diffeq_mode, + diffeq, + lag, + fa, + init, + out, + }) + } +} + +impl Parse for RoutePropertyEntry { + fn parse(input: ParseStream) -> syn::Result { + let route: Ident = input.parse()?; + input.parse::]>()?; + let value: Expr = input.parse()?; + Ok(Self { route, value }) + } +} + +impl Parse for AnalyticalInput { + fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; + let mut structure = None; + let mut lag = None; + let mut fa = None; + let mut init = None; + let mut out = None; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + + match key.to_string().as_str() { + "name" => set_once_analytical(&mut name, input.parse()?, &key, "name")?, + "params" => { + set_once_analytical(&mut params, parse_ident_list(input)?, &key, "params")? + } + "states" => { + set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")? + } + "outputs" => { + set_once_analytical(&mut outputs, parse_ident_list(input)?, &key, "outputs")? + } + "routes" => { + set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")? + } + "structure" => { + set_once_analytical(&mut structure, input.parse()?, &key, "structure")? + } + "lag" => set_once_analytical(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_analytical(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_analytical(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_analytical(&mut out, input.parse()?, &key, "out")?, + other => { + return Err(syn::Error::new_spanned( + &key, + format!( + "unknown field `{other}`, expected one of: name, params, states, outputs, routes, structure, lag, fa, init, out" + ), + )); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + let name = name.ok_or_else(|| missing_required_analytical_field("name"))?; + let params = params.ok_or_else(|| missing_required_analytical_field("params"))?; + let states = states.ok_or_else(|| missing_required_analytical_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_analytical_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_analytical_field("routes"))?; + let structure = structure.ok_or_else(|| missing_required_analytical_field("structure"))?; + let out = out.ok_or_else(|| missing_required_analytical_field("out"))?; + + validate_unique_idents("parameter", ¶ms, "analytical!")?; + validate_unique_idents("state", &states, "analytical!")?; + validate_unique_idents("output", &outputs, "analytical!")?; + validate_routes(&routes, &states, "analytical!")?; + + let kernel_spec = resolve_analytical_structure(&structure)?; + if params.len() < kernel_spec.parameter_arity { + return Err(syn::Error::new_spanned( + &structure, + format!( + "analytical structure `{}` requires at least {} parameter value(s), but `params` declares {}", + structure, kernel_spec.parameter_arity, params.len() + ), + )); + } + if states.len() != kernel_spec.state_count { + return Err(syn::Error::new_spanned( + &structure, + format!( + "analytical structure `{}` expects {} state value(s), but `states` declares {}", + structure, + kernel_spec.state_count, + states.len() + ), + )); + } + + validate_analytical_named_binding_compatibility( + ¶ms, + &states, + &outputs, + &routes, + lag.as_ref(), + fa.as_ref(), + init.as_ref(), + &out, + )?; + + if let Some(lag) = lag.as_ref() { + let lag_routes = + extract_route_property_routes("built-in `analytical!`", "lag", lag, &routes)?; + validate_route_property_kinds("built-in `analytical!`", "lag", &routes, &lag_routes)?; + } + + if let Some(fa) = fa.as_ref() { + let fa_routes = + extract_route_property_routes("built-in `analytical!`", "fa", fa, &routes)?; + validate_route_property_kinds("built-in `analytical!`", "fa", &routes, &fa_routes)?; + } + + Ok(Self { + name, + params, + states, + outputs, + routes, + structure, + lag, + fa, + init, + out, + }) + } +} + +impl Parse for SdeInput { + fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; + let mut particles = None; + let mut drift = None; + let mut diffusion = None; + let mut lag = None; + let mut fa = None; let mut init = None; let mut out = None; while !input.is_empty() { let key: Ident = input.parse()?; input.parse::()?; - let closure: ExprClosure = input.parse()?; match key.to_string().as_str() { - "diffeq" => diffeq = Some(closure), - "lag" => lag = Some(closure), - "fa" => fa_val = Some(closure), - "init" => init = Some(closure), - "out" => out = Some(closure), + "name" => set_once_sde(&mut name, input.parse()?, &key, "name")?, + "params" => set_once_sde(&mut params, parse_ident_list(input)?, &key, "params")?, + "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?, + "outputs" => set_once_sde(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?, + "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?, + "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?, + "diffusion" => set_once_sde(&mut diffusion, input.parse()?, &key, "diffusion")?, + "lag" => set_once_sde(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_sde(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_sde(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_sde(&mut out, input.parse()?, &key, "out")?, other => { return Err(syn::Error::new_spanned( &key, - format!("unknown field `{other}`, expected: diffeq, lag, fa, init, out"), + format!( + "unknown field `{other}`, expected one of: name, params, states, outputs, routes, particles, drift, diffusion, lag, fa, init, out" + ), )); } } - // optional trailing comma if !input.is_empty() { input.parse::()?; } } - Ok(OdeInput { - diffeq: diffeq.ok_or_else(|| input.error("missing required field `diffeq`"))?, + let name = name.ok_or_else(|| missing_required_sde_field("name"))?; + let params = params.ok_or_else(|| missing_required_sde_field("params"))?; + let states = states.ok_or_else(|| missing_required_sde_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_sde_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_sde_field("routes"))?; + let particles = particles.ok_or_else(|| missing_required_sde_field("particles"))?; + let drift = drift.ok_or_else(|| missing_required_sde_field("drift"))?; + let diffusion = diffusion.ok_or_else(|| missing_required_sde_field("diffusion"))?; + let out = out.ok_or_else(|| missing_required_sde_field("out"))?; + + validate_unique_idents("parameter", ¶ms, "sde!")?; + validate_unique_idents("state", &states, "sde!")?; + validate_unique_idents("output", &outputs, "sde!")?; + validate_routes(&routes, &states, "sde!")?; + validate_sde_named_binding_compatibility( + ¶ms, + &states, + &outputs, + &routes, + &drift, + &diffusion, + lag.as_ref(), + fa.as_ref(), + init.as_ref(), + &out, + )?; + + if let Some(lag) = lag.as_ref() { + let lag_routes = + extract_route_property_routes("declaration-first `sde!`", "lag", lag, &routes)?; + validate_route_property_kinds("declaration-first `sde!`", "lag", &routes, &lag_routes)?; + } + + if let Some(fa) = fa.as_ref() { + let fa_routes = + extract_route_property_routes("declaration-first `sde!`", "fa", fa, &routes)?; + validate_route_property_kinds("declaration-first `sde!`", "fa", &routes, &fa_routes)?; + } + + Ok(Self { + name, + params, + states, + outputs, + routes, + particles, + drift, + diffusion, lag, - fa: fa_val, + fa, init, - out: out.ok_or_else(|| input.error("missing required field `out`"))?, + out, }) } } @@ -70,7 +474,86 @@ impl Parse for OdeInput { // Helpers // --------------------------------------------------------------------------- -/// Return the identifier string for a closure parameter (empty for wildcards). +fn missing_required_ode_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in declaration-first `ode!`"), + ) +} + +fn missing_required_analytical_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in built-in `analytical!`"), + ) +} + +fn missing_required_sde_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in declaration-first `sde!`"), + ) +} + +fn set_once_ode(slot: &mut Option, value: T, key: &Ident, name: &str) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `ode!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn set_once_analytical( + slot: &mut Option, + value: T, + key: &Ident, + name: &str, +) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `analytical!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn set_once_sde(slot: &mut Option, value: T, key: &Ident, name: &str) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `sde!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn parse_ident_list(input: ParseStream) -> syn::Result> { + let content; + syn::bracketed!(content in input); + Ok(Punctuated::::parse_terminated(&content)? + .into_iter() + .collect()) +} + +fn parse_route_list(input: ParseStream) -> syn::Result> { + let content; + syn::braced!(content in input); + Ok( + Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(), + ) +} + fn param_name(pat: &Pat) -> String { match pat { Pat::Ident(p) => p.ident.to_string(), @@ -82,208 +565,2057 @@ fn closure_param_names(c: &ExprClosure) -> Vec { c.inputs.iter().map(param_name).collect() } -/// Recursively scan `tokens` for `ident[literal_int]` patterns where the -/// ident matches one of `names`. Returns the maximum literal integer found. -fn max_literal_index(tokens: proc_macro2::TokenStream, names: &[&str]) -> Option { - let tts: Vec = tokens.into_iter().collect(); - let mut best: Option = None; - - for (i, tt) in tts.iter().enumerate() { - match tt { - TokenTree::Ident(ident) => { - let s = ident.to_string(); - if names.contains(&s.as_str()) { - if let Some(TokenTree::Group(g)) = tts.get(i + 1) { - if g.delimiter() == proc_macro2::Delimiter::Bracket { - let inner: Vec = g.stream().into_iter().collect(); - if inner.len() == 1 { - if let TokenTree::Literal(lit) = &inner[0] { - if let Ok(n) = lit.to_string().parse::() { - best = Some(best.map_or(n, |m: usize| m.max(n))); - } - } - } - } - } +fn closure_param_ident(c: &ExprClosure, index: usize) -> Option { + c.inputs.get(index).and_then(|pat| match pat { + Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()), + _ => None, + }) +} + +fn generated_ident(name: &str) -> Ident { + Ident::new(name, Span::call_site()) +} + +#[derive(Default)] +struct ClosureBodyUsage { + idents: HashSet, + contains_macro: bool, +} + +impl ClosureBodyUsage { + fn analyze(expr: &Expr) -> Self { + let mut usage = Self::default(); + usage.visit_expr(expr); + usage + } + + fn uses(&self, ident: &Ident) -> bool { + self.contains_macro || self.idents.contains(&ident.to_string()) + } +} + +impl<'ast> Visit<'ast> for ClosureBodyUsage { + fn visit_expr_path(&mut self, expr_path: &'ast syn::ExprPath) { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 + { + self.idents + .insert(expr_path.path.segments[0].ident.to_string()); + } + + syn::visit::visit_expr_path(self, expr_path); + } + + fn visit_expr_macro(&mut self, expr_macro: &'ast syn::ExprMacro) { + self.contains_macro = true; + syn::visit::visit_expr_macro(self, expr_macro); + } + + fn visit_stmt_macro(&mut self, stmt_macro: &'ast syn::StmtMacro) { + self.contains_macro = true; + syn::visit::visit_stmt_macro(self, stmt_macro); + } +} + +fn generate_closure_input_aliases( + closure: &ExprClosure, + internal_names: &[Ident], +) -> syn::Result { + if closure.inputs.len() != internal_names.len() { + return Err(syn::Error::new_spanned( + closure, + "internal named binding generation error: closure arity mismatch", + )); + } + + let aliases = + closure + .inputs + .iter() + .zip(internal_names.iter()) + .map(|(pattern, internal_name)| { + quote! { + let #pattern = #internal_name; } + }); + + Ok(quote! { + #(#aliases)* + }) +} + +fn generate_parameter_bindings( + params: &[Ident], + closure: &ExprClosure, + parameter_vector: &Ident, +) -> TokenStream2 { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let bindings = params + .iter() + .enumerate() + .filter(|(_, ident)| usage.uses(ident)) + .map(|(index, ident)| { + quote! { + #[allow(unused_variables)] + let #ident = #parameter_vector[#index]; } - // recurse into brace / paren groups (bracket groups are indexing, handled above) - TokenTree::Group(g) - if matches!( - g.delimiter(), - proc_macro2::Delimiter::Brace | proc_macro2::Delimiter::Parenthesis - ) => - { - if let Some(n) = max_literal_index(g.stream(), names) { - best = Some(best.map_or(n, |m: usize| m.max(n))); + }); + + quote! { + #(#bindings)* + } +} + +fn classify_diffeq_mode(diffeq: &ExprClosure) -> syn::Result { + match closure_param_names(diffeq).len() { + 5 => Ok(OdeDiffeqMode::InjectedRouteInputs), + 7 => Ok(OdeDiffeqMode::ExplicitRouteVectors), + _ => Err(syn::Error::new_spanned( + diffeq, + "declaration-first `ode!` requires `diffeq` to have either 5 parameters: |x, p, t, dx, cov| or 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", + )), + } +} + +fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec { + routes.iter().map(|route| route.input.clone()).collect() +} + +fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { + let mut next_bolus_index = 0usize; + let mut next_infusion_index = 0usize; + + routes + .iter() + .map(|route| { + let index = match route.kind { + OdeRouteKind::Bolus => { + let index = next_bolus_index; + next_bolus_index += 1; + index } - } - _ => {} + OdeRouteKind::Infusion => { + let index = next_infusion_index; + next_infusion_index += 1; + index + } + }; + (route.input.clone(), index) + }) + .collect() +} + +fn dense_index_len(bindings: &[(Ident, usize)]) -> usize { + bindings + .iter() + .map(|(_, index)| index + 1) + .max() + .unwrap_or(0) +} + +fn validate_binding_conflicts( + left_label: &str, + left: &[Ident], + right_label: &str, + right: &[Ident], + context: &str, +) -> syn::Result<()> { + let right_names = right.iter().map(Ident::to_string).collect::>(); + + for ident in left { + let name = ident.to_string(); + if right_names.contains(&name) { + return Err(syn::Error::new_spanned( + ident, + format!( + "named {left_label} binding `{name}` conflicts with named {right_label} binding in {context}" + ), + )); } } - best + Ok(()) } -// --------------------------------------------------------------------------- -// Proc macro -// --------------------------------------------------------------------------- +fn validate_closure_param_conflicts( + closure_label: &str, + closure: &ExprClosure, + bindings: &[Ident], + binding_label: &str, +) -> syn::Result<()> { + let parameter_names = closure_param_names(closure) + .into_iter() + .filter(|name| !name.is_empty()) + .collect::>(); -/// Build an `equation::ODE` while **inferring** `nstates`, `ndrugs` and -/// `nout` from the maximum literal bracket-indices used in the closures. -/// -/// # Fields (any order, comma-separated) -/// -/// | Field | Required | Signature | -/// |----------|----------|-------------------------------------------------| -/// | `diffeq` | **yes** | `\|x, p, t, dx, bolus, rateiv, cov\| { … }` | -/// | `out` | **yes** | `\|x, p, t, cov, y\| { … }` | -/// | `init` | no | `\|p, t, cov, x\| { … }` | -/// | `lag` | no | `\|p, t, cov\| lag! { … }` | -/// | `fa` | no | `\|p, t, cov\| fa! { … }` | -/// -/// # Inference rules -/// -/// * **nstates** = max literal index of the state / derivative vectors + 1 -/// * **ndrugs** = max literal index of bolus / rateiv vectors + 1 -/// * **nout** = max literal index of the output vector + 1 -/// -/// Parameter names are taken from the closure signatures so you can name them -/// however you like. Only **literal** integer indices (e.g. `x[2]`) are -/// detected; computed indices require manual `.with_nstates()` etc. -/// -/// # Example -/// -/// ```ignore -/// use pharmsol::prelude::*; -/// -/// let ode = ode! { -/// diffeq: |x, p, _t, dx, b, rateiv, _cov| { -/// fetch_params!(p, ke, kcp, kpc, _v); -/// dx[0] = rateiv[0] + b[0] - ke * x[0] - kcp * x[0] + kpc * x[1]; -/// dx[1] = kcp * x[0] - kpc * x[1]; -/// }, -/// out: |x, p, _t, _cov, y| { -/// fetch_params!(p, _ke, _kcp, _kpc, v); -/// y[0] = x[0] / v; -/// }, -/// }; -/// // Inferred: nstates=2, ndrugs=1, nout=1 -/// ``` -#[proc_macro] -pub fn ode(input: TokenStream) -> TokenStream { - let input = syn::parse_macro_input!(input as OdeInput); + for ident in bindings { + let name = ident.to_string(); + if parameter_names.contains(&name) { + return Err(syn::Error::new_spanned( + ident, + format!( + "named {binding_label} binding `{name}` conflicts with `{closure_label}` closure parameter `{name}`" + ), + )); + } + } - // ── Validate parameter counts ──────────────────────────────── - let de_params = closure_param_names(&input.diffeq); - if de_params.len() != 7 { - return syn::Error::new_spanned( - &input.diffeq, - "diffeq closure must have 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", - ) - .to_compile_error() - .into(); + Ok(()) +} + +fn validate_named_binding_compatibility( + params: &[Ident], + states: &[Ident], + outputs: &[Ident], + routes: &[OdeRouteDecl], + diffeq: &ExprClosure, + out: &ExprClosure, + diffeq_mode: OdeDiffeqMode, +) -> syn::Result<()> { + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`diffeq` and `out` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`out` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`out` named binding generation", + )?; + + validate_closure_param_conflicts("diffeq", diffeq, params, "parameter")?; + validate_closure_param_conflicts("diffeq", diffeq, states, "state")?; + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + if diffeq_mode == OdeDiffeqMode::ExplicitRouteVectors { + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`diffeq` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`diffeq` named binding generation", + )?; + validate_closure_param_conflicts("diffeq", diffeq, &route_inputs, "route")?; } - let out_params = closure_param_names(&input.out); - if out_params.len() != 5 { - return syn::Error::new_spanned( - &input.out, - "out closure must have 5 parameters: |x, p, t, cov, y|", - ) - .to_compile_error() - .into(); + Ok(()) +} + +fn validate_analytical_named_binding_compatibility( + params: &[Ident], + states: &[Ident], + outputs: &[Ident], + routes: &[OdeRouteDecl], + lag: Option<&ExprClosure>, + fa: Option<&ExprClosure>, + init: Option<&ExprClosure>, + out: &ExprClosure, +) -> syn::Result<()> { + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "output", + outputs, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + + if let Some(lag) = lag { + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; } - // ── Collect names by role ──────────────────────────────────── - // diffeq positions: 0=x 3=dx 4=bolus 5=rateiv - // out positions: 0=x 4=y - // init positions: 3=x - let mut state_names: Vec = vec![ - de_params[0].clone(), - de_params[3].clone(), - out_params[0].clone(), - ]; - if let Some(ref ic) = input.init { - let ip = closure_param_names(ic); - if ip.len() >= 4 { - state_names.push(ip[3].clone()); - } + if let Some(fa) = fa { + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; } - state_names.sort(); - state_names.dedup(); - let drug_names = [de_params[4].clone(), de_params[5].clone()]; - let output_names = [out_params[4].clone()]; + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } - // filter empties (from wildcard `_` params) - let state_refs: Vec<&str> = state_names - .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - let drug_refs: Vec<&str> = drug_names - .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - let output_refs: Vec<&str> = output_names - .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - - // ── Scan closure bodies ────────────────────────────────────── - let de_tokens = input.diffeq.body.to_token_stream(); - let out_tokens = input.out.body.to_token_stream(); - let init_tokens = input.init.as_ref().map(|c| c.body.to_token_stream()); - - let max_state = [ - max_literal_index(de_tokens.clone(), &state_refs), - max_literal_index(out_tokens.clone(), &state_refs), - init_tokens.and_then(|t| max_literal_index(t, &state_refs)), - ] - .into_iter() - .flatten() - .max(); - - let max_drug = max_literal_index(de_tokens, &drug_refs); - let max_out = max_literal_index(out_tokens, &output_refs); - - let nstates = max_state.map_or(1, |n| n + 1); - let ndrugs = max_drug.map_or(1, |n| n + 1); - let nout = max_out.map_or(1, |n| n + 1); - - // ── Generate output ────────────────────────────────────────── - let diffeq = &input.diffeq; - let out = &input.out; - - let lag = input.lag.as_ref().map_or_else( - || quote! { |_, _, _| ::std::collections::HashMap::new() }, - |c| quote! { #c }, - ); - - let fa = input.fa.as_ref().map_or_else( - || quote! { |_, _, _| ::std::collections::HashMap::new() }, - |c| quote! { #c }, - ); + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; - let init = input - .init - .as_ref() - .map_or_else(|| quote! { |_, _, _, _| {} }, |c| quote! { #c }); + Ok(()) +} + +fn validate_sde_named_binding_compatibility( + params: &[Ident], + states: &[Ident], + outputs: &[Ident], + routes: &[OdeRouteDecl], + drift: &ExprClosure, + diffusion: &ExprClosure, + lag: Option<&ExprClosure>, + fa: Option<&ExprClosure>, + init: Option<&ExprClosure>, + out: &ExprClosure, +) -> syn::Result<()> { + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "output", + outputs, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + + validate_closure_param_conflicts("drift", drift, params, "parameter")?; + validate_closure_param_conflicts("drift", drift, states, "state")?; + validate_closure_param_conflicts("diffusion", diffusion, params, "parameter")?; + validate_closure_param_conflicts("diffusion", diffusion, states, "state")?; + + if let Some(lag) = lag { + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + Ok(()) +} + +fn generate_index_consts(idents: &[Ident]) -> TokenStream2 { + let bindings = idents.iter().enumerate().map(|(index, ident)| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }); quote! { - equation::ODE::new( - #diffeq, - #lag, - #fa, - #init, - #out, - ) - .with_nstates(#nstates) - .with_ndrugs(#ndrugs) - .with_nout(#nout) + #(#bindings)* + } +} + +fn generate_mapped_index_consts(bindings: &[(Ident, usize)]) -> TokenStream2 { + let bindings = bindings.iter().map(|(ident, index)| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }); + + quote! { + #(#bindings)* + } +} + +fn expand_out( + out: &ExprClosure, + params: &[Ident], + states: &[Ident], + outputs: &[Ident], +) -> syn::Result { + if closure_param_names(out).len() != 5 { + return Err(syn::Error::new_spanned( + out, + "declaration-first `ode!` requires `out` to have 5 parameters: |x, p, t, cov, y|", + )); + } + + let state_consts = generate_index_consts(states); + let output_consts = generate_index_consts(outputs); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let input_aliases = generate_closure_input_aliases( + out, + &[x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()], + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let body = &out.body; + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #body + }; + __pharmsol_out + }}) +} + +fn route_property_error(macro_name: &str, label: &str, node: T) -> syn::Error { + syn::Error::new_spanned( + node, + format!( + "{macro_name} requires `{label}` to return `{label}! {{ ... }}` so route-property metadata can be synthesized" + ), + ) +} + +fn find_terminal_macro_invocation( + macro_name: &str, + label: &str, + closure: &ExprClosure, +) -> syn::Result { + match closure.body.as_ref() { + Expr::Macro(expr_macro) if expr_macro.mac.path.is_ident(label) => { + Ok(expr_macro.mac.clone()) + } + Expr::Macro(expr_macro) => Err(route_property_error(macro_name, label, expr_macro)), + Expr::Block(expr_block) => { + for stmt in expr_block.block.stmts.iter().rev() { + match stmt { + Stmt::Expr(Expr::Macro(expr_macro), _) + if expr_macro.mac.path.is_ident(label) => + { + return Ok(expr_macro.mac.clone()); + } + Stmt::Expr(Expr::Macro(expr_macro), _) => { + return Err(route_property_error(macro_name, label, expr_macro)); + } + Stmt::Expr(other, _) => { + return Err(route_property_error(macro_name, label, other)); + } + Stmt::Macro(stmt_macro) if stmt_macro.mac.path.is_ident(label) => { + return Ok(stmt_macro.mac.clone()); + } + Stmt::Macro(stmt_macro) => { + return Err(route_property_error(macro_name, label, stmt_macro)); + } + _ => continue, + } + } + + Err(route_property_error(macro_name, label, expr_block)) + } + other => Err(route_property_error(macro_name, label, other)), + } +} + +fn extract_route_property_routes( + macro_name: &str, + label: &str, + closure: &ExprClosure, + routes: &[OdeRouteDecl], +) -> syn::Result> { + let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?; + let entries = Punctuated::::parse_terminated + .parse2(macro_expr.tokens.clone())?; + let known_routes = route_input_idents(routes) + .into_iter() + .map(|route| route.to_string()) + .collect::>(); + let mut seen = HashSet::new(); + + for entry in entries { + let route_name = entry.route.to_string(); + if !known_routes.contains(&route_name) { + return Err(syn::Error::new_spanned( + &entry.route, + format!( + "route `{route_name}` in `{label}!` is not declared in the `routes` section" + ), + )); + } + if !seen.insert(route_name.clone()) { + return Err(syn::Error::new_spanned( + &entry.route, + format!("duplicate route `{route_name}` in `{label}!`"), + )); + } + let _ = entry.value; + } + + Ok(seen) +} + +fn validate_route_property_kinds( + macro_name: &str, + label: &str, + routes: &[OdeRouteDecl], + property_routes: &HashSet, +) -> syn::Result<()> { + for route in routes { + if property_routes.contains(&route.input.to_string()) + && matches!(route.kind, OdeRouteKind::Infusion) + { + return Err(syn::Error::new_spanned( + &route.input, + format!( + "{macro_name} does not allow `{label}` on infusion route `{}`", + route.input + ), + )); + } + } + + Ok(()) +} + +fn expand_ode_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + if closure_param_names(closure).len() != 3 { + return Err(syn::Error::new_spanned( + closure, + format!( + "declaration-first `ode!` requires `{label}` to have 3 parameters: |p, t, cov|" + ), + )); + } + + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let input_aliases = + generate_closure_input_aliases(closure, &[p.clone(), t.clone(), cov.clone()])?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let body = &closure.body; + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_route_metadata( + routes: &[OdeRouteDecl], + diffeq_mode: OdeDiffeqMode, + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.to_string(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let input_policy = match diffeq_mode { + OdeDiffeqMode::InjectedRouteInputs => quote! { .inject_input_to_destination() }, + OdeDiffeqMode::ExplicitRouteVectors => quote! { .expect_explicit_input() }, + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + #lag_flag + #fa_flag + #input_policy + } + }) + .collect() +} + +fn expand_analytical_route_metadata( + routes: &[OdeRouteDecl], + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.to_string(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + #lag_flag + #fa_flag + } + }) + .collect() +} + +fn expand_sde_route_metadata( + routes: &[OdeRouteDecl], + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.to_string(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + .inject_input_to_destination() + #lag_flag + #fa_flag + } + }) + .collect() +} + +fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize { + states + .iter() + .position(|state| state == &route.destination) + .expect("validated route destination should exist") +} + +fn expand_injected_ode_route_terms( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(Ident, usize)], + dx: &Ident, + bolus: &Ident, + rateiv: &Ident, +) -> TokenStream2 { + let terms = routes + .iter() + .zip(route_bindings.iter()) + .map(|(route, (_, channel_index))| { + let destination = route_destination_index(route, states); + match route.kind { + OdeRouteKind::Bolus => quote! { + #dx[#destination] += #bolus[#channel_index]; + }, + OdeRouteKind::Infusion => quote! { + #dx[#destination] += #rateiv[#channel_index]; + }, + } + }); + + quote! { + #(#terms)* + } +} + +fn expand_injected_sde_rate_terms( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(Ident, usize)], + dx: &Ident, + rateiv: &Ident, +) -> TokenStream2 { + let terms = + routes + .iter() + .zip(route_bindings.iter()) + .filter_map(|(route, (_, channel_index))| match route.kind { + OdeRouteKind::Bolus => None, + OdeRouteKind::Infusion => { + let destination = route_destination_index(route, states); + Some(quote! { + #dx[#destination] += #rateiv[#channel_index]; + }) + } + }); + + quote! { + #(#terms)* + } +} + +fn expand_injected_sde_bolus_mappings( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(Ident, usize)], +) -> TokenStream2 { + let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; + + for (route, (_, channel_index)) in routes.iter().zip(route_bindings.iter()) { + if let OdeRouteKind::Bolus = route.kind { + let destination = route_destination_index(route, states); + destinations[*channel_index] = quote! { Some(#destination) }; + } + } + + quote! { + .with_injected_bolus_inputs(&[#(#destinations),*]) + } +} + +fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn::Result<()> { + let mut seen = HashSet::new(); + for ident in idents { + let name = ident.to_string(); + if !seen.insert(name.clone()) { + return Err(syn::Error::new_spanned( + ident, + format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"), + )); + } + } + Ok(()) +} + +fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> { + let known_states = states.iter().map(Ident::to_string).collect::>(); + let mut seen_routes = HashSet::new(); + + for route in routes { + let route_name = route.input.to_string(); + if !seen_routes.insert(route_name.clone()) { + return Err(syn::Error::new_spanned( + &route.input, + format!("duplicate route `{route_name}` in declaration-first `{macro_name}`"), + )); + } + + if !known_states.contains(&route.destination.to_string()) { + return Err(syn::Error::new_spanned( + &route.destination, + format!( + "route destination `{}` is not declared in the `states` section", + route.destination + ), + )); + } + } + + Ok(()) +} + +fn expand_diffeq( + diffeq: &ExprClosure, + params: &[Ident], + states: &[Ident], + routes: &[OdeRouteDecl], + route_bindings: &[(Ident, usize)], + diffeq_mode: OdeDiffeqMode, +) -> syn::Result { + let state_consts = generate_index_consts(states); + + match diffeq_mode { + OdeDiffeqMode::ExplicitRouteVectors => { + let route_consts = generate_mapped_index_consts(route_bindings); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let input_aliases = generate_closure_input_aliases( + diffeq, + &[ + x.clone(), + p.clone(), + t.clone(), + dx.clone(), + bolus.clone(), + rateiv.clone(), + cov.clone(), + ], + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let body = &diffeq.body; + + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #route_consts + #parameter_bindings + #body + }; + __pharmsol_diffeq + }}) + } + OdeDiffeqMode::InjectedRouteInputs => { + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let input_aliases = generate_closure_input_aliases( + diffeq, + &[x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()], + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let body = &diffeq.body; + let dx_binding = closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()); + let route_terms = expand_injected_ode_route_terms( + routes, + states, + route_bindings, + &dx_binding, + &bolus, + &rateiv, + ); + + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #body + #route_terms + }; + __pharmsol_diffeq + }}) + } + } +} + +fn resolve_analytical_structure(structure: &Ident) -> syn::Result { + let structure_name = structure.to_string(); + let (runtime_path, metadata_kernel, parameter_arity, state_count) = match structure_name + .as_str() + { + "one_compartment" => ( + quote! { ::pharmsol::equation::one_compartment }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartment }, + 1, + 1, + ), + "one_compartment_cl" => ( + quote! { ::pharmsol::equation::one_compartment_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentCl }, + 2, + 1, + ), + "one_compartment_cl_with_absorption" => ( + quote! { ::pharmsol::equation::one_compartment_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentClWithAbsorption }, + 3, + 2, + ), + "one_compartment_with_absorption" => ( + quote! { ::pharmsol::equation::one_compartment_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentWithAbsorption }, + 2, + 2, + ), + "two_compartments" => ( + quote! { ::pharmsol::equation::two_compartments }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartments }, + 3, + 2, + ), + "two_compartments_cl" => ( + quote! { ::pharmsol::equation::two_compartments_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsCl }, + 4, + 2, + ), + "two_compartments_cl_with_absorption" => ( + quote! { ::pharmsol::equation::two_compartments_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsClWithAbsorption }, + 5, + 3, + ), + "two_compartments_with_absorption" => ( + quote! { ::pharmsol::equation::two_compartments_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsWithAbsorption }, + 4, + 3, + ), + "three_compartments" => ( + quote! { ::pharmsol::equation::three_compartments }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartments }, + 5, + 3, + ), + "three_compartments_cl" => ( + quote! { ::pharmsol::equation::three_compartments_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsCl }, + 6, + 3, + ), + "three_compartments_cl_with_absorption" => ( + quote! { ::pharmsol::equation::three_compartments_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsClWithAbsorption }, + 7, + 4, + ), + "three_compartments_with_absorption" => ( + quote! { ::pharmsol::equation::three_compartments_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsWithAbsorption }, + 6, + 4, + ), + _ => { + return Err(syn::Error::new_spanned( + structure, + format!("unknown analytical structure `{structure_name}`"), + )); + } + }; + + Ok(AnalyticalKernelSpec { + runtime_path, + metadata_kernel, + parameter_arity, + state_count, + }) +} + +fn expand_analytical_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + if closure_param_names(closure).len() != 3 { + return Err(syn::Error::new_spanned( + closure, + format!("built-in `analytical!` requires `{label}` to have 3 parameters: |p, t, cov|"), + )); + } + + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let input_aliases = + generate_closure_input_aliases(closure, &[p.clone(), t.clone(), cov.clone()])?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let body = &closure.body; + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_analytical_init( + init: &ExprClosure, + params: &[Ident], + states: &[Ident], +) -> syn::Result { + if closure_param_names(init).len() != 4 { + return Err(syn::Error::new_spanned( + init, + "built-in `analytical!` requires `init` to have 4 parameters: |p, t, cov, x|", + )); + } + + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let input_aliases = + generate_closure_input_aliases(init, &[p.clone(), t.clone(), cov.clone(), x.clone()])?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_analytical_out( + out: &ExprClosure, + params: &[Ident], + states: &[Ident], + outputs: &[Ident], +) -> syn::Result { + if closure_param_names(out).len() != 5 { + return Err(syn::Error::new_spanned( + out, + "built-in `analytical!` requires `out` to have 5 parameters: |x, p, t, cov, y|", + )); + } + + let state_consts = generate_index_consts(states); + let output_consts = generate_index_consts(outputs); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let input_aliases = generate_closure_input_aliases( + out, + &[x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()], + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let body = &out.body; + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #body + }; + __pharmsol_out + }}) +} + +fn expand_sde_drift( + drift: &ExprClosure, + params: &[Ident], + states: &[Ident], + routes: &[OdeRouteDecl], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + if closure_param_names(drift).len() != 5 { + return Err(syn::Error::new_spanned( + drift, + "declaration-first `sde!` requires `drift` to have 5 parameters: |x, p, t, dx, cov|", + )); + } + + let state_consts = generate_index_consts(states); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let input_aliases = generate_closure_input_aliases( + drift, + &[x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()], + )?; + let parameter_bindings = generate_parameter_bindings(params, drift, &p); + let body = &drift.body; + let dx_binding = closure_param_ident(drift, 3).unwrap_or_else(|| dx.clone()); + let rate_terms = + expand_injected_sde_rate_terms(routes, states, route_bindings, &dx_binding, &rateiv); + + Ok(quote! {{ + let __pharmsol_drift: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #body + #rate_terms + }; + __pharmsol_drift + }}) +} + +fn expand_sde_diffusion( + diffusion: &ExprClosure, + params: &[Ident], + states: &[Ident], +) -> syn::Result { + if closure_param_names(diffusion).len() != 2 { + return Err(syn::Error::new_spanned( + diffusion, + "declaration-first `sde!` requires `diffusion` to have 2 parameters: |p, sigma|", + )); + } + + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let sigma = generated_ident("__pharmsol_sigma"); + let input_aliases = generate_closure_input_aliases(diffusion, &[p.clone(), sigma.clone()])?; + let parameter_bindings = generate_parameter_bindings(params, diffusion, &p); + let body = &diffusion.body; + + Ok(quote! {{ + let __pharmsol_diffusion: fn( + &::pharmsol::simulator::V, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #sigma: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #body + }; + __pharmsol_diffusion + }}) +} + +fn expand_sde_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + if closure_param_names(closure).len() != 3 { + return Err(syn::Error::new_spanned( + closure, + format!( + "declaration-first `sde!` requires `{label}` to have 3 parameters: |p, t, cov|" + ), + )); + } + + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let input_aliases = + generate_closure_input_aliases(closure, &[p.clone(), t.clone(), cov.clone()])?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let body = &closure.body; + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_sde_init( + init: &ExprClosure, + params: &[Ident], + states: &[Ident], +) -> syn::Result { + if closure_param_names(init).len() != 4 { + return Err(syn::Error::new_spanned( + init, + "declaration-first `sde!` requires `init` to have 4 parameters: |p, t, cov, x|", + )); + } + + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let input_aliases = + generate_closure_input_aliases(init, &[p.clone(), t.clone(), cov.clone(), x.clone()])?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_sde_out( + out: &ExprClosure, + params: &[Ident], + states: &[Ident], + outputs: &[Ident], +) -> syn::Result { + if closure_param_names(out).len() != 5 { + return Err(syn::Error::new_spanned( + out, + "declaration-first `sde!` requires `out` to have 5 parameters: |x, p, t, cov, y|", + )); + } + + let state_consts = generate_index_consts(states); + let output_consts = generate_index_consts(outputs); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let input_aliases = generate_closure_input_aliases( + out, + &[x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()], + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let body = &out.body; + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #body + }; + __pharmsol_out + }}) +} + +// --------------------------------------------------------------------------- +// Proc macros +// --------------------------------------------------------------------------- + +#[proc_macro] +pub fn ode(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as OdeInput); + + let route_bindings = ode_route_channel_bindings(&input.routes); + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `ode!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `ode!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `ode!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `ode!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let diffeq = match expand_diffeq( + &input.diffeq, + &input.params, + &input.states, + &input.routes, + &route_bindings, + input.diffeq_mode, + ) { + Ok(diffeq) => diffeq, + Err(error) => return error.to_compile_error().into(), + }; + + let out = match expand_out(&input.out, &input.params, &input.states, &input.outputs) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let covariates = &input.covariates; + let states = &input.states; + let outputs = &input.outputs; + let routes = expand_route_metadata(&input.routes, input.diffeq_mode, &lag_routes, &fa_routes); + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; + + let lag = match input.lag.as_ref() { + Some(closure) => match expand_ode_route_map("lag", closure, &input.params, &route_bindings) + { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_ode_route_map("fa", closure, &input.params, &route_bindings) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = input + .init + .as_ref() + .map_or_else(|| quote! { |_, _, _, _| {} }, |closure| quote! { #closure }); + + quote! {{ + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .parameters([#(stringify!(#params)),*]) + #covariate_metadata + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))*; + + ::pharmsol::equation::ODE::new( + #diffeq, + #lag, + #fa, + #init, + #out, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + .with_metadata(__pharmsol_metadata) + .expect("declaration-first `ode!` generated invalid metadata") + }} + .into() +} + +#[proc_macro] +pub fn analytical(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as AnalyticalInput); + let route_bindings = ode_route_channel_bindings(&input.routes); + + let kernel_spec = match resolve_analytical_structure(&input.structure) { + Ok(spec) => spec, + Err(error) => return error.to_compile_error().into(), + }; + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "built-in `analytical!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "built-in `analytical!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "built-in `analytical!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "built-in `analytical!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let out = match expand_analytical_out(&input.out, &input.params, &input.states, &input.outputs) + { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let lag = match input.lag.as_ref() { + Some(closure) => { + match expand_analytical_route_map("lag", closure, &input.params, &route_bindings) { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_analytical_route_map("fa", closure, &input.params, &route_bindings) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => match expand_analytical_init(closure, &input.params, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _, _| {} }, + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let states = &input.states; + let outputs = &input.outputs; + let routes = expand_analytical_route_metadata(&input.routes, &lag_routes, &fa_routes); + let runtime_path = kernel_spec.runtime_path; + let metadata_kernel = kernel_spec.metadata_kernel; + + quote! {{ + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .kind(::pharmsol::equation::ModelKind::Analytical) + .parameters([#(stringify!(#params)),*]) + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))* + .analytical_kernel(#metadata_kernel); + + ::pharmsol::equation::Analytical::new( + #runtime_path, + |_, _, _| {}, + #lag, + #fa, + #init, + #out, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + .with_metadata(__pharmsol_metadata) + .expect("built-in `analytical!` generated invalid metadata") + }} + .into() +} + +#[proc_macro] +pub fn sde(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as SdeInput); + let route_bindings = ode_route_channel_bindings(&input.routes); + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `sde!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `sde!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `sde!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `sde!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let drift = match expand_sde_drift( + &input.drift, + &input.params, + &input.states, + &input.routes, + &route_bindings, + ) { + Ok(drift) => drift, + Err(error) => return error.to_compile_error().into(), + }; + + let diffusion = match expand_sde_diffusion(&input.diffusion, &input.params, &input.states) { + Ok(diffusion) => diffusion, + Err(error) => return error.to_compile_error().into(), + }; + + let lag = match input.lag.as_ref() { + Some(closure) => match expand_sde_route_map("lag", closure, &input.params, &route_bindings) + { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_sde_route_map("fa", closure, &input.params, &route_bindings) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => match expand_sde_init(closure, &input.params, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _, _| {} }, + }; + + let out = match expand_sde_out(&input.out, &input.params, &input.states, &input.outputs) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let states = &input.states; + let outputs = &input.outputs; + let particles = &input.particles; + let routes = expand_sde_route_metadata(&input.routes, &lag_routes, &fa_routes); + let bolus_mappings = + expand_injected_sde_bolus_mappings(&input.routes, &input.states, &route_bindings); + + quote! {{ + let __pharmsol_particles: usize = #particles; + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .kind(::pharmsol::equation::ModelKind::Sde) + .parameters([#(stringify!(#params)),*]) + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))* + .particles(__pharmsol_particles); + + ::pharmsol::equation::SDE::new( + #drift, + #diffusion, + #lag, + #fa, + #init, + #out, + __pharmsol_particles, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + #bolus_mappings + .with_metadata(__pharmsol_metadata) + .expect("declaration-first `sde!` generated invalid metadata") + }} + .into() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_removed_legacy_form() { + let error = syn::parse_str::( + "diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("legacy macro form must fail"); + + assert!(error + .to_string() + .contains("requires `name`, `params`, `states`, `outputs`, and `routes`")); + assert!(error + .to_string() + .contains("old inferred-dimensions form has been removed")); + } + + #[test] + fn validates_route_destinations() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> peripheral }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown route destination must fail"); + + assert!(error + .to_string() + .contains("route destination `peripheral` is not declared in the `states` section")); + } + + #[test] + fn rejects_named_binding_collisions() { + let error = syn::parse_str::( + "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("parameter/state binding collisions must fail"); + + assert!(error + .to_string() + .contains("named parameter binding `central` conflicts with named state binding")); + } + + #[test] + fn ode_route_bindings_share_channels_by_kind_local_ordinal() { + let input = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: { bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot }, diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + ) + .expect("declaration-first ode input should parse"); + + let bindings = ode_route_channel_bindings(&input.routes); + + assert_eq!(dense_index_len(&bindings), 2); + assert_eq!(bindings[0].0.to_string(), "oral"); + assert_eq!(bindings[0].1, 0); + assert_eq!(bindings[1].0.to_string(), "iv"); + assert_eq!(bindings[1].1, 0); + assert_eq!(bindings[2].0.to_string(), "sc"); + assert_eq!(bindings[2].1, 1); + } + + #[test] + fn generated_parameter_bindings_only_include_referenced_locals_in_hot_closures() { + let params = vec![generated_ident("ke"), generated_ident("v")]; + let closure = syn::parse_str::( + "|x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }", + ) + .expect("closure should parse"); + + let bindings = + generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p")) + .to_string(); + + assert!( + bindings.contains("let ke = __pharmsol_p [0usize] ;") + || bindings.contains("let ke = __pharmsol_p [ 0 ] ;") + ); + assert!(!bindings.contains("let v =")); + } + + #[test] + fn generated_parameter_bindings_fall_back_to_all_params_for_stmt_macros() { + let params = vec![generated_ident("ka"), generated_ident("tlag")]; + let closure = syn::parse_str::("|_p, _t, _cov| { lag! { oral => tlag } }") + .expect("closure should parse"); + + let bindings = + generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p")) + .to_string(); + + assert!(bindings.contains("let ka =")); + assert!(bindings.contains("let tlag =")); + } + + #[test] + fn analytical_accepts_extra_parameters_beyond_kernel_arity() { + let input = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v, tlag], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + ) + .expect("extra declared parameters should be allowed"); + + assert_eq!(input.params.len(), 4); + assert_eq!(input.states.len(), 2); + } + + #[test] + fn analytical_rejects_unknown_structure() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: mystery, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown analytical structure must fail"); + + assert!(error + .to_string() + .contains("unknown analytical structure `mystery`")); + } + + #[test] + fn analytical_rejects_insufficient_kernel_parameters() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("insufficient kernel parameters must fail"); + + assert!(error + .to_string() + .contains("requires at least 2 parameter value(s)")); + } + + #[test] + fn analytical_rejects_unknown_route_property_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown lag route must fail"); + + assert!(error + .to_string() + .contains("route `iv` in `lag!` is not declared in the `routes` section")); + } + + #[test] + fn analytical_rejects_infusion_lag_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("infusion lag must fail"); + + assert!(error + .to_string() + .contains("built-in `analytical!` does not allow `lag` on infusion route `iv`")); + } + + #[test] + fn sde_requires_particles() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("missing particles must fail"); + + assert!(error + .to_string() + .contains("missing required field `particles` in declaration-first `sde!`")); + } + + #[test] + fn sde_rejects_unknown_route_property_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown lag route must fail"); + + assert!(error + .to_string() + .contains("route `oral` in `lag!` is not declared in the `routes` section")); + } + + #[test] + fn sde_rejects_infusion_lag_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("infusion lag must fail"); + + assert!(error + .to_string() + .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`")); } - .into() } diff --git a/src/dsl/compiled_backend_abi.rs b/src/dsl/compiled_backend_abi.rs index 26a2f825..8717c416 100644 --- a/src/dsl/compiled_backend_abi.rs +++ b/src/dsl/compiled_backend_abi.rs @@ -324,7 +324,9 @@ mod tests { }], routes: vec![NativeRouteInfo { name: "iv".to_string(), + declaration_index: 0, index: 0, + kind: None, destination_offset: 1, inject_input_to_destination: true, }], diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index ed516387..ac71f504 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -731,7 +731,9 @@ fn lower_load( ExecutionLoad::Parameter(index) => load_fixed(builder, env.args.params, *index, ty), ExecutionLoad::Covariate(index) => load_fixed(builder, env.args.covariates, *index, ty), ExecutionLoad::Derived(index) => load_fixed(builder, env.args.derived, *index, ty), - ExecutionLoad::RouteInput(index) => load_fixed(builder, env.args.routes, *index, ty), + ExecutionLoad::RouteInput { index, .. } => { + load_fixed(builder, env.args.routes, *index, ty) + } ExecutionLoad::Local(index) => { let binding = env.locals.get(index).ok_or_else(|| { JitCompileError::new(format!("unknown local slot {index}"), Some(span)) @@ -1330,6 +1332,89 @@ mod tests { assert!(debugged.contains("error[DSL4000]"), "{}", debugged); } + #[test] + fn authoring_runtime_shares_channel_between_bolus_and_infusion_routes() { + let source = r#" +name = shared_authoring +kind = ode + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + let parsed = pharmsol_dsl::parse_model(source).expect("authoring model parses"); + let typed = pharmsol_dsl::analyze_model(&parsed).expect("authoring model analyzes"); + let model = pharmsol_dsl::lower_typed_model(&typed).expect("authoring model lowers"); + let jit = compile_ode_model_to_jit(&model) + .expect("compile jit ode model") + .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); + + let oral = jit.route_index("oral").expect("oral route"); + let iv = jit.route_index("iv").expect("iv route"); + let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(iv, 0); + + let subject = Subject::builder("ode") + .bolus(0.0, 120.0, oral) + .infusion(6.0, 60.0, iv, 2.0) + .observation(0.5, 0.0, cp) + .observation(1.0, 0.0, cp) + .observation(2.0, 0.0, cp) + .observation(6.0, 0.0, cp) + .observation(7.0, 0.0, cp) + .observation(9.0, 0.0, cp) + .build(); + + let support = vec![1.2, 0.15, 40.0]; + let jit_predictions = jit + .estimate_predictions(&subject, &support) + .expect("jit predictions"); + + let reference = ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + let ka = p[0]; + let ke = p[1]; + dx[0] = -ka * x[0] + bolus[0]; + dx[1] = ka * x[0] - ke * x[1] + rateiv[0]; + }, + |_p, _t, _cov| std::collections::HashMap::new(), + |_p, _t, _cov| std::collections::HashMap::new(), + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + y[0] = x[1] / p[2]; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); + + let reference_predictions = reference + .estimate_predictions(&subject, &support) + .expect("reference ode predictions"); + + for (jit_pred, reference_pred) in jit_predictions + .predictions() + .iter() + .zip(reference_predictions.predictions()) + { + assert_relative_eq!( + jit_pred.prediction(), + reference_pred.prediction(), + max_relative = 1e-4 + ); + } + } + fn slot_index(layout: &DenseBufferLayout, name: &str) -> usize { layout .slots diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 8e48a022..6735ca30 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -1,10 +1,12 @@ +use std::collections::BTreeMap; + use serde::{Deserialize, Serialize}; use pharmsol_dsl::execution::{ ExecutionExpr, ExecutionExprKind, ExecutionLoad, ExecutionModel, ExecutionStmt, ExecutionStmtKind, KernelImplementation, KernelRole, }; -use pharmsol_dsl::{AnalyticalKernel, ModelKind}; +use pharmsol_dsl::{AnalyticalKernel, ModelKind, RouteKind}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeModelInfo { @@ -31,7 +33,11 @@ pub struct NativeCovariateInfo { #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeRouteInfo { pub name: String, + #[serde(default)] + pub declaration_index: usize, pub index: usize, + #[serde(default)] + pub kind: Option, pub destination_offset: usize, pub inject_input_to_destination: bool, } @@ -69,10 +75,12 @@ impl NativeModelInfo { .iter() .map(|route| NativeRouteInfo { name: route.name.clone(), + declaration_index: route.declaration_index, index: route.index, + kind: route.kind, destination_offset: route.destination.state_offset, inject_input_to_destination: !explicit_route_input_usage - .get(route.index) + .get(route.declaration_index) .copied() .unwrap_or(false), }) @@ -97,6 +105,12 @@ impl NativeModelInfo { } fn explicit_route_input_usage(model: &ExecutionModel) -> Vec { + let declaration_slots = model + .metadata + .routes + .iter() + .map(|route| (route.symbol, route.declaration_index)) + .collect::>(); let Some(kernel) = (match model.kind { ModelKind::Ode => model.kernel(KernelRole::Dynamics), ModelKind::Sde => model.kernel(KernelRole::Drift), @@ -107,54 +121,152 @@ fn explicit_route_input_usage(model: &ExecutionModel) -> Vec { let mut usage = vec![false; model.metadata.routes.len()]; if let KernelImplementation::Statements(program) = &kernel.implementation { - mark_route_inputs_in_statements(&program.body.statements, &mut usage); + mark_route_inputs_in_statements(&program.body.statements, &declaration_slots, &mut usage); } usage } -fn mark_route_inputs_in_statements(statements: &[ExecutionStmt], usage: &mut [bool]) { +fn mark_route_inputs_in_statements( + statements: &[ExecutionStmt], + declaration_slots: &BTreeMap, + usage: &mut [bool], +) { for statement in statements { match &statement.kind { ExecutionStmtKind::Let(let_stmt) => { - mark_route_inputs_in_expr(&let_stmt.value, usage); + mark_route_inputs_in_expr(&let_stmt.value, declaration_slots, usage); } ExecutionStmtKind::Assign(assign_stmt) => { - mark_route_inputs_in_expr(&assign_stmt.value, usage); + mark_route_inputs_in_expr(&assign_stmt.value, declaration_slots, usage); } ExecutionStmtKind::If(if_stmt) => { - mark_route_inputs_in_expr(&if_stmt.condition, usage); - mark_route_inputs_in_statements(&if_stmt.then_branch, usage); + mark_route_inputs_in_expr(&if_stmt.condition, declaration_slots, usage); + mark_route_inputs_in_statements(&if_stmt.then_branch, declaration_slots, usage); if let Some(else_branch) = &if_stmt.else_branch { - mark_route_inputs_in_statements(else_branch, usage); + mark_route_inputs_in_statements(else_branch, declaration_slots, usage); } } ExecutionStmtKind::For(for_stmt) => { - mark_route_inputs_in_expr(&for_stmt.range.start, usage); - mark_route_inputs_in_expr(&for_stmt.range.end, usage); - mark_route_inputs_in_statements(&for_stmt.body, usage); + mark_route_inputs_in_expr(&for_stmt.range.start, declaration_slots, usage); + mark_route_inputs_in_expr(&for_stmt.range.end, declaration_slots, usage); + mark_route_inputs_in_statements(&for_stmt.body, declaration_slots, usage); } } } } -fn mark_route_inputs_in_expr(expr: &ExecutionExpr, usage: &mut [bool]) { +fn mark_route_inputs_in_expr( + expr: &ExecutionExpr, + declaration_slots: &BTreeMap, + usage: &mut [bool], +) { match &expr.kind { ExecutionExprKind::Literal(_) => {} - ExecutionExprKind::Load(ExecutionLoad::RouteInput(index)) => { - if let Some(slot) = usage.get_mut(*index) { + ExecutionExprKind::Load(ExecutionLoad::RouteInput { route, .. }) => { + if let Some(slot) = declaration_slots.get(route).and_then(|index| usage.get_mut(*index)) { *slot = true; } } ExecutionExprKind::Load(_) => {} - ExecutionExprKind::Unary { expr, .. } => mark_route_inputs_in_expr(expr, usage), + ExecutionExprKind::Unary { expr, .. } => { + mark_route_inputs_in_expr(expr, declaration_slots, usage) + } ExecutionExprKind::Binary { lhs, rhs, .. } => { - mark_route_inputs_in_expr(lhs, usage); - mark_route_inputs_in_expr(rhs, usage); + mark_route_inputs_in_expr(lhs, declaration_slots, usage); + mark_route_inputs_in_expr(rhs, declaration_slots, usage); } ExecutionExprKind::Call { args, .. } => { for arg in args { - mark_route_inputs_in_expr(arg, usage); + mark_route_inputs_in_expr(arg, declaration_slots, usage); } } } } + +#[cfg(test)] +mod tests { + use super::*; + use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model}; + + fn load_model_info(src: &str) -> NativeModelInfo { + let model = parse_model(src).expect("model parses"); + let typed = analyze_model(&model).expect("model analyzes"); + let lowered = lower_typed_model(&typed).expect("model lowers"); + NativeModelInfo::from_execution_model(&lowered) + } + + #[test] + fn declaration_first_routes_inject_by_default() { + let info = load_model_info( + r#" +model implicit_route_injection { + kind ode + states { central } + routes { iv -> central } + dynamics { + ddt(central) = 0 + } + outputs { + cp = central + } +} +"#, + ); + + assert_eq!(info.routes.len(), 1); + assert!(info.routes[0].inject_input_to_destination); + } + + #[test] + fn explicit_rate_usage_disables_automatic_injection() { + let info = load_model_info( + r#" +model explicit_route_usage { + kind ode + states { central } + routes { iv -> central } + dynamics { + ddt(central) = rate(iv) + } + outputs { + cp = central + } +} +"#, + ); + + assert_eq!(info.routes.len(), 1); + assert!(!info.routes[0].inject_input_to_destination); + } + + #[test] + fn authoring_shared_channel_routes_keep_declaration_specific_injection() { + let info = load_model_info( + r#" +name = shared_authoring +kind = ode + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#, + ); + + assert_eq!(info.route_len, 1); + assert_eq!(info.routes.len(), 2); + assert_eq!(info.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(info.routes[1].kind, Some(RouteKind::Infusion)); + assert_eq!(info.routes[0].index, 0); + assert_eq!(info.routes[1].index, 0); + assert!(info.routes[0].inject_input_to_destination); + assert!(!info.routes[1].inject_input_to_destination); + } +} diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 4a94715f..923b644d 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -14,7 +14,7 @@ use cranelift_jit::JITModule; #[cfg(feature = "dsl-aot-load")] use libloading::Library; use pharmsol_dsl::execution::KernelRole; -use pharmsol_dsl::AnalyticalKernel; +use pharmsol_dsl::{AnalyticalKernel, RouteKind}; pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, @@ -264,12 +264,72 @@ impl RuntimeArtifact for NativeExecutionArtifact { #[derive(Clone, Debug)] struct SharedNativeModel { info: Arc, + route_semantics: Arc, artifact: Arc, } +#[derive(Clone, Debug)] +struct RouteInputSemantics { + bolus_destinations: Vec>, + infusion_inputs: Vec, + injected_infusion_destinations: Vec>, +} + +impl RouteInputSemantics { + fn from_model_info(info: &NativeModelInfo) -> Self { + let mut bolus_destinations = vec![None; info.route_len]; + let mut infusion_inputs = vec![false; info.route_len]; + let mut injected_infusion_destinations = vec![None; info.route_len]; + + for route in &info.routes { + match route.kind { + Some(RouteKind::Bolus) => { + bolus_destinations[route.index] = Some(route.destination_offset); + } + Some(RouteKind::Infusion) => { + infusion_inputs[route.index] = true; + if route.inject_input_to_destination { + injected_infusion_destinations[route.index] = Some(route.destination_offset); + } + } + None => { + bolus_destinations[route.index] = Some(route.destination_offset); + infusion_inputs[route.index] = true; + if route.inject_input_to_destination { + injected_infusion_destinations[route.index] = Some(route.destination_offset); + } + } + } + } + + Self { + bolus_destinations, + infusion_inputs, + injected_infusion_destinations, + } + } + + fn supports_input(&self, input: usize, kind: RouteKind) -> bool { + match kind { + RouteKind::Bolus => self + .bolus_destinations + .get(input) + .copied() + .flatten() + .is_some(), + RouteKind::Infusion => self.infusion_inputs.get(input).copied().unwrap_or(false), + } + } + + fn bolus_destination(&self, input: usize) -> Option { + self.bolus_destinations.get(input).copied().flatten() + } +} + impl SharedNativeModel { fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { Self { + route_semantics: Arc::new(RouteInputSemantics::from_model_info(&info)), info: Arc::new(info), artifact: Arc::new(artifact), } @@ -313,6 +373,18 @@ impl SharedNativeModel { Ok(()) } + fn validate_input_for_kind(&self, input: usize, kind: RouteKind) -> Result<(), PharmsolError> { + self.validate_input(input)?; + if self.route_semantics.supports_input(input, kind) { + return Ok(()); + } + + Err(PharmsolError::OtherError(format!( + "model `{}` does not declare a {:?} route for input channel {}", + self.info.name, kind, input + ))) + } + fn fill_cov_buffer(&self, covariates: &Covariates, time: f64, buf: &mut [f64]) { for covariate in &self.info.covariates { buf[covariate.index] = match covariates.get_covariate(&covariate.name) { @@ -323,9 +395,14 @@ impl SharedNativeModel { } fn apply_route_inputs_to_rates(&self, rates: &mut [f64], route_inputs: &[f64]) { - for route in &self.info.routes { - if route.inject_input_to_destination { - rates[route.destination_offset] += route_inputs[route.index]; + for (input, destination) in self + .route_semantics + .injected_infusion_destinations + .iter() + .enumerate() + { + if let Some(destination) = destination { + rates[*destination] += route_inputs[input]; } } } @@ -451,7 +528,7 @@ impl SharedNativeModel { for event in events.iter_mut() { if let Event::Bolus(bolus) = event { - self.validate_input(bolus.input())?; + self.validate_input_for_kind(bolus.input(), RouteKind::Bolus)?; if self.artifact.has_kernel(KernelRole::RouteLag) { lag_values.fill(0.0); @@ -525,9 +602,14 @@ impl SharedNativeModel { input: usize, amount: f64, ) -> Result<(), PharmsolError> { - self.validate_input(input)?; - let destination = &self.info.routes[input]; - state[destination.destination_offset] += amount; + self.validate_input_for_kind(input, RouteKind::Bolus)?; + let destination = self.route_semantics.bolus_destination(input).ok_or_else(|| { + PharmsolError::OtherError(format!( + "model `{}` does not declare a bolus route for input channel {}", + self.info.name, input + )) + })?; + state[destination] += amount; Ok(()) } @@ -654,7 +736,8 @@ impl NativeOdeModel { .collect::>(); for infusion in &infusions { - self.shared.validate_input(infusion.input())?; + self.shared + .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; } let mut events = occasion.process_events(None, true); @@ -919,7 +1002,8 @@ impl NativeAnalyticalModel { .collect::>(); for infusion in &infusions { - self.shared.validate_input(infusion.input())?; + self.shared + .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; } let mut events = occasion.process_events(None, true); @@ -1073,7 +1157,8 @@ impl NativeSdeModel { .collect::>(); for infusion in &infusions { - self.shared.validate_input(infusion.input())?; + self.shared + .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; } let mut events = occasion.process_events(None, true); diff --git a/src/dsl/rust_backend.rs b/src/dsl/rust_backend.rs index 19b3e5cc..850e13b7 100644 --- a/src/dsl/rust_backend.rs +++ b/src/dsl/rust_backend.rs @@ -264,7 +264,7 @@ fn emit_load(load: &ExecutionLoad, ty: ValueType) -> Result { ExecutionLoad::Covariate(index) => format!("load_f64(covariates, {index})"), ExecutionLoad::Derived(index) => format!("load_f64(derived, {index})"), ExecutionLoad::Local(index) => return Ok(format!("local_{index}")), - ExecutionLoad::RouteInput(index) => format!("load_f64(routes, {index})"), + ExecutionLoad::RouteInput { index, .. } => format!("load_f64(routes, {index})"), ExecutionLoad::State(state) => { let index = emit_state_ref_index(state)?; format!("load_f64(states, {index})") diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index 16884952..56c6fcbd 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -778,7 +778,9 @@ mod tests { covariates: Vec::new(), routes: vec![NativeRouteInfo { name: "oral".to_string(), + declaration_index: 0, index: 0, + kind: None, destination_offset: 0, inject_input_to_destination: true, }], @@ -893,7 +895,7 @@ mod tests { let model_info = loader_test_model_info("api_version_export_mismatch"); let metadata = serde_json::to_vec(&CompiledModelInfoEnvelope { abi_version: WASM_API_VERSION, - model: model_info, + name: "model_info", kernels: CompiledKernelAvailability { outputs: true, ..CompiledKernelAvailability::default() @@ -922,7 +924,7 @@ mod tests { let model_info = loader_test_model_info("metadata_api_version_mismatch"); let metadata = serde_json::to_vec(&CompiledModelInfoEnvelope { abi_version: WASM_API_VERSION + 1, - model: model_info, + name: "model_info", kernels: CompiledKernelAvailability { outputs: true, ..CompiledKernelAvailability::default() @@ -948,7 +950,7 @@ mod tests { let model_info = loader_test_model_info("kernel_metadata_mismatch"); let metadata = serde_json::to_vec(&CompiledModelInfoEnvelope { abi_version: WASM_API_VERSION, - model: model_info, + name: "model_info", kernels: CompiledKernelAvailability { outputs: true, ..CompiledKernelAvailability::default() diff --git a/src/dsl/wasm_compile.rs b/src/dsl/wasm_compile.rs index 66995e8a..caa60216 100644 --- a/src/dsl/wasm_compile.rs +++ b/src/dsl/wasm_compile.rs @@ -848,7 +848,7 @@ mod tests { }; const SIMPLE_SOURCE: &str = r#" -model = example_ode +name = example_ode kind = ode params = ke, v @@ -901,7 +901,7 @@ out(cp) = central / v ~ continuous() cache .compile_module_source_to_wasm_module( r#" -model = second_ode +name = second_ode kind = ode params = ke, v @@ -949,7 +949,7 @@ out(cp) = central / v ~ continuous() #[test] fn compile_module_source_to_wasm_module_preserves_semantic_diagnostic_structure() { let source = r#" -model = broken +name = broken kind = ode states = central @@ -995,7 +995,7 @@ out(cp) = central ~ continuous() #[test] fn compile_module_source_to_wasm_module_preserves_lowering_diagnostic_structure() { let source = r#" -model = broken +name = broken kind = ode states = transit[4], central diff --git a/src/dsl/wasm_direct_emitter.rs b/src/dsl/wasm_direct_emitter.rs index 2d92ad1d..857f2ac7 100644 --- a/src/dsl/wasm_direct_emitter.rs +++ b/src/dsl/wasm_direct_emitter.rs @@ -922,7 +922,7 @@ fn emit_load( function.instruction(&Instruction::LocalGet(local.wasm_local)); emit_cast_stack(local.ty, target_ty, function, state.model_name) } - ExecutionLoad::RouteInput(index) => emit_dense_load( + ExecutionLoad::RouteInput { index, .. } => emit_dense_load( function, KERNEL_PARAM_ROUTES, *index, diff --git a/src/lib.rs b/src/lib.rs index f2691579..c84d4ee1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,28 +28,31 @@ pub use crate::data::Interpolation::*; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::data::*; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::equation::*; -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::optimize::effect::get_e2; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::optimize::spp::SppOptimizer; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::simulator::equation::analytical::*; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::simulator::equation::metadata; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::simulator::equation::{ self, ode::{ExplicitRkTableau, OdeSolver, SdirkTableau}, - ODE, + Analytical, AnalyticalKernel, Cache, Equation, ModelKind, ModelMetadata, ModelMetadataError, + NameDomain, Predictions, RouteInputPolicy, RouteKind, State, ValidatedModelMetadata, ODE, SDE, }; pub use error::PharmsolError; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use nalgebra::dmatrix; -pub use pharmsol_macros::ode; +pub use pharmsol_macros::{analytical, ode, sde}; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use std::collections::HashMap; /// Prelude module that re-exports all commonly used types and traits. /// -/// Use `use pharmsol::prelude::*;` to import everything needed for basic -/// pharmacometric modeling. +/// Importing `pharmsol::prelude::*` brings the main modeling, simulation, +/// and data APIs into scope. /// /// # Example /// ```rust @@ -92,7 +95,7 @@ pub mod prelude { pub use crate::data::auc::{auc, auc_interval, aumc, interpolate_linear}; #[allow(deprecated)] - // Simulator submodule for internal use and advanced users + // Simulator submodule for organized access to simulation types. pub mod simulator { pub use crate::simulator::{ cache::{self, PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE}, @@ -136,6 +139,8 @@ pub mod prelude { // Re-export macros (they are exported at crate root via #[macro_export]) #[doc(inline)] + pub use crate::analytical; + #[doc(inline)] pub use crate::fa; #[doc(inline)] pub use crate::fetch_cov; @@ -145,6 +150,8 @@ pub mod prelude { pub use crate::lag; #[doc(inline)] pub use crate::ode; + #[doc(inline)] + pub use crate::sde; } #[macro_export] diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 0aff0936..4734886c 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -8,6 +8,8 @@ pub mod two_compartment_models; use diffsol::{NalgebraContext, Vector, VectorHost}; pub use one_compartment_cl_models::*; pub use one_compartment_models::*; +use pharmsol_dsl::ModelKind; +use thiserror::Error; pub use three_compartment_cl_models::*; pub use three_compartment_models::*; pub use two_compartment_cl_models::*; @@ -15,12 +17,28 @@ pub use two_compartment_models::*; use super::spphash; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, + ValidatedModelMetadata, +}; use crate::data::error_model::AssayErrorModels; use crate::simulator::cache::{PredictionCache, DEFAULT_CACHE_SIZE}; use crate::PharmsolError; -use crate::{ - data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject, -}; +use crate::{data::Covariates, simulator::*, Observation, Subject}; + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum AnalyticalMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("analytical model declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error( + "analytical model declares {declared} route metadata entries but model has {expected} input channels" + )] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("analytical model declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} /// Model equation using analytical solutions. /// @@ -35,6 +53,7 @@ pub struct Analytical { init: Init, out: Out, neqs: Neqs, + metadata: Option, cache: Option, } @@ -88,6 +107,7 @@ impl Analytical { init, out, neqs: Neqs::default(), + metadata: None, cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -95,20 +115,94 @@ impl Analytical { /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } /// Set the number of drug input channels (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } + + /// Attach validated handwritten-model metadata to this analytical model. + pub fn with_metadata( + mut self, + metadata: ModelMetadata, + ) -> Result { + let metadata = metadata.validate_for(ModelKind::Analytical)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + /// Access the validated metadata attached to this analytical model, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.metadata()?.route_index(name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.metadata()?.output_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), AnalyticalMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(AnalyticalMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_channel_count(); + if declared_routes != neqs.ndrugs { + return Err(AnalyticalMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(AnalyticalMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for Analytical { @@ -302,6 +396,7 @@ pub(crate) mod tests { use crate::SubjectBuilderExt; use approx::assert_relative_eq; use diffsol::Vector; + use pharmsol_dsl::AnalyticalKernel; use std::collections::HashMap; pub(crate) enum SubjectInfo { @@ -423,6 +518,158 @@ pub(crate) mod tests { assert_eq!(predictions.predictions()[0].prediction(), 4.0); } + fn simple_analytical() -> Analytical { + let eq = |x: &V, _p: &V, _dt: f64, _rateiv: &V, _cov: &Covariates| x.clone(); + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[0]; + }; + + Analytical::new(eq, seq_eq, lag, fa, init, out) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + #[test] + fn handwritten_analytical_metadata_exposes_name_lookup() { + let analytical = simple_analytical() + .with_metadata( + super::super::metadata::new("one_cmt_analytical") + .parameters(["ke", "v"]) + .covariates([super::super::Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate"); + + assert_eq!(analytical.parameter_index("ke"), Some(0)); + assert_eq!(analytical.parameter_index("v"), Some(1)); + assert_eq!(analytical.covariate_index("wt"), Some(0)); + assert_eq!(analytical.state_index("central"), Some(0)); + assert_eq!(analytical.route_index("iv"), Some(0)); + assert_eq!(analytical.output_index("cp"), Some(0)); + assert_eq!( + analytical.metadata().expect("metadata exists").kind(), + ModelKind::Analytical + ); + } + + #[test] + fn handwritten_analytical_without_metadata_keeps_raw_path() { + let analytical = simple_analytical(); + + assert!(analytical.metadata().is_none()); + assert_eq!(analytical.state_index("central"), None); + assert_eq!(analytical.route_index("iv"), None); + assert_eq!(analytical.output_index("cp"), None); + } + + #[test] + fn handwritten_analytical_rejects_dimension_mismatches() { + let error = simple_analytical() + .with_metadata( + super::super::metadata::new("wrong_outputs") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp", "auc"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect_err("output-count mismatches must fail"); + + assert_eq!( + error, + AnalyticalMetadataError::OutputCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_analytical_rejects_particles_metadata() { + let error = simple_analytical() + .with_metadata( + super::super::metadata::new("invalid_particles") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")) + .particles(64), + ) + .expect_err("analytical metadata must reject particles"); + + assert_eq!( + error, + AnalyticalMetadataError::Validation(ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Analytical, + }) + ); + } + + #[test] + fn built_in_analytical_models_can_advertise_kernel_identity() { + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + let analytical = + Analytical::new(one_compartment_with_absorption, seq_eq, lag, fa, init, out) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("one_cmt_abs") + .parameters(["ka", "ke", "v"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral").to_state("gut"), + super::super::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("built-in analytical metadata should validate"); + + assert_eq!( + analytical + .metadata() + .expect("metadata exists") + .analytical_kernel(), + Some(AnalyticalKernel::OneCompartmentWithAbsorption) + ); + assert_eq!(analytical.route_index("oral"), Some(0)); + assert_eq!(analytical.route_index("iv"), Some(0)); + } + + #[test] + fn changing_dimensions_after_metadata_clears_analytical_metadata() { + let analytical = simple_analytical() + .with_metadata( + super::super::metadata::new("one_cmt_analytical") + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate") + .with_ndrugs(2); + + assert!(analytical.metadata().is_none()); + assert_eq!(analytical.route_index("iv"), None); + } + fn assert_pm_wrapper_matches_native( native: AnalyticalEq, wrapper: AnalyticalEq, @@ -567,8 +814,8 @@ impl Equation for Analytical { ypred.log_likelihood(error_models) } - fn kind() -> crate::EqnKind { - crate::EqnKind::Analytical + fn kind() -> EqnKind { + EqnKind::Analytical } } diff --git a/src/simulator/equation/meta.rs b/src/simulator/equation/meta.rs deleted file mode 100644 index 1b38ae35..00000000 --- a/src/simulator/equation/meta.rs +++ /dev/null @@ -1,64 +0,0 @@ -#[repr(C)] -#[derive(Debug, Clone)] -/// Model metadata container. -/// -/// This structure holds the metadata associated with a pharmacometric model, -/// including parameter names and other model-specific information that needs -/// to be preserved across simulation and estimation activities. -/// -/// # Examples -/// -/// ``` -/// use pharmsol::simulator::equation::Meta; -/// -/// let model_metadata = Meta::new(vec!["CL", "V", "KA"]); -/// assert_eq!(model_metadata.get_params().len(), 3); -/// ``` -pub struct Meta { - params: Vec, -} - -impl Meta { - /// Creates a new metadata container with the specified parameter names. - /// - /// # Arguments - /// - /// * `params` - A vector of string slices representing parameter names - /// - /// # Returns - /// - /// A new `Meta` instance containing the converted parameter names - /// - /// # Examples - /// - /// ``` - /// use pharmsol::simulator::equation::Meta; - /// - /// let metadata = Meta::new(vec!["CL", "V", "KA"]); - /// ``` - pub fn new(params: Vec<&str>) -> Self { - let params = params.iter().map(|x| x.to_string()).collect(); - Meta { params } - } - - /// Retrieves the parameter names stored in this metadata container. - /// - /// # Returns - /// - /// A reference to the vector of parameter names - /// - /// # Examples - /// - /// ``` - /// use pharmsol::simulator::equation::Meta; - /// - /// let metadata = Meta::new(vec!["CL", "V", "KA"]); - /// let params = metadata.get_params(); - /// assert_eq!(params[0], "CL"); - /// assert_eq!(params[1], "V"); - /// assert_eq!(params[2], "KA"); - /// ``` - pub fn get_params(&self) -> &Vec { - &self.params - } -} diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs new file mode 100644 index 00000000..ecf51a52 --- /dev/null +++ b/src/simulator/equation/metadata.rs @@ -0,0 +1,1211 @@ +//! Shared model metadata for handwritten simulator models. +//! +//! This module defines the public metadata contract that handwritten ODE, +//! analytical, and SDE models can attach to. The field set is intentionally +//! aligned with the public subset of the DSL/runtime metadata surface. +//! +//! Internal runtime layout details such as dense buffer lengths, derived buffer +//! shape, or ABI-specific offsets remain internal for now. + +use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind}; +use std::fmt; +use thiserror::Error; + +/// Create a new handwritten-model metadata builder. +pub fn new(name: impl Into) -> ModelMetadata { + ModelMetadata::new(name) +} + +/// Validation errors for handwritten model metadata. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum ModelMetadataError { + #[error("model kind is required for metadata validation")] + MissingModelKind, + #[error("metadata declares kind `{declared:?}` but validation requested `{requested:?}`")] + ModelKindConflict { + declared: ModelKind, + requested: ModelKind, + }, + #[error("duplicate {domain} name `{name}`")] + DuplicateName { domain: NameDomain, name: String }, + #[error("route `{route}` must declare a destination state")] + MissingRouteDestination { route: String }, + #[error("route `{route}` targets unknown state `{destination}`")] + UnknownRouteDestination { route: String, destination: String }, + #[error("infusion route `{route}` cannot declare lag")] + InfusionLagNotAllowed { route: String }, + #[error("infusion route `{route}` cannot declare bioavailability")] + InfusionBioavailabilityNotAllowed { route: String }, + #[error("{kind:?} metadata cannot declare particles")] + ParticlesNotAllowed { kind: ModelKind }, + #[error("Sde metadata requires particles")] + MissingParticles, + #[error( + "metadata declares {declared} particle(s) but validation provided {fallback} fallback particle(s)" + )] + ParticleCountConflict { declared: usize, fallback: usize }, + #[error("{kind:?} metadata cannot declare an analytical kernel")] + AnalyticalKernelNotAllowed { kind: ModelKind }, +} + +/// Name domain used in duplicate-name validation messages. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NameDomain { + Parameter, + Covariate, + State, + Route, + Output, +} + +impl fmt::Display for NameDomain { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let domain = match self { + Self::Parameter => "parameter", + Self::Covariate => "covariate", + Self::State => "state", + Self::Route => "route", + Self::Output => "output", + }; + f.write_str(domain) + } +} + +/// Immutable validated metadata view used by later attachment slices. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedModelMetadata { + name: String, + kind: ModelKind, + parameters: Vec, + covariates: Vec, + states: Vec, + routes: Vec, + route_channel_count: usize, + outputs: Vec, + particles: Option, + analytical: Option, +} + +impl ValidatedModelMetadata { + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> ModelKind { + self.kind + } + + pub fn parameters(&self) -> &[Parameter] { + &self.parameters + } + + pub fn covariates(&self) -> &[Covariate] { + &self.covariates + } + + pub fn states(&self) -> &[State] { + &self.states + } + + pub fn routes(&self) -> &[ValidatedRoute] { + &self.routes + } + + pub fn route_channel_count(&self) -> usize { + self.route_channel_count + } + + pub fn outputs(&self) -> &[Output] { + &self.outputs + } + + pub fn particles(&self) -> Option { + self.particles + } + + pub fn analytical_kernel(&self) -> Option { + self.analytical + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.parameters + .iter() + .position(|parameter| parameter.name() == name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.covariates + .iter() + .position(|covariate| covariate.name() == name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.states.iter().position(|state| state.name() == name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.route(name).map(ValidatedRoute::channel_index) + } + + pub fn route_declaration_index(&self, name: &str) -> Option { + self.routes.iter().position(|route| route.name() == name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.outputs.iter().position(|output| output.name() == name) + } + + pub fn parameter(&self, name: &str) -> Option<&Parameter> { + self.parameter_index(name) + .map(|index| &self.parameters[index]) + } + + pub fn covariate(&self, name: &str) -> Option<&Covariate> { + self.covariate_index(name) + .map(|index| &self.covariates[index]) + } + + pub fn state(&self, name: &str) -> Option<&State> { + self.state_index(name).map(|index| &self.states[index]) + } + + pub fn route(&self, name: &str) -> Option<&ValidatedRoute> { + self.route_declaration_index(name) + .map(|index| &self.routes[index]) + } + + pub fn output(&self, name: &str) -> Option<&Output> { + self.output_index(name).map(|index| &self.outputs[index]) + } +} + +/// One validated route declaration with resolved destination state index. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedRoute { + name: String, + kind: RouteKind, + declaration_index: usize, + channel_index: usize, + destination: String, + destination_index: usize, + has_lag: bool, + has_bioavailability: bool, + input_policy: Option, +} + +impl ValidatedRoute { + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> RouteKind { + self.kind + } + + pub fn declaration_index(&self) -> usize { + self.declaration_index + } + + pub fn channel_index(&self) -> usize { + self.channel_index + } + + pub fn destination(&self) -> &str { + &self.destination + } + + pub fn destination_index(&self) -> usize { + self.destination_index + } + + pub fn has_lag(&self) -> bool { + self.has_lag + } + + pub fn has_bioavailability(&self) -> bool { + self.has_bioavailability + } + + pub fn input_policy(&self) -> Option { + self.input_policy + } +} + +/// Metadata describing one handwritten simulator model. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelMetadata { + name: String, + kind: Option, + parameters: Vec, + covariates: Vec, + states: Vec, + routes: Vec, + outputs: Vec, + particles: Option, + analytical: Option, +} + +impl ModelMetadata { + /// Create a new metadata builder with a model name. + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + kind: None, + parameters: Vec::new(), + covariates: Vec::new(), + states: Vec::new(), + routes: Vec::new(), + outputs: Vec::new(), + particles: None, + analytical: None, + } + } + + /// Set the model kind explicitly. + pub fn kind(mut self, kind: ModelKind) -> Self { + self.kind = Some(kind); + self + } + + /// Replace the ordered parameter list. + pub fn parameters(mut self, parameters: I) -> Self + where + I: IntoIterator, + Parameter: From, + { + self.parameters = parameters.into_iter().map(Parameter::from).collect(); + self + } + + /// Replace the ordered covariate list. + pub fn covariates(mut self, covariates: I) -> Self + where + I: IntoIterator, + { + self.covariates = covariates.into_iter().collect(); + self + } + + /// Replace the ordered state list. + pub fn states(mut self, states: I) -> Self + where + I: IntoIterator, + State: From, + { + self.states = states.into_iter().map(State::from).collect(); + self + } + + /// Add one route declaration. + pub fn route(mut self, route: Route) -> Self { + self.routes.push(route); + self + } + + /// Extend with multiple route declarations. + pub fn routes(mut self, routes: I) -> Self + where + I: IntoIterator, + { + self.routes.extend(routes); + self + } + + /// Replace the ordered output list. + pub fn outputs(mut self, outputs: I) -> Self + where + I: IntoIterator, + Output: From, + { + self.outputs = outputs.into_iter().map(Output::from).collect(); + self + } + + /// Set the particle count for stochastic models. + pub fn particles(mut self, particles: usize) -> Self { + self.particles = Some(particles); + self + } + + /// Set the analytical kernel identity for built-in analytical models. + pub fn analytical_kernel(mut self, analytical: AnalyticalKernel) -> Self { + self.analytical = Some(analytical); + self + } + + /// Get the model name. + pub fn name(&self) -> &str { + &self.name + } + + /// Get the explicit model kind, if already declared. + pub fn kind_decl(&self) -> Option { + self.kind + } + + /// Get the ordered parameter metadata. + pub fn parameters_decl(&self) -> &[Parameter] { + &self.parameters + } + + /// Get the ordered covariate metadata. + pub fn covariates_decl(&self) -> &[Covariate] { + &self.covariates + } + + /// Get the ordered state metadata. + pub fn states_decl(&self) -> &[State] { + &self.states + } + + /// Get the ordered route metadata. + pub fn routes_decl(&self) -> &[Route] { + &self.routes + } + + /// Get the ordered output metadata. + pub fn outputs_decl(&self) -> &[Output] { + &self.outputs + } + + /// Get the declared particle count. + pub fn particles_decl(&self) -> Option { + self.particles + } + + /// Get the declared analytical kernel identity. + pub fn analytical_kernel_decl(&self) -> Option { + self.analytical + } + + /// Validate this metadata using its declared kind. + pub fn validate(self) -> Result { + self.validate_internal(None, None) + } + + /// Validate this metadata for a specific model kind. + pub fn validate_for( + self, + kind: ModelKind, + ) -> Result { + self.validate_internal(Some(kind), None) + } + + /// Validate this metadata for a specific model kind, using a fallback + /// particle count when the metadata itself does not declare one. + pub fn validate_for_with_particles( + self, + kind: ModelKind, + fallback_particles: usize, + ) -> Result { + self.validate_internal(Some(kind), Some(fallback_particles)) + } + + fn validate_internal( + self, + requested_kind: Option, + fallback_particles: Option, + ) -> Result { + let kind = resolve_kind(self.kind, requested_kind)?; + validate_unique_names(&self.parameters, NameDomain::Parameter, Parameter::name)?; + validate_unique_names(&self.covariates, NameDomain::Covariate, Covariate::name)?; + validate_unique_names(&self.states, NameDomain::State, State::name)?; + validate_unique_names(&self.routes, NameDomain::Route, Route::name)?; + validate_unique_names(&self.outputs, NameDomain::Output, Output::name)?; + + let particles = resolve_particles(kind, self.particles, fallback_particles)?; + validate_kind_specific_fields(kind, self.analytical, particles)?; + + let (routes, route_channel_count) = validate_routes(self.routes, &self.states)?; + + Ok(ValidatedModelMetadata { + name: self.name, + kind, + parameters: self.parameters, + covariates: self.covariates, + states: self.states, + routes, + route_channel_count, + outputs: self.outputs, + particles, + analytical: self.analytical, + }) + } +} + +/// One named parameter in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Parameter { + name: String, +} + +impl Parameter { + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for Parameter +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// One named covariate plus interpolation semantics. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Covariate { + name: String, + interpolation: Option, +} + +impl Covariate { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + interpolation: None, + } + } + + pub fn continuous(name: impl Into) -> Self { + Self::new(name).with_interpolation(CovariateInterpolation::Linear) + } + + pub fn locf(name: impl Into) -> Self { + Self::new(name).with_interpolation(CovariateInterpolation::Locf) + } + + pub fn with_interpolation(mut self, interpolation: CovariateInterpolation) -> Self { + self.interpolation = Some(interpolation); + self + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn interpolation(&self) -> Option { + self.interpolation + } +} + +/// One named state in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct State { + name: String, +} + +impl State { + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for State +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// One named output in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Output { + name: String, +} + +impl Output { + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for Output +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// Route declaration kind. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteKind { + Bolus, + Infusion, +} + +/// How route inputs should be interpreted by the execution layer. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteInputPolicy { + InjectToDestination, + ExplicitInputVector, +} + +/// One named route declaration. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Route { + name: String, + kind: RouteKind, + destination: Option, + has_lag: bool, + has_bioavailability: bool, + input_policy: Option, +} + +impl Route { + pub fn bolus(name: impl Into) -> Self { + Self::new(name, RouteKind::Bolus) + } + + pub fn infusion(name: impl Into) -> Self { + Self::new(name, RouteKind::Infusion) + } + + pub fn new(name: impl Into, kind: RouteKind) -> Self { + Self { + name: name.into(), + kind, + destination: None, + has_lag: false, + has_bioavailability: false, + input_policy: None, + } + } + + pub fn to_state(mut self, destination: impl Into) -> Self { + self.destination = Some(destination.into()); + self + } + + pub fn with_lag(mut self) -> Self { + self.has_lag = true; + self + } + + pub fn with_bioavailability(mut self) -> Self { + self.has_bioavailability = true; + self + } + + pub fn inject_input_to_destination(mut self) -> Self { + self.input_policy = Some(RouteInputPolicy::InjectToDestination); + self + } + + pub fn expect_explicit_input(mut self) -> Self { + self.input_policy = Some(RouteInputPolicy::ExplicitInputVector); + self + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> RouteKind { + self.kind + } + + pub fn destination(&self) -> Option<&str> { + self.destination.as_deref() + } + + pub fn has_lag(&self) -> bool { + self.has_lag + } + + pub fn has_bioavailability(&self) -> bool { + self.has_bioavailability + } + + pub fn input_policy(&self) -> Option { + self.input_policy + } +} + +fn resolve_kind( + declared_kind: Option, + requested_kind: Option, +) -> Result { + match (declared_kind, requested_kind) { + (Some(declared), Some(requested)) if declared != requested => { + Err(ModelMetadataError::ModelKindConflict { + declared, + requested, + }) + } + (Some(declared), _) => Ok(declared), + (None, Some(requested)) => Ok(requested), + (None, None) => Err(ModelMetadataError::MissingModelKind), + } +} + +fn resolve_particles( + kind: ModelKind, + declared_particles: Option, + fallback_particles: Option, +) -> Result, ModelMetadataError> { + let particles = match (declared_particles, fallback_particles) { + (Some(declared), Some(fallback)) if declared != fallback => { + return Err(ModelMetadataError::ParticleCountConflict { declared, fallback }); + } + (Some(declared), _) => Some(declared), + (None, Some(fallback)) => Some(fallback), + (None, None) => None, + }; + + match kind { + ModelKind::Ode | ModelKind::Analytical if particles.is_some() => { + Err(ModelMetadataError::ParticlesNotAllowed { kind }) + } + ModelKind::Sde if particles.is_none() => Err(ModelMetadataError::MissingParticles), + _ => Ok(particles), + } +} + +fn validate_kind_specific_fields( + kind: ModelKind, + analytical: Option, + particles: Option, +) -> Result<(), ModelMetadataError> { + match kind { + ModelKind::Ode => { + if analytical.is_some() { + return Err(ModelMetadataError::AnalyticalKernelNotAllowed { kind }); + } + if particles.is_some() { + return Err(ModelMetadataError::ParticlesNotAllowed { kind }); + } + } + ModelKind::Analytical => { + if particles.is_some() { + return Err(ModelMetadataError::ParticlesNotAllowed { kind }); + } + } + ModelKind::Sde => { + if analytical.is_some() { + return Err(ModelMetadataError::AnalyticalKernelNotAllowed { kind }); + } + } + } + Ok(()) +} + +fn validate_unique_names( + values: &[T], + domain: NameDomain, + name_of: impl Fn(&T) -> &str, +) -> Result<(), ModelMetadataError> { + let mut names = std::collections::HashSet::with_capacity(values.len()); + for value in values { + let name = name_of(value); + if !names.insert(name) { + return Err(ModelMetadataError::DuplicateName { + domain, + name: name.to_string(), + }); + } + } + Ok(()) +} + +fn validate_routes( + routes: Vec, + states: &[State], +) -> Result<(Vec, usize), ModelMetadataError> { + let mut bolus_channels = 0; + let mut infusion_channels = 0; + let mut validated_routes = Vec::with_capacity(routes.len()); + + for (declaration_index, route) in routes.into_iter().enumerate() { + let channel_index = match route.kind { + RouteKind::Bolus => { + let index = bolus_channels; + bolus_channels += 1; + index + } + RouteKind::Infusion => { + let index = infusion_channels; + infusion_channels += 1; + index + } + }; + + validated_routes.push(validate_route( + route, + declaration_index, + channel_index, + states, + )?); + } + + Ok((validated_routes, bolus_channels.max(infusion_channels))) +} + +fn validate_route( + route: Route, + declaration_index: usize, + channel_index: usize, + states: &[State], +) -> Result { + if route.kind == RouteKind::Infusion && route.has_lag { + return Err(ModelMetadataError::InfusionLagNotAllowed { + route: route.name.clone(), + }); + } + + if route.kind == RouteKind::Infusion && route.has_bioavailability { + return Err(ModelMetadataError::InfusionBioavailabilityNotAllowed { + route: route.name.clone(), + }); + } + + let destination = + route + .destination + .clone() + .ok_or_else(|| ModelMetadataError::MissingRouteDestination { + route: route.name.clone(), + })?; + let destination_index = states + .iter() + .position(|state| state.name() == destination) + .ok_or_else(|| ModelMetadataError::UnknownRouteDestination { + route: route.name.clone(), + destination: destination.clone(), + })?; + + Ok(ValidatedRoute { + name: route.name, + kind: route.kind, + declaration_index, + channel_index, + destination, + destination_index, + has_lag: route.has_lag, + has_bioavailability: route.has_bioavailability, + input_policy: route.input_policy, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builds_ode_metadata_shape() { + let metadata = new("bimodal_ke") + .kind(ModelKind::Ode) + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")); + + assert_eq!(metadata.name(), "bimodal_ke"); + assert_eq!(metadata.kind_decl(), Some(ModelKind::Ode)); + assert_eq!(metadata.parameters_decl()[0].name(), "ke"); + assert_eq!(metadata.parameters_decl()[1].name(), "v"); + assert_eq!(metadata.states_decl()[0].name(), "central"); + assert_eq!(metadata.outputs_decl()[0].name(), "cp"); + assert_eq!(metadata.routes_decl()[0].name(), "iv"); + assert_eq!(metadata.routes_decl()[0].kind(), RouteKind::Infusion); + assert_eq!(metadata.routes_decl()[0].destination(), Some("central")); + } + + #[test] + fn builds_analytical_metadata_shape() { + let metadata = new("one_cmt_abs") + .kind(ModelKind::Analytical) + .parameters(["ka", "ke", "v"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route(Route::bolus("oral").to_state("gut").with_bioavailability()) + .route(Route::infusion("iv").to_state("central")) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption); + + assert_eq!(metadata.kind_decl(), Some(ModelKind::Analytical)); + assert_eq!(metadata.states_decl()[0].name(), "gut"); + assert_eq!(metadata.states_decl()[1].name(), "central"); + assert_eq!(metadata.routes_decl()[0].kind(), RouteKind::Bolus); + assert!(metadata.routes_decl()[0].has_bioavailability()); + assert_eq!( + metadata.analytical_kernel_decl(), + Some(AnalyticalKernel::OneCompartmentWithAbsorption) + ); + } + + #[test] + fn builds_sde_metadata_shape() { + let metadata = new("one_cmt_sde") + .kind(ModelKind::Sde) + .parameters(["ke", "sigma", "v"]) + .covariates([Covariate::continuous("wt"), Covariate::locf("age")]) + .states(["central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(128); + + assert_eq!(metadata.kind_decl(), Some(ModelKind::Sde)); + assert_eq!(metadata.covariates_decl()[0].name(), "wt"); + assert_eq!( + metadata.covariates_decl()[0].interpolation(), + Some(CovariateInterpolation::Linear) + ); + assert_eq!(metadata.covariates_decl()[1].name(), "age"); + assert_eq!( + metadata.covariates_decl()[1].interpolation(), + Some(CovariateInterpolation::Locf) + ); + assert_eq!(metadata.particles_decl(), Some(128)); + assert_eq!( + metadata.routes_decl()[0].input_policy(), + Some(RouteInputPolicy::InjectToDestination) + ); + } + + #[test] + fn validates_metadata_and_exposes_lookup_helpers() { + let metadata = new("bimodal_ke") + .kind(ModelKind::Ode) + .parameters(["ke", "v"]) + .covariates([Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect("metadata should validate"); + + assert_eq!(metadata.parameter_index("ke"), Some(0)); + assert_eq!(metadata.parameter_index("v"), Some(1)); + assert_eq!(metadata.covariate_index("wt"), Some(0)); + assert_eq!(metadata.state_index("central"), Some(0)); + assert_eq!(metadata.route_index("iv"), Some(0)); + assert_eq!(metadata.route_declaration_index("iv"), Some(0)); + assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.output_index("cp"), Some(0)); + assert_eq!( + metadata.route("iv").expect("route exists").destination(), + "central" + ); + assert_eq!( + metadata + .route("iv") + .expect("route exists") + .declaration_index(), + 0 + ); + assert_eq!( + metadata.route("iv").expect("route exists").channel_index(), + 0 + ); + assert_eq!( + metadata + .route("iv") + .expect("route exists") + .destination_index(), + 0 + ); + } + + #[test] + fn duplicate_names_fail_validation() { + let error = new("dup_params") + .kind(ModelKind::Ode) + .parameters(["ke", "ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("duplicate parameters must fail"); + + assert_eq!( + error, + ModelMetadataError::DuplicateName { + domain: NameDomain::Parameter, + name: "ke".to_string(), + } + ); + } + + #[test] + fn missing_route_destination_fails_validation() { + let error = new("missing_route_destination") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv")) + .validate() + .expect_err("route destination is required"); + + assert_eq!( + error, + ModelMetadataError::MissingRouteDestination { + route: "iv".to_string(), + } + ); + } + + #[test] + fn unknown_route_destination_fails_validation() { + let error = new("unknown_route_destination") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("peripheral")) + .validate() + .expect_err("unknown destinations must fail"); + + assert_eq!( + error, + ModelMetadataError::UnknownRouteDestination { + route: "iv".to_string(), + destination: "peripheral".to_string(), + } + ); + } + + #[test] + fn shared_channel_routes_preserve_declaration_and_channel_identity() { + let metadata = new("shared_channel") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + Route::bolus("oral").to_state("gut"), + Route::infusion("iv").to_state("central"), + ]) + .validate() + .expect("shared-channel metadata should validate"); + + assert_eq!(metadata.routes().len(), 2); + assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_index("oral"), Some(0)); + assert_eq!(metadata.route_index("iv"), Some(0)); + assert_eq!(metadata.route_declaration_index("oral"), Some(0)); + assert_eq!(metadata.route_declaration_index("iv"), Some(1)); + assert_eq!( + metadata.route("oral").expect("oral route").channel_index(), + 0 + ); + assert_eq!(metadata.route("iv").expect("iv route").channel_index(), 0); + assert_eq!( + metadata + .route("oral") + .expect("oral route") + .declaration_index(), + 0 + ); + assert_eq!( + metadata.route("iv").expect("iv route").declaration_index(), + 1 + ); + } + + #[test] + fn infusion_routes_reject_lag_and_bioavailability() { + let lag_error = new("infusion_lag") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central").with_lag()) + .validate() + .expect_err("infusion lag must fail"); + + assert_eq!( + lag_error, + ModelMetadataError::InfusionLagNotAllowed { + route: "iv".to_string(), + } + ); + + let fa_error = new("infusion_fa") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .with_bioavailability(), + ) + .validate() + .expect_err("infusion bioavailability must fail"); + + assert_eq!( + fa_error, + ModelMetadataError::InfusionBioavailabilityNotAllowed { + route: "iv".to_string(), + } + ); + } + + #[test] + fn validate_requires_or_accepts_a_kind() { + let error = new("kind_required") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("kindless metadata needs explicit validation kind"); + + assert_eq!(error, ModelMetadataError::MissingModelKind); + + let validated = new("kind_override") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for(ModelKind::Ode) + .expect("caller-provided kind should validate"); + + assert_eq!(validated.kind(), ModelKind::Ode); + } + + #[test] + fn conflicting_kinds_fail_validation() { + let error = new("kind_conflict") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for(ModelKind::Sde) + .expect_err("conflicting kinds must fail"); + + assert_eq!( + error, + ModelMetadataError::ModelKindConflict { + declared: ModelKind::Ode, + requested: ModelKind::Sde, + } + ); + } + + #[test] + fn particles_are_rejected_for_ode_and_analytical() { + let ode_error = new("ode_particles") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate() + .expect_err("ODE metadata cannot declare particles"); + + assert_eq!( + ode_error, + ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Ode, + } + ); + + let analytical_error = new("analytical_particles") + .kind(ModelKind::Analytical) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate() + .expect_err("Analytical metadata cannot declare particles"); + + assert_eq!( + analytical_error, + ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Analytical, + } + ); + } + + #[test] + fn analytical_kernel_is_limited_to_analytical_models() { + let error = new("ode_kernel") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .analytical_kernel(AnalyticalKernel::OneCompartment) + .validate() + .expect_err("ODE metadata cannot declare an analytical kernel"); + + assert_eq!( + error, + ModelMetadataError::AnalyticalKernelNotAllowed { + kind: ModelKind::Ode, + } + ); + } + + #[test] + fn sde_requires_particles_or_a_fallback_count() { + let error = new("sde_missing_particles") + .kind(ModelKind::Sde) + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("SDE metadata requires particles"); + + assert_eq!(error, ModelMetadataError::MissingParticles); + + let validated = new("sde_fallback_particles") + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for_with_particles(ModelKind::Sde, 128) + .expect("fallback particle count should satisfy SDE validation"); + + assert_eq!(validated.kind(), ModelKind::Sde); + assert_eq!(validated.particles(), Some(128)); + } + + #[test] + fn conflicting_particle_counts_fail_validation() { + let error = new("sde_particle_conflict") + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate_for_with_particles(ModelKind::Sde, 128) + .expect_err("mismatched particle counts must fail"); + + assert_eq!( + error, + ModelMetadataError::ParticleCountConflict { + declared: 64, + fallback: 128, + } + ); + } +} diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 39cd741f..60cb2d8f 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -1,11 +1,12 @@ use std::fmt::Debug; pub mod analytical; -pub mod meta; +pub mod metadata; pub mod ode; pub mod sde; pub use analytical::*; -pub use meta::*; +pub use metadata::*; pub use ode::*; +pub use pharmsol_dsl::{AnalyticalKernel, ModelKind}; pub use sde::*; use crate::{ diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 17b04235..cafe6a96 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -27,8 +27,13 @@ use diffsol::{ OdeSolverStopReason, Vector, VectorHost, }; use nalgebra::DVector; +use pharmsol_dsl::ModelKind; +use thiserror::Error; -use super::{Equation, EquationPriv, EquationTypes, State}; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, State, + ValidatedModelMetadata, +}; const RTOL: f64 = 1e-4; const ATOL: f64 = 1e-4; @@ -76,6 +81,20 @@ pub enum ExplicitRkTableau { Tsit45, } +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum OdeMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("ODE declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error( + "ODE declares {declared} route metadata entries but model has {expected} input channels" + )] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("ODE declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} + #[derive(Clone, Debug)] pub struct ODE { diffeq: DiffEq, @@ -87,6 +106,7 @@ pub struct ODE { solver: OdeSolver, rtol: f64, atol: f64, + metadata: Option, cache: Option, } @@ -102,6 +122,7 @@ impl ODE { solver: OdeSolver::default(), rtol: RTOL, atol: ATOL, + metadata: None, cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -109,18 +130,21 @@ impl ODE { /// Set the number of state variables (ODE compartments). pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } /// Set the number of drug input channels (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } @@ -136,6 +160,74 @@ impl ODE { self.atol = atol; self } + + /// Attach validated handwritten-model metadata to this ODE. + pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { + let metadata = metadata.validate_for(ModelKind::Ode)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + /// Access the validated metadata attached to this ODE, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.metadata()?.route_index(name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.metadata()?.output_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), OdeMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(OdeMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_channel_count(); + if declared_routes != neqs.ndrugs { + return Err(OdeMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(OdeMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for ODE { @@ -280,7 +372,7 @@ impl EquationPriv for ODE { impl ODE { /// Generic event-loop runner, parameterized over the concrete solver type. #[allow(clippy::too_many_arguments)] - fn run_events<'a, S: OdeSolverMethod<'a, PMProblem<'a, DiffEq>>>( + fn run_events<'a, F, S>( &self, solver: &mut S, events: &[Event], @@ -295,7 +387,11 @@ impl ODE { y_out: &mut V, likelihood: &mut Vec, output: &mut SubjectPredictions, - ) -> Result<(), PharmsolError> { + ) -> Result<(), PharmsolError> + where + F: Fn(&V, &V, f64, &mut V, &V, &V, &Covariates) + 'a, + S: OdeSolverMethod<'a, PMProblem<'a, F>>, + { for (index, event) in events.iter().enumerate() { let next_event = events.get(index + 1); @@ -420,8 +516,8 @@ impl Equation for ODE { ypred.log_likelihood(error_models) } - fn kind() -> crate::EqnKind { - crate::EqnKind::ODE + fn kind() -> EqnKind { + EqnKind::ODE } fn simulate_subject( @@ -467,7 +563,9 @@ impl Equation for ODE { .h0(1e-3) .p(support_point.to_vec()) .build_from_eqn(PMProblem::with_params_v( - self.diffeq, + move |x, p, t, dx, bolus, rateiv, cov| { + (self.diffeq)(x, p, t, dx, bolus, rateiv, cov); + }, nstates, ndrugs, support_point.to_vec(), @@ -560,3 +658,235 @@ impl Equation for ODE { Ok((output, ll)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{fa, lag, Subject, SubjectBuilderExt}; + use approx::assert_relative_eq; + + fn simple_ode() -> ODE { + ODE::new( + |_x, _p, _t, _dx, _b, _rateiv, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, _y| {}, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + fn route_policy_subject() -> Subject { + Subject::builder("route_policy") + .bolus(0.0, 100.0, 0) + .infusion(0.0, 100.0, 0, 1.0) + .observation(1.0, 0.0, 0) + .build() + } + + fn explicit_route_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + b: &V, + rateiv: &V, + _cov: &Covariates, + ) { + dx[0] = b[0] + rateiv[0]; + } + + fn injected_route_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + _b: &V, + _rateiv: &V, + _cov: &Covariates, + ) { + dx[0] = 0.0; + } + + fn zero_lag(_p: &V, _t: f64, _cov: &Covariates) -> std::collections::HashMap { + std::collections::HashMap::new() + } + + fn unit_fa(_p: &V, _t: f64, _cov: &Covariates) -> std::collections::HashMap { + std::collections::HashMap::new() + } + + fn zero_init(_p: &V, _t: f64, _cov: &Covariates, _x: &mut V) {} + + fn state_output(x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V) { + y[0] = x[0]; + } + + #[test] + fn handwritten_ode_metadata_exposes_name_lookup() { + let ode = simple_ode() + .with_metadata( + super::super::metadata::new("bimodal_ke") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate"); + + assert_eq!(ode.parameter_index("ke"), Some(0)); + assert_eq!(ode.parameter_index("v"), Some(1)); + assert_eq!(ode.state_index("central"), Some(0)); + assert_eq!(ode.route_index("iv"), Some(0)); + assert_eq!(ode.output_index("cp"), Some(0)); + assert_eq!( + ode.metadata().expect("metadata exists").kind(), + ModelKind::Ode + ); + } + + #[test] + fn handwritten_ode_without_metadata_keeps_raw_path() { + let ode = simple_ode(); + + assert!(ode.metadata().is_none()); + assert_eq!(ode.state_index("central"), None); + assert_eq!(ode.route_index("iv"), None); + assert_eq!(ode.output_index("cp"), None); + } + + #[test] + fn handwritten_ode_rejects_dimension_mismatches() { + let error = simple_ode() + .with_metadata( + super::super::metadata::new("wrong_outputs") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp", "auc"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect_err("output-count mismatches must fail"); + + assert_eq!( + error, + OdeMetadataError::OutputCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_ode_rejects_invalid_metadata() { + let error = simple_ode() + .with_metadata( + super::super::metadata::new("missing_destination") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv")), + ) + .expect_err("invalid metadata must fail during attachment"); + + assert_eq!( + error, + OdeMetadataError::Validation(ModelMetadataError::MissingRouteDestination { + route: "iv".to_string(), + }) + ); + } + + #[test] + fn handwritten_ode_defaults_to_explicit_route_vectors() { + let ode = ODE::new( + explicit_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("explicit_routes") + .states(["central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral").to_state("central"), + super::super::Route::infusion("iv").to_state("central"), + ]), + ) + .expect("metadata attachment should validate"); + + let predictions = ode + .simulate_subject(&route_policy_subject(), &[], None) + .expect("simulation should succeed") + .0; + + assert_eq!(ode.route_index("oral").expect("oral route"), 0); + assert_eq!(ode.route_index("iv").expect("iv route"), 0); + assert_relative_eq!( + predictions.predictions()[0].prediction(), + 200.0, + epsilon = 1e-6 + ); + } + + #[test] + fn handwritten_ode_metadata_input_policy_is_descriptive_only() { + let ode = ODE::new( + injected_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("injected_routes") + .states(["central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + super::super::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("metadata attachment should validate"); + + let predictions = ode + .simulate_subject(&route_policy_subject(), &[], None) + .expect("simulation should succeed") + .0; + + assert_relative_eq!( + predictions.predictions()[0].prediction(), + 0.0, + epsilon = 1e-6 + ); + } + + #[test] + fn changing_dimensions_after_metadata_clears_route_metadata() { + let ode = simple_ode() + .with_metadata( + super::super::metadata::new("bimodal_ke") + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate") + .with_ndrugs(2); + + assert!(ode.metadata().is_none()); + assert_eq!(ode.route_index("iv"), None); + } +} diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index af8ea246..bdafbbc3 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -3,8 +3,10 @@ mod em; use diffsol::{NalgebraContext, Vector}; use nalgebra::DVector; use ndarray::{concatenate, Array2, Axis}; +use pharmsol_dsl::ModelKind; use rand::{rng, RngExt}; use rayon::prelude::*; +use thiserror::Error; use crate::{ data::{Covariates, Infusion}, @@ -21,7 +23,59 @@ use diffsol::VectorCommon; use crate::PharmsolError; -use super::{Equation, EquationPriv, EquationTypes, Predictions, State}; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, Predictions, + State, ValidatedModelMetadata, +}; + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum SdeMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("SDE declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error( + "SDE declares {declared} route metadata entries but model has {expected} input channels" + )] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("SDE declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} + +#[derive(Clone, Debug, Default)] +struct InjectedBolusMappings { + destinations: Vec>, +} + +impl InjectedBolusMappings { + fn explicit(ndrugs: usize) -> Self { + Self { + destinations: vec![None; ndrugs], + } + } + + fn from_destinations(ndrugs: usize, destinations: &[Option]) -> Self { + let mut mappings = Self::explicit(ndrugs); + for (input, destination) in destinations.iter().copied().take(ndrugs).enumerate() { + mappings.destinations[input] = destination; + } + mappings + } + + fn invalidate_for_ndrugs(&mut self, ndrugs: usize) { + *self = Self::explicit(ndrugs); + } + + fn apply(&self, state: &mut [DVector], input: usize, amount: f64) -> bool { + let Some(destination) = self.destinations.get(input).copied().flatten() else { + return false; + }; + state.par_iter_mut().for_each(|particle| { + particle[destination] += amount; + }); + true + } +} /// Simulate a stochastic differential equation (SDE) event. /// @@ -44,7 +98,7 @@ use super::{Equation, EquationPriv, EquationTypes, Predictions, State}; /// The state vector at time `tf` after simulation. #[inline(always)] #[allow(clippy::too_many_arguments)] -pub(crate) fn simulate_sde_event( +fn simulate_sde_event( drift: &Drift, difussion: &Diffusion, x: V, @@ -133,6 +187,8 @@ pub struct SDE { out: Out, neqs: Neqs, nparticles: usize, + metadata: Option, + injected_bolus_mappings: InjectedBolusMappings, cache: Option, } @@ -164,6 +220,8 @@ impl SDE { out, neqs: Neqs::default(), nparticles, + metadata: None, + injected_bolus_mappings: InjectedBolusMappings::default(), cache: Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)), } } @@ -171,20 +229,100 @@ impl SDE { /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } /// Set the number of drug input channels (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } + + /// Attach validated handwritten-model metadata to this SDE model. + pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { + let metadata = metadata.validate_for_with_particles(ModelKind::Sde, self.nparticles)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + #[doc(hidden)] + pub fn with_injected_bolus_inputs(mut self, destinations: &[Option]) -> Self { + self.injected_bolus_mappings = + InjectedBolusMappings::from_destinations(self.neqs.ndrugs, destinations); + self + } + + /// Access the validated metadata attached to this SDE model, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.metadata()?.route_index(name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.metadata()?.output_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + self.injected_bolus_mappings + .invalidate_for_ndrugs(self.neqs.ndrugs); + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), SdeMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(SdeMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_channel_count(); + if declared_routes != neqs.ndrugs { + return Err(SdeMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(SdeMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for SDE { @@ -435,6 +573,63 @@ impl EquationPriv for SDE { } x } + + fn simulate_event( + &self, + support_point: &[f64], + event: &crate::Event, + next_event: Option<&crate::Event>, + error_models: Option<&AssayErrorModels>, + covariates: &Covariates, + x: &mut Self::S, + infusions: &mut Vec, + likelihood: &mut Vec, + output: &mut Self::P, + ) -> Result<(), PharmsolError> { + match event { + crate::Event::Bolus(bolus) => { + if bolus.input() >= self.get_ndrugs() { + return Err(PharmsolError::InputOutOfRange { + input: bolus.input(), + ndrugs: self.get_ndrugs(), + }); + } + if !self + .injected_bolus_mappings + .apply(x, bolus.input(), bolus.amount()) + { + x.add_bolus(bolus.input(), bolus.amount()); + } + } + crate::Event::Infusion(infusion) => { + infusions.push(infusion.clone()); + } + crate::Event::Observation(observation) => { + self.process_observation( + support_point, + observation, + error_models, + event.time(), + covariates, + x, + likelihood, + output, + )?; + } + } + + if let Some(next_event) = next_event { + self.solve( + x, + support_point, + covariates, + infusions, + event.time(), + next_event.time(), + )?; + } + Ok(()) + } } impl Equation for SDE { @@ -475,8 +670,8 @@ impl Equation for SDE { } } - fn kind() -> crate::EqnKind { - crate::EqnKind::SDE + fn kind() -> EqnKind { + EqnKind::SDE } } @@ -533,3 +728,276 @@ fn sysresample(q: &[f64]) -> Vec { } i } + +#[cfg(test)] +mod tests { + use super::*; + use crate::simulator::equation::{self, Covariate, Route}; + use crate::SubjectBuilderExt; + use crate::{fa, fetch_params, lag}; + + fn simple_sde() -> SDE { + let drift = |x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx[0] = rateiv[0] - x[0]; + }; + let diffusion = |_p: &V, g: &mut V| { + g[0] = 1.0; + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x[0] = 0.0; + }; + let out = |x: &V, p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }; + + SDE::new(drift, diffusion, lag, fa, init, out, 128) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + fn route_policy_sde(drift: Drift) -> SDE { + let diffusion = |_p: &V, sigma: &mut V| { + sigma.fill(0.0); + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + SDE::new(drift, diffusion, lag, fa, init, out, 16) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + } + + #[test] + fn handwritten_sde_metadata_exposes_name_lookup_and_particles() { + let sde = simple_sde() + .with_metadata( + equation::metadata::new("one_cmt_sde") + .parameters(["ke", "v"]) + .covariates([Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect("SDE metadata attachment should validate"); + + let metadata = sde.metadata().expect("metadata exists"); + assert_eq!(metadata.kind(), ModelKind::Sde); + assert_eq!(metadata.particles(), Some(128)); + assert_eq!(sde.parameter_index("ke"), Some(0)); + assert_eq!(sde.parameter_index("v"), Some(1)); + assert_eq!(sde.covariate_index("wt"), Some(0)); + assert_eq!(sde.state_index("central"), Some(0)); + assert_eq!(sde.route_index("iv"), Some(0)); + assert_eq!(sde.output_index("cp"), Some(0)); + } + + #[test] + fn handwritten_sde_without_metadata_keeps_raw_path() { + let sde = simple_sde(); + + assert!(sde.metadata().is_none()); + assert_eq!(sde.parameter_index("ke"), None); + assert_eq!(sde.route_index("iv"), None); + assert_eq!(sde.output_index("cp"), None); + } + + #[test] + fn handwritten_sde_rejects_dimension_mismatches() { + let error = simple_sde() + .with_metadata( + equation::metadata::new("bad_sde") + .parameters(["ke", "v"]) + .states(["central", "peripheral"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect_err("mismatched state metadata must fail"); + + assert_eq!( + error, + SdeMetadataError::StateCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_sde_rejects_particle_mismatch() { + let error = simple_sde() + .with_metadata( + equation::metadata::new("particle_conflict") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64), + ) + .expect_err("mismatched SDE particles must fail"); + + assert_eq!( + error, + SdeMetadataError::Validation(ModelMetadataError::ParticleCountConflict { + declared: 64, + fallback: 128, + }) + ); + } + + #[test] + fn changing_dimensions_after_metadata_clears_sde_metadata() { + let sde = simple_sde() + .with_metadata( + equation::metadata::new("one_cmt_sde") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect("metadata attachment should validate") + .with_nout(2); + + assert!(sde.metadata().is_none()); + assert_eq!(sde.route_index("iv"), None); + assert_eq!(sde.output_index("cp"), None); + } + + #[test] + fn sde_metadata_input_policy_is_descriptive_only_for_bolus_routes() { + let zero_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, _rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + }; + + let explicit = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("explicit_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(Route::bolus("oral").to_state("central")) + .particles(16), + ) + .expect("explicit metadata should validate"); + + let injected = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("injected_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate"); + + let subject = Subject::builder("bolus_route") + .bolus(0.0, 100.0, 0) + .missing_observation(0.1, 0) + .build(); + + let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); + let injected_predictions = injected.estimate_predictions(&subject, &[0.0]).unwrap(); + + assert_eq!(explicit_predictions[[0, 0]].prediction(), 0.0); + assert_eq!(injected_predictions[[0, 0]].prediction(), 0.0); + } + + #[test] + fn sde_metadata_input_policy_does_not_change_explicit_rateiv_behavior() { + let rateiv_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + dx[1] = rateiv[0]; + }; + + let explicit = route_policy_sde(rateiv_drift) + .with_metadata( + equation::metadata::new("explicit_infusion") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(16), + ) + .expect("explicit metadata should validate"); + + let injected = route_policy_sde(rateiv_drift) + .with_metadata( + equation::metadata::new("injected_infusion") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate"); + + let subject = Subject::builder("infusion_route") + .infusion(0.0, 100.0, 0, 1.0) + .missing_observation(1.0, 0) + .build(); + + let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); + let injected_predictions = injected.estimate_predictions(&subject, &[0.0]).unwrap(); + + let explicit_prediction = explicit_predictions[[0, 0]].prediction(); + let injected_prediction = injected_predictions[[0, 0]].prediction(); + + assert!(explicit_prediction > 0.0); + assert!((injected_prediction - explicit_prediction).abs() < 1e-8); + } + + #[test] + fn clearing_sde_metadata_preserves_raw_bolus_behavior() { + let zero_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, _rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + }; + + let sde = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("injected_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate") + .with_nout(1); + + let subject = Subject::builder("bolus_route") + .bolus(0.0, 100.0, 0) + .missing_observation(0.1, 0) + .build(); + + let predictions = sde.estimate_predictions(&subject, &[0.0]).unwrap(); + + assert!(sde.metadata().is_none()); + assert_eq!(predictions[[0, 0]].prediction(), 0.0); + } +} diff --git a/src/test_fixtures.rs b/src/test_fixtures.rs index 7fb1610e..91d21e90 100644 --- a/src/test_fixtures.rs +++ b/src/test_fixtures.rs @@ -83,7 +83,7 @@ model one_cmt_abs { oral -> depot } analytical { - kernel = one_compartment_with_absorption + structure = one_compartment_with_absorption } outputs { cp = central / v diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs new file mode 100644 index 00000000..0842f9c2 --- /dev/null +++ b/tests/analytical_macro_lowering.rs @@ -0,0 +1,299 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; + +fn infusion_subject(input: usize) -> Subject { + Subject::builder("analytical-macro-iv") + .infusion(0.0, 120.0, input, 1.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn oral_subject(input: usize) -> Subject { + Subject::builder("analytical-macro-oral") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn shared_channel_subject(input: usize) -> Subject { + Subject::builder("analytical-macro-shared") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .build() +} + +fn macro_one_compartment() -> equation::Analytical { + analytical! { + name: "one_cpt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_one_compartment() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_iv") + .kind(equation::ModelKind::Analytical) + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(equation::Route::infusion("iv").to_state("central")) + .analytical_kernel(equation::AnalyticalKernel::OneCompartment), + ) + .expect("handwritten analytical metadata should validate") +} + +fn macro_one_compartment_with_absorption() -> equation::Analytical { + analytical! { + name: "one_cmt_abs", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + structure: one_compartment_with_absorption, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + init: |_p, _t, _cov, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_one_compartment_with_absorption() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + ) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten absorption metadata should validate") +} + +fn macro_shared_channel_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_shared", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_shared_channel_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_shared") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten shared-channel analytical metadata should validate") +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +#[test] +fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_one_compartment(); + let handwritten_model = handwritten_one_compartment(); + let subject = infusion_subject(0); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { + let macro_model = macro_one_compartment_with_absorption(); + let handwritten_model = handwritten_one_compartment_with_absorption(); + let subject = oral_subject(0); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!( + macro_model + .metadata() + .expect("macro metadata exists") + .analytical_kernel(), + Some(equation::AnalyticalKernel::OneCompartmentWithAbsorption) + ); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro absorption model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten absorption model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_channel_analytical(); + let handwritten_model = handwritten_shared_channel_analytical(); + let subject = shared_channel_subject(0); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(1)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro shared-channel analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-channel analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs new file mode 100644 index 00000000..43621e8a --- /dev/null +++ b/tests/authoring_parity_corpus.rs @@ -0,0 +1,1365 @@ +#[cfg(feature = "dsl-jit")] +use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; +#[cfg(feature = "dsl-jit")] +use pharmsol::equation::RouteInputPolicy; +use pharmsol::equation::{ + self, AnalyticalKernel, RouteKind as HandwrittenRouteKind, ValidatedModelMetadata, +}; +use pharmsol::prelude::*; +#[cfg(feature = "dsl-jit")] +use pharmsol::Predictions; +use pharmsol_dsl::{ + analyze_model, lower_typed_model, parse_model, CovariateInterpolation, ExecutionModel, + ModelKind, RouteKind as DslRouteKind, +}; + +const ODE_DSL: &str = r#" +name = one_cmt_oral_iv +kind = ode + +params = ka, cl, v, tlag, f_oral +covariates = wt @linear +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - (cl / v) * central + +out(cp) = central / (v * (wt / 70.0)) ~ continuous() +"#; + +const ODE_MACRO_DSL: &str = r#" +name = one_cmt_oral_covariate_parity +kind = ode + +params = ka, cl, v, tlag, f_oral +covariates = wt @linear +states = depot, central +outputs = cp + +bolus(oral) -> depot +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - (cl / v) * central + +out(cp) = central / (v * (wt / 70.0)) ~ continuous() +"#; + +const ODE_INVALID_INFUSION_LAG_DSL: &str = r#" +name = invalid_infusion_lag_parity +kind = ode + +params = ke, v, tlag +states = central +outputs = cp + +infusion(iv) -> central +lag(iv) = tlag + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +name = shared_channel_one_cpt +kind = ode + +params = ka, ke, v, tlag, f_oral +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + +const ANALYTICAL_DSL: &str = r#" +name = one_cmt_abs_parity +kind = analytical + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +structure = one_compartment_with_absorption + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +name = one_cmt_abs_shared +kind = analytical + +params = ka, ke, v, tlag, f_oral +states = gut, central +outputs = cp + +bolus(oral) -> gut +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral +structure = one_compartment_with_absorption + +out(cp) = central / v ~ continuous() +"#; + +const SDE_DSL: &str = r#" +name = one_cmt_sde_parity +kind = sde + +params = ka, ke, v, sigma +covariates = wt @locf +states = depot, central +outputs = cp + +bolus(oral) -> depot +particles = 256 + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central +noise(central) = sigma + +out(cp) = central / (v * wt) ~ continuous() +"#; + +const SDE_MACRO_DSL: &str = r#" +name = one_cmt_sde_macro_parity +kind = sde + +params = ka, ke, v, sigma +states = depot, central +outputs = cp + +bolus(oral) -> depot +particles = 256 + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central +noise(central) = sigma + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const SDE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +name = one_cmt_shared_sde +kind = sde + +params = ka, ke, sigma_ke, v, tlag, f_oral +states = gut, central +outputs = cp +particles = 8 + +bolus(oral) -> gut +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(gut) = -ka * gut +dx(central) = ka * gut - ke * central +noise(central) = sigma_ke + +out(cp) = central / v ~ continuous() +"#; + +#[derive(Clone, Debug, PartialEq, Eq)] +struct MetadataParityView { + name: String, + kind: ModelKind, + parameters: Vec, + covariates: Vec, + states: Vec, + route_channel_count: usize, + routes: Vec, + outputs: Vec, + analytical_kernel: Option, + particles: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct NamedIndex { + name: String, + index: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct CovariateParity { + name: String, + index: usize, + interpolation: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct RouteParity { + name: String, + kind: Option, + declaration_index: usize, + channel_index: usize, + destination_name: String, + destination_index: usize, + has_lag: bool, + has_bioavailability: bool, +} + +#[cfg(feature = "dsl-jit")] +#[derive(Clone, Debug, PartialEq, Eq)] +struct RouteInputPolicyParity { + name: String, + declaration_index: usize, + channel_index: usize, + input_policy: RouteInputPolicy, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum RouteKindParity { + Bolus, + Infusion, +} + +impl RouteKindParity { + fn from_dsl(kind: DslRouteKind) -> Self { + match kind { + DslRouteKind::Bolus => Self::Bolus, + DslRouteKind::Infusion => Self::Infusion, + } + } + + fn from_handwritten(kind: HandwrittenRouteKind) -> Self { + match kind { + HandwrittenRouteKind::Bolus => Self::Bolus, + HandwrittenRouteKind::Infusion => Self::Infusion, + } + } +} + +fn load_execution_model(src: &str) -> ExecutionModel { + let parsed = parse_model(src).expect("DSL model should parse"); + let typed = analyze_model(&parsed).expect("DSL model should analyze"); + lower_typed_model(&typed).expect("DSL model should lower") +} + +#[cfg(feature = "dsl-jit")] +fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntimeModel { + dsl::compile_module_source_to_runtime( + src, + Some(model_name), + RuntimeCompilationTarget::Jit, + |_, _| {}, + ) + .expect("DSL runtime model should compile") +} + +#[cfg(feature = "dsl-jit")] +fn shared_channel_prediction_subject(input: usize, output: usize) -> Subject { + Subject::builder("authoring-parity-shared-channel") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, output) + .missing_observation(1.0, output) + .missing_observation(2.0, output) + .missing_observation(6.5, output) + .missing_observation(7.0, output) + .missing_observation(8.0, output) + .build() +} + +fn dsl_metadata_view(src: &str) -> MetadataParityView { + let model = load_execution_model(src); + + let parameters = model + .metadata + .parameters + .iter() + .map(|parameter| NamedIndex { + name: parameter.name.clone(), + index: parameter.index, + }) + .collect(); + let covariates = model + .metadata + .covariates + .iter() + .map(|covariate| CovariateParity { + name: covariate.name.clone(), + index: covariate.index, + interpolation: covariate.interpolation, + }) + .collect(); + let states = model + .metadata + .states + .iter() + .map(|state| NamedIndex { + name: state.name.clone(), + index: state.offset, + }) + .collect(); + let outputs = model + .metadata + .outputs + .iter() + .map(|output| NamedIndex { + name: output.name.clone(), + index: output.index, + }) + .collect(); + let routes = model + .metadata + .routes + .iter() + .map(|route| RouteParity { + name: route.name.clone(), + kind: route.kind.map(RouteKindParity::from_dsl), + declaration_index: route.declaration_index, + channel_index: route.index, + destination_name: route.destination.state_name.clone(), + destination_index: route.destination.state_offset, + has_lag: route.has_lag, + has_bioavailability: route.has_bioavailability, + }) + .collect(); + + MetadataParityView { + name: model.name, + kind: model.kind, + parameters, + covariates, + states, + route_channel_count: model.abi.route_buffer.len, + routes, + outputs, + analytical_kernel: model.metadata.analytical, + particles: model.metadata.particles, + } +} + +#[cfg(feature = "dsl-jit")] +fn dsl_route_input_policy_view(src: &str) -> Vec { + let model = load_execution_model(src); + let info = dsl::NativeModelInfo::from_execution_model(&model); + + info.routes + .into_iter() + .map(|route| RouteInputPolicyParity { + name: route.name, + declaration_index: route.declaration_index, + channel_index: route.index, + input_policy: if route.inject_input_to_destination { + RouteInputPolicy::InjectToDestination + } else { + RouteInputPolicy::ExplicitInputVector + }, + }) + .collect() +} + +fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityView { + MetadataParityView { + name: metadata.name().to_string(), + kind: metadata.kind(), + parameters: metadata + .parameters() + .iter() + .enumerate() + .map(|(index, parameter)| NamedIndex { + name: parameter.name().to_string(), + index, + }) + .collect(), + covariates: metadata + .covariates() + .iter() + .enumerate() + .map(|(index, covariate)| CovariateParity { + name: covariate.name().to_string(), + index, + interpolation: covariate.interpolation(), + }) + .collect(), + states: metadata + .states() + .iter() + .enumerate() + .map(|(index, state)| NamedIndex { + name: state.name().to_string(), + index, + }) + .collect(), + route_channel_count: metadata.route_channel_count(), + routes: metadata + .routes() + .iter() + .map(|route| RouteParity { + name: route.name().to_string(), + kind: Some(RouteKindParity::from_handwritten(route.kind())), + declaration_index: route.declaration_index(), + channel_index: route.channel_index(), + destination_name: route.destination().to_string(), + destination_index: route.destination_index(), + has_lag: route.has_lag(), + has_bioavailability: route.has_bioavailability(), + }) + .collect(), + outputs: metadata + .outputs() + .iter() + .enumerate() + .map(|(index, output)| NamedIndex { + name: output.name().to_string(), + index, + }) + .collect(), + analytical_kernel: metadata.analytical_kernel(), + particles: metadata.particles(), + } +} + +#[cfg(feature = "dsl-jit")] +fn handwritten_route_input_policy_view( + metadata: &ValidatedModelMetadata, +) -> Vec { + metadata + .routes() + .iter() + .map(|route| RouteInputPolicyParity { + name: route.name().to_string(), + declaration_index: route.declaration_index(), + channel_index: route.channel_index(), + input_policy: route + .input_policy() + .expect("route input policy should be explicit in this handwritten fixture"), + }) + .collect() +} + +fn macro_ode_model() -> equation::ODE { + ode! { + name: "one_cmt_oral_covariate_parity", + params: [ka, cl, v, tlag, f_oral], + covariates: [wt], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - (cl / v) * x[central]; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, t, cov, y| { + fetch_cov!(cov, t, wt); + y[cp] = x[central] / (v * (wt / 70.0)); + }, + } +} + +fn handwritten_ode_macro_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_covariate_parity") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + ), + ) + .expect("handwritten macro-shape ODE metadata should validate") +} + +fn handwritten_ode_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_iv") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_macro_ode() -> equation::ODE { + ode! { + name: "shared_channel_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, bolus, rateiv, _cov| { + dx[depot] = bolus[oral] - ka * x[depot]; + dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_channel_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten shared-channel ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_mismatched_shared_channel_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_channel_one_cpt_mismatched") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("mismatched shared-channel ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_macro_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_shared", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_shared") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten shared-channel analytical metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_macro_sde() -> equation::SDE { + sde! { + name: "one_cmt_shared_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + drift: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[gut] = 0.0; + sigma[central] = 0.0 * sigma_ke; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + init: |_p, _t, _cov, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_handwritten_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_shared_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten shared-channel SDE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn assert_prediction_vectors_close(left: &[f64], right: &[f64], tolerance: f64) { + assert_eq!(left.len(), right.len()); + for (left_value, right_value) in left.iter().zip(right.iter()) { + let diff = (left_value - right_value).abs(); + assert!( + diff <= tolerance, + "prediction mismatch: left={left_value:.12}, right={right_value:.12}, diff={diff:.12}, tolerance={tolerance:.12}" + ); + } +} + +#[cfg(feature = "dsl-jit")] +fn assert_prediction_vectors_diverge(left: &[f64], right: &[f64], tolerance: f64) { + assert_eq!(left.len(), right.len()); + assert!( + left.iter() + .zip(right.iter()) + .any(|(left_value, right_value)| (left_value - right_value).abs() > tolerance), + "expected prediction vectors to diverge beyond tolerance {tolerance:.12}" + ); +} + +#[cfg(feature = "dsl-jit")] +fn particle_prediction_means(predictions: &ndarray::Array2) -> Vec { + predictions + .get_predictions() + .into_iter() + .map(|prediction| prediction.prediction()) + .collect() +} + +fn macro_analytical_model() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_parity", + params: [ka, ke, v], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + }, + structure: one_compartment_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_analytical_model() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_parity") + .kind(ModelKind::Analytical) + .parameters(["ka", "ke", "v"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(equation::Route::bolus("oral").to_state("depot")) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten analytical metadata should validate") +} + +fn macro_sde_model() -> equation::SDE { + sde! { + name: "one_cmt_sde_macro_parity", + params: [ka, ke, v, sigma], + states: [depot, central], + outputs: [cp], + particles: 256, + routes: { + bolus(oral) -> depot, + }, + drift: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + diffusion: |_p, sigma_values| { + sigma_values[depot] = 0.0; + sigma_values[central] = sigma; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_sde_model() -> equation::SDE { + equation::SDE::new( + |_x, _p, _t, dx, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + 256, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_parity") + .kind(ModelKind::Sde) + .parameters(["ka", "ke", "v", "sigma"]) + .covariates([equation::Covariate::locf("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination(), + ) + .particles(256), + ) + .expect("handwritten SDE metadata should validate") +} + +fn handwritten_sde_macro_model() -> equation::SDE { + equation::SDE::new( + |_x, _p, _t, dx, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, sigma_values| { + sigma_values[0] = 0.0; + sigma_values[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + 256, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_macro_parity") + .kind(ModelKind::Sde) + .parameters(["ka", "ke", "v", "sigma"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination(), + ) + .particles(256), + ) + .expect("handwritten macro-shape SDE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn mismatched_ode_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_iv") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .expect_explicit_input() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("mismatched ODE metadata should validate") +} + +#[test] +fn ode_dsl_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_ode_model(); + let dsl_view = dsl_metadata_view(ODE_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[test] +fn ode_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { + let handwritten = handwritten_ode_macro_model(); + let macro_model = macro_ode_model(); + let dsl_view = dsl_metadata_view(ODE_MACRO_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten macro-shape ODE metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[test] +fn analytical_dsl_macro_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_analytical_model(); + let macro_model = macro_analytical_model(); + let dsl_view = dsl_metadata_view(ANALYTICAL_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten analytical metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro analytical metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[test] +fn sde_dsl_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_sde_model(); + let dsl_view = dsl_metadata_view(SDE_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[test] +fn sde_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { + let handwritten = handwritten_sde_macro_model(); + let macro_model = macro_sde_model(); + let dsl_view = dsl_metadata_view(SDE_MACRO_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten macro-shape SDE metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_route_input_policies_agree_with_handwritten_metadata() { + let dsl_view = dsl_route_input_policy_view(ODE_DSL); + let handwritten = handwritten_ode_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("handwritten ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn sde_route_input_policies_agree_with_handwritten_metadata() { + let dsl_view = dsl_route_input_policy_view(SDE_DSL); + let handwritten = handwritten_sde_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("handwritten SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn route_input_policy_mismatches_are_detected_explicitly() { + let dsl_view = dsl_route_input_policy_view(ODE_DSL); + let handwritten = mismatched_ode_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("mismatched handwritten metadata should exist"), + ); + + assert_ne!(handwritten_view, dsl_view); + assert_eq!(dsl_view[0].name, "oral"); + assert_eq!( + dsl_view[0].input_policy, + RouteInputPolicy::InjectToDestination + ); + assert_eq!( + handwritten_view[0].input_policy, + RouteInputPolicy::ExplicitInputVector + ); +} + +#[test] +fn invalid_dsl_infusion_route_properties_fail_explicitly() { + let model = + parse_model(ODE_INVALID_INFUSION_LAG_DSL).expect("invalid DSL fixture should parse"); + let typed = analyze_model(&model).expect("invalid DSL fixture should analyze"); + let error = lower_typed_model(&typed) + .err() + .expect("infusion lag should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `lag` on infusion route `iv`")); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { + let runtime_model = + compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); + let macro_model = runtime_shared_channel_macro_ode(); + let handwritten_model = runtime_shared_channel_handwritten_ode(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(macro_model.route_index("oral"), Some(oral)); + assert_eq!(macro_model.route_index("iv"), Some(iv)); + assert_eq!(handwritten_model.route_index("oral"), Some(oral)); + assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime ODE model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro ODE model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten ODE model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-4); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-4); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { + let runtime_model = + compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_abs_shared"); + let macro_model = runtime_shared_channel_macro_analytical(); + let handwritten_model = runtime_shared_channel_handwritten_analytical(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(macro_model.route_index("oral"), Some(oral)); + assert_eq!(macro_model.route_index("iv"), Some(iv)); + assert_eq!(handwritten_model.route_index("oral"), Some(oral)); + assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime analytical model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => { + panic!("analytical runtime should return subject predictions") + } + }; + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-8); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-8); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { + let runtime_model = + compile_runtime_jit_model(SDE_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_shared_sde"); + let macro_model = runtime_shared_channel_macro_sde(); + let handwritten_model = runtime_shared_channel_handwritten_sde(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(macro_model.route_index("oral"), Some(oral)); + assert_eq!(macro_model.route_index("iv"), Some(iv)); + assert_eq!(handwritten_model.route_index("oral"), Some(oral)); + assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime SDE model should simulate") + { + RuntimePredictions::Particles(predictions) => particle_prediction_means(&predictions), + RuntimePredictions::Subject(_) => panic!("SDE runtime should return particle predictions"), + }; + let macro_predictions = particle_prediction_means( + ¯o_model + .estimate_predictions(&subject, &support_point) + .expect("macro SDE model should simulate"), + ); + let handwritten_predictions = particle_prediction_means( + &handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten SDE model should simulate"), + ); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-4); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-4); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn route_input_policy_runtime_mismatches_are_detected_explicitly() { + let runtime_model = + compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); + let mismatched_model = runtime_mismatched_shared_channel_ode(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(mismatched_model.route_index("oral"), Some(oral)); + assert_eq!(mismatched_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime ODE model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + let mismatched_predictions = mismatched_model + .estimate_predictions(&subject, &support_point) + .expect("mismatched handwritten ODE should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_diverge(&runtime_predictions, &mismatched_predictions, 1e-4); +} diff --git a/tests/browser-e2e/site/app.mjs b/tests/browser-e2e/site/app.mjs index 87b68eda..c939c670 100644 --- a/tests/browser-e2e/site/app.mjs +++ b/tests/browser-e2e/site/app.mjs @@ -4,7 +4,7 @@ const precompiledInputs = Object.freeze({ }); const compileFlowSource = ` -model = example_ode +name = example_ode kind = ode params = ke, v @@ -25,7 +25,7 @@ const compileFlowInputs = Object.freeze({ }); const invalidCompileSource = ` -model = broken +name = broken kind = ode states = central diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs new file mode 100644 index 00000000..45636bb1 --- /dev/null +++ b/tests/ode_macro_lowering.rs @@ -0,0 +1,376 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; + +fn subject_for_route(input: usize) -> Subject { + Subject::builder("macro-lowering") + .infusion(0.0, 100.0, input, 1.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn subject_for_shared_channel(input: usize) -> Subject { + Subject::builder("macro-shared-channel") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .build() +} + + fn subject_for_covariates(input: usize) -> Subject { + Subject::builder("macro-covariates") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .covariate("wt", 0.0, 70.0) + .build() + } + +fn injected_macro_ode() -> equation::ODE { + ode! { + name: "injected_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn injected_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("injected_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten injected metadata should validate") +} + +fn explicit_macro_ode() -> equation::ODE { + ode! { + name: "explicit_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _bolus, rateiv, _cov| { + dx[central] = rateiv[iv] - ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn explicit_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("explicit_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ), + ) + .expect("handwritten explicit metadata should validate") +} + +fn shared_channel_macro_ode() -> equation::ODE { + ode! { + name: "shared_channel_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, bolus, rateiv, _cov| { + dx[depot] = bolus[oral] - ka * x[depot]; + dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn shared_channel_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_channel_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten shared-channel metadata should validate") +} + +fn covariate_macro_ode() -> equation::ODE { + ode! { + name: "covariate_one_cpt", + params: [ka, ke, v], + covariates: [wt], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _p, t, dx, cov| { + fetch_cov!(cov, t, wt); + let scaled_ke = ke * (wt / 70.0); + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - scaled_ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn covariate_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, t, dx, bolus, _rateiv, cov| { + fetch_cov!(cov, t, wt); + fetch_params!(p, ka, ke, _v); + let scaled_ke = ke * (wt / 70.0); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] - scaled_ke * x[1]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("covariate_one_cpt") + .parameters(["ka", "ke", "v"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination(), + ), + ) + .expect("handwritten covariate metadata should validate") +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +#[test] +fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = injected_macro_ode(); + let handwritten_ode = injected_handwritten_ode(); + let subject = subject_for_route(0); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("iv"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro injected model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten injected model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = explicit_macro_ode(); + let handwritten_ode = explicit_handwritten_ode(); + let subject = subject_for_route(0); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("iv"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro explicit model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten explicit model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = shared_channel_macro_ode(); + let handwritten_ode = shared_channel_handwritten_ode(); + let subject = subject_for_shared_channel(0); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("oral"), Some(0)); + assert_eq!(macro_ode.route_index("iv"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("depot"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro shared-channel model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-channel model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_covariate_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = covariate_macro_ode(); + let handwritten_ode = covariate_handwritten_ode(); + let subject = subject_for_covariates(0); + let support_point = [1.0, 0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro covariate model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_metadata.covariates().len(), 1); + assert_eq!(macro_ode.route_index("oral"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("gut"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro covariate model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs new file mode 100644 index 00000000..05fd4c63 --- /dev/null +++ b/tests/sde_macro_lowering.rs @@ -0,0 +1,359 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; +use pharmsol::Predictions; + +fn infusion_subject(input: usize) -> Subject { + Subject::builder("sde-macro-iv") + .infusion(0.0, 120.0, input, 1.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn oral_subject(input: usize) -> Subject { + Subject::builder("sde-macro-oral") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + + fn shared_channel_subject(input: usize) -> Subject { + Subject::builder("sde-macro-shared") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .build() + } + +fn prediction_means(predictions: &ndarray::Array2) -> Vec { + predictions + .get_predictions() + .into_iter() + .map(|prediction| prediction.prediction()) + .collect() +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +fn macro_infusion_sde() -> equation::SDE { + sde! { + name: "one_cpt_sde", + params: [ke, sigma_ke, v], + states: [central], + outputs: [cp], + particles: 16, + routes: { + infusion(iv) -> central, + }, + drift: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[central] = sigma_ke; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_infusion_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ke, _sigma_ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |p, sigma| { + fetch_params!(p, _ke, sigma_ke, _v); + sigma[0] = sigma_ke; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, _sigma_ke, v); + y[0] = x[0] / v; + }, + 16, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ke", "sigma_ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("handwritten SDE metadata should validate") +} + +fn macro_absorption_sde() -> equation::SDE { + sde! { + name: "one_cmt_abs_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + }, + drift: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[gut] = 0.0 * sigma_ke; + sigma[central] = sigma_ke; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + init: |_p, _t, _cov, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_absorption_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, _rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, sigma| { + fetch_params!(p, _ka, _ke, sigma_ke, _v, _tlag, _f_oral); + sigma[0] = 0.0 * sigma_ke; + sigma[1] = sigma_ke; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + ) + .particles(8), + ) + .expect("handwritten absorption SDE metadata should validate") +} + +fn macro_shared_channel_sde() -> equation::SDE { + sde! { + name: "one_cmt_shared_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + drift: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[gut] = 0.0; + sigma[central] = 0.0; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + init: |_p, _t, _cov, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_shared_channel_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_shared_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten shared-channel SDE metadata should validate") +} + +#[test] +fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_infusion_sde(); + let handwritten_model = handwritten_infusion_sde(); + let subject = infusion_subject(0); + let support_point = [0.2, 0.0, 10.0]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro SDE model should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten SDE model should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { + let macro_model = macro_absorption_sde(); + let handwritten_model = handwritten_absorption_sde(); + let subject = oral_subject(0); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro absorption SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten absorption SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_channel_sde(); + let handwritten_model = handwritten_shared_channel_sde(); + let subject = shared_channel_subject(0); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(1)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro shared-channel SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-channel SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} \ No newline at end of file diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index fe363aa7..4c82be4f 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -12,7 +12,7 @@ pub const OBSERVATION_TIMES: [f64; 7] = [0.5, 1.0, 2.0, 3.0, 4.0, 6.0, 8.0]; pub const SUPPORT_POINT: [f64; 2] = [1.2, 50.0]; pub const AUTHORING_DSL: &str = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 0a9917a0..6a14ed33 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -19,7 +19,7 @@ use pharmsol::{equation, fa, fetch_cov, fetch_params, lag, Subject, SubjectBuild use tempfile::{tempdir, TempDir}; const ODE_SOURCE: &str = r#" -model = one_cmt_oral_iv +name = one_cmt_oral_iv kind = ode params = ka, cl, v, tlag, f_oral @@ -44,7 +44,7 @@ out(cp) = central / v ~ continuous() "#; const ANALYTICAL_SOURCE: &str = r#" -model = one_cmt_abs +name = one_cmt_abs kind = analytical params = ka, ke, v, tlag, f_oral @@ -56,13 +56,13 @@ bolus(oral) -> depot lag(oral) = tlag fa(oral) = f_oral -kernel = one_compartment_with_absorption +structure = one_compartment_with_absorption out(cp) = central / v ~ continuous() "#; const SDE_SOURCE: &str = r#" -model = vanco_sde +name = vanco_sde kind = sde params = ka, ke0, kcp, kpc, vol, ske From b0c967f197467f2f0c8a4fcc2a80f09db11bf5da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 08:18:19 +0100 Subject: [PATCH 02/22] chore: fmt --- examples/analytical_readme.rs | 2 +- examples/sde_readme.rs | 2 +- pharmsol-dsl/src/semantic.rs | 15 ++++++++------- src/dsl/jit.rs | 4 +--- src/dsl/model_info.rs | 5 ++++- src/dsl/native.rs | 21 +++++++++++++-------- tests/ode_macro_lowering.rs | 6 +++--- tests/sde_macro_lowering.rs | 8 ++++---- 8 files changed, 35 insertions(+), 28 deletions(-) diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs index fdcd5b1a..8e5b97f7 100644 --- a/examples/analytical_readme.rs +++ b/examples/analytical_readme.rs @@ -32,4 +32,4 @@ fn main() -> Result<(), pharmsol::PharmsolError> { println!("predictions => {:?}", predictions.flat_predictions()); Ok(()) -} \ No newline at end of file +} diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs index b3385bed..6106b17a 100644 --- a/examples/sde_readme.rs +++ b/examples/sde_readme.rs @@ -38,4 +38,4 @@ fn main() -> Result<(), pharmsol::PharmsolError> { println!("prediction grid shape => {:?}", predictions.dim()); Ok(()) -} \ No newline at end of file +} diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index b328934f..ac9223dd 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -345,12 +345,13 @@ impl<'a> Analyzer<'a> { }; let analytical = if let Some(block) = sections.analytical { - let structure = AnalyticalKernel::from_name(&block.structure.text).ok_or_else(|| { - SemanticError::new( - format!("unknown analytical structure `{}`", block.structure.text), - block.structure.span, - ) - })?; + let structure = + AnalyticalKernel::from_name(&block.structure.text).ok_or_else(|| { + SemanticError::new( + format!("unknown analytical structure `{}`", block.structure.text), + block.structure.span, + ) + })?; let state_components = states .iter() .map(|state| state.size.unwrap_or(1)) @@ -2652,8 +2653,8 @@ mod tests { use crate::test_fixtures::{ RECOMMENDED_STYLE_AUTHORING, RECOMMENDED_STYLE_CANONICAL, STRUCTURED_BLOCK_CORPUS, }; - use crate::{parse_model, parse_module}; use crate::RouteKind; + use crate::{parse_model, parse_module}; #[test] fn analyzes_structured_block_corpus() { diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index ac71f504..b0f1fe4a 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -731,9 +731,7 @@ fn lower_load( ExecutionLoad::Parameter(index) => load_fixed(builder, env.args.params, *index, ty), ExecutionLoad::Covariate(index) => load_fixed(builder, env.args.covariates, *index, ty), ExecutionLoad::Derived(index) => load_fixed(builder, env.args.derived, *index, ty), - ExecutionLoad::RouteInput { index, .. } => { - load_fixed(builder, env.args.routes, *index, ty) - } + ExecutionLoad::RouteInput { index, .. } => load_fixed(builder, env.args.routes, *index, ty), ExecutionLoad::Local(index) => { let binding = env.locals.get(index).ok_or_else(|| { JitCompileError::new(format!("unknown local slot {index}"), Some(span)) diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 6735ca30..0094059f 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -163,7 +163,10 @@ fn mark_route_inputs_in_expr( match &expr.kind { ExecutionExprKind::Literal(_) => {} ExecutionExprKind::Load(ExecutionLoad::RouteInput { route, .. }) => { - if let Some(slot) = declaration_slots.get(route).and_then(|index| usage.get_mut(*index)) { + if let Some(slot) = declaration_slots + .get(route) + .and_then(|index| usage.get_mut(*index)) + { *slot = true; } } diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 923b644d..5186fcf9 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -289,14 +289,16 @@ impl RouteInputSemantics { Some(RouteKind::Infusion) => { infusion_inputs[route.index] = true; if route.inject_input_to_destination { - injected_infusion_destinations[route.index] = Some(route.destination_offset); + injected_infusion_destinations[route.index] = + Some(route.destination_offset); } } None => { bolus_destinations[route.index] = Some(route.destination_offset); infusion_inputs[route.index] = true; if route.inject_input_to_destination { - injected_infusion_destinations[route.index] = Some(route.destination_offset); + injected_infusion_destinations[route.index] = + Some(route.destination_offset); } } } @@ -603,12 +605,15 @@ impl SharedNativeModel { amount: f64, ) -> Result<(), PharmsolError> { self.validate_input_for_kind(input, RouteKind::Bolus)?; - let destination = self.route_semantics.bolus_destination(input).ok_or_else(|| { - PharmsolError::OtherError(format!( - "model `{}` does not declare a bolus route for input channel {}", - self.info.name, input - )) - })?; + let destination = self + .route_semantics + .bolus_destination(input) + .ok_or_else(|| { + PharmsolError::OtherError(format!( + "model `{}` does not declare a bolus route for input channel {}", + self.info.name, input + )) + })?; state[destination] += amount; Ok(()) } diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 45636bb1..1cc4cb5c 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -23,15 +23,15 @@ fn subject_for_shared_channel(input: usize) -> Subject { .build() } - fn subject_for_covariates(input: usize) -> Subject { - Subject::builder("macro-covariates") +fn subject_for_covariates(input: usize) -> Subject { + Subject::builder("macro-covariates") .bolus(0.0, 100.0, input) .missing_observation(0.5, 0) .missing_observation(1.0, 0) .missing_observation(2.0, 0) .covariate("wt", 0.0, 70.0) .build() - } +} fn injected_macro_ode() -> equation::ODE { ode! { diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 05fd4c63..289ab127 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -20,8 +20,8 @@ fn oral_subject(input: usize) -> Subject { .build() } - fn shared_channel_subject(input: usize) -> Subject { - Subject::builder("sde-macro-shared") +fn shared_channel_subject(input: usize) -> Subject { + Subject::builder("sde-macro-shared") .bolus(0.0, 100.0, input) .infusion(6.0, 60.0, input, 2.0) .missing_observation(0.5, 0) @@ -31,7 +31,7 @@ fn oral_subject(input: usize) -> Subject { .missing_observation(7.0, 0) .missing_observation(8.0, 0) .build() - } +} fn prediction_means(predictions: &ndarray::Array2) -> Vec { predictions @@ -356,4 +356,4 @@ fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_prediction &prediction_means(¯o_predictions), &prediction_means(&handwritten_predictions), ); -} \ No newline at end of file +} From 6c6be15ce4d966842999a23373a24e9a83bb8193 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 08:26:52 +0100 Subject: [PATCH 03/22] fix test --- src/dsl/wasm.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index 56c6fcbd..f2504d44 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -895,7 +895,7 @@ mod tests { let model_info = loader_test_model_info("api_version_export_mismatch"); let metadata = serde_json::to_vec(&CompiledModelInfoEnvelope { abi_version: WASM_API_VERSION, - name: "model_info", + model: model_info, kernels: CompiledKernelAvailability { outputs: true, ..CompiledKernelAvailability::default() @@ -924,7 +924,7 @@ mod tests { let model_info = loader_test_model_info("metadata_api_version_mismatch"); let metadata = serde_json::to_vec(&CompiledModelInfoEnvelope { abi_version: WASM_API_VERSION + 1, - name: "model_info", + model: model_info, kernels: CompiledKernelAvailability { outputs: true, ..CompiledKernelAvailability::default() @@ -950,7 +950,7 @@ mod tests { let model_info = loader_test_model_info("kernel_metadata_mismatch"); let metadata = serde_json::to_vec(&CompiledModelInfoEnvelope { abi_version: WASM_API_VERSION, - name: "model_info", + model: model_info, kernels: CompiledKernelAvailability { outputs: true, ..CompiledKernelAvailability::default() From 75cba572cf7201e1ff7a826f1a461cf26353f8a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 09:36:02 +0100 Subject: [PATCH 04/22] fix test, add new test to check for full parity. checking also against DSL --- examples/covariates.rs | 9 +- examples/two_compartment.rs | 9 +- pharmsol-dsl/src/authoring.rs | 6 +- pharmsol-macros/src/lib.rs | 951 +++++++++++++++++++++++------ src/dsl/native.rs | 48 +- tests/analytical_macro_lowering.rs | 212 ++++++- tests/full_feature_macro_parity.rs | 473 ++++++++++++++ tests/ode_macro_lowering.rs | 21 +- tests/runtime_backend_matrix.rs | 79 +++ tests/sde_macro_lowering.rs | 195 +++++- tests/support/runtime_corpus.rs | 368 +++++++++++ 11 files changed, 2125 insertions(+), 246 deletions(-) create mode 100644 tests/full_feature_macro_parity.rs diff --git a/examples/covariates.rs b/examples/covariates.rs index 0e85b9bf..9aabf491 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -10,21 +10,18 @@ fn main() { routes: { bolus(oral) -> gut, }, - diffeq: |x, _p, t, dx, cov| { - // Macro to get the (possibly interpolated) covariate values at time `t` - fetch_cov!(cov, t, creatinine, age); - + diffeq: |x, _t, dx| { let scaled_ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - scaled_ke * x[central]; }, // This blocks defines the lag-time of the bolus dose - lag: |_p, _t, _cov| { + lag: |_t| { // Macro used to define the lag-time for the input of the bolus dose lag! { oral => tlag } }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { // Define the predicted concentration as the amount in the central compartment divided by volume y[cp] = x[central] / v; }, diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index 5634704c..64d554af 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -30,10 +30,7 @@ fn main() -> Result<(), pharmsol::PharmsolError> { routes: { infusion(iv) -> central, }, - diffeq: |x, _p, t, dx, cov| { - // Fetch the (possibly interpolated) weight covariate at time t - fetch_cov!(cov, t, wt); - + diffeq: |x, _t, dx| { // CL: Clearance (L/hr), V: Central volume (L) // Vp: Peripheral volume (L), Q: Inter-compartmental clearance (L/hr) // Weight-based allometric scaling @@ -58,9 +55,7 @@ fn main() -> Result<(), pharmsol::PharmsolError> { dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; }, // Output equation block - calculates observed concentration - out: |x, _p, t, cov, y| { - fetch_cov!(cov, t, wt); - + out: |x, _t, y| { // Calculate scaled volume for concentration let wt_ratio = wt / 85.0; let v_scaled = v * wt_ratio; diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 09bb309a..c81c19eb 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -598,12 +598,10 @@ impl<'a> AuthoringParser<'a> { .unwrap_or(module_span); if matches!(kind, ModelKind::Analytical) - && (!self.diffusion_statements.is_empty() - || self.particles.is_some() - || !self.init_statements.is_empty()) + && (!self.diffusion_statements.is_empty() || self.particles.is_some()) { return Err(ParseError::new( - "analytical authoring models cannot declare particles, init, or noise equations", + "analytical authoring models cannot declare particles or noise equations", kind_span, )); } diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 7607871a..54a79fe3 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -37,10 +37,12 @@ struct OdeInput { struct AnalyticalInput { name: LitStr, params: Vec, + covariates: Vec, states: Vec, outputs: Vec, routes: Vec, structure: Ident, + sec: Option, lag: Option, fa: Option, init: Option, @@ -50,6 +52,7 @@ struct AnalyticalInput { struct SdeInput { name: LitStr, params: Vec, + covariates: Vec, states: Vec, outputs: Vec, routes: Vec, @@ -198,7 +201,7 @@ impl Parse for OdeInput { let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?; let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?; let out = out.ok_or_else(|| missing_required_ode_field("out"))?; - let diffeq_mode = classify_diffeq_mode(&diffeq)?; + let diffeq_mode = classify_diffeq_mode(&diffeq, &routes)?; validate_unique_idents("parameter", ¶ms, "ode!")?; validate_unique_idents("covariate", &covariates, "ode!")?; @@ -206,13 +209,23 @@ impl Parse for OdeInput { validate_unique_idents("output", &outputs, "ode!")?; validate_routes(&routes, &states, "ode!")?; validate_named_binding_compatibility( - ¶ms, - &states, - &outputs, - &routes, - &diffeq, - &out, - diffeq_mode, + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &outputs, + routes: &routes, + }, + OdeBindingClosures { + diffeq: &diffeq, + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + diffeq_mode, + }, )?; Ok(Self { @@ -245,10 +258,12 @@ impl Parse for AnalyticalInput { fn parse(input: ParseStream) -> syn::Result { let mut name = None; let mut params = None; + let mut covariates = None; let mut states = None; let mut outputs = None; let mut routes = None; let mut structure = None; + let mut sec = None; let mut lag = None; let mut fa = None; let mut init = None; @@ -263,6 +278,12 @@ impl Parse for AnalyticalInput { "params" => { set_once_analytical(&mut params, parse_ident_list(input)?, &key, "params")? } + "covariates" => set_once_analytical( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, "states" => { set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")? } @@ -275,6 +296,7 @@ impl Parse for AnalyticalInput { "structure" => { set_once_analytical(&mut structure, input.parse()?, &key, "structure")? } + "sec" => set_once_analytical(&mut sec, input.parse()?, &key, "sec")?, "lag" => set_once_analytical(&mut lag, input.parse()?, &key, "lag")?, "fa" => set_once_analytical(&mut fa, input.parse()?, &key, "fa")?, "init" => set_once_analytical(&mut init, input.parse()?, &key, "init")?, @@ -283,7 +305,7 @@ impl Parse for AnalyticalInput { return Err(syn::Error::new_spanned( &key, format!( - "unknown field `{other}`, expected one of: name, params, states, outputs, routes, structure, lag, fa, init, out" + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, structure, sec, lag, fa, init, out" ), )); } @@ -296,6 +318,7 @@ impl Parse for AnalyticalInput { let name = name.ok_or_else(|| missing_required_analytical_field("name"))?; let params = params.ok_or_else(|| missing_required_analytical_field("params"))?; + let covariates = covariates.unwrap_or_default(); let states = states.ok_or_else(|| missing_required_analytical_field("states"))?; let outputs = outputs.ok_or_else(|| missing_required_analytical_field("outputs"))?; let routes = routes.ok_or_else(|| missing_required_analytical_field("routes"))?; @@ -303,6 +326,7 @@ impl Parse for AnalyticalInput { let out = out.ok_or_else(|| missing_required_analytical_field("out"))?; validate_unique_idents("parameter", ¶ms, "analytical!")?; + validate_unique_idents("covariate", &covariates, "analytical!")?; validate_unique_idents("state", &states, "analytical!")?; validate_unique_idents("output", &outputs, "analytical!")?; validate_routes(&routes, &states, "analytical!")?; @@ -330,14 +354,22 @@ impl Parse for AnalyticalInput { } validate_analytical_named_binding_compatibility( - ¶ms, - &states, - &outputs, - &routes, - lag.as_ref(), - fa.as_ref(), - init.as_ref(), - &out, + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &outputs, + routes: &routes, + }, + AnalyticalBindingClosures { + sec: sec.as_ref(), + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + }, )?; if let Some(lag) = lag.as_ref() { @@ -355,10 +387,12 @@ impl Parse for AnalyticalInput { Ok(Self { name, params, + covariates, states, outputs, routes, structure, + sec, lag, fa, init, @@ -371,6 +405,7 @@ impl Parse for SdeInput { fn parse(input: ParseStream) -> syn::Result { let mut name = None; let mut params = None; + let mut covariates = None; let mut states = None; let mut outputs = None; let mut routes = None; @@ -389,6 +424,12 @@ impl Parse for SdeInput { match key.to_string().as_str() { "name" => set_once_sde(&mut name, input.parse()?, &key, "name")?, "params" => set_once_sde(&mut params, parse_ident_list(input)?, &key, "params")?, + "covariates" => set_once_sde( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?, "outputs" => set_once_sde(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?, @@ -403,7 +444,7 @@ impl Parse for SdeInput { return Err(syn::Error::new_spanned( &key, format!( - "unknown field `{other}`, expected one of: name, params, states, outputs, routes, particles, drift, diffusion, lag, fa, init, out" + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, particles, drift, diffusion, lag, fa, init, out" ), )); } @@ -416,6 +457,7 @@ impl Parse for SdeInput { let name = name.ok_or_else(|| missing_required_sde_field("name"))?; let params = params.ok_or_else(|| missing_required_sde_field("params"))?; + let covariates = covariates.unwrap_or_default(); let states = states.ok_or_else(|| missing_required_sde_field("states"))?; let outputs = outputs.ok_or_else(|| missing_required_sde_field("outputs"))?; let routes = routes.ok_or_else(|| missing_required_sde_field("routes"))?; @@ -425,20 +467,28 @@ impl Parse for SdeInput { let out = out.ok_or_else(|| missing_required_sde_field("out"))?; validate_unique_idents("parameter", ¶ms, "sde!")?; + validate_unique_idents("covariate", &covariates, "sde!")?; validate_unique_idents("state", &states, "sde!")?; validate_unique_idents("output", &outputs, "sde!")?; validate_routes(&routes, &states, "sde!")?; validate_sde_named_binding_compatibility( - ¶ms, - &states, - &outputs, - &routes, - &drift, - &diffusion, - lag.as_ref(), - fa.as_ref(), - init.as_ref(), - &out, + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &outputs, + routes: &routes, + }, + SdeBindingClosures { + drift: &drift, + diffusion: &diffusion, + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + }, )?; if let Some(lag) = lag.as_ref() { @@ -456,6 +506,7 @@ impl Parse for SdeInput { Ok(Self { name, params, + covariates, states, outputs, routes, @@ -579,6 +630,8 @@ fn generated_ident(name: &str) -> Ident { #[derive(Default)] struct ClosureBodyUsage { idents: HashSet, + indexed_idents: HashSet, + assigned_indexed_idents: HashSet, contains_macro: bool, } @@ -592,6 +645,18 @@ impl ClosureBodyUsage { fn uses(&self, ident: &Ident) -> bool { self.contains_macro || self.idents.contains(&ident.to_string()) } + + fn mentions(&self, ident: &Ident) -> bool { + self.idents.contains(&ident.to_string()) + } + + fn indexes(&self, ident: &Ident) -> bool { + self.indexed_idents.contains(&ident.to_string()) + } + + fn assigns_index(&self, ident: &Ident) -> bool { + self.assigned_indexed_idents.contains(&ident.to_string()) + } } impl<'ast> Visit<'ast> for ClosureBodyUsage { @@ -616,6 +681,36 @@ impl<'ast> Visit<'ast> for ClosureBodyUsage { self.contains_macro = true; syn::visit::visit_stmt_macro(self, stmt_macro); } + + fn visit_expr_index(&mut self, expr_index: &'ast syn::ExprIndex) { + if let Expr::Path(expr_path) = expr_index.expr.as_ref() { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 + { + self.indexed_idents + .insert(expr_path.path.segments[0].ident.to_string()); + } + } + + syn::visit::visit_expr_index(self, expr_index); + } + + fn visit_expr_assign(&mut self, expr_assign: &'ast syn::ExprAssign) { + if let Expr::Index(expr_index) = expr_assign.left.as_ref() { + if let Expr::Path(expr_path) = expr_index.expr.as_ref() { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 + { + self.assigned_indexed_idents + .insert(expr_path.path.segments[0].ident.to_string()); + } + } + } + + syn::visit::visit_expr_assign(self, expr_assign); + } } fn generate_closure_input_aliases( @@ -645,6 +740,20 @@ fn generate_closure_input_aliases( }) } +fn generate_supported_input_aliases( + closure: &ExprClosure, + supported_internal_names: &[&[Ident]], + error_message: &str, +) -> syn::Result { + for internal_names in supported_internal_names { + if closure.inputs.len() == internal_names.len() { + return generate_closure_input_aliases(closure, internal_names); + } + } + + Err(syn::Error::new_spanned(closure, error_message)) +} + fn generate_parameter_bindings( params: &[Ident], closure: &ExprClosure, @@ -667,13 +776,81 @@ fn generate_parameter_bindings( } } -fn classify_diffeq_mode(diffeq: &ExprClosure) -> syn::Result { +fn generate_mutable_parameter_bindings( + params: &[Ident], + closure: &ExprClosure, + parameter_vector: &Ident, +) -> (TokenStream2, TokenStream2) { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let used_params = params + .iter() + .enumerate() + .filter(|(_, ident)| usage.uses(ident)) + .collect::>(); + + let bindings = used_params.iter().map(|(index, ident)| { + quote! { + #[allow(unused_mut, unused_variables)] + let mut #ident = #parameter_vector[#index]; + } + }); + let writebacks = used_params.iter().map(|(index, ident)| { + quote! { + #parameter_vector[#index] = #ident; + } + }); + + (quote! { #(#bindings)* }, quote! { #(#writebacks)* }) +} + +fn generate_covariate_bindings( + covariates: &[Ident], + closure: &ExprClosure, + covariate_map: &Ident, + time: &Ident, +) -> TokenStream2 { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let used_covariates = covariates + .iter() + .filter(|ident| usage.uses(ident)) + .collect::>(); + + if used_covariates.is_empty() { + quote! {} + } else { + quote! { + ::pharmsol::fetch_cov!(#covariate_map, #time, #(#used_covariates),*); + } + } +} + +fn classify_diffeq_mode( + diffeq: &ExprClosure, + routes: &[OdeRouteDecl], +) -> syn::Result { match closure_param_names(diffeq).len() { - 5 => Ok(OdeDiffeqMode::InjectedRouteInputs), + 3 => Ok(OdeDiffeqMode::InjectedRouteInputs), 7 => Ok(OdeDiffeqMode::ExplicitRouteVectors), + 5 => { + let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref()); + let route_inputs = route_input_idents(routes); + let fourth_param = closure_param_ident(diffeq, 3); + let fifth_param = closure_param_ident(diffeq, 4); + let mentions_route_inputs = route_inputs.iter().any(|route| usage.mentions(route)); + let indexes_fifth_param = fifth_param.as_ref().is_some_and(|ident| usage.indexes(ident)); + let reads_fourth_param_as_input = fourth_param + .as_ref() + .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident)); + + if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input { + Ok(OdeDiffeqMode::ExplicitRouteVectors) + } else { + Ok(OdeDiffeqMode::InjectedRouteInputs) + } + } _ => Err(syn::Error::new_spanned( diffeq, - "declaration-first `ode!` requires `diffeq` to have either 5 parameters: |x, p, t, dx, cov| or 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", + "declaration-first `ode!` requires `diffeq` to have either 3 parameters: |x, t, dx|, 5 parameters: |x, p, t, dx, cov| or |x, t, dx, bolus, rateiv|, or 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", )), } } @@ -764,17 +941,68 @@ fn validate_closure_param_conflicts( Ok(()) } -fn validate_named_binding_compatibility( - params: &[Ident], - states: &[Ident], - outputs: &[Ident], - routes: &[OdeRouteDecl], - diffeq: &ExprClosure, - out: &ExprClosure, +#[derive(Clone, Copy)] +struct NamedBindingSets<'a> { + params: &'a [Ident], + covariates: &'a [Ident], + states: &'a [Ident], + outputs: &'a [Ident], + routes: &'a [OdeRouteDecl], +} + +#[derive(Clone, Copy)] +struct CommonBindingClosures<'a> { + lag: Option<&'a ExprClosure>, + fa: Option<&'a ExprClosure>, + init: Option<&'a ExprClosure>, + out: &'a ExprClosure, +} + +#[derive(Clone, Copy)] +struct AnalyticalBindingClosures<'a> { + sec: Option<&'a ExprClosure>, + common: CommonBindingClosures<'a>, +} + +#[derive(Clone, Copy)] +struct OdeBindingClosures<'a> { + diffeq: &'a ExprClosure, + common: CommonBindingClosures<'a>, diffeq_mode: OdeDiffeqMode, +} + +#[derive(Clone, Copy)] +struct SdeBindingClosures<'a> { + drift: &'a ExprClosure, + diffusion: &'a ExprClosure, + common: CommonBindingClosures<'a>, +} + +fn validate_named_binding_compatibility( + bindings: NamedBindingSets<'_>, + closures: OdeBindingClosures<'_>, ) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let OdeBindingClosures { + diffeq, + common: CommonBindingClosures { lag, fa, init, out }, + diffeq_mode, + } = closures; let route_inputs = route_input_idents(routes); + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "declaration-first `ode!` named binding generation", + )?; validate_binding_conflicts( "parameter", params, @@ -796,12 +1024,24 @@ fn validate_named_binding_compatibility( outputs, "`out` named binding generation", )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "declaration-first `ode!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "declaration-first `ode!` named binding generation", + )?; validate_closure_param_conflicts("diffeq", diffeq, params, "parameter")?; + validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?; validate_closure_param_conflicts("diffeq", diffeq, states, "state")?; - validate_closure_param_conflicts("out", out, params, "parameter")?; - validate_closure_param_conflicts("out", out, states, "state")?; - validate_closure_param_conflicts("out", out, outputs, "output")?; if diffeq_mode == OdeDiffeqMode::ExplicitRouteVectors { validate_binding_conflicts( @@ -818,24 +1058,80 @@ fn validate_named_binding_compatibility( &route_inputs, "`diffeq` named binding generation", )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`diffeq` named binding generation", + )?; validate_closure_param_conflicts("diffeq", diffeq, &route_inputs, "route")?; } + if let Some(lag) = lag { + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`lag` named binding generation", + )?; + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`fa` named binding generation", + )?; + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + Ok(()) } fn validate_analytical_named_binding_compatibility( - params: &[Ident], - states: &[Ident], - outputs: &[Ident], - routes: &[OdeRouteDecl], - lag: Option<&ExprClosure>, - fa: Option<&ExprClosure>, - init: Option<&ExprClosure>, - out: &ExprClosure, + bindings: NamedBindingSets<'_>, + closures: AnalyticalBindingClosures<'_>, ) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let AnalyticalBindingClosures { + sec, + common: CommonBindingClosures { lag, fa, init, out }, + } = closures; let route_inputs = route_input_idents(routes); + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "`analytical!` named binding generation", + )?; validate_binding_conflicts( "parameter", params, @@ -850,6 +1146,27 @@ fn validate_analytical_named_binding_compatibility( outputs, "`analytical!` named binding generation", )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; validate_binding_conflicts( "parameter", params, @@ -879,22 +1196,31 @@ fn validate_analytical_named_binding_compatibility( "`analytical!` named binding generation", )?; + if let Some(sec) = sec { + validate_closure_param_conflicts("sec", sec, params, "parameter")?; + validate_closure_param_conflicts("sec", sec, covariates, "covariate")?; + } + if let Some(lag) = lag { validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; } if let Some(fa) = fa { validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; } if let Some(init) = init { validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; validate_closure_param_conflicts("init", init, states, "state")?; } validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; validate_closure_param_conflicts("out", out, states, "state")?; validate_closure_param_conflicts("out", out, outputs, "output")?; @@ -902,19 +1228,30 @@ fn validate_analytical_named_binding_compatibility( } fn validate_sde_named_binding_compatibility( - params: &[Ident], - states: &[Ident], - outputs: &[Ident], - routes: &[OdeRouteDecl], - drift: &ExprClosure, - diffusion: &ExprClosure, - lag: Option<&ExprClosure>, - fa: Option<&ExprClosure>, - init: Option<&ExprClosure>, - out: &ExprClosure, + bindings: NamedBindingSets<'_>, + closures: SdeBindingClosures<'_>, ) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let SdeBindingClosures { + drift, + diffusion, + common: CommonBindingClosures { lag, fa, init, out }, + } = closures; let route_inputs = route_input_idents(routes); + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "`sde!` named binding generation", + )?; validate_binding_conflicts( "parameter", params, @@ -929,6 +1266,27 @@ fn validate_sde_named_binding_compatibility( outputs, "`sde!` named binding generation", )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; validate_binding_conflicts( "parameter", params, @@ -959,26 +1317,31 @@ fn validate_sde_named_binding_compatibility( )?; validate_closure_param_conflicts("drift", drift, params, "parameter")?; + validate_closure_param_conflicts("drift", drift, covariates, "covariate")?; validate_closure_param_conflicts("drift", drift, states, "state")?; validate_closure_param_conflicts("diffusion", diffusion, params, "parameter")?; validate_closure_param_conflicts("diffusion", diffusion, states, "state")?; if let Some(lag) = lag { validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; } if let Some(fa) = fa { validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; } if let Some(init) = init { validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; validate_closure_param_conflicts("init", init, states, "state")?; } validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; validate_closure_param_conflicts("out", out, states, "state")?; validate_closure_param_conflicts("out", out, outputs, "output")?; @@ -1014,16 +1377,10 @@ fn generate_mapped_index_consts(bindings: &[(Ident, usize)]) -> TokenStream2 { fn expand_out( out: &ExprClosure, params: &[Ident], + covariates: &[Ident], states: &[Ident], outputs: &[Ident], ) -> syn::Result { - if closure_param_names(out).len() != 5 { - return Err(syn::Error::new_spanned( - out, - "declaration-first `ode!` requires `out` to have 5 parameters: |x, p, t, cov, y|", - )); - } - let state_consts = generate_index_consts(states); let output_consts = generate_index_consts(outputs); let x = generated_ident("__pharmsol_x"); @@ -1031,11 +1388,15 @@ fn expand_out( let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); let y = generated_ident("__pharmsol_y"); - let input_aliases = generate_closure_input_aliases( + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( out, - &[x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()], + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); let body = &out.body; Ok(quote! {{ @@ -1054,6 +1415,7 @@ fn expand_out( #state_consts #output_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_out @@ -1173,24 +1535,24 @@ fn expand_ode_route_map( label: &str, closure: &ExprClosure, params: &[Ident], + covariates: &[Ident], route_bindings: &[(Ident, usize)], ) -> syn::Result { - if closure_param_names(closure).len() != 3 { - return Err(syn::Error::new_spanned( - closure, - format!( - "declaration-first `ode!` requires `{label}` to have 3 parameters: |p, t, cov|" - ), - )); - } - let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); - let input_aliases = - generate_closure_input_aliases(closure, &[p.clone(), t.clone(), cov.clone()])?; + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "declaration-first `ode!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); let body = &closure.body; Ok(quote! {{ @@ -1204,12 +1566,55 @@ fn expand_ode_route_map( #input_aliases #route_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_route_map }}) } +fn expand_ode_init( + init: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_init + }}) +} + fn expand_route_metadata( routes: &[OdeRouteDecl], diffeq_mode: OdeDiffeqMode, @@ -1461,6 +1866,7 @@ fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) fn expand_diffeq( diffeq: &ExprClosure, params: &[Ident], + covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], route_bindings: &[(Ident, usize)], @@ -1478,19 +1884,29 @@ fn expand_diffeq( let bolus = generated_ident("__pharmsol_bolus"); let rateiv = generated_ident("__pharmsol_rateiv"); let cov = generated_ident("__pharmsol_cov"); - let input_aliases = generate_closure_input_aliases( + let full_inputs = [ + x.clone(), + p.clone(), + t.clone(), + dx.clone(), + bolus.clone(), + rateiv.clone(), + cov.clone(), + ]; + let reduced_inputs = [ + x.clone(), + t.clone(), + dx.clone(), + bolus.clone(), + rateiv.clone(), + ]; + let input_aliases = generate_supported_input_aliases( diffeq, - &[ - x.clone(), - p.clone(), - t.clone(), - dx.clone(), - bolus.clone(), - rateiv.clone(), - cov.clone(), - ], + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` explicit-route `diffeq` requires either 7 parameters: |x, p, t, dx, bolus, rateiv, cov| or 5 parameters: |x, t, dx, bolus, rateiv|", )?; let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); let body = &diffeq.body; Ok(quote! {{ @@ -1513,6 +1929,7 @@ fn expand_diffeq( #state_consts #route_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_diffeq @@ -1526,13 +1943,21 @@ fn expand_diffeq( let bolus = generated_ident("__pharmsol_bolus"); let rateiv = generated_ident("__pharmsol_rateiv"); let cov = generated_ident("__pharmsol_cov"); - let input_aliases = generate_closure_input_aliases( + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( diffeq, - &[x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()], + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", )?; let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); let body = &diffeq.body; - let dx_binding = closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()); + let dx_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) + }; let route_terms = expand_injected_ode_route_terms( routes, states, @@ -1561,6 +1986,7 @@ fn expand_diffeq( #input_aliases #state_consts #parameter_bindings + #covariate_bindings #body #route_terms }; @@ -1667,22 +2093,24 @@ fn expand_analytical_route_map( label: &str, closure: &ExprClosure, params: &[Ident], + covariates: &[Ident], route_bindings: &[(Ident, usize)], ) -> syn::Result { - if closure_param_names(closure).len() != 3 { - return Err(syn::Error::new_spanned( - closure, - format!("built-in `analytical!` requires `{label}` to have 3 parameters: |p, t, cov|"), - )); - } - let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); - let input_aliases = - generate_closure_input_aliases(closure, &[p.clone(), t.clone(), cov.clone()])?; + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "built-in `analytical!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); let body = &closure.body; Ok(quote! {{ @@ -1696,32 +2124,76 @@ fn expand_analytical_route_map( #input_aliases #route_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_route_map }}) } +fn expand_analytical_sec( + sec: &ExprClosure, + params: &[Ident], + covariates: &[Ident], +) -> syn::Result { + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + sec, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `sec` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|", + )?; + let parameter_vector = if sec.inputs.len() == full_inputs.len() { + closure_param_ident(sec, 0).unwrap_or_else(|| p.clone()) + } else { + p.clone() + }; + let (parameter_bindings, parameter_writebacks) = + generate_mutable_parameter_bindings(params, sec, ¶meter_vector); + let covariate_bindings = generate_covariate_bindings(covariates, sec, &cov, &t); + let body = &sec.body; + + Ok(quote! {{ + let __pharmsol_sec: fn( + &mut ::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) = |#p: &mut ::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #parameter_bindings + #covariate_bindings + #body + #parameter_writebacks + }; + __pharmsol_sec + }}) +} + fn expand_analytical_init( init: &ExprClosure, params: &[Ident], + covariates: &[Ident], states: &[Ident], ) -> syn::Result { - if closure_param_names(init).len() != 4 { - return Err(syn::Error::new_spanned( - init, - "built-in `analytical!` requires `init` to have 4 parameters: |p, t, cov, x|", - )); - } - let state_consts = generate_index_consts(states); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); let x = generated_ident("__pharmsol_x"); - let input_aliases = - generate_closure_input_aliases(init, &[p.clone(), t.clone(), cov.clone(), x.clone()])?; + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); let body = &init.body; Ok(quote! {{ @@ -1737,6 +2209,7 @@ fn expand_analytical_init( #input_aliases #state_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_init @@ -1746,16 +2219,10 @@ fn expand_analytical_init( fn expand_analytical_out( out: &ExprClosure, params: &[Ident], + covariates: &[Ident], states: &[Ident], outputs: &[Ident], ) -> syn::Result { - if closure_param_names(out).len() != 5 { - return Err(syn::Error::new_spanned( - out, - "built-in `analytical!` requires `out` to have 5 parameters: |x, p, t, cov, y|", - )); - } - let state_consts = generate_index_consts(states); let output_consts = generate_index_consts(outputs); let x = generated_ident("__pharmsol_x"); @@ -1763,11 +2230,15 @@ fn expand_analytical_out( let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); let y = generated_ident("__pharmsol_y"); - let input_aliases = generate_closure_input_aliases( + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( out, - &[x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()], + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); let body = &out.body; Ok(quote! {{ @@ -1786,6 +2257,7 @@ fn expand_analytical_out( #state_consts #output_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_out @@ -1795,17 +2267,11 @@ fn expand_analytical_out( fn expand_sde_drift( drift: &ExprClosure, params: &[Ident], + covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], route_bindings: &[(Ident, usize)], ) -> syn::Result { - if closure_param_names(drift).len() != 5 { - return Err(syn::Error::new_spanned( - drift, - "declaration-first `sde!` requires `drift` to have 5 parameters: |x, p, t, dx, cov|", - )); - } - let state_consts = generate_index_consts(states); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); @@ -1813,13 +2279,21 @@ fn expand_sde_drift( let dx = generated_ident("__pharmsol_dx"); let rateiv = generated_ident("__pharmsol_rateiv"); let cov = generated_ident("__pharmsol_cov"); - let input_aliases = generate_closure_input_aliases( + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( drift, - &[x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()], + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `drift` to have either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", )?; let parameter_bindings = generate_parameter_bindings(params, drift, &p); + let covariate_bindings = generate_covariate_bindings(covariates, drift, &cov, &t); let body = &drift.body; - let dx_binding = closure_param_ident(drift, 3).unwrap_or_else(|| dx.clone()); + let dx_binding = if drift.inputs.len() == full_inputs.len() { + closure_param_ident(drift, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(drift, 2).unwrap_or_else(|| dx.clone()) + }; let rate_terms = expand_injected_sde_rate_terms(routes, states, route_bindings, &dx_binding, &rateiv); @@ -1840,6 +2314,7 @@ fn expand_sde_drift( #input_aliases #state_consts #parameter_bindings + #covariate_bindings #body #rate_terms }; @@ -1852,17 +2327,16 @@ fn expand_sde_diffusion( params: &[Ident], states: &[Ident], ) -> syn::Result { - if closure_param_names(diffusion).len() != 2 { - return Err(syn::Error::new_spanned( - diffusion, - "declaration-first `sde!` requires `diffusion` to have 2 parameters: |p, sigma|", - )); - } - let state_consts = generate_index_consts(states); let p = generated_ident("__pharmsol_p"); let sigma = generated_ident("__pharmsol_sigma"); - let input_aliases = generate_closure_input_aliases(diffusion, &[p.clone(), sigma.clone()])?; + let full_inputs = [p.clone(), sigma.clone()]; + let reduced_inputs = [sigma.clone()]; + let input_aliases = generate_supported_input_aliases( + diffusion, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `diffusion` to have either 2 parameters: |p, sigma| or 1 parameter: |sigma|", + )?; let parameter_bindings = generate_parameter_bindings(params, diffusion, &p); let body = &diffusion.body; @@ -1885,24 +2359,24 @@ fn expand_sde_route_map( label: &str, closure: &ExprClosure, params: &[Ident], + covariates: &[Ident], route_bindings: &[(Ident, usize)], ) -> syn::Result { - if closure_param_names(closure).len() != 3 { - return Err(syn::Error::new_spanned( - closure, - format!( - "declaration-first `sde!` requires `{label}` to have 3 parameters: |p, t, cov|" - ), - )); - } - let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); - let input_aliases = - generate_closure_input_aliases(closure, &[p.clone(), t.clone(), cov.clone()])?; + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "declaration-first `sde!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); let body = &closure.body; Ok(quote! {{ @@ -1916,6 +2390,7 @@ fn expand_sde_route_map( #input_aliases #route_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_route_map @@ -1925,23 +2400,23 @@ fn expand_sde_route_map( fn expand_sde_init( init: &ExprClosure, params: &[Ident], + covariates: &[Ident], states: &[Ident], ) -> syn::Result { - if closure_param_names(init).len() != 4 { - return Err(syn::Error::new_spanned( - init, - "declaration-first `sde!` requires `init` to have 4 parameters: |p, t, cov, x|", - )); - } - let state_consts = generate_index_consts(states); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); let x = generated_ident("__pharmsol_x"); - let input_aliases = - generate_closure_input_aliases(init, &[p.clone(), t.clone(), cov.clone(), x.clone()])?; + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); let body = &init.body; Ok(quote! {{ @@ -1957,6 +2432,7 @@ fn expand_sde_init( #input_aliases #state_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_init @@ -1966,16 +2442,10 @@ fn expand_sde_init( fn expand_sde_out( out: &ExprClosure, params: &[Ident], + covariates: &[Ident], states: &[Ident], outputs: &[Ident], ) -> syn::Result { - if closure_param_names(out).len() != 5 { - return Err(syn::Error::new_spanned( - out, - "declaration-first `sde!` requires `out` to have 5 parameters: |x, p, t, cov, y|", - )); - } - let state_consts = generate_index_consts(states); let output_consts = generate_index_consts(outputs); let x = generated_ident("__pharmsol_x"); @@ -1983,11 +2453,15 @@ fn expand_sde_out( let t = generated_ident("__pharmsol_t"); let cov = generated_ident("__pharmsol_cov"); let y = generated_ident("__pharmsol_y"); - let input_aliases = generate_closure_input_aliases( + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( out, - &[x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()], + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); let body = &out.body; Ok(quote! {{ @@ -2006,6 +2480,7 @@ fn expand_sde_out( #state_consts #output_consts #parameter_bindings + #covariate_bindings #body }; __pharmsol_out @@ -2071,6 +2546,7 @@ pub fn ode(input: TokenStream) -> TokenStream { let diffeq = match expand_diffeq( &input.diffeq, &input.params, + &input.covariates, &input.states, &input.routes, &route_bindings, @@ -2080,7 +2556,13 @@ pub fn ode(input: TokenStream) -> TokenStream { Err(error) => return error.to_compile_error().into(), }; - let out = match expand_out(&input.out, &input.params, &input.states, &input.outputs) { + let out = match expand_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { Ok(out) => out, Err(error) => return error.to_compile_error().into(), }; @@ -2104,8 +2586,13 @@ pub fn ode(input: TokenStream) -> TokenStream { }; let lag = match input.lag.as_ref() { - Some(closure) => match expand_ode_route_map("lag", closure, &input.params, &route_bindings) - { + Some(closure) => match expand_ode_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { Ok(lag) => lag, Err(error) => return error.to_compile_error().into(), }, @@ -2114,7 +2601,13 @@ pub fn ode(input: TokenStream) -> TokenStream { let fa = match input.fa.as_ref() { Some(closure) => { - match expand_ode_route_map("fa", closure, &input.params, &route_bindings) { + match expand_ode_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { Ok(fa) => fa, Err(error) => return error.to_compile_error().into(), } @@ -2122,10 +2615,15 @@ pub fn ode(input: TokenStream) -> TokenStream { None => quote! { |_, _, _| ::std::collections::HashMap::new() }, }; - let init = input - .init - .as_ref() - .map_or_else(|| quote! { |_, _, _, _| {} }, |closure| quote! { #closure }); + let init = match input.init.as_ref() { + Some(closure) => { + match expand_ode_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _, _| {} }, + }; quote! {{ let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) @@ -2207,15 +2705,34 @@ pub fn analytical(input: TokenStream) -> TokenStream { None => HashSet::new(), }; - let out = match expand_analytical_out(&input.out, &input.params, &input.states, &input.outputs) - { + let sec = match input.sec.as_ref() { + Some(closure) => match expand_analytical_sec(closure, &input.params, &input.covariates) { + Ok(sec) => sec, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| {} }, + }; + + let out = match expand_analytical_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { Ok(out) => out, Err(error) => return error.to_compile_error().into(), }; let lag = match input.lag.as_ref() { Some(closure) => { - match expand_analytical_route_map("lag", closure, &input.params, &route_bindings) { + match expand_analytical_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { Ok(lag) => lag, Err(error) => return error.to_compile_error().into(), } @@ -2225,7 +2742,13 @@ pub fn analytical(input: TokenStream) -> TokenStream { let fa = match input.fa.as_ref() { Some(closure) => { - match expand_analytical_route_map("fa", closure, &input.params, &route_bindings) { + match expand_analytical_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { Ok(fa) => fa, Err(error) => return error.to_compile_error().into(), } @@ -2234,10 +2757,12 @@ pub fn analytical(input: TokenStream) -> TokenStream { }; let init = match input.init.as_ref() { - Some(closure) => match expand_analytical_init(closure, &input.params, &input.states) { - Ok(init) => init, - Err(error) => return error.to_compile_error().into(), - }, + Some(closure) => { + match expand_analytical_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } None => quote! { |_, _, _, _| {} }, }; @@ -2247,16 +2772,25 @@ pub fn analytical(input: TokenStream) -> TokenStream { let name = &input.name; let params = &input.params; + let covariates = &input.covariates; let states = &input.states; let outputs = &input.outputs; let routes = expand_analytical_route_metadata(&input.routes, &lag_routes, &fa_routes); let runtime_path = kernel_spec.runtime_path; let metadata_kernel = kernel_spec.metadata_kernel; + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; quote! {{ let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) .kind(::pharmsol::equation::ModelKind::Analytical) .parameters([#(stringify!(#params)),*]) + #covariate_metadata .states([#(stringify!(#states)),*]) .outputs([#(stringify!(#outputs)),*]) #(.route(#routes))* @@ -2264,7 +2798,7 @@ pub fn analytical(input: TokenStream) -> TokenStream { ::pharmsol::equation::Analytical::new( #runtime_path, - |_, _, _| {}, + #sec, #lag, #fa, #init, @@ -2333,6 +2867,7 @@ pub fn sde(input: TokenStream) -> TokenStream { let drift = match expand_sde_drift( &input.drift, &input.params, + &input.covariates, &input.states, &input.routes, &route_bindings, @@ -2347,8 +2882,13 @@ pub fn sde(input: TokenStream) -> TokenStream { }; let lag = match input.lag.as_ref() { - Some(closure) => match expand_sde_route_map("lag", closure, &input.params, &route_bindings) - { + Some(closure) => match expand_sde_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { Ok(lag) => lag, Err(error) => return error.to_compile_error().into(), }, @@ -2357,7 +2897,13 @@ pub fn sde(input: TokenStream) -> TokenStream { let fa = match input.fa.as_ref() { Some(closure) => { - match expand_sde_route_map("fa", closure, &input.params, &route_bindings) { + match expand_sde_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { Ok(fa) => fa, Err(error) => return error.to_compile_error().into(), } @@ -2366,14 +2912,22 @@ pub fn sde(input: TokenStream) -> TokenStream { }; let init = match input.init.as_ref() { - Some(closure) => match expand_sde_init(closure, &input.params, &input.states) { - Ok(init) => init, - Err(error) => return error.to_compile_error().into(), - }, + Some(closure) => { + match expand_sde_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } None => quote! { |_, _, _, _| {} }, }; - let out = match expand_sde_out(&input.out, &input.params, &input.states, &input.outputs) { + let out = match expand_sde_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { Ok(out) => out, Err(error) => return error.to_compile_error().into(), }; @@ -2384,18 +2938,27 @@ pub fn sde(input: TokenStream) -> TokenStream { let name = &input.name; let params = &input.params; + let covariates = &input.covariates; let states = &input.states; let outputs = &input.outputs; let particles = &input.particles; let routes = expand_sde_route_metadata(&input.routes, &lag_routes, &fa_routes); let bolus_mappings = expand_injected_sde_bolus_mappings(&input.routes, &input.states, &route_bindings); + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; quote! {{ let __pharmsol_particles: usize = #particles; let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) .kind(::pharmsol::equation::ModelKind::Sde) .parameters([#(stringify!(#params)),*]) + #covariate_metadata .states([#(stringify!(#states)),*]) .outputs([#(stringify!(#outputs)),*]) #(.route(#routes))* @@ -2520,11 +3083,13 @@ mod tests { #[test] fn analytical_accepts_extra_parameters_beyond_kernel_arity() { let input = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v, tlag], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", ) .expect("extra declared parameters should be allowed"); - assert_eq!(input.params.len(), 4); + assert_eq!(input.params.len(), 5); + assert_eq!(input.covariates.len(), 2); + assert!(input.sec.is_some()); assert_eq!(input.states.len(), 2); } diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 5186fcf9..2f0f4684 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1048,6 +1048,7 @@ impl NativeAnalyticalModel { if let Some(next_event) = events.get(index + 1) { self.solve_interval( + &mut *session, &mut state, support_point, occasion.covariates(), @@ -1064,6 +1065,7 @@ impl NativeAnalyticalModel { fn solve_interval( &self, + session: &mut dyn KernelSession, state: &mut [f64], support_point: &[f64], covariates: &Covariates, @@ -1090,11 +1092,29 @@ impl NativeAnalyticalModel { breakpoints.dedup_by(|lhs, rhs| (*lhs - *rhs).abs() < 1e-12); let mut current = breakpoints[0]; - let projected = project_analytical_parameters(&self.shared.info, support_point)?; + let mut cov_buf = vec![0.0; self.shared.info.covariates.len()]; + let mut derived = vec![0.0; self.shared.info.derived_len]; for next in breakpoints.iter().copied().skip(1) { let dt = next - current; - let route_inputs = active_route_inputs(infusions, current, self.shared.info.route_len); + let route_inputs = interval_route_inputs( + infusions, + current, + next, + self.shared.info.route_len, + ); + self.shared.refresh_derived( + session, + next, + state, + support_point, + covariates, + &route_inputs, + &mut derived, + &mut cov_buf, + )?; + let projected = + project_analytical_parameters(&self.shared.info, support_point, &derived)?; let next_state = apply_analytical_kernel( self.shared.info.analytical.ok_or_else(|| { PharmsolError::OtherError(format!( @@ -1392,6 +1412,22 @@ fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> V values } +fn interval_route_inputs( + infusions: &[Infusion], + start_time: f64, + end_time: f64, + route_len: usize, +) -> Vec { + let mut values = vec![0.0; route_len]; + for infusion in infusions { + let finish = infusion.time() + infusion.duration(); + if infusion.input() < route_len && start_time >= infusion.time() && end_time <= finish { + values[infusion.input()] += infusion.amount() / infusion.duration(); + } + } + values +} + fn sort_events(events: &mut [Event]) { events.sort_by(|lhs, rhs| { fn order(event: &Event) -> u8 { @@ -1413,6 +1449,7 @@ fn sort_events(events: &mut [Event]) { fn project_analytical_parameters( info: &NativeModelInfo, support_point: &[f64], + derived: &[f64], ) -> Result { let kernel = info.analytical.ok_or_else(|| { PharmsolError::OtherError(format!( @@ -1429,6 +1466,13 @@ fn project_analytical_parameters( support_point.len() ))); } + + // Analytical authoring models can project kernel arguments through a derive + // kernel by declaring exactly the built-in kernel arity in `derived`. + if derived.len() == arity { + return Ok(V::from_vec(derived.to_vec(), NalgebraContext)); + } + Ok(V::from_vec( support_point[..arity].to_vec(), NalgebraContext, diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index 0842f9c2..e025ec4f 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -32,6 +32,24 @@ fn shared_channel_subject(input: usize) -> Subject { .build() } +fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("analytical-macro-covariates") + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + fn macro_one_compartment() -> equation::Analytical { analytical! { name: "one_cpt_iv", @@ -42,7 +60,7 @@ fn macro_one_compartment() -> equation::Analytical { infusion(iv) -> central, }, structure: one_compartment, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -85,17 +103,17 @@ fn macro_one_compartment_with_absorption() -> equation::Analytical { bolus(oral) -> gut, }, structure: one_compartment_with_absorption, - lag: |_p, _t, _cov| { + lag: |_t| { lag! { oral => tlag } }, - fa: |_p, _t, _cov| { + fa: |_t| { fa! { oral => f_oral } }, - init: |_p, _t, _cov, x| { + init: |_t, x| { x[gut] = 0.0; x[central] = 0.0; }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -153,13 +171,13 @@ fn macro_shared_channel_analytical() -> equation::Analytical { infusion(iv) -> central, }, structure: one_compartment_with_absorption, - lag: |_p, _t, _cov| { + lag: |_t| { lag! { oral => tlag } }, - fa: |_p, _t, _cov| { + fa: |_t| { fa! { oral => f_oral } }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -204,6 +222,155 @@ fn handwritten_shared_channel_analytical() -> equation::Analytical { .expect("handwritten shared-channel analytical metadata should validate") } +fn macro_covariate_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_covariates", + params: [ka, ke, v, tlag, f_oral, base_gut, base_central, tvke], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + sec: |_t| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + ke = tvke * wt_scale * renal_scale; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_covariate_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_covariates") + .kind(equation::ModelKind::Analytical) + .parameters([ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten covariate analytical metadata should validate") +} + fn assert_prediction_match(left: &[f64], right: &[f64]) { assert_eq!(left.len(), right.len()); for (left, right) in left.iter().zip(right.iter()) { @@ -297,3 +464,32 @@ fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_pre assert_prediction_match(¯o_predictions, &handwritten_predictions); } + +#[test] +fn analytical_macro_covariates_lower_to_handwritten_behavior() { + let macro_model = macro_covariate_analytical(); + let handwritten_model = handwritten_covariate_analytical(); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + + let oral = macro_model.route_index("oral").expect("oral route exists"); + let iv = macro_model.route_index("iv").expect("iv route exists"); + let cp = macro_model.output_index("cp").expect("cp output exists"); + let subject = covariate_subject(oral, iv, cp); + let support_point = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; + + assert_eq!(oral, iv); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro covariate analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs new file mode 100644 index 00000000..620d7639 --- /dev/null +++ b/tests/full_feature_macro_parity.rs @@ -0,0 +1,473 @@ +use pharmsol::prelude::*; + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn macro_ode_model() -> equation::ODE { + ode! { + name: "ode_full_feature_parity", + params: [ka, ke, kcp, kpc, v, tlag, f_oral, base_depot, base_central, base_peripheral], + covariates: [wt, renal], + states: [depot, central, peripheral], + outputs: [cp], + routes: { + bolus(oral) -> depot, + bolus(load) -> central, + infusion(iv) -> central, + }, + diffeq: |x, _t, dx, bolus, rateiv| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[depot] = bolus[oral] - ka * x[depot]; + dx[central] = bolus[load] + ka * x[depot] + rateiv[iv] + - (adjusted_ke + adjusted_kcp) * x[central] + + kpc * x[peripheral]; + dx[peripheral] = adjusted_kcp * x[central] - kpc * x[peripheral]; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[depot] = base_depot + 0.05 * wt; + x[central] = base_central + 0.1 * renal; + x[peripheral] = base_peripheral + 0.02 * wt; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_ode_model() -> equation::ODE { + equation::ODE::new( + |x, p, t, dx, bolus, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + kcp, + kpc, + _v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[0] = bolus[0] - ka * x[0]; + dx[1] = bolus[1] + ka * x[0] + rateiv[0] + - (adjusted_ke + adjusted_kcp) * x[1] + + kpc * x[2]; + dx[2] = adjusted_kcp * x[1] - kpc * x[2]; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + _f_oral, + base_depot, + base_central, + base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_depot + 0.05 * wt; + x[1] = base_central + 0.1 * renal; + x[2] = base_peripheral + 0.02 * wt; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(3) + .with_ndrugs(2) + .with_nout(1) + .with_metadata( + equation::metadata::new("ode_full_feature_parity") + .parameters([ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["depot", "central", "peripheral"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::bolus("load") + .to_state("central") + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten ODE metadata should validate") +} + +fn build_ode_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("macro-vs-handwritten-ode-full-features") + .bolus(0.0, 80.0, load) + .bolus(1.0, 120.0, oral) + .infusion(6.0, 150.0, iv, 2.5) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +fn macro_analytical_model() -> equation::Analytical { + analytical! { + name: "analytical_full_feature_parity", + params: [ka, ke, v, tlag, f_oral, base_gut, base_central, tvke], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + bolus(load) -> central, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + sec: |_t| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + ke = tvke * wt_scale * renal_scale; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_analytical_model() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(2) + .with_nout(1) + .with_metadata( + equation::metadata::new("analytical_full_feature_parity") + .kind(equation::ModelKind::Analytical) + .parameters([ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::bolus("load").to_state("central"), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten analytical metadata should validate") +} + +fn build_analytical_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("macro-vs-handwritten-analytical-full-features") + .bolus(0.0, 60.0, load) + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +#[test] +fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_ode_model(); + let handwritten_ode = handwritten_ode_model(); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let oral = macro_ode.route_index("oral").expect("oral route exists"); + let load = macro_ode.route_index("load").expect("load route exists"); + let iv = macro_ode.route_index("iv").expect("iv route exists"); + let cp = macro_ode.output_index("cp").expect("cp output exists"); + + assert_eq!(oral, iv); + assert_eq!(load, 1); + assert_eq!(handwritten_ode.route_index("oral"), Some(oral)); + assert_eq!(handwritten_ode.route_index("load"), Some(load)); + assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); + assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + + let subject = build_ode_subject(oral, load, iv, cp); + let params = [1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0]; + + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let diff = max_abs_diff( + ¯o_predictions.flat_predictions(), + &handwritten_predictions.flat_predictions(), + ); + assert!( + diff <= 1e-10, + "macro and handwritten ODE predictions diverged: {diff:e}" + ); + + Ok(()) +} + +#[test] +fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { + let macro_analytical = macro_analytical_model(); + let handwritten_analytical = handwritten_analytical_model(); + + assert_eq!( + macro_analytical.metadata(), + handwritten_analytical.metadata() + ); + + let oral = macro_analytical + .route_index("oral") + .expect("oral route exists"); + let load = macro_analytical + .route_index("load") + .expect("load route exists"); + let iv = macro_analytical.route_index("iv").expect("iv route exists"); + let cp = macro_analytical + .output_index("cp") + .expect("cp output exists"); + + assert_eq!(oral, iv); + assert_eq!(load, 1); + assert_eq!(handwritten_analytical.route_index("oral"), Some(oral)); + assert_eq!(handwritten_analytical.route_index("load"), Some(load)); + assert_eq!(handwritten_analytical.route_index("iv"), Some(iv)); + assert_eq!(handwritten_analytical.output_index("cp"), Some(cp)); + + let subject = build_analytical_subject(oral, load, iv, cp); + let params = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; + + let macro_predictions = macro_analytical.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_analytical.estimate_predictions(&subject, ¶ms)?; + + let diff = max_abs_diff( + ¯o_predictions.flat_predictions(), + &handwritten_predictions.flat_predictions(), + ); + assert!( + diff <= 1e-10, + "macro and handwritten analytical predictions diverged: {diff:e}" + ); + + Ok(()) +} \ No newline at end of file diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 1cc4cb5c..7b068733 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -42,10 +42,10 @@ fn injected_macro_ode() -> equation::ODE { routes: { infusion(iv) -> central, }, - diffeq: |x, _p, _t, dx, _cov| { + diffeq: |x, _t, dx| { dx[central] = -ke * x[central]; }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -91,10 +91,10 @@ fn explicit_macro_ode() -> equation::ODE { routes: { infusion(iv) -> central, }, - diffeq: |x, _p, _t, dx, _bolus, rateiv, _cov| { + diffeq: |x, _t, dx, _bolus, rateiv| { dx[central] = rateiv[iv] - ke * x[central]; }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -141,17 +141,17 @@ fn shared_channel_macro_ode() -> equation::ODE { bolus(oral) -> depot, infusion(iv) -> central, }, - diffeq: |x, _p, _t, dx, bolus, rateiv, _cov| { + diffeq: |x, _t, dx, bolus, rateiv| { dx[depot] = bolus[oral] - ka * x[depot]; dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; }, - lag: |_p, _t, _cov| { + lag: |_t| { lag! { oral => tlag } }, - fa: |_p, _t, _cov| { + fa: |_t| { fa! { oral => f_oral } }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -210,13 +210,12 @@ fn covariate_macro_ode() -> equation::ODE { routes: { bolus(oral) -> gut, }, - diffeq: |x, _p, t, dx, cov| { - fetch_cov!(cov, t, wt); + diffeq: |x, _t, dx| { let scaled_ke = ke * (wt / 70.0); dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - scaled_ke * x[central]; }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } diff --git a/tests/runtime_backend_matrix.rs b/tests/runtime_backend_matrix.rs index 6a207398..14b959b2 100644 --- a/tests/runtime_backend_matrix.rs +++ b/tests/runtime_backend_matrix.rs @@ -84,6 +84,85 @@ mod tests { Ok(()) } + #[test] + fn analytical_full_runtime_backend_matrix_matches_reference_predictions( + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(CorpusCase::AnalyticalFull)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-jit", + &jit, + )?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = + corpus::compile_runtime_native_aot_model(CorpusCase::AnalyticalFull, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-native-aot", + &aot, + )?; + + let wasm = corpus::compile_runtime_wasm_model(CorpusCase::AnalyticalFull)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-wasm", + &wasm, + )?; + corpus::assert_runtime_models_match_each_other( + CorpusCase::AnalyticalFull, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + Ok(()) + } + + #[test] + fn ode_full_runtime_backend_matrix_matches_reference_predictions( + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(CorpusCase::OdeFull)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + corpus::assert_runtime_model_matches_reference(CorpusCase::OdeFull, "runtime-jit", &jit)?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(CorpusCase::OdeFull, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + corpus::assert_runtime_model_matches_reference( + CorpusCase::OdeFull, + "runtime-native-aot", + &aot, + )?; + + let wasm = corpus::compile_runtime_wasm_model(CorpusCase::OdeFull)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + corpus::assert_runtime_model_matches_reference(CorpusCase::OdeFull, "runtime-wasm", &wasm)?; + corpus::assert_runtime_models_match_each_other( + CorpusCase::OdeFull, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + Ok(()) + } + #[test] fn sde_runtime_backend_matrix_matches_reference_predictions( ) -> Result<(), Box> { diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 289ab127..0f520b92 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -33,6 +33,24 @@ fn shared_channel_subject(input: usize) -> Subject { .build() } +fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("sde-macro-covariates") + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + fn prediction_means(predictions: &ndarray::Array2) -> Vec { predictions .get_predictions() @@ -58,13 +76,13 @@ fn macro_infusion_sde() -> equation::SDE { routes: { infusion(iv) -> central, }, - drift: |x, _p, _t, dx, _cov| { + drift: |x, _t, dx| { dx[central] = -ke * x[central]; }, - diffusion: |_p, sigma| { + diffusion: |sigma| { sigma[central] = sigma_ke; }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -118,25 +136,25 @@ fn macro_absorption_sde() -> equation::SDE { routes: { bolus(oral) -> gut, }, - drift: |x, _p, _t, dx, _cov| { + drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; }, - diffusion: |_p, sigma| { + diffusion: |sigma| { sigma[gut] = 0.0 * sigma_ke; sigma[central] = sigma_ke; }, - lag: |_p, _t, _cov| { + lag: |_t| { lag! { oral => tlag } }, - fa: |_p, _t, _cov| { + fa: |_t| { fa! { oral => f_oral } }, - init: |_p, _t, _cov, x| { + init: |_t, x| { x[gut] = 0.0; x[central] = 0.0; }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -204,25 +222,25 @@ fn macro_shared_channel_sde() -> equation::SDE { bolus(oral) -> gut, infusion(iv) -> central, }, - drift: |x, _p, _t, dx, _cov| { + drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; }, - diffusion: |_p, sigma| { + diffusion: |sigma| { sigma[gut] = 0.0; sigma[central] = 0.0; }, - lag: |_p, _t, _cov| { + lag: |_t| { lag! { oral => tlag } }, - fa: |_p, _t, _cov| { + fa: |_t| { fa! { oral => f_oral } }, - init: |_p, _t, _cov, x| { + init: |_t, x| { x[gut] = 0.0; x[central] = 0.0; }, - out: |x, _p, _t, _cov, y| { + out: |x, _t, y| { y[cp] = x[central] / v; }, } @@ -281,6 +299,125 @@ fn handwritten_shared_channel_sde() -> equation::SDE { .expect("handwritten shared-channel SDE metadata should validate") } +fn macro_covariate_sde() -> equation::SDE { + sde! { + name: "one_cmt_sde_covariates", + params: [ka, ke, sigma_ke, v, tlag, f_oral, base_gut, base_central], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + drift: |x, _t, dx| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - adjusted_ke * x[central]; + }, + diffusion: |sigma| { + sigma[gut] = 0.0 * sigma_ke; + sigma[central] = 0.0 * sigma_ke; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_covariate_sde() -> equation::SDE { + equation::SDE::new( + |x, p, t, dx, rateiv, cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral, _base_gut, _base_central); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - adjusted_ke * x[1]; + }, + |p, sigma| { + fetch_params!(p, _ka, _ke, sigma_ke, _v, _tlag, _f_oral, _base_gut, _base_central); + sigma[0] = 0.0 * sigma_ke; + sigma[1] = 0.0 * sigma_ke; + }, + |p, t, cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral, _base_gut, _base_central); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral, _base_gut, _base_central); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, _f_oral, base_gut, base_central); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral, _base_gut, _base_central); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_covariates") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral", "base_gut", "base_central"]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten covariate SDE metadata should validate") +} + #[test] fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_infusion_sde(); @@ -357,3 +494,31 @@ fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_prediction &prediction_means(&handwritten_predictions), ); } + +#[test] +fn sde_macro_covariates_lower_to_handwritten_behavior() { + let macro_model = macro_covariate_sde(); + let handwritten_model = handwritten_covariate_sde(); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + + let oral = macro_model.route_index("oral").expect("oral route exists"); + let iv = macro_model.route_index("iv").expect("iv route exists"); + let cp = macro_model.output_index("cp").expect("cp output exists"); + let subject = covariate_subject(oral, iv, cp); + let support_point = [1.0, 0.16, 0.0, 32.0, 0.5, 0.8, 3.0, 14.0]; + + assert_eq!(oral, iv); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro covariate SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 6a14ed33..a1603578 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -43,6 +43,38 @@ dx(central) = ka * depot - ke * central out(cp) = central / v ~ continuous() "#; +const ODE_FULL_SOURCE: &str = r#" +name = ode_full_feature_parity +kind = ode + +params = ka, ke, kcp, kpc, v, tlag, f_oral, base_depot, base_central, base_peripheral +covariates = wt@linear, renal@linear +derived = adjusted_ke, adjusted_kcp, adjusted_v +states = depot, central, peripheral +outputs = cp + +bolus(oral) -> depot +bolus(load) -> central +infusion(iv) -> central + +lag(oral) = tlag * sqrt(wt / 70.0) * pow(90.0 / renal, 0.1) +fa(oral) = min(max(f_oral * pow(renal / 90.0, 0.1), 0.0), 1.0) + +adjusted_ke = ke * pow(wt / 70.0, 0.75) * pow(renal / 90.0, 0.25) +adjusted_kcp = kcp * pow(wt / 70.0, 0.25) +adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)) + +dx(depot) = -ka * depot +dx(central) = ka * depot - (adjusted_ke + adjusted_kcp) * central + kpc * peripheral +dx(peripheral) = adjusted_kcp * central - kpc * peripheral + +init(depot) = base_depot + 0.05 * wt +init(central) = base_central + 0.1 * renal +init(peripheral) = base_peripheral + 0.02 * wt + +out(cp) = central / adjusted_v ~ continuous() +"#; + const ANALYTICAL_SOURCE: &str = r#" name = one_cmt_abs kind = analytical @@ -61,6 +93,34 @@ structure = one_compartment_with_absorption out(cp) = central / v ~ continuous() "#; +const ANALYTICAL_FULL_SOURCE: &str = r#" +name = analytical_full_feature_parity +kind = analytical + +params = ka, ke, v, tlag, f_oral, base_gut, base_central, tvke +covariates = wt@linear, renal@linear +derived = ka_proj, ke_proj +states = gut, central +outputs = cp + +bolus(oral) -> gut +bolus(load) -> central +infusion(iv) -> central + +lag(oral) = tlag * sqrt(wt / 70.0) * pow(90.0 / renal, 0.1) +fa(oral) = min(max(f_oral * pow(renal / 90.0, 0.1), 0.0), 1.0) + +ka_proj = ka +ke_proj = tvke * pow(wt / 70.0, 0.75) * pow(renal / 90.0, 0.25) + +structure = one_compartment_with_absorption + +init(gut) = base_gut + 0.03 * wt +init(central) = base_central + 0.08 * renal + +out(cp) = central / (v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0))) ~ continuous() +"#; + const SDE_SOURCE: &str = r#" name = vanco_sde kind = sde @@ -90,7 +150,9 @@ pub const SDE_PARTICLE_COUNT: usize = 16; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CorpusCase { Ode, + OdeFull, Analytical, + AnalyticalFull, Sde, } @@ -98,7 +160,9 @@ impl CorpusCase { pub fn label(self) -> &'static str { match self { Self::Ode => "dsl-ode-one_cmt_oral_iv", + Self::OdeFull => "dsl-ode-full-feature-parity", Self::Analytical => "dsl-analytical-one_cmt_abs", + Self::AnalyticalFull => "dsl-analytical-full-feature-parity", Self::Sde => "dsl-sde-vanco_sde", } } @@ -106,7 +170,9 @@ impl CorpusCase { pub fn model_name(self) -> &'static str { match self { Self::Ode => "one_cmt_oral_iv", + Self::OdeFull => "ode_full_feature_parity", Self::Analytical => "one_cmt_abs", + Self::AnalyticalFull => "analytical_full_feature_parity", Self::Sde => "vanco_sde", } } @@ -114,7 +180,9 @@ impl CorpusCase { fn source(self) -> &'static str { match self { Self::Ode => ODE_SOURCE, + Self::OdeFull => ODE_FULL_SOURCE, Self::Analytical => ANALYTICAL_SOURCE, + Self::AnalyticalFull => ANALYTICAL_FULL_SOURCE, Self::Sde => SDE_SOURCE, } } @@ -122,7 +190,9 @@ impl CorpusCase { pub fn tolerance(self) -> f64 { match self { Self::Ode => 1e-4, + Self::OdeFull => 1e-4, Self::Analytical => 1e-8, + Self::AnalyticalFull => 1e-8, Self::Sde => 1e-4, } } @@ -130,7 +200,9 @@ impl CorpusCase { pub fn support_point(self) -> &'static [f64] { match self { Self::Ode => &[1.2, 5.0, 40.0, 0.5, 0.8], + Self::OdeFull => &[1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0], Self::Analytical => &[1.0, 0.15, 25.0, 0.5, 0.8], + Self::AnalyticalFull => &[1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16], Self::Sde => &[1.1, 0.2, 0.12, 0.08, 15.0, 0.0], } } @@ -160,6 +232,34 @@ impl CorpusCase { .missing_observation(9.0, cp) .build() } + Self::OdeFull => { + let oral = model.route_index("oral").ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + let load = model.route_index("load").ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + let iv = model.route_index("iv").ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; + Subject::builder(self.label()) + .bolus(0.0, 80.0, load) + .bolus(1.0, 120.0, oral) + .infusion(6.0, 150.0, iv, 2.5) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() + } Self::Analytical => { let oral = model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) @@ -172,6 +272,34 @@ impl CorpusCase { .missing_observation(4.0, cp) .build() } + Self::AnalyticalFull => { + let oral = model.route_index("oral").ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + let load = model.route_index("load").ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + let iv = model.route_index("iv").ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; + Subject::builder(self.label()) + .bolus(0.0, 60.0, load) + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() + } Self::Sde => { let oral = model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) @@ -193,9 +321,13 @@ impl CorpusCase { fn reference_predictions(self) -> Result> { match self { Self::Ode => Ok(ExpectedPredictions::Subject(reference_ode_predictions()?)), + Self::OdeFull => Ok(ExpectedPredictions::Subject(reference_ode_full_predictions()?)), Self::Analytical => Ok(ExpectedPredictions::Subject( reference_analytical_predictions()?, )), + Self::AnalyticalFull => Ok(ExpectedPredictions::Subject( + reference_analytical_full_predictions()?, + )), Self::Sde => Ok(ExpectedPredictions::Particles(reference_sde_predictions()?)), } } @@ -605,6 +737,138 @@ fn reference_ode_predictions() -> Result> { )?) } +fn reference_ode_full_predictions() -> Result> { + Ok(equation::ODE::new( + |x, p, t, dx, bolus, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + kcp, + kpc, + _v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[0] = bolus[0] - ka * x[0]; + dx[1] = bolus[1] + ka * x[0] + rateiv[0] + - (adjusted_ke + adjusted_kcp) * x[1] + + kpc * x[2]; + dx[2] = adjusted_kcp * x[1] - kpc * x[2]; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + _f_oral, + base_depot, + base_central, + base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_depot + 0.05 * wt; + x[1] = base_central + 0.1 * renal; + x[2] = base_peripheral + 0.02 * wt; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(3) + .with_ndrugs(2) + .with_nout(1) + .estimate_predictions( + &Subject::builder(CorpusCase::OdeFull.label()) + .bolus(0.0, 80.0, 1) + .bolus(1.0, 120.0, 0) + .infusion(6.0, 150.0, 0, 2.5) + .missing_observation(0.25, 0) + .missing_observation(0.75, 0) + .missing_observation(1.5, 0) + .missing_observation(3.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .missing_observation(12.0, 0) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build(), + CorpusCase::OdeFull.support_point(), + )?) +} + fn reference_analytical_predictions() -> Result> { Ok(equation::Analytical::new( one_compartment_with_absorption, @@ -638,6 +902,110 @@ fn reference_analytical_predictions() -> Result Result> { + Ok(equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(2) + .with_nout(1) + .estimate_predictions( + &Subject::builder(CorpusCase::AnalyticalFull.label()) + .bolus(0.0, 60.0, 1) + .bolus(1.0, 100.0, 0) + .infusion(6.0, 140.0, 0, 2.0) + .missing_observation(0.25, 0) + .missing_observation(0.75, 0) + .missing_observation(1.5, 0) + .missing_observation(3.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .missing_observation(12.0, 0) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build(), + CorpusCase::AnalyticalFull.support_point(), + )?) +} + fn reference_sde_predictions() -> Result, Box> { Ok(SDE::new( |x, p, _t, dx, _rateiv, _cov| { From 8b142a6971d959c8453ff6ab4e5d018de24724ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 09:36:18 +0100 Subject: [PATCH 05/22] chore: fmt --- src/dsl/native.rs | 8 +-- tests/full_feature_macro_parity.rs | 7 ++- tests/runtime_backend_matrix.rs | 3 +- tests/sde_macro_lowering.rs | 83 +++++++++++++++++++++++++++--- tests/support/runtime_corpus.rs | 9 ++-- 5 files changed, 87 insertions(+), 23 deletions(-) diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 2f0f4684..202fd45a 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1097,12 +1097,8 @@ impl NativeAnalyticalModel { for next in breakpoints.iter().copied().skip(1) { let dt = next - current; - let route_inputs = interval_route_inputs( - infusions, - current, - next, - self.shared.info.route_len, - ); + let route_inputs = + interval_route_inputs(infusions, current, next, self.shared.info.route_len); self.shared.refresh_derived( session, next, diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index 620d7639..71a1afa7 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -75,9 +75,8 @@ fn handwritten_ode_model() -> equation::ODE { let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); dx[0] = bolus[0] - ka * x[0]; - dx[1] = bolus[1] + ka * x[0] + rateiv[0] - - (adjusted_ke + adjusted_kcp) * x[1] - + kpc * x[2]; + dx[1] = + bolus[1] + ka * x[0] + rateiv[0] - (adjusted_ke + adjusted_kcp) * x[1] + kpc * x[2]; dx[2] = adjusted_kcp * x[1] - kpc * x[2]; }, |p, t, cov| { @@ -470,4 +469,4 @@ fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::P ); Ok(()) -} \ No newline at end of file +} diff --git a/tests/runtime_backend_matrix.rs b/tests/runtime_backend_matrix.rs index 14b959b2..fdabc94d 100644 --- a/tests/runtime_backend_matrix.rs +++ b/tests/runtime_backend_matrix.rs @@ -99,8 +99,7 @@ mod tests { )?; #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] - let aot = - corpus::compile_runtime_native_aot_model(CorpusCase::AnalyticalFull, &workspace)?; + let aot = corpus::compile_runtime_native_aot_model(CorpusCase::AnalyticalFull, &workspace)?; #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] assert_eq!(aot.backend(), RuntimeBackend::NativeAot); #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 0f520b92..876d2b23 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -345,7 +345,17 @@ fn macro_covariate_sde() -> equation::SDE { fn handwritten_covariate_sde() -> equation::SDE { equation::SDE::new( |x, p, t, dx, rateiv, cov| { - fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral, _base_gut, _base_central); + fetch_params!( + p, + ka, + ke, + _sigma_ke, + _v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); fetch_cov!(cov, t, wt, renal); let wt_scale = (wt / 70.0).powf(0.75); @@ -356,33 +366,83 @@ fn handwritten_covariate_sde() -> equation::SDE { dx[1] = ka * x[0] + rateiv[0] - adjusted_ke * x[1]; }, |p, sigma| { - fetch_params!(p, _ka, _ke, sigma_ke, _v, _tlag, _f_oral, _base_gut, _base_central); + fetch_params!( + p, + _ka, + _ke, + sigma_ke, + _v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); sigma[0] = 0.0 * sigma_ke; sigma[1] = 0.0 * sigma_ke; }, |p, t, cov| { - fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral, _base_gut, _base_central); + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central + ); fetch_cov!(cov, t, wt, renal); let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); lag! { 0 => tlag * lag_scale } }, |p, t, cov| { - fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral, _base_gut, _base_central); + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central + ); fetch_cov!(cov, t, wt, renal); let fa_scale = (renal / 90.0).powf(0.1); fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } }, |p, t, cov, x| { - fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, _f_oral, base_gut, base_central); + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central + ); fetch_cov!(cov, t, wt, renal); x[0] = base_gut + 0.03 * wt; x[1] = base_central + 0.08 * renal; }, |x, p, t, cov, y| { - fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral, _base_gut, _base_central); + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); fetch_cov!(cov, t, wt, renal); let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); @@ -396,7 +456,16 @@ fn handwritten_covariate_sde() -> equation::SDE { .with_metadata( equation::metadata::new("one_cmt_sde_covariates") .kind(equation::ModelKind::Sde) - .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral", "base_gut", "base_central"]) + .parameters([ + "ka", + "ke", + "sigma_ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + ]) .covariates([ equation::Covariate::continuous("wt"), equation::Covariate::continuous("renal"), diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index a1603578..3ed75511 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -321,7 +321,9 @@ impl CorpusCase { fn reference_predictions(self) -> Result> { match self { Self::Ode => Ok(ExpectedPredictions::Subject(reference_ode_predictions()?)), - Self::OdeFull => Ok(ExpectedPredictions::Subject(reference_ode_full_predictions()?)), + Self::OdeFull => Ok(ExpectedPredictions::Subject( + reference_ode_full_predictions()?, + )), Self::Analytical => Ok(ExpectedPredictions::Subject( reference_analytical_predictions()?, )), @@ -761,9 +763,8 @@ fn reference_ode_full_predictions() -> Result let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); dx[0] = bolus[0] - ka * x[0]; - dx[1] = bolus[1] + ka * x[0] + rateiv[0] - - (adjusted_ke + adjusted_kcp) * x[1] - + kpc * x[2]; + dx[1] = + bolus[1] + ka * x[0] + rateiv[0] - (adjusted_ke + adjusted_kcp) * x[1] + kpc * x[2]; dx[2] = adjusted_kcp * x[1] - kpc * x[2]; }, |p, t, cov| { From 9aa16951c65257c6e38c567f0c9a1e34705f9447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 09:57:00 +0100 Subject: [PATCH 06/22] chore: fix compilation ubuntu --- pharmsol-dsl/src/execution.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 0904bec2..886d570a 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -542,7 +542,7 @@ impl<'a> ExecutionLowerer<'a> { for (declaration_index, route) in model.routes.iter().enumerate() { let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?; if route.kind == Some(RouteKind::Infusion) { - for property in &route.properties { + if let Some(property) = route.properties.first() { let label = match property.kind { RoutePropertyKind::Lag => "lag", RoutePropertyKind::Bioavailability => "bioavailability", From 0a04f41a66a7fd67bc5c3bdcf46151c55e89cac2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 13:39:28 +0100 Subject: [PATCH 07/22] feature: INPUT/OUTEQ colums in the Pmetrics format are now read as Strings instead of Integer. This gives flexibility to modelers to define routes in terms of indices or named elements. All three surfaces for Model creation are also updated (macros, SDL, and ::new() constructors) --- examples/macro_vs_handwritten_one_cpt.rs | 15 +- examples/macro_vs_handwritten_two_cpt.rs | 30 +- pharmsol-dsl/src/authoring.rs | 51 +- pharmsol-dsl/src/parser.rs | 77 ++- .../tests/dsl_authoring_edge_cases.rs | 50 +- pharmsol-macros/src/lib.rs | 456 +++++++++++++++--- src/data/builder.rs | 12 +- src/data/event.rs | 145 +++++- src/data/parser/pmetrics.rs | 62 ++- src/data/row.rs | 34 +- src/data/structs.rs | 57 +-- src/dsl/native.rs | 182 +++++-- src/error/mod.rs | 4 + src/simulator/equation/analytical/mod.rs | 24 +- src/simulator/equation/mod.rs | 100 +++- src/simulator/equation/ode/closure.rs | 6 +- src/simulator/equation/ode/mod.rs | 44 +- src/simulator/equation/sde/mod.rs | 41 +- tests/analytical_macro_lowering.rs | 61 ++- tests/authoring_parity_corpus.rs | 81 +++- tests/ode_macro_lowering.rs | 360 +++++++++++++- tests/sde_macro_lowering.rs | 61 ++- 22 files changed, 1589 insertions(+), 364 deletions(-) diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index 4d8f74d0..ddff59f8 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -26,6 +26,9 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( + // Handwritten closures stay on dense internal channels. + // Public labels like `iv` and `cp` live in attached metadata, not in + // the low-level `rateiv[]` / `y[]` buffers. |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -75,12 +78,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); let subject = Subject::builder("macro-vs-handwritten-one-cpt") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") .build(); let params = [1.022, 194.0]; diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 9ab1a675..915267d6 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -29,6 +29,10 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( + // Handwritten closures stay on dense internal channels. + // Public route labels like `load` and `iv` are metadata names; the + // low-level `bolus[]`, `rateiv[]`, and `y[]` buffers remain indexed by + // dense internal slots. |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; @@ -88,19 +92,19 @@ fn main() -> Result<(), pharmsol::PharmsolError> { assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); let subject = Subject::builder("macro-vs-handwritten-two-cpt") - .bolus(0.0, 100.0, load) - .infusion(12.0, 200.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) - .missing_observation(12.5, cp) - .missing_observation(13.0, cp) - .missing_observation(14.0, cp) - .missing_observation(16.0, cp) - .missing_observation(24.0, cp) + .bolus(0.0, 100.0, "load") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") .build(); let params = [0.1, 0.05, 0.03, 50.0]; diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index c81c19eb..129f07c8 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -371,7 +371,7 @@ impl<'a> AuthoringParser<'a> { if lhs_trimmed == "outputs" { self.declared_outputs_span = Some(span); - for ident in parse_ident_list(rhs, rhs_abs)? { + for ident in parse_output_label_list(rhs, rhs_abs)? { self.declared_outputs.insert(ident.text.clone()); self.explicit_outputs.insert(ident.text, ident.span); } @@ -413,7 +413,20 @@ impl<'a> AuthoringParser<'a> { return self.parse_call_assignment(call, rhs, rhs_abs, span); } - let target = parse_ident_segment(lhs, lhs_abs)?; + let target = match parse_ident_segment(lhs, lhs_abs) { + Ok(target) => target, + Err(error) => { + if self.declared_outputs_span.is_none() { + return Err(error); + } + + let target = parse_output_label_segment(lhs, lhs_abs)?; + if !self.declared_outputs.contains(&target.text) { + return Err(self.undeclared_output_error(&target.text, target.span)); + } + target + } + }; let rhs = parse_surface_rhs(rhs, rhs_abs)?; let stmt = build_assignment_statement( AssignTarget { @@ -552,7 +565,7 @@ impl<'a> AuthoringParser<'a> { self.init_statements.push(stmt); } "out" => { - let output = parse_ident_segment(call.argument, call.argument_start)?; + let output = parse_output_label_segment(call.argument, call.argument_start)?; self.validate_output_target(&output)?; self.declared_outputs.insert(output.text.clone()); self.note_output_assignment(&output); @@ -839,6 +852,13 @@ fn parse_ident_list(src: &str, abs_start: usize) -> Result, ParseErro .collect() } +fn parse_output_label_list(src: &str, abs_start: usize) -> Result, ParseError> { + split_top_level(src, ',') + .into_iter() + .map(|(segment, start)| parse_output_label_segment(segment, abs_start + start)) + .collect() +} + fn parse_covariates_list(src: &str, abs_start: usize) -> Result, ParseError> { let mut covariates = Vec::new(); for (segment, start) in split_top_level(src, ',') { @@ -907,6 +927,27 @@ fn parse_ident_segment(src: &str, abs_start: usize) -> Result )) } +fn parse_output_label_segment(src: &str, abs_start: usize) -> Result { + let trimmed = src.trim(); + let leading = src.len() - src.trim_start().len(); + if trimmed.is_empty() { + return Err(ParseError::new( + "expected output label", + Span::new(abs_start, abs_start + src.len()), + )); + } + if !is_valid_output_label(trimmed) { + return Err(ParseError::new( + format!("expected output label, found `{trimmed}`"), + Span::new(abs_start + leading, abs_start + leading + trimmed.len()), + )); + } + Ok(Ident::new( + trimmed, + Span::new(abs_start + leading, abs_start + leading + trimmed.len()), + )) +} + fn parse_place_at(src: &str, abs_start: usize) -> Result { let mut place = parse_place_fragment(src).map_err(|error| error.shifted(abs_start))?; shift_place(&mut place, abs_start); @@ -1344,6 +1385,10 @@ fn is_valid_ident(src: &str) -> bool { chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') } +fn is_valid_output_label(src: &str) -> bool { + is_valid_ident(src) || src.chars().all(|ch| ch.is_ascii_digit()) +} + fn is_ident_byte(byte: u8) -> bool { (byte as char).is_ascii_alphanumeric() || byte == b'_' } diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index f07fbd50..c265b4df 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -106,12 +106,18 @@ struct Parser { #[derive(Clone, Copy)] enum LayoutBoundary { ModelItem, - Statement, + Statement(StatementContext), Binding, IdentItem, RouteDecl, } +#[derive(Clone, Copy, PartialEq, Eq)] +enum StatementContext { + Standard, + Outputs, +} + impl Parser { fn new(src: &str) -> Result { Ok(Self::from_tokens(lex(src)?, src.len())) @@ -655,8 +661,13 @@ impl Parser { fn parse_statement_block(&mut self, name: &str) -> Result { let start = self.bump().unwrap().span; let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = if name == "outputs" { + StatementContext::Outputs + } else { + StatementContext::Standard + }; let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -790,8 +801,9 @@ impl Parser { fn parse_stmt_body(&mut self) -> Result, ParseError> { let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = self.current_statement_context(); let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -854,7 +866,11 @@ impl Parser { } fn parse_assign_target(&mut self) -> Result { - let name = self.parse_ident()?; + let name = if matches!(self.current_statement_context(), StatementContext::Outputs) { + self.parse_output_target_name()? + } else { + self.parse_ident()? + }; let mut span = name.span; let kind = if let Some(open) = self.take_if(|kind| matches!(kind, TokenKind::LParen)) { let args = self.parse_expr_list(&open, TokenKindMatcher::RPAREN)?; @@ -885,6 +901,30 @@ impl Parser { Ok(AssignTarget { kind, span }) } + fn parse_output_target_name(&mut self) -> Result { + let token = self + .bump() + .ok_or_else(|| ParseError::new("expected output label", Span::empty(self.src_len)))?; + match token.kind { + TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), + TokenKind::Number(value) + if value.is_finite() + && value >= 0.0 + && value.fract() == 0.0 + && value <= usize::MAX as f64 => + { + Ok(Ident::new((value as usize).to_string(), token.span)) + } + other => Err(ParseError::new( + format!( + "expected output label identifier or non-negative integer, found {}", + other.describe() + ), + token.span, + )), + } + } + fn parse_ident(&mut self) -> Result { let token = self .bump() @@ -1320,9 +1360,12 @@ impl Parser { | TokenKind::Diffusion | TokenKind::Particles ), - LayoutBoundary::Statement => match &token.kind { + LayoutBoundary::Statement(context) => match &token.kind { TokenKind::If | TokenKind::For | TokenKind::Let => true, TokenKind::Ident(_) => self.line_starts_assignment_target(index), + TokenKind::Number(_) if matches!(context, StatementContext::Outputs) => { + self.line_starts_numeric_output_assignment(index) + } _ => false, }, LayoutBoundary::Binding => self.line_starts_named_assignment(index), @@ -1379,6 +1422,26 @@ impl Parser { } } + fn line_starts_numeric_output_assignment(&self, index: usize) -> bool { + matches!( + self.tokens.get(index).map(|token| &token.kind), + Some(TokenKind::Number(_)) + ) && self + .next_same_line_index(index) + .is_some_and(|next| matches!(self.tokens[next].kind, TokenKind::Eq)) + } + + fn current_statement_context(&self) -> StatementContext { + self.layout_boundaries + .iter() + .rev() + .find_map(|boundary| match boundary { + LayoutBoundary::Statement(context) => Some(*context), + _ => None, + }) + .unwrap_or(StatementContext::Standard) + } + fn next_same_line_index(&self, index: usize) -> Option { let next = index + 1; let token = self.tokens.get(next)?; @@ -1413,7 +1476,7 @@ impl Parser { fn current_boundary_label(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "next model item starts here", - Some(LayoutBoundary::Statement) => "next statement starts here", + Some(LayoutBoundary::Statement(_)) => "next statement starts here", Some(LayoutBoundary::Binding) => "next binding starts here", Some(LayoutBoundary::IdentItem) => "next declaration starts here", Some(LayoutBoundary::RouteDecl) => "next route starts here", @@ -1424,7 +1487,7 @@ impl Parser { fn current_boundary_subject(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "model item", - Some(LayoutBoundary::Statement) => "statement", + Some(LayoutBoundary::Statement(_)) => "statement", Some(LayoutBoundary::Binding) => "binding", Some(LayoutBoundary::IdentItem) => "declaration", Some(LayoutBoundary::RouteDecl) => "route", diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 797be3e9..404487dc 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -1,4 +1,4 @@ -use pharmsol_dsl::{analyze_model, parse_model, parse_module}; +use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model, parse_module}; #[test] fn output_annotation_is_optional() { @@ -161,6 +161,54 @@ out(cp) = central ~ continous() ); } +#[test] +fn mixed_named_and_numeric_output_labels_lower_and_round_trip() { + let src = r#" +name = mixed_output_labels +kind = ode +params = ke, v +states = central +outputs = cp, 0, 1 +infusion(iv) -> central +ddt(central) = -ke * central +out(cp) = central / v +out(0) = 2 * central / v +out(1) = 3 * central / v +"#; + + let module = parse_module(src).expect("mixed output labels should parse in authoring DSL"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(&model).expect("mixed output labels should analyze"); + let lowered = lower_typed_model(&typed).expect("mixed output labels should lower"); + + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "0", "1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.index) + .collect::>(), + vec![0, 1, 2] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered mixed-output model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 54a79fe3..96b9536e 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -6,13 +6,14 @@ use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use syn::{ parse::{Parse, ParseStream, Parser}, punctuated::Punctuated, token, visit::Visit, - Expr, ExprClosure, Ident, LitStr, Pat, Stmt, Token, + visit_mut::VisitMut, + Expr, ExprClosure, Ident, Lit, LitInt, LitStr, Pat, Stmt, Token, }; // --------------------------------------------------------------------------- @@ -24,7 +25,7 @@ struct OdeInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, diffeq_mode: OdeDiffeqMode, diffeq: ExprClosure, @@ -39,7 +40,7 @@ struct AnalyticalInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, structure: Ident, sec: Option, @@ -54,7 +55,7 @@ struct SdeInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, particles: Expr, drift: ExprClosure, @@ -73,7 +74,7 @@ enum OdeDiffeqMode { struct OdeRouteDecl { kind: OdeRouteKind, - input: Ident, + input: SymbolicIndex, destination: Ident, } @@ -91,10 +92,78 @@ struct AnalyticalKernelSpec { } struct RoutePropertyEntry { - route: Ident, + route: SymbolicIndex, value: Expr, } +#[derive(Clone)] +enum SymbolicIndex { + Ident(Ident), + Int(LitInt), +} + +impl SymbolicIndex { + fn name(&self) -> String { + match self { + Self::Ident(ident) => ident.to_string(), + Self::Int(lit) => lit.base10_digits().to_string(), + } + } + + fn ident(&self) -> Option<&Ident> { + match self { + Self::Ident(ident) => Some(ident), + Self::Int(_) => None, + } + } + + fn numeric_value(&self) -> Option { + match self { + Self::Ident(_) => None, + Self::Int(lit) => Some( + lit.base10_parse::() + .expect("validated numeric label should fit usize"), + ), + } + } + + fn numeric(value: usize) -> Self { + Self::Int(LitInt::new(&value.to_string(), Span::call_site())) + } +} + +impl Parse for SymbolicIndex { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(LitInt) { + let lit: LitInt = input.parse()?; + lit.base10_parse::().map_err(|_| { + syn::Error::new_spanned( + &lit, + "numeric declaration-first labels must be non-negative base-10 integers that fit in usize", + ) + })?; + Ok(Self::Int(lit)) + } else { + Ok(Self::Ident(input.parse()?)) + } + } +} + +impl ToTokens for SymbolicIndex { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + Self::Ident(ident) => ident.to_tokens(tokens), + Self::Int(lit) => lit.to_tokens(tokens), + } + } +} + +impl std::fmt::Display for SymbolicIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.name()) + } +} + impl Parse for OdeRouteDecl { fn parse(input: ParseStream) -> syn::Result { let kind_ident: Ident = input.parse()?; @@ -111,7 +180,7 @@ impl Parse for OdeRouteDecl { let content; syn::parenthesized!(content in input); - let route_input: Ident = content.parse()?; + let route_input: SymbolicIndex = content.parse()?; if !content.is_empty() { return Err(content.error("expected a single route input name inside `(...)`")); } @@ -166,7 +235,12 @@ impl Parse for OdeInput { "covariates", )?, "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?, - "outputs" => set_once_ode(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "outputs" => set_once_ode( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?, "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?, "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?, @@ -206,14 +280,16 @@ impl Parse for OdeInput { validate_unique_idents("parameter", ¶ms, "ode!")?; validate_unique_idents("covariate", &covariates, "ode!")?; validate_unique_idents("state", &states, "ode!")?; - validate_unique_idents("output", &outputs, "ode!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "ode!")?; validate_routes(&routes, &states, "ode!")?; validate_named_binding_compatibility( NamedBindingSets { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, OdeBindingClosures { @@ -247,7 +323,7 @@ impl Parse for OdeInput { impl Parse for RoutePropertyEntry { fn parse(input: ParseStream) -> syn::Result { - let route: Ident = input.parse()?; + let route: SymbolicIndex = input.parse()?; input.parse::]>()?; let value: Expr = input.parse()?; Ok(Self { route, value }) @@ -287,9 +363,12 @@ impl Parse for AnalyticalInput { "states" => { set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")? } - "outputs" => { - set_once_analytical(&mut outputs, parse_ident_list(input)?, &key, "outputs")? - } + "outputs" => set_once_analytical( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => { set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")? } @@ -328,7 +407,9 @@ impl Parse for AnalyticalInput { validate_unique_idents("parameter", ¶ms, "analytical!")?; validate_unique_idents("covariate", &covariates, "analytical!")?; validate_unique_idents("state", &states, "analytical!")?; - validate_unique_idents("output", &outputs, "analytical!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "analytical!")?; validate_routes(&routes, &states, "analytical!")?; let kernel_spec = resolve_analytical_structure(&structure)?; @@ -358,7 +439,7 @@ impl Parse for AnalyticalInput { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, AnalyticalBindingClosures { @@ -431,7 +512,12 @@ impl Parse for SdeInput { "covariates", )?, "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?, - "outputs" => set_once_sde(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "outputs" => set_once_sde( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?, "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?, "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?, @@ -469,14 +555,16 @@ impl Parse for SdeInput { validate_unique_idents("parameter", ¶ms, "sde!")?; validate_unique_idents("covariate", &covariates, "sde!")?; validate_unique_idents("state", &states, "sde!")?; - validate_unique_idents("output", &outputs, "sde!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "sde!")?; validate_routes(&routes, &states, "sde!")?; validate_sde_named_binding_compatibility( NamedBindingSets { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, SdeBindingClosures { @@ -595,6 +683,16 @@ fn parse_ident_list(input: ParseStream) -> syn::Result> { .collect()) } +fn parse_symbolic_index_list(input: ParseStream) -> syn::Result> { + let content; + syn::bracketed!(content in input); + Ok( + Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(), + ) +} + fn parse_route_list(input: ParseStream) -> syn::Result> { let content; syn::braced!(content in input); @@ -627,6 +725,29 @@ fn generated_ident(name: &str) -> Ident { Ident::new(name, Span::call_site()) } +fn symbolic_index_idents(labels: &[SymbolicIndex]) -> Vec { + labels + .iter() + .filter_map(|label| label.ident().cloned()) + .collect() +} + +fn symbolic_index_bindings(labels: &[SymbolicIndex]) -> Vec<(SymbolicIndex, usize)> { + labels + .iter() + .cloned() + .enumerate() + .map(|(index, label)| (label, index)) + .collect() +} + +fn symbolic_numeric_binding_map(bindings: &[(SymbolicIndex, usize)]) -> HashMap { + bindings + .iter() + .filter_map(|(label, index)| label.numeric_value().map(|value| (value, *index))) + .collect() +} + #[derive(Default)] struct ClosureBodyUsage { idents: HashSet, @@ -713,6 +834,124 @@ impl<'ast> Visit<'ast> for ClosureBodyUsage { } } +struct IndexRewriteTarget { + container: Ident, + labels: HashMap, +} + +impl IndexRewriteTarget { + fn new(container: Ident, labels: HashMap) -> Self { + Self { container, labels } + } +} + +struct NumericLabelRewriter { + index_targets: Vec, + route_labels: Option>, +} + +impl NumericLabelRewriter { + fn rewrite( + expr: &Expr, + index_targets: Vec, + route_labels: Option>, + ) -> Expr { + let mut rewritten = expr.clone(); + let mut rewriter = Self { + index_targets, + route_labels, + }; + rewriter.visit_expr_mut(&mut rewritten); + rewritten + } + + fn target_labels(&self, path: &syn::ExprPath) -> Option<&HashMap> { + if path.qself.is_some() + || path.path.leading_colon.is_some() + || path.path.segments.len() != 1 + { + return None; + } + + let ident = &path.path.segments[0].ident; + self.index_targets + .iter() + .find(|target| target.container == *ident) + .map(|target| &target.labels) + } + + fn rewrite_route_macro(&self, mac: &mut syn::Macro) { + let Some(route_labels) = self.route_labels.as_ref() else { + return; + }; + if !(mac.path.is_ident("lag") || mac.path.is_ident("fa")) { + return; + } + + let Ok(entries) = Punctuated::::parse_terminated + .parse2(mac.tokens.clone()) + else { + return; + }; + + let entries = entries.into_iter().map(|mut entry| { + if let Some(value) = entry.route.numeric_value() { + if let Some(internal_index) = route_labels.get(&value) { + entry.route = SymbolicIndex::numeric(*internal_index); + } + } + entry + }); + + let tokens = entries.map(|entry| { + let route = entry.route; + let value = entry.value; + quote! { #route => #value } + }); + mac.tokens = quote! { #(#tokens),* }; + } +} + +impl VisitMut for NumericLabelRewriter { + fn visit_expr_index_mut(&mut self, expr_index: &mut syn::ExprIndex) { + syn::visit_mut::visit_expr_index_mut(self, expr_index); + + let Expr::Path(expr_path) = expr_index.expr.as_ref() else { + return; + }; + let Some(labels) = self.target_labels(expr_path) else { + return; + }; + let Expr::Lit(expr_lit) = expr_index.index.as_ref() else { + return; + }; + let Lit::Int(lit) = &expr_lit.lit else { + return; + }; + let Ok(external_index) = lit.base10_parse::() else { + return; + }; + let Some(internal_index) = labels.get(&external_index) else { + return; + }; + + expr_index.index = Box::new(Expr::Lit(syn::ExprLit { + attrs: Vec::new(), + lit: Lit::Int(LitInt::new(&internal_index.to_string(), lit.span())), + })); + } + + fn visit_expr_macro_mut(&mut self, expr_macro: &mut syn::ExprMacro) { + self.rewrite_route_macro(&mut expr_macro.mac); + syn::visit_mut::visit_expr_macro_mut(self, expr_macro); + } + + fn visit_stmt_macro_mut(&mut self, stmt_macro: &mut syn::StmtMacro) { + self.rewrite_route_macro(&mut stmt_macro.mac); + syn::visit_mut::visit_stmt_macro_mut(self, stmt_macro); + } +} + fn generate_closure_input_aliases( closure: &ExprClosure, internal_names: &[Ident], @@ -856,10 +1095,17 @@ fn classify_diffeq_mode( } fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec { - routes.iter().map(|route| route.input.clone()).collect() + routes + .iter() + .filter_map(|route| route.input.ident().cloned()) + .collect() +} + +fn route_input_names(routes: &[OdeRouteDecl]) -> Vec { + routes.iter().map(|route| route.input.name()).collect() } -fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { +fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { let mut next_bolus_index = 0usize; let mut next_infusion_index = 0usize; @@ -883,7 +1129,7 @@ fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { .collect() } -fn dense_index_len(bindings: &[(Ident, usize)]) -> usize { +fn dense_index_len(bindings: &[(SymbolicIndex, usize)]) -> usize { bindings .iter() .map(|(_, index)| index + 1) @@ -1361,12 +1607,14 @@ fn generate_index_consts(idents: &[Ident]) -> TokenStream2 { } } -fn generate_mapped_index_consts(bindings: &[(Ident, usize)]) -> TokenStream2 { - let bindings = bindings.iter().map(|(ident, index)| { - quote! { - #[allow(non_upper_case_globals, dead_code)] - const #ident: usize = #index; - } +fn generate_mapped_index_consts(bindings: &[(SymbolicIndex, usize)]) -> TokenStream2 { + let bindings = bindings.iter().filter_map(|(label, index)| { + label.ident().map(|ident| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }) }); quote! { @@ -1379,10 +1627,11 @@ fn expand_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -1397,7 +1646,19 @@ fn expand_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -1480,14 +1741,13 @@ fn extract_route_property_routes( let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?; let entries = Punctuated::::parse_terminated .parse2(macro_expr.tokens.clone())?; - let known_routes = route_input_idents(routes) + let known_routes = route_input_names(routes) .into_iter() - .map(|route| route.to_string()) .collect::>(); let mut seen = HashSet::new(); for entry in entries { - let route_name = entry.route.to_string(); + let route_name = entry.route.name(); if !known_routes.contains(&route_name) { return Err(syn::Error::new_spanned( &entry.route, @@ -1515,7 +1775,7 @@ fn validate_route_property_kinds( property_routes: &HashSet, ) -> syn::Result<()> { for route in routes { - if property_routes.contains(&route.input.to_string()) + if property_routes.contains(&route.input.name()) && matches!(route.kind, OdeRouteKind::Infusion) { return Err(syn::Error::new_spanned( @@ -1536,7 +1796,7 @@ fn expand_ode_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -1553,7 +1813,11 @@ fn expand_ode_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -1626,7 +1890,7 @@ fn expand_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1671,7 +1935,7 @@ fn expand_analytical_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1711,7 +1975,7 @@ fn expand_sde_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1752,7 +2016,7 @@ fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize { fn expand_injected_ode_route_terms( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], dx: &Ident, bolus: &Ident, rateiv: &Ident, @@ -1780,7 +2044,7 @@ fn expand_injected_ode_route_terms( fn expand_injected_sde_rate_terms( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], dx: &Ident, rateiv: &Ident, ) -> TokenStream2 { @@ -1806,7 +2070,7 @@ fn expand_injected_sde_rate_terms( fn expand_injected_sde_bolus_mappings( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> TokenStream2 { let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; @@ -1836,12 +2100,30 @@ fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn Ok(()) } +fn validate_unique_symbolic_indices( + kind: &str, + labels: &[SymbolicIndex], + macro_name: &str, +) -> syn::Result<()> { + let mut seen = HashSet::new(); + for label in labels { + let name = label.name(); + if !seen.insert(name.clone()) { + return Err(syn::Error::new_spanned( + label, + format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"), + )); + } + } + Ok(()) +} + fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> { let known_states = states.iter().map(Ident::to_string).collect::>(); let mut seen_routes = HashSet::new(); for route in routes { - let route_name = route.input.to_string(); + let route_name = route.input.name(); if !seen_routes.insert(route_name.clone()) { return Err(syn::Error::new_spanned( &route.input, @@ -1869,7 +2151,7 @@ fn expand_diffeq( covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], diffeq_mode: OdeDiffeqMode, ) -> syn::Result { let state_consts = generate_index_consts(states); @@ -1907,7 +2189,25 @@ fn expand_diffeq( )?; let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let body = &diffeq.body; + let bolus_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 4).unwrap_or_else(|| bolus.clone()) + } else { + closure_param_ident(diffeq, 3).unwrap_or_else(|| bolus.clone()) + }; + let rateiv_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 5).unwrap_or_else(|| rateiv.clone()) + } else { + closure_param_ident(diffeq, 4).unwrap_or_else(|| rateiv.clone()) + }; + let route_label_map = symbolic_numeric_binding_map(route_bindings); + let body = NumericLabelRewriter::rewrite( + diffeq.body.as_ref(), + vec![ + IndexRewriteTarget::new(bolus_binding, route_label_map.clone()), + IndexRewriteTarget::new(rateiv_binding, route_label_map), + ], + None, + ); Ok(quote! {{ let __pharmsol_diffeq: fn( @@ -2094,7 +2394,7 @@ fn expand_analytical_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -2111,7 +2411,11 @@ fn expand_analytical_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -2221,10 +2525,11 @@ fn expand_analytical_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -2239,7 +2544,19 @@ fn expand_analytical_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -2270,7 +2587,7 @@ fn expand_sde_drift( covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let state_consts = generate_index_consts(states); let x = generated_ident("__pharmsol_x"); @@ -2360,7 +2677,7 @@ fn expand_sde_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -2377,7 +2694,11 @@ fn expand_sde_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -2444,10 +2765,11 @@ fn expand_sde_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -2462,7 +2784,19 @@ fn expand_sde_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -3039,11 +3373,11 @@ mod tests { let bindings = ode_route_channel_bindings(&input.routes); assert_eq!(dense_index_len(&bindings), 2); - assert_eq!(bindings[0].0.to_string(), "oral"); + assert_eq!(bindings[0].0.name(), "oral"); assert_eq!(bindings[0].1, 0); - assert_eq!(bindings[1].0.to_string(), "iv"); + assert_eq!(bindings[1].0.name(), "iv"); assert_eq!(bindings[1].1, 0); - assert_eq!(bindings[2].0.to_string(), "sc"); + assert_eq!(bindings[2].0.name(), "sc"); assert_eq!(bindings[2].1, 1); } diff --git a/src/data/builder.rs b/src/data/builder.rs index 18aa17fe..a1718dc7 100644 --- a/src/data/builder.rs +++ b/src/data/builder.rs @@ -67,7 +67,7 @@ impl SubjectBuilder { /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered /// * `input` - The compartment number receiving the dose - pub fn bolus(self, time: f64, amount: f64, input: usize) -> Self { + pub fn bolus(self, time: f64, amount: f64, input: impl ToString) -> Self { let bolus = Bolus::new(time, amount, input, self.current_occasion.index()); let event = Event::Bolus(bolus); self.event(event) @@ -81,7 +81,7 @@ impl SubjectBuilder { /// * `amount` - Total amount of drug to be administered /// * `input` - The compartment number receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn infusion(self, time: f64, amount: f64, input: usize, duration: f64) -> Self { + pub fn infusion(self, time: f64, amount: f64, input: impl ToString, duration: f64) -> Self { let infusion = Infusion::new(time, amount, input, duration, self.current_occasion.index()); let event = Event::Infusion(infusion); self.event(event) @@ -94,7 +94,7 @@ impl SubjectBuilder { /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) /// * `outeq` - Output equation number corresponding to this observation - pub fn observation(self, time: f64, value: f64, outeq: usize) -> Self { + pub fn observation(self, time: f64, value: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, Some(value), @@ -118,7 +118,7 @@ impl SubjectBuilder { self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, censoring: Censor, ) -> Self { let observation = Observation::new( @@ -139,7 +139,7 @@ impl SubjectBuilder { /// /// * `time` - Time of the observation /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation - pub fn missing_observation(self, time: f64, outeq: usize) -> Self { + pub fn missing_observation(self, time: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, None, @@ -165,7 +165,7 @@ impl SubjectBuilder { self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) -> Self { diff --git a/src/data/event.rs b/src/data/event.rs index ff88e097..46995ef5 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -93,6 +93,78 @@ pub enum Event { /// An observation of drug concentration or other measure Observation(Observation), } + +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct ChannelId(String); + +impl ChannelId { + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } + + pub fn as_str(&self) -> &str { + &self.0 + } + + pub fn index(&self) -> Option { + self.0.parse::().ok() + } +} + +impl From for ChannelId { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for ChannelId { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl From for ChannelId { + fn from(value: usize) -> Self { + Self(value.to_string()) + } +} + +impl AsRef for ChannelId { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl fmt::Display for ChannelId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl PartialEq for ChannelId { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } +} + +impl PartialEq for usize { + fn eq(&self, other: &ChannelId) -> bool { + other == self + } +} + +impl PartialEq for &ChannelId { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } +} + +impl PartialEq<&ChannelId> for usize { + fn eq(&self, other: &&ChannelId) -> bool { + other.eq(self) + } +} + impl Event { /// Get the time of the event pub fn time(&self) -> f64 { @@ -152,7 +224,7 @@ impl Event { pub struct Bolus { time: f64, amount: f64, - input: usize, + input: ChannelId, occasion: usize, } impl Bolus { @@ -163,11 +235,11 @@ impl Bolus { /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered /// * `input` - The compartment number receiving the dose - pub fn new(time: f64, amount: f64, input: usize, occasion: usize) -> Self { + pub fn new(time: f64, amount: f64, input: impl ToString, occasion: usize) -> Self { Bolus { time, amount, - input, + input: ChannelId::new(input), occasion, } } @@ -178,8 +250,12 @@ impl Bolus { } /// Get the compartment number that receives the bolus - pub fn input(&self) -> usize { - self.input + pub fn input(&self) -> &ChannelId { + &self.input + } + + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the time of the bolus administration @@ -193,8 +269,8 @@ impl Bolus { } /// Set the compartment number that receives the bolus - pub fn set_input(&mut self, input: usize) { - self.input = input; + pub fn set_input(&mut self, input: impl ToString) { + self.input = ChannelId::new(input); } /// Set the time of the bolus administration @@ -208,7 +284,7 @@ impl Bolus { } /// Get a mutable reference to the compartment number (1-indexed) that receives the bolus - pub fn mut_input(&mut self) -> &mut usize { + pub fn mut_input(&mut self) -> &mut ChannelId { &mut self.input } @@ -235,7 +311,7 @@ impl Bolus { pub struct Infusion { time: f64, amount: f64, - input: usize, + input: ChannelId, duration: f64, occasion: usize, } @@ -248,11 +324,17 @@ impl Infusion { /// * `amount` - Total amount of drug to be administered /// * `input` - The compartment number receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn new(time: f64, amount: f64, input: usize, duration: f64, occasion: usize) -> Self { + pub fn new( + time: f64, + amount: f64, + input: impl ToString, + duration: f64, + occasion: usize, + ) -> Self { Infusion { time, amount, - input, + input: ChannelId::new(input), duration, occasion, } @@ -264,8 +346,12 @@ impl Infusion { } /// Get the compartment number that receives the infusion - pub fn input(&self) -> usize { - self.input + pub fn input(&self) -> &ChannelId { + &self.input + } + + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the duration of the infusion @@ -286,8 +372,8 @@ impl Infusion { } /// Set the compartment number that receives the infusion - pub fn set_input(&mut self, input: usize) { - self.input = input; + pub fn set_input(&mut self, input: impl ToString) { + self.input = ChannelId::new(input); } /// Set the time of the infusion administration @@ -306,7 +392,7 @@ impl Infusion { } /// Get a mutable reference to the compartment number (1-indexed) that receives the infusion - pub fn mut_input(&mut self) -> &mut usize { + pub fn mut_input(&mut self) -> &mut ChannelId { &mut self.input } @@ -348,7 +434,7 @@ pub enum Censor { pub struct Observation { time: f64, value: Option, - outeq: usize, + outeq: ChannelId, errorpoly: Option, occasion: usize, censoring: Censor, @@ -367,7 +453,7 @@ impl Observation { pub(crate) fn new( time: f64, value: Option, - outeq: usize, + outeq: impl ToString, errorpoly: Option, occasion: usize, censoring: Censor, @@ -375,7 +461,7 @@ impl Observation { Observation { time, value, - outeq, + outeq: ChannelId::new(outeq), errorpoly, occasion, censoring, @@ -393,8 +479,12 @@ impl Observation { } /// Get the output equation number corresponding to this observation - pub fn outeq(&self) -> usize { - self.outeq + pub fn outeq(&self) -> &ChannelId { + &self.outeq + } + + pub fn outeq_index(&self) -> Option { + self.outeq.index() } /// Get the error polynomial coefficients (c0, c1, c2, c3) if available @@ -415,8 +505,8 @@ impl Observation { } /// Set the output equation number corresponding to this observation - pub fn set_outeq(&mut self, outeq: usize) { - self.outeq = outeq; + pub fn set_outeq(&mut self, outeq: impl ToString) { + self.outeq = ChannelId::new(outeq); } /// Set the [ErrorPoly] for this observation @@ -435,7 +525,7 @@ impl Observation { } /// Get a mutable reference to the output equation number - pub fn mut_outeq(&mut self) -> &mut usize { + pub fn mut_outeq(&mut self) -> &mut ChannelId { &mut self.outeq } @@ -460,7 +550,9 @@ impl Observation { time: self.time(), observation: self.value(), prediction: pred, - outeq: self.outeq(), + outeq: self + .outeq_index() + .expect("prediction requires a resolved or numeric output label"), errorpoly: self.errorpoly(), state, occasion: self.occasion(), @@ -539,6 +631,7 @@ mod tests { assert_eq!(bolus.time(), 2.5); assert_eq!(bolus.amount(), 100.0); assert_eq!(bolus.input(), 1); + assert_eq!(bolus.input().as_str(), "1"); } #[test] @@ -561,6 +654,7 @@ mod tests { assert_eq!(infusion.time(), 1.0); assert_eq!(infusion.amount(), 200.0); assert_eq!(infusion.input(), 1); + assert_eq!(infusion.input().as_str(), "1"); assert_eq!(infusion.duration(), 2.5); } @@ -589,6 +683,7 @@ mod tests { assert_eq!(observation.time(), 5.0); assert_eq!(observation.value(), Some(75.5)); assert_eq!(observation.outeq(), 2); + assert_eq!(observation.outeq().as_str(), "2"); assert_eq!(observation.errorpoly(), error_poly); } diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index c410d689..4554e435 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -95,14 +95,14 @@ struct Row { #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, /// Input compartment - #[serde(deserialize_with = "deserialize_option_usize")] - input: Option, + #[serde(deserialize_with = "deserialize_option_channel_id")] + input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, /// Corresponding output equation for the observation - #[serde(deserialize_with = "deserialize_option_usize")] - outeq: Option, + #[serde(deserialize_with = "deserialize_option_channel_id")] + outeq: Option, /// Censoring output #[serde(default, deserialize_with = "deserialize_option_censor")] cens: Option, @@ -134,12 +134,12 @@ impl Row { dur: self.dur, addl: self.addl.map(|a| a as i64), ii: self.ii, - input: self.input, + input: self.input.clone(), // Treat -99 as missing value (Pmetrics convention) out: self .out .and_then(|v| if v == -99.0 { None } else { Some(v) }), - outeq: self.outeq, + outeq: self.outeq.clone(), cens: self.cens, c0: self.c0, c1: self.c1, @@ -196,11 +196,11 @@ where } } -fn deserialize_option_usize<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_option_channel_id<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserialize_option::(deserializer) + deserialize_option::(deserializer).map(|value| value.map(ChannelId::from)) } fn deserialize_option_isize<'de, D>(deserializer: D) -> Result, D::Error> @@ -496,4 +496,50 @@ mod tests { assert_eq!(second.get(11), Some(".")); assert_eq!(second.get(14), Some(".")); } + + #[test] + fn read_pmetrics_preserves_named_channel_labels() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,cp,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Infusion(infusion) => assert_eq!(infusion.input().as_str(), "iv"), + _ => panic!("expected infusion event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "cp"), + _ => panic!("expected observation event"), + } + } + + #[test] + fn read_pmetrics_preserves_numeric_labels_as_strings() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,.,100,.,.,1,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,1,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Bolus(bolus) => assert_eq!(bolus.input().as_str(), "1"), + _ => panic!("expected bolus event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "1"), + _ => panic!("expected observation event"), + } + } } diff --git a/src/data/row.rs b/src/data/row.rs index b3b38ad8..f6e44e98 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -79,11 +79,11 @@ pub struct DataRow { /// Interdose interval for ADDL pub ii: Option, /// Input compartment - pub input: Option, + pub input: Option, /// Observed value (for EVID=0) pub out: Option, /// Output equation number - pub outeq: Option, + pub outeq: Option, /// Censoring indicator pub cens: Option, /// Error polynomial coefficients @@ -180,14 +180,17 @@ impl DataRow { match self.evid { 0 => { // Observation event - events.push(Event::Observation(Observation::new( - self.time, - self.out, + let outeq = self.outeq + .clone() .ok_or_else(|| DataError::MissingObservationOuteq { id: self.id.clone(), time: self.time, - })?, // Keep 1-indexed as provided by Pmetrics + })?; + events.push(Event::Observation(Observation::new( + self.time, + self.out, + outeq, self.get_errorpoly(), 0, // occasion set later self.cens.unwrap_or(Censor::None), @@ -196,10 +199,13 @@ impl DataRow { 1 | 4 => { // Dosing event (1) or reset with dose (4) - let input = self.input.ok_or_else(|| DataError::MissingBolusInput { - id: self.id.clone(), - time: self.time, - })?; // Keep 1-indexed as provided by Pmetrics + let input = self + .input + .clone() + .ok_or_else(|| DataError::MissingBolusInput { + id: self.id.clone(), + time: self.time, + })?; let event = if self.dur.unwrap_or(0.0) > 0.0 { // Infusion @@ -371,8 +377,8 @@ impl DataRowBuilder { /// /// Required for EVID=1 (dosing events). /// Kept as 1-indexed; user must size state arrays accordingly. - pub fn input(mut self, input: usize) -> Self { - self.row.input = Some(input); + pub fn input(mut self, input: impl ToString) -> Self { + self.row.input = Some(ChannelId::new(input)); self } @@ -388,8 +394,8 @@ impl DataRowBuilder { /// /// Required for EVID=0 (observation events). /// Will be converted to 0-indexed internally. - pub fn outeq(mut self, outeq: usize) -> Self { - self.row.outeq = Some(outeq); + pub fn outeq(mut self, outeq: impl ToString) -> Self { + self.row.outeq = Some(ChannelId::new(outeq)); self } diff --git a/src/data/structs.rs b/src/data/structs.rs index 82cd3faf..c977d89a 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -180,13 +180,13 @@ impl Data { let old_events = occasion.process_events(None, true); // Create a set of existing (time, outeq) pairs for fast lookup - let existing_obs: std::collections::HashSet<(u64, usize)> = old_events + let existing_obs: std::collections::HashSet<(u64, ChannelId)> = old_events .iter() .filter_map(|event| match event { Event::Observation(obs) => { // Convert to microseconds for consistent comparison let time_key = (obs.time() * 1e6).round() as u64; - Some((time_key, obs.outeq())) + Some((time_key, obs.outeq().clone())) } _ => None, }) @@ -198,13 +198,13 @@ impl Data { while time < last_time { let time_key = (time * 1e6).round() as u64; - for &outeq in &outeq_values { + for outeq in &outeq_values { // Only add if this (time, outeq) combination doesn't exist - if !existing_obs.contains(&(time_key, outeq)) { + if !existing_obs.contains(&(time_key, outeq.clone())) { let obs = Observation::new( time, None, - outeq, + outeq.clone(), None, occasion.index, Censor::None, @@ -274,14 +274,14 @@ impl Data { } /// Get a vector of all unique output equations (outeq) across all subjects - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let mut outeq_values: Vec = self + let mut outeq_values: Vec = self .subjects .iter() .flat_map(|subject| subject.get_output_equations()) .collect(); - outeq_values.sort_unstable(); + outeq_values.sort(); outeq_values.dedup(); outeq_values } @@ -396,14 +396,14 @@ impl Subject { self.occasions.iter_mut() } - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let outeq_values: Vec = self + let outeq_values: Vec = self .occasions .iter() .flat_map(|occasion| { occasion.events.iter().filter_map(|event| match event { - Event::Observation(obs) => Some(obs.outeq()), + Event::Observation(obs) => Some(obs.outeq().clone()), _ => None, }) }) @@ -598,8 +598,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let lagtime = fn_lag(&spp.clone().into(), time, covariates); - if let Some(l) = lagtime.get(&bolus.input()) { - *bolus.mut_time() += l; + if let Some(input) = bolus.input_index() { + if let Some(l) = lagtime.get(&input) { + *bolus.mut_time() += l; + } } } } @@ -615,8 +617,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let fa = fn_fa(&spp.clone().into(), time, covariates); - if let Some(f) = fa.get(&bolus.input()) { - bolus.set_amount(bolus.amount() * f); + if let Some(input) = bolus.input_index() { + if let Some(f) = fa.get(&input) { + bolus.set_amount(bolus.amount() * f); + } } } } @@ -703,7 +707,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: Option, censored: Censor, ) { @@ -713,7 +717,7 @@ impl Occasion { } /// Add a missing [Observation] event to the [Occasion] - pub fn add_missing_observation(&mut self, time: f64, outeq: usize) { + pub fn add_missing_observation(&mut self, time: f64, outeq: impl ToString) { let observation = Observation::new(time, None, outeq, None, self.index, Censor::None); self.add_event(Event::Observation(observation)); } @@ -725,7 +729,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) { @@ -741,13 +745,13 @@ impl Occasion { } /// Add a [Bolus] event to the [Occasion] - pub fn add_bolus(&mut self, time: f64, amount: f64, input: usize) { + pub fn add_bolus(&mut self, time: f64, amount: f64, input: impl ToString) { let bolus = Bolus::new(time, amount, input, self.index); self.add_event(Event::Bolus(bolus)); } /// Add an [Infusion] event to the [Occasion] - pub fn add_infusion(&mut self, time: f64, amount: f64, input: usize, duration: f64) { + pub fn add_infusion(&mut self, time: f64, amount: f64, input: impl ToString, duration: f64) { let infusion = Infusion::new(time, amount, input, duration, self.index); self.add_event(Event::Infusion(infusion)); } @@ -775,17 +779,6 @@ impl Occasion { .unwrap_or(0.0) } - pub(crate) fn infusions_ref(&self) -> Vec<&Infusion> { - //TODO this can be pre-computed when the struct is initially created - self.events - .iter() - .filter_map(|event| match event { - Event::Infusion(infusion) => Some(infusion), - _ => None, - }) - .collect() - } - /// Get an iterator over all events /// /// # Returns @@ -967,7 +960,7 @@ impl Occasion { for event in &self.events { if let Event::Observation(obs) = event { - if obs.outeq() == outeq { + if obs.outeq_index() == Some(outeq) { if let Some(value) = obs.value() { times.push(obs.time()); concs.push(value); diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 202fd45a..d2600e67 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -20,7 +20,7 @@ pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ - data::{Covariates, Infusion}, + data::{ChannelId, Covariates, Infusion}, simulator::{ equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, @@ -29,7 +29,7 @@ use crate::{ likelihood::{Prediction, SubjectPredictions}, M, V, }, - Event, Observation, PharmsolError, Subject, + Event, Observation, Occasion, PharmsolError, Subject, }; pub type DenseKernelFn = unsafe extern "C" fn( @@ -375,6 +375,16 @@ impl SharedNativeModel { Ok(()) } + fn validate_output(&self, outeq: usize) -> Result<(), PharmsolError> { + if outeq >= self.info.output_len { + return Err(PharmsolError::OuteqOutOfRange { + outeq, + nout: self.info.output_len, + }); + } + Ok(()) + } + fn validate_input_for_kind(&self, input: usize, kind: RouteKind) -> Result<(), PharmsolError> { self.validate_input(input)?; if self.route_semantics.supports_input(input, kind) { @@ -387,6 +397,62 @@ impl SharedNativeModel { ))) } + fn resolve_input_label( + &self, + label: &ChannelId, + kind: RouteKind, + ) -> Result { + if let Some(input) = self.route_index(label.as_str()) { + self.validate_input_for_kind(input, kind)?; + return Ok(input); + } + + let input = label + .index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + self.validate_input_for_kind(input, kind)?; + Ok(input) + } + + fn resolve_output_label(&self, label: &ChannelId) -> Result { + if let Some(outeq) = self.output_index(label.as_str()) { + return Ok(outeq); + } + + let outeq = label + .index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + })?; + self.validate_output(outeq)?; + Ok(outeq) + } + + fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { + let mut events = occasion.process_events(None, true); + + for event in events.iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(events) + } + fn fill_cov_buffer(&self, covariates: &Covariates, time: f64, buf: &mut [f64]) { for covariate in &self.info.covariates { buf[covariate.index] = match covariates.get_covariate(&covariate.name) { @@ -530,7 +596,13 @@ impl SharedNativeModel { for event in events.iter_mut() { if let Event::Bolus(bolus) = event { - self.validate_input_for_kind(bolus.input(), RouteKind::Bolus)?; + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + self.validate_input_for_kind(input, RouteKind::Bolus)?; if self.artifact.has_kernel(KernelRole::RouteLag) { lag_values.fill(0.0); @@ -556,7 +628,7 @@ impl SharedNativeModel { lag_values.as_mut_ptr(), )?; } - let lag = lag_values[bolus.input()]; + let lag = lag_values[input]; if lag != 0.0 { *bolus.mut_time() += lag; } @@ -586,7 +658,7 @@ impl SharedNativeModel { fa_values.as_mut_ptr(), )?; } - let factor = fa_values[bolus.input()]; + let factor = fa_values[input]; if factor != 1.0 { bolus.set_amount(bolus.amount() * factor); } @@ -651,13 +723,13 @@ impl SharedNativeModel { &cov_buf, &mut outputs, )?; - if observation.outeq() >= outputs.len() { - return Err(PharmsolError::OuteqOutOfRange { - outeq: observation.outeq(), - nout: outputs.len(), - }); - } - Ok(observation.to_prediction(outputs[observation.outeq()], state.to_vec())) + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + self.validate_output(outeq)?; + Ok(observation.to_prediction(outputs[outeq], state.to_vec())) } } @@ -734,18 +806,15 @@ impl NativeOdeModel { let support_vector: V = DVector::from_vec(support_point.to_vec()).into(); for occasion in subject.occasions() { - let infusion_refs = occasion.infusions_ref(); - let infusions = infusion_refs + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); + let infusion_refs = infusions.iter().collect::>(); let session = RefCell::new(self.shared.artifact.start_session()?); let mut route_session = session.borrow_mut(); self.shared.apply_route_properties( @@ -901,9 +970,15 @@ impl NativeOdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; self.shared.apply_bolus( solver.state_mut().y.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -1000,18 +1075,14 @@ impl NativeAnalyticalModel { let mut output = SubjectPredictions::default(); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -1030,8 +1101,12 @@ impl NativeAnalyticalModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { - self.shared - .apply_bolus(&mut state, bolus.input(), bolus.amount())? + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; + self.shared.apply_bolus(&mut state, input, bolus.amount())? } Event::Infusion(_) => {} Event::Observation(observation) => { @@ -1171,18 +1246,14 @@ impl NativeSdeModel { let mut output = Array2::from_shape_fn((self.nparticles, 0), |_| Prediction::default()); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -1204,10 +1275,15 @@ impl NativeSdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; for particle in &mut particles { self.shared.apply_bolus( particle.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -1398,11 +1474,14 @@ impl NativeSdeModel { fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> Vec { let mut values = vec![0.0; route_len]; for infusion in infusions { - if infusion.input() < route_len + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && time >= infusion.time() && time <= infusion.time() + infusion.duration() { - values[infusion.input()] += infusion.amount() / infusion.duration(); + values[input] += infusion.amount() / infusion.duration(); } } values @@ -1417,8 +1496,11 @@ fn interval_route_inputs( let mut values = vec![0.0; route_len]; for infusion in infusions { let finish = infusion.time() + infusion.duration(); - if infusion.input() < route_len && start_time >= infusion.time() && end_time <= finish { - values[infusion.input()] += infusion.amount() / infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && start_time >= infusion.time() && end_time <= finish { + values[input] += infusion.amount() / infusion.duration(); } } values diff --git a/src/error/mod.rs b/src/error/mod.rs index 1316b8a4..5145626e 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -37,6 +37,10 @@ pub enum PharmsolError { ZeroLikelihood, #[error("Missing observation in prediction")] MissingObservation, + #[error("Input label `{label}` could not be resolved to a route channel")] + UnknownInputLabel { label: String }, + #[error("Output label `{label}` could not be resolved to an output channel")] + UnknownOutputLabel { label: String }, #[error("Input channel {input} is out of range (ndrugs = {ndrugs})")] InputOutOfRange { input: usize, ndrugs: usize }, #[error("Output equation {outeq} is out of range (nout = {nout})")] diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 4734886c..b0d78481 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -278,6 +278,11 @@ impl EquationPriv for Analytical { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -321,13 +326,19 @@ impl EquationPriv for Analytical { let s = inf.time(); let e = s + inf.duration(); if current_t >= s && next_t <= e { - if inf.input() >= self.get_ndrugs() { + let input = + inf.input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: inf.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: inf.input(), + input, ndrugs: self.get_ndrugs(), }); } - rateiv[inf.input()] += inf.amount() / inf.duration(); + rateiv[input] += inf.amount() / inf.duration(); } } @@ -365,7 +376,12 @@ impl EquationPriv for Analytical { covariates, &mut y, ); - let pred = y[observation.outeq()]; + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + let pred = y[outeq]; let pred = observation.to_prediction(pred, x.as_slice().to_vec()); if let Some(error_models) = error_models { likelihood.push(pred.log_likelihood(error_models)?.exp()); diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 60cb2d8f..f3532382 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -12,7 +12,7 @@ pub use sde::*; use crate::{ error_model::AssayErrorModels, simulator::{Fa, Lag}, - Covariates, Event, Infusion, Observation, PharmsolError, Subject, + ChannelId, Covariates, Event, Infusion, Observation, Occasion, PharmsolError, Subject, }; use super::likelihood::Prediction; @@ -129,6 +129,7 @@ pub(crate) trait EquationPriv: EquationTypes { fn get_nstates(&self) -> usize; fn get_ndrugs(&self) -> usize; fn get_nouteqs(&self) -> usize; + fn metadata(&self) -> Option<&ValidatedModelMetadata>; fn solve( &self, state: &mut Self::S, @@ -141,6 +142,85 @@ pub(crate) trait EquationPriv: EquationTypes { fn nparticles(&self) -> usize { 1 } + + fn resolve_input_label( + &self, + label: &ChannelId, + expected_kind: RouteKind, + ) -> Result { + if let Some(metadata) = self.metadata() { + let route = + metadata + .route(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + + if route.kind() != expected_kind { + return Err(PharmsolError::OtherError(format!( + "input label `{}` is declared as {:?} but used as {:?}", + label, + route.kind(), + expected_kind + ))); + } + + return Ok(route.channel_index()); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + }) + } + + fn resolve_output_label(&self, label: &ChannelId) -> Result { + if let Some(metadata) = self.metadata() { + return metadata.output_index(label.as_str()).ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: label.to_string(), + } + }); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }) + } + + fn resolve_occasion_events( + &self, + occasion: &Occasion, + support_point: &[f64], + covariates: &Covariates, + ) -> Result, PharmsolError> { + let mut resolved = occasion.clone(); + + for event in resolved.events_iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(resolved.process_events( + Some((self.fa(), self.lag(), support_point, covariates)), + true, + )) + } #[allow(dead_code)] fn is_sde(&self) -> bool { false @@ -181,13 +261,20 @@ pub(crate) trait EquationPriv: EquationTypes { ) -> Result<(), PharmsolError> { match event { Event::Bolus(bolus) => { - if bolus.input() >= self.get_ndrugs() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: self.get_ndrugs(), }); } - x.add_bolus(bolus.input(), bolus.amount()); + x.add_bolus(input, bolus.amount()); } Event::Infusion(infusion) => { infusions.push(infusion.clone()); @@ -332,10 +419,7 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { let mut x = self.initial_state(support_point, covariates, occasion.index()); let mut infusions = Vec::new(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; for (index, event) in events.iter().enumerate() { self.simulate_event( support_point, diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index eed65e7a..cb9c0726 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -80,7 +80,11 @@ impl InfusionSchedule { continue; } - let input = infusion.input(); + let input = infusion + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: infusion.input().to_string(), + })?; if input >= ndrugs { return Err(PharmsolError::InputOutOfRange { input, ndrugs }); } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index cafe6a96..853b3108 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -330,6 +330,11 @@ impl EquationPriv for ODE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -397,14 +402,21 @@ impl ODE { match event { Event::Bolus(bolus) => { - if bolus.input() >= bolus_v.len() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= bolus_v.len() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: bolus_v.len(), }); } bolus_v.fill(0.0); - bolus_v[bolus.input()] = bolus.amount(); + bolus_v[input] = bolus.amount(); state_with_bolus.fill(0.0); state_without_bolus.fill(0.0); @@ -444,7 +456,12 @@ impl ODE { covariates, y_out, ); - let pred = y_out[observation.outeq()]; + let outeq = observation.outeq_index().ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + } + })?; + let pred = y_out[outeq]; let pred = observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); if let Some(error_models) = error_models { @@ -550,11 +567,14 @@ impl Equation for ODE { // Iterate over occasions for occasion in subject.occasions() { let covariates = occasion.covariates(); - let infusions = occasion.infusions_ref(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; + let infusions = events + .iter() + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion), + _ => None, + }) + .collect::>(); let problem = OdeBuilder::::new() .atol(vec![self.atol]) @@ -680,9 +700,9 @@ mod tests { fn route_policy_subject() -> Subject { Subject::builder("route_policy") - .bolus(0.0, 100.0, 0) - .infusion(0.0, 100.0, 0, 1.0) - .observation(1.0, 0.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(0.0, 100.0, "iv", 1.0) + .observation(1.0, 0.0, "cp") .build() } diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index bdafbbc3..c24b615d 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -124,7 +124,10 @@ fn simulate_sde_event( let mut rateiv = V::zeros(ndrugs, NalgebraContext); for infusion in &infusion_events { if time >= infusion.time() && time <= infusion.duration() + infusion.time() { - rateiv[infusion.input()] += infusion.amount() / infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + rateiv[input] += infusion.amount() / infusion.duration(); } } @@ -466,6 +469,11 @@ impl EquationPriv for SDE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -524,7 +532,10 @@ impl EquationPriv for SDE { covariates, &mut y, ); - *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec()); + let outeq = observation + .outeq_index() + .expect("resolved observations should use numeric output labels"); + *p = observation.to_prediction(y[outeq], x[i].as_slice().to_vec()); }); let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?; *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap(); @@ -588,17 +599,21 @@ impl EquationPriv for SDE { ) -> Result<(), PharmsolError> { match event { crate::Event::Bolus(bolus) => { - if bolus.input() >= self.get_ndrugs() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: self.get_ndrugs(), }); } - if !self - .injected_bolus_mappings - .apply(x, bolus.input(), bolus.amount()) - { - x.add_bolus(bolus.input(), bolus.amount()); + if !self.injected_bolus_mappings.apply(x, input, bolus.amount()) { + x.add_bolus(input, bolus.amount()); } } crate::Event::Infusion(infusion) => { @@ -909,8 +924,8 @@ mod tests { .expect("injected metadata should validate"); let subject = Subject::builder("bolus_route") - .bolus(0.0, 100.0, 0) - .missing_observation(0.1, 0) + .bolus(0.0, 100.0, "oral") + .missing_observation(0.1, "cp") .build(); let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); @@ -954,8 +969,8 @@ mod tests { .expect("injected metadata should validate"); let subject = Subject::builder("infusion_route") - .infusion(0.0, 100.0, 0, 1.0) - .missing_observation(1.0, 0) + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(1.0, "cp") .build(); let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index e025ec4f..c55719f5 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -1,47 +1,47 @@ use approx::assert_relative_eq; use pharmsol::prelude::*; -fn infusion_subject(input: usize) -> Subject { +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("analytical-macro-iv") .infusion(0.0, 120.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn oral_subject(input: usize) -> Subject { +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("analytical-macro-oral") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn shared_channel_subject(input: usize) -> Subject { +fn shared_channel_subject() -> Subject { Subject::builder("analytical-macro-shared") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { Subject::builder("analytical-macro-covariates") .bolus(1.0, 100.0, oral) .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) .missing_observation(8.0, cp) .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) @@ -382,7 +382,7 @@ fn assert_prediction_match(left: &[f64], right: &[f64]) { fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_one_compartment(); let handwritten_model = handwritten_one_compartment(); - let subject = infusion_subject(0); + let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -408,7 +408,7 @@ fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { let macro_model = macro_one_compartment_with_absorption(); let handwritten_model = handwritten_one_compartment_with_absorption(); - let subject = oral_subject(0); + let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -441,7 +441,7 @@ fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_shared_channel_analytical(); let handwritten_model = handwritten_shared_channel_analytical(); - let subject = shared_channel_subject(0); + let subject = shared_channel_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -472,14 +472,9 @@ fn analytical_macro_covariates_lower_to_handwritten_behavior() { assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - let oral = macro_model.route_index("oral").expect("oral route exists"); - let iv = macro_model.route_index("iv").expect("iv route exists"); - let cp = macro_model.output_index("cp").expect("cp output exists"); - let subject = covariate_subject(oral, iv, cp); + let subject = covariate_subject("oral", "iv", "cp"); let support_point = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; - assert_eq!(oral, iv); - let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) .expect("macro covariate analytical model should simulate") diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index 43621e8a..67f91c7a 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -1,3 +1,4 @@ +use approx::assert_relative_eq; #[cfg(feature = "dsl-jit")] use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; #[cfg(feature = "dsl-jit")] @@ -88,6 +89,24 @@ dx(central) = ka * depot - ke * central out(cp) = central / v ~ continuous() "#; +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL: &str = r#" +name = mixed_output_labels_runtime +kind = ode + +params = ke, v +states = central +outputs = cp, 0, 1 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +out(0) = 2 * central / v ~ continuous() +out(1) = 3 * central / v ~ continuous() +"#; + const ANALYTICAL_DSL: &str = r#" name = one_cmt_abs_parity kind = analytical @@ -267,16 +286,16 @@ fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntim } #[cfg(feature = "dsl-jit")] -fn shared_channel_prediction_subject(input: usize, output: usize) -> Subject { +fn shared_channel_prediction_subject() -> Subject { Subject::builder("authoring-parity-shared-channel") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, output) - .missing_observation(1.0, output) - .missing_observation(2.0, output) - .missing_observation(6.5, output) - .missing_observation(7.0, output) - .missing_observation(8.0, output) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } @@ -1192,11 +1211,12 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1241,11 +1261,12 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1292,11 +1313,12 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1340,11 +1362,12 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(mismatched_model.route_index("oral"), Some(oral)); assert_eq!(mismatched_model.route_index("iv"), Some(iv)); @@ -1363,3 +1386,35 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { assert_prediction_vectors_diverge(&runtime_predictions, &mismatched_predictions, 1e-4); } + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_preserves_mixed_output_labels() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL, + "mixed_output_labels_runtime", + ); + let subject = Subject::builder("runtime-mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(0.5, "0") + .missing_observation(0.5, "1") + .build(); + let support_point = [0.2, 10.0]; + + assert_eq!(runtime_model.output_index("cp"), Some(0)); + assert_eq!(runtime_model.output_index("0"), Some(1)); + assert_eq!(runtime_model.output_index("1"), Some(2)); + + let predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime mixed-output model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + + assert_eq!(predictions.len(), 3); + assert_relative_eq!(predictions[1], 2.0 * predictions[0], epsilon = 1e-6); + assert_relative_eq!(predictions[2], 3.0 * predictions[0], epsilon = 1e-6); +} diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 7b068733..a556f428 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -1,38 +1,55 @@ use approx::assert_relative_eq; +use pharmsol::prelude::data::read_pmetrics; use pharmsol::prelude::*; +use tempfile::NamedTempFile; -fn subject_for_route(input: usize) -> Subject { +fn write_pmetrics_fixture(contents: &str) -> NamedTempFile { + let file = NamedTempFile::new().expect("create temporary Pmetrics fixture"); + std::fs::write(file.path(), contents).expect("write temporary Pmetrics fixture"); + file +} + +fn subject_for_route(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("macro-lowering") .infusion(0.0, 100.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn subject_for_shared_channel(input: usize) -> Subject { +fn subject_for_shared_channel() -> Subject { Subject::builder("macro-shared-channel") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn subject_for_covariates(input: usize) -> Subject { +fn subject_for_covariates(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("macro-covariates") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .covariate("wt", 0.0, 70.0) .build() } +fn subject_for_numeric_bolus_route(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("numeric-bolus-route") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + fn injected_macro_ode() -> equation::ODE { ode! { name: "injected_one_cpt", @@ -131,6 +148,55 @@ fn explicit_handwritten_ode() -> equation::ODE { .expect("handwritten explicit metadata should validate") } +fn numeric_label_macro_ode() -> equation::ODE { + ode! { + name: "numeric_label_one_cpt", + params: [ke, v], + states: [central], + outputs: [1], + routes: { + infusion(1) -> central, + }, + diffeq: |x, _t, dx, _bolus, rateiv| { + dx[central] = rateiv[1] - ke * x[central]; + }, + out: |x, _t, y| { + y[1] = x[central] / v; + }, + } +} + +fn numeric_label_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_label_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["1"]) + .route( + equation::Route::infusion("1") + .to_state("central") + .expect_explicit_input(), + ), + ) + .expect("handwritten numeric-label metadata should validate") +} + fn shared_channel_macro_ode() -> equation::ODE { ode! { name: "shared_channel_one_cpt", @@ -200,6 +266,124 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .expect("handwritten shared-channel metadata should validate") } +fn numeric_route_property_macro_ode() -> equation::ODE { + ode! { + name: "numeric_route_property_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [1], + routes: { + bolus(1) -> depot, + }, + diffeq: |x, _t, dx, bolus, _rateiv| { + dx[depot] = bolus[1] - ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + lag: |_t| { + lag! { 1 => tlag } + }, + fa: |_t| { + fa! { 1 => f_oral } + }, + out: |x, _t, y| { + y[1] = x[central] / v; + }, + } +} + +fn numeric_route_property_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_route_property_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["1"]) + .route( + equation::Route::bolus("1") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + ), + ) + .expect("handwritten numeric route-property metadata should validate") +} + +fn mixed_output_labels_macro_ode() -> equation::ODE { + ode! { + name: "mixed_output_labels_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp, 0, 1], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _t, dx, _bolus, rateiv| { + dx[central] = rateiv[iv] - ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + y[0] = 2.0 * x[central] / v; + y[1] = 3.0 * x[central] / v; + }, + } +} + +fn mixed_output_labels_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + y[1] = 2.0 * x[0] / v; + y[2] = 3.0 * x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(3) + .with_metadata( + equation::metadata::new("mixed_output_labels_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp", "0", "1"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ), + ) + .expect("handwritten mixed-output metadata should validate") +} + fn covariate_macro_ode() -> equation::ODE { ode! { name: "covariate_one_cpt", @@ -267,7 +451,7 @@ fn assert_prediction_match(left: &[f64], right: &[f64]) { fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = injected_macro_ode(); let handwritten_ode = injected_handwritten_ode(); - let subject = subject_for_route(0); + let subject = subject_for_route("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -293,7 +477,7 @@ fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = explicit_macro_ode(); let handwritten_ode = explicit_handwritten_ode(); - let subject = subject_for_route(0); + let subject = subject_for_route("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -315,11 +499,37 @@ fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { assert_prediction_match(¯o_predictions, &handwritten_predictions); } +#[test] +fn macro_numeric_labels_lower_to_dense_slots() { + let macro_ode = numeric_label_macro_ode(); + let handwritten_ode = numeric_label_handwritten_ode(); + let subject = subject_for_route("1", "1"); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("1"), Some(0)); + assert_eq!(macro_ode.output_index("1"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro numeric-label model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten numeric-label model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + #[test] fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = shared_channel_macro_ode(); let handwritten_ode = shared_channel_handwritten_ode(); - let subject = subject_for_shared_channel(0); + let subject = subject_for_shared_channel(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -343,11 +553,119 @@ fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() assert_prediction_match(¯o_predictions, &handwritten_predictions); } +#[test] +fn macro_mixed_output_labels_lower_to_dense_slots() { + let macro_ode = mixed_output_labels_macro_ode(); + let handwritten_ode = mixed_output_labels_handwritten_ode(); + let subject = Subject::builder("mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "0") + .missing_observation(2.0, "1") + .build(); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.output_index("0"), Some(1)); + assert_eq!(macro_ode.output_index("1"), Some(2)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro mixed-output model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten mixed-output model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_numeric_route_properties_lower_to_dense_slots() { + let macro_ode = numeric_route_property_macro_ode(); + let handwritten_ode = numeric_route_property_handwritten_ode(); + let subject = subject_for_numeric_bolus_route("1", "1"); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("1"), Some(0)); + assert_eq!(macro_ode.output_index("1"), Some(0)); + assert_eq!(macro_ode.state_index("depot"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro numeric route-property model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten numeric route-property model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_named_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,cp,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read named-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = explicit_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro named-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = explicit_macro_ode() + .estimate_predictions(&subject_for_route("iv", "cp"), &support_point) + .expect("macro internal-index model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + +#[test] +fn macro_numeric_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,1,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,1,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read numeric-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = numeric_label_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro numeric-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = numeric_label_macro_ode() + .estimate_predictions(&subject_for_route("1", "1"), &support_point) + .expect("macro internal-index numeric-label model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + #[test] fn macro_covariate_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = covariate_macro_ode(); let handwritten_ode = covariate_handwritten_ode(); - let subject = subject_for_covariates(0); + let subject = subject_for_covariates("oral", "cp"); let support_point = [1.0, 0.2, 10.0]; let macro_metadata = macro_ode .metadata() diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 876d2b23..13d21a2b 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -2,47 +2,47 @@ use approx::assert_relative_eq; use pharmsol::prelude::*; use pharmsol::Predictions; -fn infusion_subject(input: usize) -> Subject { +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("sde-macro-iv") .infusion(0.0, 120.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn oral_subject(input: usize) -> Subject { +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("sde-macro-oral") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn shared_channel_subject(input: usize) -> Subject { +fn shared_channel_subject() -> Subject { Subject::builder("sde-macro-shared") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { Subject::builder("sde-macro-covariates") .bolus(1.0, 100.0, oral) .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) .missing_observation(8.0, cp) .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) @@ -491,7 +491,7 @@ fn handwritten_covariate_sde() -> equation::SDE { fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_infusion_sde(); let handwritten_model = handwritten_infusion_sde(); - let subject = infusion_subject(0); + let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 0.0, 10.0]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -516,7 +516,7 @@ fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { let macro_model = macro_absorption_sde(); let handwritten_model = handwritten_absorption_sde(); - let subject = oral_subject(0); + let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -541,7 +541,7 @@ fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_shared_channel_sde(); let handwritten_model = handwritten_shared_channel_sde(); - let subject = shared_channel_subject(0); + let subject = shared_channel_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -571,14 +571,9 @@ fn sde_macro_covariates_lower_to_handwritten_behavior() { assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - let oral = macro_model.route_index("oral").expect("oral route exists"); - let iv = macro_model.route_index("iv").expect("iv route exists"); - let cp = macro_model.output_index("cp").expect("cp output exists"); - let subject = covariate_subject(oral, iv, cp); + let subject = covariate_subject("oral", "iv", "cp"); let support_point = [1.0, 0.16, 0.0, 32.0, 0.5, 0.8, 3.0, 14.0]; - assert_eq!(oral, iv); - let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) .expect("macro covariate SDE should simulate"); From 8a94a3180eeabbceb55316af7b16b2b9b871abfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 13:57:19 +0100 Subject: [PATCH 08/22] chore: update test --- tests/full_feature_macro_parity.rs | 52 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index 71a1afa7..e3175f84 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -197,19 +197,19 @@ fn handwritten_ode_model() -> equation::ODE { .expect("handwritten ODE metadata should validate") } -fn build_ode_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { +fn build_ode_subject() -> Subject { Subject::builder("macro-vs-handwritten-ode-full-features") - .bolus(0.0, 80.0, load) - .bolus(1.0, 120.0, oral) - .infusion(6.0, 150.0, iv, 2.5) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -368,19 +368,19 @@ fn handwritten_analytical_model() -> equation::Analytical { .expect("handwritten analytical metadata should validate") } -fn build_analytical_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { +fn build_analytical_subject() -> Subject { Subject::builder("macro-vs-handwritten-analytical-full-features") - .bolus(0.0, 60.0, load) - .bolus(1.0, 100.0, oral) - .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -407,7 +407,7 @@ fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::Pharmsol assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); - let subject = build_ode_subject(oral, load, iv, cp); + let subject = build_ode_subject(); let params = [1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0]; let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; @@ -453,7 +453,7 @@ fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::P assert_eq!(handwritten_analytical.route_index("iv"), Some(iv)); assert_eq!(handwritten_analytical.output_index("cp"), Some(cp)); - let subject = build_analytical_subject(oral, load, iv, cp); + let subject = build_analytical_subject(); let params = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; let macro_predictions = macro_analytical.estimate_predictions(&subject, ¶ms)?; From fe3916a80a9a938e5a61ba3a84ea0106c722ac76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 14:19:42 +0100 Subject: [PATCH 09/22] chore: channeless --- examples/compare_solvers.rs | 4 +- examples/macro_vs_handwritten_one_cpt.rs | 2 +- examples/macro_vs_handwritten_two_cpt.rs | 13 +- pharmsol-dsl/src/execution.rs | 2 +- pharmsol-macros/src/lib.rs | 47 +++--- src/data/event.rs | 177 ++++++++++++----------- src/data/parser/pmetrics.rs | 21 ++- src/data/row.rs | 24 +-- src/data/structs.rs | 33 +++-- src/dsl/jit.rs | 2 +- src/dsl/model_info.rs | 2 +- src/dsl/native.rs | 10 +- src/error/mod.rs | 6 +- src/simulator/equation/analytical/mod.rs | 8 +- src/simulator/equation/metadata.rs | 62 ++++---- src/simulator/equation/mod.rs | 9 +- src/simulator/equation/ode/closure.rs | 22 ++- src/simulator/equation/ode/mod.rs | 8 +- src/simulator/equation/sde/mod.rs | 8 +- src/simulator/mod.rs | 4 +- tests/analytical_macro_lowering.rs | 20 +-- tests/authoring_parity_corpus.rs | 94 ++++++------ tests/ode_macro_lowering.rs | 26 ++-- tests/sde_macro_lowering.rs | 20 +-- 24 files changed, 312 insertions(+), 312 deletions(-) diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index ebec4caa..ad705931 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -48,7 +48,7 @@ fn main() { let trbdf2 = two_cpt(OdeSolver::Sdirk(SdirkTableau::TrBdf2)); let esdirk34 = two_cpt(OdeSolver::Sdirk(SdirkTableau::Esdirk34)); - // Both declarations resolve to the same shared input channel, so subject + // Both declarations resolve to the same shared input, so subject // authoring still uses one numeric index for the loading bolus and the // maintenance infusion. let load = bdf.route_index("load").expect("load route exists"); @@ -57,7 +57,7 @@ fn main() { assert_eq!( load, iv, - "mixed IV declarations should share one numeric channel" + "mixed IV declarations should share one numeric input" ); let subject = Subject::builder("id1") diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index ddff59f8..be9edb2a 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -26,7 +26,7 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( - // Handwritten closures stay on dense internal channels. + // Handwritten closures stay on dense internal slots. // Public labels like `iv` and `cp` live in attached metadata, not in // the low-level `rateiv[]` / `y[]` buffers. |x, p, _t, dx, _bolus, rateiv, _cov| { diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 915267d6..114024bd 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -1,5 +1,5 @@ //! Compares a declaration-first macro ODE with the equivalent handwritten ODE -//! on a two-compartment IV problem that shares one numeric input channel across +//! on a two-compartment IV problem that shares one numeric input across //! a loading bolus and a maintenance infusion. //! //! This keeps the macro story as the default surface while showing the @@ -9,7 +9,7 @@ use pharmsol::prelude::*; fn macro_model() -> equation::ODE { ode! { - name: "two_cpt_shared_channel_parity", + name: "two_cpt_shared_input_parity", params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], @@ -29,7 +29,7 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( - // Handwritten closures stay on dense internal channels. + // Handwritten closures stay on dense internal slots. // Public route labels like `load` and `iv` are metadata names; the // low-level `bolus[]`, `rateiv[]`, and `y[]` buffers remain indexed by // dense internal slots. @@ -50,7 +50,7 @@ fn handwritten_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("two_cpt_shared_channel_parity") + equation::metadata::new("two_cpt_shared_input_parity") .parameters(["ke", "kcp", "kpc", "v"]) .states(["central", "peripheral"]) .outputs(["cp"]) @@ -83,10 +83,7 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let iv = macro_ode.route_index("iv").expect("iv route exists"); let cp = macro_ode.output_index("cp").expect("cp output exists"); - assert_eq!( - load, iv, - "load and iv should share one numeric input channel" - ); + assert_eq!(load, iv, "load and iv should share one numeric input"); assert_eq!(handwritten_ode.route_index("load"), Some(load)); assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 886d570a..8bac1d69 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -1516,7 +1516,7 @@ mod tests { } #[test] - fn authoring_routes_share_channel_indices_by_kind_local_ordinal() { + fn authoring_routes_share_input_indices_by_kind_local_ordinal() { let src = r#"name = shared_authoring kind = ode diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 96b9536e..0d143184 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -1105,7 +1105,7 @@ fn route_input_names(routes: &[OdeRouteDecl]) -> Vec { routes.iter().map(|route| route.input.name()).collect() } -fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { +fn ode_route_input_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { let mut next_bolus_index = 0usize; let mut next_infusion_index = 0usize; @@ -2024,14 +2024,14 @@ fn expand_injected_ode_route_terms( let terms = routes .iter() .zip(route_bindings.iter()) - .map(|(route, (_, channel_index))| { + .map(|(route, (_, input_index))| { let destination = route_destination_index(route, states); match route.kind { OdeRouteKind::Bolus => quote! { - #dx[#destination] += #bolus[#channel_index]; + #dx[#destination] += #bolus[#input_index]; }, OdeRouteKind::Infusion => quote! { - #dx[#destination] += #rateiv[#channel_index]; + #dx[#destination] += #rateiv[#input_index]; }, } }); @@ -2048,19 +2048,18 @@ fn expand_injected_sde_rate_terms( dx: &Ident, rateiv: &Ident, ) -> TokenStream2 { - let terms = - routes - .iter() - .zip(route_bindings.iter()) - .filter_map(|(route, (_, channel_index))| match route.kind { - OdeRouteKind::Bolus => None, - OdeRouteKind::Infusion => { - let destination = route_destination_index(route, states); - Some(quote! { - #dx[#destination] += #rateiv[#channel_index]; - }) - } - }); + let terms = routes + .iter() + .zip(route_bindings.iter()) + .filter_map(|(route, (_, input_index))| match route.kind { + OdeRouteKind::Bolus => None, + OdeRouteKind::Infusion => { + let destination = route_destination_index(route, states); + Some(quote! { + #dx[#destination] += #rateiv[#input_index]; + }) + } + }); quote! { #(#terms)* @@ -2074,10 +2073,10 @@ fn expand_injected_sde_bolus_mappings( ) -> TokenStream2 { let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; - for (route, (_, channel_index)) in routes.iter().zip(route_bindings.iter()) { + for (route, (_, input_index)) in routes.iter().zip(route_bindings.iter()) { if let OdeRouteKind::Bolus = route.kind { let destination = route_destination_index(route, states); - destinations[*channel_index] = quote! { Some(#destination) }; + destinations[*input_index] = quote! { Some(#destination) }; } } @@ -2829,7 +2828,7 @@ fn expand_sde_out( pub fn ode(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as OdeInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let lag_routes = match input.lag.as_ref() { Some(closure) => match extract_route_property_routes( @@ -2986,7 +2985,7 @@ pub fn ode(input: TokenStream) -> TokenStream { #[proc_macro] pub fn analytical(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as AnalyticalInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let kernel_spec = match resolve_analytical_structure(&input.structure) { Ok(spec) => spec, @@ -3150,7 +3149,7 @@ pub fn analytical(input: TokenStream) -> TokenStream { #[proc_macro] pub fn sde(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as SdeInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let lag_routes = match input.lag.as_ref() { Some(closure) => match extract_route_property_routes( @@ -3364,13 +3363,13 @@ mod tests { } #[test] - fn ode_route_bindings_share_channels_by_kind_local_ordinal() { + fn ode_route_bindings_share_inputs_by_kind_local_ordinal() { let input = syn::parse_str::( "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: { bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot }, diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", ) .expect("declaration-first ode input should parse"); - let bindings = ode_route_channel_bindings(&input.routes); + let bindings = ode_route_input_bindings(&input.routes); assert_eq!(dense_index_len(&bindings), 2); assert_eq!(bindings[0].0.name(), "oral"); diff --git a/src/data/event.rs b/src/data/event.rs index 46995ef5..bff4c700 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -94,76 +94,85 @@ pub enum Event { Observation(Observation), } -#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] -pub struct ChannelId(String); +macro_rules! impl_label_type { + ($name:ident) => { + #[derive( + Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, + )] + pub struct $name(String); + + impl $name { + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } -impl ChannelId { - pub fn new(label: impl ToString) -> Self { - Self(label.to_string()) - } + pub fn as_str(&self) -> &str { + &self.0 + } - pub fn as_str(&self) -> &str { - &self.0 - } + pub fn index(&self) -> Option { + self.0.parse::().ok() + } + } - pub fn index(&self) -> Option { - self.0.parse::().ok() - } -} + impl From for $name { + fn from(value: String) -> Self { + Self(value) + } + } -impl From for ChannelId { - fn from(value: String) -> Self { - Self(value) - } -} + impl From<&str> for $name { + fn from(value: &str) -> Self { + Self(value.to_string()) + } + } -impl From<&str> for ChannelId { - fn from(value: &str) -> Self { - Self(value.to_string()) - } -} + impl From for $name { + fn from(value: usize) -> Self { + Self(value.to_string()) + } + } -impl From for ChannelId { - fn from(value: usize) -> Self { - Self(value.to_string()) - } -} + impl AsRef for $name { + fn as_ref(&self) -> &str { + self.as_str() + } + } -impl AsRef for ChannelId { - fn as_ref(&self) -> &str { - self.as_str() - } -} + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } + } -impl fmt::Display for ChannelId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.as_str()) - } -} + impl PartialEq for $name { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } + } -impl PartialEq for ChannelId { - fn eq(&self, other: &usize) -> bool { - self.index() == Some(*other) - } -} + impl PartialEq<$name> for usize { + fn eq(&self, other: &$name) -> bool { + other == self + } + } -impl PartialEq for usize { - fn eq(&self, other: &ChannelId) -> bool { - other == self - } -} + impl PartialEq for &$name { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } + } -impl PartialEq for &ChannelId { - fn eq(&self, other: &usize) -> bool { - (**self).eq(other) - } + impl PartialEq<&$name> for usize { + fn eq(&self, other: &&$name) -> bool { + other.eq(self) + } + } + }; } -impl PartialEq<&ChannelId> for usize { - fn eq(&self, other: &&ChannelId) -> bool { - other.eq(self) - } -} +impl_label_type!(InputLabel); +impl_label_type!(OutputLabel); impl Event { /// Get the time of the event @@ -224,7 +233,7 @@ impl Event { pub struct Bolus { time: f64, amount: f64, - input: ChannelId, + input: InputLabel, occasion: usize, } impl Bolus { @@ -234,12 +243,12 @@ impl Bolus { /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose + /// * `input` - The route label receiving the dose pub fn new(time: f64, amount: f64, input: impl ToString, occasion: usize) -> Self { Bolus { time, amount, - input: ChannelId::new(input), + input: InputLabel::new(input), occasion, } } @@ -249,8 +258,8 @@ impl Bolus { self.amount } - /// Get the compartment number that receives the bolus - pub fn input(&self) -> &ChannelId { + /// Get the route label that receives the bolus + pub fn input(&self) -> &InputLabel { &self.input } @@ -268,9 +277,9 @@ impl Bolus { self.amount = amount; } - /// Set the compartment number that receives the bolus + /// Set the route label that receives the bolus pub fn set_input(&mut self, input: impl ToString) { - self.input = ChannelId::new(input); + self.input = InputLabel::new(input); } /// Set the time of the bolus administration @@ -283,8 +292,8 @@ impl Bolus { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the bolus - pub fn mut_input(&mut self) -> &mut ChannelId { + /// Get a mutable reference to the route label that receives the bolus + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -311,7 +320,7 @@ impl Bolus { pub struct Infusion { time: f64, amount: f64, - input: ChannelId, + input: InputLabel, duration: f64, occasion: usize, } @@ -322,7 +331,7 @@ impl Infusion { /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - The route label receiving the dose /// * `duration` - Duration of the infusion in time units pub fn new( time: f64, @@ -334,7 +343,7 @@ impl Infusion { Infusion { time, amount, - input: ChannelId::new(input), + input: InputLabel::new(input), duration, occasion, } @@ -345,8 +354,8 @@ impl Infusion { self.amount } - /// Get the compartment number that receives the infusion - pub fn input(&self) -> &ChannelId { + /// Get the route label that receives the infusion + pub fn input(&self) -> &InputLabel { &self.input } @@ -371,9 +380,9 @@ impl Infusion { self.amount = amount; } - /// Set the compartment number that receives the infusion + /// Set the route label that receives the infusion pub fn set_input(&mut self, input: impl ToString) { - self.input = ChannelId::new(input); + self.input = InputLabel::new(input); } /// Set the time of the infusion administration @@ -391,8 +400,8 @@ impl Infusion { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the infusion - pub fn mut_input(&mut self) -> &mut ChannelId { + /// Get a mutable reference to the route label that receives the infusion + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -434,7 +443,7 @@ pub enum Censor { pub struct Observation { time: f64, value: Option, - outeq: ChannelId, + outeq: OutputLabel, errorpoly: Option, occasion: usize, censoring: Censor, @@ -446,7 +455,7 @@ impl Observation { /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation + /// * `outeq` - Output label corresponding to this observation /// * `errorpoly` - Optional error polynomial coefficients (c0, c1, c2, c3) /// * `occasion` - Occasion index /// * `censoring` - Censoring type for this observation @@ -461,7 +470,7 @@ impl Observation { Observation { time, value, - outeq: ChannelId::new(outeq), + outeq: OutputLabel::new(outeq), errorpoly, occasion, censoring, @@ -478,8 +487,8 @@ impl Observation { self.value } - /// Get the output equation number corresponding to this observation - pub fn outeq(&self) -> &ChannelId { + /// Get the output label corresponding to this observation + pub fn outeq(&self) -> &OutputLabel { &self.outeq } @@ -504,9 +513,9 @@ impl Observation { self.value = value; } - /// Set the output equation number corresponding to this observation + /// Set the output label corresponding to this observation pub fn set_outeq(&mut self, outeq: impl ToString) { - self.outeq = ChannelId::new(outeq); + self.outeq = OutputLabel::new(outeq); } /// Set the [ErrorPoly] for this observation @@ -524,8 +533,8 @@ impl Observation { &mut self.value } - /// Get a mutable reference to the output equation number - pub fn mut_outeq(&mut self) -> &mut ChannelId { + /// Get a mutable reference to the output label + pub fn mut_outeq(&mut self) -> &mut OutputLabel { &mut self.outeq } diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index 4554e435..89943f6e 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -95,14 +95,14 @@ struct Row { #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, /// Input compartment - #[serde(deserialize_with = "deserialize_option_channel_id")] - input: Option, + #[serde(deserialize_with = "deserialize_option_route_label")] + input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, /// Corresponding output equation for the observation - #[serde(deserialize_with = "deserialize_option_channel_id")] - outeq: Option, + #[serde(deserialize_with = "deserialize_option_output_label")] + outeq: Option, /// Censoring output #[serde(default, deserialize_with = "deserialize_option_censor")] cens: Option, @@ -196,11 +196,18 @@ where } } -fn deserialize_option_channel_id<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_option_route_label<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserialize_option::(deserializer).map(|value| value.map(ChannelId::from)) + deserialize_option::(deserializer).map(|value| value.map(InputLabel::from)) +} + +fn deserialize_option_output_label<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserialize_option::(deserializer).map(|value| value.map(OutputLabel::from)) } fn deserialize_option_isize<'de, D>(deserializer: D) -> Result, D::Error> @@ -498,7 +505,7 @@ mod tests { } #[test] - fn read_pmetrics_preserves_named_channel_labels() { + fn read_pmetrics_preserves_named_route_and_output_labels() { let file = NamedTempFile::new().unwrap(); std::fs::write( file.path(), diff --git a/src/data/row.rs b/src/data/row.rs index f6e44e98..b9a807c1 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -32,8 +32,8 @@ use thiserror::Error; /// /// # Fields /// -/// All fields use Pmetrics conventions: -/// - `input` and `outeq` are **1-indexed** (kept as-is, user must size arrays accordingly) +/// All fields use the public labeling conventions: +/// - `input` and `outeq` preserve the route and output labels from the source data /// - `evid`: 0=observation, 1=dose, 4=reset/new occasion /// - `addl`: positive=forward in time, negative=backward in time /// @@ -78,12 +78,12 @@ pub struct DataRow { pub addl: Option, /// Interdose interval for ADDL pub ii: Option, - /// Input compartment - pub input: Option, + /// Input route label + pub input: Option, /// Observed value (for EVID=0) pub out: Option, - /// Output equation number - pub outeq: Option, + /// Output label + pub outeq: Option, /// Censoring indicator pub cens: Option, /// Error polynomial coefficients @@ -373,12 +373,12 @@ impl DataRowBuilder { self } - /// Set the input compartment (1-indexed) + /// Set the input route label /// /// Required for EVID=1 (dosing events). - /// Kept as 1-indexed; user must size state arrays accordingly. + /// Preserved as the public route label until model resolution. pub fn input(mut self, input: impl ToString) -> Self { - self.row.input = Some(ChannelId::new(input)); + self.row.input = Some(InputLabel::new(input)); self } @@ -390,12 +390,12 @@ impl DataRowBuilder { self } - /// Set the output equation (1-indexed) + /// Set the output label /// /// Required for EVID=0 (observation events). - /// Will be converted to 0-indexed internally. + /// Preserved as the public output label until model resolution. pub fn outeq(mut self, outeq: impl ToString) -> Self { - self.row.outeq = Some(ChannelId::new(outeq)); + self.row.outeq = Some(OutputLabel::new(outeq)); self } diff --git a/src/data/structs.rs b/src/data/structs.rs index c977d89a..d7d123b1 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -180,17 +180,18 @@ impl Data { let old_events = occasion.process_events(None, true); // Create a set of existing (time, outeq) pairs for fast lookup - let existing_obs: std::collections::HashSet<(u64, ChannelId)> = old_events - .iter() - .filter_map(|event| match event { - Event::Observation(obs) => { - // Convert to microseconds for consistent comparison - let time_key = (obs.time() * 1e6).round() as u64; - Some((time_key, obs.outeq().clone())) - } - _ => None, - }) - .collect(); + let existing_obs: std::collections::HashSet<(u64, OutputLabel)> = + old_events + .iter() + .filter_map(|event| match event { + Event::Observation(obs) => { + // Convert to microseconds for consistent comparison + let time_key = (obs.time() * 1e6).round() as u64; + Some((time_key, obs.outeq().clone())) + } + _ => None, + }) + .collect(); // Generate new observation times let mut new_events = Vec::new(); @@ -273,10 +274,10 @@ impl Data { self.subjects.is_empty() } - /// Get a vector of all unique output equations (outeq) across all subjects - pub fn get_output_equations(&self) -> Vec { + /// Get a vector of all unique output labels (outeq) across all subjects + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let mut outeq_values: Vec = self + let mut outeq_values: Vec = self .subjects .iter() .flat_map(|subject| subject.get_output_equations()) @@ -396,9 +397,9 @@ impl Subject { self.occasions.iter_mut() } - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let outeq_values: Vec = self + let outeq_values: Vec = self .occasions .iter() .flat_map(|occasion| { diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index b0f1fe4a..5504ab08 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1331,7 +1331,7 @@ mod tests { } #[test] - fn authoring_runtime_shares_channel_between_bolus_and_infusion_routes() { + fn authoring_runtime_shares_input_between_bolus_and_infusion_routes() { let source = r#" name = shared_authoring kind = ode diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 0094059f..d9a2fdbd 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -243,7 +243,7 @@ model explicit_route_usage { } #[test] - fn authoring_shared_channel_routes_keep_declaration_specific_injection() { + fn authoring_shared_input_routes_keep_declaration_specific_injection() { let info = load_model_info( r#" name = shared_authoring diff --git a/src/dsl/native.rs b/src/dsl/native.rs index d2600e67..c1ce8eac 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -20,7 +20,7 @@ pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ - data::{ChannelId, Covariates, Infusion}, + data::{Covariates, Infusion, InputLabel, OutputLabel}, simulator::{ equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, @@ -392,14 +392,14 @@ impl SharedNativeModel { } Err(PharmsolError::OtherError(format!( - "model `{}` does not declare a {:?} route for input channel {}", + "model `{}` does not declare a {:?} route for input {}", self.info.name, kind, input ))) } fn resolve_input_label( &self, - label: &ChannelId, + label: &InputLabel, kind: RouteKind, ) -> Result { if let Some(input) = self.route_index(label.as_str()) { @@ -416,7 +416,7 @@ impl SharedNativeModel { Ok(input) } - fn resolve_output_label(&self, label: &ChannelId) -> Result { + fn resolve_output_label(&self, label: &OutputLabel) -> Result { if let Some(outeq) = self.output_index(label.as_str()) { return Ok(outeq); } @@ -682,7 +682,7 @@ impl SharedNativeModel { .bolus_destination(input) .ok_or_else(|| { PharmsolError::OtherError(format!( - "model `{}` does not declare a bolus route for input channel {}", + "model `{}` does not declare a bolus route for input index {}", self.info.name, input )) })?; diff --git a/src/error/mod.rs b/src/error/mod.rs index 5145626e..c8f70b58 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -37,11 +37,11 @@ pub enum PharmsolError { ZeroLikelihood, #[error("Missing observation in prediction")] MissingObservation, - #[error("Input label `{label}` could not be resolved to a route channel")] + #[error("Input label `{label}` could not be resolved to a route input")] UnknownInputLabel { label: String }, - #[error("Output label `{label}` could not be resolved to an output channel")] + #[error("Output label `{label}` could not be resolved to an output")] UnknownOutputLabel { label: String }, - #[error("Input channel {input} is out of range (ndrugs = {ndrugs})")] + #[error("Input index {input} is out of range (ndrugs = {ndrugs})")] InputOutOfRange { input: usize, ndrugs: usize }, #[error("Output equation {outeq} is out of range (nout = {nout})")] OuteqOutOfRange { outeq: usize, nout: usize }, diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index b0d78481..1dd4bbb5 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -32,9 +32,7 @@ pub enum AnalyticalMetadataError { Validation(#[from] ModelMetadataError), #[error("analytical model declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "analytical model declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("analytical model declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("analytical model declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -119,7 +117,7 @@ impl Analytical { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -186,7 +184,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(AnalyticalMetadataError::RouteCountMismatch { expected: neqs.ndrugs, diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs index ecf51a52..fecab7e2 100644 --- a/src/simulator/equation/metadata.rs +++ b/src/simulator/equation/metadata.rs @@ -80,7 +80,7 @@ pub struct ValidatedModelMetadata { covariates: Vec, states: Vec, routes: Vec, - route_channel_count: usize, + route_input_count: usize, outputs: Vec, particles: Option, analytical: Option, @@ -111,8 +111,8 @@ impl ValidatedModelMetadata { &self.routes } - pub fn route_channel_count(&self) -> usize { - self.route_channel_count + pub fn route_input_count(&self) -> usize { + self.route_input_count } pub fn outputs(&self) -> &[Output] { @@ -144,7 +144,7 @@ impl ValidatedModelMetadata { } pub fn route_index(&self, name: &str) -> Option { - self.route(name).map(ValidatedRoute::channel_index) + self.route(name).map(ValidatedRoute::input_index) } pub fn route_declaration_index(&self, name: &str) -> Option { @@ -185,7 +185,7 @@ pub struct ValidatedRoute { name: String, kind: RouteKind, declaration_index: usize, - channel_index: usize, + input_index: usize, destination: String, destination_index: usize, has_lag: bool, @@ -206,8 +206,8 @@ impl ValidatedRoute { self.declaration_index } - pub fn channel_index(&self) -> usize { - self.channel_index + pub fn input_index(&self) -> usize { + self.input_index } pub fn destination(&self) -> &str { @@ -416,7 +416,7 @@ impl ModelMetadata { let particles = resolve_particles(kind, self.particles, fallback_particles)?; validate_kind_specific_fields(kind, self.analytical, particles)?; - let (routes, route_channel_count) = validate_routes(self.routes, &self.states)?; + let (routes, route_input_count) = validate_routes(self.routes, &self.states)?; Ok(ValidatedModelMetadata { name: self.name, @@ -425,7 +425,7 @@ impl ModelMetadata { covariates: self.covariates, states: self.states, routes, - route_channel_count, + route_input_count, outputs: self.outputs, particles, analytical: self.analytical, @@ -730,20 +730,20 @@ fn validate_routes( routes: Vec, states: &[State], ) -> Result<(Vec, usize), ModelMetadataError> { - let mut bolus_channels = 0; - let mut infusion_channels = 0; + let mut bolus_inputs = 0; + let mut infusion_inputs = 0; let mut validated_routes = Vec::with_capacity(routes.len()); for (declaration_index, route) in routes.into_iter().enumerate() { - let channel_index = match route.kind { + let input_index = match route.kind { RouteKind::Bolus => { - let index = bolus_channels; - bolus_channels += 1; + let index = bolus_inputs; + bolus_inputs += 1; index } RouteKind::Infusion => { - let index = infusion_channels; - infusion_channels += 1; + let index = infusion_inputs; + infusion_inputs += 1; index } }; @@ -751,18 +751,18 @@ fn validate_routes( validated_routes.push(validate_route( route, declaration_index, - channel_index, + input_index, states, )?); } - Ok((validated_routes, bolus_channels.max(infusion_channels))) + Ok((validated_routes, bolus_inputs.max(infusion_inputs))) } fn validate_route( route: Route, declaration_index: usize, - channel_index: usize, + input_index: usize, states: &[State], ) -> Result { if route.kind == RouteKind::Infusion && route.has_lag { @@ -796,7 +796,7 @@ fn validate_route( name: route.name, kind: route.kind, declaration_index, - channel_index, + input_index, destination, destination_index, has_lag: route.has_lag, @@ -902,7 +902,7 @@ mod tests { assert_eq!(metadata.state_index("central"), Some(0)); assert_eq!(metadata.route_index("iv"), Some(0)); assert_eq!(metadata.route_declaration_index("iv"), Some(0)); - assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_input_count(), 1); assert_eq!(metadata.output_index("cp"), Some(0)); assert_eq!( metadata.route("iv").expect("route exists").destination(), @@ -915,10 +915,7 @@ mod tests { .declaration_index(), 0 ); - assert_eq!( - metadata.route("iv").expect("route exists").channel_index(), - 0 - ); + assert_eq!(metadata.route("iv").expect("route exists").input_index(), 0); assert_eq!( metadata .route("iv") @@ -988,8 +985,8 @@ mod tests { } #[test] - fn shared_channel_routes_preserve_declaration_and_channel_identity() { - let metadata = new("shared_channel") + fn shared_input_routes_preserve_declaration_and_input_identity() { + let metadata = new("shared_input") .kind(ModelKind::Ode) .parameters(["ke"]) .states(["gut", "central"]) @@ -999,19 +996,16 @@ mod tests { Route::infusion("iv").to_state("central"), ]) .validate() - .expect("shared-channel metadata should validate"); + .expect("shared-input metadata should validate"); assert_eq!(metadata.routes().len(), 2); - assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_input_count(), 1); assert_eq!(metadata.route_index("oral"), Some(0)); assert_eq!(metadata.route_index("iv"), Some(0)); assert_eq!(metadata.route_declaration_index("oral"), Some(0)); assert_eq!(metadata.route_declaration_index("iv"), Some(1)); - assert_eq!( - metadata.route("oral").expect("oral route").channel_index(), - 0 - ); - assert_eq!(metadata.route("iv").expect("iv route").channel_index(), 0); + assert_eq!(metadata.route("oral").expect("oral route").input_index(), 0); + assert_eq!(metadata.route("iv").expect("iv route").input_index(), 0); assert_eq!( metadata .route("oral") diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index f3532382..c5a97958 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -12,7 +12,8 @@ pub use sde::*; use crate::{ error_model::AssayErrorModels, simulator::{Fa, Lag}, - ChannelId, Covariates, Event, Infusion, Observation, Occasion, PharmsolError, Subject, + Covariates, Event, Infusion, InputLabel, Observation, Occasion, OutputLabel, PharmsolError, + Subject, }; use super::likelihood::Prediction; @@ -145,7 +146,7 @@ pub(crate) trait EquationPriv: EquationTypes { fn resolve_input_label( &self, - label: &ChannelId, + label: &InputLabel, expected_kind: RouteKind, ) -> Result { if let Some(metadata) = self.metadata() { @@ -165,7 +166,7 @@ pub(crate) trait EquationPriv: EquationTypes { ))); } - return Ok(route.channel_index()); + return Ok(route.input_index()); } label @@ -175,7 +176,7 @@ pub(crate) trait EquationPriv: EquationTypes { }) } - fn resolve_output_label(&self, label: &ChannelId) -> Result { + fn resolve_output_label(&self, label: &OutputLabel) -> Result { if let Some(metadata) = self.metadata() { return metadata.output_index(label.as_str()).ok_or_else(|| { PharmsolError::UnknownOutputLabel { diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index cb9c0726..47f2a81e 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -11,13 +11,13 @@ type C = ::C; type T = ::T; #[derive(Debug, Clone)] -struct InfusionChannel { +struct InfusionTrack { input: usize, event_times: Vec, cumulative_rates: Vec, } -impl InfusionChannel { +impl InfusionTrack { fn new(input: usize, mut events: Vec<(f64, f64)>) -> Self { events.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal)); @@ -63,15 +63,13 @@ impl InfusionChannel { #[derive(Debug, Clone, Default)] struct InfusionSchedule { - channels: Vec, + tracks: Vec, } impl InfusionSchedule { fn new(ndrugs: usize, infusions: &[&Infusion]) -> Result { if ndrugs == 0 || infusions.is_empty() { - return Ok(Self { - channels: Vec::new(), - }); + return Ok(Self { tracks: Vec::new() }); } let mut per_input: Vec> = vec![Vec::new(); ndrugs]; @@ -94,27 +92,27 @@ impl InfusionSchedule { per_input[input].push((infusion.time() + infusion.duration(), -rate)); } - let channels = per_input + let tracks = per_input .into_iter() .enumerate() .filter_map(|(input, events)| { if events.is_empty() { None } else { - Some(InfusionChannel::new(input, events)) + Some(InfusionTrack::new(input, events)) } }) .collect(); - Ok(Self { channels }) + Ok(Self { tracks }) } fn fill_rate_vector(&self, time: f64, rateiv: &mut V) { rateiv.fill(0.0); - for channel in &self.channels { - let rate = channel.rate_at(time); + for track in &self.tracks { + let rate = track.rate_at(time); if rate != 0.0 { - rateiv[channel.input] = rate; + rateiv[track.input] = rate; } } } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 853b3108..c65f16a9 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -87,9 +87,7 @@ pub enum OdeMetadataError { Validation(#[from] ModelMetadataError), #[error("ODE declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "ODE declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("ODE declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("ODE declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -134,7 +132,7 @@ impl ODE { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -211,7 +209,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(OdeMetadataError::RouteCountMismatch { expected: neqs.ndrugs, diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index c24b615d..43a1d48a 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -34,9 +34,7 @@ pub enum SdeMetadataError { Validation(#[from] ModelMetadataError), #[error("SDE declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "SDE declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("SDE declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("SDE declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -236,7 +234,7 @@ impl SDE { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -309,7 +307,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(SdeMetadataError::RouteCountMismatch { expected: neqs.ndrugs, diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 5cea84fe..058ca125 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -200,7 +200,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; /// /// # Fields /// - `nstates`: Number of state variables (ODE compartments) -/// - `ndrugs`: Number of drug input channels (size of bolus[] and rateiv[]) +/// - `ndrugs`: Number of drug inputs (size of bolus[] and rateiv[]) /// - `nout`: Number of output equations /// /// # Defaults @@ -218,7 +218,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; pub struct Neqs { /// Number of state variables pub nstates: usize, - /// Number of drug input channels (bolus/rateiv size) + /// Number of drug inputs (bolus/rateiv size) pub ndrugs: usize, /// Number of output equations pub nout: usize, diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index c55719f5..796cb55e 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -19,7 +19,7 @@ fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { .build() } -fn shared_channel_subject() -> Subject { +fn shared_input_subject() -> Subject { Subject::builder("analytical-macro-shared") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) @@ -160,7 +160,7 @@ fn handwritten_one_compartment_with_absorption() -> equation::Analytical { .expect("handwritten absorption metadata should validate") } -fn macro_shared_channel_analytical() -> equation::Analytical { +fn macro_shared_input_analytical() -> equation::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], @@ -183,7 +183,7 @@ fn macro_shared_channel_analytical() -> equation::Analytical { } } -fn handwritten_shared_channel_analytical() -> equation::Analytical { +fn handwritten_shared_input_analytical() -> equation::Analytical { equation::Analytical::new( equation::one_compartment_with_absorption, |_p, _t, _cov| {}, @@ -219,7 +219,7 @@ fn handwritten_shared_channel_analytical() -> equation::Analytical { ]) .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), ) - .expect("handwritten shared-channel analytical metadata should validate") + .expect("handwritten shared-input analytical metadata should validate") } fn macro_covariate_analytical() -> equation::Analytical { @@ -438,10 +438,10 @@ fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { } #[test] -fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_model = macro_shared_channel_analytical(); - let handwritten_model = handwritten_shared_channel_analytical(); - let subject = shared_channel_subject(); +fn analytical_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_analytical(); + let handwritten_model = handwritten_shared_input_analytical(); + let subject = shared_input_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -453,12 +453,12 @@ fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_pre let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel analytical model should simulate") + .expect("macro shared-input analytical model should simulate") .flat_predictions() .to_vec(); let handwritten_predictions = handwritten_model .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel analytical model should simulate") + .expect("handwritten shared-input analytical model should simulate") .flat_predictions() .to_vec(); diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index 67f91c7a..37a5891a 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -70,8 +70,8 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const ODE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" -name = shared_channel_one_cpt +const ODE_RUNTIME_SHARED_INPUT_DSL: &str = r#" +name = shared_input_one_cpt kind = ode params = ka, ke, v, tlag, f_oral @@ -122,7 +122,7 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +const ANALYTICAL_RUNTIME_SHARED_INPUT_DSL: &str = r#" name = one_cmt_abs_shared kind = analytical @@ -177,7 +177,7 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const SDE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +const SDE_RUNTIME_SHARED_INPUT_DSL: &str = r#" name = one_cmt_shared_sde kind = sde @@ -205,7 +205,7 @@ struct MetadataParityView { parameters: Vec, covariates: Vec, states: Vec, - route_channel_count: usize, + route_input_count: usize, routes: Vec, outputs: Vec, analytical_kernel: Option, @@ -230,7 +230,7 @@ struct RouteParity { name: String, kind: Option, declaration_index: usize, - channel_index: usize, + input_index: usize, destination_name: String, destination_index: usize, has_lag: bool, @@ -242,7 +242,7 @@ struct RouteParity { struct RouteInputPolicyParity { name: String, declaration_index: usize, - channel_index: usize, + input_index: usize, input_policy: RouteInputPolicy, } @@ -286,8 +286,8 @@ fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntim } #[cfg(feature = "dsl-jit")] -fn shared_channel_prediction_subject() -> Subject { - Subject::builder("authoring-parity-shared-channel") +fn shared_input_prediction_subject() -> Subject { + Subject::builder("authoring-parity-shared-input") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) .missing_observation(0.5, "cp") @@ -347,7 +347,7 @@ fn dsl_metadata_view(src: &str) -> MetadataParityView { name: route.name.clone(), kind: route.kind.map(RouteKindParity::from_dsl), declaration_index: route.declaration_index, - channel_index: route.index, + input_index: route.index, destination_name: route.destination.state_name.clone(), destination_index: route.destination.state_offset, has_lag: route.has_lag, @@ -361,7 +361,7 @@ fn dsl_metadata_view(src: &str) -> MetadataParityView { parameters, covariates, states, - route_channel_count: model.abi.route_buffer.len, + route_input_count: model.abi.route_buffer.len, routes, outputs, analytical_kernel: model.metadata.analytical, @@ -379,7 +379,7 @@ fn dsl_route_input_policy_view(src: &str) -> Vec { .map(|route| RouteInputPolicyParity { name: route.name, declaration_index: route.declaration_index, - channel_index: route.index, + input_index: route.index, input_policy: if route.inject_input_to_destination { RouteInputPolicy::InjectToDestination } else { @@ -421,7 +421,7 @@ fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityV index, }) .collect(), - route_channel_count: metadata.route_channel_count(), + route_input_count: metadata.route_input_count(), routes: metadata .routes() .iter() @@ -429,7 +429,7 @@ fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityV name: route.name().to_string(), kind: Some(RouteKindParity::from_handwritten(route.kind())), declaration_index: route.declaration_index(), - channel_index: route.channel_index(), + input_index: route.input_index(), destination_name: route.destination().to_string(), destination_index: route.destination_index(), has_lag: route.has_lag(), @@ -460,7 +460,7 @@ fn handwritten_route_input_policy_view( .map(|route| RouteInputPolicyParity { name: route.name().to_string(), declaration_index: route.declaration_index(), - channel_index: route.channel_index(), + input_index: route.input_index(), input_policy: route .input_policy() .expect("route input policy should be explicit in this handwritten fixture"), @@ -565,9 +565,9 @@ fn handwritten_ode_model() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_ode() -> equation::ODE { +fn runtime_shared_input_macro_ode() -> equation::ODE { ode! { - name: "shared_channel_one_cpt", + name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], @@ -592,7 +592,7 @@ fn runtime_shared_channel_macro_ode() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_ode() -> equation::ODE { +fn runtime_shared_input_handwritten_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -617,7 +617,7 @@ fn runtime_shared_channel_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt") + equation::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -632,11 +632,11 @@ fn runtime_shared_channel_handwritten_ode() -> equation::ODE { .expect_explicit_input(), ]), ) - .expect("handwritten shared-channel ODE metadata should validate") + .expect("handwritten shared-input ODE metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_mismatched_shared_channel_ode() -> equation::ODE { +fn runtime_mismatched_shared_input_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, _bolus, _rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -661,7 +661,7 @@ fn runtime_mismatched_shared_channel_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt_mismatched") + equation::metadata::new("shared_input_one_cpt_mismatched") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -676,11 +676,11 @@ fn runtime_mismatched_shared_channel_ode() -> equation::ODE { .expect_explicit_input(), ]), ) - .expect("mismatched shared-channel ODE metadata should validate") + .expect("mismatched shared-input ODE metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_analytical() -> equation::Analytical { +fn runtime_shared_input_macro_analytical() -> equation::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], @@ -704,7 +704,7 @@ fn runtime_shared_channel_macro_analytical() -> equation::Analytical { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { +fn runtime_shared_input_handwritten_analytical() -> equation::Analytical { equation::Analytical::new( equation::one_compartment_with_absorption, |_p, _t, _cov| {}, @@ -740,11 +740,11 @@ fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { ]) .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), ) - .expect("handwritten shared-channel analytical metadata should validate") + .expect("handwritten shared-input analytical metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_sde() -> equation::SDE { +fn runtime_shared_input_macro_sde() -> equation::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], @@ -780,7 +780,7 @@ fn runtime_shared_channel_macro_sde() -> equation::SDE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_sde() -> equation::SDE { +fn runtime_shared_input_handwritten_sde() -> equation::SDE { equation::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); @@ -830,7 +830,7 @@ fn runtime_shared_channel_handwritten_sde() -> equation::SDE { ]) .particles(8), ) - .expect("handwritten shared-channel SDE metadata should validate") + .expect("handwritten shared-input SDE metadata should validate") } #[cfg(feature = "dsl-jit")] @@ -1196,11 +1196,11 @@ fn invalid_dsl_infusion_route_properties_fail_explicitly() { #[cfg(feature = "dsl-jit")] #[test] -fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); - let macro_model = runtime_shared_channel_macro_ode(); - let handwritten_model = runtime_shared_channel_handwritten_ode(); + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let macro_model = runtime_shared_input_macro_ode(); + let handwritten_model = runtime_shared_input_handwritten_ode(); let oral = runtime_model .route_index("oral") @@ -1211,7 +1211,7 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); @@ -1246,11 +1246,11 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha #[cfg(feature = "dsl-jit")] #[test] -fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_abs_shared"); - let macro_model = runtime_shared_channel_macro_analytical(); - let handwritten_model = runtime_shared_channel_handwritten_analytical(); + compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_INPUT_DSL, "one_cmt_abs_shared"); + let macro_model = runtime_shared_input_macro_analytical(); + let handwritten_model = runtime_shared_input_handwritten_analytical(); let oral = runtime_model .route_index("oral") @@ -1261,7 +1261,7 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); @@ -1298,11 +1298,11 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan #[cfg(feature = "dsl-jit")] #[test] -fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(SDE_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_shared_sde"); - let macro_model = runtime_shared_channel_macro_sde(); - let handwritten_model = runtime_shared_channel_handwritten_sde(); + compile_runtime_jit_model(SDE_RUNTIME_SHARED_INPUT_DSL, "one_cmt_shared_sde"); + let macro_model = runtime_shared_input_macro_sde(); + let handwritten_model = runtime_shared_input_handwritten_sde(); let oral = runtime_model .route_index("oral") @@ -1313,7 +1313,7 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); @@ -1350,8 +1350,8 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha #[test] fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let runtime_model = - compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); - let mismatched_model = runtime_mismatched_shared_channel_ode(); + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let mismatched_model = runtime_mismatched_shared_input_ode(); let oral = runtime_model .route_index("oral") @@ -1362,7 +1362,7 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index a556f428..480f7e80 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -18,8 +18,8 @@ fn subject_for_route(input: impl ToString, outeq: impl ToString) -> Subject { .build() } -fn subject_for_shared_channel() -> Subject { - Subject::builder("macro-shared-channel") +fn subject_for_shared_input() -> Subject { + Subject::builder("macro-shared-input") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) .missing_observation(0.5, "cp") @@ -197,9 +197,9 @@ fn numeric_label_handwritten_ode() -> equation::ODE { .expect("handwritten numeric-label metadata should validate") } -fn shared_channel_macro_ode() -> equation::ODE { +fn shared_input_macro_ode() -> equation::ODE { ode! { - name: "shared_channel_one_cpt", + name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], @@ -223,7 +223,7 @@ fn shared_channel_macro_ode() -> equation::ODE { } } -fn shared_channel_handwritten_ode() -> equation::ODE { +fn shared_input_handwritten_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -248,7 +248,7 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt") + equation::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -263,7 +263,7 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .expect_explicit_input(), ]), ) - .expect("handwritten shared-channel metadata should validate") + .expect("handwritten shared-input metadata should validate") } fn numeric_route_property_macro_ode() -> equation::ODE { @@ -526,10 +526,10 @@ fn macro_numeric_labels_lower_to_dense_slots() { } #[test] -fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_ode = shared_channel_macro_ode(); - let handwritten_ode = shared_channel_handwritten_ode(); - let subject = subject_for_shared_channel(); +fn macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = shared_input_macro_ode(); + let handwritten_ode = shared_input_handwritten_ode(); + let subject = subject_for_shared_input(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -541,12 +541,12 @@ fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() let macro_predictions = macro_ode .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel model should simulate") + .expect("macro shared-input model should simulate") .flat_predictions() .to_vec(); let handwritten_predictions = handwritten_ode .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel model should simulate") + .expect("handwritten shared-input model should simulate") .flat_predictions() .to_vec(); diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 13d21a2b..05b5cb27 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -20,7 +20,7 @@ fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { .build() } -fn shared_channel_subject() -> Subject { +fn shared_input_subject() -> Subject { Subject::builder("sde-macro-shared") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) @@ -211,7 +211,7 @@ fn handwritten_absorption_sde() -> equation::SDE { .expect("handwritten absorption SDE metadata should validate") } -fn macro_shared_channel_sde() -> equation::SDE { +fn macro_shared_input_sde() -> equation::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], @@ -246,7 +246,7 @@ fn macro_shared_channel_sde() -> equation::SDE { } } -fn handwritten_shared_channel_sde() -> equation::SDE { +fn handwritten_shared_input_sde() -> equation::SDE { equation::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); @@ -296,7 +296,7 @@ fn handwritten_shared_channel_sde() -> equation::SDE { ]) .particles(8), ) - .expect("handwritten shared-channel SDE metadata should validate") + .expect("handwritten shared-input SDE metadata should validate") } fn macro_covariate_sde() -> equation::SDE { @@ -538,10 +538,10 @@ fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { } #[test] -fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_model = macro_shared_channel_sde(); - let handwritten_model = handwritten_shared_channel_sde(); - let subject = shared_channel_subject(); +fn sde_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_sde(); + let handwritten_model = handwritten_shared_input_sde(); + let subject = shared_input_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -553,10 +553,10 @@ fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_prediction let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel SDE should simulate"); + .expect("macro shared-input SDE should simulate"); let handwritten_predictions = handwritten_model .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel SDE should simulate"); + .expect("handwritten shared-input SDE should simulate"); assert_prediction_match( &prediction_means(¯o_predictions), From dac11081f25ef10c386479f9431d8fffa86155a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:22:09 +0100 Subject: [PATCH 10/22] fix: The new string-based mapper for inputs/outeq had an issue when ordering indices, lso the implementation was not complete over the DSL frontend --- pharmsol-dsl/src/authoring.rs | 44 +++- pharmsol-dsl/src/parser.rs | 10 +- src/dsl/aot.rs | 16 +- src/dsl/jit.rs | 122 +++++++--- src/dsl/native.rs | 23 +- src/dsl/runtime.rs | 285 +++++++++++++++++++++-- tests/authoring_parity_corpus.rs | 238 +++++++++++++++++++ tests/full_feature_dsl_backend_parity.rs | 201 ++++++++++++++++ tests/support/bimodal_ke.rs | 23 +- tests/support/runtime_corpus.rs | 102 ++++---- 10 files changed, 924 insertions(+), 140 deletions(-) create mode 100644 tests/full_feature_dsl_backend_parity.rs diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 129f07c8..7b0b4dd6 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -20,6 +20,7 @@ struct AuthoringParser<'a> { states: Vec, declared_derived: BTreeSet, declared_outputs: BTreeSet, + explicit_output_order: Vec, explicit_outputs: BTreeMap, assigned_outputs: BTreeMap, declared_outputs_span: Option, @@ -77,6 +78,7 @@ impl<'a> AuthoringParser<'a> { states: Vec::new(), declared_derived: BTreeSet::new(), declared_outputs: BTreeSet::new(), + explicit_output_order: Vec::new(), explicit_outputs: BTreeMap::new(), assigned_outputs: BTreeMap::new(), declared_outputs_span: None, @@ -175,6 +177,20 @@ impl<'a> AuthoringParser<'a> { )); } + if !self.explicit_output_order.is_empty() { + let output_order = self + .explicit_output_order + .iter() + .enumerate() + .map(|(index, name)| (name.clone(), index)) + .collect::>(); + self.output_statements.sort_by_key(|statement| { + output_statement_name(statement) + .and_then(|name| output_order.get(name).copied()) + .unwrap_or(usize::MAX) + }); + } + let mut derivative_statements = std::mem::take(&mut self.derivative_statements); inject_infusion_rates(&surface_routes, &routes, &mut derivative_statements); @@ -372,6 +388,7 @@ impl<'a> AuthoringParser<'a> { if lhs_trimmed == "outputs" { self.declared_outputs_span = Some(span); for ident in parse_output_label_list(rhs, rhs_abs)? { + self.explicit_output_order.push(ident.text.clone()); self.declared_outputs.insert(ident.text.clone()); self.explicit_outputs.insert(ident.text, ident.span); } @@ -467,7 +484,7 @@ impl<'a> AuthoringParser<'a> { } }; - let input = parse_ident_segment(call.argument, call.argument_start)?; + let input = parse_label_segment(call.argument, call.argument_start, "route label")?; let route_name = input.text.clone(); let destination = parse_place_at(rhs, line_start + arrow + 2)?; if self.routes.contains_key(&route_name) { @@ -498,7 +515,8 @@ impl<'a> AuthoringParser<'a> { ) -> Result<(), ParseError> { match call.callee.text.as_str() { "lag" | "fa" => { - let route_name = parse_ident_segment(call.argument, call.argument_start)?; + let route_name = + parse_label_segment(call.argument, call.argument_start, "route label")?; let value = parse_expr_at(rhs, rhs_abs)?; let property_name = match call.callee.text.as_str() { "lag" => "lag", @@ -928,17 +946,25 @@ fn parse_ident_segment(src: &str, abs_start: usize) -> Result } fn parse_output_label_segment(src: &str, abs_start: usize) -> Result { + parse_label_segment(src, abs_start, "output label") +} + +fn parse_label_segment( + src: &str, + abs_start: usize, + expected: &str, +) -> Result { let trimmed = src.trim(); let leading = src.len() - src.trim_start().len(); if trimmed.is_empty() { return Err(ParseError::new( - "expected output label", + format!("expected {expected}"), Span::new(abs_start, abs_start + src.len()), )); } if !is_valid_output_label(trimmed) { return Err(ParseError::new( - format!("expected output label, found `{trimmed}`"), + format!("expected {expected}, found `{trimmed}`"), Span::new(abs_start + leading, abs_start + leading + trimmed.len()), )); } @@ -1417,6 +1443,16 @@ fn join_covariate_spans(items: &[CovariateDecl]) -> Span { .unwrap_or_else(|| Span::empty(0)) } +fn output_statement_name(statement: &Stmt) -> Option<&str> { + match &statement.kind { + StmtKind::Assign(assign) => match &assign.target.kind { + AssignTargetKind::Name(name) => Some(name.text.as_str()), + _ => None, + }, + _ => None, + } +} + fn join_state_spans(items: &[StateDecl]) -> Span { items .iter() diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index c265b4df..fe844c37 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -563,7 +563,7 @@ impl Parser { } fn parse_route_decl(&mut self) -> Result { - let input = self.parse_ident()?; + let input = self.parse_label_name("route label")?; let arrow = self.expect_simple(|kind| matches!(kind, TokenKind::Arrow), "`->`")?; self.ensure_not_layout_boundary( arrow.span, @@ -902,9 +902,13 @@ impl Parser { } fn parse_output_target_name(&mut self) -> Result { + self.parse_label_name("output label") + } + + fn parse_label_name(&mut self, expected: &str) -> Result { let token = self .bump() - .ok_or_else(|| ParseError::new("expected output label", Span::empty(self.src_len)))?; + .ok_or_else(|| ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)))?; match token.kind { TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), TokenKind::Number(value) @@ -917,7 +921,7 @@ impl Parser { } other => Err(ParseError::new( format!( - "expected output label identifier or non-negative integer, found {}", + "expected {expected} identifier or non-negative integer, found {}", other.describe() ), token.span, diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 3749f183..2a46409a 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -543,14 +543,14 @@ mod tests { let subject = crate::Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index 5504ab08..a440c51d 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1360,21 +1360,33 @@ out(cp) = central / v ~ continuous() let cp = jit.output_index("cp").expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("ode") + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .observation(0.5, 0.0, "cp") + .observation(1.0, 0.0, "cp") + .observation(2.0, 0.0, "cp") + .observation(6.0, 0.0, "cp") + .observation(7.0, 0.0, "cp") + .observation(9.0, 0.0, "cp") + .build(); - let subject = Subject::builder("ode") - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .observation(0.5, 0.0, cp) - .observation(1.0, 0.0, cp) - .observation(2.0, 0.0, cp) - .observation(6.0, 0.0, cp) - .observation(7.0, 0.0, cp) - .observation(9.0, 0.0, cp) + let reference_subject = Subject::builder("ode") + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 0, 2.0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(7.0, 0.0, 0) + .observation(9.0, 0.0, 0) .build(); let support = vec![1.2, 0.15, 40.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit predictions"); let reference = ODE::new( @@ -1397,7 +1409,7 @@ out(cp) = central / v ~ continuous() .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference ode predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1491,22 +1503,35 @@ out(cp) = central / v ~ continuous() let cp = jit.output_index("cp").expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 1); + assert_eq!(cp, 0); - let subject = Subject::builder("ode") + let jit_subject = Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") + .build(); + + let reference_subject = Subject::builder("ode") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 1, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(7.0, 0) + .missing_observation(9.0, 0) .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit predictions"); let reference = ODE::new( @@ -1551,7 +1576,7 @@ out(cp) = central / v ~ continuous() .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference ode predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1574,18 +1599,28 @@ out(cp) = central / v ~ continuous() let oral = jit.route_index("oral").expect("oral route"); let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("analytical") - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + let reference_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.0, 0.15, 25.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit analytical predictions"); let reference = equation::Analytical::new( @@ -1603,7 +1638,7 @@ out(cp) = central / v ~ continuous() .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference analytical predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1628,19 +1663,30 @@ out(cp) = central / v ~ continuous() let oral = jit.route_index("oral").expect("oral route"); let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("sde") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("sde") + let reference_subject = Subject::builder("sde") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.1, 0.2, 0.12, 0.08, 15.0, 0.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit sde predictions"); let reference = SDE::new( @@ -1677,7 +1723,7 @@ out(cp) = central / v ~ continuous() .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference sde predictions"); for (jit_pred, reference_pred) in jit_predictions diff --git a/src/dsl/native.rs b/src/dsl/native.rs index c1ce8eac..6df2f05d 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -402,13 +402,8 @@ impl SharedNativeModel { label: &InputLabel, kind: RouteKind, ) -> Result { - if let Some(input) = self.route_index(label.as_str()) { - self.validate_input_for_kind(input, kind)?; - return Ok(input); - } - - let input = label - .index() + let input = self + .route_index(label.as_str()) .ok_or_else(|| PharmsolError::UnknownInputLabel { label: label.to_string(), })?; @@ -417,17 +412,11 @@ impl SharedNativeModel { } fn resolve_output_label(&self, label: &OutputLabel) -> Result { - if let Some(outeq) = self.output_index(label.as_str()) { - return Ok(outeq); - } - - let outeq = label - .index() - .ok_or_else(|| PharmsolError::UnknownOutputLabel { + self.output_index(label.as_str()).ok_or_else(|| { + PharmsolError::UnknownOutputLabel { label: label.to_string(), - })?; - self.validate_output(outeq)?; - Ok(outeq) + } + }) } fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1d49d82a..1d8a1327 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -376,12 +376,96 @@ fn runtime_model_from_parts( mod tests { use super::*; use crate::dsl::compile_sde_model_to_jit; + use crate::PharmsolError; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; use crate::SubjectBuilderExt; use approx::assert_relative_eq; use pharmsol_dsl::{DiagnosticPhase, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; use tempfile::tempdir; + const MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL: &str = r#" +name = multi_digit_output_runtime +kind = ode + +params = ke, v +states = central +outputs = 2, 10, 11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(10) = central / v ~ continuous() +out(2) = central / v ~ continuous() +out(11) = central / v ~ continuous() +"#; + + const NUMERIC_ROUTE_LABELS_RUNTIME_DSL: &str = r#" +name = numeric_route_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(10) -> central +bolus(11) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + fn corpus_source() -> &'static str { STRUCTURED_BLOCK_CORPUS } @@ -397,17 +481,17 @@ mod tests { pharmsol_dsl::lower_typed_model(model).expect("lower corpus model") } - fn ode_subject(output: usize, oral: usize, iv: usize) -> Subject { + fn ode_subject() -> Subject { Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, output) - .missing_observation(1.0, output) - .missing_observation(2.0, output) - .missing_observation(6.0, output) - .missing_observation(7.0, output) - .missing_observation(9.0, output) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build() } @@ -421,6 +505,80 @@ mod tests { .collect() } + fn compile_runtime_backend_matrix( + source: &str, + model_name: &str, + work_dir: &std::path::Path, + ) -> (CompiledRuntimeModel, CompiledRuntimeModel, CompiledRuntimeModel) { + let jit = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Jit, + |_, _| {}, + ) + .expect("compile jit runtime model"); + let aot = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::NativeAot( + NativeAotCompileOptions::new(work_dir.join(format!("{model_name}-aot-build"))) + .with_output(work_dir.join(format!("{model_name}.pkm"))), + ), + |_, _| {}, + ) + .expect("compile aot runtime model"); + let wasm = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Wasm, + |_, _| {}, + ) + .expect("compile wasm runtime model"); + + (jit, aot, wasm) + } + + fn numeric_route_subject() -> Subject { + Subject::builder("numeric-route-runtime") + .bolus(0.0, 120.0, "10") + .bolus(1.0, 80.0, "11") + .missing_observation(0.5, "cp") + .missing_observation(1.5, "cp") + .build() + } + + fn assert_unknown_output_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == expected_label + )); + } + + fn assert_unknown_input_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == expected_label + )); + } + #[test] fn runtime_backend_matrix_matches_ode_predictions() { let work_dir = tempdir().expect("tempdir"); @@ -460,10 +618,73 @@ mod tests { vec!["ka", "cl", "v", "tlag", "f_oral"] ); - let oral = jit.route_index("oral").expect("oral route"); - let iv = jit.route_index("iv").expect("iv route"); - let cp = jit.output_index("cp").expect("cp output"); - let subject = ode_subject(cp, oral, iv); + assert!(jit.route_index("oral").is_some()); + assert!(jit.route_index("iv").is_some()); + assert_eq!(jit.output_index("cp"), Some(0)); + let subject = ode_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_preserves_multi_digit_output_label_order() { + let work_dir = tempdir().expect("tempdir"); + let (jit, aot, wasm) = compile_runtime_backend_matrix( + MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL, + "multi_digit_output_runtime", + work_dir.path(), + ); + + assert_eq!(jit.output_index("2"), Some(0)); + assert_eq!(jit.output_index("10"), Some(1)); + assert_eq!(jit.output_index("11"), Some(2)); + assert_eq!(aot.output_index("2"), Some(0)); + assert_eq!(aot.output_index("10"), Some(1)); + assert_eq!(aot.output_index("11"), Some(2)); + assert_eq!(wasm.output_index("2"), Some(0)); + assert_eq!(wasm.output_index("10"), Some(1)); + assert_eq!(wasm.output_index("11"), Some(2)); + } + + #[test] + fn runtime_backend_matrix_supports_multi_digit_numeric_route_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "numeric_route_runtime", + work_dir.path(), + ); + + assert_eq!(jit.route_index("10"), Some(0)); + assert_eq!(jit.route_index("11"), Some(1)); + assert_eq!(aot.route_index("10"), Some(0)); + assert_eq!(aot.route_index("11"), Some(1)); + assert_eq!(wasm.route_index("10"), Some(0)); + assert_eq!(wasm.route_index("11"), Some(1)); + + let subject = numeric_route_subject(); let jit_values = subject_values( &jit.estimate_predictions(&subject, &support) @@ -489,6 +710,44 @@ mod tests { } } + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_output_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + + assert_unknown_output_label(&jit, &subject, &support, "10"); + assert_unknown_output_label(&aot, &subject, &support, "10"); + assert_unknown_output_label(&wasm, &subject, &support, "10"); + } + + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_input_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_input_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + + assert_unknown_input_label(&jit, &subject, &support, "10"); + assert_unknown_input_label(&aot, &subject, &support, "10"); + assert_unknown_input_label(&wasm, &subject, &support, "10"); + } + #[test] fn runtime_compile_preserves_parse_diagnostic_structure() { let source = "model broken { kind ode outputs { cp = 1 + } }"; diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index 37a5891a..c7164d71 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -53,6 +53,59 @@ dx(central) = ka * depot - (cl / v) * central out(cp) = central / (v * (wt / 70.0)) ~ continuous() "#; +const ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL: &str = r#" +name = multi_digit_output_order +kind = ode + +params = ke, v +states = central +outputs = 2, 10, 11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(10) = central / v ~ continuous() +out(2) = central / v ~ continuous() +out(11) = central / v ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL: &str = r#" +name = authoring_numeric_routes +kind = ode + +states = first, second +outputs = cp + +bolus(10) -> first +bolus(11) -> second + +dx(first) = 0 +dx(second) = 0 + +out(cp) = first + second ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL: &str = r#"model structured_numeric_routes { + kind ode + states { + first, + second, + } + routes { + 10 -> first + 11 -> second + } + dynamics { + ddt(first) = 0 + ddt(second) = 0 + } + outputs { + cp = first + second + } +} +"#; + const ODE_INVALID_INFUSION_LAG_DSL: &str = r#" name = invalid_infusion_lag_parity kind = ode @@ -107,6 +160,58 @@ out(0) = 2 * central / v ~ continuous() out(1) = 3 * central / v ~ continuous() "#; +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + const ANALYTICAL_DSL: &str = r#" name = one_cmt_abs_parity kind = analytical @@ -1056,6 +1161,93 @@ fn ode_dsl_and_handwritten_metadata_agree_on_public_shape() { assert_eq!(handwritten_view, dsl_view); } +#[test] +fn ode_dsl_declared_output_order_controls_dense_indices_for_multi_digit_labels() { + let dsl_view = dsl_metadata_view(ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL); + + assert_eq!( + dsl_view.outputs, + vec![ + NamedIndex { + name: "2".to_string(), + index: 0, + }, + NamedIndex { + name: "10".to_string(), + index: 1, + }, + NamedIndex { + name: "11".to_string(), + index: 2, + }, + ] + ); +} + +#[test] +fn ode_authoring_dsl_supports_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "10".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "11".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + +#[test] +fn ode_structured_dsl_supports_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "10".to_string(), + kind: None, + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "11".to_string(), + kind: None, + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + #[test] fn ode_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { let handwritten = handwritten_ode_macro_model(); @@ -1418,3 +1610,49 @@ fn ode_runtime_jit_preserves_mixed_output_labels() { assert_relative_eq!(predictions[1], 2.0 * predictions[0], epsilon = 1e-6); assert_relative_eq!(predictions[2], 3.0 * predictions[0], epsilon = 1e-6); } + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_output_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL, + "undeclared_numeric_output_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == "10" + )); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_input_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL, + "undeclared_numeric_input_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == "10" + )); +} diff --git a/tests/full_feature_dsl_backend_parity.rs b/tests/full_feature_dsl_backend_parity.rs new file mode 100644 index 00000000..1aba8213 --- /dev/null +++ b/tests/full_feature_dsl_backend_parity.rs @@ -0,0 +1,201 @@ +#[path = "support/runtime_corpus.rs"] +mod runtime_corpus; + +#[cfg(all(feature = "dsl-jit", feature = "dsl-wasm"))] +mod tests { + use super::runtime_corpus::{self as corpus, CorpusCase}; + use pharmsol::dsl::{CompiledRuntimeModel, RuntimeBackend}; + + fn owned_names(names: &[&str]) -> Vec { + names.iter().map(|name| (*name).to_owned()).collect() + } + + fn assert_info_matches( + left_label: &str, + left: &CompiledRuntimeModel, + right_label: &str, + right: &CompiledRuntimeModel, + ) { + assert_eq!( + left.info(), + right.info(), + "{left_label} model info diverged from {right_label}" + ); + } + + fn assert_ode_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "ode_full_feature_parity"); + assert_eq!(info.parameters, owned_names(&[ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ])); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes.iter().map(|route| route.index).collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + assert_eq!(model.route_index("oral"), Some(0)); + assert_eq!(model.route_index("load"), Some(1)); + assert_eq!(model.route_index("iv"), Some(0)); + assert_eq!(model.output_index("cp"), Some(0)); + } + + fn assert_analytical_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "analytical_full_feature_parity"); + assert_eq!(info.parameters, owned_names(&[ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ])); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes.iter().map(|route| route.index).collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + assert_eq!(model.route_index("oral"), Some(0)); + assert_eq!(model.route_index("load"), Some(1)); + assert_eq!(model.route_index("iv"), Some(0)); + assert_eq!(model.output_index("cp"), Some(0)); + } + + fn assert_full_backend_parity( + case: CorpusCase, + assert_public_shape: fn(&CompiledRuntimeModel), + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(case)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + assert_public_shape(&jit); + corpus::assert_runtime_model_matches_reference(case, "runtime-jit", &jit)?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(case, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + assert_public_shape(&aot); + corpus::assert_runtime_model_matches_reference(case, "runtime-native-aot", &aot)?; + } + + let wasm = corpus::compile_runtime_wasm_model(case)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + assert_public_shape(&wasm); + corpus::assert_runtime_model_matches_reference(case, "runtime-wasm", &wasm)?; + + assert_info_matches("runtime-jit", &jit, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_info_matches("runtime-jit", &jit, "runtime-native-aot", &aot); + assert_info_matches("runtime-native-aot", &aot, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-native-aot", + &aot, + )?; + corpus::assert_runtime_models_match_each_other( + case, + "runtime-native-aot", + &aot, + "runtime-wasm", + &wasm, + )?; + } + + Ok(()) + } + + #[test] + fn ode_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity(CorpusCase::OdeFull, assert_ode_full_public_shape) + } + + #[test] + fn analytical_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity(CorpusCase::AnalyticalFull, assert_analytical_full_public_shape) + } +} \ No newline at end of file diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index 4c82be4f..6e7e5f8e 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -55,6 +55,14 @@ fn subject_for_indices(route_index: usize, output_index: usize) -> Subject { builder.build() } +fn subject_for_labels(route_label: &str, output_label: &str) -> Subject { + let mut builder = Subject::builder(MODEL_NAME).infusion(0.0, 500.0, route_label, 0.5); + for time in OBSERVATION_TIMES { + builder = builder.missing_observation(time, output_label); + } + builder.build() +} + pub fn subject() -> Subject { subject_for_indices(0, 0) } @@ -65,12 +73,15 @@ pub fn subject() -> Subject { feature = "dsl-wasm" ))] pub fn subject_for_runtime_model(model: &pharmsol::dsl::CompiledRuntimeModel) -> Subject { - let route_index = model - .route_index("iv") - .or_else(|| model.route_index("input_0")) - .expect("bimodal_ke route is available"); - let output_index = model.output_index("cp").expect("cp output is available"); - subject_for_indices(route_index, output_index) + let route_label = if model.route_index("iv").is_some() { + "iv" + } else if model.route_index("input_0").is_some() { + "input_0" + } else { + panic!("bimodal_ke route is available"); + }; + model.output_index("cp").expect("cp output is available"); + subject_for_labels(route_label, "cp") } pub fn reference_values() -> Result, Box> { diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 3ed75511..1ca8ae78 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -208,52 +208,52 @@ impl CorpusCase { } fn runtime_subject(self, model: &CompiledRuntimeModel) -> Result> { - let cp = model + model .output_index("cp") .ok_or_else(|| io::Error::other(format!("{}: missing cp output", self.label())))?; let subject = match self { Self::Ode => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build() } Self::OdeFull => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let load = model.route_index("load").ok_or_else(|| { + model.route_index("load").ok_or_else(|| { io::Error::other(format!("{}: missing load route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 80.0, load) - .bolus(1.0, 120.0, oral) - .infusion(6.0, 150.0, iv, 2.5) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -261,39 +261,39 @@ impl CorpusCase { .build() } Self::Analytical => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build() } Self::AnalyticalFull => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let load = model.route_index("load").ok_or_else(|| { + model.route_index("load").ok_or_else(|| { io::Error::other(format!("{}: missing load route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 60.0, load) - .bolus(1.0, 100.0, oral) - .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -301,16 +301,16 @@ impl CorpusCase { .build() } Self::Sde => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build() } }; From ecf05753b1db510e9cffbe873fdb480b3561a9a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:22:23 +0100 Subject: [PATCH 11/22] chore: fmt --- pharmsol-dsl/src/authoring.rs | 6 +-- pharmsol-dsl/src/parser.rs | 6 +-- src/dsl/native.rs | 17 +++--- src/dsl/runtime.rs | 8 ++- tests/full_feature_dsl_backend_parity.rs | 67 +++++++++++++++--------- 5 files changed, 59 insertions(+), 45 deletions(-) diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 7b0b4dd6..0496c0fc 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -949,11 +949,7 @@ fn parse_output_label_segment(src: &str, abs_start: usize) -> Result Result { +fn parse_label_segment(src: &str, abs_start: usize, expected: &str) -> Result { let trimmed = src.trim(); let leading = src.len() - src.trim_start().len(); if trimmed.is_empty() { diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index fe844c37..98c6b0a4 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -906,9 +906,9 @@ impl Parser { } fn parse_label_name(&mut self, expected: &str) -> Result { - let token = self - .bump() - .ok_or_else(|| ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)))?; + let token = self.bump().ok_or_else(|| { + ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)) + })?; match token.kind { TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), TokenKind::Number(value) diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 6df2f05d..97c41013 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -402,21 +402,20 @@ impl SharedNativeModel { label: &InputLabel, kind: RouteKind, ) -> Result { - let input = self - .route_index(label.as_str()) - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: label.to_string(), - })?; + let input = + self.route_index(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; self.validate_input_for_kind(input, kind)?; Ok(input) } fn resolve_output_label(&self, label: &OutputLabel) -> Result { - self.output_index(label.as_str()).ok_or_else(|| { - PharmsolError::UnknownOutputLabel { + self.output_index(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownOutputLabel { label: label.to_string(), - } - }) + }) } fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1d8a1327..ba6dd5cd 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -376,8 +376,8 @@ fn runtime_model_from_parts( mod tests { use super::*; use crate::dsl::compile_sde_model_to_jit; - use crate::PharmsolError; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; + use crate::PharmsolError; use crate::SubjectBuilderExt; use approx::assert_relative_eq; use pharmsol_dsl::{DiagnosticPhase, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; @@ -509,7 +509,11 @@ out(cp) = central / v ~ continuous() source: &str, model_name: &str, work_dir: &std::path::Path, - ) -> (CompiledRuntimeModel, CompiledRuntimeModel, CompiledRuntimeModel) { + ) -> ( + CompiledRuntimeModel, + CompiledRuntimeModel, + CompiledRuntimeModel, + ) { let jit = compile_module_source_to_runtime( source, Some(model_name), diff --git a/tests/full_feature_dsl_backend_parity.rs b/tests/full_feature_dsl_backend_parity.rs index 1aba8213..929e7243 100644 --- a/tests/full_feature_dsl_backend_parity.rs +++ b/tests/full_feature_dsl_backend_parity.rs @@ -27,18 +27,21 @@ mod tests { let info = model.info(); assert_eq!(info.name, "ode_full_feature_parity"); - assert_eq!(info.parameters, owned_names(&[ - "ka", - "ke", - "kcp", - "kpc", - "v", - "tlag", - "f_oral", - "base_depot", - "base_central", - "base_peripheral", - ])); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ]) + ); assert_eq!( info.covariates .iter() @@ -61,7 +64,10 @@ mod tests { vec![0, 1, 2] ); assert_eq!( - info.routes.iter().map(|route| route.index).collect::>(), + info.routes + .iter() + .map(|route| route.index) + .collect::>(), vec![0, 1, 0] ); assert_eq!( @@ -81,16 +87,19 @@ mod tests { let info = model.info(); assert_eq!(info.name, "analytical_full_feature_parity"); - assert_eq!(info.parameters, owned_names(&[ - "ka", - "ke", - "v", - "tlag", - "f_oral", - "base_gut", - "base_central", - "tvke", - ])); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + ); assert_eq!( info.covariates .iter() @@ -113,7 +122,10 @@ mod tests { vec![0, 1, 2] ); assert_eq!( - info.routes.iter().map(|route| route.index).collect::>(), + info.routes + .iter() + .map(|route| route.index) + .collect::>(), vec![0, 1, 0] ); assert_eq!( @@ -196,6 +208,9 @@ mod tests { #[test] fn analytical_full_feature_dsl_matches_handwritten_across_backends( ) -> Result<(), Box> { - assert_full_backend_parity(CorpusCase::AnalyticalFull, assert_analytical_full_public_shape) + assert_full_backend_parity( + CorpusCase::AnalyticalFull, + assert_analytical_full_public_shape, + ) } -} \ No newline at end of file +} From b65d795657cc537b6c4de2929d2e608bd7b8575c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:26:14 +0100 Subject: [PATCH 12/22] chore: update README.md --- README.md | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index d24aea87..cf9865d0 100644 --- a/README.md +++ b/README.md @@ -36,15 +36,12 @@ let analytical = analytical! { }, }; -let iv = analytical.route_index("iv").unwrap(); -let cp = analytical.output_index("cp").unwrap(); - let subject = Subject::builder("patient_001") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = analytical @@ -121,12 +118,12 @@ use pharmsol::prelude::*; use pharmsol::nca::NCAOptions; let subject = Subject::builder("patient_001") - .bolus(0.0, 100.0, 0) // 100 mg oral dose - .observation(0.5, 5.0, 0) - .observation(1.0, 10.0, 0) - .observation(2.0, 8.0, 0) - .observation(4.0, 4.0, 0) - .observation(8.0, 2.0, 0) + .bolus(0.0, 100.0, "oral") // 100 mg oral dose + .observation(0.5, 5.0, "cp") + .observation(1.0, 10.0, "cp") + .observation(2.0, 8.0, "cp") + .observation(4.0, 4.0, "cp") + .observation(8.0, 2.0, "cp") .build(); let result = subject.nca(&NCAOptions::default()).expect("NCA failed"); From 627085bf7d5fedd9f5a4e56b141295564ae20014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:52:07 +0100 Subject: [PATCH 13/22] chore: {}->[] :D --- README.md | 8 +- examples/analytical_readme.rs | 4 +- examples/analytical_vs_ode.rs | 32 +-- examples/compare_solvers.rs | 4 +- examples/covariates.rs | 4 +- examples/macro_vs_handwritten_one_cpt.rs | 4 +- examples/macro_vs_handwritten_two_cpt.rs | 4 +- examples/ode_readme.rs | 4 +- examples/one_compartment.rs | 8 +- examples/sde_readme.rs | 4 +- examples/two_compartment.rs | 4 +- pharmsol-macros/src/lib.rs | 305 ++++++++--------------- tests/analytical_macro_lowering.rs | 16 +- tests/authoring_parity_corpus.rs | 34 +-- tests/full_feature_macro_parity.rs | 20 +- tests/ode_macro_lowering.rs | 131 +++------- tests/sde_macro_lowering.rs | 16 +- 17 files changed, 213 insertions(+), 389 deletions(-) diff --git a/README.md b/README.md index cf9865d0..73932de2 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,9 @@ let analytical = analytical! { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -61,9 +61,9 @@ let ode = ode! { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs index 8e5b97f7..676f07b9 100644 --- a/examples/analytical_readme.rs +++ b/examples/analytical_readme.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 290d6632..3fd58fd1 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -72,9 +72,9 @@ fn one_cmt_iv(params: &[f64]) { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -86,9 +86,9 @@ fn one_cmt_iv(params: &[f64]) { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, @@ -114,9 +114,9 @@ fn one_cmt_oral(params: &[f64]) { params: [ka, ke, v], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: one_compartment_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -128,9 +128,9 @@ fn one_cmt_oral(params: &[f64]) { params: [ka, ke, v], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -157,9 +157,9 @@ fn two_cmt_iv(params: &[f64]) { params: [ke, k12, k21, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: two_compartments, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -171,9 +171,9 @@ fn two_cmt_iv(params: &[f64]) { params: [ke, k12, k21, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - k12 * x[central] + k21 * x[peripheral]; dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; @@ -200,9 +200,9 @@ fn two_cmt_oral(params: &[f64]) { params: [ka, ke, k12, k21, v], states: [gut, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: two_compartments_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -214,9 +214,9 @@ fn two_cmt_oral(params: &[f64]) { params: [ka, ke, k12, k21, v], states: [gut, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central] - k12 * x[central] + k21 * x[peripheral]; diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index ad705931..a8067485 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -24,10 +24,10 @@ fn two_cpt(solver: OdeSolver) -> equation::ODE { params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(load) -> central, infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; diff --git a/examples/covariates.rs b/examples/covariates.rs index 9aabf491..180a0173 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -7,9 +7,9 @@ fn main() { covariates: [creatinine, age], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _t, dx| { let scaled_ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index be9edb2a..c7b088a5 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -12,9 +12,9 @@ fn macro_model() -> equation::ODE { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 114024bd..377e1e88 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -13,10 +13,10 @@ fn macro_model() -> equation::ODE { params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(load) -> central, infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index a0174801..7b436d0b 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index aafdf2b2..021e06f2 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -20,9 +20,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs index 6106b17a..cc47cdda 100644 --- a/examples/sde_readme.rs +++ b/examples/sde_readme.rs @@ -7,9 +7,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { states: [central], outputs: [cp], particles: 16, - routes: { + routes: [ infusion(iv) -> central, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index 64d554af..fdba715e 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -27,9 +27,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { covariates: [wt], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _t, dx| { // CL: Clearance (L/hr), V: Central volume (L) // Vp: Peripheral volume (L), Q: Inter-compartmental clearance (L/hr) diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 0d143184..7e483951 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -27,7 +27,6 @@ struct OdeInput { states: Vec, outputs: Vec, routes: Vec, - diffeq_mode: OdeDiffeqMode, diffeq: ExprClosure, lag: Option, fa: Option, @@ -66,12 +65,6 @@ struct SdeInput { out: ExprClosure, } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum OdeDiffeqMode { - InjectedRouteInputs, - ExplicitRouteVectors, -} - struct OdeRouteDecl { kind: OdeRouteKind, input: SymbolicIndex, @@ -275,7 +268,7 @@ impl Parse for OdeInput { let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?; let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?; let out = out.ok_or_else(|| missing_required_ode_field("out"))?; - let diffeq_mode = classify_diffeq_mode(&diffeq, &routes)?; + validate_ode_diffeq_uses_automatic_injection(&diffeq, &routes)?; validate_unique_idents("parameter", ¶ms, "ode!")?; validate_unique_idents("covariate", &covariates, "ode!")?; @@ -300,7 +293,6 @@ impl Parse for OdeInput { init: init.as_ref(), out: &out, }, - diffeq_mode, }, )?; @@ -311,7 +303,6 @@ impl Parse for OdeInput { states, outputs, routes, - diffeq_mode, diffeq, lag, fa, @@ -694,8 +685,18 @@ fn parse_symbolic_index_list(input: ParseStream) -> syn::Result syn::Result> { + if input.peek(token::Brace) { + return Err(input.error("declaration-first macro `routes` must use `[...]`, not `{...}`")); + } + + if !input.peek(token::Bracket) { + return Err( + input.error("expected a bracketed route list like `routes: [infusion(iv) -> central]`") + ); + } + let content; - syn::braced!(content in input); + syn::bracketed!(content in input); Ok( Punctuated::::parse_terminated(&content)? .into_iter() @@ -1063,13 +1064,12 @@ fn generate_covariate_bindings( } } -fn classify_diffeq_mode( +fn validate_ode_diffeq_uses_automatic_injection( diffeq: &ExprClosure, routes: &[OdeRouteDecl], -) -> syn::Result { +) -> syn::Result<()> { match closure_param_names(diffeq).len() { - 3 => Ok(OdeDiffeqMode::InjectedRouteInputs), - 7 => Ok(OdeDiffeqMode::ExplicitRouteVectors), + 3 => Ok(()), 5 => { let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref()); let route_inputs = route_input_idents(routes); @@ -1082,14 +1082,17 @@ fn classify_diffeq_mode( .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident)); if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input { - Ok(OdeDiffeqMode::ExplicitRouteVectors) + Err(syn::Error::new_spanned( + diffeq, + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx| and remove manual `bolus[...]` / `rateiv[...]` terms", + )) } else { - Ok(OdeDiffeqMode::InjectedRouteInputs) + Ok(()) } } _ => Err(syn::Error::new_spanned( diffeq, - "declaration-first `ode!` requires `diffeq` to have either 3 parameters: |x, t, dx|, 5 parameters: |x, p, t, dx, cov| or |x, t, dx, bolus, rateiv|, or 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", )), } } @@ -1214,7 +1217,6 @@ struct AnalyticalBindingClosures<'a> { struct OdeBindingClosures<'a> { diffeq: &'a ExprClosure, common: CommonBindingClosures<'a>, - diffeq_mode: OdeDiffeqMode, } #[derive(Clone, Copy)] @@ -1238,7 +1240,6 @@ fn validate_named_binding_compatibility( let OdeBindingClosures { diffeq, common: CommonBindingClosures { lag, fa, init, out }, - diffeq_mode, } = closures; let route_inputs = route_input_idents(routes); @@ -1289,31 +1290,6 @@ fn validate_named_binding_compatibility( validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?; validate_closure_param_conflicts("diffeq", diffeq, states, "state")?; - if diffeq_mode == OdeDiffeqMode::ExplicitRouteVectors { - validate_binding_conflicts( - "parameter", - params, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_binding_conflicts( - "state", - states, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_binding_conflicts( - "covariate", - covariates, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_closure_param_conflicts("diffeq", diffeq, &route_inputs, "route")?; - } - if let Some(lag) = lag { validate_binding_conflicts( "covariate", @@ -1881,7 +1857,6 @@ fn expand_ode_init( fn expand_route_metadata( routes: &[OdeRouteDecl], - diffeq_mode: OdeDiffeqMode, lag_routes: &HashSet, fa_routes: &HashSet, ) -> Vec { @@ -1899,10 +1874,6 @@ fn expand_route_metadata( quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } } }; - let input_policy = match diffeq_mode { - OdeDiffeqMode::InjectedRouteInputs => quote! { .inject_input_to_destination() }, - OdeDiffeqMode::ExplicitRouteVectors => quote! { .expect_explicit_input() }, - }; let lag_flag = if lag_routes.contains(&route_name) { quote! { .with_lag() } } else { @@ -1919,7 +1890,7 @@ fn expand_route_metadata( .to_state(stringify!(#destination)) #lag_flag #fa_flag - #input_policy + .inject_input_to_destination() } }) .collect() @@ -2151,148 +2122,64 @@ fn expand_diffeq( states: &[Ident], routes: &[OdeRouteDecl], route_bindings: &[(SymbolicIndex, usize)], - diffeq_mode: OdeDiffeqMode, ) -> syn::Result { let state_consts = generate_index_consts(states); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( + diffeq, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); + let body = &diffeq.body; + let dx_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) + }; + let route_terms = expand_injected_ode_route_terms( + routes, + states, + route_bindings, + &dx_binding, + &bolus, + &rateiv, + ); - match diffeq_mode { - OdeDiffeqMode::ExplicitRouteVectors => { - let route_consts = generate_mapped_index_consts(route_bindings); - let x = generated_ident("__pharmsol_x"); - let p = generated_ident("__pharmsol_p"); - let t = generated_ident("__pharmsol_t"); - let dx = generated_ident("__pharmsol_dx"); - let bolus = generated_ident("__pharmsol_bolus"); - let rateiv = generated_ident("__pharmsol_rateiv"); - let cov = generated_ident("__pharmsol_cov"); - let full_inputs = [ - x.clone(), - p.clone(), - t.clone(), - dx.clone(), - bolus.clone(), - rateiv.clone(), - cov.clone(), - ]; - let reduced_inputs = [ - x.clone(), - t.clone(), - dx.clone(), - bolus.clone(), - rateiv.clone(), - ]; - let input_aliases = generate_supported_input_aliases( - diffeq, - &[&full_inputs, &reduced_inputs], - "declaration-first `ode!` explicit-route `diffeq` requires either 7 parameters: |x, p, t, dx, bolus, rateiv, cov| or 5 parameters: |x, t, dx, bolus, rateiv|", - )?; - let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); - let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let bolus_binding = if diffeq.inputs.len() == full_inputs.len() { - closure_param_ident(diffeq, 4).unwrap_or_else(|| bolus.clone()) - } else { - closure_param_ident(diffeq, 3).unwrap_or_else(|| bolus.clone()) - }; - let rateiv_binding = if diffeq.inputs.len() == full_inputs.len() { - closure_param_ident(diffeq, 5).unwrap_or_else(|| rateiv.clone()) - } else { - closure_param_ident(diffeq, 4).unwrap_or_else(|| rateiv.clone()) - }; - let route_label_map = symbolic_numeric_binding_map(route_bindings); - let body = NumericLabelRewriter::rewrite( - diffeq.body.as_ref(), - vec![ - IndexRewriteTarget::new(bolus_binding, route_label_map.clone()), - IndexRewriteTarget::new(rateiv_binding, route_label_map), - ], - None, - ); - - Ok(quote! {{ - let __pharmsol_diffeq: fn( - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - f64, - &mut ::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::data::Covariates, - ) = |#x: &::pharmsol::simulator::V, - #p: &::pharmsol::simulator::V, - #t: f64, - #dx: &mut ::pharmsol::simulator::V, - #bolus: &::pharmsol::simulator::V, - #rateiv: &::pharmsol::simulator::V, - #cov: &::pharmsol::data::Covariates| { - #input_aliases - #state_consts - #route_consts - #parameter_bindings - #covariate_bindings - #body - }; - __pharmsol_diffeq - }}) - } - OdeDiffeqMode::InjectedRouteInputs => { - let x = generated_ident("__pharmsol_x"); - let p = generated_ident("__pharmsol_p"); - let t = generated_ident("__pharmsol_t"); - let dx = generated_ident("__pharmsol_dx"); - let bolus = generated_ident("__pharmsol_bolus"); - let rateiv = generated_ident("__pharmsol_rateiv"); - let cov = generated_ident("__pharmsol_cov"); - let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; - let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; - let input_aliases = generate_supported_input_aliases( - diffeq, - &[&full_inputs, &reduced_inputs], - "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", - )?; - let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); - let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let body = &diffeq.body; - let dx_binding = if diffeq.inputs.len() == full_inputs.len() { - closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) - } else { - closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) - }; - let route_terms = expand_injected_ode_route_terms( - routes, - states, - route_bindings, - &dx_binding, - &bolus, - &rateiv, - ); - - Ok(quote! {{ - let __pharmsol_diffeq: fn( - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - f64, - &mut ::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::data::Covariates, - ) = |#x: &::pharmsol::simulator::V, - #p: &::pharmsol::simulator::V, - #t: f64, - #dx: &mut ::pharmsol::simulator::V, - #bolus: &::pharmsol::simulator::V, - #rateiv: &::pharmsol::simulator::V, - #cov: &::pharmsol::data::Covariates| { - #input_aliases - #state_consts - #parameter_bindings - #covariate_bindings - #body - #route_terms - }; - __pharmsol_diffeq - }}) - } - } + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + #route_terms + }; + __pharmsol_diffeq + }}) } fn resolve_analytical_structure(structure: &Ident) -> syn::Result { @@ -2883,7 +2770,6 @@ pub fn ode(input: TokenStream) -> TokenStream { &input.states, &input.routes, &route_bindings, - input.diffeq_mode, ) { Ok(diffeq) => diffeq, Err(error) => return error.to_compile_error().into(), @@ -2909,7 +2795,7 @@ pub fn ode(input: TokenStream) -> TokenStream { let covariates = &input.covariates; let states = &input.states; let outputs = &input.outputs; - let routes = expand_route_metadata(&input.routes, input.diffeq_mode, &lag_routes, &fa_routes); + let routes = expand_route_metadata(&input.routes, &lag_routes, &fa_routes); let covariate_metadata = if covariates.is_empty() { quote! {} } else { @@ -3339,7 +3225,7 @@ mod tests { #[test] fn validates_route_destinations() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> peripheral }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> peripheral], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown route destination must fail"); @@ -3352,7 +3238,7 @@ mod tests { #[test] fn rejects_named_binding_collisions() { let error = syn::parse_str::( - "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: [infusion(iv) -> central], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("parameter/state binding collisions must fail"); @@ -3365,7 +3251,7 @@ mod tests { #[test] fn ode_route_bindings_share_inputs_by_kind_local_ordinal() { let input = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: { bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot }, diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: [bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot], diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", ) .expect("declaration-first ode input should parse"); @@ -3416,7 +3302,7 @@ mod tests { #[test] fn analytical_accepts_extra_parameters_beyond_kernel_arity() { let input = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", ) .expect("extra declared parameters should be allowed"); @@ -3429,7 +3315,7 @@ mod tests { #[test] fn analytical_rejects_unknown_structure() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: mystery, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: mystery, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown analytical structure must fail"); @@ -3442,7 +3328,7 @@ mod tests { #[test] fn analytical_rejects_insufficient_kernel_parameters() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", ) .err() .expect("insufficient kernel parameters must fail"); @@ -3455,7 +3341,7 @@ mod tests { #[test] fn analytical_rejects_unknown_route_property_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown lag route must fail"); @@ -3468,7 +3354,7 @@ mod tests { #[test] fn analytical_rejects_infusion_lag_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", ) .err() .expect("infusion lag must fail"); @@ -3481,7 +3367,7 @@ mod tests { #[test] fn sde_requires_particles() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: [infusion(iv) -> central], drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("missing particles must fail"); @@ -3494,7 +3380,7 @@ mod tests { #[test] fn sde_rejects_unknown_route_property_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown lag route must fail"); @@ -3507,7 +3393,7 @@ mod tests { #[test] fn sde_rejects_infusion_lag_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", ) .err() .expect("infusion lag must fail"); @@ -3516,4 +3402,17 @@ mod tests { .to_string() .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`")); } + + #[test] + fn rejects_braced_route_lists() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("braced route lists must fail"); + + assert!(error + .to_string() + .contains("declaration-first macro `routes` must use `[...]`, not `{...}`")); + } } diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index 796cb55e..f527978f 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -56,9 +56,9 @@ fn macro_one_compartment() -> equation::Analytical { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _t, y| { y[cp] = x[central] / v; @@ -99,9 +99,9 @@ fn macro_one_compartment_with_absorption() -> equation::Analytical { params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: one_compartment_with_absorption, lag: |_t| { lag! { oral => tlag } @@ -166,10 +166,10 @@ fn macro_shared_input_analytical() -> equation::Analytical { params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, lag: |_t| { lag! { oral => tlag } @@ -229,10 +229,10 @@ fn macro_covariate_analytical() -> equation::Analytical { covariates: [wt, renal], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, sec: |_t| { let wt_scale = (wt / 70.0).powf(0.75); diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index c7164d71..be80f10e 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -580,9 +580,9 @@ fn macro_ode_model() -> equation::ODE { covariates: [wt], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - (cl / v) * x[central]; @@ -676,13 +676,13 @@ fn runtime_shared_input_macro_ode() -> equation::ODE { params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, infusion(iv) -> central, - }, - diffeq: |x, _p, _t, dx, bolus, rateiv, _cov| { - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_p, _t, _cov| { lag! { oral => tlag } @@ -731,10 +731,10 @@ fn runtime_shared_input_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) .expect("handwritten shared-input ODE metadata should validate") @@ -791,10 +791,10 @@ fn runtime_shared_input_macro_analytical() -> equation::Analytical { params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { oral => tlag } @@ -856,10 +856,10 @@ fn runtime_shared_input_macro_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -976,9 +976,9 @@ fn macro_analytical_model() -> equation::Analytical { params: [ka, ke, v], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, - }, + ], structure: one_compartment_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -1019,9 +1019,9 @@ fn macro_sde_model() -> equation::SDE { states: [depot, central], outputs: [cp], particles: 256, - routes: { + routes: [ bolus(oral) -> depot, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - ke * x[central]; diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index e3175f84..5017902e 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -14,19 +14,19 @@ fn macro_ode_model() -> equation::ODE { covariates: [wt, renal], states: [depot, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, bolus(load) -> central, infusion(iv) -> central, - }, - diffeq: |x, _t, dx, bolus, rateiv| { + ], + diffeq: |x, _t, dx| { let wt_scale = (wt / 70.0).powf(0.75); let renal_scale = (renal / 90.0).powf(0.25); let adjusted_ke = ke * wt_scale * renal_scale; let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = bolus[load] + ka * x[depot] + rateiv[iv] + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - (adjusted_ke + adjusted_kcp) * x[central] + kpc * x[peripheral]; dx[peripheral] = adjusted_kcp * x[central] - kpc * x[peripheral]; @@ -185,13 +185,13 @@ fn handwritten_ode_model() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::bolus("load") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) .expect("handwritten ODE metadata should validate") @@ -224,11 +224,11 @@ fn macro_analytical_model() -> equation::Analytical { covariates: [wt, renal], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, bolus(load) -> central, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, sec: |_t| { let wt_scale = (wt / 70.0).powf(0.75); diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 480f7e80..99e0eeab 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -56,9 +56,9 @@ fn injected_macro_ode() -> equation::ODE { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _t, dx| { dx[central] = -ke * x[central]; }, @@ -99,66 +99,17 @@ fn injected_handwritten_ode() -> equation::ODE { .expect("handwritten injected metadata should validate") } -fn explicit_macro_ode() -> equation::ODE { - ode! { - name: "explicit_one_cpt", - params: [ke, v], - states: [central], - outputs: [cp], - routes: { - infusion(iv) -> central, - }, - diffeq: |x, _t, dx, _bolus, rateiv| { - dx[central] = rateiv[iv] - ke * x[central]; - }, - out: |x, _t, y| { - y[cp] = x[central] / v; - }, - } -} - -fn explicit_handwritten_ode() -> equation::ODE { - equation::ODE::new( - |x, p, _t, dx, _bolus, rateiv, _cov| { - fetch_params!(p, ke, _v); - dx[0] = rateiv[0] - ke * x[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(1) - .with_ndrugs(1) - .with_nout(1) - .with_metadata( - equation::metadata::new("explicit_one_cpt") - .parameters(["ke", "v"]) - .states(["central"]) - .outputs(["cp"]) - .route( - equation::Route::infusion("iv") - .to_state("central") - .expect_explicit_input(), - ), - ) - .expect("handwritten explicit metadata should validate") -} - fn numeric_label_macro_ode() -> equation::ODE { ode! { name: "numeric_label_one_cpt", params: [ke, v], states: [central], outputs: [1], - routes: { + routes: [ infusion(1) -> central, - }, - diffeq: |x, _t, dx, _bolus, rateiv| { - dx[central] = rateiv[1] - ke * x[central]; + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; }, out: |x, _t, y| { y[1] = x[central] / v; @@ -191,7 +142,7 @@ fn numeric_label_handwritten_ode() -> equation::ODE { .route( equation::Route::infusion("1") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ), ) .expect("handwritten numeric-label metadata should validate") @@ -203,13 +154,13 @@ fn shared_input_macro_ode() -> equation::ODE { params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, infusion(iv) -> central, - }, - diffeq: |x, _t, dx, bolus, rateiv| { - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_t| { lag! { oral => tlag } @@ -257,10 +208,10 @@ fn shared_input_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) .expect("handwritten shared-input metadata should validate") @@ -272,11 +223,11 @@ fn numeric_route_property_macro_ode() -> equation::ODE { params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [1], - routes: { + routes: [ bolus(1) -> depot, - }, - diffeq: |x, _t, dx, bolus, _rateiv| { - dx[depot] = bolus[1] - ka * x[depot]; + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_t| { @@ -325,7 +276,7 @@ fn numeric_route_property_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), ), ) .expect("handwritten numeric route-property metadata should validate") @@ -337,11 +288,11 @@ fn mixed_output_labels_macro_ode() -> equation::ODE { params: [ke, v], states: [central], outputs: [cp, 0, 1], - routes: { + routes: [ infusion(iv) -> central, - }, - diffeq: |x, _t, dx, _bolus, rateiv| { - dx[central] = rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; }, out: |x, _t, y| { y[cp] = x[central] / v; @@ -378,7 +329,7 @@ fn mixed_output_labels_handwritten_ode() -> equation::ODE { .route( equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ), ) .expect("handwritten mixed-output metadata should validate") @@ -391,9 +342,9 @@ fn covariate_macro_ode() -> equation::ODE { covariates: [wt], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _t, dx| { let scaled_ke = ke * (wt / 70.0); dx[gut] = -ka * x[gut]; @@ -473,32 +424,6 @@ fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { assert_prediction_match(¯o_predictions, &handwritten_predictions); } -#[test] -fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { - let macro_ode = explicit_macro_ode(); - let handwritten_ode = explicit_handwritten_ode(); - let subject = subject_for_route("iv", "cp"); - let support_point = [0.2, 10.0]; - - assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.route_index("iv"), Some(0)); - assert_eq!(macro_ode.output_index("cp"), Some(0)); - assert_eq!(macro_ode.state_index("central"), Some(0)); - - let macro_predictions = macro_ode - .estimate_predictions(&subject, &support_point) - .expect("macro explicit model should simulate") - .flat_predictions() - .to_vec(); - let handwritten_predictions = handwritten_ode - .estimate_predictions(&subject, &support_point) - .expect("handwritten explicit model should simulate") - .flat_predictions() - .to_vec(); - - assert_prediction_match(¯o_predictions, &handwritten_predictions); -} - #[test] fn macro_numeric_labels_lower_to_dense_slots() { let macro_ode = numeric_label_macro_ode(); @@ -622,12 +547,12 @@ fn macro_named_labels_resolve_from_pmetrics_ingestion() { let subject = &data.subjects()[0]; let support_point = [0.2, 10.0]; - let pmetrics_predictions = explicit_macro_ode() + let pmetrics_predictions = injected_macro_ode() .estimate_predictions(subject, &support_point) .expect("macro named-label model should simulate") .flat_predictions() .to_vec(); - let manual_predictions = explicit_macro_ode() + let manual_predictions = injected_macro_ode() .estimate_predictions(&subject_for_route("iv", "cp"), &support_point) .expect("macro internal-index model should simulate") .flat_predictions() diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 05b5cb27..474c7bab 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -73,9 +73,9 @@ fn macro_infusion_sde() -> equation::SDE { states: [central], outputs: [cp], particles: 16, - routes: { + routes: [ infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { dx[central] = -ke * x[central]; }, @@ -133,9 +133,9 @@ fn macro_absorption_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, - }, + ], drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -218,10 +218,10 @@ fn macro_shared_input_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -307,10 +307,10 @@ fn macro_covariate_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { let wt_scale = (wt / 70.0).powf(0.75); let renal_scale = (renal / 90.0).powf(0.25); From 8cac67c5cf8a806e03769e4ad8fc54e287baca90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 16:16:31 +0100 Subject: [PATCH 14/22] chore: update examples to use new route API --- examples/analytical_readme.rs | 13 +++++-------- examples/compare_solvers.rs | 34 +++++++++++++--------------------- examples/dsl_runtime_jit.rs | 24 ++++++++---------------- examples/ode_readme.rs | 13 +++++-------- examples/sde_readme.rs | 13 +++++-------- 5 files changed, 36 insertions(+), 61 deletions(-) diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs index 676f07b9..8451b478 100644 --- a/examples/analytical_readme.rs +++ b/examples/analytical_readme.rs @@ -15,15 +15,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = analytical.route_index("iv").expect("iv route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("analytical_readme") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = analytical.estimate_predictions(&subject, &[1.022, 194.0])?; diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index a8067485..58813081 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -51,29 +51,21 @@ fn main() { // Both declarations resolve to the same shared input, so subject // authoring still uses one numeric index for the loading bolus and the // maintenance infusion. - let load = bdf.route_index("load").expect("load route exists"); - let iv = bdf.route_index("iv").expect("iv route exists"); - let cp = bdf.output_index("cp").expect("cp output exists"); - - assert_eq!( - load, iv, - "mixed IV declarations should share one numeric input" - ); let subject = Subject::builder("id1") - .bolus(0.0, 100.0, iv) - .infusion(12.0, 200.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) - .missing_observation(12.5, cp) - .missing_observation(13.0, cp) - .missing_observation(14.0, cp) - .missing_observation(16.0, cp) - .missing_observation(24.0, cp) + .bolus(0.0, 100.0, "iv") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") .build(); let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V diff --git a/examples/dsl_runtime_jit.rs b/examples/dsl_runtime_jit.rs index 932acaae..3f7d1efe 100644 --- a/examples/dsl_runtime_jit.rs +++ b/examples/dsl_runtime_jit.rs @@ -43,24 +43,16 @@ out(cp) = central / v on_compile_event, )?; - // 2. Resolve the route and output indices declared by the model. - let iv = model - .route_index("iv") - .ok_or_else(|| io::Error::other("missing iv route"))?; - let cp = model - .output_index("cp") - .ok_or_else(|| io::Error::other("missing cp output"))?; - // 3. Define the subject data. let subject = Subject::builder("bimodal_ke") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(3.0, cp) - .missing_observation(4.0, cp) - .missing_observation(6.0, cp) - .missing_observation(8.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(3.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(8.0, "cp") .build(); // 4. Estimate predictions for one support point. diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index 7b436d0b..2989895f 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -17,15 +17,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = ode.route_index("iv").expect("iv route exists"); - let cp = ode.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("id1") - .infusion(0.0, 100.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 100.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = ode.estimate_predictions(&subject, &[1.022, 194.0])?; diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs index cc47cdda..97b5fed4 100644 --- a/examples/sde_readme.rs +++ b/examples/sde_readme.rs @@ -21,15 +21,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = sde.route_index("iv").expect("iv route exists"); - let cp = sde.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("sde_readme") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = sde.estimate_predictions(&subject, &[1.022, 0.0, 194.0])?; From 4b6d962dbe63437198e1180b982011ab59281791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 16:21:31 +0100 Subject: [PATCH 15/22] chore: Julian made a mistake :P --- examples/compare_solvers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index 58813081..5d8fdbb6 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -53,7 +53,7 @@ fn main() { // maintenance infusion. let subject = Subject::builder("id1") - .bolus(0.0, 100.0, "iv") + .bolus(0.0, 100.0, "load") .infusion(12.0, 200.0, "iv", 2.0) .missing_observation(0.5, "cp") .missing_observation(1.0, "cp") From 7b24e09715b7c1cf217f20ebe00bb05072d1028c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 20:42:01 +0100 Subject: [PATCH 16/22] feature: DSL supports numeric input/outeq. Implement Equation Trait for runtime environments. --- pharmsol-dsl/src/semantic.rs | 20 +- .../tests/dsl_authoring_edge_cases.rs | 71 +++++++ src/dsl/native.rs | 177 +++++++++++++++++- src/dsl/runtime.rs | 66 +++++++ 4 files changed, 328 insertions(+), 6 deletions(-) diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index ac9223dd..6a5e3b91 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -1617,11 +1617,13 @@ impl<'a> Analyzer<'a> { span, ))); } - if let Some(existing) = self.globals.all_names.get(name) { + if let Some(existing) = self.globals.all_names.get(name).copied() { + let existing_kind = self.symbols.get(existing).expect("valid symbol id").kind; + if !allows_route_output_name_overlap(existing_kind, kind) { return Err(SemanticAssist::default() .context_label( - self.symbol_span(*existing), - self.symbol_declared_here(*existing), + self.symbol_span(existing), + self.symbol_declared_here(existing), ) .help(format!( "rename this declaration to a unique name such as `{}_2`", @@ -1636,10 +1638,11 @@ impl<'a> Analyzer<'a> { .apply(SemanticError::new( format!( "symbol name `{name}` collides with existing `{}`", - self.symbol_name(*existing) + self.symbol_name(existing) ), span, ))); + } } let id = self.symbols.len(); self.symbols.push(PendingSymbol { @@ -1649,7 +1652,7 @@ impl<'a> Analyzer<'a> { ty, span, }); - self.globals.all_names.insert(name.to_string(), id); + self.globals.all_names.entry(name.to_string()).or_insert(id); Ok(id) } @@ -2132,6 +2135,13 @@ impl<'a> Analyzer<'a> { } } +fn allows_route_output_name_overlap(existing: SymbolKind, new: SymbolKind) -> bool { + matches!( + (existing, new), + (SymbolKind::Route, SymbolKind::Output) | (SymbolKind::Output, SymbolKind::Route) + ) +} + #[derive(Default)] struct Globals { all_names: BTreeMap, diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 404487dc..3f4cb494 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -209,6 +209,77 @@ out(1) = 3 * central / v assert_eq!(rendered, reparsed.to_string()); } +#[test] +fn shared_numeric_route_and_output_labels_lower_and_round_trip() { + let src = r#" +name = shared_numeric_route_output_labels +kind = ode +params = ke, v +states = central +outputs = 1 +infusion(1) -> central +ddt(central) = -ke * central +out(1) = central / v +"#; + + let module = parse_module(src).expect("shared numeric route/output labels should parse"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(model).expect("shared numeric route/output labels should analyze"); + let lowered = lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); + + assert_eq!( + lowered + .metadata + .routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["1"] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered shared-label model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + +#[test] +fn route_labels_still_collide_with_scalar_symbol_names() { + let src = r#" +name = route_state_collision +kind = ode +params = ke +states = central, iv +outputs = cp +infusion(iv) -> central +ddt(central) = -ke * central +ddt(iv) = 0 +out(cp) = central +"#; + + let model = parse_model(src).expect("route/state collision model parses"); + let err = analyze_model(&model).expect_err("route label should still collide with state name"); + let rendered = err.render(src); + + assert!( + rendered.contains("symbol name `iv` collides with existing `iv`"), + "{}", + rendered + ); +} + #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 97c41013..d9598172 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::HashMap; use std::sync::Arc; use diffsol::{ @@ -20,14 +21,17 @@ pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ + data::error_model::AssayErrorModels, data::{Covariates, Infusion, InputLabel, OutputLabel}, simulator::{ + cache::{PredictionCache, DEFAULT_CACHE_SIZE}, equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, sde::simulate_sde_event_with, + EqnKind, Equation, EquationPriv, EquationTypes, }, likelihood::{Prediction, SubjectPredictions}, - M, V, + Fa, Lag, M, T, V, }, Event, Observation, Occasion, PharmsolError, Subject, }; @@ -727,6 +731,7 @@ pub struct NativeOdeModel { solver: OdeSolver, rtol: f64, atol: f64, + cache: Option, } #[derive(Clone, Debug)] @@ -754,6 +759,7 @@ impl NativeOdeModel { solver: OdeSolver::default(), rtol: DEFAULT_ODE_RTOL, atol: DEFAULT_ODE_ATOL, + cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -1031,6 +1037,175 @@ impl NativeOdeModel { } } +fn runtime_no_lag(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +fn runtime_no_fa(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +#[inline(always)] +fn runtime_ode_predictions( + model: &NativeOdeModel, + subject: &Subject, + support_point: &[f64], +) -> Result { + if let Some(cache) = &model.cache { + let key = ( + subject.hash(), + crate::simulator::equation::spphash(support_point), + ); + if let Some(cached) = cache.get(&key) { + return Ok(cached); + } + + let result = model.estimate_predictions(subject, support_point)?; + cache.insert(key, result.clone()); + Ok(result) + } else { + model.estimate_predictions(subject, support_point) + } +} + +impl crate::simulator::equation::Cache for NativeOdeModel { + fn with_cache_capacity(mut self, size: u64) -> Self { + self.cache = Some(PredictionCache::new(size)); + self + } + + fn enable_cache(mut self) -> Self { + self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); + self + } + + fn clear_cache(&self) { + if let Some(cache) = &self.cache { + cache.invalidate_all(); + } + } + + fn disable_cache(mut self) -> Self { + self.cache = None; + self + } +} + +impl EquationTypes for NativeOdeModel { + type S = V; + type P = SubjectPredictions; +} + +impl EquationPriv for NativeOdeModel { + fn lag(&self) -> &Lag { + &(runtime_no_lag as Lag) + } + + fn fa(&self) -> &Fa { + &(runtime_no_fa as Fa) + } + + fn get_nstates(&self) -> usize { + self.shared.info.state_len + } + + fn get_ndrugs(&self) -> usize { + self.shared.info.route_len + } + + fn get_nouteqs(&self) -> usize { + self.shared.info.output_len + } + + fn metadata(&self) -> Option<&crate::ValidatedModelMetadata> { + None + } + + fn solve( + &self, + _state: &mut Self::S, + _support_point: &[f64], + _covariates: &Covariates, + _infusions: &[Infusion], + _start_time: f64, + _end_time: f64, + ) -> Result<(), PharmsolError> { + unimplemented!("solve is not used for runtime ODE models") + } + + fn process_observation( + &self, + _support_point: &[f64], + _observation: &Observation, + _error_models: Option<&AssayErrorModels>, + _time: f64, + _covariates: &Covariates, + _x: &mut Self::S, + _likelihood: &mut Vec, + _output: &mut Self::P, + ) -> Result<(), PharmsolError> { + unimplemented!("process_observation is not used for runtime ODE models") + } + + fn initial_state( + &self, + _support_point: &[f64], + _covariates: &Covariates, + _occasion_index: usize, + ) -> Self::S { + V::zeros(self.shared.info.state_len, NalgebraContext) + } +} + +impl Equation for NativeOdeModel { + fn estimate_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + Ok(self + .estimate_log_likelihood(subject, support_point, error_models)? + .exp()) + } + + fn estimate_log_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + let predictions = runtime_ode_predictions(self, subject, support_point)?; + predictions.log_likelihood(error_models) + } + + fn kind() -> EqnKind { + EqnKind::ODE + } + + fn estimate_predictions( + &self, + subject: &Subject, + support_point: &[f64], + ) -> Result { + runtime_ode_predictions(self, subject, support_point) + } + + fn simulate_subject( + &self, + subject: &Subject, + support_point: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::P, Option), PharmsolError> { + let predictions = runtime_ode_predictions(self, subject, support_point)?; + let likelihood = match error_models { + Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), + None => None, + }; + Ok((predictions, likelihood)) + } +} + impl NativeAnalyticalModel { pub(crate) fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { Self { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index ba6dd5cd..59c399ab 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -414,6 +414,21 @@ bolus(11) -> central dx(central) = -ke * central out(cp) = central / v ~ continuous() +"#; + + const SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = shared_numeric_route_output_runtime +kind = ode + +params = ke, v +states = central +outputs = 1 + +infusion(1) -> central + +dx(central) = -ke * central + +out(1) = central / v ~ continuous() "#; const UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" @@ -551,6 +566,14 @@ out(cp) = central / v ~ continuous() .build() } + fn shared_numeric_route_output_subject() -> Subject { + Subject::builder("shared-numeric-route-output-runtime") + .infusion(0.0, 120.0, "1", 1.0) + .missing_observation(0.5, "1") + .missing_observation(1.5, "1") + .build() + } + fn assert_unknown_output_label( model: &CompiledRuntimeModel, subject: &Subject, @@ -714,6 +737,49 @@ out(cp) = central / v ~ continuous() } } + #[test] + fn runtime_backend_matrix_supports_shared_numeric_route_and_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL, + "shared_numeric_route_output_runtime", + work_dir.path(), + ); + + assert_eq!(jit.route_index("1"), Some(0)); + assert_eq!(jit.output_index("1"), Some(0)); + assert_eq!(aot.route_index("1"), Some(0)); + assert_eq!(aot.output_index("1"), Some(0)); + assert_eq!(wasm.route_index("1"), Some(0)); + assert_eq!(wasm.output_index("1"), Some(0)); + + let subject = shared_numeric_route_output_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + #[test] fn runtime_backend_matrix_rejects_undeclared_numeric_output_labels() { let work_dir = tempdir().expect("tempdir"); From 6a067e9543e1938d6b5116534d93da59c0f2efd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 20:42:18 +0100 Subject: [PATCH 17/22] chore: fmt --- pharmsol-dsl/src/semantic.rs | 44 +++++++++---------- .../tests/dsl_authoring_edge_cases.rs | 3 +- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index 6a5e3b91..f20288a9 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -1620,28 +1620,28 @@ impl<'a> Analyzer<'a> { if let Some(existing) = self.globals.all_names.get(name).copied() { let existing_kind = self.symbols.get(existing).expect("valid symbol id").kind; if !allows_route_output_name_overlap(existing_kind, kind) { - return Err(SemanticAssist::default() - .context_label( - self.symbol_span(existing), - self.symbol_declared_here(existing), - ) - .help(format!( - "rename this declaration to a unique name such as `{}_2`", - name - )) - .replacement_suggestion( - span, - format!("{}_2", name), - format!("rename this declaration to `{}_2`", name), - Applicability::MaybeIncorrect, - ) - .apply(SemanticError::new( - format!( - "symbol name `{name}` collides with existing `{}`", - self.symbol_name(existing) - ), - span, - ))); + return Err(SemanticAssist::default() + .context_label( + self.symbol_span(existing), + self.symbol_declared_here(existing), + ) + .help(format!( + "rename this declaration to a unique name such as `{}_2`", + name + )) + .replacement_suggestion( + span, + format!("{}_2", name), + format!("rename this declaration to `{}_2`", name), + Applicability::MaybeIncorrect, + ) + .apply(SemanticError::new( + format!( + "symbol name `{name}` collides with existing `{}`", + self.symbol_name(existing) + ), + span, + ))); } } let id = self.symbols.len(); diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 3f4cb494..4d1651f5 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -228,7 +228,8 @@ out(1) = central / v .first() .expect("authoring DSL should produce one model"); let typed = analyze_model(model).expect("shared numeric route/output labels should analyze"); - let lowered = lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); + let lowered = + lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); assert_eq!( lowered From 444c031b65f2932365ec474d821e8b350f9acdc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Wed, 6 May 2026 20:21:42 +0100 Subject: [PATCH 18/22] chore: documentation --- src/data/builder.rs | 117 +++++++++++++++++++-------- src/data/event.rs | 102 +++++++++++++++++++----- src/data/mod.rs | 85 ++++++++++++++++---- src/data/parser/mod.rs | 12 +++ src/data/parser/pmetrics.rs | 59 +++++++++++--- src/data/row.rs | 121 +++++++++++++++++----------- src/dsl/aot.rs | 63 +++++++++++++++ src/dsl/jit.rs | 47 +++++++++++ src/dsl/mod.rs | 93 +++++++++++++++++++++- src/dsl/model_info.rs | 33 ++++++++ src/dsl/runtime.rs | 113 ++++++++++++++++++++++++++ src/dsl/wasm.rs | 28 +++++++ src/dsl/wasm_compile.rs | 55 +++++++++++++ src/lib.rs | 124 +++++++++++++++++++++++++++-- src/simulator/equation/metadata.rs | 103 +++++++++++++++++++++--- src/simulator/equation/mod.rs | 67 ++++++++++++++-- 16 files changed, 1066 insertions(+), 156 deletions(-) diff --git a/src/data/builder.rs b/src/data/builder.rs index a1718dc7..ed0a57a8 100644 --- a/src/data/builder.rs +++ b/src/data/builder.rs @@ -1,6 +1,21 @@ +//! Builder API for constructing [`Subject`] schedules in Rust. +//! +//! Use `Subject::builder(...)` when you want to describe a subject directly in +//! code with a schedule-oriented API. This is the preferred high-level +//! path for hand-written datasets. +//! +//! Builder methods accept public input and output labels. Prefer stable strings +//! such as `"depot"`, `"iv"`, and `"cp"`. Numeric values are accepted, but +//! they remain public labels rather than automatically becoming dense internal +//! indices. + use crate::{data::*, Censor}; -/// Extension trait for creating [Subject] instances using the builder pattern +/// Extension trait that enables `Subject::builder(...)`. +/// +/// Most users do not need to import [`SubjectBuilder`] directly. Import this +/// trait from the crate root or [`crate::prelude`] and then start with +/// `Subject::builder("id")`. pub trait SubjectBuilderExt { /// Create a new SubjectBuilder with the specified ID /// @@ -14,8 +29,8 @@ pub trait SubjectBuilderExt { /// use pharmsol::*; /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) - /// .observation(1.0, 10.5, 0) + /// .bolus(0.0, 100.0, "depot") + /// .observation(1.0, 10.5, "cp") /// .build(); /// ``` fn builder(id: impl Into) -> SubjectBuilder; @@ -34,11 +49,37 @@ impl SubjectBuilderExt for Subject { } } -/// Builder for creating [Subject] instances with a fluent API +/// Builder for creating [`Subject`] values with a fluent API. +/// +/// Use [`SubjectBuilder`] when you want to author common dose and observation +/// schedules directly in Rust without constructing low-level event values by +/// hand. +/// +/// A builder instance accumulates events inside the current [`Occasion`]. +/// [`SubjectBuilder::repeat`] duplicates the most recently added event at later +/// times, and [`SubjectBuilder::reset`] closes the current occasion and starts a +/// new one with fresh occasion-local state. +/// +/// Input and output arguments are public labels. Prefer stable model-facing +/// names such as `"depot"`, `"iv"`, and `"cp"`. +/// +/// # Example +/// +/// ```rust +/// use pharmsol::*; +/// +/// let subject = Subject::builder("patient_001") +/// .bolus(0.0, 100.0, "depot") +/// .repeat(1, 24.0) +/// .observation(1.0, 12.3, "cp") +/// .missing_observation(25.0, "cp") +/// .reset() +/// .bolus(0.0, 80.0, "depot") +/// .observation(1.0, 10.1, "cp") +/// .build(); /// -/// The [SubjectBuilder] allows for constructing complex subject data with a -/// chainable, readable syntax. Events like doses and observations can be -/// added sequentially, and the builder handles organizing them into occasions. +/// assert_eq!(subject.occasions().len(), 2); +/// ``` #[derive(Debug, Clone)] pub struct SubjectBuilder { id: String, @@ -49,37 +90,39 @@ pub struct SubjectBuilder { } impl SubjectBuilder { - /// Add an event to the current occasion + /// Add a fully constructed event to the current occasion. /// - /// # Arguments - /// - /// * `event` - The event to add + /// Use this when you want to mix builder convenience methods with direct + /// [`Event`] values. pub fn event(mut self, event: Event) -> Self { self.last_added_event = Some(event.clone()); self.current_occasion.add_event(event); self } - /// Add a bolus dosing event + /// Add an instantaneous dose. /// /// # Arguments /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose + /// * `input` - Public input label receiving the dose + /// + /// Prefer stable route names such as `"depot"` or `"iv"` when the model + /// declares named routes. pub fn bolus(self, time: f64, amount: f64, input: impl ToString) -> Self { let bolus = Bolus::new(time, amount, input, self.current_occasion.index()); let event = Event::Bolus(bolus); self.event(event) } - /// Add an infusion event + /// Add a continuous dose over a duration. /// /// # Arguments /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - Public input label receiving the dose /// * `duration` - Duration of the infusion in time units pub fn infusion(self, time: f64, amount: f64, input: impl ToString, duration: f64) -> Self { let infusion = Infusion::new(time, amount, input, duration, self.current_occasion.index()); @@ -87,13 +130,13 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation + /// Add an observed value at a given time. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation + /// * `outeq` - Public output label for this observation pub fn observation(self, time: f64, value: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, @@ -107,13 +150,14 @@ impl SubjectBuilder { self.event(event) } - /// Add a censored observation + /// Add an observed value with explicit censoring information. + /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this - /// observation + /// * `outeq` - Public output label for this observation + /// * `censoring` - Censoring status for the observation value pub fn censored_observation( self, time: f64, @@ -133,12 +177,15 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation + /// Add a prediction-only observation slot. /// /// # Arguments /// /// * `time` - Time of the observation - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation + /// * `outeq` - Public output label for this observation + /// + /// Use this when you want a prediction at a time point but do not have an + /// observed value. pub fn missing_observation(self, time: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, @@ -152,15 +199,15 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation with a specific error polynomial + /// Add an observed value with an explicit assay error polynomial. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation + /// * `outeq` - Public output label for this observation /// * `errorpoly` - Error polynomial coefficients (c0, c1, c2, c3) - /// * `censored` - Whether the observation is censored + /// * `censored` - Censoring status for the observation value pub fn observation_with_error( self, time: f64, @@ -181,7 +228,10 @@ impl SubjectBuilder { self.event(event) } - /// Repeat the last event `n` times, separated by some interval `delta` + /// Repeat the last event `n` times, separated by `delta`. + /// + /// The repeated events keep the same label, value, censoring state, and + /// error polynomial as the original event. Only the event time changes. /// /// # Arguments /// @@ -193,9 +243,8 @@ impl SubjectBuilder { /// ```rust /// use pharmsol::*; /// - /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) // First dose at time 0 + /// .bolus(0.0, 100.0, "depot") // First dose at time 0 /// .repeat(3, 24.0) // Repeat the dose at times 24, 48, and 72 /// .build(); /// ``` @@ -255,12 +304,14 @@ impl SubjectBuilder { self } - /// Complete the current occasion and start a new one + /// Complete the current occasion and start a new one. /// /// This finalizes the current occasion, adds it to the subject, /// and creates a new occasion for subsequent events. - /// This is useful if a patient has new observations at some other occasion. - /// Note that all states are reset! + /// Use this when the subject should begin a new occasion with reset state. + /// + /// Covariates collected since the previous reset are attached to the + /// finished occasion. The new occasion starts empty and its state is reset. pub fn reset(mut self) -> Self { let block_index = self.current_occasion.index() + 1; self.current_occasion.sort(); @@ -274,7 +325,7 @@ impl SubjectBuilder { self } - /// Add a covariate value at a specific time + /// Add a covariate value at a specific time. /// /// Multiple calls for the same covariate at different times will create /// linear interpolation between the time points. @@ -300,7 +351,7 @@ impl SubjectBuilder { self } - /// Finalize and build the Subject + /// Finalize and build the [`Subject`]. /// /// This completes the current occasion and returns a new Subject with all /// the accumulated data. diff --git a/src/data/event.rs b/src/data/event.rs index bff4c700..02a4c9a7 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -1,3 +1,15 @@ +//! Event types and public label wrappers for subject schedules. +//! +//! These types are the low-level representation behind the higher-level +//! builder and parsing APIs. Most users can start with +//! [`crate::data::builder::SubjectBuilder`], then inspect or transform +//! [`Event`] values after construction. +//! +//! Dose events carry an [`InputLabel`], and observations carry an +//! [`OutputLabel`]. Prefer stable strings such as `"depot"`, `"iv"`, and +//! `"cp"`. Numeric values are accepted, but they remain labels until a +//! downstream workflow explicitly interprets them as indices. + use crate::data::error_model::ErrorPoly; use crate::prelude::simulator::Prediction; use serde::{Deserialize, Serialize}; @@ -7,12 +19,16 @@ use std::fmt; // Shared Analysis Types // ============================================================================ -/// Administration route for a dosing event +/// Administration route classification used by downstream analyses. +/// +/// [`Route`] is a coarse route category, not the original public input label. +/// In the current data-side heuristic: +/// - [`Event::Infusion`] maps to [`Route::IVInfusion`] +/// - [`Event::Bolus`] with input label `0` maps to [`Route::Extravascular`] +/// - [`Event::Bolus`] with any other label maps to [`Route::IVBolus`] /// -/// Determined by the type of dose events and their target compartment: -/// - [`Event::Infusion`] → [`Route::IVInfusion`] -/// - [`Event::Bolus`] with `input >= 1` (central compartment) → [`Route::IVBolus`] -/// - [`Event::Bolus`] with `input == 0` (depot compartment) → [`Route::Extravascular`] +/// If you need the original model-facing label, read [`Bolus::input`] or +/// [`Infusion::input`] instead. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum Route { /// Intravenous bolus @@ -78,12 +94,15 @@ pub enum BLQRule { }, } -/// Represents a pharmacokinetic/pharmacodynamic event +/// One scheduled item in a subject record. +/// +/// Events are the low-level representation for doses and observations: +/// - [`Bolus`] for instantaneous input +/// - [`Infusion`] for input over a duration +/// - [`Observation`] for measured or missing outputs /// -/// Events represent key occurrences in a PK/PD profile, including: -/// - [Bolus] doses (instantaneous drug input) -/// - [Infusion]s (continuous drug input over a duration) -/// - [Observation]s (measured concentrations or other values) +/// Most users create these through `Subject::builder(...)`, row ingestion, or +/// file parsing rather than constructing them all by hand. #[derive(Serialize, Debug, Clone, Deserialize)] pub enum Event { /// A bolus dose (instantaneous drug input) @@ -95,21 +114,31 @@ pub enum Event { } macro_rules! impl_label_type { - ($name:ident) => { + ($(#[$meta:meta])* $name:ident) => { + $(#[$meta])* #[derive( Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, )] pub struct $name(String); impl $name { + /// Create a new public label. + /// + /// Prefer stable names when the model declares named routes or + /// outputs. pub fn new(label: impl ToString) -> Self { Self(label.to_string()) } + /// Borrow the stored label as a string. pub fn as_str(&self) -> &str { &self.0 } + /// Try to interpret the label as a numeric index. + /// + /// This is mainly a compatibility helper for lower-level paths that + /// still operate on dense indices after label resolution. pub fn index(&self) -> Option { self.0.parse::().ok() } @@ -171,8 +200,20 @@ macro_rules! impl_label_type { }; } -impl_label_type!(InputLabel); -impl_label_type!(OutputLabel); +impl_label_type!( + /// Public label for a dosing input or route. + /// + /// [`Bolus`] and [`Infusion`] store the original user-facing route name in + /// this type. + InputLabel +); +impl_label_type!( + /// Public label for an observation output. + /// + /// [`Observation`] stores the original user-facing output name in this + /// type. + OutputLabel +); impl Event { /// Get the time of the event @@ -226,9 +267,10 @@ impl Event { } } -/// Represents an instantaneous input of drug +/// Instantaneous dose input. /// -/// A [Bolus] is a discrete amount of drug added to a specific compartment at a specific time. +/// A [`Bolus`] records one discrete amount at one time, tagged with the public +/// input label that should be matched against the model. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Bolus { time: f64, @@ -263,6 +305,9 @@ impl Bolus { &self.input } + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Bolus::input`] when working with the public label itself. pub fn input_index(&self) -> Option { self.input.index() } @@ -313,9 +358,10 @@ impl Bolus { } } -/// Represents a continuous dose of drug over time +/// Continuous dose input over a duration. /// -/// An [Infusion] administers drug at a constant rate over a specified duration. +/// An [`Infusion`] records the total amount, start time, duration, and public +/// input label for one infusion event. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Infusion { time: f64, @@ -359,6 +405,9 @@ impl Infusion { &self.input } + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Infusion::input`] when working with the public label itself. pub fn input_index(&self) -> Option { self.input.index() } @@ -438,7 +487,11 @@ pub enum Censor { ALOQ, } -/// Represents an observation of drug concentration or other measured value + /// Observation of a model output. + /// + /// An [`Observation`] can carry a measured value or `None` for a prediction-only + /// time point. Observations also carry the public output label, optional assay + /// error polynomial, occasion index, and censoring state. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Observation { time: f64, @@ -482,7 +535,9 @@ impl Observation { self.time } - /// Get the value of the observation (e.g., drug concentration) + /// Get the value of the observation. + /// + /// `None` means this is a prediction-only or missing-observation slot. pub fn value(&self) -> Option { self.value } @@ -492,6 +547,9 @@ impl Observation { &self.outeq } + /// Try to interpret the output label as a numeric index. + /// + /// Prefer [`Observation::outeq`] when working with the public label itself. pub fn outeq_index(&self) -> Option { self.outeq.index() } @@ -553,7 +611,11 @@ impl Observation { &mut self.occasion } - /// Create a [Prediction] from this observation + /// Create a [`Prediction`] from this observation. + /// + /// This is a low-level helper for code paths that already operate on a + /// resolved or numeric output index. Named output labels must be resolved by + /// the caller before this conversion happens. pub fn to_prediction(&self, pred: f64, state: Vec) -> Prediction { Prediction { time: self.time(), diff --git a/src/data/mod.rs b/src/data/mod.rs index 996c791d..28a80b32 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,32 +1,83 @@ -//! Data structures and utilities for pharmacometric modeling +//! Data structures for building pharmacometric input data. //! -//! This module provides types for representing pharmacokinetic/pharmacodynamic data, -//! including subjects, dosing events, observations, and covariates. It also includes -//! utilities for reading and manipulating this data. +//! Use this module when you need to describe what happened to each subject: +//! doses, infusions, observations, covariates, and occasion boundaries. //! -//! # Key Components +//! This module is the input side of `pharmsol`. It is where you assemble +//! subjects and datasets before simulation, estimation, or NCA. It is not where +//! you define model equations or choose a backend. For those workflows, move to +//! [`crate::simulator`], [`crate::nca`], or the feature-gated `pharmsol::dsl` +//! surface. //! -//! - **Events**: Dosing events (bolus, infusion) and observations -//! - **Covariates**: Time-varying subject characteristics -//! - **Subjects**: Collections of events and covariates for a single individual -//! - **Data**: Collections of subjects, representing a complete dataset -//! - **Error Models**: Two types for different algorithm families: -//! - [`ErrorModel`]: Observation-based (assay error) for non-parametric algorithms -//! - [`ResidualErrorModel`]: Prediction-based (residual error) for parametric algorithms +//! # Start Here //! -//! # Examples +//! Most users only need three entrypoints first: //! -//! Creating a subject with the builder pattern: +//! - [`Subject`] for one individual and their full schedule. +//! - [`Data`] for a dataset containing many subjects. +//! - `Subject::builder` for the smallest fluent API to create doses, +//! observations, and covariates in Rust. +//! +//! The main supporting types are: +//! +//! - [`Occasion`] for repeated periods within one subject. +//! - [`Event`], [`Bolus`], [`Infusion`], and [`Observation`] for explicit +//! event-level control. +//! - [`Covariate`] and [`Covariates`] for time-varying subject characteristics. +//! - [`ErrorModel`], [`ResidualErrorModel`], and [`ObservationError`] for the +//! different error surfaces used by downstream workflows. +//! +//! # Choose A Data Input Path +//! +//! - Use `Subject::builder` when you are authoring a schedule directly in Rust. +//! - Use [`row::DataRow`] and [`row::DataRowBuilder`] when your source data is +//! already row-shaped in memory. +//! - Use [`parser::read_pmetrics`] when you are loading a Pmetrics-style file +//! from disk. +//! - Use [`Event`] variants directly when you already have validated event +//! records and need lower-level control than the builder offers. +//! +//! # Label Semantics +//! +//! Dosing inputs and observation outputs use public labels. +//! +//! - The `input` on [`Bolus`] and [`Infusion`] is the route or input label that +//! will be matched against the model. +//! - The `outeq` on [`Observation`] is the output label that identifies which +//! model output the observation belongs to. +//! - Prefer stable names such as `"depot"`, `"central"`, `"iv"`, or `"cp"`. +//! - If you pass a number, it is still treated as a public label string. Use +//! numeric values only when your model intentionally declares numeric labels. +//! +//! [`Occasion`] indices are different: they are integer period markers used to +//! separate repeated dosing blocks within one subject. +//! +//! # Error Surfaces +//! +//! This module exposes three related but different error families: +//! +//! - [`ErrorModel`] for assay or measurement error driven by the observation +//! value, commonly used in non-parametric workflows. +//! - [`ResidualErrorModel`] for residual unexplained variability driven by the +//! prediction value, commonly used in parametric workflows. +//! - [`ObservationError`] for invalid or insufficient observation data during +//! profile construction and related preprocessing. +//! +//! # Example //! //! ```rust //! use pharmsol::*; //! //! let subject = Subject::builder("patient_001") -//! .bolus(0.0, 100.0, 0) -//! .observation(1.0, 10.5, 0) -//! .observation(2.0, 8.2, 0) +//! .bolus(0.0, 100.0, "depot") +//! .observation(1.0, 12.3, "cp") +//! .missing_observation(2.0, "cp") //! .covariate("weight", 0.0, 70.0) //! .build(); +//! +//! let data = Data::new(vec![subject]); +//! +//! assert_eq!(data.subjects().len(), 1); //! ``` pub mod auc; diff --git a/src/data/parser/mod.rs b/src/data/parser/mod.rs index 7bfde3ca..74a50a84 100644 --- a/src/data/parser/mod.rs +++ b/src/data/parser/mod.rs @@ -1,3 +1,15 @@ +//! File-based parsers and parser-facing row utilities. +//! +//! Use this module when your source data starts as files or parser-shaped rows. +//! It re-exports the row ingestion API from [`crate::data::row`] and provides +//! format-specific loaders such as [`read_pmetrics`]. +//! +//! Choose the entrypoint by source shape: +//! - Use [`DataRow`] or [`build_data`] when you already mapped external data into +//! canonical row fields yourself. +//! - Use [`read_pmetrics`] when the source file already follows the Pmetrics CSV +//! convention. + pub mod pmetrics; pub use crate::data::row::{build_data, DataError, DataRow, DataRowBuilder}; diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index 89943f6e..2c90e2a7 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -1,3 +1,12 @@ +//! Pmetrics CSV parsing and export helpers. +//! +//! This module reads and writes the Pmetrics-style tabular format while keeping +//! pharmsol's public input and output labels intact. +//! +//! `INPUT` and `OUTEQ` values are parsed as labels, not rewritten to dense +//! indices. Named values such as `iv` and `cp` are preserved exactly, and +//! numeric values such as `1` are preserved as numeric-looking labels. + use crate::{data::*, PharmsolError}; use csv::WriterBuilder; use serde::de::{MapAccess, Visitor}; @@ -10,19 +19,27 @@ use crate::data::row::DataRow; use std::fmt; use std::str::FromStr; -/// Read a Pmetrics datafile and convert it to a [Data] object +/// Read a Pmetrics CSV file into [`Data`]. +/// +/// Use [`read_pmetrics`] when the source file already follows the usual +/// Pmetrics column convention instead of mapping the file into [`DataRow`] +/// values yourself. /// -/// This function parses a Pmetrics-formatted CSV file and constructs a [Data] object containing the structured -/// pharmacokinetic/pharmacodynamic data. The function handles various data formats including doses, observations, -/// and covariates. +/// The parser normalizes header names to lowercase, preserves `INPUT` and +/// `OUTEQ` as public labels, expands `ADDL` dosing rows through the shared row +/// ingestion path, and groups rows into occasions using `EVID=4`. +/// +/// All columns not claimed by the core Pmetrics schema are treated as +/// covariates. /// /// # Arguments /// -/// * `path` - The path to the Pmetrics CSV file +/// * `path` - Path to the Pmetrics CSV file /// /// # Returns /// -/// * `Result` - A result containing either the parsed [Data] object or an error +/// A parsed [`Data`] object or a [`DataError`] if the file cannot be read or a +/// required row field is missing. /// /// # Example /// @@ -33,14 +50,25 @@ use std::str::FromStr; /// println!("Number of subjects: {}", data.subjects().len()); /// ``` /// -/// # Format details +/// # Expected columns +/// +/// The canonical columns are `ID`, `TIME`, `EVID`, `DOSE`, `DUR`, `ADDL`, +/// `II`, `INPUT`, `OUT`, `OUTEQ`, `CENS`, and optional `C0..C3` error +/// coefficients. /// -/// The Pmetrics format expects columns like ID, TIME, EVID, DOSE, DUR, etc. The function will: +/// All other numeric columns are treated as covariates. +/// +/// # Parsing behavior +/// +/// The parser will: /// - Convert all headers to lowercase for case-insensitivity /// - Group rows by subject ID /// - Create occasions based on EVID=4 events /// - Parse covariates and create appropriate interpolations /// - Handle additional doses via ADDL and II fields +/// - Preserve raw `INPUT` and `OUTEQ` labels as strings until model resolution +/// - Treat `OUT=-99` as a missing observation value, matching the common +/// Pmetrics convention /// /// For specific column definitions, see the `Row` struct. #[allow(dead_code)] @@ -72,7 +100,7 @@ pub fn read_pmetrics(path: impl Into) -> Result { build_data(data_rows) } -/// A [Row] represents a row in the Pmetrics data format +/// One row from a Pmetrics file after serde deserialization. #[derive(Deserialize, Debug, Serialize, Default, Clone)] #[serde(rename_all = "lowercase")] struct Row { @@ -94,13 +122,13 @@ struct Row { /// Dosing interval #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, - /// Input compartment + /// Input label from the `INPUT` column #[serde(deserialize_with = "deserialize_option_route_label")] input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, - /// Corresponding output equation for the observation + /// Output label from the `OUTEQ` column #[serde(deserialize_with = "deserialize_option_output_label")] outeq: Option, /// Censoring output @@ -264,7 +292,14 @@ where } impl Data { - /// Write the dataset to a file in Pmetrics format + /// Write the dataset to a file in Pmetrics format. + /// + /// `INPUT` and `OUTEQ` are written using their stored public labels. Named + /// labels such as `iv` and `cp` remain named labels, and numeric-looking + /// labels are written back exactly as stored. + /// + /// Missing optional fields are emitted as `.` placeholders to match the + /// usual Pmetrics text convention. /// /// # Arguments /// diff --git a/src/data/row.rs b/src/data/row.rs index b9a807c1..fcb610ea 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -1,34 +1,51 @@ -//! Row representation of [Data] for flexible parsing +//! Row-shaped data ingestion for [`Data`] and [`Subject`] assembly. +//! +//! Use this module when your source data already looks like rows from a table, +//! CSV file, database export, or ETL pipeline. +//! +//! Choose the ingestion path by source shape: +//! - Use [`crate::data::builder::SubjectBuilder`] when you want to author a +//! schedule directly in Rust. +//! - Use [`DataRow`] and [`build_data`] when your application already has +//! validated row records in memory. +//! - Use [`crate::data::parser::read_pmetrics`] when the source file already +//! follows the Pmetrics column convention. +//! +//! [`DataRow`] keeps public route and output labels as strings. Labels such as +//! `"iv"`, `"depot"`, and `"cp"` are preserved through row parsing and later +//! resolved against model metadata by downstream workflows. //! //! # Example //! //! ```rust //! use pharmsol::data::parser::DataRow; //! -//! // Create a dosing row with ADDL expansion //! let row = DataRow::builder("subject_1", 0.0) //! .evid(1) //! .dose(100.0) -//! .input(1) -//! .addl(3) // 3 additional doses -//! .ii(12.0) // 12 hours apart +//! .input("iv") +//! .addl(3) +//! .ii(12.0) //! .build(); //! //! let events = row.into_events().unwrap(); -//! assert_eq!(events.len(), 4); // Original + 3 additional doses +//! assert_eq!(events.len(), 4); //! ``` -//! use crate::data::*; use std::collections::HashMap; use thiserror::Error; -/// A format-agnostic representation of a single data row +/// A format-agnostic representation of one input row. +/// +/// [`DataRow`] collects the canonical fields needed to turn one external row +/// into one or more [`Event`] values. /// -/// This struct represents the canonical fields needed to create pharmsol Events. -/// Consumers construct this from their source data (regardless of column names), -/// then call [`into_events()`](DataRow::into_events) to get properly parsed -/// Events with full ADDL expansion, EVID handling, censoring, etc. +/// Build this type from your own column mapping or external schema, then call +/// [`DataRow::into_events`] or [`build_data`] to assemble subjects and datasets. +/// +/// A single row can expand into several events when `ADDL` and `II` are both +/// present. /// /// # Fields /// @@ -42,24 +59,22 @@ use thiserror::Error; /// ```rust /// use pharmsol::data::parser::DataRow; /// -/// // Observation row /// let obs = DataRow::builder("pt1", 1.0) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .build(); /// -/// // Dosing row with negative ADDL (doses before time 0) /// let dose = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) -/// .input(1) -/// .addl(-10) // 10 doses BEFORE time 0 +/// .input("iv") +/// .addl(-10) /// .ii(12.0) /// .build(); /// /// let events = dose.into_events().unwrap(); -/// // Events at times: -120, -108, -96, ..., -12, 0 +/// assert_eq!(obs.outeq.as_ref().map(|label| label.as_str()), Some("cp")); /// assert_eq!(events.len(), 11); /// ``` #[derive(Debug, Clone, Default)] @@ -99,7 +114,7 @@ pub struct DataRow { } impl DataRow { - /// Create a new builder for constructing a DataRow + /// Create a builder for constructing one [`DataRow`]. /// /// # Arguments /// @@ -114,7 +129,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("depot") /// .build(); /// ``` pub fn builder(id: impl Into, time: f64) -> DataRowBuilder { @@ -129,13 +144,14 @@ impl DataRow { } } - /// Convert this row into pharmsol Events + /// Convert this row into one or more [`Event`] values. /// - /// This method contains all the complex parsing logic: + /// This method performs the row-level translation logic: /// - EVID interpretation (0=observation, 1=dose, 4=reset) /// - ADDL/II expansion (both positive and negative directions) /// - Infusion vs bolus detection based on DUR /// - Censoring and error polynomial handling + /// - Preservation of public input and output labels /// /// # ADDL Expansion /// @@ -163,13 +179,13 @@ impl DataRow { /// let row = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("iv") /// .addl(2) /// .ii(24.0) /// .build(); /// /// let events = row.into_events().unwrap(); - /// assert_eq!(events.len(), 3); // doses at 24, 48, and 0 + /// assert_eq!(events.len(), 3); /// /// let times: Vec = events.iter().map(|e| e.time()).collect(); /// assert_eq!(times, vec![24.0, 48.0, 0.0]); @@ -287,7 +303,11 @@ impl DataRow { } } -/// Builder for constructing DataRow with a fluent API +/// Fluent builder for [`DataRow`]. +/// +/// Use [`DataRowBuilder`] when you have row-shaped data in memory and want to +/// construct rows incrementally before calling [`DataRow::into_events`] or +/// [`build_data`]. /// /// # Example /// @@ -298,7 +318,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 1.5) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .cens(Censor::None) /// .covariate("weight", 70.0) /// .covariate("age", 45.0) @@ -373,10 +393,11 @@ impl DataRowBuilder { self } - /// Set the input route label + /// Set the input route label. /// - /// Required for EVID=1 (dosing events). - /// Preserved as the public route label until model resolution. + /// Required for EVID=1 dosing rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. pub fn input(mut self, input: impl ToString) -> Self { self.row.input = Some(InputLabel::new(input)); self @@ -390,10 +411,11 @@ impl DataRowBuilder { self } - /// Set the output label + /// Set the output label. /// - /// Required for EVID=0 (observation events). - /// Preserved as the public output label until model resolution. + /// Required for EVID=0 observation rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. pub fn outeq(mut self, outeq: impl ToString) -> Self { self.row.outeq = Some(OutputLabel::new(outeq)); self @@ -436,13 +458,18 @@ impl DataRowBuilder { } } -/// Build a [Data] object from an iterator of [DataRow]s +/// Build a [`Data`] object from row-shaped input. /// -/// This function handles all the complex assembly logic: +/// This function assembles rows into subjects and occasions: /// - Groups rows by subject ID /// - Splits into occasions at EVID=4 boundaries /// - Converts rows to events via [`DataRow::into_events()`] /// - Builds covariates from row covariate data +/// - Preserves per-subject row order within each occasion block +/// +/// Use this when you already have a collection of [`DataRow`] values in memory. +/// If your source file is a Pmetrics CSV, use [`crate::data::parser::read_pmetrics`] +/// instead. /// /// # Example /// @@ -450,23 +477,21 @@ impl DataRowBuilder { /// use pharmsol::data::parser::{DataRow, build_data}; /// /// let rows = vec![ -/// // Subject 1, Occasion 0 /// DataRow::builder("pt1", 0.0) -/// .evid(1).dose(100.0).input(1).build(), +/// .evid(1).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 1.0) -/// .evid(0).out(50.0).outeq(1).build(), -/// // Subject 1, Occasion 1 (EVID=4 starts new occasion) +/// .evid(0).out(50.0).outeq("cp").build(), /// DataRow::builder("pt1", 24.0) -/// .evid(4).dose(100.0).input(1).build(), +/// .evid(4).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 25.0) -/// .evid(0).out(48.0).outeq(1).build(), -/// // Subject 2 +/// .evid(0).out(48.0).outeq("cp").build(), /// DataRow::builder("pt2", 0.0) -/// .evid(1).dose(50.0).input(1).build(), +/// .evid(1).dose(50.0).input("iv").build(), /// ]; /// /// let data = build_data(rows).unwrap(); /// assert_eq!(data.subjects().len(), 2); +/// assert_eq!(data.subjects()[0].occasions().len(), 2); /// ``` pub fn build_data(rows: impl IntoIterator) -> Result { // Group rows by subject ID @@ -562,14 +587,14 @@ pub enum DataError { /// Required observation value (OUT) is missing #[error("Observation OUT is missing for {id} at time {time}")] MissingObservationOut { id: String, time: f64 }, - /// Required observation output equation (OUTEQ) is missing - #[error("Observation OUTEQ is missing in for {id} at time {time}")] + /// Required observation output label (`OUTEQ`) is missing + #[error("Observation OUTEQ is missing for {id} at time {time}")] MissingObservationOuteq { id: String, time: f64 }, /// Required infusion dose amount is missing #[error("Infusion amount (DOSE) is missing for {id} at time {time}")] MissingInfusionDose { id: String, time: f64 }, - /// Required infusion input compartment is missing - #[error("Infusion compartment (INPUT) is missing for {id} at time {time}")] + /// Required infusion input label (`INPUT`) is missing + #[error("Infusion input label (INPUT) is missing for {id} at time {time}")] MissingInfusionInput { id: String, time: f64 }, /// Required infusion duration is missing #[error("Infusion duration (DUR) is missing for {id} at time {time}")] @@ -577,8 +602,8 @@ pub enum DataError { /// Required bolus dose amount is missing #[error("Bolus amount (DOSE) is missing for {id} at time {time}")] MissingBolusDose { id: String, time: f64 }, - /// Required bolus input compartment is missing - #[error("Bolus compartment (INPUT) is missing for {id} at time {time}")] + /// Required bolus input label (`INPUT`) is missing + #[error("Bolus input label (INPUT) is missing for {id} at time {time}")] MissingBolusInput { id: String, time: f64 }, } diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 2a46409a..6557e015 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -37,18 +37,23 @@ use pharmsol_dsl::ModelKind; use pharmsol_dsl::{analyze_module, lower_typed_model, parse_module, ExecutionModel}; use pharmsol_dsl::{Diagnostic, DiagnosticReport, LoweringError, ParseError, SemanticError}; +/// ABI version for native AoT artifacts produced by this crate. pub const AOT_API_VERSION: u32 = 1; #[cfg(feature = "dsl-aot")] +/// Selects the compilation target for a native ahead-of-time artifact. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub enum NativeAotTarget { + /// Compile for the current host toolchain target. #[default] Host, + /// Compile for an explicit Rust target triple. Triple(String), } #[cfg(feature = "dsl-aot")] impl NativeAotTarget { + /// Create a target selector for an explicit Rust target triple. pub fn triple(target: impl Into) -> Self { Self::Triple(target.into()) } @@ -62,15 +67,24 @@ impl NativeAotTarget { } #[cfg(feature = "dsl-aot")] +/// Options that control native ahead-of-time artifact export. +/// +/// AoT export writes a small template crate under [`template_root`](Self::template_root), +/// builds a native shared library, and then copies the resulting artifact to +/// [`output`](Self::output) or a generated default path. #[derive(Debug, Clone, PartialEq, Eq)] pub struct NativeAotCompileOptions { + /// Target triple selection for the emitted artifact. pub target: NativeAotTarget, + /// Optional final artifact location. pub output: Option, + /// Working directory used for the temporary template crate and build output. pub template_root: PathBuf, } #[cfg(feature = "dsl-aot")] impl NativeAotCompileOptions { + /// Create AoT options rooted at a template build directory. pub fn new(template_root: PathBuf) -> Self { Self { target: NativeAotTarget::Host, @@ -79,17 +93,20 @@ impl NativeAotCompileOptions { } } + /// Set the final artifact output path. pub fn with_output(mut self, output: PathBuf) -> Self { self.output = Some(output); self } + /// Set the compilation target triple. pub fn with_target(mut self, target: NativeAotTarget) -> Self { self.target = target; self } } +/// Error produced while exporting, reading, or loading a native AoT artifact. #[derive(Error)] pub enum AotError { #[error(transparent)] @@ -151,6 +168,43 @@ impl fmt::Debug for AotError { } #[cfg(feature = "dsl-aot")] +/// Parse DSL source, lower one selected model, and export a native AoT artifact. +/// +/// Use this when you want a reusable native artifact that can be loaded later +/// with [`load_aot_model`] or [`crate::dsl::load_runtime_artifact`]. +/// +/// This function requires the `dsl-aot` feature. Loading the resulting artifact +/// later requires `dsl-aot-load`. +/// +/// ```rust,no_run +/// use std::path::PathBuf; +/// +/// use pharmsol::dsl::{compile_module_source_to_aot, load_aot_model, NativeAotCompileOptions}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let artifact = compile_module_source_to_aot( +/// source, +/// Some("bimodal_ke"), +/// NativeAotCompileOptions::new(PathBuf::from("target/doc-aot-build")), +/// |_, _| {}, +/// )?; +/// let loaded = load_aot_model(&artifact)?; +/// # let _ = loaded; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_module_source_to_aot( source: &str, model_name: Option<&str>, @@ -184,6 +238,10 @@ pub fn compile_module_source_to_aot( } #[cfg(feature = "dsl-aot")] +/// Export a lowered execution model as a native AoT artifact. +/// +/// Use this lower-level entrypoint when you already own the frontend pipeline +/// and only need artifact generation. pub fn export_execution_model_to_aot( model: &ExecutionModel, options: NativeAotCompileOptions, @@ -240,6 +298,10 @@ pub fn export_execution_model_to_aot( } #[cfg(feature = "dsl-aot-load")] +/// Read only the metadata from a native AoT artifact. +/// +/// This is useful when you need to inspect model identity, routes, outputs, or +/// buffer sizes without loading the executable kernels. pub fn read_aot_model_info(path: impl AsRef) -> Result { let library = unsafe { Library::new(path.as_ref()) } .map_err(|error| AotError::Load(error.to_string()))?; @@ -248,6 +310,7 @@ pub fn read_aot_model_info(path: impl AsRef) -> Result) -> Result { let path = path.as_ref(); let library = diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index a440c51d..684b7810 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -83,6 +83,11 @@ pub type JitAnalyticalModel = NativeAnalyticalModel; pub type JitSdeModel = NativeSdeModel; pub type CompiledJitModel = CompiledNativeModel; +/// Error reported while lowering an execution model into native in-process JIT +/// code. +/// +/// The error retains the backend diagnostic so callers can render the message +/// against the original DSL source when available. #[derive(Clone, PartialEq, Eq)] pub struct JitCompileError { diagnostic: Box, @@ -214,6 +219,10 @@ struct LoweredValue { ty: ValueType, } +/// Compile one lowered execution model into a reusable JIT kernel artifact. +/// +/// This builds the raw Cranelift-compiled kernel bundle for all roles present in +/// the model. Most callers should use [`compile_execution_model_to_jit`] instead. pub fn compile_execution_artifact( model: &ExecutionModel, ) -> Result { @@ -1217,6 +1226,41 @@ fn state_address( Ok(builder.ins().iadd(base, byte_offset)) } +/// Compile an [`ExecutionModel`](pharmsol_dsl::ExecutionModel) to the native +/// in-process JIT backend. +/// +/// Use this low-level entrypoint when you already own the parse, analyze, and +/// lower steps and want the JIT backend directly instead of the higher-level +/// runtime facade. +/// +/// This function requires the `dsl-jit` feature. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{ +/// analyze_model, compile_execution_model_to_jit, lower_typed_model, parse_model, +/// }; +/// +/// let parsed = parse_model( +/// r#" +/// model implicit_route_injection { +/// kind ode +/// states { central } +/// routes { iv -> central } +/// dynamics { +/// ddt(central) = 0 +/// } +/// outputs { +/// cp = central +/// } +/// } +/// "#, +/// )?; +/// let typed = analyze_model(&parsed)?; +/// let execution = lower_typed_model(&typed)?; +/// let compiled = compile_execution_model_to_jit(&execution)?; +/// # let _ = compiled; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_execution_model_to_jit( model: &ExecutionModel, ) -> Result { @@ -1229,6 +1273,7 @@ pub fn compile_execution_model_to_jit( } } +/// Compile an ODE execution model to the native in-process JIT backend. pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Ode { return Err(JitCompileError::new( @@ -1245,6 +1290,7 @@ pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result Result { @@ -1263,6 +1309,7 @@ pub fn compile_analytical_model_to_jit( )) } +/// Compile an SDE execution model to the native in-process JIT backend. pub fn compile_sde_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Sde { return Err(JitCompileError::new( diff --git a/src/dsl/mod.rs b/src/dsl/mod.rs index 563e4cf6..f536c377 100644 --- a/src/dsl/mod.rs +++ b/src/dsl/mod.rs @@ -1,9 +1,94 @@ //! Public DSL facade for pharmsol. //! -//! The backend-neutral frontend is being extracted into `pharmsol-dsl`. -//! Frontend syntax, diagnostics, semantic analysis, and lowering now come -//! from `pharmsol-dsl`, while runtime and backend compilation entrypoints -//! remain owned by `pharmsol`. +//! Use this module when you want to work with pharmsol models as source text +//! and stay inside the main crate for the full workflow: parse DSL source, +//! inspect diagnostics, lower to the execution model, compile to a runtime +//! backend, load saved artifacts, and run predictions. +//! +//! Use the `pharmsol-dsl` crate directly only when you need the backend-neutral +//! frontend as an engineering API. That crate owns parsing, diagnostics, +//! semantic analysis, and lowering. This module re-exports that stable +//! frontend surface and adds the backend-specific entrypoints that stay owned +//! by `pharmsol`. +//! +//! Main entrypoints: +//! +//! - [`parse_model`], [`parse_module`], [`analyze_model`], and +//! [`analyze_module`] for frontend-only validation and inspection. +//! - [`lower_typed_model`] and [`lower_typed_module`] for lowering typed models +//! into the execution representation used by the runtime backends. +//! - [`compile_module_source_to_runtime`] and [`compile_execution_model_to_runtime`] +//! for the one-stop compile-and-run path. +//! - [`load_runtime_artifact`], [`load_aot_model`], and +//! [`load_runtime_wasm_bytes`] for loading saved artifacts back into a model +//! you can execute. +//! +//! Common workflow choices: +//! +//! - Frontend only: parse, analyze, and lower when you need diagnostics, +//! authoring tools, or your own backend. +//! - In-process execution: compile straight to [`RuntimeCompilationTarget`] and +//! keep everything inside the current process. +//! - Native artifact shipping: export a native AoT artifact, then load it later +//! on a compatible host. +//! - WASM artifact shipping: emit `.wasm` bytes or a bundled module for browser +//! or portable runtime use. +//! +//! Feature map: +//! +//! - `dsl-core`: enables this facade and the frontend re-exports from +//! `pharmsol-dsl`. +//! - `dsl-jit`: enables in-process JIT compilation through +//! [`compile_module_source_to_runtime`] with +//! [`RuntimeCompilationTarget::Jit`], plus the lower-level JIT compile +//! entrypoints. +//! - `dsl-aot`: enables native ahead-of-time artifact export through +//! [`compile_module_source_to_aot`] and [`export_execution_model_to_aot`]. +//! - `dsl-aot-load`: enables native AoT artifact loading through +//! [`load_aot_model`] and [`read_aot_model_info`]. +//! - `dsl-wasm-compile`: enables WASM artifact emission through +//! [`compile_module_source_to_wasm_bytes`], +//! [`compile_module_source_to_wasm_module`], and the browser loader helpers. +//! - `dsl-wasm`: enables host-side WASM loading and runtime execution on +//! non-browser native hosts. This includes +//! [`compile_module_source_to_runtime_wasm`], [`load_runtime_wasm_bytes`], +//! [`read_wasm_model_info`], and [`read_wasm_model_info_bytes`]. +//! +//! Smallest compile-to-runtime example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! # let _ = model; +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` +//! +//! For a lower-level frontend pipeline without backend selection, use +//! `pharmsol-dsl`. For a complete runtime path inside the main crate, stay in +//! [`pharmsol::dsl`](self). #[cfg(any(feature = "dsl-aot", feature = "dsl-aot-load"))] mod aot; diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index d9a2fdbd..7dd3f72a 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -8,47 +8,80 @@ use pharmsol_dsl::execution::{ }; use pharmsol_dsl::{AnalyticalKernel, ModelKind, RouteKind}; +/// Public metadata extracted from a compiled backend model. +/// +/// This is the shared inspection surface returned by the native AoT, WASM, and +/// runtime loaders. It keeps public labels and buffer sizes available without +/// exposing backend-specific kernel details. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeModelInfo { + /// Public model name. pub name: String, + /// High-level model family. pub kind: ModelKind, + /// Parameter names in support-point order. pub parameters: Vec, + /// Declared covariates and their dense runtime indices. pub covariates: Vec, + /// Declared routes together with declaration-order and dense runtime indices. pub routes: Vec, + /// Declared outputs and their dense runtime indices. pub outputs: Vec, + /// Length of the state buffer used during execution. pub state_len: usize, + /// Length of the derived-value buffer used during execution. pub derived_len: usize, + /// Length of the output buffer used during execution. pub output_len: usize, + /// Length of the dense route-input buffer used during execution. pub route_len: usize, + /// Analytical kernel metadata when the compiled model is analytical. pub analytical: Option, + /// Particle count when the compiled model is stochastic. pub particles: Option, } +/// Metadata for one compiled covariate. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeCovariateInfo { + /// Public covariate name. pub name: String, + /// Dense runtime covariate index. pub index: usize, } +/// Metadata for one compiled route. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeRouteInfo { + /// Public route label. pub name: String, + /// Route position in declaration order. #[serde(default)] pub declaration_index: usize, + /// Dense runtime route-input index. pub index: usize, + /// Coarse route kind when declared in metadata. #[serde(default)] pub kind: Option, + /// Dense destination state offset used by compiled kernels. pub destination_offset: usize, + /// Whether the compiled backend injects the route input into the destination + /// state automatically when the model does not read the route input + /// explicitly. pub inject_input_to_destination: bool, } +/// Metadata for one compiled output. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeOutputInfo { + /// Public output label. pub name: String, + /// Dense runtime output index. pub index: usize, } impl NativeModelInfo { + /// Build public compiled-model metadata from a lowered execution model. pub fn from_execution_model(model: &ExecutionModel) -> Self { let explicit_route_input_usage = explicit_route_input_usage(model); Self { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 59c399ab..1cef784e 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -1,3 +1,82 @@ +//! Unified runtime entrypoints for DSL-backed models. +//! +//! Use this module when you already know you want an executable model and need +//! one backend-neutral surface for compile, load, and prediction workflows. +//! It normalizes the backend-specific JIT, native AoT, and WASM entrypoints so +//! callers can choose a deployment target without rewriting the downstream +//! prediction code. +//! +//! Use the backend modules directly only when you need a backend-specific +//! artifact or compile control: +//! +//! - [`super::jit`] for direct in-process JIT compilation. +//! - [`compile_module_source_to_aot`][crate::dsl::compile_module_source_to_aot] for native artifact export and reload. +//! - [`compile_module_source_to_wasm_bytes`][crate::dsl::compile_module_source_to_wasm_bytes] and [`load_runtime_wasm_bytes`] for portable WASM bytes, +//! browser-loader assets, and host-side WASM loading. +//! +//! Main entrypoints: +//! +//! - [`compile_module_source_to_runtime`] for the one-stop source-to-runtime +//! path. +//! - [`compile_execution_model_to_runtime`] when you already have an +//! [`ExecutionModel`](pharmsol_dsl::ExecutionModel). +//! - [`load_runtime_artifact`] and [`load_runtime_wasm_bytes`] when the model +//! has already been compiled and stored elsewhere. +//! - [`CompiledRuntimeModel::estimate_predictions`] for backend-neutral +//! execution against a [`Subject`](crate::Subject). +//! +//! Backend choice guide: +//! +//! - [`RuntimeCompilationTarget::Jit`] keeps compilation and execution inside +//! the current process. Use it for native interactive workflows and tests. +//! - [`RuntimeCompilationTarget::NativeAot`] emits a native artifact and reloads +//! it into the same runtime model shape. Use it when you want reusable native +//! artifacts and can control the target platform. +//! - [`RuntimeCompilationTarget::Wasm`] emits portable WASM bytes and reloads +//! them into the host-side runtime adapter. Use it when you need a portable +//! artifact or browser-aligned deployment story. +//! +//! Smallest compile-and-run example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! use pharmsol::prelude::*; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .missing_observation(2.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.2, 50.0])?; +//! assert!(predictions.as_subject().is_some()); +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` + use std::fmt; use std::path::Path; @@ -39,24 +118,39 @@ pub type RuntimeOdeModel = NativeOdeModel; pub type RuntimeAnalyticalModel = NativeAnalyticalModel; pub type RuntimeSdeModel = NativeSdeModel; +/// Selects which backend should produce the executable runtime model. +/// +/// This enum is the main backend-switching point for +/// [`compile_module_source_to_runtime`] and +/// [`compile_execution_model_to_runtime`]. #[derive(Debug, Clone, PartialEq, Eq)] pub enum RuntimeCompilationTarget { + /// Compile and execute the model inside the current native process. #[cfg(feature = "dsl-jit")] Jit, + /// Export a native artifact and reload it as a runtime model. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot(NativeAotCompileOptions), + /// Emit WASM bytes and reload them through the host-side WASM runtime. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Identifies the on-disk artifact format for [`load_runtime_artifact`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RuntimeArtifactFormat { + /// A native ahead-of-time artifact produced by the AoT compiler. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot, + /// A WASM artifact produced by the WASM compiler. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Backend-neutral prediction output from a compiled runtime model. +/// +/// ODE and analytical models return subject predictions. SDE models return the +/// particle matrix used by the stochastic workflow. #[derive(Clone, Debug)] pub enum RuntimePredictions { Subject(SubjectPredictions), @@ -93,6 +187,10 @@ impl RuntimePredictions { } } +/// Executable runtime model returned by the backend-neutral runtime surface. +/// +/// This type hides the concrete backend and keeps the prediction entrypoint the +/// same across JIT, native AoT, and WASM-based flows. #[derive(Clone, Debug)] pub enum CompiledRuntimeModel { Ode(RuntimeOdeModel), @@ -166,6 +264,8 @@ impl CompiledRuntimeModel { } } +/// Errors produced while parsing, lowering, compiling, loading, or executing a +/// runtime DSL model. #[derive(Error)] pub enum RuntimeError { #[error("failed to parse DSL source: {0}")] @@ -231,6 +331,10 @@ impl fmt::Debug for RuntimeError { } } +/// Parse, analyze, lower, compile, and return a runtime model in one step. +/// +/// Use this when your input is DSL source text and you want the shortest path +/// from source to predictions. pub fn compile_module_source_to_runtime( source: &str, model_name: Option<&str>, @@ -269,6 +373,10 @@ pub fn compile_module_source_to_runtime( }) } +/// Compile a lowered execution model to a selected runtime backend. +/// +/// Use this when you already own the frontend pipeline and only need the final +/// backend step. pub fn compile_execution_model_to_runtime( model: &ExecutionModel, target: RuntimeCompilationTarget, @@ -309,6 +417,7 @@ pub fn compile_execution_model_to_runtime( } } +/// Load a previously compiled native AoT or WASM artifact from disk. pub fn load_runtime_artifact( path: impl AsRef, format: RuntimeArtifactFormat, @@ -330,6 +439,7 @@ pub fn load_runtime_artifact( } #[cfg(feature = "dsl-wasm")] +/// Compile DSL source straight to a host-side runtime model via the WASM path. pub fn compile_module_source_to_runtime_wasm( source: &str, model_name: Option<&str>, @@ -339,6 +449,8 @@ pub fn compile_module_source_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Compile a lowered execution model straight to a host-side runtime model via +/// the WASM path. pub fn compile_execution_model_to_runtime_wasm( model: &ExecutionModel, ) -> Result { @@ -347,6 +459,7 @@ pub fn compile_execution_model_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Load a runtime model from in-memory WASM bytes. pub fn load_runtime_wasm_bytes(bytes: &[u8]) -> Result { let (info, artifact) = load_wasm_artifact_bytes(bytes)?; Ok(runtime_model_from_parts(info, artifact)) diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index f2504d44..e95b799a 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -406,11 +406,39 @@ impl RuntimeArtifact for WasmExecutionArtifact { } } +/// Read only the metadata from a compiled WASM artifact on disk. +/// +/// Use this when you need model identity, route labels, output labels, or +/// buffer sizes without loading the executable runtime wrapper. pub fn read_wasm_model_info(path: impl AsRef) -> Result { let (info, _) = load_wasm_artifact(path)?; Ok(info) } +/// Read only the metadata from in-memory compiled WASM bytes. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{compile_module_source_to_wasm_bytes, read_wasm_model_info_bytes}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let bytes = compile_module_source_to_wasm_bytes(source, Some("bimodal_ke"))?; +/// let info = read_wasm_model_info_bytes(&bytes)?; +/// assert_eq!(info.name, "bimodal_ke"); +/// # Ok::<(), Box>(()) +/// ``` pub fn read_wasm_model_info_bytes(bytes: &[u8]) -> Result { let (info, _) = load_wasm_artifact_bytes(bytes)?; Ok(info) diff --git a/src/dsl/wasm_compile.rs b/src/dsl/wasm_compile.rs index caa60216..cda4727d 100644 --- a/src/dsl/wasm_compile.rs +++ b/src/dsl/wasm_compile.rs @@ -19,15 +19,24 @@ use pharmsol_dsl::{ LoweringError, ParseError, SemanticError, }; +/// ABI version for compiled WASM artifacts produced by this crate. pub const WASM_API_VERSION: u32 = 1; +/// Default entry capacity for [`WasmCompileCache`]. pub const DEFAULT_WASM_COMPILE_CACHE_CAPACITY: usize = 32; static BROWSER_LOADER_SOURCE: OnceLock = OnceLock::new(); +/// Portable WASM artifact bundle produced by the WASM compiler path. +/// +/// The bundle includes the raw WASM bytes, model metadata, and a browser loader +/// source string that can instantiate the model in JavaScript. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct CompiledWasmModule { + /// Raw compiled WASM bytes. pub wasm_bytes: Vec, + /// Serialized model metadata and kernel availability. pub metadata: CompiledModelInfoEnvelope, + /// JavaScript loader source for browser-side instantiation. pub browser_loader_source: String, } @@ -52,6 +61,7 @@ struct WasmCompileCacheState { lru: VecDeque, } +/// In-memory LRU cache for repeated WASM compilation from the same DSL source. #[derive(Debug)] pub struct WasmCompileCache { capacity: usize, @@ -65,6 +75,7 @@ impl Default for WasmCompileCache { } impl WasmCompileCache { + /// Create a new compile cache with at least one entry of capacity. pub fn new(capacity: usize) -> Self { Self { capacity: capacity.max(1), @@ -72,10 +83,12 @@ impl WasmCompileCache { } } + /// Return the configured cache capacity. pub fn capacity(&self) -> usize { self.capacity } + /// Return the number of cached compiled modules. pub fn entry_count(&self) -> usize { self.state .lock() @@ -84,6 +97,7 @@ impl WasmCompileCache { .len() } + /// Remove all cached compiled modules. pub fn clear(&self) { let mut state = self .state @@ -93,6 +107,8 @@ impl WasmCompileCache { state.lru.clear(); } + /// Compile DSL source to a full WASM module bundle, reusing the cache when + /// possible. pub fn compile_module_source_to_wasm_module( &self, source: &str, @@ -108,6 +124,7 @@ impl WasmCompileCache { Ok(compiled) } + /// Compile DSL source to raw WASM bytes, reusing the cache when possible. pub fn compile_module_source_to_wasm_bytes( &self, source: &str, @@ -145,6 +162,8 @@ impl WasmCompileCache { } } +/// Error produced while compiling, inspecting, or loading a DSL-backed WASM +/// artifact. #[derive(Error)] pub enum WasmError { #[error(transparent)] @@ -224,10 +243,12 @@ impl fmt::Debug for WasmError { } } +/// Compile a lowered execution model to raw WASM bytes. pub fn compile_execution_model_to_wasm_bytes(model: &ExecutionModel) -> Result, WasmError> { emit_execution_model_to_wasm_bytes(model, WASM_API_VERSION) } +/// Compile a lowered execution model to a portable WASM bundle. pub fn compile_execution_model_to_wasm_module( model: &ExecutionModel, ) -> Result { @@ -238,6 +259,7 @@ pub fn compile_execution_model_to_wasm_module( }) } +/// Parse DSL source, lower one selected model, and return raw WASM bytes. pub fn compile_module_source_to_wasm_bytes( source: &str, model_name: Option<&str>, @@ -245,6 +267,35 @@ pub fn compile_module_source_to_wasm_bytes( Ok(compile_module_source_to_wasm_module(source, model_name)?.wasm_bytes) } +/// Parse DSL source, lower one selected model, and return the full WASM bundle. +/// +/// Use this when you want a portable artifact for browser or host-side loading +/// together with the browser loader source. +/// +/// This function requires `dsl-wasm-compile`. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{browser_loader_source, compile_module_source_to_wasm_module}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let compiled = compile_module_source_to_wasm_module(source, Some("bimodal_ke"))?; +/// let loader = browser_loader_source(); +/// # let _ = (compiled, loader); +/// # Ok::<(), pharmsol::dsl::WasmError>(()) +/// ``` pub fn compile_module_source_to_wasm_module( source: &str, model_name: Option<&str>, @@ -282,6 +333,10 @@ fn compile_module_source_to_wasm_module_uncached( compile_execution_model_to_wasm_module(&execution) } +/// Return the JavaScript loader source for browser-side WASM model execution. +/// +/// This helper is useful when you want to ship compiled WASM bytes together +/// with the minimal browser glue code that understands the pharmsol ABI. pub fn browser_loader_source() -> String { BROWSER_LOADER_SOURCE .get_or_init(build_browser_loader_source) diff --git a/src/lib.rs b/src/lib.rs index c84d4ee1..9c9e40b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,105 @@ +//! `pharmsol` is a Rust library for pharmacometric work. +//! +//! You can use it to: +//! +//! - build PK/PD datasets from dose and observation events +//! - simulate analytical, ODE, and SDE models +//! - run non-compartmental analysis (NCA) +//! - compile and run models from the pharmsol DSL when the DSL features are enabled +//! +//! Most users start in one of these places: +//! +//! - [`prelude`] for the common types, traits, and macros +//! - [`data`] to build subjects, occasions, events, and covariates +//! - [`simulator`] to define models and generate predictions +//! - [`nca`] to calculate NCA metrics from the same data structures +//! - [`optimize`] for optimizer-oriented workflows +//! +//! The DSL runtime surface is feature-gated. When you enable `dsl-core`, the +//! `pharmsol::dsl` module adds parsing, analysis, lowering, compile, and runtime +//! entrypoints for models written as DSL source text. +//! +//! ## Quick Start +//! +//! This example shows the smallest full workflow: define a model, build a +//! subject, and generate predictions. +//! +//! ```rust +//! use pharmsol::prelude::*; +//! +//! let model = analytical! { +//! name: "one_cmt_iv", +//! params: [ke, v], +//! states: [central], +//! outputs: [cp], +//! routes: [ +//! infusion(iv) -> central, +//! ], +//! structure: one_compartment, +//! out: |x, _p, _t, _cov, y| { +//! y[cp] = x[central] / v; +//! }, +//! }; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.022, 194.0])?; +//! assert_eq!(predictions.flat_predictions().len(), 2); +//! # Ok::<(), pharmsol::PharmsolError>(()) +//! ``` +//! +//! ## Choose A Workflow +//! +//! Use this guide when you are deciding where to start. +//! +//! | Task | Start Here | Notes | +//! | --- | --- | --- | +//! | Build subject data | [`data`] or [`prelude`] | Best when you already know dose times, labels, and observations. | +//! | Simulate a model written in Rust | [`simulator`] or [`prelude`] | Supports analytical, ODE, and SDE models. | +//! | Run NCA | [`nca`] or [`prelude`] | Reuses the same `Subject`, `Occasion`, and `Data` types. | +//! | Use optimization helpers | [`optimize`] | Intended for advanced workflows. | +//! | Parse or compile DSL source | `pharmsol::dsl` | Requires one or more DSL features. | +//! +//! ## Feature Guide +//! +//! Core simulation and NCA APIs do not need extra crate features on native +//! targets. +//! +//! DSL work is feature-gated: +//! +//! - `dsl-core`: exposes the `pharmsol::dsl` facade and frontend types +//! - `dsl-jit`: adds in-process JIT compilation +//! - `dsl-aot`: adds native ahead-of-time artifact compilation +//! - `dsl-aot-load`: adds native artifact loading +//! - `dsl-wasm-compile`: adds WASM artifact generation +//! - `dsl-wasm`: adds WASM runtime loading and execution +//! +//! ## Labels And Indices +//! +//! Public data APIs use route labels and output labels such as `"iv"`, +//! `"oral"`, and `"cp"`. +//! +//! Use labels in builders and parsed data unless you are deliberately working +//! with dense internal indices from a lower-level API. +//! +//! ## Platform Notes +//! +//! The main `data`, `simulator`, `nca`, and `optimize` modules are documented +//! for native targets. Some surfaces are not built on `wasm32-unknown-unknown`. +//! The DSL runtime also has feature-specific platform limits. +//! +//! ## Next Stops +//! +//! - Start with [`prelude`] if you want one import for the common workflow. +//! - Open [`data`] if you need to construct subjects or parse input files. +//! - Open [`simulator`] if you need predictions from analytical, ODE, or SDE models. +//! - Open [`nca`] if you need exposure and terminal metrics. +//! - Use `pharmsol::dsl` if the model comes from source text instead of Rust code. + #[cfg(feature = "dsl-aot")] mod build_support; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] @@ -49,19 +151,31 @@ pub use pharmsol_macros::{analytical, ode, sde}; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use std::collections::HashMap; -/// Prelude module that re-exports all commonly used types and traits. +/// Common imports for the main pharmsol workflow. +/// +/// Use the prelude when you want one import that covers the common public API: +/// +/// - subject and dataset types +/// - subject builders and events +/// - simulation types and prediction results +/// - NCA traits and option types +/// - declaration-first macros such as [`crate::ode`] and [`crate::analytical`] /// -/// Importing `pharmsol::prelude::*` brings the main modeling, simulation, -/// and data APIs into scope. +/// This is the fastest way to get started with examples, scripts, and small +/// applications. +/// +/// If you need a narrower import surface, use the modules directly instead. /// /// # Example /// ```rust /// use pharmsol::prelude::*; /// /// let subject = Subject::builder("patient_001") -/// .bolus(0.0, 100.0, 0) -/// .observation(1.0, 10.5, 0) +/// .infusion(0.0, 100.0, "iv", 1.0) +/// .missing_observation(1.0, "cp") /// .build(); +/// +/// assert_eq!(subject.id(), "patient_001"); /// ``` #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod prelude { diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs index fecab7e2..c7fbd4c9 100644 --- a/src/simulator/equation/metadata.rs +++ b/src/simulator/equation/metadata.rs @@ -1,17 +1,40 @@ -//! Shared model metadata for handwritten simulator models. +//! Metadata builders and validated metadata views for handwritten models. //! -//! This module defines the public metadata contract that handwritten ODE, -//! analytical, and SDE models can attach to. The field set is intentionally -//! aligned with the public subset of the DSL/runtime metadata surface. +//! Use this module when a handwritten [`crate::ODE`], [`crate::Analytical`], or +//! [`crate::SDE`] model should expose the same public names that appear in data +//! rows, subject builders, or parsed files. //! -//! Internal runtime layout details such as dense buffer lengths, derived buffer -//! shape, or ABI-specific offsets remain internal for now. +//! Metadata gives names to parameters, covariates, states, routes, and outputs. +//! After validation, the execution layer can resolve public labels such as +//! `"iv"` and `"cp"` against those declarations before simulation. +//! +//! Without metadata, handwritten models fall back to numeric labels. With +//! metadata, labels are matched by name. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.name(), "one_cmt"); +//! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); +//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! ``` use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind}; use std::fmt; use thiserror::Error; -/// Create a new handwritten-model metadata builder. +/// Shorthand for [`ModelMetadata::new`]. pub fn new(name: impl Into) -> ModelMetadata { ModelMetadata::new(name) } @@ -71,7 +94,17 @@ impl fmt::Display for NameDomain { } } -/// Immutable validated metadata view used by later attachment slices. +/// Validated metadata view used by the execution layer. +/// +/// This type is what handwritten equation builders store after metadata has +/// passed validation. It provides stable lookup helpers from public names to the +/// dense indices used during execution. +/// +/// Route lookups expose two different indices: +/// - [`ValidatedModelMetadata::route_declaration_index`] is the route position in +/// declaration order. +/// - [`ValidatedModelMetadata::route_index`] is the dense execution input index +/// for that route kind. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedModelMetadata { name: String, @@ -87,10 +120,12 @@ pub struct ValidatedModelMetadata { } impl ValidatedModelMetadata { + /// Get the public model name. pub fn name(&self) -> &str { &self.name } + /// Get the validated model family. pub fn kind(&self) -> ModelKind { self.kind } @@ -111,6 +146,9 @@ impl ValidatedModelMetadata { &self.routes } + /// Get the number of dense execution input slots needed for routes. + /// + /// This is the maximum of the bolus-route count and infusion-route count. pub fn route_input_count(&self) -> usize { self.route_input_count } @@ -143,14 +181,17 @@ impl ValidatedModelMetadata { self.states.iter().position(|state| state.name() == name) } + /// Look up a route by public name and return its dense execution input index. pub fn route_index(&self, name: &str) -> Option { self.route(name).map(ValidatedRoute::input_index) } + /// Look up a route by public name and return its declaration-order index. pub fn route_declaration_index(&self, name: &str) -> Option { self.routes.iter().position(|route| route.name() == name) } + /// Look up an output by public name and return its dense output index. pub fn output_index(&self, name: &str) -> Option { self.outputs.iter().position(|output| output.name() == name) } @@ -179,7 +220,11 @@ impl ValidatedModelMetadata { } } -/// One validated route declaration with resolved destination state index. +/// One validated route declaration with resolved execution details. +/// +/// A validated route keeps both the declaration-order index and the dense input +/// index used during execution. Those values can differ from each other when a +/// model mixes bolus and infusion routes. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedRoute { name: String, @@ -194,6 +239,7 @@ pub struct ValidatedRoute { } impl ValidatedRoute { + /// Get the public route name used for label matching. pub fn name(&self) -> &str { &self.name } @@ -202,18 +248,22 @@ impl ValidatedRoute { self.kind } + /// Get the declaration-order index for this route. pub fn declaration_index(&self) -> usize { self.declaration_index } + /// Get the dense execution input index for this route kind. pub fn input_index(&self) -> usize { self.input_index } + /// Get the destination state name. pub fn destination(&self) -> &str { &self.destination } + /// Get the destination state index in model order. pub fn destination_index(&self) -> usize { self.destination_index } @@ -231,7 +281,12 @@ impl ValidatedRoute { } } -/// Metadata describing one handwritten simulator model. +/// Builder for handwritten model metadata. +/// +/// Use [`ModelMetadata`] to declare the public names that should be attached to +/// a handwritten equation. After validation, the resulting metadata can be +/// attached to handwritten [`crate::ODE`], [`crate::Analytical`], and +/// [`crate::SDE`] models through their `with_metadata(...)` methods. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ModelMetadata { name: String, @@ -379,11 +434,17 @@ impl ModelMetadata { } /// Validate this metadata using its declared kind. + /// + /// Use this when the metadata itself already declares whether the model is + /// ODE, analytical, or SDE. pub fn validate(self) -> Result { self.validate_internal(None, None) } /// Validate this metadata for a specific model kind. + /// + /// Use this when the equation type determines the model family and you want + /// validation to enforce that family explicitly. pub fn validate_for( self, kind: ModelKind, @@ -440,6 +501,7 @@ pub struct Parameter { } impl Parameter { + /// Create a named parameter declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -466,6 +528,7 @@ pub struct Covariate { } impl Covariate { + /// Create a named covariate without an explicit interpolation policy. pub fn new(name: impl Into) -> Self { Self { name: name.into(), @@ -473,14 +536,17 @@ impl Covariate { } } + /// Create a continuous covariate that uses linear interpolation. pub fn continuous(name: impl Into) -> Self { Self::new(name).with_interpolation(CovariateInterpolation::Linear) } + /// Create a covariate that uses last-observation-carried-forward semantics. pub fn locf(name: impl Into) -> Self { Self::new(name).with_interpolation(CovariateInterpolation::Locf) } + /// Set the interpolation policy explicitly. pub fn with_interpolation(mut self, interpolation: CovariateInterpolation) -> Self { self.interpolation = Some(interpolation); self @@ -502,6 +568,7 @@ pub struct State { } impl State { + /// Create a named state declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -527,6 +594,7 @@ pub struct Output { } impl Output { + /// Create a named output declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -548,18 +616,25 @@ where /// Route declaration kind. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteKind { + /// Instantaneous dose input. Bolus, + /// Dose input over a duration. Infusion, } /// How route inputs should be interpreted by the execution layer. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteInputPolicy { + /// Inject the resolved input directly into the declared destination state. InjectToDestination, + /// Expect the low-level execution path to provide an explicit input vector. ExplicitInputVector, } /// One named route declaration. +/// +/// Route names are the public labels matched against dose events such as `iv` +/// or `oral`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Route { name: String, @@ -571,14 +646,17 @@ pub struct Route { } impl Route { + /// Create a named bolus route declaration. pub fn bolus(name: impl Into) -> Self { Self::new(name, RouteKind::Bolus) } + /// Create a named infusion route declaration. pub fn infusion(name: impl Into) -> Self { Self::new(name, RouteKind::Infusion) } + /// Create a route declaration with an explicit kind. pub fn new(name: impl Into, kind: RouteKind) -> Self { Self { name: name.into(), @@ -590,26 +668,31 @@ impl Route { } } + /// Declare which state this route targets. pub fn to_state(mut self, destination: impl Into) -> Self { self.destination = Some(destination.into()); self } + /// Mark this route as supporting lag handling. pub fn with_lag(mut self) -> Self { self.has_lag = true; self } + /// Mark this route as supporting bioavailability handling. pub fn with_bioavailability(mut self) -> Self { self.has_bioavailability = true; self } + /// Request direct injection into the destination state at execution time. pub fn inject_input_to_destination(mut self) -> Self { self.input_policy = Some(RouteInputPolicy::InjectToDestination); self } + /// Request an explicit low-level input vector at execution time. pub fn expect_explicit_input(mut self) -> Self { self.input_policy = Some(RouteInputPolicy::ExplicitInputVector); self diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index c5a97958..03e5318c 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -1,3 +1,51 @@ +//! Handwritten equation families and their shared simulation interfaces. +//! +//! This module is the public home for handwritten [`ODE`], [`Analytical`], and +//! [`SDE`] models, plus the shared [`Equation`] trait and the metadata types +//! that attach public names to parameters, states, routes, and outputs. +//! +//! Use this module when you want to: +//! - choose between deterministic ODE, analytical, and stochastic SDE models +//! - attach metadata so dataset labels such as `"iv"` and `"cp"` resolve by +//! name instead of by dense numeric index +//! - work with prediction or likelihood APIs across equation families +//! +//! # Equation Families +//! +//! - [`ODE`] for deterministic models that must be numerically integrated. +//! - [`Analytical`] for supported closed-form models. +//! - [`SDE`] for stochastic models that use particles. +//! +//! # Labels And Metadata +//! +//! Input and output labels arrive from public data APIs as strings. +//! +//! - Without metadata, handwritten models fall back to numeric labels such as +//! `0` or `1`. +//! - With [`metadata::ModelMetadata`] attached, route and output labels are +//! resolved by name against the declared routes and outputs before +//! simulation. +//! +//! That label-first path is the preferred public workflow for current authoring. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.route_index("iv"), Some(0)); +//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! ``` + use std::fmt::Debug; pub mod analytical; pub mod metadata; @@ -20,10 +68,10 @@ use super::likelihood::Prediction; /// Trait for state vectors that can receive bolus doses. pub trait State { - /// Add a bolus dose to the state at the specified input compartment. + /// Add a bolus dose to the state at the specified resolved input index. /// /// # Parameters - /// - `input`: The compartment index + /// - `input`: The resolved dense input index used by the execution layer /// - `amount`: The bolus amount fn add_bolus(&mut self, input: usize, amount: f64); } @@ -114,7 +162,7 @@ pub trait Cache: Sized { fn disable_cache(self) -> Self; } -/// Trait defining the associated types for equations. +/// Associated state and prediction container types for an equation family. pub trait EquationTypes { /// The state vector type type S: State + Debug; @@ -308,11 +356,15 @@ pub(crate) trait EquationPriv: EquationTypes { } } -/// Trait for model equations that can be simulated. +/// Trait for handwritten model equations that can be simulated. +/// +/// [`Equation`] is the shared interface implemented by handwritten [`ODE`], +/// [`Analytical`], and [`SDE`] models. /// -/// This trait defines the interface for different types of model equations -/// (ODE, SDE, analytical) that can be simulated to generate predictions -/// and estimate parameters. +/// Subject data enters this layer through public labels on dose and observation +/// events. If metadata is attached to the equation, those labels are resolved by +/// name before simulation. Otherwise, the execution layer expects numeric labels +/// that can be interpreted as dense indices. /// /// # Likelihood Calculation /// @@ -440,6 +492,7 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { } } +/// Runtime family tag for handwritten equations. #[repr(C)] #[derive(Clone, Debug)] pub enum EqnKind { From ddc0b1982ec2674f493c468da6827b29932041d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 8 May 2026 09:35:42 +0100 Subject: [PATCH 19/22] feat: Answering Markus' comments --- examples/analytical_vs_ode.rs | 44 +- examples/covariates.rs | 18 +- examples/macro_vs_handwritten_one_cpt.rs | 9 - examples/macro_vs_handwritten_two_cpt.rs | 22 +- examples/one_compartment.rs | 33 +- examples/two_compartment.rs | 24 +- pharmsol-dsl/README.md | 81 ++-- pharmsol-dsl/src/authoring.rs | 137 ++++++- pharmsol-dsl/src/execution.rs | 63 +++ pharmsol-dsl/src/lib.rs | 70 +++- pharmsol-dsl/src/semantic.rs | 120 ++++++ .../tests/dsl_authoring_edge_cases.rs | 229 ++++++++++- src/data/error_model.rs | 387 +++++++++++++++--- src/data/event.rs | 245 +++++++---- src/dsl/aot.rs | 51 ++- src/dsl/jit.rs | 80 +++- src/dsl/model_info.rs | 38 ++ src/dsl/native.rs | 108 +++-- src/dsl/runtime.rs | 205 +++++++--- src/simulator/equation/analytical/mod.rs | 88 +++- src/simulator/equation/metadata.rs | 13 +- src/simulator/equation/mod.rs | 76 +++- src/simulator/equation/ode/mod.rs | 100 +++-- src/simulator/equation/sde/mod.rs | 70 +++- src/simulator/likelihood/mod.rs | 10 +- src/simulator/likelihood/prediction.rs | 2 +- src/simulator/likelihood/subject.rs | 4 +- tests/analytical_macro_lowering.rs | 24 +- tests/authoring_parity_corpus.rs | 242 +++++++---- tests/full_feature_dsl_backend_parity.rs | 8 - tests/full_feature_macro_parity.rs | 48 ++- tests/ode_macro_lowering.rs | 75 ++-- tests/ode_optimizations.rs | 2 +- tests/sde_macro_lowering.rs | 17 +- tests/support/bimodal_ke.rs | 18 +- tests/support/runtime_corpus.rs | 115 ++++-- tests/test_pf.rs | 2 +- 37 files changed, 2200 insertions(+), 678 deletions(-) diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 3fd58fd1..48c4986f 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -13,28 +13,28 @@ use pharmsol::prelude::*; // ── Subjects ─────────────────────────────────────────────────────── -fn subject_iv(input: usize, output: usize) -> Subject { +fn subject_iv(input: impl ToString, output: impl ToString) -> Subject { Subject::builder("1") .infusion(0.0, 500.0, input, 0.5) - .observation(0.5, 0.0, output) - .observation(1.0, 0.0, output) - .observation(2.0, 0.0, output) - .observation(4.0, 0.0, output) - .observation(8.0, 0.0, output) - .observation(12.0, 0.0, output) + .observation(0.5, 0.0, output.to_string()) + .observation(1.0, 0.0, output.to_string()) + .observation(2.0, 0.0, output.to_string()) + .observation(4.0, 0.0, output.to_string()) + .observation(8.0, 0.0, output.to_string()) + .observation(12.0, 0.0, output.to_string()) .observation(24.0, 0.0, output) .build() } -fn subject_oral(input: usize, output: usize) -> Subject { +fn subject_oral(input: impl ToString, output: impl ToString) -> Subject { Subject::builder("1") .bolus(0.0, 500.0, input) - .observation(0.5, 0.0, output) - .observation(1.0, 0.0, output) - .observation(2.0, 0.0, output) - .observation(4.0, 0.0, output) - .observation(8.0, 0.0, output) - .observation(12.0, 0.0, output) + .observation(0.5, 0.0, output.to_string()) + .observation(1.0, 0.0, output.to_string()) + .observation(2.0, 0.0, output.to_string()) + .observation(4.0, 0.0, output.to_string()) + .observation(8.0, 0.0, output.to_string()) + .observation(12.0, 0.0, output.to_string()) .observation(24.0, 0.0, output) .build() } @@ -97,9 +97,7 @@ fn one_cmt_iv(params: &[f64]) { }, }; - let iv = analytical.route_index("iv").expect("iv route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - let subject = subject_iv(iv, cp); + let subject = subject_iv("iv", "cp"); let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); let pred_o = ode.estimate_predictions(&subject, params).unwrap(); @@ -140,9 +138,7 @@ fn one_cmt_oral(params: &[f64]) { }, }; - let oral = analytical.route_index("oral").expect("oral route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - let subject = subject_oral(oral, cp); + let subject = subject_oral("oral", "cp"); let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); let pred_o = ode.estimate_predictions(&subject, params).unwrap(); @@ -183,9 +179,7 @@ fn two_cmt_iv(params: &[f64]) { }, }; - let iv = analytical.route_index("iv").expect("iv route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - let subject = subject_iv(iv, cp); + let subject = subject_iv("iv", "cp"); let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); let pred_o = ode.estimate_predictions(&subject, params).unwrap(); @@ -227,9 +221,7 @@ fn two_cmt_oral(params: &[f64]) { }, }; - let oral = analytical.route_index("oral").expect("oral route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - let subject = subject_oral(oral, cp); + let subject = subject_oral("oral", "cp"); let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); let pred_o = ode.estimate_predictions(&subject, params).unwrap(); diff --git a/examples/covariates.rs b/examples/covariates.rs index 180a0173..83516ebb 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -27,22 +27,18 @@ fn main() { }, }; - let oral = ode.route_index("oral").expect("oral route exists"); - let cp = ode.output_index("cp").expect("cp output exists"); - - // Create a subject with metadata-backed route and output names instead of - // hard-coded numeric indices. + // Create a subject using route and output labels directly. let subject = Subject::builder("id1") - .bolus(0.0, 100.0, oral) + .bolus(0.0, 100.0, "oral") .repeat(2, 2.0) - .observation(0.5, 0.1, cp) - .observation(1.0, 0.4, cp) - .observation(2.0, 1.0, cp) - .observation(2.5, 1.1, cp) + .observation(0.5, 0.1, "cp") + .observation(1.0, 0.4, "cp") + .observation(2.0, 1.0, "cp") + .observation(2.5, 1.1, "cp") .covariate("creatinine", 0.0, 80.0) .covariate("creatinine", 1.0, 40.0) .covariate("age", 0.0, 25.0) - .missing_observation(8.0, cp) + .missing_observation(8.0, "cp") .build(); // Define parameter values diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index c7b088a5..5e8ec9ec 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -26,9 +26,6 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( - // Handwritten closures stay on dense internal slots. - // Public labels like `iv` and `cp` live in attached metadata, not in - // the low-level `rateiv[]` / `y[]` buffers. |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -71,12 +68,6 @@ fn main() -> Result<(), pharmsol::PharmsolError> { assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - let iv = macro_ode.route_index("iv").expect("iv route exists"); - let cp = macro_ode.output_index("cp").expect("cp output exists"); - - assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); - assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); - let subject = Subject::builder("macro-vs-handwritten-one-cpt") .infusion(0.0, 500.0, "iv", 0.5) .missing_observation(0.5, "cp") diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 377e1e88..d3c10a0f 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -29,10 +29,6 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( - // Handwritten closures stay on dense internal slots. - // Public route labels like `load` and `iv` are metadata names; the - // low-level `bolus[]`, `rateiv[]`, and `y[]` buffers remain indexed by - // dense internal slots. |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; @@ -76,17 +72,17 @@ fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { fn main() -> Result<(), pharmsol::PharmsolError> { let macro_ode = macro_model(); let handwritten_ode = handwritten_model(); + let macro_metadata = macro_ode.metadata().expect("macro metadata exists"); assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - - let load = macro_ode.route_index("load").expect("load route exists"); - let iv = macro_ode.route_index("iv").expect("iv route exists"); - let cp = macro_ode.output_index("cp").expect("cp output exists"); - - assert_eq!(load, iv, "load and iv should share one numeric input"); - assert_eq!(handwritten_ode.route_index("load"), Some(load)); - assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); - assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + assert_eq!( + macro_metadata + .route("load") + .map(|route| route.input_index()), + macro_metadata.route("iv").map(|route| route.input_index()), + "load and iv should share one numeric input" + ); + assert!(macro_metadata.output("cp").is_some()); let subject = Subject::builder("macro-vs-handwritten-two-cpt") .bolus(0.0, 100.0, "load") diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index 021e06f2..e5813e2a 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -31,28 +31,23 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = analytical.route_index("iv").expect("iv route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - - // Create a subject using metadata-backed route and output names instead of - // hard-coded indices. + // Create a subject using route and output labels directly. let subject = Subject::builder("Nikola Tesla") - .infusion(0., 500.0, iv, 0.5) - .observation(0.5, 1.645, cp) - .observation(1., 1.216, cp) - .observation(2., 0.462, cp) - .observation(3., 0.169, cp) - .observation(4., 0.063, cp) - .observation(6., 0.009, cp) - .observation(8., 0.001, cp) - .missing_observation(12.0, cp) + .infusion(0., 500.0, "iv", 0.5) + .observation(0.5, 1.645, "cp") + .observation(1., 1.216, "cp") + .observation(2., 0.462, "cp") + .observation(3., 0.169, "cp") + .observation(4., 0.063, "cp") + .observation(6., 0.009, "cp") + .observation(8., 0.001, "cp") + .missing_observation(12.0, "cp") .build(); - // Define the error models for the observations - let ems = AssayErrorModels::new(). - // For this example, we use a simple additive error model with 5% error - add( - cp, + // Define the assay error models once by label and reuse them across both + // equations. + let ems = AssayErrorModels::new().add( + "cp", AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), )?; diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index fdba715e..e6f44e32 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -65,22 +65,18 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = ode.route_index("iv").expect("iv route exists"); - let cp = ode.output_index("cp").expect("cp output exists"); - - // Create a subject using metadata-backed route and output names instead of - // hard-coded numeric indices. + // Create a subject using route and output labels directly. let subject = Subject::builder("subject_001") - .infusion(0.0, 500.0, iv, 0.5) + .infusion(0.0, 500.0, "iv", 0.5) .covariate("wt", 0.0, 70.0) - .observation(0.5, 8.5, cp) - .observation(1.0, 6.2, cp) - .observation(2.0, 4.1, cp) - .observation(4.0, 2.3, cp) - .observation(6.0, 1.5, cp) - .observation(8.0, 1.1, cp) - .observation(12.0, 0.7, cp) - .missing_observation(24.0, cp) + .observation(0.5, 8.5, "cp") + .observation(1.0, 6.2, "cp") + .observation(2.0, 4.1, "cp") + .observation(4.0, 2.3, "cp") + .observation(6.0, 1.5, "cp") + .observation(8.0, 1.1, "cp") + .observation(12.0, 0.7, "cp") + .missing_observation(24.0, "cp") .build(); // Define parameter values diff --git a/pharmsol-dsl/README.md b/pharmsol-dsl/README.md index e3a2469c..0280727b 100644 --- a/pharmsol-dsl/README.md +++ b/pharmsol-dsl/README.md @@ -1,50 +1,65 @@ # pharmsol-dsl -`pharmsol-dsl` is the extraction target for the backend-neutral frontend of the pharmsol DSL. +`pharmsol-dsl` is the backend-neutral frontend crate for the pharmsol DSL. -The crate is introduced as an internal workspace member so the codebase can move to a clean engine / DSL split without duplicating workflows or breaking the current `pharmsol::dsl` user-facing API mid-migration. +Use this crate when you need to work with model source as data: -## Current Status +- parse DSL text into syntax nodes +- inspect spans and diagnostics +- analyze names and types into typed IR +- lower validated models into the execution model used by runtime backends -Slices 1 through 7 have moved the shared frontend data modules, parsing frontend, semantic analysis, and execution lowering here, rewired `pharmsol` backend modules to consume that frontend directly, and cleaned up frontend test ownership: +Do not use this crate for JIT compilation, native AoT export or load, WASM runtime loading, or `Subject`-based prediction helpers. Those workflows stay in `pharmsol::dsl` in the main `pharmsol` crate. -- AST types -- diagnostics and spans -- typed IR -- lexer -- parser -- authoring desugaring used by the parser -- semantic analysis and semantic diagnostics -- execution lowering and execution model types -- frontend-only integration tests and authoring fixtures +## Main Pipeline -`pharmsol::dsl` now acts as a deliberate compatibility façade: it re-exports the frontend surface from this crate while keeping runtime compilation, artifact loading, and execution wrappers in `pharmsol`. +The public pipeline is: -## Planned Ownership +1. `parse_model` or `parse_module` +2. `analyze_model` or `analyze_module` +3. `lower_typed_model` or `lower_typed_module` -The crate will own the backend-neutral frontend pipeline: +The main public modules are: -- AST and syntax types -- authoring desugaring -- diagnostics and spans -- lexical analysis -- parse entrypoints -- typed IR -- execution IR and lowering -- parse / analyze / lower entrypoints +- `ast` for syntax-level nodes +- `diagnostic` for spans, codes, and rendered reports +- `ir` for the typed intermediate representation +- `execution` for the lowered execution model shared by JIT, AoT, and WASM backends -The crate will not own runtime-facing APIs such as JIT, AoT, WASM loading, or `Subject`-based prediction wrappers in the initial extraction. +The parser accepts both canonical `model { ... }` source and the authoring shorthand used by the `pharmsol` examples. -## Migration Rule +## Small Example -Until the move is complete, `pharmsol::dsl` remains the compatibility façade. +```rust +use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model}; -That means: +let source = r#" +name = bimodal_ke +kind = ode -- backend code continues to live in `pharmsol` -- frontend modules move here slice by slice -- user-facing import churn is deferred until the architecture is stable +params = ke, v +states = central +outputs = cp -## Transitional Note +infusion(iv) -> central -The temporary lexer bridge from Slice 1 is gone. Frontend-only authoring fixtures now live under `pharmsol-dsl/tests/fixtures/dsl`, while the shared structured-block corpus remains under `tests/fixtures/dsl` because both crates still consume it. +dx(central) = -ke * central +out(cp) = central / v +"#; + +let syntax = parse_model(source).expect("model parses"); +let typed = analyze_model(&syntax).expect("model analyzes"); +let execution = lower_typed_model(&typed).expect("model lowers"); + +assert_eq!(execution.name, "bimodal_ke"); +assert_eq!(execution.metadata.routes.len(), 1); +assert_eq!(execution.metadata.outputs.len(), 1); +``` + +## Boundary With `pharmsol` + +`pharmsol-dsl` owns the frontend pipeline and its data structures. + +`pharmsol::dsl` re-exports that frontend surface and adds the runtime-facing APIs for backend selection, artifact loading, and prediction execution. + +Use `pharmsol-dsl` when you are building tooling, validation, migration, or your own backend. Use `pharmsol::dsl` when you want a complete source-to-runtime workflow. diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 0496c0fc..153760a0 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -5,6 +5,8 @@ use super::diagnostic::{Applicability, DiagnosticSuggestion, ParseError, Span, T use super::parser::{parse_expr_fragment, parse_place_fragment}; const DEFAULT_MODEL_NAME: &str = "main"; +const NUMERIC_ROUTE_PREFIX: &str = "input_"; +const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; pub(super) fn parse_module(src: &str) -> Result { AuthoringParser::new(src).parse_module() @@ -484,7 +486,7 @@ impl<'a> AuthoringParser<'a> { } }; - let input = parse_label_segment(call.argument, call.argument_start, "route label")?; + let input = parse_route_label_segment(call.argument, call.argument_start)?; let route_name = input.text.clone(); let destination = parse_place_at(rhs, line_start + arrow + 2)?; if self.routes.contains_key(&route_name) { @@ -515,8 +517,7 @@ impl<'a> AuthoringParser<'a> { ) -> Result<(), ParseError> { match call.callee.text.as_str() { "lag" | "fa" => { - let route_name = - parse_label_segment(call.argument, call.argument_start, "route label")?; + let route_name = parse_route_label_segment(call.argument, call.argument_start)?; let value = parse_expr_at(rhs, rhs_abs)?; let property_name = match call.callee.text.as_str() { "lag" => "lag", @@ -946,28 +947,140 @@ fn parse_ident_segment(src: &str, abs_start: usize) -> Result } fn parse_output_label_segment(src: &str, abs_start: usize) -> Result { - parse_label_segment(src, abs_start, "output label") + parse_label_segment(src, abs_start, LabelKind::Output) } -fn parse_label_segment(src: &str, abs_start: usize, expected: &str) -> Result { +fn parse_route_label_segment(src: &str, abs_start: usize) -> Result { + parse_label_segment(src, abs_start, LabelKind::Route) +} + +fn parse_label_segment(src: &str, abs_start: usize, kind: LabelKind) -> Result { let trimmed = src.trim(); let leading = src.len() - src.trim_start().len(); + let span = Span::new(abs_start + leading, abs_start + leading + trimmed.len()); if trimmed.is_empty() { return Err(ParseError::new( - format!("expected {expected}"), + format!("expected {}", kind.expected()), Span::new(abs_start, abs_start + src.len()), )); } if !is_valid_output_label(trimmed) { return Err(ParseError::new( - format!("expected {expected}, found `{trimmed}`"), - Span::new(abs_start + leading, abs_start + leading + trimmed.len()), + format!("expected {}, found `{trimmed}`", kind.expected()), + span, )); } - Ok(Ident::new( - trimmed, - Span::new(abs_start + leading, abs_start + leading + trimmed.len()), - )) + + if let Some(suffix) = bare_numeric_label(trimmed) { + let replacement = kind.canonical_label(suffix); + return Err(ParseError::new( + format!( + "bare numeric {} labels are not allowed in the DSL; use `{replacement}` instead", + kind.noun() + ), + span, + ) + .with_help(format!( + "numeric {} labels must use the `{}` prefix in authored DSL", + kind.noun(), + kind.prefix_pattern() + )) + .with_suggestion(DiagnosticSuggestion { + message: format!("use `{replacement}`"), + edits: vec![TextEdit { span, replacement }], + applicability: Applicability::Always, + })); + } + + if let Some(suffix) = canonical_numeric_suffix(trimmed, kind.wrong_prefix()) { + let replacement = kind.canonical_label(suffix); + return Err(ParseError::new( + format!( + "`{trimmed}` is {} label and cannot be used as {}; use `{replacement}` here", + kind.wrong_kind_phrase(), + kind.noun_phrase() + ), + span, + ) + .with_help(format!( + "numeric {} labels use the `{}` prefix", + kind.noun(), + kind.prefix_pattern() + )) + .with_suggestion(DiagnosticSuggestion { + message: format!("use `{replacement}`"), + edits: vec![TextEdit { span, replacement }], + applicability: Applicability::Always, + })); + } + + Ok(Ident::new(trimmed, span)) +} + +#[derive(Clone, Copy)] +enum LabelKind { + Route, + Output, +} + +impl LabelKind { + fn expected(self) -> &'static str { + match self { + Self::Route => "route label", + Self::Output => "output label", + } + } + + fn noun(self) -> &'static str { + match self { + Self::Route => "route", + Self::Output => "output", + } + } + + fn noun_phrase(self) -> &'static str { + match self { + Self::Route => "a route", + Self::Output => "an output", + } + } + + fn wrong_kind_phrase(self) -> &'static str { + match self { + Self::Route => "an output", + Self::Output => "a route", + } + } + + fn canonical_label(self, suffix: &str) -> String { + match self { + Self::Route => format!("{NUMERIC_ROUTE_PREFIX}{suffix}"), + Self::Output => format!("{NUMERIC_OUTPUT_PREFIX}{suffix}"), + } + } + + fn wrong_prefix(self) -> &'static str { + match self { + Self::Route => NUMERIC_OUTPUT_PREFIX, + Self::Output => NUMERIC_ROUTE_PREFIX, + } + } + + fn prefix_pattern(self) -> &'static str { + match self { + Self::Route => "input_", + Self::Output => "outeq_", + } + } +} + +fn bare_numeric_label(src: &str) -> Option<&str> { + (!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src) +} + +fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> { + let suffix = src.strip_prefix(prefix)?; + (!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix) } fn parse_place_at(src: &str, abs_start: usize) -> Result { diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 8bac1d69..2384432a 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -1553,6 +1553,69 @@ out(cp) = central / v ~ continuous() assert!(!lowered.metadata.routes[1].has_bioavailability); } + #[test] + fn canonical_numeric_channel_names_flow_into_execution_metadata_and_abi() { + let src = r#"name = canonical_numeric_channels +kind = ode + +params = ke, v +states = depot, central +outputs = cp, outeq_2 + +bolus(input_10) -> depot +infusion(iv) -> central + +dx(depot) = -ke * depot +dx(central) = rate(input_10) - ke * central + +out(cp) = central / v +out(outeq_2) = depot / v +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers"); + + assert_eq!( + lowered + .metadata + .routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["input_10", "iv"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "outeq_2"] + ); + assert_eq!( + lowered + .abi + .route_buffer + .slots + .iter() + .map(|slot| slot.name.as_str()) + .collect::>(), + vec!["input_10", "iv"] + ); + assert_eq!( + lowered + .abi + .output_buffer + .slots + .iter() + .map(|slot| slot.name.as_str()) + .collect::>(), + vec!["cp", "outeq_2"] + ); + } + #[test] fn authoring_routes_reject_infusion_lag_properties() { let src = r#"name = invalid_infusion_lag diff --git a/pharmsol-dsl/src/lib.rs b/pharmsol-dsl/src/lib.rs index 83bd9a7c..351b041c 100644 --- a/pharmsol-dsl/src/lib.rs +++ b/pharmsol-dsl/src/lib.rs @@ -1,8 +1,72 @@ //! Backend-neutral frontend crate for the pharmsol DSL. //! -//! This crate owns parsing, diagnostics, authoring desugaring, semantic -//! analysis, and execution lowering for DSL modules. `pharmsol::dsl` -//! re-exports the stable runtime-facing surface in the main crate. +//! Use this crate when you need the DSL frontend as an engineering API: parse +//! model source, inspect diagnostics, analyze names and types, and lower a +//! validated model into the execution representation that backends consume. +//! +//! Do not use this crate when you already know you want JIT compilation, +//! native AoT artifacts, WASM artifacts, or `Subject`-based prediction +//! helpers. Those runtime-facing workflows stay in the main `pharmsol` crate +//! under `pharmsol::dsl`. +//! +//! Main entrypoints: +//! +//! - [`parse_model`] and [`parse_module`] for turning DSL source text into the +//! syntax tree in [`ast`]. +//! - [`analyze_model`] and [`analyze_module`] for semantic validation and the +//! typed IR in [`ir`]. +//! - [`lower_typed_model`] and [`lower_typed_module`] for lowering typed models +//! into the execution representation in [`execution`]. +//! +//! The frontend pipeline is intentionally simple: +//! +//! 1. Parse source text into syntax. +//! 2. Analyze the syntax into a typed model. +//! 3. Lower the typed model into an [`ExecutionModel`] or [`ExecutionModule`]. +//! +//! This crate accepts both canonical `model { ... }` source and the authoring +//! shorthand used by the `pharmsol` examples. The returned diagnostics carry +//! source spans, rendered messages, and structured data for editor or UI use. +//! +//! Main modules: +//! +//! - [`ast`] for syntax-level nodes. +//! - [`diagnostic`] for spans, diagnostic codes, and rendered reports. +//! - [`ir`] for the typed intermediate representation. +//! - [`execution`] for the lowered execution model shared by JIT, AoT, and +//! WASM backends. +//! +//! Smallest parse-analyze-lower example: +//! +//! ```rust +//! use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model}; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let syntax = parse_model(source).expect("model parses"); +//! let typed = analyze_model(&syntax).expect("model analyzes"); +//! let execution = lower_typed_model(&typed).expect("model lowers"); +//! +//! assert_eq!(execution.name, "bimodal_ke"); +//! assert_eq!(execution.metadata.routes.len(), 1); +//! assert_eq!(execution.metadata.outputs.len(), 1); +//! ``` +//! +//! If you are building an authoring tool, custom compiler, or diagnostics UI, +//! stay in this crate. If you want a complete source-to-runtime workflow, +//! switch to `pharmsol::dsl` in the main crate. pub mod ast; mod authoring; diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index f20288a9..2f043ac1 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -38,6 +38,8 @@ const RESERVED_NAMES: &[&str] = &[ ]; const RATE_FUNCTION_NAME: &str = "rate"; +const NUMERIC_ROUTE_PREFIX: &str = "input_"; +const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; #[derive(Default)] struct SemanticAssist { @@ -572,6 +574,7 @@ impl<'a> Analyzer<'a> { let mut routes = Vec::new(); if let Some(block) = block { for route in &block.routes { + self.validate_route_label_name(&route.input)?; let id = self.insert_global_symbol( &route.input.text, SymbolKind::Route, @@ -649,6 +652,9 @@ impl<'a> Analyzer<'a> { collect_bare_assignment_names(statements, &mut seen, &mut collected_idents); let mut symbols = Vec::new(); for ident in collected_idents { + if matches!(kind, SymbolKind::Output) { + self.validate_output_label_name(&ident)?; + } let id = self.insert_global_symbol( &ident.text, kind, @@ -1318,12 +1324,18 @@ impl<'a> Analyzer<'a> { callee.span, )); } + if let syntax::ExprKind::Number(value) = &args[0].kind { + if let Some(suffix) = numeric_label_literal_suffix(*value) { + return Err(self.bare_numeric_route_error(args[0].span, &suffix)); + } + } let syntax::ExprKind::Name(route_name) = &args[0].kind else { return Err(SemanticError::new( "`rate` expects a route identifier argument", args[0].span, )); }; + self.validate_route_label_name(route_name)?; let route = self .globals .routes @@ -1656,6 +1668,100 @@ impl<'a> Analyzer<'a> { Ok(id) } + fn validate_route_label_name(&self, label: &syntax::Ident) -> Result<(), SemanticError> { + if let Some(suffix) = bare_numeric_label(&label.text) { + return Err(self.bare_numeric_route_error(label.span, suffix)); + } + if let Some(suffix) = canonical_numeric_suffix(&label.text, NUMERIC_OUTPUT_PREFIX) { + return Err(self.wrong_prefix_route_error(label, suffix)); + } + Ok(()) + } + + fn validate_output_label_name(&self, label: &syntax::Ident) -> Result<(), SemanticError> { + if let Some(suffix) = bare_numeric_label(&label.text) { + return Err(self.bare_numeric_output_error(label.span, suffix)); + } + if let Some(suffix) = canonical_numeric_suffix(&label.text, NUMERIC_ROUTE_PREFIX) { + return Err(self.wrong_prefix_output_error(label, suffix)); + } + Ok(()) + } + + fn bare_numeric_route_error(&self, span: Span, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_ROUTE_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric route labels must use the `input_` form in authored DSL") + .replacement_suggestion( + span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "bare numeric route labels are not allowed in the DSL; use `{replacement}` instead" + ), + span, + )) + } + + fn bare_numeric_output_error(&self, span: Span, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_OUTPUT_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric output labels must use the `outeq_` form in authored DSL") + .replacement_suggestion( + span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "bare numeric output labels are not allowed in the DSL; use `{replacement}` instead" + ), + span, + )) + } + + fn wrong_prefix_route_error(&self, label: &syntax::Ident, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_ROUTE_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric route labels use the `input_` prefix") + .replacement_suggestion( + label.span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "`{}` is an output label and cannot be used as a route; use `{replacement}` here", + label.text + ), + label.span, + )) + } + + fn wrong_prefix_output_error(&self, label: &syntax::Ident, suffix: &str) -> SemanticError { + let replacement = format!("{NUMERIC_OUTPUT_PREFIX}{suffix}"); + SemanticAssist::default() + .help("numeric output labels use the `outeq_` prefix") + .replacement_suggestion( + label.span, + replacement.clone(), + format!("use `{replacement}`"), + Applicability::Always, + ) + .apply(SemanticError::new( + format!( + "`{}` is a route label and cannot be used as an output target; use `{replacement}` here", + label.text + ), + label.span, + )) + } + fn insert_local_symbol( &mut self, env: &mut BlockEnv, @@ -2142,6 +2248,20 @@ fn allows_route_output_name_overlap(existing: SymbolKind, new: SymbolKind) -> bo ) } +fn bare_numeric_label(src: &str) -> Option<&str> { + (!src.is_empty() && src.chars().all(|ch| ch.is_ascii_digit())).then_some(src) +} + +fn canonical_numeric_suffix<'a>(src: &'a str, prefix: &str) -> Option<&'a str> { + let suffix = src.strip_prefix(prefix)?; + (!suffix.is_empty() && suffix.chars().all(|ch| ch.is_ascii_digit())).then_some(suffix) +} + +fn numeric_label_literal_suffix(value: f64) -> Option { + (value.is_finite() && value >= 0.0 && value.fract() == 0.0 && value <= usize::MAX as f64) + .then(|| (value as usize).to_string()) +} + #[derive(Default)] struct Globals { all_names: BTreeMap, diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 4d1651f5..335a7f86 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -162,18 +162,18 @@ out(cp) = central ~ continous() } #[test] -fn mixed_named_and_numeric_output_labels_lower_and_round_trip() { +fn mixed_named_and_prefixed_numeric_output_labels_lower_and_round_trip() { let src = r#" name = mixed_output_labels kind = ode params = ke, v states = central -outputs = cp, 0, 1 +outputs = cp, outeq_0, outeq_1 infusion(iv) -> central ddt(central) = -ke * central out(cp) = central / v -out(0) = 2 * central / v -out(1) = 3 * central / v +out(outeq_0) = 2 * central / v +out(outeq_1) = 3 * central / v "#; let module = parse_module(src).expect("mixed output labels should parse in authoring DSL"); @@ -191,7 +191,7 @@ out(1) = 3 * central / v .iter() .map(|output| output.name.as_str()) .collect::>(), - vec!["cp", "0", "1"] + vec!["cp", "outeq_0", "outeq_1"] ); assert_eq!( lowered @@ -210,26 +210,26 @@ out(1) = 3 * central / v } #[test] -fn shared_numeric_route_and_output_labels_lower_and_round_trip() { +fn prefixed_numeric_route_and_output_labels_lower_and_round_trip() { let src = r#" -name = shared_numeric_route_output_labels +name = prefixed_numeric_route_output_labels kind = ode params = ke, v states = central -outputs = 1 -infusion(1) -> central +outputs = outeq_1 +infusion(input_1) -> central ddt(central) = -ke * central -out(1) = central / v +out(outeq_1) = central / v "#; - let module = parse_module(src).expect("shared numeric route/output labels should parse"); + let module = parse_module(src).expect("prefixed numeric route/output labels should parse"); let model = module .models .first() .expect("authoring DSL should produce one model"); - let typed = analyze_model(model).expect("shared numeric route/output labels should analyze"); + let typed = analyze_model(model).expect("prefixed numeric route/output labels should analyze"); let lowered = - lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); + lower_typed_model(&typed).expect("prefixed numeric route/output labels should lower"); assert_eq!( lowered @@ -238,7 +238,7 @@ out(1) = central / v .iter() .map(|route| route.name.as_str()) .collect::>(), - vec!["1"] + vec!["input_1"] ); assert_eq!( lowered @@ -247,7 +247,7 @@ out(1) = central / v .iter() .map(|output| output.name.as_str()) .collect::>(), - vec!["1"] + vec!["outeq_1"] ); let rendered = module.to_string(); @@ -256,6 +256,205 @@ out(1) = central / v assert_eq!(rendered, reparsed.to_string()); } +#[test] +fn rejects_authoring_bare_numeric_output_declarations() { + let src = r#" +name = numeric_outputs +kind = ode +states = central +outputs = 1, 2 +ddt(central) = 0 +out(1) = central +"#; + + let err = parse_model(src).expect_err("bare numeric output declarations must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric output labels are not allowed in the DSL; use `outeq_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `outeq_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_authoring_bare_numeric_route_labels() { + let src = r#" +name = numeric_routes +kind = ode +states = central +outputs = cp +infusion(1) -> central +ddt(central) = 0 +out(cp) = central +"#; + + let err = parse_model(src).expect_err("bare numeric route labels must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric route labels are not allowed in the DSL; use `input_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `input_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_structured_bare_numeric_output_targets() { + let src = r#" +model numeric_output_target { + kind ode + states { central } + outputs { + 1 = central + } +} +"#; + + let model = parse_model(src).expect("structured model parses"); + let err = analyze_model(&model).expect_err("bare numeric output target must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric output labels are not allowed in the DSL; use `outeq_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `outeq_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_structured_bare_numeric_route_labels() { + let src = r#" +model numeric_route_label { + kind ode + states { central } + routes { + 1 -> central + } + outputs { + cp = central + } +} +"#; + + let model = parse_model(src).expect("structured model parses"); + let err = analyze_model(&model).expect_err("bare numeric route label must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric route labels are not allowed in the DSL; use `input_1` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `input_1`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_rate_numeric_literals_with_prefixed_guidance() { + let src = r#" +model numeric_rate_arg { + kind ode + states { central } + routes { input_5 -> central } + dynamics { + ddt(central) = rate(5) + } + outputs { + cp = central + } +} +"#; + + let model = parse_model(src).expect("structured model parses"); + let err = analyze_model(&model).expect_err("bare numeric rate argument must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "bare numeric route labels are not allowed in the DSL; use `input_5` instead" + ), + "{}", + rendered + ); + assert!( + rendered.contains("suggestion: use `input_5`"), + "{}", + rendered + ); +} + +#[test] +fn rejects_wrong_prefix_labels_in_authored_dsl() { + let src = r#" +name = wrong_prefix_route +kind = ode +states = central +outputs = cp +infusion(outeq_1) -> central +ddt(central) = 0 +out(cp) = central +"#; + + let err = parse_model(src).expect_err("wrong-prefix route labels must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "`outeq_1` is an output label and cannot be used as a route; use `input_1` here" + ), + "{}", + rendered + ); + + let src = r#" +name = wrong_prefix_output +kind = ode +states = central +outputs = cp +infusion(iv) -> central +ddt(central) = 0 +out(input_1) = central +"#; + + let err = parse_model(src).expect_err("wrong-prefix output labels must fail"); + let rendered = err.render(src); + + assert!( + rendered.contains( + "`input_1` is a route label and cannot be used as an output; use `outeq_1` here" + ), + "{}", + rendered + ); +} + #[test] fn route_labels_still_collide_with_scalar_symbol_names() { let src = r#" diff --git a/src/data/error_model.rs b/src/data/error_model.rs index 609cad9e..1e66eb60 100644 --- a/src/data/error_model.rs +++ b/src/data/error_model.rs @@ -1,6 +1,9 @@ -use std::hash::{Hash, Hasher}; +use std::{ + collections::BTreeMap, + hash::{Hash, Hasher}, +}; -use crate::simulator::likelihood::Prediction; +use crate::{data::event::OutputLabel, simulator::likelihood::Prediction}; use serde::{Deserialize, Serialize}; use thiserror::Error; @@ -120,7 +123,11 @@ impl ErrorPoly { impl From> for AssayErrorModels { fn from(models: Vec) -> Self { - Self { models } + Self { + models, + output_lookup: BTreeMap::new(), + named_models: BTreeMap::new(), + } } } @@ -140,6 +147,8 @@ impl From> for AssayErrorModels { #[derive(Serialize, Debug, Clone, Deserialize)] pub struct AssayErrorModels { models: Vec, + output_lookup: BTreeMap, + named_models: BTreeMap, } /// Deprecated alias for [`AssayErrorModels`]. @@ -159,12 +168,149 @@ impl Default for AssayErrorModels { } impl AssayErrorModels { - /// Create a new instance of [`AssayErrorModels`] + /// Create a new reusable label-first [`AssayErrorModels`] definition. /// - /// # Returns - /// A new instance of [AssayErrorModels]. + /// Output labels are resolved once per equation when the error models are + /// used through simulation or likelihood entrypoints. + /// + /// This lets the same public definition be reused safely across multiple + /// equations while keeping the dense bound representation internal to the + /// runtime path. + /// + /// ```rust + /// # use pharmsol::prelude::*; + /// let error_models = AssayErrorModels::new() + /// .add("cp", AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0))?; + /// # Ok::<(), pharmsol::data::error_model::ErrorModelError>(()) + /// ``` pub fn new() -> Self { - Self { models: vec![] } + Self::empty() + } + + pub(crate) fn assert_compatible_output_names( + &self, + outputs: I, + ) -> Result<(), ErrorModelError> + where + I: IntoIterator, + S: AsRef, + { + if self.output_lookup.is_empty() { + return Ok(()); + } + + let expected = self.bound_output_names(); + let found = outputs + .into_iter() + .map(|output| output.as_ref().to_string()) + .collect::>(); + if expected == found { + return Ok(()); + } + + Err(ErrorModelError::IncompatibleOutputContext { expected, found }) + } + + pub(crate) fn bind_to(&self, context: &impl crate::Equation) -> Result { + self.bind_output_names(context.assay_error_models().bound_output_names()) + } + + pub(crate) fn bind_output_names(&self, outputs: I) -> Result + where + I: IntoIterator, + S: AsRef, + { + let outputs = outputs + .into_iter() + .map(|output| output.as_ref().to_string()) + .collect::>(); + + if !self.output_lookup.is_empty() { + self.assert_compatible_output_names(outputs.iter().map(String::as_str))?; + return Ok(self.clone()); + } + + if self.named_models.is_empty() { + return Ok(self.clone()); + } + + let mut bound = Self::with_output_names(outputs.iter().map(String::as_str)); + bound.models = self.models.clone(); + + for (label, model) in &self.named_models { + bound = bound.add(label.clone(), model.clone())?; + } + + Ok(bound) + } + + /// Create an unbound error-model set for dense-slot callers. + /// + /// This keeps the pre-existing numeric-slot setup path available for low-level + /// tests or workflows that deliberately operate on dense output indices. + pub(crate) fn empty() -> Self { + Self { + models: vec![], + output_lookup: BTreeMap::new(), + named_models: BTreeMap::new(), + } + } + + /// Create an error-model set with output labels resolved up front. + /// + /// This is the label-aware constructor for public workflows. It binds names + /// to dense output slots once during setup so that likelihood evaluation can + /// keep using direct vector indexing with no additional runtime lookup cost. + pub(crate) fn with_output_names(outputs: I) -> Self + where + I: IntoIterator, + S: AsRef, + { + let output_lookup = outputs + .into_iter() + .enumerate() + .map(|(index, output)| (OutputLabel::new(output.as_ref()), index)) + .collect(); + + Self { + models: vec![], + output_lookup, + named_models: BTreeMap::new(), + } + } + + fn bound_output_names(&self) -> Vec { + let mut names = self + .output_lookup + .iter() + .map(|(label, index)| (*index, label.to_string())) + .collect::>(); + names.sort_by_key(|(index, _)| *index); + names.into_iter().map(|(_, label)| label).collect() + } + + fn resolve_output_binding(&self, outeq: impl ToString) -> Result { + let label = OutputLabel::new(outeq); + self.output_lookup + .get(&label) + .copied() + .or_else(|| label.index()) + .ok_or_else(|| ErrorModelError::UnknownOutputLabel(label.to_string())) + } + + fn insert_model_at( + &mut self, + outeq: usize, + model: AssayErrorModel, + ) -> Result<(), ErrorModelError> { + if outeq >= self.models.len() { + self.models.resize(outeq + 1, AssayErrorModel::None); + } + if self.models[outeq] != AssayErrorModel::None { + return Err(ErrorModelError::ExistingOutputEquation(outeq)); + } + self.models[outeq] = model; + Ok(()) } /// Get the error model for a specific output equation @@ -182,22 +328,36 @@ impl AssayErrorModels { Ok(&self.models[outeq]) } - /// Add a new error model for a specific output equation + /// Add a new error model for a specific output equation or declared label. /// # Arguments - /// * `outeq` - The index of the output equation for which to add the error model. + /// * `outeq` - The output slot index or public output label. /// * `model` - The [AssayErrorModel] to add for the specified output equation. /// # Returns /// A new instance of AssayErrorModels with the added model. /// # Errors - /// If the output equation index is invalid or if a model already exists for that output equation, an [ErrorModelError::ExistingOutputEquation] is returned. - pub fn add(mut self, outeq: usize, model: AssayErrorModel) -> Result { - if outeq >= self.models.len() { - self.models.resize(outeq + 1, AssayErrorModel::None); + /// If the output label is unknown or if a model already exists for that output equation, an error is returned. + pub fn add( + mut self, + outeq: impl ToString, + model: AssayErrorModel, + ) -> Result { + let label = OutputLabel::new(outeq); + + if !self.output_lookup.is_empty() { + let outeq = self.resolve_output_binding(label.clone())?; + self.insert_model_at(outeq, model)?; + return Ok(self); } - if self.models[outeq] != AssayErrorModel::None { - return Err(ErrorModelError::ExistingOutputEquation(outeq)); + + if let Some(outeq) = label.index() { + self.insert_model_at(outeq, model)?; + return Ok(self); } - self.models[outeq] = model; + + if self.named_models.contains_key(&label) { + return Err(ErrorModelError::ExistingOutputLabel(label.to_string())); + } + self.named_models.insert(label, model); Ok(self) } /// Returns an iterator over the error models in the collection. @@ -222,6 +382,27 @@ impl AssayErrorModels { pub fn hash(&self) -> u64 { let mut hasher = ahash::AHasher::default(); + for (label, model) in &self.named_models { + 3u8.hash(&mut hasher); + label.hash(&mut hasher); + + match model { + AssayErrorModel::Additive { lambda, .. } => { + 0u8.hash(&mut hasher); + lambda.value().to_bits().hash(&mut hasher); + lambda.is_fixed().hash(&mut hasher); + } + AssayErrorModel::Proportional { gamma, .. } => { + 1u8.hash(&mut hasher); + gamma.value().to_bits().hash(&mut hasher); + gamma.is_fixed().hash(&mut hasher); + } + AssayErrorModel::None => { + 2u8.hash(&mut hasher); + } + } + } + for outeq in 0..self.models.len() { // Find the model with the matching outeq ID @@ -249,12 +430,16 @@ impl AssayErrorModels { } /// Returns the number of error models in the collection. pub fn len(&self) -> usize { + if self.models.is_empty() && !self.named_models.is_empty() && self.output_lookup.is_empty() + { + return self.named_models.len(); + } self.models.len() } /// Returns whether the collection contains no error models. pub fn is_empty(&self) -> bool { - self.models.is_empty() + self.models.is_empty() && self.named_models.is_empty() } /// Returns the error polynomial associated with the specified output equation. @@ -943,8 +1128,19 @@ pub enum ErrorModelError { NonFiniteSigma, #[error("The output equation index {0} is invalid")] InvalidOutputEquation(usize), + #[error("The output label `{0}` is not declared in this error model context")] + UnknownOutputLabel(String), + #[error("The output label `{0}` already exists in this assay error model specification")] + ExistingOutputLabel(String), #[error("The output equation number {0} already exists")] ExistingOutputEquation(usize), + #[error( + "Assay error models were bound for outputs {expected:?} but used with outputs {found:?}" + )] + IncompatibleOutputContext { + expected: Vec, + found: Vec, + }, #[error("An output equation does not have an error model defined")] MissingErrorModel, #[error("The output equation index {0} is of type ErrorModel::None")] @@ -1029,7 +1225,7 @@ mod tests { #[test] fn test_error_models_add_single() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!(models.len(), 1); } @@ -1038,7 +1234,7 @@ mod tests { let model1 = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add(0, model1) .unwrap() .add(1, model2) @@ -1047,12 +1243,101 @@ mod tests { assert_eq!(models.len(), 2); } + #[test] + fn test_error_models_add_label_with_output_names() { + let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); + let models = AssayErrorModels::with_output_names(["cp", "effect"]) + .add("effect", model) + .unwrap(); + + assert_eq!(models.len(), 2); + assert!(matches!(models.error_model(1), Ok(_))); + } + + #[test] + fn test_error_models_bind_output_names() { + let error_models = AssayErrorModels::new() + .add( + "effect", + AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), + ) + .unwrap(); + + let models = error_models.bind_output_names(["cp", "effect"]).unwrap(); + assert_eq!(models.len(), 2); + assert!(matches!(models.error_model(1), Ok(_))); + } + + #[test] + fn test_error_models_add_unknown_label_fails() { + let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); + let result = AssayErrorModels::with_output_names(["cp"]).add("effect", model); + + assert!(result.is_err()); + match result { + Err(ErrorModelError::UnknownOutputLabel(label)) => assert_eq!(label, "effect"), + _ => panic!("Expected UnknownOutputLabel error"), + } + } + + #[test] + fn test_error_models_duplicate_label_fails() { + let result = AssayErrorModels::new() + .add( + "cp", + AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), + ) + .unwrap() + .add( + "cp", + AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0), + ); + + match result { + Err(ErrorModelError::ExistingOutputLabel(label)) => assert_eq!(label, "cp"), + _ => panic!("Expected ExistingOutputLabel error"), + } + } + + #[test] + fn test_bound_error_models_reject_mismatched_output_context() { + let error_models = AssayErrorModels::new() + .add( + "cp", + AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), + ) + .unwrap() + .bind_output_names(["cp", "effect"]) + .unwrap(); + + match error_models.assert_compatible_output_names(["effect", "cp"]) { + Err(ErrorModelError::IncompatibleOutputContext { expected, found }) => { + assert_eq!(expected, vec!["cp".to_string(), "effect".to_string()]); + assert_eq!(found, vec!["effect".to_string(), "cp".to_string()]); + } + _ => panic!("Expected IncompatibleOutputContext error"), + } + } + + #[test] + fn test_error_models_sigma_from_label_bound_output() { + let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); + let models = AssayErrorModels::with_output_names(["cp"]) + .add("cp", model) + .unwrap(); + + let observation = Observation::new(0.0, Some(20.0), 0, None, 0, Censor::None); + let prediction = observation.to_prediction(10.0, vec![]); + + assert_eq!(models.sigma(&prediction).unwrap(), (26.0_f64).sqrt()); + } + #[test] fn test_error_models_add_duplicate_outeq_fails() { let model1 = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); - let result = AssayErrorModels::new() + let result = AssayErrorModels::empty() .add(0, model1) .unwrap() .add(0, model2); // Same outeq should fail @@ -1067,7 +1352,7 @@ mod tests { #[test] fn test_error_models_factor() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!(models.factor(0).unwrap(), 5.0); } @@ -1075,7 +1360,7 @@ mod tests { #[test] fn test_error_models_factor_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.factor(1); assert!(result.is_err()); @@ -1088,7 +1373,7 @@ mod tests { #[test] fn test_error_models_set_factor() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!(models.factor(0).unwrap(), 5.0); models.set_factor(0, 10.0).unwrap(); @@ -1098,7 +1383,7 @@ mod tests { #[test] fn test_error_models_set_factor_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.set_factor(1, 10.0); assert!(result.is_err()); @@ -1112,7 +1397,7 @@ mod tests { fn test_error_models_errorpoly() { let poly = ErrorPoly::new(1.0, 2.0, 3.0, 4.0); let model = AssayErrorModel::additive(poly, 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let retrieved_poly = models.errorpoly(0).unwrap(); assert_eq!(retrieved_poly.coefficients(), (1.0, 2.0, 3.0, 4.0)); @@ -1121,7 +1406,7 @@ mod tests { #[test] fn test_error_models_errorpoly_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.errorpoly(1); assert!(result.is_err()); @@ -1136,7 +1421,7 @@ mod tests { let poly1 = ErrorPoly::new(1.0, 2.0, 3.0, 4.0); let poly2 = ErrorPoly::new(5.0, 6.0, 7.0, 8.0); let model = AssayErrorModel::additive(poly1, 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); assert_eq!( models.errorpoly(0).unwrap().coefficients(), @@ -1152,7 +1437,7 @@ mod tests { #[test] fn test_error_models_set_errorpoly_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let mut models = AssayErrorModels::new().add(0, model).unwrap(); + let mut models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.set_errorpoly(1, ErrorPoly::new(5.0, 6.0, 7.0, 8.0)); assert!(result.is_err()); @@ -1165,7 +1450,7 @@ mod tests { #[test] fn test_error_models_sigma() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 0, None, 0, Censor::None); let prediction = observation.to_prediction(10.0, vec![]); @@ -1178,7 +1463,7 @@ mod tests { #[test] fn test_error_models_sigma_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 1, None, 0, Censor::None); // outeq=1 not in models let prediction = observation.to_prediction(10.0, vec![]); @@ -1194,7 +1479,7 @@ mod tests { #[test] fn test_error_models_variance() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 0, None, 0, Censor::None); let prediction = observation.to_prediction(10.0, vec![]); @@ -1207,7 +1492,7 @@ mod tests { #[test] fn test_error_models_variance_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let observation = Observation::new(0.0, Some(20.0), 1, None, 0, Censor::None); // outeq=1 not in models let prediction = observation.to_prediction(10.0, vec![]); @@ -1223,7 +1508,7 @@ mod tests { #[test] fn test_error_models_sigma_from_value() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let sigma = models.sigma_from_value(0, 20.0).unwrap(); assert_eq!(sigma, (26.0_f64).sqrt()); @@ -1232,7 +1517,7 @@ mod tests { #[test] fn test_error_models_sigma_from_value_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.sigma_from_value(1, 20.0); assert!(result.is_err()); @@ -1245,7 +1530,7 @@ mod tests { #[test] fn test_error_models_variance_from_value() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let variance = models.variance_from_value(0, 20.0).unwrap(); let expected_sigma = (26.0_f64).sqrt(); @@ -1255,7 +1540,7 @@ mod tests { #[test] fn test_error_models_variance_from_value_invalid_outeq() { let model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models = AssayErrorModels::new().add(0, model).unwrap(); + let models = AssayErrorModels::empty().add(0, model).unwrap(); let result = models.variance_from_value(1, 20.0); assert!(result.is_err()); @@ -1270,13 +1555,13 @@ mod tests { let model1 = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); - let models1 = AssayErrorModels::new() + let models1 = AssayErrorModels::empty() .add(0, model1.clone()) .unwrap() .add(1, model2.clone()) .unwrap(); - let models2 = AssayErrorModels::new() + let models2 = AssayErrorModels::empty() .add(0, model1) .unwrap() .add(1, model2) @@ -1292,13 +1577,13 @@ mod tests { let model2 = AssayErrorModel::proportional(ErrorPoly::new(2.0, 0.0, 0.0, 0.0), 3.0); // Add in different orders - let models1 = AssayErrorModels::new() + let models1 = AssayErrorModels::empty() .add(0, model1.clone()) .unwrap() .add(1, model2.clone()) .unwrap(); - let models2 = AssayErrorModels::new() + let models2 = AssayErrorModels::empty() .add(1, model2) .unwrap() .add(0, model1) @@ -1314,7 +1599,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.1); - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1343,7 +1628,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 2.0); - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1463,7 +1748,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 2.0); - let mut models = AssayErrorModels::new() + let mut models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1523,8 +1808,8 @@ mod tests { let model1_variable = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); let model1_fixed = AssayErrorModel::additive_fixed(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0); - let models1 = AssayErrorModels::new().add(0, model1_variable).unwrap(); - let models2 = AssayErrorModels::new().add(0, model1_fixed).unwrap(); + let models1 = AssayErrorModels::empty().add(0, model1_variable).unwrap(); + let models2 = AssayErrorModels::empty().add(0, model1_fixed).unwrap(); // Different fixed/variable states should produce different hashes assert_ne!(models1.hash(), models2.hash()); @@ -1536,7 +1821,7 @@ mod tests { let proportional_model = AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 2.0); - let mut models = AssayErrorModels::new() + let mut models = AssayErrorModels::empty() .add(0, additive_model) .unwrap() .add(1, proportional_model) @@ -1610,7 +1895,7 @@ mod tests { #[test] fn error_model_hash_deterministic() { - let models = AssayErrorModels::new() + let models = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), @@ -1621,13 +1906,13 @@ mod tests { #[test] fn error_model_hash_differs_on_value() { - let a = AssayErrorModels::new() + let a = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), ) .unwrap(); - let b = AssayErrorModels::new() + let b = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 10.0), @@ -1638,13 +1923,13 @@ mod tests { #[test] fn error_model_hash_differs_on_type() { - let a = AssayErrorModels::new() + let a = AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), ) .unwrap(); - let b = AssayErrorModels::new() + let b = AssayErrorModels::empty() .add( 0, AssayErrorModel::proportional(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 5.0), diff --git a/src/data/event.rs b/src/data/event.rs index 02a4c9a7..d0f96f7e 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -113,107 +113,170 @@ pub enum Event { Observation(Observation), } -macro_rules! impl_label_type { - ($(#[$meta:meta])* $name:ident) => { - $(#[$meta])* - #[derive( - Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, - )] - pub struct $name(String); - - impl $name { - /// Create a new public label. - /// - /// Prefer stable names when the model declares named routes or - /// outputs. - pub fn new(label: impl ToString) -> Self { - Self(label.to_string()) - } +/// Public label for a dosing input or route. +/// +/// [`Bolus`] and [`Infusion`] store the original user-facing route name in +/// this type. +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct InputLabel(String); - /// Borrow the stored label as a string. - pub fn as_str(&self) -> &str { - &self.0 - } +impl InputLabel { + /// Create a new public label. + /// + /// Prefer stable names when the model declares named routes. + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } - /// Try to interpret the label as a numeric index. - /// - /// This is mainly a compatibility helper for lower-level paths that - /// still operate on dense indices after label resolution. - pub fn index(&self) -> Option { - self.0.parse::().ok() - } - } + /// Borrow the stored label as a string. + pub fn as_str(&self) -> &str { + &self.0 + } - impl From for $name { - fn from(value: String) -> Self { - Self(value) - } - } + /// Try to interpret the label as a numeric index. + /// + /// This is mainly a compatibility helper for lower-level paths that still + /// operate on dense indices after label resolution. + pub fn index(&self) -> Option { + self.0.parse::().ok() + } +} - impl From<&str> for $name { - fn from(value: &str) -> Self { - Self(value.to_string()) - } - } +impl From for InputLabel { + fn from(value: String) -> Self { + Self(value) + } +} - impl From for $name { - fn from(value: usize) -> Self { - Self(value.to_string()) - } - } +impl From<&str> for InputLabel { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} - impl AsRef for $name { - fn as_ref(&self) -> &str { - self.as_str() - } - } +impl From for InputLabel { + fn from(value: usize) -> Self { + Self(value.to_string()) + } +} - impl fmt::Display for $name { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.as_str()) - } - } +impl AsRef for InputLabel { + fn as_ref(&self) -> &str { + self.as_str() + } +} - impl PartialEq for $name { - fn eq(&self, other: &usize) -> bool { - self.index() == Some(*other) - } - } +impl fmt::Display for InputLabel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} - impl PartialEq<$name> for usize { - fn eq(&self, other: &$name) -> bool { - other == self - } - } +impl PartialEq for InputLabel { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } +} - impl PartialEq for &$name { - fn eq(&self, other: &usize) -> bool { - (**self).eq(other) - } - } +impl PartialEq for usize { + fn eq(&self, other: &InputLabel) -> bool { + other == self + } +} - impl PartialEq<&$name> for usize { - fn eq(&self, other: &&$name) -> bool { - other.eq(self) - } - } - }; +impl PartialEq for &InputLabel { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } +} + +impl PartialEq<&InputLabel> for usize { + fn eq(&self, other: &&InputLabel) -> bool { + other.eq(self) + } } -impl_label_type!( - /// Public label for a dosing input or route. +/// Public label for an observation output. +/// +/// [`Observation`] stores the original user-facing output name in this type. +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct OutputLabel(String); + +impl OutputLabel { + /// Create a new public label. /// - /// [`Bolus`] and [`Infusion`] store the original user-facing route name in - /// this type. - InputLabel -); -impl_label_type!( - /// Public label for an observation output. + /// Prefer stable names when the model declares named outputs. + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } + + /// Borrow the stored label as a string. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Try to interpret the label as a numeric index. /// - /// [`Observation`] stores the original user-facing output name in this - /// type. - OutputLabel -); + /// This is mainly a compatibility helper for lower-level paths that still + /// operate on dense indices after label resolution. + pub fn index(&self) -> Option { + self.0.parse::().ok() + } +} + +impl From for OutputLabel { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for OutputLabel { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl From for OutputLabel { + fn from(value: usize) -> Self { + Self(value.to_string()) + } +} + +impl AsRef for OutputLabel { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl fmt::Display for OutputLabel { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl PartialEq for OutputLabel { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } +} + +impl PartialEq for usize { + fn eq(&self, other: &OutputLabel) -> bool { + other == self + } +} + +impl PartialEq for &OutputLabel { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } +} + +impl PartialEq<&OutputLabel> for usize { + fn eq(&self, other: &&OutputLabel) -> bool { + other.eq(self) + } +} impl Event { /// Get the time of the event @@ -487,11 +550,11 @@ pub enum Censor { ALOQ, } - /// Observation of a model output. - /// - /// An [`Observation`] can carry a measured value or `None` for a prediction-only - /// time point. Observations also carry the public output label, optional assay - /// error polynomial, occasion index, and censoring state. +/// Observation of a model output. +/// +/// An [`Observation`] can carry a measured value or `None` for a prediction-only +/// time point. Observations also carry the public output label, optional assay +/// error polynomial, occasion index, and censoring state. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Observation { time: f64, diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 6557e015..0dd38252 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -597,12 +597,51 @@ mod tests { other => panic!("expected ode model, got {other:?}"), }; - let oral = jit.route_index("oral").expect("jit oral route"); - let iv = jit.route_index("iv").expect("jit iv route"); - let cp = jit.output_index("cp").expect("jit cp output"); - assert_eq!(aot.route_index("oral"), Some(oral)); - assert_eq!(aot.route_index("iv"), Some(iv)); - assert_eq!(aot.output_index("cp"), Some(cp)); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("jit oral route"); + let iv = jit + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index) + .expect("jit iv route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("jit cp output"); + assert_eq!( + aot.info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index), + Some(oral) + ); + assert_eq!( + aot.info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index), + Some(iv) + ); + assert_eq!( + aot.info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index), + Some(cp) + ); let subject = crate::Subject::builder("ode") .covariate("wt", 0.0, 70.0) diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index 684b7810..c41d4ed4 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1402,9 +1402,27 @@ out(cp) = central / v ~ continuous() .expect("compile jit ode model") .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); - let oral = jit.route_index("oral").expect("oral route"); - let iv = jit.route_index("iv").expect("iv route"); - let cp = jit.output_index("cp").expect("cp output"); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let iv = jit + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index) + .expect("iv route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 0); assert_eq!(cp, 0); @@ -1545,9 +1563,27 @@ out(cp) = central / v ~ continuous() .expect("compile jit ode model") .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); - let oral = jit.route_index("oral").expect("oral route"); - let iv = jit.route_index("iv").expect("iv route"); - let cp = jit.output_index("cp").expect("cp output"); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let iv = jit + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .map(|route| route.index) + .expect("iv route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 1); assert_eq!(cp, 0); @@ -1644,8 +1680,20 @@ out(cp) = central / v ~ continuous() let model = load_corpus_model("one_cmt_abs"); let jit = compile_analytical_model_to_jit(&model).expect("compile jit analytical model"); - let oral = jit.route_index("oral").expect("oral route"); - let cp = jit.output_index("cp").expect("cp output"); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); assert_eq!(oral, 0); assert_eq!(cp, 0); @@ -1708,8 +1756,20 @@ out(cp) = central / v ~ continuous() .expect("compile jit sde model") .with_particles(64); - let oral = jit.route_index("oral").expect("oral route"); - let cp = jit.output_index("cp").expect("cp output"); + let oral = jit + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .map(|route| route.index) + .expect("oral route"); + let cp = jit + .info() + .outputs + .iter() + .find(|output| output.name == "cp") + .map(|output| output.index) + .expect("cp output"); assert_eq!(oral, 0); assert_eq!(cp, 0); diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 7dd3f72a..27f2416a 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -305,4 +305,42 @@ out(cp) = central / v ~ continuous() assert!(info.routes[0].inject_input_to_destination); assert!(!info.routes[1].inject_input_to_destination); } + + #[test] + fn native_model_info_preserves_canonical_numeric_channel_names() { + let info = load_model_info( + r#" +name = canonical_numeric_channels +kind = ode + +params = ke, v +states = depot, central +outputs = cp, outeq_2 + +bolus(input_10) -> depot +infusion(iv) -> central + +dx(depot) = -ke * depot +dx(central) = rate(input_10) - ke * central + +out(cp) = central / v +out(outeq_2) = depot / v +"#, + ); + + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["input_10", "iv"] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "outeq_2"] + ); + } } diff --git a/src/dsl/native.rs b/src/dsl/native.rs index d9598172..d35a392b 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -48,6 +48,8 @@ pub type DenseKernelFn = unsafe extern "C" fn( const DEFAULT_ODE_RTOL: f64 = 1e-4; const DEFAULT_ODE_ATOL: f64 = 1e-4; +const NUMERIC_ROUTE_PREFIX: &str = "input_"; +const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum RuntimeBackend { @@ -357,6 +359,20 @@ impl SharedNativeModel { .map(|output| output.index) } + fn metadata_route_index_for_label(&self, label: &str) -> Option { + self.route_index(label).or_else(|| { + canonical_numeric_alias(label, NUMERIC_ROUTE_PREFIX) + .and_then(|alias| self.route_index(alias.as_str())) + }) + } + + fn metadata_output_index_for_label(&self, label: &str) -> Option { + self.output_index(label).or_else(|| { + canonical_numeric_alias(label, NUMERIC_OUTPUT_PREFIX) + .and_then(|alias| self.output_index(alias.as_str())) + }) + } + fn validate_support_point(&self, support_point: &[f64]) -> Result<(), PharmsolError> { if support_point.len() != self.info.parameters.len() { return Err(PharmsolError::OtherError(format!( @@ -406,17 +422,17 @@ impl SharedNativeModel { label: &InputLabel, kind: RouteKind, ) -> Result { - let input = - self.route_index(label.as_str()) - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: label.to_string(), - })?; + let input = self + .metadata_route_index_for_label(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; self.validate_input_for_kind(input, kind)?; Ok(input) } fn resolve_output_label(&self, label: &OutputLabel) -> Result { - self.output_index(label.as_str()) + self.metadata_output_index_for_label(label.as_str()) .ok_or_else(|| PharmsolError::UnknownOutputLabel { label: label.to_string(), }) @@ -774,14 +790,6 @@ impl NativeOdeModel { self } - pub fn route_index(&self, name: &str) -> Option { - self.shared.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.shared.output_index(name) - } - pub fn info(&self) -> &NativeModelInfo { self.shared.info.as_ref() } @@ -1175,14 +1183,24 @@ impl Equation for NativeOdeModel { support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = self.bind_error_models(error_models)?; let predictions = runtime_ode_predictions(self, subject, support_point)?; - predictions.log_likelihood(error_models) + predictions.log_likelihood(&bound_error_models) } fn kind() -> EqnKind { EqnKind::ODE } + fn assay_error_models(&self) -> AssayErrorModels { + AssayErrorModels::with_output_names( + self.info() + .outputs + .iter() + .map(|output| output.name.as_str()), + ) + } + fn estimate_predictions( &self, subject: &Subject, @@ -1197,8 +1215,13 @@ impl Equation for NativeOdeModel { support_point: &[f64], error_models: Option<&AssayErrorModels>, ) -> Result<(Self::P, Option), PharmsolError> { + let bound_error_models = match error_models { + Some(error_models) => Some(self.bind_error_models(error_models)?), + None => None, + }; + let predictions = runtime_ode_predictions(self, subject, support_point)?; - let likelihood = match error_models { + let likelihood = match bound_error_models.as_ref() { Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), None => None, }; @@ -1213,14 +1236,6 @@ impl NativeAnalyticalModel { } } - pub fn route_index(&self, name: &str) -> Option { - self.shared.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.shared.output_index(name) - } - pub fn info(&self) -> &NativeModelInfo { self.shared.info.as_ref() } @@ -1384,14 +1399,6 @@ impl NativeSdeModel { self } - pub fn route_index(&self, name: &str) -> Option { - self.shared.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.shared.output_index(name) - } - pub fn info(&self) -> &NativeModelInfo { self.shared.info.as_ref() } @@ -1687,6 +1694,13 @@ fn sort_events(events: &mut [Event]) { }); } +fn canonical_numeric_alias(label: &str, prefix: &str) -> Option { + if label.is_empty() || !label.chars().all(|ch| ch.is_ascii_digit()) { + return None; + } + Some(format!("{prefix}{label}")) +} + fn project_analytical_parameters( info: &NativeModelInfo, support_point: &[f64], @@ -1858,3 +1872,33 @@ fn apply_analytical_kernel( } } } + +#[cfg(test)] +mod tests { + use super::{canonical_numeric_alias, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; + + #[test] + fn canonical_numeric_alias_maps_bare_numeric_labels_to_contextual_prefixes() { + assert_eq!( + canonical_numeric_alias("1", NUMERIC_ROUTE_PREFIX), + Some("input_1".to_string()) + ); + assert_eq!( + canonical_numeric_alias("10", NUMERIC_OUTPUT_PREFIX), + Some("outeq_10".to_string()) + ); + } + + #[test] + fn canonical_numeric_alias_ignores_symbolic_and_prefixed_labels() { + assert_eq!(canonical_numeric_alias("iv", NUMERIC_ROUTE_PREFIX), None); + assert_eq!( + canonical_numeric_alias("input_1", NUMERIC_ROUTE_PREFIX), + None + ); + assert_eq!( + canonical_numeric_alias("outeq_2", NUMERIC_OUTPUT_PREFIX), + None + ); + } +} diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1cef784e..1b2ecb76 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -33,7 +33,7 @@ //! it into the same runtime model shape. Use it when you want reusable native //! artifacts and can control the target platform. //! - [`RuntimeCompilationTarget::Wasm`] emits portable WASM bytes and reloads -//! them into the host-side runtime adapter. Use it when you need a portable +//! them into the host-side runtime adapter. Choose this target when you need a portable //! artifact or browser-aligned deployment story. //! //! Smallest compile-and-run example: @@ -229,22 +229,6 @@ impl CompiledRuntimeModel { self.info().kind } - pub fn route_index(&self, name: &str) -> Option { - match self { - Self::Ode(model) => model.route_index(name), - Self::Analytical(model) => model.route_index(name), - Self::Sde(model) => model.route_index(name), - } - } - - pub fn output_index(&self, name: &str) -> Option { - match self { - Self::Ode(model) => model.output_index(name), - Self::Analytical(model) => model.output_index(name), - Self::Sde(model) => model.output_index(name), - } - } - pub fn estimate_predictions( &self, subject: &Subject, @@ -502,27 +486,27 @@ kind = ode params = ke, v states = central -outputs = 2, 10, 11 +outputs = outeq_2, outeq_10, outeq_11 infusion(iv) -> central dx(central) = -ke * central -out(10) = central / v ~ continuous() -out(2) = central / v ~ continuous() -out(11) = central / v ~ continuous() +out(outeq_10) = central / v ~ continuous() +out(outeq_2) = central / v ~ continuous() +out(outeq_11) = central / v ~ continuous() "#; const NUMERIC_ROUTE_LABELS_RUNTIME_DSL: &str = r#" -name = numeric_route_runtime +name = prefixed_numeric_route_runtime kind = ode params = ke, v states = central outputs = cp -bolus(10) -> central -bolus(11) -> central +bolus(input_10) -> central +bolus(input_11) -> central dx(central) = -ke * central @@ -530,18 +514,18 @@ out(cp) = central / v ~ continuous() "#; const SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" -name = shared_numeric_route_output_runtime +name = prefixed_numeric_route_output_runtime kind = ode params = ke, v states = central -outputs = 1 +outputs = outeq_1 -infusion(1) -> central +infusion(input_1) -> central dx(central) = -ke * central -out(1) = central / v ~ continuous() +out(outeq_1) = central / v ~ continuous() "#; const UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" @@ -670,8 +654,35 @@ out(cp) = central / v ~ continuous() (jit, aot, wasm) } + fn compiled_route_input_index(model: &CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .routes + .iter() + .find(|route| route.name == name) + .map(|route| route.index) + } + + fn compiled_output_slot_index(model: &CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .outputs + .iter() + .find(|output| output.name == name) + .map(|output| output.index) + } + fn numeric_route_subject() -> Subject { Subject::builder("numeric-route-runtime") + .bolus(0.0, 120.0, "input_10") + .bolus(1.0, 80.0, "input_11") + .missing_observation(0.5, "cp") + .missing_observation(1.5, "cp") + .build() + } + + fn numeric_route_alias_subject() -> Subject { + Subject::builder("numeric-route-runtime-alias") .bolus(0.0, 120.0, "10") .bolus(1.0, 80.0, "11") .missing_observation(0.5, "cp") @@ -680,7 +691,15 @@ out(cp) = central / v ~ continuous() } fn shared_numeric_route_output_subject() -> Subject { - Subject::builder("shared-numeric-route-output-runtime") + Subject::builder("prefixed-numeric-route-output-runtime") + .infusion(0.0, 120.0, "input_1", 1.0) + .missing_observation(0.5, "outeq_1") + .missing_observation(1.5, "outeq_1") + .build() + } + + fn shared_numeric_route_output_alias_subject() -> Subject { + Subject::builder("raw-numeric-route-output-runtime") .infusion(0.0, 120.0, "1", 1.0) .missing_observation(0.5, "1") .missing_observation(1.5, "1") @@ -758,9 +777,9 @@ out(cp) = central / v ~ continuous() vec!["ka", "cl", "v", "tlag", "f_oral"] ); - assert!(jit.route_index("oral").is_some()); - assert!(jit.route_index("iv").is_some()); - assert_eq!(jit.output_index("cp"), Some(0)); + assert!(compiled_route_input_index(&jit, "oral").is_some()); + assert!(compiled_route_input_index(&jit, "iv").is_some()); + assert_eq!(compiled_output_slot_index(&jit, "cp"), Some(0)); let subject = ode_subject(); let jit_values = subject_values( @@ -796,33 +815,33 @@ out(cp) = central / v ~ continuous() work_dir.path(), ); - assert_eq!(jit.output_index("2"), Some(0)); - assert_eq!(jit.output_index("10"), Some(1)); - assert_eq!(jit.output_index("11"), Some(2)); - assert_eq!(aot.output_index("2"), Some(0)); - assert_eq!(aot.output_index("10"), Some(1)); - assert_eq!(aot.output_index("11"), Some(2)); - assert_eq!(wasm.output_index("2"), Some(0)); - assert_eq!(wasm.output_index("10"), Some(1)); - assert_eq!(wasm.output_index("11"), Some(2)); + assert_eq!(compiled_output_slot_index(&jit, "outeq_2"), Some(0)); + assert_eq!(compiled_output_slot_index(&jit, "outeq_10"), Some(1)); + assert_eq!(compiled_output_slot_index(&jit, "outeq_11"), Some(2)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_2"), Some(0)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_10"), Some(1)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_11"), Some(2)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_2"), Some(0)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_10"), Some(1)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_11"), Some(2)); } #[test] - fn runtime_backend_matrix_supports_multi_digit_numeric_route_labels() { + fn runtime_backend_matrix_supports_prefixed_multi_digit_numeric_route_labels() { let work_dir = tempdir().expect("tempdir"); let support = vec![0.2, 10.0]; let (jit, aot, wasm) = compile_runtime_backend_matrix( NUMERIC_ROUTE_LABELS_RUNTIME_DSL, - "numeric_route_runtime", + "prefixed_numeric_route_runtime", work_dir.path(), ); - assert_eq!(jit.route_index("10"), Some(0)); - assert_eq!(jit.route_index("11"), Some(1)); - assert_eq!(aot.route_index("10"), Some(0)); - assert_eq!(aot.route_index("11"), Some(1)); - assert_eq!(wasm.route_index("10"), Some(0)); - assert_eq!(wasm.route_index("11"), Some(1)); + assert_eq!(compiled_route_input_index(&jit, "input_10"), Some(0)); + assert_eq!(compiled_route_input_index(&jit, "input_11"), Some(1)); + assert_eq!(compiled_route_input_index(&aot, "input_10"), Some(0)); + assert_eq!(compiled_route_input_index(&aot, "input_11"), Some(1)); + assert_eq!(compiled_route_input_index(&wasm, "input_10"), Some(0)); + assert_eq!(compiled_route_input_index(&wasm, "input_11"), Some(1)); let subject = numeric_route_subject(); @@ -851,21 +870,57 @@ out(cp) = central / v ~ continuous() } #[test] - fn runtime_backend_matrix_supports_shared_numeric_route_and_output_labels() { + fn runtime_backend_matrix_resolves_raw_numeric_route_labels_against_prefixed_metadata() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "prefixed_numeric_route_runtime", + work_dir.path(), + ); + + let subject = numeric_route_alias_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_supports_prefixed_numeric_route_and_output_labels() { let work_dir = tempdir().expect("tempdir"); let support = vec![0.2, 10.0]; let (jit, aot, wasm) = compile_runtime_backend_matrix( SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL, - "shared_numeric_route_output_runtime", + "prefixed_numeric_route_output_runtime", work_dir.path(), ); - assert_eq!(jit.route_index("1"), Some(0)); - assert_eq!(jit.output_index("1"), Some(0)); - assert_eq!(aot.route_index("1"), Some(0)); - assert_eq!(aot.output_index("1"), Some(0)); - assert_eq!(wasm.route_index("1"), Some(0)); - assert_eq!(wasm.output_index("1"), Some(0)); + assert_eq!(compiled_route_input_index(&jit, "input_1"), Some(0)); + assert_eq!(compiled_output_slot_index(&jit, "outeq_1"), Some(0)); + assert_eq!(compiled_route_input_index(&aot, "input_1"), Some(0)); + assert_eq!(compiled_output_slot_index(&aot, "outeq_1"), Some(0)); + assert_eq!(compiled_route_input_index(&wasm, "input_1"), Some(0)); + assert_eq!(compiled_output_slot_index(&wasm, "outeq_1"), Some(0)); let subject = shared_numeric_route_output_subject(); @@ -893,6 +948,42 @@ out(cp) = central / v ~ continuous() } } + #[test] + fn runtime_backend_matrix_resolves_shared_raw_numeric_route_and_output_aliases() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL, + "prefixed_numeric_route_output_runtime", + work_dir.path(), + ); + + let subject = shared_numeric_route_output_alias_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + #[test] fn runtime_backend_matrix_rejects_undeclared_numeric_output_labels() { let work_dir = tempdir().expect("tempdir"); diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 1dd4bbb5..1dcc3a8f 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -159,14 +159,6 @@ impl Analytical { self.metadata()?.state_index(name) } - pub fn route_index(&self, name: &str) -> Option { - self.metadata()?.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.metadata()?.output_index(name) - } - fn invalidate_metadata(&mut self) { self.metadata = None; } @@ -562,16 +554,66 @@ pub(crate) mod tests { .route(super::super::Route::infusion("iv").to_state("central")), ) .expect("metadata attachment should validate"); + let metadata = analytical.metadata().expect("metadata exists"); assert_eq!(analytical.parameter_index("ke"), Some(0)); assert_eq!(analytical.parameter_index("v"), Some(1)); assert_eq!(analytical.covariate_index("wt"), Some(0)); assert_eq!(analytical.state_index("central"), Some(0)); - assert_eq!(analytical.route_index("iv"), Some(0)); - assert_eq!(analytical.output_index("cp"), Some(0)); - assert_eq!( - analytical.metadata().expect("metadata exists").kind(), - ModelKind::Analytical + assert!(metadata.route("iv").is_some()); + assert!(metadata.output("cp").is_some()); + assert_eq!(metadata.kind(), ModelKind::Analytical); + } + + #[test] + fn handwritten_analytical_metadata_resolves_raw_numeric_aliases_against_canonical_labels() { + let eq = |x: &V, _p: &V, dt: f64, rateiv: &V, _cov: &Covariates| { + let mut next = x.clone(); + next[0] += rateiv[0] * dt; + next + }; + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[0]; + }; + + let analytical = Analytical::new(eq, seq_eq, lag, fa, init, out) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("numeric_alias_analytical") + .states(["central"]) + .outputs(["outeq_1"]) + .route(super::super::Route::infusion("input_1").to_state("central")), + ) + .expect("metadata attachment should validate"); + + let canonical = Subject::builder("canonical") + .infusion(0.0, 100.0, "input_1", 1.0) + .observation(1.0, 0.0, "outeq_1") + .build(); + let aliased = Subject::builder("aliased") + .infusion(0.0, 100.0, "1", 1.0) + .observation(1.0, 0.0, "1") + .build(); + + let canonical_predictions = analytical + .estimate_predictions(&canonical, &[]) + .expect("canonical labels should simulate"); + let aliased_predictions = analytical + .estimate_predictions(&aliased, &[]) + .expect("raw numeric aliases should simulate"); + + assert_relative_eq!( + canonical_predictions.predictions()[0].prediction(), + aliased_predictions.predictions()[0].prediction(), + epsilon = 1e-10 ); } @@ -581,8 +623,6 @@ pub(crate) mod tests { assert!(analytical.metadata().is_none()); assert_eq!(analytical.state_index("central"), None); - assert_eq!(analytical.route_index("iv"), None); - assert_eq!(analytical.output_index("cp"), None); } #[test] @@ -664,8 +704,15 @@ pub(crate) mod tests { .analytical_kernel(), Some(AnalyticalKernel::OneCompartmentWithAbsorption) ); - assert_eq!(analytical.route_index("oral"), Some(0)); - assert_eq!(analytical.route_index("iv"), Some(0)); + let metadata = analytical.metadata().expect("metadata exists"); + assert_eq!( + metadata.route("oral").map(|route| route.input_index()), + Some(0) + ); + assert_eq!( + metadata.route("iv").map(|route| route.input_index()), + Some(0) + ); } #[test] @@ -681,7 +728,6 @@ pub(crate) mod tests { .with_ndrugs(2); assert!(analytical.metadata().is_none()); - assert_eq!(analytical.route_index("iv"), None); } fn assert_pm_wrapper_matches_native( @@ -824,8 +870,9 @@ impl Equation for Analytical { support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = self.bind_error_models(error_models)?; let ypred = _subject_predictions(self, subject, support_point)?; - ypred.log_likelihood(error_models) + ypred.log_likelihood(&bound_error_models) } fn kind() -> EqnKind { @@ -859,6 +906,7 @@ fn _estimate_likelihood( support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = ode.bind_error_models(error_models)?; let ypred = _subject_predictions(ode, subject, support_point)?; - Ok(ypred.log_likelihood(error_models)?.exp()) + Ok(ypred.log_likelihood(&bound_error_models)?.exp()) } diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs index c7fbd4c9..a512381e 100644 --- a/src/simulator/equation/metadata.rs +++ b/src/simulator/equation/metadata.rs @@ -27,7 +27,7 @@ //! //! assert_eq!(metadata.name(), "one_cmt"); //! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); -//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! assert!(metadata.output("cp").is_some()); //! ``` use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind}; @@ -181,18 +181,13 @@ impl ValidatedModelMetadata { self.states.iter().position(|state| state.name() == name) } - /// Look up a route by public name and return its dense execution input index. - pub fn route_index(&self, name: &str) -> Option { - self.route(name).map(ValidatedRoute::input_index) - } - /// Look up a route by public name and return its declaration-order index. pub fn route_declaration_index(&self, name: &str) -> Option { self.routes.iter().position(|route| route.name() == name) } /// Look up an output by public name and return its dense output index. - pub fn output_index(&self, name: &str) -> Option { + pub(crate) fn output_index(&self, name: &str) -> Option { self.outputs.iter().position(|output| output.name() == name) } @@ -983,7 +978,7 @@ mod tests { assert_eq!(metadata.parameter_index("v"), Some(1)); assert_eq!(metadata.covariate_index("wt"), Some(0)); assert_eq!(metadata.state_index("central"), Some(0)); - assert_eq!(metadata.route_index("iv"), Some(0)); + assert!(metadata.route("iv").is_some()); assert_eq!(metadata.route_declaration_index("iv"), Some(0)); assert_eq!(metadata.route_input_count(), 1); assert_eq!(metadata.output_index("cp"), Some(0)); @@ -1083,8 +1078,6 @@ mod tests { assert_eq!(metadata.routes().len(), 2); assert_eq!(metadata.route_input_count(), 1); - assert_eq!(metadata.route_index("oral"), Some(0)); - assert_eq!(metadata.route_index("iv"), Some(0)); assert_eq!(metadata.route_declaration_index("oral"), Some(0)); assert_eq!(metadata.route_declaration_index("iv"), Some(1)); assert_eq!(metadata.route("oral").expect("oral route").input_index(), 0); diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 03e5318c..838124dc 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -42,8 +42,8 @@ //! .validate() //! .unwrap(); //! -//! assert_eq!(metadata.route_index("iv"), Some(0)); -//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); +//! assert!(metadata.output("cp").is_some()); //! ``` use std::fmt::Debug; @@ -66,6 +66,9 @@ use crate::{ use super::likelihood::Prediction; +const NUMERIC_ROUTE_PREFIX: &str = "input_"; +const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; + /// Trait for state vectors that can receive bolus doses. pub trait State { /// Add a bolus dose to the state at the specified resolved input index. @@ -198,12 +201,15 @@ pub(crate) trait EquationPriv: EquationTypes { expected_kind: RouteKind, ) -> Result { if let Some(metadata) = self.metadata() { - let route = - metadata - .route(label.as_str()) - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: label.to_string(), - })?; + let route = metadata + .route(label.as_str()) + .or_else(|| { + canonical_numeric_alias(label.as_str(), NUMERIC_ROUTE_PREFIX) + .and_then(|alias| metadata.route(alias.as_str())) + }) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; if route.kind() != expected_kind { return Err(PharmsolError::OtherError(format!( @@ -226,11 +232,15 @@ pub(crate) trait EquationPriv: EquationTypes { fn resolve_output_label(&self, label: &OutputLabel) -> Result { if let Some(metadata) = self.metadata() { - return metadata.output_index(label.as_str()).ok_or_else(|| { - PharmsolError::UnknownOutputLabel { + return metadata + .output_index(label.as_str()) + .or_else(|| { + canonical_numeric_alias(label.as_str(), NUMERIC_OUTPUT_PREFIX) + .and_then(|alias| metadata.output_index(alias.as_str())) + }) + .ok_or_else(|| PharmsolError::UnknownOutputLabel { label: label.to_string(), - } - }); + }); } label @@ -356,6 +366,13 @@ pub(crate) trait EquationPriv: EquationTypes { } } +fn canonical_numeric_alias(label: &str, prefix: &str) -> Option { + if label.is_empty() || !label.chars().all(|ch| ch.is_ascii_digit()) { + return None; + } + Some(format!("{prefix}{label}")) +} + /// Trait for handwritten model equations that can be simulated. /// /// [`Equation`] is the shared interface implemented by handwritten [`ODE`], @@ -373,6 +390,14 @@ pub(crate) trait EquationPriv: EquationTypes { /// is provided for backward compatibility. #[allow(private_bounds)] pub trait Equation: EquationPriv + 'static + Clone + Sync { + #[doc(hidden)] + fn bind_error_models( + &self, + error_models: &AssayErrorModels, + ) -> Result { + Ok(error_models.bind_to(self)?) + } + /// Estimate the likelihood of the subject given the support point and error model. /// /// **Deprecated**: Use [`estimate_log_likelihood`](Self::estimate_log_likelihood) instead @@ -450,6 +475,22 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { self.get_nstates() } + /// Build a label-aware [`AssayErrorModels`] set for this equation. + /// + /// Handwritten equations resolve output labels from attached metadata. + /// Equations without metadata fall back to an explicit unbound set so dense + /// output-slot workflows remain available without adding runtime lookup cost. + #[doc(hidden)] + fn assay_error_models(&self) -> AssayErrorModels { + self.metadata() + .map(|metadata| { + AssayErrorModels::with_output_names( + metadata.outputs().iter().map(|output| output.name()), + ) + }) + .unwrap_or_else(AssayErrorModels::empty) + } + /// Simulate a subject with given parameters and optionally calculate likelihood. /// /// # Parameters @@ -465,6 +506,11 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { support_point: &[f64], error_models: Option<&AssayErrorModels>, ) -> Result<(Self::P, Option), PharmsolError> { + let bound_error_models = match error_models { + Some(error_models) => Some(self.bind_error_models(error_models)?), + None => None, + }; + let mut output = Self::P::new(self.nparticles()); let mut likelihood = Vec::new(); for occasion in subject.occasions() { @@ -478,7 +524,7 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { support_point, event, events.get(index + 1), - error_models, + bound_error_models.as_ref(), covariates, &mut x, &mut infusions, @@ -487,7 +533,9 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { )?; } } - let ll = error_models.map(|_| likelihood.iter().product::()); + let ll = bound_error_models + .as_ref() + .map(|_| likelihood.iter().product::()); Ok((output, ll)) } } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index c65f16a9..f982347f 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -184,14 +184,6 @@ impl ODE { self.metadata()?.state_index(name) } - pub fn route_index(&self, name: &str) -> Option { - self.metadata()?.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.metadata()?.output_index(name) - } - fn invalidate_metadata(&mut self) { self.metadata = None; } @@ -264,8 +256,9 @@ fn _estimate_likelihood( support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = ode.bind_error_models(error_models)?; let ypred = _subject_predictions(ode, subject, support_point)?; - Ok(ypred.log_likelihood(error_models)?.exp()) + Ok(ypred.log_likelihood(&bound_error_models)?.exp()) } #[inline(always)] @@ -527,8 +520,9 @@ impl Equation for ODE { support_point: &[f64], error_models: &AssayErrorModels, ) -> Result { + let bound_error_models = self.bind_error_models(error_models)?; let ypred = _subject_predictions(self, subject, support_point)?; - ypred.log_likelihood(error_models) + ypred.log_likelihood(&bound_error_models) } fn kind() -> EqnKind { @@ -541,6 +535,11 @@ impl Equation for ODE { support_point: &[f64], error_models: Option<&AssayErrorModels>, ) -> Result<(Self::P, Option), PharmsolError> { + let bound_error_models = match error_models { + Some(error_models) => Some(self.bind_error_models(error_models)?), + None => None, + }; + let mut output = Self::P::new(self.nparticles()); // Preallocate likelihood vector @@ -602,7 +601,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -621,7 +620,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -640,7 +639,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -659,7 +658,7 @@ impl Equation for ODE { &events, &spp_v, covariates, - error_models, + bound_error_models.as_ref(), &mut bolus_v, &zero_bolus, &zero_rateiv, @@ -672,7 +671,9 @@ impl Equation for ODE { } } } - let ll = error_models.map(|_| likelihood.iter().product::()); + let ll = bound_error_models + .as_ref() + .map(|_| likelihood.iter().product::()); Ok((output, ll)) } } @@ -753,16 +754,14 @@ mod tests { .route(super::super::Route::infusion("iv").to_state("central")), ) .expect("metadata attachment should validate"); + let metadata = ode.metadata().expect("metadata exists"); assert_eq!(ode.parameter_index("ke"), Some(0)); assert_eq!(ode.parameter_index("v"), Some(1)); assert_eq!(ode.state_index("central"), Some(0)); - assert_eq!(ode.route_index("iv"), Some(0)); - assert_eq!(ode.output_index("cp"), Some(0)); - assert_eq!( - ode.metadata().expect("metadata exists").kind(), - ModelKind::Ode - ); + assert!(metadata.route("iv").is_some()); + assert!(metadata.output("cp").is_some()); + assert_eq!(metadata.kind(), ModelKind::Ode); } #[test] @@ -771,8 +770,6 @@ mod tests { assert!(ode.metadata().is_none()); assert_eq!(ode.state_index("central"), None); - assert_eq!(ode.route_index("iv"), None); - assert_eq!(ode.output_index("cp"), None); } #[test] @@ -843,9 +840,16 @@ mod tests { .simulate_subject(&route_policy_subject(), &[], None) .expect("simulation should succeed") .0; + let metadata = ode.metadata().expect("metadata exists"); - assert_eq!(ode.route_index("oral").expect("oral route"), 0); - assert_eq!(ode.route_index("iv").expect("iv route"), 0); + assert_eq!( + metadata.route("oral").map(|route| route.input_index()), + Some(0) + ); + assert_eq!( + metadata.route("iv").map(|route| route.input_index()), + Some(0) + ); assert_relative_eq!( predictions.predictions()[0].prediction(), 200.0, @@ -892,6 +896,51 @@ mod tests { ); } + #[test] + fn handwritten_ode_metadata_resolves_raw_numeric_aliases_against_canonical_labels() { + let ode = ODE::new( + explicit_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("numeric_alias_ode") + .states(["central"]) + .outputs(["outeq_1"]) + .route(super::super::Route::infusion("input_1").to_state("central")), + ) + .expect("metadata attachment should validate"); + + let canonical = Subject::builder("canonical") + .infusion(0.0, 100.0, "input_1", 1.0) + .observation(1.0, 0.0, "outeq_1") + .build(); + let aliased = Subject::builder("aliased") + .infusion(0.0, 100.0, "1", 1.0) + .observation(1.0, 0.0, "1") + .build(); + + let canonical_predictions = ode + .simulate_subject(&canonical, &[], None) + .expect("canonical labels should simulate") + .0; + let aliased_predictions = ode + .simulate_subject(&aliased, &[], None) + .expect("raw numeric aliases should simulate") + .0; + + assert_relative_eq!( + canonical_predictions.predictions()[0].prediction(), + aliased_predictions.predictions()[0].prediction(), + epsilon = 1e-6 + ); + } + #[test] fn changing_dimensions_after_metadata_clears_route_metadata() { let ode = simple_ode() @@ -905,6 +954,5 @@ mod tests { .with_ndrugs(2); assert!(ode.metadata().is_none()); - assert_eq!(ode.route_index("iv"), None); } } diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index 43a1d48a..c5b01435 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -280,14 +280,6 @@ impl SDE { self.metadata()?.state_index(name) } - pub fn route_index(&self, name: &str) -> Option { - self.metadata()?.route_index(name) - } - - pub fn output_index(&self, name: &str) -> Option { - self.metadata()?.output_index(name) - } - fn invalidate_metadata(&mut self) { self.metadata = None; self.injected_bolus_mappings @@ -812,8 +804,62 @@ mod tests { assert_eq!(sde.parameter_index("v"), Some(1)); assert_eq!(sde.covariate_index("wt"), Some(0)); assert_eq!(sde.state_index("central"), Some(0)); - assert_eq!(sde.route_index("iv"), Some(0)); - assert_eq!(sde.output_index("cp"), Some(0)); + assert!(metadata.route("iv").is_some()); + assert!(metadata.output("cp").is_some()); + } + + #[test] + fn handwritten_sde_metadata_resolves_raw_numeric_aliases_against_canonical_labels() { + let drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + dx[1] = rateiv[0]; + }; + let diffusion = |_p: &V, sigma: &mut V| { + sigma.fill(0.0); + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + let sde = SDE::new(drift, diffusion, lag, fa, init, out, 16) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_alias_sde") + .states(["depot", "central"]) + .outputs(["outeq_1"]) + .route(Route::infusion("input_1").to_state("central")) + .particles(16), + ) + .expect("SDE metadata attachment should validate"); + + let canonical = Subject::builder("canonical") + .infusion(0.0, 100.0, "input_1", 1.0) + .observation(1.0, 0.0, "outeq_1") + .build(); + let aliased = Subject::builder("aliased") + .infusion(0.0, 100.0, "1", 1.0) + .observation(1.0, 0.0, "1") + .build(); + + let canonical_predictions = sde + .estimate_predictions(&canonical, &[]) + .expect("canonical labels should simulate"); + let aliased_predictions = sde + .estimate_predictions(&aliased, &[]) + .expect("raw numeric aliases should simulate"); + + assert!( + (canonical_predictions[[0, 0]].prediction() - aliased_predictions[[0, 0]].prediction()) + .abs() + < 1e-10 + ); } #[test] @@ -822,8 +868,6 @@ mod tests { assert!(sde.metadata().is_none()); assert_eq!(sde.parameter_index("ke"), None); - assert_eq!(sde.route_index("iv"), None); - assert_eq!(sde.output_index("cp"), None); } #[test] @@ -885,8 +929,6 @@ mod tests { .with_nout(2); assert!(sde.metadata().is_none()); - assert_eq!(sde.route_index("iv"), None); - assert_eq!(sde.output_index("cp"), None); } #[test] diff --git a/src/simulator/likelihood/mod.rs b/src/simulator/likelihood/mod.rs index c703dee7..30ad4945 100644 --- a/src/simulator/likelihood/mod.rs +++ b/src/simulator/likelihood/mod.rs @@ -223,7 +223,7 @@ mod tests { }; // Create error model with additive error - let error_models = crate::AssayErrorModels::new() + let error_models = crate::AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 1.0, 0.0, 0.0), 0.0), @@ -270,7 +270,7 @@ mod tests { ]; let subject_predictions = SubjectPredictions::from(predictions); - let error_models = crate::AssayErrorModels::new() + let error_models = crate::AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 1.0, 0.0, 0.0), 0.0), @@ -294,7 +294,7 @@ mod tests { #[test] fn test_empty_predictions_have_neutral_log_likelihood() { let preds = SubjectPredictions::default(); - let errors = crate::AssayErrorModels::new(); + let errors = crate::AssayErrorModels::empty(); assert_eq!(preds.log_likelihood(&errors).unwrap(), 0.0); // log(1) = 0 } @@ -305,7 +305,9 @@ mod tests { preds.add_prediction(obs.to_prediction(1.0, vec![])); let error_model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 0.0); - let errors = crate::AssayErrorModels::new().add(0, error_model).unwrap(); + let errors = crate::AssayErrorModels::empty() + .add(0, error_model) + .unwrap(); let log_lik = preds.log_likelihood(&errors).unwrap(); assert!(log_lik.is_finite()); diff --git a/src/simulator/likelihood/prediction.rs b/src/simulator/likelihood/prediction.rs index fbd60230..b1610f56 100644 --- a/src/simulator/likelihood/prediction.rs +++ b/src/simulator/likelihood/prediction.rs @@ -233,7 +233,7 @@ mod tests { } fn create_error_models() -> AssayErrorModels { - AssayErrorModels::new() + AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 0.0), diff --git a/src/simulator/likelihood/subject.rs b/src/simulator/likelihood/subject.rs index 77d8963b..5c3346fe 100644 --- a/src/simulator/likelihood/subject.rs +++ b/src/simulator/likelihood/subject.rs @@ -169,7 +169,7 @@ mod tests { use crate::Censor; fn create_error_models() -> AssayErrorModels { - AssayErrorModels::new() + AssayErrorModels::empty() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 1.0, 0.0, 0.0), 0.0), @@ -199,7 +199,7 @@ mod tests { preds.add_prediction(obs.to_prediction(1.0, vec![])); let error_model = AssayErrorModel::additive(ErrorPoly::new(1.0, 0.0, 0.0, 0.0), 0.0); - let errors = AssayErrorModels::new().add(0, error_model).unwrap(); + let errors = AssayErrorModels::empty().add(0, error_model).unwrap(); let log_lik = preds.log_likelihood(&errors).unwrap(); assert!(log_lik.is_finite()); diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index f527978f..44075cff 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -384,10 +384,13 @@ fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { let handwritten_model = handwritten_one_compartment(); let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 10.0]; + let macro_metadata = macro_model + .metadata() + .expect("macro analytical metadata exists"); assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - assert_eq!(macro_model.route_index("iv"), Some(0)); - assert_eq!(macro_model.output_index("cp"), Some(0)); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_model.state_index("central"), Some(0)); let macro_predictions = macro_model @@ -410,16 +413,14 @@ fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { let handwritten_model = handwritten_one_compartment_with_absorption(); let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro metadata exists"); assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - assert_eq!(macro_model.route_index("oral"), Some(0)); - assert_eq!(macro_model.output_index("cp"), Some(0)); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_model.state_index("gut"), Some(0)); assert_eq!( - macro_model - .metadata() - .expect("macro metadata exists") - .analytical_kernel(), + macro_metadata.analytical_kernel(), Some(equation::AnalyticalKernel::OneCompartmentWithAbsorption) ); @@ -443,11 +444,12 @@ fn analytical_macro_shared_input_lowering_matches_handwritten_metadata_and_predi let handwritten_model = handwritten_shared_input_analytical(); let subject = shared_input_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro metadata exists"); assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - assert_eq!(macro_model.route_index("oral"), Some(0)); - assert_eq!(macro_model.route_index("iv"), Some(0)); - assert_eq!(macro_model.output_index("cp"), Some(0)); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_model.state_index("gut"), Some(0)); assert_eq!(macro_model.state_index("central"), Some(1)); diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index be80f10e..e0b86744 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "dsl-jit")] use approx::assert_relative_eq; #[cfg(feature = "dsl-jit")] use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; @@ -59,26 +60,26 @@ kind = ode params = ke, v states = central -outputs = 2, 10, 11 +outputs = outeq_2, outeq_10, outeq_11 infusion(iv) -> central dx(central) = -ke * central -out(10) = central / v ~ continuous() -out(2) = central / v ~ continuous() -out(11) = central / v ~ continuous() +out(outeq_10) = central / v ~ continuous() +out(outeq_2) = central / v ~ continuous() +out(outeq_11) = central / v ~ continuous() "#; const ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL: &str = r#" -name = authoring_numeric_routes +name = authoring_prefixed_numeric_routes kind = ode states = first, second outputs = cp -bolus(10) -> first -bolus(11) -> second +bolus(input_10) -> first +bolus(input_11) -> second dx(first) = 0 dx(second) = 0 @@ -93,8 +94,8 @@ const ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL: &str = r#"model structured_numeri second, } routes { - 10 -> first - 11 -> second + input_10 -> first + input_11 -> second } dynamics { ddt(first) = 0 @@ -149,15 +150,15 @@ kind = ode params = ke, v states = central -outputs = cp, 0, 1 +outputs = cp, outeq_0, outeq_1 infusion(iv) -> central dx(central) = -ke * central out(cp) = central / v ~ continuous() -out(0) = 2 * central / v ~ continuous() -out(1) = 3 * central / v ~ continuous() +out(outeq_0) = 2 * central / v ~ continuous() +out(outeq_1) = 3 * central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] @@ -390,6 +391,26 @@ fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntim .expect("DSL runtime model should compile") } +#[cfg(feature = "dsl-jit")] +fn compiled_route_input_index(model: &dsl::CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .routes + .iter() + .find(|route| route.name == name) + .map(|route| route.index) +} + +#[cfg(feature = "dsl-jit")] +fn compiled_output_slot_index(model: &dsl::CompiledRuntimeModel, name: &str) -> Option { + model + .info() + .outputs + .iter() + .find(|output| output.name == name) + .map(|output| output.index) +} + #[cfg(feature = "dsl-jit")] fn shared_input_prediction_subject() -> Subject { Subject::builder("authoring-parity-shared-input") @@ -1169,15 +1190,15 @@ fn ode_dsl_declared_output_order_controls_dense_indices_for_multi_digit_labels() dsl_view.outputs, vec![ NamedIndex { - name: "2".to_string(), + name: "outeq_2".to_string(), index: 0, }, NamedIndex { - name: "10".to_string(), + name: "outeq_10".to_string(), index: 1, }, NamedIndex { - name: "11".to_string(), + name: "outeq_11".to_string(), index: 2, }, ] @@ -1185,7 +1206,7 @@ fn ode_dsl_declared_output_order_controls_dense_indices_for_multi_digit_labels() } #[test] -fn ode_authoring_dsl_supports_multi_digit_numeric_route_labels() { +fn ode_authoring_dsl_supports_prefixed_multi_digit_numeric_route_labels() { let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL); assert_eq!(dsl_view.route_input_count, 2); @@ -1193,7 +1214,7 @@ fn ode_authoring_dsl_supports_multi_digit_numeric_route_labels() { dsl_view.routes, vec![ RouteParity { - name: "10".to_string(), + name: "input_10".to_string(), kind: Some(RouteKindParity::Bolus), declaration_index: 0, input_index: 0, @@ -1203,7 +1224,7 @@ fn ode_authoring_dsl_supports_multi_digit_numeric_route_labels() { has_bioavailability: false, }, RouteParity { - name: "11".to_string(), + name: "input_11".to_string(), kind: Some(RouteKindParity::Bolus), declaration_index: 1, input_index: 1, @@ -1217,7 +1238,7 @@ fn ode_authoring_dsl_supports_multi_digit_numeric_route_labels() { } #[test] -fn ode_structured_dsl_supports_multi_digit_numeric_route_labels() { +fn ode_structured_dsl_supports_prefixed_multi_digit_numeric_route_labels() { let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL); assert_eq!(dsl_view.route_input_count, 2); @@ -1225,7 +1246,7 @@ fn ode_structured_dsl_supports_multi_digit_numeric_route_labels() { dsl_view.routes, vec![ RouteParity { - name: "10".to_string(), + name: "input_10".to_string(), kind: None, declaration_index: 0, input_index: 0, @@ -1235,7 +1256,7 @@ fn ode_structured_dsl_supports_multi_digit_numeric_route_labels() { has_bioavailability: false, }, RouteParity { - name: "11".to_string(), + name: "input_11".to_string(), kind: None, declaration_index: 1, input_index: 1, @@ -1393,26 +1414,45 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); let macro_model = runtime_shared_input_macro_ode(); let handwritten_model = runtime_shared_input_handwritten_ode(); + let macro_metadata = macro_model.metadata().expect("macro ODE metadata exists"); + let handwritten_metadata = handwritten_model + .metadata() + .expect("handwritten ODE metadata exists"); - let oral = runtime_model - .route_index("oral") + let oral = compiled_route_input_index(&runtime_model, "oral") .expect("runtime oral route should exist"); - let iv = runtime_model - .route_index("iv") - .expect("runtime iv route should exist"); - let cp = runtime_model - .output_index("cp") - .expect("runtime cp output should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); assert_eq!(cp, 0); - assert_eq!(macro_model.route_index("oral"), Some(oral)); - assert_eq!(macro_model.route_index("iv"), Some(iv)); - assert_eq!(handwritten_model.route_index("oral"), Some(oral)); - assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + assert_eq!( + macro_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + macro_metadata.route("iv").map(|route| route.input_index()), + Some(iv) + ); + assert_eq!( + handwritten_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + handwritten_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); let runtime_predictions = match runtime_model .estimate_predictions(&subject, &support_point) @@ -1443,26 +1483,47 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_inpu compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_INPUT_DSL, "one_cmt_abs_shared"); let macro_model = runtime_shared_input_macro_analytical(); let handwritten_model = runtime_shared_input_handwritten_analytical(); - - let oral = runtime_model - .route_index("oral") + let macro_metadata = macro_model + .metadata() + .expect("macro analytical metadata exists"); + let handwritten_metadata = handwritten_model + .metadata() + .expect("handwritten analytical metadata exists"); + + let oral = compiled_route_input_index(&runtime_model, "oral") .expect("runtime oral route should exist"); - let iv = runtime_model - .route_index("iv") - .expect("runtime iv route should exist"); - let cp = runtime_model - .output_index("cp") - .expect("runtime cp output should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); assert_eq!(cp, 0); - assert_eq!(macro_model.route_index("oral"), Some(oral)); - assert_eq!(macro_model.route_index("iv"), Some(iv)); - assert_eq!(handwritten_model.route_index("oral"), Some(oral)); - assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + assert_eq!( + macro_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + macro_metadata.route("iv").map(|route| route.input_index()), + Some(iv) + ); + assert_eq!( + handwritten_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + handwritten_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); let runtime_predictions = match runtime_model .estimate_predictions(&subject, &support_point) @@ -1495,26 +1556,45 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape compile_runtime_jit_model(SDE_RUNTIME_SHARED_INPUT_DSL, "one_cmt_shared_sde"); let macro_model = runtime_shared_input_macro_sde(); let handwritten_model = runtime_shared_input_handwritten_sde(); + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); + let handwritten_metadata = handwritten_model + .metadata() + .expect("handwritten SDE metadata exists"); - let oral = runtime_model - .route_index("oral") + let oral = compiled_route_input_index(&runtime_model, "oral") .expect("runtime oral route should exist"); - let iv = runtime_model - .route_index("iv") - .expect("runtime iv route should exist"); - let cp = runtime_model - .output_index("cp") - .expect("runtime cp output should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); assert_eq!(cp, 0); - assert_eq!(macro_model.route_index("oral"), Some(oral)); - assert_eq!(macro_model.route_index("iv"), Some(iv)); - assert_eq!(handwritten_model.route_index("oral"), Some(oral)); - assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + assert_eq!( + macro_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + macro_metadata.route("iv").map(|route| route.input_index()), + Some(iv) + ); + assert_eq!( + handwritten_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + handwritten_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); let runtime_predictions = match runtime_model .estimate_predictions(&subject, &support_point) @@ -1544,24 +1624,34 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let runtime_model = compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); let mismatched_model = runtime_mismatched_shared_input_ode(); + let mismatched_metadata = mismatched_model + .metadata() + .expect("mismatched handwritten metadata exists"); - let oral = runtime_model - .route_index("oral") + let oral = compiled_route_input_index(&runtime_model, "oral") .expect("runtime oral route should exist"); - let iv = runtime_model - .route_index("iv") - .expect("runtime iv route should exist"); - let cp = runtime_model - .output_index("cp") - .expect("runtime cp output should exist"); + let iv = + compiled_route_input_index(&runtime_model, "iv").expect("runtime iv route should exist"); + let cp = + compiled_output_slot_index(&runtime_model, "cp").expect("runtime cp output should exist"); let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); assert_eq!(cp, 0); - assert_eq!(mismatched_model.route_index("oral"), Some(oral)); - assert_eq!(mismatched_model.route_index("iv"), Some(iv)); + assert_eq!( + mismatched_metadata + .route("oral") + .map(|route| route.input_index()), + Some(oral) + ); + assert_eq!( + mismatched_metadata + .route("iv") + .map(|route| route.input_index()), + Some(iv) + ); let runtime_predictions = match runtime_model .estimate_predictions(&subject, &support_point) @@ -1589,14 +1679,20 @@ fn ode_runtime_jit_preserves_mixed_output_labels() { let subject = Subject::builder("runtime-mixed-output-labels") .infusion(0.0, 100.0, "iv", 1.0) .missing_observation(0.5, "cp") - .missing_observation(0.5, "0") - .missing_observation(0.5, "1") + .missing_observation(0.5, "outeq_0") + .missing_observation(0.5, "outeq_1") .build(); let support_point = [0.2, 10.0]; - assert_eq!(runtime_model.output_index("cp"), Some(0)); - assert_eq!(runtime_model.output_index("0"), Some(1)); - assert_eq!(runtime_model.output_index("1"), Some(2)); + assert_eq!(compiled_output_slot_index(&runtime_model, "cp"), Some(0)); + assert_eq!( + compiled_output_slot_index(&runtime_model, "outeq_0"), + Some(1) + ); + assert_eq!( + compiled_output_slot_index(&runtime_model, "outeq_1"), + Some(2) + ); let predictions = match runtime_model .estimate_predictions(&subject, &support_point) diff --git a/tests/full_feature_dsl_backend_parity.rs b/tests/full_feature_dsl_backend_parity.rs index 929e7243..b83157dd 100644 --- a/tests/full_feature_dsl_backend_parity.rs +++ b/tests/full_feature_dsl_backend_parity.rs @@ -77,10 +77,6 @@ mod tests { .collect::>(), vec!["cp"] ); - assert_eq!(model.route_index("oral"), Some(0)); - assert_eq!(model.route_index("load"), Some(1)); - assert_eq!(model.route_index("iv"), Some(0)); - assert_eq!(model.output_index("cp"), Some(0)); } fn assert_analytical_full_public_shape(model: &CompiledRuntimeModel) { @@ -135,10 +131,6 @@ mod tests { .collect::>(), vec!["cp"] ); - assert_eq!(model.route_index("oral"), Some(0)); - assert_eq!(model.route_index("load"), Some(1)); - assert_eq!(model.route_index("iv"), Some(0)); - assert_eq!(model.output_index("cp"), Some(0)); } fn assert_full_backend_parity( diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index 5017902e..4fe7b442 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -392,20 +392,26 @@ fn build_analytical_subject() -> Subject { fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { let macro_ode = macro_ode_model(); let handwritten_ode = handwritten_ode_model(); + let macro_metadata = macro_ode.metadata().expect("macro ODE metadata exists"); assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - let oral = macro_ode.route_index("oral").expect("oral route exists"); - let load = macro_ode.route_index("load").expect("load route exists"); - let iv = macro_ode.route_index("iv").expect("iv route exists"); - let cp = macro_ode.output_index("cp").expect("cp output exists"); + let oral = macro_metadata + .route("oral") + .map(|route| route.input_index()) + .expect("oral route exists"); + let load = macro_metadata + .route("load") + .map(|route| route.input_index()) + .expect("load route exists"); + let iv = macro_metadata + .route("iv") + .map(|route| route.input_index()) + .expect("iv route exists"); assert_eq!(oral, iv); assert_eq!(load, 1); - assert_eq!(handwritten_ode.route_index("oral"), Some(oral)); - assert_eq!(handwritten_ode.route_index("load"), Some(load)); - assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); - assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + assert!(macro_metadata.output("cp").is_some()); let subject = build_ode_subject(); let params = [1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0]; @@ -429,29 +435,31 @@ fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::Pharmsol fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { let macro_analytical = macro_analytical_model(); let handwritten_analytical = handwritten_analytical_model(); + let macro_metadata = macro_analytical + .metadata() + .expect("macro analytical metadata exists"); assert_eq!( macro_analytical.metadata(), handwritten_analytical.metadata() ); - let oral = macro_analytical - .route_index("oral") + let oral = macro_metadata + .route("oral") + .map(|route| route.input_index()) .expect("oral route exists"); - let load = macro_analytical - .route_index("load") + let load = macro_metadata + .route("load") + .map(|route| route.input_index()) .expect("load route exists"); - let iv = macro_analytical.route_index("iv").expect("iv route exists"); - let cp = macro_analytical - .output_index("cp") - .expect("cp output exists"); + let iv = macro_metadata + .route("iv") + .map(|route| route.input_index()) + .expect("iv route exists"); assert_eq!(oral, iv); assert_eq!(load, 1); - assert_eq!(handwritten_analytical.route_index("oral"), Some(oral)); - assert_eq!(handwritten_analytical.route_index("load"), Some(load)); - assert_eq!(handwritten_analytical.route_index("iv"), Some(iv)); - assert_eq!(handwritten_analytical.output_index("cp"), Some(cp)); + assert!(macro_metadata.output("cp").is_some()); let subject = build_analytical_subject(); let params = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 99e0eeab..8e5540ee 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -104,15 +104,15 @@ fn numeric_label_macro_ode() -> equation::ODE { name: "numeric_label_one_cpt", params: [ke, v], states: [central], - outputs: [1], + outputs: [outeq_1], routes: [ - infusion(1) -> central, + infusion(input_1) -> central, ], diffeq: |x, _t, dx| { dx[central] = -ke * x[central]; }, out: |x, _t, y| { - y[1] = x[central] / v; + y[outeq_1] = x[central] / v; }, } } @@ -138,9 +138,9 @@ fn numeric_label_handwritten_ode() -> equation::ODE { equation::metadata::new("numeric_label_one_cpt") .parameters(["ke", "v"]) .states(["central"]) - .outputs(["1"]) + .outputs(["outeq_1"]) .route( - equation::Route::infusion("1") + equation::Route::infusion("input_1") .to_state("central") .inject_input_to_destination(), ), @@ -222,22 +222,22 @@ fn numeric_route_property_macro_ode() -> equation::ODE { name: "numeric_route_property_one_cpt", params: [ka, ke, v, tlag, f_oral], states: [depot, central], - outputs: [1], + outputs: [outeq_1], routes: [ - bolus(1) -> depot, + bolus(input_1) -> depot, ], diffeq: |x, _t, dx| { dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_t| { - lag! { 1 => tlag } + lag! { input_1 => tlag } }, fa: |_t| { - fa! { 1 => f_oral } + fa! { input_1 => f_oral } }, out: |x, _t, y| { - y[1] = x[central] / v; + y[outeq_1] = x[central] / v; }, } } @@ -270,9 +270,9 @@ fn numeric_route_property_handwritten_ode() -> equation::ODE { equation::metadata::new("numeric_route_property_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) - .outputs(["1"]) + .outputs(["outeq_1"]) .route( - equation::Route::bolus("1") + equation::Route::bolus("input_1") .to_state("depot") .with_lag() .with_bioavailability() @@ -287,7 +287,7 @@ fn mixed_output_labels_macro_ode() -> equation::ODE { name: "mixed_output_labels_one_cpt", params: [ke, v], states: [central], - outputs: [cp, 0, 1], + outputs: [cp, outeq_0, outeq_1], routes: [ infusion(iv) -> central, ], @@ -296,8 +296,8 @@ fn mixed_output_labels_macro_ode() -> equation::ODE { }, out: |x, _t, y| { y[cp] = x[central] / v; - y[0] = 2.0 * x[central] / v; - y[1] = 3.0 * x[central] / v; + y[outeq_0] = 2.0 * x[central] / v; + y[outeq_1] = 3.0 * x[central] / v; }, } } @@ -325,7 +325,7 @@ fn mixed_output_labels_handwritten_ode() -> equation::ODE { equation::metadata::new("mixed_output_labels_one_cpt") .parameters(["ke", "v"]) .states(["central"]) - .outputs(["cp", "0", "1"]) + .outputs(["cp", "outeq_0", "outeq_1"]) .route( equation::Route::infusion("iv") .to_state("central") @@ -404,10 +404,13 @@ fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { let handwritten_ode = injected_handwritten_ode(); let subject = subject_for_route("iv", "cp"); let support_point = [0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro injected model should carry metadata"); assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.route_index("iv"), Some(0)); - assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_ode.state_index("central"), Some(0)); let macro_predictions = macro_ode @@ -430,10 +433,13 @@ fn macro_numeric_labels_lower_to_dense_slots() { let handwritten_ode = numeric_label_handwritten_ode(); let subject = subject_for_route("1", "1"); let support_point = [0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro numeric-label model should carry metadata"); assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.route_index("1"), Some(0)); - assert_eq!(macro_ode.output_index("1"), Some(0)); + assert!(macro_metadata.route("input_1").is_some()); + assert!(macro_metadata.output("outeq_1").is_some()); assert_eq!(macro_ode.state_index("central"), Some(0)); let macro_predictions = macro_ode @@ -456,11 +462,14 @@ fn macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { let handwritten_ode = shared_input_handwritten_ode(); let subject = subject_for_shared_input(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_ode + .metadata() + .expect("macro shared-input model should carry metadata"); assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.route_index("oral"), Some(0)); - assert_eq!(macro_ode.route_index("iv"), Some(0)); - assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_ode.state_index("depot"), Some(0)); assert_eq!(macro_ode.state_index("central"), Some(1)); @@ -489,11 +498,14 @@ fn macro_mixed_output_labels_lower_to_dense_slots() { .missing_observation(2.0, "1") .build(); let support_point = [0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro mixed-output model should carry metadata"); assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.output_index("cp"), Some(0)); - assert_eq!(macro_ode.output_index("0"), Some(1)); - assert_eq!(macro_ode.output_index("1"), Some(2)); + assert!(macro_metadata.output("cp").is_some()); + assert!(macro_metadata.output("outeq_0").is_some()); + assert!(macro_metadata.output("outeq_1").is_some()); let macro_predictions = macro_ode .estimate_predictions(&subject, &support_point) @@ -515,10 +527,13 @@ fn macro_numeric_route_properties_lower_to_dense_slots() { let handwritten_ode = numeric_route_property_handwritten_ode(); let subject = subject_for_numeric_bolus_route("1", "1"); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + let macro_metadata = macro_ode + .metadata() + .expect("macro numeric route-property model should carry metadata"); assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.route_index("1"), Some(0)); - assert_eq!(macro_ode.output_index("1"), Some(0)); + assert!(macro_metadata.route("input_1").is_some()); + assert!(macro_metadata.output("outeq_1").is_some()); assert_eq!(macro_ode.state_index("depot"), Some(0)); assert_eq!(macro_ode.state_index("central"), Some(1)); @@ -598,8 +613,8 @@ fn macro_covariate_lowering_matches_handwritten_metadata_and_predictions() { assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); assert_eq!(macro_metadata.covariates().len(), 1); - assert_eq!(macro_ode.route_index("oral"), Some(0)); - assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_ode.state_index("gut"), Some(0)); assert_eq!(macro_ode.state_index("central"), Some(1)); diff --git a/tests/ode_optimizations.rs b/tests/ode_optimizations.rs index 488ae0f3..e50ecf10 100644 --- a/tests/ode_optimizations.rs +++ b/tests/ode_optimizations.rs @@ -900,7 +900,7 @@ fn likelihood_calculation_matches_analytical() { .with_nstates(1) .with_nout(1); - let error_models = AssayErrorModels::new() + let error_models = AssayErrorModels::default() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.0, 0.1, 0.0, 0.0), 0.0), diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 474c7bab..7980ccd4 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -493,10 +493,11 @@ fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { let handwritten_model = handwritten_infusion_sde(); let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 0.0, 10.0]; + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - assert_eq!(macro_model.route_index("iv"), Some(0)); - assert_eq!(macro_model.output_index("cp"), Some(0)); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_model.state_index("central"), Some(0)); let macro_predictions = macro_model @@ -518,10 +519,11 @@ fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { let handwritten_model = handwritten_absorption_sde(); let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - assert_eq!(macro_model.route_index("oral"), Some(0)); - assert_eq!(macro_model.output_index("cp"), Some(0)); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_model.state_index("gut"), Some(0)); let macro_predictions = macro_model @@ -543,11 +545,12 @@ fn sde_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions( let handwritten_model = handwritten_shared_input_sde(); let subject = shared_input_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + let macro_metadata = macro_model.metadata().expect("macro SDE metadata exists"); assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - assert_eq!(macro_model.route_index("oral"), Some(0)); - assert_eq!(macro_model.route_index("iv"), Some(0)); - assert_eq!(macro_model.output_index("cp"), Some(0)); + assert!(macro_metadata.route("oral").is_some()); + assert!(macro_metadata.route("iv").is_some()); + assert!(macro_metadata.output("cp").is_some()); assert_eq!(macro_model.state_index("gut"), Some(0)); assert_eq!(macro_model.state_index("central"), Some(1)); diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index 6e7e5f8e..97b958c2 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -73,14 +73,26 @@ pub fn subject() -> Subject { feature = "dsl-wasm" ))] pub fn subject_for_runtime_model(model: &pharmsol::dsl::CompiledRuntimeModel) -> Subject { - let route_label = if model.route_index("iv").is_some() { + let route_label = if model.info().routes.iter().any(|route| route.name == "iv") { "iv" - } else if model.route_index("input_0").is_some() { + } else if model + .info() + .routes + .iter() + .any(|route| route.name == "input_0") + { "input_0" } else { panic!("bimodal_ke route is available"); }; - model.output_index("cp").expect("cp output is available"); + assert!( + model + .info() + .outputs + .iter() + .any(|output| output.name == "cp"), + "cp output is available" + ); subject_for_labels(route_label, "cp") } diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 1ca8ae78..0f32fada 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -209,17 +209,30 @@ impl CorpusCase { fn runtime_subject(self, model: &CompiledRuntimeModel) -> Result> { model - .output_index("cp") + .info() + .outputs + .iter() + .find(|output| output.name == "cp") .ok_or_else(|| io::Error::other(format!("{}: missing cp output", self.label())))?; let subject = match self { Self::Ode => { - model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; - model.route_index("iv").ok_or_else(|| { - io::Error::other(format!("{}: missing iv route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) .bolus(0.0, 120.0, "oral") @@ -233,15 +246,30 @@ impl CorpusCase { .build() } Self::OdeFull => { - model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; - model.route_index("load").ok_or_else(|| { - io::Error::other(format!("{}: missing load route", self.label())) - })?; - model.route_index("iv").ok_or_else(|| { - io::Error::other(format!("{}: missing iv route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "load") + .ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; Subject::builder(self.label()) .bolus(0.0, 80.0, "load") .bolus(1.0, 120.0, "oral") @@ -261,9 +289,14 @@ impl CorpusCase { .build() } Self::Analytical => { - model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; Subject::builder(self.label()) .bolus(0.0, 100.0, "oral") .missing_observation(0.5, "cp") @@ -273,15 +306,30 @@ impl CorpusCase { .build() } Self::AnalyticalFull => { - model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; - model.route_index("load").ok_or_else(|| { - io::Error::other(format!("{}: missing load route", self.label())) - })?; - model.route_index("iv").ok_or_else(|| { - io::Error::other(format!("{}: missing iv route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "load") + .ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "iv") + .ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; Subject::builder(self.label()) .bolus(0.0, 60.0, "load") .bolus(1.0, 100.0, "oral") @@ -301,9 +349,14 @@ impl CorpusCase { .build() } Self::Sde => { - model.route_index("oral").ok_or_else(|| { - io::Error::other(format!("{}: missing oral route", self.label())) - })?; + model + .info() + .routes + .iter() + .find(|route| route.name == "oral") + .ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) .bolus(0.0, 80.0, "oral") diff --git a/tests/test_pf.rs b/tests/test_pf.rs index 1f9482f2..3a195ff0 100644 --- a/tests/test_pf.rs +++ b/tests/test_pf.rs @@ -34,7 +34,7 @@ fn test_particle_filter_likelihood() { .with_nstates(2) .with_nout(1); - let ems = AssayErrorModels::new() + let ems = AssayErrorModels::default() .add( 0, AssayErrorModel::additive(ErrorPoly::new(0.5, 0.0, 0.0, 0.0), 0.0), From 2d675fd0fc319c2a9ae7248877e008b733c1aba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 8 May 2026 21:45:12 +0100 Subject: [PATCH 20/22] feat: Answering Markus' comments --- examples/analytical_vs_ode.rs | 14 ++++++++++++-- pharmsol-macros/Cargo.toml | 2 +- pharmsol-macros/src/lib.rs | 4 ++-- src/data/error_model.rs | 4 ++-- src/dsl/native.rs | 1 + tests/authoring_parity_corpus.rs | 4 +--- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 48c4986f..fe5ada96 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -6,6 +6,9 @@ //! the predictions side by side so you can verify they match. //! Both authoring paths use the declaration-first macro surface so the //! example stays on the preferred public authoring story. +//! Built-in analytical structures are positional: the `params: [...]` +//! declaration becomes metadata, but the runtime kernel still expects values +//! in the structure's native positional order. //! //! cargo run --release --example analytical_vs_ode @@ -191,7 +194,7 @@ fn two_cmt_iv(params: &[f64]) { fn two_cmt_oral(params: &[f64]) { let analytical = analytical! { name: "two_cmt_oral", - params: [ka, ke, k12, k21, v], + params: [ke, ka, k12, k21, v], states: [gut, central, peripheral], outputs: [cp], routes: [ @@ -223,7 +226,14 @@ fn two_cmt_oral(params: &[f64]) { let subject = subject_oral("oral", "cp"); - let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + // `two_compartments_with_absorption` is positional and expects + // `ke, ka, k12, k21, v`, while the ODE closure below is authored as + // `ka, ke, k12, k21, v`. + let analytical_params = [params[1], params[0], params[2], params[3], params[4]]; + + let pred_a = analytical + .estimate_predictions(&subject, &analytical_params) + .unwrap(); let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("Two-compartment oral", &pred_a, &pred_o); } diff --git a/pharmsol-macros/Cargo.toml b/pharmsol-macros/Cargo.toml index d9fe58ec..07e3688f 100644 --- a/pharmsol-macros/Cargo.toml +++ b/pharmsol-macros/Cargo.toml @@ -13,4 +13,4 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.106" quote = "1.0.45" -syn = { version = "2.0.117", features = ["full", "visit"] } +syn = { version = "2.0.117", features = ["full", "visit", "visit-mut"] } diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 7e483951..83007df2 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -936,10 +936,10 @@ impl VisitMut for NumericLabelRewriter { return; }; - expr_index.index = Box::new(Expr::Lit(syn::ExprLit { + *expr_index.index = Expr::Lit(syn::ExprLit { attrs: Vec::new(), lit: Lit::Int(LitInt::new(&internal_index.to_string(), lit.span())), - })); + }); } fn visit_expr_macro_mut(&mut self, expr_macro: &mut syn::ExprMacro) { diff --git a/src/data/error_model.rs b/src/data/error_model.rs index 1e66eb60..0d52932f 100644 --- a/src/data/error_model.rs +++ b/src/data/error_model.rs @@ -1251,7 +1251,7 @@ mod tests { .unwrap(); assert_eq!(models.len(), 2); - assert!(matches!(models.error_model(1), Ok(_))); + assert!(models.error_model(1).is_ok()); } #[test] @@ -1265,7 +1265,7 @@ mod tests { let models = error_models.bind_output_names(["cp", "effect"]).unwrap(); assert_eq!(models.len(), 2); - assert!(matches!(models.error_model(1), Ok(_))); + assert!(models.error_model(1).is_ok()); } #[test] diff --git a/src/dsl/native.rs b/src/dsl/native.rs index d35a392b..d325226f 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1316,6 +1316,7 @@ impl NativeAnalyticalModel { Ok(output) } + #[allow(clippy::too_many_arguments)] fn solve_interval( &self, session: &mut dyn KernelSession, diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index e0b86744..38fd4c9d 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -1398,9 +1398,7 @@ fn invalid_dsl_infusion_route_properties_fail_explicitly() { let model = parse_model(ODE_INVALID_INFUSION_LAG_DSL).expect("invalid DSL fixture should parse"); let typed = analyze_model(&model).expect("invalid DSL fixture should analyze"); - let error = lower_typed_model(&typed) - .err() - .expect("infusion lag should fail during lowering"); + let error = lower_typed_model(&typed).expect_err("infusion lag should fail during lowering"); assert!(error .to_string() From 31f418f38bb692f881d044b6ca7b4a066c44a87f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Sun, 10 May 2026 08:09:29 +0100 Subject: [PATCH 21/22] chore: remove duplicated definitions --- pharmsol-dsl/src/authoring.rs | 5 ++--- pharmsol-dsl/src/lib.rs | 4 ++++ pharmsol-dsl/src/semantic.rs | 12 ++++-------- src/dsl/native.rs | 4 +--- src/simulator/equation/mod.rs | 4 +--- 5 files changed, 12 insertions(+), 17 deletions(-) diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 153760a0..984ea4db 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -3,10 +3,9 @@ use std::collections::{BTreeMap, BTreeSet}; use super::ast::*; use super::diagnostic::{Applicability, DiagnosticSuggestion, ParseError, Span, TextEdit}; use super::parser::{parse_expr_fragment, parse_place_fragment}; +use crate::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME}; const DEFAULT_MODEL_NAME: &str = "main"; -const NUMERIC_ROUTE_PREFIX: &str = "input_"; -const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; pub(super) fn parse_module(src: &str) -> Result { AuthoringParser::new(src).parse_module() @@ -802,7 +801,7 @@ fn inject_infusion_rates( let rate_expr = Expr { span: surface_route.span, kind: ExprKind::Call { - callee: Ident::new("rate", surface_route.input.span), + callee: Ident::new(RATE_FUNCTION_NAME, surface_route.input.span), args: vec![Expr { span: surface_route.input.span, kind: ExprKind::Name(surface_route.input.clone()), diff --git a/pharmsol-dsl/src/lib.rs b/pharmsol-dsl/src/lib.rs index 351b041c..6cb674a6 100644 --- a/pharmsol-dsl/src/lib.rs +++ b/pharmsol-dsl/src/lib.rs @@ -79,6 +79,10 @@ mod semantic; #[cfg(test)] mod test_fixtures; +pub const NUMERIC_ROUTE_PREFIX: &str = "input_"; +pub const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; +pub(crate) const RATE_FUNCTION_NAME: &str = "rate"; + pub use ast::*; pub use diagnostic::*; pub use execution::{ diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index 2f043ac1..3f454983 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -8,7 +8,7 @@ use crate::diagnostic::{ TextEdit, DSL_SEMANTIC_GENERIC, }; use crate::ir::*; -use crate::ModelKind; +use crate::{ModelKind, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX, RATE_FUNCTION_NAME}; const RESERVED_NAMES: &[&str] = &[ "abs", @@ -29,7 +29,7 @@ const RESERVED_NAMES: &[&str] = &[ "min", "noise", "pow", - "rate", + RATE_FUNCTION_NAME, "round", "sin", "cos", @@ -37,10 +37,6 @@ const RESERVED_NAMES: &[&str] = &[ "sqrt", ]; -const RATE_FUNCTION_NAME: &str = "rate"; -const NUMERIC_ROUTE_PREFIX: &str = "input_"; -const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; - #[derive(Default)] struct SemanticAssist { context_labels: Vec<(Span, String)>, @@ -1314,7 +1310,7 @@ impl<'a> Analyzer<'a> { span: Span, env: &BlockEnv, ) -> Result { - if callee.text == "rate" { + if callee.text == RATE_FUNCTION_NAME { if args.len() != 1 { return Err(SemanticError::new( format!( @@ -1575,7 +1571,7 @@ impl<'a> Analyzer<'a> { }) } syntax::ExprKind::Call { callee, args } => { - if callee.text == "rate" { + if callee.text == RATE_FUNCTION_NAME { return Err(SemanticError::new( "`rate(...)` cannot appear in a compile-time expression", callee.span, diff --git a/src/dsl/native.rs b/src/dsl/native.rs index d325226f..e04b7610 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -15,7 +15,7 @@ use cranelift_jit::JITModule; #[cfg(feature = "dsl-aot-load")] use libloading::Library; use pharmsol_dsl::execution::KernelRole; -use pharmsol_dsl::{AnalyticalKernel, RouteKind}; +use pharmsol_dsl::{AnalyticalKernel, RouteKind, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, @@ -48,8 +48,6 @@ pub type DenseKernelFn = unsafe extern "C" fn( const DEFAULT_ODE_RTOL: f64 = 1e-4; const DEFAULT_ODE_ATOL: f64 = 1e-4; -const NUMERIC_ROUTE_PREFIX: &str = "input_"; -const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum RuntimeBackend { diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 838124dc..37e0ff29 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -55,6 +55,7 @@ pub use analytical::*; pub use metadata::*; pub use ode::*; pub use pharmsol_dsl::{AnalyticalKernel, ModelKind}; +use pharmsol_dsl::{NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; pub use sde::*; use crate::{ @@ -66,9 +67,6 @@ use crate::{ use super::likelihood::Prediction; -const NUMERIC_ROUTE_PREFIX: &str = "input_"; -const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; - /// Trait for state vectors that can receive bolus doses. pub trait State { /// Add a bolus dose to the state at the specified resolved input index. From 8517fa8cc9bdd07bcd11884ec84f799487aaa364 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Mon, 11 May 2026 10:16:24 +0100 Subject: [PATCH 22/22] feat: applying @mhovd's suggestions --- Cargo.toml | 4 + benches/likelihood_matrix.rs | 247 ++++++++++++++++++++++++++ pharmsol-dsl/src/authoring.rs | 22 +-- pharmsol-dsl/src/lib.rs | 2 + src/data/structs.rs | 4 + src/dsl/native.rs | 100 +++++++++-- src/dsl/runtime.rs | 55 +++++- src/error/mod.rs | 4 + src/simulator/equation/mod.rs | 13 +- src/simulator/equation/ode/closure.rs | 47 ++--- src/simulator/equation/ode/mod.rs | 70 ++++++-- src/simulator/likelihood/matrix.rs | 50 +++--- src/simulator/likelihood/mod.rs | 12 +- 13 files changed, 528 insertions(+), 102 deletions(-) create mode 100644 benches/likelihood_matrix.rs diff --git a/Cargo.toml b/Cargo.toml index 35f233d1..18698813 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -79,3 +79,7 @@ harness = false [[bench]] name = "runtime_matrix" harness = false + +[[bench]] +name = "likelihood_matrix" +harness = false diff --git a/benches/likelihood_matrix.rs b/benches/likelihood_matrix.rs new file mode 100644 index 00000000..b8f94d54 --- /dev/null +++ b/benches/likelihood_matrix.rs @@ -0,0 +1,247 @@ +use criterion::{ + criterion_group, criterion_main, BenchmarkId, Criterion, SamplingMode, Throughput, +}; +use ndarray::Array2; +use pharmsol::prelude::simulator::{log_likelihood_batch, log_likelihood_matrix}; +use pharmsol::prelude::*; +use pharmsol::{Cache, ResidualErrorModel, ResidualErrorModels, ODE}; +use std::hint::black_box; +use std::time::Duration; + +fn example_equation() -> ODE { + equation::ODE::new( + |x, p, _t, dx, _b, _rateiv, _cov| { + fetch_params!(p, ka, ke, _tlag, _v); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, tlag, _v); + lag! {0=>tlag} + }, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _tlag, v); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_nout(1) +} + +fn example_data(n_subjects: usize) -> Data { + let subjects = (0..n_subjects) + .map(|index| { + let dose = 100.0 + index as f64; + let offset = index as f64 * 0.02; + Subject::builder(format!("subject_{index}")) + .bolus(0.0, dose, 0) + .observation(0.5, 4.2 + offset, 0) + .observation(1.0, 5.9 + offset, 0) + .observation(2.0, 7.4 + offset, 0) + .observation(4.0, 6.8 + offset, 0) + .observation(8.0, 4.7 + offset, 0) + .observation(12.0, 3.3 + offset, 0) + .build() + }) + .collect(); + + Data::new(subjects) +} + +fn example_subject() -> Subject { + example_data(1) + .get_subject("subject_0") + .expect("example subject exists") + .clone() +} + +fn example_params() -> [f64; 4] { + [0.60, 0.09, 0.05, 20.0] +} + +fn support_points(n_support_points: usize) -> Array2 { + Array2::from_shape_fn((n_support_points, 4), |(row, column)| match column { + 0 => 0.55 + row as f64 * 0.002, + 1 => 0.08 + row as f64 * 0.0004, + 2 => 0.05, + 3 => 18.0 + row as f64 * 0.15, + _ => unreachable!(), + }) +} + +fn batch_parameters(n_subjects: usize) -> Array2 { + Array2::from_shape_fn((n_subjects, 4), |(row, column)| match column { + 0 => 0.60 + row as f64 * 0.001, + 1 => 0.09 + row as f64 * 0.0002, + 2 => 0.05, + 3 => 20.0 + row as f64 * 0.05, + _ => unreachable!(), + }) +} + +fn assay_error_models() -> AssayErrorModels { + AssayErrorModels::new() + .add( + 0, + AssayErrorModel::additive(ErrorPoly::new(0.0, 0.1, 0.0, 0.0), 0.0), + ) + .unwrap() +} + +fn residual_error_models() -> ResidualErrorModels { + ResidualErrorModels::new().add(0, ResidualErrorModel::constant(0.2)) +} + +fn criterion_benchmark(c: &mut Criterion) { + let assay_error_models = assay_error_models(); + let residual_error_models = residual_error_models(); + + let subject = example_subject(); + let params = example_params(); + let equation_cold = example_equation().disable_cache(); + let equation_hot = example_equation(); + let _ = equation_hot + .estimate_predictions(&subject, params.as_slice()) + .unwrap(); + let predictions = equation_cold + .estimate_predictions(&subject, params.as_slice()) + .unwrap(); + + let mut breakdown_group = c.benchmark_group("ode/runtime-breakdown"); + breakdown_group.sample_size(10); + breakdown_group.warm_up_time(Duration::from_millis(250)); + breakdown_group.measurement_time(Duration::from_secs(1)); + breakdown_group.bench_function("predict-cold", |b| { + b.iter(|| { + black_box( + equation_cold + .estimate_predictions(&subject, params.as_slice()) + .unwrap(), + ) + }) + }); + breakdown_group.bench_function("predict-hot", |b| { + b.iter(|| { + black_box( + equation_hot + .estimate_predictions(&subject, params.as_slice()) + .unwrap(), + ) + }) + }); + breakdown_group.bench_function("score-only", |b| { + b.iter(|| black_box(predictions.log_likelihood(&assay_error_models)).unwrap()) + }); + breakdown_group.bench_function("loglik-cold", |b| { + b.iter(|| { + black_box( + equation_cold + .estimate_log_likelihood(&subject, params.as_slice(), &assay_error_models) + .unwrap(), + ) + }) + }); + breakdown_group.bench_function("loglik-hot", |b| { + b.iter(|| { + black_box( + equation_hot + .estimate_log_likelihood(&subject, params.as_slice(), &assay_error_models) + .unwrap(), + ) + }) + }); + breakdown_group.finish(); + + let mut matrix_group = c.benchmark_group("likelihood/matrix"); + matrix_group.sampling_mode(SamplingMode::Flat); + matrix_group.sample_size(10); + matrix_group.warm_up_time(Duration::from_millis(250)); + matrix_group.measurement_time(Duration::from_secs(2)); + + for (n_subjects, n_support_points) in [(16usize, 32usize), (64usize, 128usize)] { + let case = (example_data(n_subjects), support_points(n_support_points)); + let equation_cold = example_equation().disable_cache(); + let equation_hot = example_equation(); + let _ = log_likelihood_matrix(&equation_hot, &case.0, &case.1, &assay_error_models, false) + .unwrap(); + + matrix_group.throughput(Throughput::Elements((n_subjects * n_support_points) as u64)); + matrix_group.bench_with_input( + BenchmarkId::new("cold", format!("{n_subjects}x{n_support_points}")), + &case, + |b, (data, theta)| { + b.iter(|| { + black_box(log_likelihood_matrix( + &equation_cold, + data, + theta, + &assay_error_models, + false, + )) + .unwrap() + }) + }, + ); + matrix_group.bench_with_input( + BenchmarkId::new("hot", format!("{n_subjects}x{n_support_points}")), + &case, + |b, (data, theta)| { + b.iter(|| { + black_box(log_likelihood_matrix( + &equation_hot, + data, + theta, + &assay_error_models, + false, + )) + .unwrap() + }) + }, + ); + } + matrix_group.finish(); + + let mut batch_group = c.benchmark_group("likelihood/batch"); + batch_group.sample_size(10); + batch_group.warm_up_time(Duration::from_millis(250)); + batch_group.measurement_time(Duration::from_secs(1)); + + for n_subjects in [16usize, 64usize, 256usize] { + let data = example_data(n_subjects); + let parameters = batch_parameters(n_subjects); + let equation_cold = example_equation().disable_cache(); + let equation_hot = example_equation(); + let _ = log_likelihood_batch(&equation_hot, &data, ¶meters, &residual_error_models) + .unwrap(); + + batch_group.throughput(Throughput::Elements(n_subjects as u64)); + batch_group.bench_with_input(BenchmarkId::new("cold", n_subjects), &data, |b, data| { + b.iter(|| { + black_box(log_likelihood_batch( + &equation_cold, + data, + ¶meters, + &residual_error_models, + )) + .unwrap() + }) + }); + batch_group.bench_with_input(BenchmarkId::new("hot", n_subjects), &data, |b, data| { + b.iter(|| { + black_box(log_likelihood_batch( + &equation_hot, + data, + ¶meters, + &residual_error_models, + )) + .unwrap() + }) + }); + } + batch_group.finish(); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 984ea4db..04a64290 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -173,7 +173,7 @@ impl<'a> AuthoringParser<'a> { let kind = self.determine_kind(module_span)?; if matches!(kind, ModelKind::Analytical) && !self.derivative_statements.is_empty() { return Err(ParseError::new( - "analytical authoring models cannot declare `dx(...)` equations", + "analytical models cannot declare `dx(...)` equations", self.derivative_statements[0].span, )); } @@ -312,7 +312,7 @@ impl<'a> AuthoringParser<'a> { let eq_index = find_top_level_assignment(trimmed).ok_or_else(|| { ParseError::new( - "expected an authoring declaration, equation, or route shorthand", + "expected an declaration, equation, or route shorthand", span, ) })?; @@ -605,7 +605,7 @@ impl<'a> AuthoringParser<'a> { } other => { return Err(ParseError::new( - format!("unsupported authoring equation target `{other}`"), + format!("unsupported equation target `{other}`"), call.callee.span, )) } @@ -632,14 +632,14 @@ impl<'a> AuthoringParser<'a> { && (!self.diffusion_statements.is_empty() || self.particles.is_some()) { return Err(ParseError::new( - "analytical authoring models cannot declare particles or noise equations", + "analytical models cannot declare particles or noise equations", kind_span, )); } if matches!(kind, ModelKind::Ode) && !self.diffusion_statements.is_empty() { return Err(ParseError::new( - "ODE authoring models cannot declare `noise(...)` equations", + "ODE models cannot declare `noise(...)` equations", self.diffusion_statements[0].span, )); } @@ -647,7 +647,7 @@ impl<'a> AuthoringParser<'a> { if matches!(kind, ModelKind::Sde) { if let Some(analytical) = &self.analytical { return Err(ParseError::new( - "SDE authoring models cannot declare an analytical structure", + "SDE models cannot declare an analytical structure", analytical.span, )); } @@ -841,13 +841,13 @@ fn parse_call_head<'a>(src: &'a str, abs_start: usize) -> Result Result { let rest_abs = abs_start + 2 + rest_leading; if !rest.starts_with('(') { return Err(ParseError::new( - "expected `(` after `if` in authoring conditional expression", + "expected `(` after `if` in conditional expression", Span::new(rest_abs, rest_abs + rest.len().min(1)), )); } let close = find_matching_delimiter(rest, '(', ')').ok_or_else(|| { ParseError::new( - "unclosed `(` in authoring conditional expression", + "unclosed `(` in conditional expression", Span::new(rest_abs, rest_abs + rest.len()), ) })?; @@ -1127,7 +1127,7 @@ fn parse_if_rhs(src: &str, abs_start: usize) -> Result { let remaining_abs = rest_abs + close + 1; let else_index = find_top_level_keyword(remaining, "else").ok_or_else(|| { ParseError::new( - "expected `else` in authoring conditional expression", + "expected `else` in conditional expression", Span::new(remaining_abs, remaining_abs + remaining.len()), ) })?; diff --git a/pharmsol-dsl/src/lib.rs b/pharmsol-dsl/src/lib.rs index 6cb674a6..261efabf 100644 --- a/pharmsol-dsl/src/lib.rs +++ b/pharmsol-dsl/src/lib.rs @@ -79,7 +79,9 @@ mod semantic; #[cfg(test)] mod test_fixtures; +/// Canonical prefix for numeric route labels such as `input_1`. pub const NUMERIC_ROUTE_PREFIX: &str = "input_"; +/// Canonical prefix for numeric output labels such as `outeq_1`. pub const NUMERIC_OUTPUT_PREFIX: &str = "outeq_"; pub(crate) const RATE_FUNCTION_NAME: &str = "rate"; diff --git a/src/data/structs.rs b/src/data/structs.rs index d7d123b1..a116d703 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -59,6 +59,10 @@ impl Data { self.subjects.iter().collect() } + pub(crate) fn subjects_slice(&self) -> &[Subject] { + &self.subjects + } + /// Add a subject to the dataset /// /// # Arguments diff --git a/src/dsl/native.rs b/src/dsl/native.rs index e04b7610..7197084e 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -409,10 +409,7 @@ impl SharedNativeModel { return Ok(()); } - Err(PharmsolError::OtherError(format!( - "model `{}` does not declare a {:?} route for input {}", - self.info.name, kind, input - ))) + Err(PharmsolError::UnsupportedInputRouteKind { input, kind }) } fn resolve_input_label( @@ -683,15 +680,12 @@ impl SharedNativeModel { amount: f64, ) -> Result<(), PharmsolError> { self.validate_input_for_kind(input, RouteKind::Bolus)?; - let destination = self - .route_semantics - .bolus_destination(input) - .ok_or_else(|| { - PharmsolError::OtherError(format!( - "model `{}` does not declare a bolus route for input index {}", - self.info.name, input - )) - })?; + let destination = self.route_semantics.bolus_destination(input).ok_or( + PharmsolError::UnsupportedInputRouteKind { + input, + kind: RouteKind::Bolus, + }, + )?; state[destination] += amount; Ok(()) } @@ -814,7 +808,6 @@ impl NativeOdeModel { _ => None, }) .collect::>(); - let infusion_refs = infusions.iter().collect::>(); let session = RefCell::new(self.shared.artifact.start_session()?); let mut route_session = session.borrow_mut(); self.shared.apply_route_properties( @@ -899,20 +892,20 @@ impl NativeOdeModel { }, NalgebraContext, ); + let support_point_vec = support_point.to_vec(); let problem = OdeBuilder::::new() .atol(vec![self.atol]) .rtol(self.rtol) .t0(occasion.initial_time()) .h0(1e-3) - .p(support_point.to_vec()) + .p(support_point_vec.clone()) .build_from_eqn(PMProblem::with_params_v( diffeq, self.shared.info.state_len, self.shared.info.route_len, - support_point.to_vec(), support_vector.clone(), occasion.covariates(), - infusion_refs.as_slice(), + infusions.iter(), initial_state, )?)?; @@ -1874,7 +1867,61 @@ fn apply_analytical_kernel( #[cfg(test)] mod tests { - use super::{canonical_numeric_alias, NUMERIC_OUTPUT_PREFIX, NUMERIC_ROUTE_PREFIX}; + use super::{ + canonical_numeric_alias, KernelSession, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, + RuntimeArtifact, RuntimeBackend, SharedNativeModel, NUMERIC_OUTPUT_PREFIX, + NUMERIC_ROUTE_PREFIX, + }; + use crate::PharmsolError; + use pharmsol_dsl::execution::KernelRole; + use pharmsol_dsl::{ModelKind, RouteKind}; + + #[derive(Debug)] + struct DummyArtifact; + + impl RuntimeArtifact for DummyArtifact { + fn backend(&self) -> RuntimeBackend { + panic!("dummy artifact backend should not be used in tests") + } + + fn has_kernel(&self, _role: KernelRole) -> bool { + false + } + + fn start_session(&self) -> Result, PharmsolError> { + panic!("dummy artifact sessions should not be used in tests") + } + } + + fn bolus_only_shared_model() -> SharedNativeModel { + SharedNativeModel::new( + NativeModelInfo { + name: "bolus_only".to_string(), + kind: ModelKind::Ode, + parameters: Vec::new(), + covariates: Vec::new(), + routes: vec![NativeRouteInfo { + name: "oral".to_string(), + declaration_index: 0, + index: 0, + kind: Some(RouteKind::Bolus), + destination_offset: 0, + inject_input_to_destination: false, + }], + outputs: vec![NativeOutputInfo { + name: "cp".to_string(), + index: 0, + }], + state_len: 1, + derived_len: 0, + output_len: 1, + route_len: 1, + analytical: None, + particles: None, + }, + DummyArtifact, + ) + } #[test] fn canonical_numeric_alias_maps_bare_numeric_labels_to_contextual_prefixes() { @@ -1900,4 +1947,21 @@ mod tests { None ); } + + #[test] + fn validate_input_for_kind_reports_structured_route_kind_error() { + let model = bolus_only_shared_model(); + + let error = model + .validate_input_for_kind(0, RouteKind::Infusion) + .expect_err("bolus-only route should reject infusion usage"); + + assert!(matches!( + error, + PharmsolError::UnsupportedInputRouteKind { + input: 0, + kind: RouteKind::Infusion, + } + )); + } } diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1b2ecb76..7ccb132d 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -477,7 +477,7 @@ mod tests { use crate::PharmsolError; use crate::SubjectBuilderExt; use approx::assert_relative_eq; - use pharmsol_dsl::{DiagnosticPhase, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; + use pharmsol_dsl::{DiagnosticPhase, RouteKind, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; use tempfile::tempdir; const MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL: &str = r#" @@ -706,6 +706,13 @@ out(cp) = central / v ~ continuous() .build() } + fn mismatched_route_kind_subject() -> Subject { + Subject::builder("mismatched-route-kind-runtime") + .infusion(0.0, 120.0, "10", 1.0) + .missing_observation(0.5, "cp") + .build() + } + fn assert_unknown_output_label( model: &CompiledRuntimeModel, subject: &Subject, @@ -738,6 +745,27 @@ out(cp) = central / v ~ continuous() )); } + fn assert_unsupported_input_route_kind( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_input: usize, + expected_kind: RouteKind, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("mismatched route kind should fail"); + + match error { + RuntimeError::Runtime(PharmsolError::UnsupportedInputRouteKind { input, kind }) + if input == expected_input && kind == expected_kind => {} + other => panic!( + "expected UnsupportedInputRouteKind {{ input: {expected_input}, kind: {:?} }}, got {other:?}", + expected_kind + ), + } + } + #[test] fn runtime_backend_matrix_matches_ode_predictions() { let work_dir = tempdir().expect("tempdir"); @@ -806,6 +834,31 @@ out(cp) = central / v ~ continuous() } } + #[test] + fn runtime_backend_matrix_reports_route_kind_mismatch() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let subject = mismatched_route_kind_subject(); + + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "prefixed_numeric_route_runtime", + work_dir.path(), + ); + let expected_input = + compiled_route_input_index(&jit, "input_10").expect("input_10 route index"); + + for model in [&jit, &aot, &wasm] { + assert_unsupported_input_route_kind( + model, + &subject, + &support, + expected_input, + RouteKind::Infusion, + ); + } + } + #[test] fn runtime_backend_matrix_preserves_multi_digit_output_label_order() { let work_dir = tempdir().expect("tempdir"); diff --git a/src/error/mod.rs b/src/error/mod.rs index c8f70b58..1e97aee0 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -1,5 +1,7 @@ use thiserror::Error; +use pharmsol_dsl::RouteKind; + #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] use crate::data::error_model::ErrorModelError; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] @@ -41,6 +43,8 @@ pub enum PharmsolError { UnknownInputLabel { label: String }, #[error("Output label `{label}` could not be resolved to an output")] UnknownOutputLabel { label: String }, + #[error("Input index {input} does not support route kind {kind:?}")] + UnsupportedInputRouteKind { input: usize, kind: RouteKind }, #[error("Input index {input} is out of range (ndrugs = {ndrugs})")] InputOutOfRange { input: usize, ndrugs: usize }, #[error("Output equation {outeq} is out of range (nout = {nout})")] diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 37e0ff29..94ca5ccf 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -210,12 +210,13 @@ pub(crate) trait EquationPriv: EquationTypes { })?; if route.kind() != expected_kind { - return Err(PharmsolError::OtherError(format!( - "input label `{}` is declared as {:?} but used as {:?}", - label, - route.kind(), - expected_kind - ))); + return Err(PharmsolError::UnsupportedInputRouteKind { + input: route.input_index(), + kind: match expected_kind { + RouteKind::Bolus => pharmsol_dsl::RouteKind::Bolus, + RouteKind::Infusion => pharmsol_dsl::RouteKind::Infusion, + }, + }); } return Ok(route.input_index()); diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index 47f2a81e..9f5ab3c1 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -1,9 +1,8 @@ use crate::{Covariates, Infusion, PharmsolError}; use diffsol::{ ConstantOp, LinearOp, MatrixCommon, NalgebraContext, NalgebraMat, NonLinearOp, - NonLinearOpJacobian, OdeEquations, OdeEquationsRef, Op, UnitCallable, Vector, VectorCommon, + NonLinearOpJacobian, OdeEquations, OdeEquationsRef, Op, UnitCallable, Vector, }; -use nalgebra::DVector; use std::{cell::RefCell, cmp::Ordering}; type M = NalgebraMat; type V = ::V; @@ -67,13 +66,18 @@ struct InfusionSchedule { } impl InfusionSchedule { - fn new(ndrugs: usize, infusions: &[&Infusion]) -> Result { - if ndrugs == 0 || infusions.is_empty() { + fn new<'a, I>(ndrugs: usize, infusions: I) -> Result + where + I: IntoIterator, + { + if ndrugs == 0 { return Ok(Self { tracks: Vec::new() }); } let mut per_input: Vec> = vec![Vec::new(); ndrugs]; + let mut saw_infusion = false; for infusion in infusions { + saw_infusion = true; if infusion.duration() <= 0.0 { continue; } @@ -92,6 +96,10 @@ impl InfusionSchedule { per_input[input].push((infusion.time() + infusion.duration(), -rate)); } + if !saw_infusion { + return Ok(Self { tracks: Vec::new() }); + } + let tracks = per_input .into_iter() .enumerate() @@ -319,7 +327,6 @@ where nstates: usize, nparams: usize, init: V, - p: Vec, p_as_v: V, zero_bolus: V, covariates: &'a Covariates, @@ -334,17 +341,19 @@ where /// Creates a new PMProblem with a pre-converted parameter vector. /// This avoids an allocation when the caller already has a V representation. #[allow(clippy::too_many_arguments)] - pub fn with_params_v( + pub fn with_params_v<'b, I>( func: F, nstates: usize, ndrugs: usize, - p: Vec, p_as_v: V, covariates: &'a Covariates, - infusions: &[&'a Infusion], + infusions: I, init: V, - ) -> Result { - let nparams = p.len(); + ) -> Result + where + I: IntoIterator, + { + let nparams = p_as_v.len(); let rateiv_buffer = RefCell::new(V::zeros(ndrugs, NalgebraContext)); let infusion_schedule = InfusionSchedule::new(ndrugs, infusions)?; // Pre-allocate zero bolus vector @@ -355,7 +364,6 @@ where nstates, nparams, init, - p, p_as_v, zero_bolus, covariates, @@ -432,13 +440,10 @@ where } fn get_params(&self, p: &mut V) { - // Avoid unnecessary cloning by directly copying values from self.p - if p.len() == self.p.len() { - for i in 0..self.p.len() { - p[i] = self.p[i]; - } + if p.len() == self.p_as_v.len() { + p.copy_from(&self.p_as_v); } else { - p.copy_from(&DVector::from_vec(self.p.clone()).into()); + *p = self.p_as_v.clone(); } } @@ -455,13 +460,9 @@ where } fn set_params(&mut self, p: &V) { - if self.p.len() == p.len() { - for i in 0..p.len() { - self.p[i] = p[i]; - self.p_as_v[i] = p[i]; - } + if self.p_as_v.len() == p.len() { + self.p_as_v.copy_from(p); } else { - self.p = p.inner().iter().cloned().collect(); self.p_as_v = p.clone(); } } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index f982347f..13f0c2f3 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -514,6 +514,14 @@ impl Equation for ODE { _estimate_likelihood(self, subject, support_point, error_models) } + fn estimate_predictions( + &self, + subject: &Subject, + support_point: &[f64], + ) -> Result { + _subject_predictions(self, subject, support_point) + } + fn estimate_log_likelihood( &self, subject: &Subject, @@ -556,7 +564,8 @@ impl Equation for ODE { let zero_bolus = V::zeros(ndrugs, NalgebraContext); let zero_rateiv = V::zeros(ndrugs, NalgebraContext); let mut bolus_v = V::zeros(ndrugs, NalgebraContext); - let spp_v: V = DVector::from_vec(support_point.to_vec()).into(); + let support_point_vec = support_point.to_vec(); + let spp_v: V = DVector::from_vec(support_point_vec.clone()).into(); // Pre-allocate output vector for observations let mut y_out = V::zeros(self.get_nouteqs(), NalgebraContext); @@ -565,30 +574,25 @@ impl Equation for ODE { for occasion in subject.occasions() { let covariates = occasion.covariates(); let events = self.resolve_occasion_events(occasion, support_point, covariates)?; - let infusions = events - .iter() - .filter_map(|event| match event { - Event::Infusion(infusion) => Some(infusion), - _ => None, - }) - .collect::>(); let problem = OdeBuilder::::new() .atol(vec![self.atol]) .rtol(self.rtol) .t0(occasion.initial_time()) .h0(1e-3) - .p(support_point.to_vec()) + .p(support_point_vec.clone()) .build_from_eqn(PMProblem::with_params_v( move |x, p, t, dx, bolus, rateiv, cov| { (self.diffeq)(x, p, t, dx, bolus, rateiv, cov); }, nstates, ndrugs, - support_point.to_vec(), spp_v.clone(), covariates, - infusions.as_slice(), + events.iter().filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion), + _ => None, + }), self.initial_state(support_point, covariates, occasion.index()), )?)?; @@ -683,6 +687,9 @@ mod tests { use super::*; use crate::{fa, lag, Subject, SubjectBuilderExt}; use approx::assert_relative_eq; + use std::sync::atomic::{AtomicUsize, Ordering}; + + static PREDICTION_CACHE_DIFFEQ_CALLS: AtomicUsize = AtomicUsize::new(0); fn simple_ode() -> ODE { ODE::new( @@ -743,6 +750,19 @@ mod tests { y[0] = x[0]; } + fn counting_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + _b: &V, + _rateiv: &V, + _cov: &Covariates, + ) { + PREDICTION_CACHE_DIFFEQ_CALLS.fetch_add(1, Ordering::SeqCst); + dx[0] = 0.0; + } + #[test] fn handwritten_ode_metadata_exposes_name_lookup() { let ode = simple_ode() @@ -955,4 +975,32 @@ mod tests { assert!(ode.metadata().is_none()); } + + #[test] + fn handwritten_ode_estimate_predictions_uses_prediction_cache() { + PREDICTION_CACHE_DIFFEQ_CALLS.store(0, Ordering::SeqCst); + + let ode = ODE::new(counting_kernel, zero_lag, unit_fa, zero_init, state_output) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1); + let subject = Subject::builder("cached_predictions") + .bolus(0.0, 100.0, 0) + .observation(1.0, 0.0, 0) + .build(); + + let first = ode + .estimate_predictions(&subject, &[]) + .expect("first prediction run should succeed"); + let first_calls = PREDICTION_CACHE_DIFFEQ_CALLS.load(Ordering::SeqCst); + assert!(first_calls > 0); + + let second = ode + .estimate_predictions(&subject, &[]) + .expect("second prediction run should succeed"); + let second_calls = PREDICTION_CACHE_DIFFEQ_CALLS.load(Ordering::SeqCst); + + assert_eq!(first.predictions().len(), second.predictions().len()); + assert_eq!(first_calls, second_calls); + } } diff --git a/src/simulator/likelihood/matrix.rs b/src/simulator/likelihood/matrix.rs index dae2cac4..9512a191 100644 --- a/src/simulator/likelihood/matrix.rs +++ b/src/simulator/likelihood/matrix.rs @@ -47,16 +47,20 @@ pub fn log_likelihood_matrix( error_models: &AssayErrorModels, progress: bool, ) -> Result, PharmsolError> { - let mut log_psi: Array2 = Array2::default((subjects.len(), support_points.nrows()).f()); - - let subjects_vec = subjects.subjects(); + let n_support_points = support_points.nrows(); + let mut log_psi: Array2 = Array2::default((subjects.len(), n_support_points).f()); + let subject_slice = subjects.subjects_slice(); + let support_point_rows = support_points + .axis_iter(Axis(0)) + .map(|row| row.to_vec()) + .collect::>(); let progress_tracker = if progress { - let total = subjects_vec.len() * support_points.nrows(); + let total = subject_slice.len() * n_support_points; println!( "Computing log-likelihood matrix: {} subjects × {} support points...", - subjects_vec.len(), - support_points.nrows() + subject_slice.len(), + n_support_points ); Some(ProgressTracker::new(total)) } else { @@ -68,26 +72,20 @@ pub fn log_likelihood_matrix( .into_par_iter() .enumerate() .try_for_each(|(i, mut row)| { - row.axis_iter_mut(Axis(0)) - .into_par_iter() - .enumerate() - .try_for_each(|(j, mut element)| { - let subject = subjects_vec.get(i).unwrap(); - match equation.estimate_log_likelihood( - subject, - &support_points.row(j).to_vec(), - error_models, - ) { - Ok(log_likelihood) => { - element.fill(log_likelihood); - if let Some(ref tracker) = progress_tracker { - tracker.inc(); - } - } - Err(e) => return Err(e), - }; - Ok(()) - }) + let subject = &subject_slice[i]; + + for (element, support_point) in row.iter_mut().zip(support_point_rows.iter()) { + *element = equation.estimate_log_likelihood( + subject, + support_point.as_slice(), + error_models, + )?; + if let Some(ref tracker) = progress_tracker { + tracker.inc(); + } + } + + Ok(()) }); if let Some(tracker) = progress_tracker { diff --git a/src/simulator/likelihood/mod.rs b/src/simulator/likelihood/mod.rs index 30ad4945..17d07a19 100644 --- a/src/simulator/likelihood/mod.rs +++ b/src/simulator/likelihood/mod.rs @@ -111,8 +111,8 @@ pub fn log_likelihood_batch( parameters: &Array2, residual_error_models: &crate::ResidualErrorModels, ) -> Result, PharmsolError> { - let subjects_vec = subjects.subjects(); - let n_subjects = subjects_vec.len(); + let subject_slice = subjects.subjects_slice(); + let n_subjects = subject_slice.len(); if parameters.nrows() != n_subjects { return Err(PharmsolError::OtherError(format!( @@ -123,10 +123,10 @@ pub fn log_likelihood_batch( } // Parallel computation across subjects - let results: Vec = (0..n_subjects) - .into_par_iter() - .map(|i| { - let subject = &subjects_vec[i]; + let results: Vec = subject_slice + .par_iter() + .enumerate() + .map(|(i, subject)| { let params = parameters.row(i).to_vec(); // Simulate to get predictions