Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pharmsol-dsl/src/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub struct TypedConstant {
pub span: Span,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CovariateInterpolation {
Linear,
Locf,
Expand Down
10 changes: 8 additions & 2 deletions src/dsl/aot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,18 @@ pub fn load_aot_model(path: impl AsRef<Path>) -> Result<CompiledNativeModel, Aot
};

Ok(match info.kind {
ModelKind::Ode => CompiledNativeModel::Ode(super::NativeOdeModel::new(info, artifact)),
ModelKind::Ode => CompiledNativeModel::Ode(
super::NativeOdeModel::new(info, artifact)
.map_err(|error| AotError::Load(error.to_string()))?,
),
ModelKind::Analytical => CompiledNativeModel::Analytical(
super::NativeAnalyticalModel::new(info, artifact)
.map_err(|error| AotError::Load(error.to_string()))?,
),
ModelKind::Sde => CompiledNativeModel::Sde(super::NativeSdeModel::new(info, artifact)),
ModelKind::Sde => CompiledNativeModel::Sde(
super::NativeSdeModel::new(info, artifact)
.map_err(|error| AotError::Load(error.to_string()))?,
),
})
}

Expand Down
23 changes: 20 additions & 3 deletions src/dsl/compiled_backend_abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,11 @@ fn kernel_output_len(info: &NativeModelInfo, role: KernelRole) -> usize {

#[cfg(test)]
mod tests {
use super::super::model_info::{NativeCovariateInfo, NativeOutputInfo, NativeRouteInfo};
use super::super::model_info::{
NativeCovariateInfo, NativeOutputInfo, NativeRouteInfo, NativeStateInfo,
};
use super::*;
use pharmsol_dsl::ModelKind;
use pharmsol_dsl::{ModelKind, RouteKind};

#[test]
fn compiled_backend_symbol_names_are_frozen() {
Expand Down Expand Up @@ -322,13 +324,27 @@ mod tests {
covariates: vec![NativeCovariateInfo {
name: "wt".to_string(),
index: 0,
interpolation: None,
}],
states: vec![
NativeStateInfo {
name: "depot".to_string(),
offset: 0,
},
NativeStateInfo {
name: "central".to_string(),
offset: 1,
},
],
routes: vec![NativeRouteInfo {
name: "iv".to_string(),
declaration_index: 0,
index: 0,
kind: None,
kind: Some(RouteKind::Infusion),
destination_offset: 1,
destination_name: "central".to_string(),
has_lag: false,
has_bioavailability: false,
inject_input_to_destination: true,
}],
outputs: vec![NativeOutputInfo {
Expand Down Expand Up @@ -367,6 +383,7 @@ mod tests {
parameters: vec![],
derived: vec!["ke_i".to_string(), "v_i".to_string(), "cl_i".to_string()],
covariates: vec![],
states: vec![],
routes: vec![],
outputs: vec![],
state_len: 2,
Expand Down
10 changes: 6 additions & 4 deletions src/dsl/jit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1284,10 +1284,11 @@ pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result<JitOdeModel, J
Some(model.span),
));
}
Ok(JitOdeModel::new(
JitOdeModel::new(
NativeModelInfo::from_execution_model(model),
compile_execution_artifact(model)?,
))
)
.map_err(|error| JitCompileError::new(error.to_string(), Some(model.span)))
}

/// Compile an analytical execution model to the native in-process JIT backend.
Expand Down Expand Up @@ -1321,10 +1322,11 @@ pub fn compile_sde_model_to_jit(model: &ExecutionModel) -> Result<JitSdeModel, J
Some(model.span),
));
}
Ok(JitSdeModel::new(
JitSdeModel::new(
NativeModelInfo::from_execution_model(model),
compile_execution_artifact(model)?,
))
)
.map_err(|error| JitCompileError::new(error.to_string(), Some(model.span)))
}

#[cfg(test)]
Expand Down
2 changes: 1 addition & 1 deletion src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub use runtime::{
compile_execution_model_to_runtime, compile_module_source_to_runtime, load_runtime_artifact,
CompiledRuntimeModel, RuntimeAnalyticalModel, RuntimeArtifactFormat, RuntimeCompilationTarget,
RuntimeCovariateInfo, RuntimeError, RuntimeModelInfo, RuntimeOdeModel, RuntimeOutputInfo,
RuntimePredictions, RuntimeRouteInfo, RuntimeSdeModel,
RuntimePredictions, RuntimeRouteInfo, RuntimeSdeModel, RuntimeStateInfo,
};
#[cfg(all(
feature = "dsl-wasm",
Expand Down
75 changes: 72 additions & 3 deletions src/dsl/model_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use pharmsol_dsl::execution::{
ExecutionExpr, ExecutionExprKind, ExecutionLoad, ExecutionModel, ExecutionStmt,
ExecutionStmtKind, KernelImplementation, KernelRole,
};
use pharmsol_dsl::{AnalyticalKernel, ModelKind, RouteKind};
use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind, RouteKind};

/// Public metadata extracted from a compiled backend model.
///
Expand All @@ -26,6 +26,8 @@ pub struct NativeModelInfo {
pub derived: Vec<String>,
/// Declared covariates and their dense runtime indices.
pub covariates: Vec<NativeCovariateInfo>,
/// Declared states together with their dense runtime offsets.
pub states: Vec<NativeStateInfo>,
Comment thread
Siel marked this conversation as resolved.
/// Declared routes together with declaration-order and dense runtime indices.
pub routes: Vec<NativeRouteInfo>,
/// Declared outputs and their dense runtime indices.
Expand All @@ -51,6 +53,17 @@ pub struct NativeCovariateInfo {
pub name: String,
/// Dense runtime covariate index.
pub index: usize,
/// Optional interpolation policy declared for this covariate.
pub interpolation: Option<CovariateInterpolation>,
}

/// Metadata for one compiled state.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct NativeStateInfo {
/// Public state name.
pub name: String,
/// Dense runtime state offset.
pub offset: usize,
}

/// Metadata for one compiled route.
Expand All @@ -59,15 +72,19 @@ pub struct NativeRouteInfo {
/// Public route label.
pub name: String,
/// Route position in declaration order.
#[serde(default)]
pub declaration_index: usize,
/// Dense runtime route-input index.
pub index: usize,
/// Coarse route kind when declared in metadata.
#[serde(default)]
pub kind: Option<RouteKind>,
/// Dense destination state offset used by compiled kernels.
pub destination_offset: usize,
/// Public destination state name.
pub destination_name: String,
/// Whether this route declares lag handling.
pub has_lag: bool,
/// Whether this route declares bioavailability handling.
pub has_bioavailability: bool,
/// Whether the compiled backend injects the route input into the destination
/// state automatically when the model does not read the route input
/// explicitly.
Expand Down Expand Up @@ -109,6 +126,16 @@ impl NativeModelInfo {
.map(|covariate| NativeCovariateInfo {
name: covariate.name.clone(),
index: covariate.index,
interpolation: covariate.interpolation,
})
.collect(),
states: model
.metadata
.states
.iter()
.map(|state| NativeStateInfo {
name: state.name.clone(),
offset: state.offset,
})
.collect(),
routes: model
Expand All @@ -121,6 +148,9 @@ impl NativeModelInfo {
index: route.index,
kind: route.kind,
destination_offset: route.destination.state_offset,
destination_name: route.destination.state_name.clone(),
has_lag: route.has_lag,
has_bioavailability: route.has_bioavailability,
inject_input_to_destination: !explicit_route_input_usage
.get(route.declaration_index)
.copied()
Expand Down Expand Up @@ -315,6 +345,45 @@ out(cp) = central / v ~ continuous()
assert!(!info.routes[1].inject_input_to_destination);
}

#[test]
fn native_model_info_preserves_state_covariate_and_route_metadata() {
let info = load_model_info(
r#"
name = metadata_surface
kind = ode

params = ke, v
covariates = wt@linear
states = depot, central
outputs = cp

bolus(oral) -> depot
infusion(iv) -> central
lag(oral) = 1.0
fa(oral) = 0.8

dx(depot) = -ke * depot
dx(central) = ke * depot - rate(iv)

out(cp) = central / v
"#,
);

assert_eq!(info.states.len(), 2);
assert_eq!(info.states[0].name, "depot");
assert_eq!(info.states[1].name, "central");
assert_eq!(
info.covariates[0].interpolation,
Some(CovariateInterpolation::Linear)
);
assert_eq!(info.routes[0].destination_name, "depot");
assert!(info.routes[0].has_lag);
assert!(info.routes[0].has_bioavailability);
assert_eq!(info.routes[1].destination_name, "central");
assert!(!info.routes[1].has_lag);
assert!(!info.routes[1].has_bioavailability);
}

#[test]
fn native_model_info_preserves_canonical_numeric_channel_names() {
let info = load_model_info(
Expand Down
Loading
Loading