From 395d465b9b29471c0ad508329bb270feb7f8c756 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 9 Dec 2025 13:34:15 +0000 Subject: [PATCH 01/32] Rough approach to patch models with diff files --- src/input.rs | 7 ++ src/input/patch.rs | 155 ++++++++++++++++++++++++++++++++++++++++ src/model/parameters.rs | 4 ++ 3 files changed, 166 insertions(+) create mode 100644 src/input/patch.rs diff --git a/src/input.rs b/src/input.rs index 08358facc..2fb84a867 100644 --- a/src/input.rs +++ b/src/input.rs @@ -24,6 +24,8 @@ mod asset; use asset::read_assets; mod commodity; use commodity::read_commodities; +mod patch; +use patch::patch_model; mod process; use process::read_processes; mod region; @@ -229,6 +231,11 @@ where pub fn load_model>(model_dir: P) -> Result<(Model, AssetPool)> { let model_params = ModelParameters::from_path(&model_dir)?; + // If `model_params` specifies a `base_dir`, patch the base model and load the patched model + if let Some(base_dir) = &model_params.base_model { + return load_model(patch_model(Path::new(base_dir), model_dir.as_ref())?); + } + let time_slice_info = read_time_slice_info(model_dir.as_ref())?; let regions = read_regions(model_dir.as_ref())?; let region_ids = regions.keys().cloned().collect(); diff --git a/src/input/patch.rs b/src/input/patch.rs new file mode 100644 index 000000000..815e6b51b --- /dev/null +++ b/src/input/patch.rs @@ -0,0 +1,155 @@ +//! Code for applying patches/diffs to model input files. +use super::input_err_msg; + +use anyhow::{Context, Result, bail, ensure}; +use std::fs; +use std::path::{Path, PathBuf}; +use tempfile::tempdir; + +/// Structure to hold diffs from a diff f +struct FileDiffs { + /// The header line from the diff file + header_line: String, + /// Lines to delete from the original file + to_delete: Vec, + /// Lines to add to the original file + to_add: Vec, +} + +/// Read diffs from a diff file. +/// +/// Reads a diff file where the first line is a header, and subsequent lines start with "-," for +/// deletions and "+," for additions. +fn read_diffs(file_path: &Path) -> Result { + // Read the entire file as a string + let content = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?; + + // Read header line + // This is saved to ensure that diffs are applied to a base file with the same header + let header_line = content + .lines() + .next() + .expect("Diff file cannot be empty") + .to_string(); + + // Collect additions and deletions + let mut to_delete = Vec::new(); + let mut to_add = Vec::new(); + for line in content.lines().skip(1) { + let line = line.trim(); + if let Some(stripped) = line.strip_prefix("-,") { + to_delete.push(stripped.trim().to_string()); + } else if let Some(stripped) = line.strip_prefix("+,") { + to_add.push(stripped.trim().to_string()); + } else { + bail!( + "Invalid line in diff file {}: {}", + file_path.display(), + line + ); + } + } + + Ok(FileDiffs { + header_line, + to_delete, + to_add, + }) +} + +/// Modify a string by applying diffs: removing lines and adding lines. +fn modify_string_with_diffs(original: &str, diffs: &FileDiffs) -> Result { + let mut modified = original.to_string(); + + // Check that the headers match + let original_header = original + .lines() + .next() + .expect("Original string cannot be empty"); + ensure!( + original_header == diffs.header_line, + "Header line in diff file does not match original string" + ); + + // Apply deletions + for item in &diffs.to_delete { + ensure!( + modified.contains(item), + "Item to delete not found in original string: {item}" + ); + modified = modified.replace(item, ""); + } + + // Apply additions + for item in &diffs.to_add { + modified.push_str(item); + } + + Ok(modified) +} + +pub fn patch_model>(model_dir: P, diffs_dir: P) -> Result { + // Copy contents of model_dir to a teporary directory + let temp_dir = tempdir().context("Failed to create temporary directory")?; + let temp_path = temp_dir.path(); + + for entry in fs::read_dir(model_dir.as_ref()).with_context(|| { + format!( + "Failed to read model directory: {}", + model_dir.as_ref().display() + ) + })? { + let entry = entry?; + let src_path = entry.path(); + + // Only copy files (skip any subdirectories if present) + if src_path.is_file() { + let dst_path = temp_path.join(entry.file_name()); + fs::copy(&src_path, &dst_path) + .with_context(|| format!("Failed to copy file: {}", src_path.display()))?; + } + } + + // Apply each patch from diffs_dir to the corresponding file in the temporary directory + for entry in fs::read_dir(diffs_dir.as_ref()).with_context(|| { + format!( + "Failed to read diffs directory: {}", + diffs_dir.as_ref().display() + ) + })? { + let entry = entry?; + let diff_path = entry.path(); + + // Only process files (skip any subdirectories if present) + if diff_path.is_file() { + let diff_filename = diff_path + .file_name() + .and_then(|name| name.to_str()) + .context("Failed to get diff filename")?; + + // Check that the filename ends with "_diff.csv" + ensure!( + diff_filename.ends_with("_diff.csv"), + "Diff file must end with '_diff.csv': {diff_filename}" + ); + + // Extract the base filename (e.g., "agents" from "agents_diff.csv") + let base_name = &diff_filename[..diff_filename.len() - "_diff.csv".len()]; + let target_filename = format!("{base_name}.csv"); + let file_path = temp_path.join(&target_filename); + + apply_patch_to_file(&file_path, &diff_path)?; + } + } + + // Return the path to the temporary directory + Ok(temp_path.to_path_buf()) +} + +fn apply_patch_to_file(file_path: &Path, diff_path: &Path) -> Result<()> { + let diffs = read_diffs(diff_path)?; + let original = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?; + let modified = modify_string_with_diffs(&original, &diffs)?; + fs::write(file_path, modified)?; + Ok(()) +} diff --git a/src/model/parameters.rs b/src/model/parameters.rs index 0311a5c60..d0c9f2b47 100644 --- a/src/model/parameters.rs +++ b/src/model/parameters.rs @@ -97,6 +97,10 @@ pub struct ModelParameters { /// Number of years an asset can remain unused before being decommissioned #[serde(default = "default_mothball_years")] pub mothball_years: u32, + /// Optional base model directory to use as a starting point, with this model's files applied + /// as patches/diffs. + #[serde(default)] + pub base_model: Option, } /// The strategy used for calculating commodity prices From b9764c290e06c340bbab9a301cbeb21cb47bdeb6 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 9 Dec 2025 14:06:04 +0000 Subject: [PATCH 02/32] Add tests --- src/input/patch.rs | 113 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 104 insertions(+), 9 deletions(-) diff --git a/src/input/patch.rs b/src/input/patch.rs index 815e6b51b..5d55bd21a 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -6,7 +6,8 @@ use std::fs; use std::path::{Path, PathBuf}; use tempfile::tempdir; -/// Structure to hold diffs from a diff f +/// Structure to hold diffs from a diff file +#[derive(Debug)] struct FileDiffs { /// The header line from the diff file header_line: String, @@ -42,11 +43,7 @@ fn read_diffs(file_path: &Path) -> Result { } else if let Some(stripped) = line.strip_prefix("+,") { to_add.push(stripped.trim().to_string()); } else { - bail!( - "Invalid line in diff file {}: {}", - file_path.display(), - line - ); + bail!("Invalid line in diff file: {line}"); } } @@ -68,14 +65,14 @@ fn modify_string_with_diffs(original: &str, diffs: &FileDiffs) -> Result .expect("Original string cannot be empty"); ensure!( original_header == diffs.header_line, - "Header line in diff file does not match original string" + "Header line in diff file does not match original file" ); // Apply deletions for item in &diffs.to_delete { ensure!( modified.contains(item), - "Item to delete not found in original string: {item}" + "Item to delete not found in original file: {item}" ); modified = modified.replace(item, ""); } @@ -147,9 +144,107 @@ pub fn patch_model>(model_dir: P, diffs_dir: P) -> Result Result<()> { - let diffs = read_diffs(diff_path)?; + let diffs = read_diffs(diff_path).with_context(|| input_err_msg(diff_path))?; let original = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?; let modified = modify_string_with_diffs(&original, &diffs)?; fs::write(file_path, modified)?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::fixture::assert_error; + use std::io::Write; + + #[test] + fn test_read_diffs_basic() { + let temp_dir = tempdir().unwrap(); + let diff_file = temp_dir.path().join("test_diff.csv"); + + let content = "header\n-,line_to_delete\n+,line_to_add\n"; + let mut file = fs::File::create(&diff_file).unwrap(); + file.write_all(content.as_bytes()).unwrap(); + + let diffs = read_diffs(&diff_file).unwrap(); + + assert_eq!(diffs.header_line, "header"); + assert_eq!(diffs.to_delete, vec!["line_to_delete"]); + assert_eq!(diffs.to_add, vec!["line_to_add"]); + } + + #[test] + fn test_read_diffs_with_whitespace() { + let temp_dir = tempdir().unwrap(); + let diff_file = temp_dir.path().join("test_diff.csv"); + + let content = "header\n-, item_with_spaces \n+, another_item \n"; + let mut file = fs::File::create(&diff_file).unwrap(); + file.write_all(content.as_bytes()).unwrap(); + + let diffs = read_diffs(&diff_file).unwrap(); + + // Whitespace should be trimmed + assert_eq!(diffs.to_delete, vec!["item_with_spaces"]); + assert_eq!(diffs.to_add, vec!["another_item"]); + } + + #[test] + fn test_read_diffs_invalid_line() { + let temp_dir = tempdir().unwrap(); + let diff_file = temp_dir.path().join("test_diff.csv"); + + let content = "header\ninvalid_line\n"; + let mut file = fs::File::create(&diff_file).unwrap(); + file.write_all(content.as_bytes()).unwrap(); + + let result = read_diffs(&diff_file); + assert_error!(result, "Invalid line in diff file: invalid_line"); + } + + #[test] + fn test_modify_string_with_diffs_basic() { + let original = "header\nline1\nline2\nline3\n"; + let diffs = FileDiffs { + header_line: "header".to_string(), + to_delete: vec!["line2".to_string()], + to_add: vec!["line_new".to_string()], + }; + + let modified = modify_string_with_diffs(original, &diffs).unwrap(); + assert!(!modified.contains("line2")); + assert!(modified.contains("line_new")); + } + + #[test] + fn test_modify_string_with_diffs_mismatched_header() { + let original = "header1\nline1\n"; + let diffs = FileDiffs { + header_line: "header2".to_string(), + to_delete: vec![], + to_add: vec![], + }; + + let result = modify_string_with_diffs(original, &diffs); + assert_error!( + result, + "Header line in diff file does not match original file" + ); + } + + #[test] + fn test_modify_string_with_diffs_missing_item() { + let original = "header\nline1\n"; + let diffs = FileDiffs { + header_line: "header".to_string(), + to_delete: vec!["nonexistent".to_string()], + to_add: vec![], + }; + + let result = modify_string_with_diffs(original, &diffs); + assert_error!( + result, + "Item to delete not found in original file: nonexistent" + ); + } +} From 1504078b03b03146bebb6086ff2060ddfc5eeaa5 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 9 Dec 2025 16:24:22 +0000 Subject: [PATCH 03/32] More robustness and error handling --- src/input.rs | 10 ++- src/input/patch.rs | 175 +++++++++++++++++++++++++++++----------- src/model/parameters.rs | 11 +++ 3 files changed, 148 insertions(+), 48 deletions(-) diff --git a/src/input.rs b/src/input.rs index 2fb84a867..2a41745cc 100644 --- a/src/input.rs +++ b/src/input.rs @@ -233,7 +233,15 @@ pub fn load_model>(model_dir: P) -> Result<(Model, AssetPool)> { // If `model_params` specifies a `base_dir`, patch the base model and load the patched model if let Some(base_dir) = &model_params.base_model { - return load_model(patch_model(Path::new(base_dir), model_dir.as_ref())?); + let patched_model = + patch_model(Path::new(base_dir), model_dir.as_ref()).with_context(|| { + format!( + "Error patching base model at {} with patches from {}", + base_dir, + model_dir.as_ref().display() + ) + })?; + return load_model(patched_model); } let time_slice_info = read_time_slice_info(model_dir.as_ref())?; diff --git a/src/input/patch.rs b/src/input/patch.rs index 5d55bd21a..1588de250 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -2,9 +2,10 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; +use log::info; use std::fs; -use std::path::{Path, PathBuf}; -use tempfile::tempdir; +use std::path::Path; +use tempfile::{TempDir, tempdir}; /// Structure to hold diffs from a diff file #[derive(Debug)] @@ -24,26 +25,29 @@ struct FileDiffs { fn read_diffs(file_path: &Path) -> Result { // Read the entire file as a string let content = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?; + let content = content.trim(); // Read header line // This is saved to ensure that diffs are applied to a base file with the same header - let header_line = content - .lines() - .next() - .expect("Diff file cannot be empty") - .to_string(); + let header_line = match content.lines().next() { + Some(line) => line.trim().to_string(), + None => bail!("Diff file cannot be empty"), + }; // Collect additions and deletions let mut to_delete = Vec::new(); let mut to_add = Vec::new(); for line in content.lines().skip(1) { let line = line.trim(); + if line.is_empty() { + continue; + } if let Some(stripped) = line.strip_prefix("-,") { to_delete.push(stripped.trim().to_string()); } else if let Some(stripped) = line.strip_prefix("+,") { to_add.push(stripped.trim().to_string()); } else { - bail!("Invalid line in diff file: {line}"); + bail!("Invalid row in diff file: {line}. Must start with '-,' or '+,'"); } } @@ -54,25 +58,28 @@ fn read_diffs(file_path: &Path) -> Result { }) } -/// Modify a string by applying diffs: removing lines and adding lines. -fn modify_string_with_diffs(original: &str, diffs: &FileDiffs) -> Result { - let mut modified = original.to_string(); +/// Modify a string representation of a file by applying diffs: removing lines and adding lines. +fn modify_base_with_diffs(base: &str, diffs: &FileDiffs) -> Result { + let base = base.trim(); + let mut modified = base.to_string(); + + // Check that the base string is not empty + ensure!(!base.is_empty(), "Base file is empty"); // Check that the headers match - let original_header = original - .lines() - .next() - .expect("Original string cannot be empty"); + let original_header = base.lines().next().unwrap().trim(); ensure!( original_header == diffs.header_line, - "Header line in diff file does not match original file" + "Header line in diff file does not match base file: expected '{}', found '{}'", + original_header, + diffs.header_line ); // Apply deletions for item in &diffs.to_delete { ensure!( modified.contains(item), - "Item to delete not found in original file: {item}" + "Row to delete not found in base file: {item}" ); modified = modified.replace(item, ""); } @@ -85,17 +92,18 @@ fn modify_string_with_diffs(original: &str, diffs: &FileDiffs) -> Result Ok(modified) } -pub fn patch_model>(model_dir: P, diffs_dir: P) -> Result { - // Copy contents of model_dir to a teporary directory +pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result { + info!( + "Patching model at '{}' with diffs from '{}'", + base_model_dir.as_ref().display(), + diffs_dir.as_ref().display() + ); + + // Copy contents of `base_model_dir` to a temporary directory let temp_dir = tempdir().context("Failed to create temporary directory")?; let temp_path = temp_dir.path(); - for entry in fs::read_dir(model_dir.as_ref()).with_context(|| { - format!( - "Failed to read model directory: {}", - model_dir.as_ref().display() - ) - })? { + for entry in fs::read_dir(base_model_dir.as_ref())? { let entry = entry?; let src_path = entry.path(); @@ -107,13 +115,13 @@ pub fn patch_model>(model_dir: P, diffs_dir: P) -> Result>(model_dir: P, diffs_dir: P) -> Result Result<()> { +fn read_base_file_with_patch(base_file_path: &Path, diff_path: &Path) -> Result { + // Read base file + if !base_file_path.exists() { + bail!( + "Base file for patching does not exist: {}", + base_file_path.display() + ); + } + let base = fs::read_to_string(base_file_path).with_context(|| input_err_msg(base_file_path))?; + + // Read diff file + if !diff_path.exists() { + bail!( + "Diff file for patching does not exist: {}", + diff_path.display() + ); + } let diffs = read_diffs(diff_path).with_context(|| input_err_msg(diff_path))?; - let original = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?; - let modified = modify_string_with_diffs(&original, &diffs)?; - fs::write(file_path, modified)?; + + // Apply diffs to base file + let modified = modify_base_with_diffs(&base, &diffs)?; + Ok(modified) +} + +fn patch_model_toml(base_path: &Path, patch_path: &Path) -> Result<()> { + // Read original TOML file + let base_str = fs::read_to_string(base_path).with_context(|| input_err_msg(base_path))?; + let mut base_data: toml::Value = + toml::from_str(&base_str).with_context(|| input_err_msg(base_path))?; + + // Read patch TOML file + let patch_str = fs::read_to_string(patch_path).with_context(|| input_err_msg(patch_path))?; + let patch_data: toml::Value = + toml::from_str(&patch_str).with_context(|| input_err_msg(patch_path))?; + + // Merge patch into base (only top-level fields allowed) + let base_table = base_data.as_table_mut().expect("Base TOML must be a table"); + let patch_table = patch_data.as_table().expect("Patch TOML must be a table"); + + for (key, patch_val) in patch_table { + // Skip `base_model` field + if key == "base_model" { + continue; + } + + // Overwrite or add the field from the patch + base_table.insert(key.clone(), patch_val.clone()); + } + + // Save modified TOML back to original file + let modified_str = toml::to_string_pretty(&base_data)?; + fs::write(base_path, modified_str)?; + Ok(()) } @@ -178,13 +258,14 @@ mod tests { let temp_dir = tempdir().unwrap(); let diff_file = temp_dir.path().join("test_diff.csv"); - let content = "header\n-, item_with_spaces \n+, another_item \n"; + let content = " header \n-, item_with_spaces \n+, another_item \n"; let mut file = fs::File::create(&diff_file).unwrap(); file.write_all(content.as_bytes()).unwrap(); let diffs = read_diffs(&diff_file).unwrap(); // Whitespace should be trimmed + assert_eq!(diffs.header_line, "header"); assert_eq!(diffs.to_delete, vec!["item_with_spaces"]); assert_eq!(diffs.to_add, vec!["another_item"]); } @@ -211,7 +292,7 @@ mod tests { to_add: vec!["line_new".to_string()], }; - let modified = modify_string_with_diffs(original, &diffs).unwrap(); + let modified = modify_base_with_diffs(original, &diffs).unwrap(); assert!(!modified.contains("line2")); assert!(modified.contains("line_new")); } @@ -225,7 +306,7 @@ mod tests { to_add: vec![], }; - let result = modify_string_with_diffs(original, &diffs); + let result = modify_base_with_diffs(original, &diffs); assert_error!( result, "Header line in diff file does not match original file" @@ -241,7 +322,7 @@ mod tests { to_add: vec![], }; - let result = modify_string_with_diffs(original, &diffs); + let result = modify_base_with_diffs(original, &diffs); assert_error!( result, "Item to delete not found in original file: nonexistent" diff --git a/src/model/parameters.rs b/src/model/parameters.rs index d0c9f2b47..862434fdd 100644 --- a/src/model/parameters.rs +++ b/src/model/parameters.rs @@ -58,6 +58,7 @@ define_param_default!(default_mothball_years, u32, 0); #[derive(Debug, Deserialize, PartialEq)] pub struct ModelParameters { /// Milestone years + #[serde(default)] pub milestone_years: Vec, /// Allow known-broken options to be enabled. #[serde(default, rename = "please_give_me_broken_results")] // Can't use constant here :-( @@ -178,6 +179,16 @@ impl ModelParameters { let file_path = model_dir.as_ref().join(MODEL_PARAMETERS_FILE_NAME); let model_params: ModelParameters = read_toml(&file_path)?; + // If `base_model` is specified, just check that it exists and skip further validation + // as we will do this later on the fully patched model. + if let Some(base_model_path) = &model_params.base_model { + ensure!( + Path::new(base_model_path).is_dir(), + "`base_model` directory not found: {base_model_path}", + ); + return Ok(model_params); + } + // Set flag signalling whether broken model options are allowed or not BROKEN_OPTIONS_ALLOWED .set(model_params.allow_broken_options) From e87fd9d3a861717216a7172b4ef4859858145fa0 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 11 Dec 2025 16:09:32 +0000 Subject: [PATCH 04/32] Use indexset to prevent duplicate rows, fix issues with newline characters --- src/input/patch.rs | 72 ++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 25 deletions(-) diff --git a/src/input/patch.rs b/src/input/patch.rs index 1588de250..3f98053fc 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -2,6 +2,7 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; +use indexmap::IndexSet; use log::info; use std::fs; use std::path::Path; @@ -13,9 +14,9 @@ struct FileDiffs { /// The header line from the diff file header_line: String, /// Lines to delete from the original file - to_delete: Vec, + to_delete: IndexSet, /// Lines to add to the original file - to_add: Vec, + to_add: IndexSet, } /// Read diffs from a diff file. @@ -35,22 +36,29 @@ fn read_diffs(file_path: &Path) -> Result { }; // Collect additions and deletions - let mut to_delete = Vec::new(); - let mut to_add = Vec::new(); + let mut to_delete = IndexSet::new(); + let mut to_add = IndexSet::new(); for line in content.lines().skip(1) { let line = line.trim(); if line.is_empty() { continue; } - if let Some(stripped) = line.strip_prefix("-,") { - to_delete.push(stripped.trim().to_string()); - } else if let Some(stripped) = line.strip_prefix("+,") { - to_add.push(stripped.trim().to_string()); + if let Some(body) = line.strip_prefix("-,") { + let v = body.trim().to_string(); + ensure!(to_delete.insert(v.clone()), "Duplicate deletion entry: {v}",); + } else if let Some(body) = line.strip_prefix("+,") { + let v = body.trim().to_string(); + ensure!(to_add.insert(v.clone()), "Duplicate addition entry: {v}"); } else { bail!("Invalid row in diff file: {line}. Must start with '-,' or '+,'"); } } + // Disallow overlap + if let Some(dup) = to_delete.iter().find(|d| to_add.contains(*d)) { + bail!("Line appears in both deletions and additions: {dup}"); + } + Ok(FileDiffs { header_line, to_delete, @@ -60,14 +68,11 @@ fn read_diffs(file_path: &Path) -> Result { /// Modify a string representation of a file by applying diffs: removing lines and adding lines. fn modify_base_with_diffs(base: &str, diffs: &FileDiffs) -> Result { - let base = base.trim(); - let mut modified = base.to_string(); - - // Check that the base string is not empty - ensure!(!base.is_empty(), "Base file is empty"); + ensure!(!base.trim().is_empty(), "Base file is empty"); - // Check that the headers match - let original_header = base.lines().next().unwrap().trim(); + // Split into lines while preserving order; keep header intact + let lines: Vec<&str> = base.lines().collect(); + let original_header = lines.first().unwrap().trim(); ensure!( original_header == diffs.header_line, "Header line in diff file does not match base file: expected '{}', found '{}'", @@ -75,21 +80,38 @@ fn modify_base_with_diffs(base: &str, diffs: &FileDiffs) -> Result { diffs.header_line ); - // Apply deletions - for item in &diffs.to_delete { + // Build a unique set from the body + let mut body = IndexSet::new(); + for &line in lines.iter().skip(1) { + let l = line.strip_suffix('\r').unwrap_or(line); + ensure!(body.insert(l), "Duplicate line found in base file: {l}"); + } + + // Deletions + for d in &diffs.to_delete { ensure!( - modified.contains(item), - "Row to delete not found in base file: {item}" + body.shift_remove(d.as_str()), + "Row to delete not found in base file: {d}" ); - modified = modified.replace(item, ""); } - // Apply additions - for item in &diffs.to_add { - modified.push_str(item); + // Additions + for a in &diffs.to_add { + ensure!( + body.insert(a.as_str()), + "Addition already present in base file: {a}" + ); } - Ok(modified) + // Rebuild + let mut out = String::new(); + out.push_str(original_header); + if !body.is_empty() { + out.push('\n'); + out.push_str(&body.iter().copied().collect::>().join("\n")); + } + + Ok(out) } pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result { @@ -116,7 +138,7 @@ pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result Date: Thu, 11 Dec 2025 18:31:13 +0000 Subject: [PATCH 05/32] Robustness to white space, correct reading of headers --- src/input/patch.rs | 352 ++++++++++++++++++++++++++++++--------------- 1 file changed, 238 insertions(+), 114 deletions(-) diff --git a/src/input/patch.rs b/src/input/patch.rs index 3f98053fc..94d420383 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -2,116 +2,241 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; +use csv::{ReaderBuilder, StringRecord}; use indexmap::IndexSet; use log::info; use std::fs; +use std::hash::{Hash, Hasher}; use std::path::Path; use tempfile::{TempDir, tempdir}; +/// A CSV row with normalized representation for comparison and original format for output. +/// Equality and hashing are based only on the normalized representation. +#[derive(Debug, Clone)] +struct CsvRow { + normalized: String, + original: StringRecord, +} + +impl CsvRow { + /// Create a `CsvRow` from a CSV record, normalizing it automatically + fn from_record(record: &StringRecord) -> Self { + // Normalize: trim fields and join with commas for comparison + let normalized = record.iter().map(str::trim).collect::>().join(","); + // Preserve original format + Self { + normalized, + original: record.clone(), + } + } + + /// Get the normalized representation (for comparison) + fn normalized(&self) -> &str { + &self.normalized + } + + /// Get the original format as a string (for output) + fn original_str(&self) -> String { + self.original.iter().collect::>().join(",") + } +} + +impl PartialEq for CsvRow { + fn eq(&self, other: &Self) -> bool { + self.normalized == other.normalized + } +} + +impl Eq for CsvRow {} + +impl Hash for CsvRow { + fn hash(&self, state: &mut H) { + self.normalized.hash(state); + } +} + /// Structure to hold diffs from a diff file #[derive(Debug)] struct FileDiffs { - /// The header line from the diff file - header_line: String, - /// Lines to delete from the original file - to_delete: IndexSet, - /// Lines to add to the original file - to_add: IndexSet, + /// The header columns from the base file (without the diff column) + base_headers: Vec, + /// Rows to delete (normalized) + to_delete: IndexSet, + /// Rows to add (normalized) + to_add: IndexSet, } /// Read diffs from a diff file. /// -/// Reads a diff file where the first line is a header, and subsequent lines start with "-," for -/// deletions and "+," for additions. +/// The diff file has an extra column on the left with '+' or '-' indicators. +/// The header row also has this extra column which should be ignored when comparing +/// with the base file header. fn read_diffs(file_path: &Path) -> Result { - // Read the entire file as a string let content = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?; - let content = content.trim(); - // Read header line - // This is saved to ensure that diffs are applied to a base file with the same header - let header_line = match content.lines().next() { - Some(line) => line.trim().to_string(), - None => bail!("Diff file cannot be empty"), - }; + let mut reader = ReaderBuilder::new().from_reader(content.as_bytes()); + + // Read header + let diff_header = reader.headers().with_context(|| input_err_msg(file_path))?; + + ensure!(!diff_header.is_empty(), "Diff file header cannot be empty"); + + // Extract base headers (skip first column which is the diff indicator) + // Trim headers for comparison + let base_headers: Vec = diff_header + .iter() + .skip(1) + .map(|s| s.trim().to_string()) + .collect(); + + ensure!( + !base_headers.is_empty(), + "Diff file must have at least one data column (excluding diff column)" + ); // Collect additions and deletions let mut to_delete = IndexSet::new(); let mut to_add = IndexSet::new(); - for line in content.lines().skip(1) { - let line = line.trim(); - if line.is_empty() { - continue; - } - if let Some(body) = line.strip_prefix("-,") { - let v = body.trim().to_string(); - ensure!(to_delete.insert(v.clone()), "Duplicate deletion entry: {v}",); - } else if let Some(body) = line.strip_prefix("+,") { - let v = body.trim().to_string(); - ensure!(to_add.insert(v.clone()), "Duplicate addition entry: {v}"); - } else { - bail!("Invalid row in diff file: {line}. Must start with '-,' or '+,'"); + + for (line_num, result) in reader.records().enumerate() { + let record = result.with_context(|| { + format!("Error reading record at line {} in diff file", line_num + 2) + })?; + + ensure!( + !record.is_empty(), + "Empty row at line {} in diff file", + line_num + 2 + ); + + // First column is the diff indicator + let diff_indicator = record + .get(0) + .context("Missing diff indicator column")? + .trim(); + + // Extract the base row (skip first column) + let base_row_record: StringRecord = record.iter().skip(1).collect(); + + // Create CsvRow (normalizes internally) + let csv_row = CsvRow::from_record(&base_row_record); + + match diff_indicator { + "-" => { + ensure!( + to_delete.insert(csv_row.clone()), + "Duplicate deletion entry at line {}: {}", + line_num + 2, + csv_row.normalized() + ); + } + "+" => { + ensure!( + to_add.insert(csv_row.clone()), + "Duplicate addition entry at line {}: {}", + line_num + 2, + csv_row.normalized() + ); + } + _ => { + bail!( + "Invalid diff indicator at line {}: '{}'. Must be '+' or '-'", + line_num + 2, + diff_indicator + ); + } } } - // Disallow overlap - if let Some(dup) = to_delete.iter().find(|d| to_add.contains(*d)) { - bail!("Line appears in both deletions and additions: {dup}"); + // Disallow overlap between deletions and additions + for del_row in &to_delete { + ensure!( + !to_add.contains(del_row), + "Row appears in both deletions and additions: {}", + del_row.normalized() + ); } Ok(FileDiffs { - header_line, + base_headers, to_delete, to_add, }) } -/// Modify a string representation of a file by applying diffs: removing lines and adding lines. +/// Modify a base CSV file by applying diffs: removing rows and adding rows. +/// Preserves the order of rows from the base file, with new rows appended at the end. fn modify_base_with_diffs(base: &str, diffs: &FileDiffs) -> Result { ensure!(!base.trim().is_empty(), "Base file is empty"); - // Split into lines while preserving order; keep header intact - let lines: Vec<&str> = base.lines().collect(); - let original_header = lines.first().unwrap().trim(); + let mut reader = ReaderBuilder::new().from_reader(base.as_bytes()); + + // Read and validate header + let base_header = reader + .headers() + .context("Failed to read base file header")?; + + // Trim headers for comparison + let base_header_vec: Vec = base_header.iter().map(|s| s.trim().to_string()).collect(); + ensure!( - original_header == diffs.header_line, - "Header line in diff file does not match base file: expected '{}', found '{}'", - original_header, - diffs.header_line + base_header_vec == diffs.base_headers, + "Header mismatch: base file has [{}], diff file expects [{}]", + base_header_vec.join(", "), + diffs.base_headers.join(", ") ); - // Build a unique set from the body - let mut body = IndexSet::new(); - for &line in lines.iter().skip(1) { - let l = line.strip_suffix('\r').unwrap_or(line); - ensure!(body.insert(l), "Duplicate line found in base file: {l}"); - } + // Read all rows from base file, preserving order and checking for duplicates + let mut base_rows = IndexSet::new(); + + for (line_num, result) in reader.records().enumerate() { + let record = result.with_context(|| { + format!("Error reading record at line {} in base file", line_num + 2) + })?; - // Deletions - for d in &diffs.to_delete { + // CsvRow handles normalization internally + let row = CsvRow::from_record(&record); + + // Check for duplicates using IndexSet ensure!( - body.shift_remove(d.as_str()), - "Row to delete not found in base file: {d}" + base_rows.insert(row.clone()), + "Duplicate row at line {} in base file: {}", + line_num + 2, + row.original_str() ); } - // Additions - for a in &diffs.to_add { + // Apply deletions (IndexSet preserves order during iteration) + base_rows.retain(|row| !diffs.to_delete.contains(row)); + + // Apply additions (append to end, checking for duplicates) + for add_row in &diffs.to_add { ensure!( - body.insert(a.as_str()), - "Addition already present in base file: {a}" + base_rows.insert(add_row.clone()), + "Addition already present in base file: {}", + add_row.normalized() ); } - // Rebuild - let mut out = String::new(); - out.push_str(original_header); - if !body.is_empty() { - out.push('\n'); - out.push_str(&body.iter().copied().collect::>().join("\n")); + // Rebuild CSV output + let mut output = String::new(); + + // Write header + output.push_str(&base_header_vec.join(",")); + + // Write rows (IndexSet preserves insertion order) + if !base_rows.is_empty() { + output.push('\n'); + output.push_str( + &base_rows + .iter() + .map(CsvRow::original_str) + .collect::>() + .join("\n"), + ); } - Ok(out) + Ok(output) } pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result { @@ -152,7 +277,7 @@ pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result row5,row6 -> row7,row8 + let lines: Vec<&str> = modified.lines().collect(); + assert_eq!(lines[0], "col1,col2"); + assert_eq!(lines[1], "row1,row2"); + assert_eq!(lines[2], "row5,row6"); + assert_eq!(lines[3], "row7,row8"); + assert!(!modified.contains("row3,row4")); } #[test] - fn test_modify_string_with_diffs_missing_item() { - let original = "header\nline1\n"; + fn test_modify_base_with_diffs_mismatched_header() { + let base = "col1,col2\nrow1,row2\n"; let diffs = FileDiffs { - header_line: "header".to_string(), - to_delete: vec!["nonexistent".to_string()], - to_add: vec![], + base_headers: vec!["col1".to_string(), "col3".to_string()], + to_delete: IndexSet::new(), + to_add: IndexSet::new(), }; - let result = modify_base_with_diffs(original, &diffs); + let result = modify_base_with_diffs(base, &diffs); assert_error!( result, - "Item to delete not found in original file: nonexistent" + "Header mismatch: base file has [col1, col2], diff file expects [col1, col3]" ); } } From a6109adfe434c35027a3af2a2bf967959d22b5c9 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 11 Dec 2025 18:53:40 +0000 Subject: [PATCH 06/32] Simplify the code - don't need to preserve whitespace --- src/input/patch.rs | 181 ++++++++++++++------------------------------- 1 file changed, 56 insertions(+), 125 deletions(-) diff --git a/src/input/patch.rs b/src/input/patch.rs index 94d420383..abea74c85 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -2,102 +2,54 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; -use csv::{ReaderBuilder, StringRecord}; +use csv::{ReaderBuilder, Trim}; use indexmap::IndexSet; use log::info; use std::fs; -use std::hash::{Hash, Hasher}; use std::path::Path; use tempfile::{TempDir, tempdir}; -/// A CSV row with normalized representation for comparison and original format for output. -/// Equality and hashing are based only on the normalized representation. -#[derive(Debug, Clone)] -struct CsvRow { - normalized: String, - original: StringRecord, -} - -impl CsvRow { - /// Create a `CsvRow` from a CSV record, normalizing it automatically - fn from_record(record: &StringRecord) -> Self { - // Normalize: trim fields and join with commas for comparison - let normalized = record.iter().map(str::trim).collect::>().join(","); - // Preserve original format - Self { - normalized, - original: record.clone(), - } - } - - /// Get the normalized representation (for comparison) - fn normalized(&self) -> &str { - &self.normalized - } - - /// Get the original format as a string (for output) - fn original_str(&self) -> String { - self.original.iter().collect::>().join(",") - } -} - -impl PartialEq for CsvRow { - fn eq(&self, other: &Self) -> bool { - self.normalized == other.normalized - } -} - -impl Eq for CsvRow {} - -impl Hash for CsvRow { - fn hash(&self, state: &mut H) { - self.normalized.hash(state); - } -} - /// Structure to hold diffs from a diff file #[derive(Debug)] struct FileDiffs { - /// The header columns from the base file (without the diff column) - base_headers: Vec, - /// Rows to delete (normalized) - to_delete: IndexSet, - /// Rows to add (normalized) - to_add: IndexSet, + /// The column headers from the diff file + headers: Vec, + /// Rows to delete (normalized to remove whitespace) + to_delete: IndexSet, + /// Rows to add (normalized to remove whitespace) + to_add: IndexSet, } /// Read diffs from a diff file. /// -/// The diff file has an extra column on the left with '+' or '-' indicators. -/// The header row also has this extra column which should be ignored when comparing -/// with the base file header. +/// Diff files follow the same format as base files, but with an extra column on the left with +/// '+' or '-' indicators to indicate whether that row should be added to or deleted from the base +/// file. fn read_diffs(file_path: &Path) -> Result { - let content = fs::read_to_string(file_path).with_context(|| input_err_msg(file_path))?; - - let mut reader = ReaderBuilder::new().from_reader(content.as_bytes()); + // Read the diff file from the given path, trimming any whitespace + let mut reader = ReaderBuilder::new() + .trim(Trim::All) + .from_path(file_path) + .with_context(|| input_err_msg(file_path))?; // Read header let diff_header = reader.headers().with_context(|| input_err_msg(file_path))?; - ensure!(!diff_header.is_empty(), "Diff file header cannot be empty"); - // Extract base headers (skip first column which is the diff indicator) - // Trim headers for comparison - let base_headers: Vec = diff_header + // Colect column headers (skip first column which is the diff indicator) + let headers: Vec = diff_header .iter() .skip(1) - .map(|s| s.trim().to_string()) + .map(ToString::to_string) .collect(); - ensure!( - !base_headers.is_empty(), - "Diff file must have at least one data column (excluding diff column)" + !headers.is_empty(), + "Diff file must have at least one data column" ); // Collect additions and deletions let mut to_delete = IndexSet::new(); let mut to_add = IndexSet::new(); - for (line_num, result) in reader.records().enumerate() { let record = result.with_context(|| { format!("Error reading record at line {} in diff file", line_num + 2) @@ -110,32 +62,26 @@ fn read_diffs(file_path: &Path) -> Result { ); // First column is the diff indicator - let diff_indicator = record - .get(0) - .context("Missing diff indicator column")? - .trim(); - - // Extract the base row (skip first column) - let base_row_record: StringRecord = record.iter().skip(1).collect(); + let diff_indicator = record.get(0).context("Missing diff indicator column")?; - // Create CsvRow (normalizes internally) - let csv_row = CsvRow::from_record(&base_row_record); + // Build normalized row string by joining trimmed fields with commas + let row_str = record.iter().skip(1).collect::>().join(","); match diff_indicator { "-" => { ensure!( - to_delete.insert(csv_row.clone()), + to_delete.insert(row_str.clone()), "Duplicate deletion entry at line {}: {}", line_num + 2, - csv_row.normalized() + row_str ); } "+" => { ensure!( - to_add.insert(csv_row.clone()), + to_add.insert(row_str.clone()), "Duplicate addition entry at line {}: {}", line_num + 2, - csv_row.normalized() + row_str ); } _ => { @@ -152,13 +98,12 @@ fn read_diffs(file_path: &Path) -> Result { for del_row in &to_delete { ensure!( !to_add.contains(del_row), - "Row appears in both deletions and additions: {}", - del_row.normalized() + "Row appears in both deletions and additions: {del_row}" ); } Ok(FileDiffs { - base_headers, + headers, to_delete, to_add, }) @@ -167,54 +112,52 @@ fn read_diffs(file_path: &Path) -> Result { /// Modify a base CSV file by applying diffs: removing rows and adding rows. /// Preserves the order of rows from the base file, with new rows appended at the end. fn modify_base_with_diffs(base: &str, diffs: &FileDiffs) -> Result { - ensure!(!base.trim().is_empty(), "Base file is empty"); - - let mut reader = ReaderBuilder::new().from_reader(base.as_bytes()); + // Read base file from string, trimming whitespace + let mut reader = ReaderBuilder::new() + .trim(Trim::All) + .from_reader(base.as_bytes()); // Read and validate header let base_header = reader .headers() .context("Failed to read base file header")?; - // Trim headers for comparison - let base_header_vec: Vec = base_header.iter().map(|s| s.trim().to_string()).collect(); + let base_header_vec: Vec = base_header.iter().map(ToString::to_string).collect(); ensure!( - base_header_vec == diffs.base_headers, + base_header_vec == diffs.headers, "Header mismatch: base file has [{}], diff file expects [{}]", base_header_vec.join(", "), - diffs.base_headers.join(", ") + diffs.headers.join(", ") ); // Read all rows from base file, preserving order and checking for duplicates let mut base_rows = IndexSet::new(); - for (line_num, result) in reader.records().enumerate() { let record = result.with_context(|| { format!("Error reading record at line {} in base file", line_num + 2) })?; - // CsvRow handles normalization internally - let row = CsvRow::from_record(&record); + // Create normalized row string by joining trimmed fields with commas + let row_str = record.iter().collect::>().join(","); - // Check for duplicates using IndexSet + // Check for duplicates ensure!( - base_rows.insert(row.clone()), + base_rows.insert(row_str.clone()), "Duplicate row at line {} in base file: {}", line_num + 2, - row.original_str() + row_str ); } - // Apply deletions (IndexSet preserves order during iteration) + // Apply deletions base_rows.retain(|row| !diffs.to_delete.contains(row)); // Apply additions (append to end, checking for duplicates) for add_row in &diffs.to_add { ensure!( base_rows.insert(add_row.clone()), - "Addition already present in base file: {}", - add_row.normalized() + "Addition already present in base file: {add_row}" ); } @@ -224,16 +167,10 @@ fn modify_base_with_diffs(base: &str, diffs: &FileDiffs) -> Result { // Write header output.push_str(&base_header_vec.join(",")); - // Write rows (IndexSet preserves insertion order) + // Write rows if !base_rows.is_empty() { output.push('\n'); - output.push_str( - &base_rows - .iter() - .map(CsvRow::original_str) - .collect::>() - .join("\n"), - ); + output.push_str(&base_rows.iter().cloned().collect::>().join("\n")); } Ok(output) @@ -262,7 +199,7 @@ pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result Date: Fri, 12 Dec 2025 11:53:27 +0000 Subject: [PATCH 07/32] Store base_filename within Patch --- src/input/patch.rs | 307 ++++++++++++++++++++++++--------------------- 1 file changed, 165 insertions(+), 142 deletions(-) diff --git a/src/input/patch.rs b/src/input/patch.rs index abea74c85..73223e049 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -2,7 +2,7 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; -use csv::{ReaderBuilder, Trim}; +use csv::{Reader, ReaderBuilder, Trim}; use indexmap::IndexSet; use log::info; use std::fs; @@ -11,7 +11,9 @@ use tempfile::{TempDir, tempdir}; /// Structure to hold diffs from a diff file #[derive(Debug)] -struct FileDiffs { +struct Patch { + /// The target base filename that this patch applies to (e.g. "agents.csv") + base_filename: String, /// The column headers from the diff file headers: Vec, /// Rows to delete (normalized to remove whitespace) @@ -20,98 +22,151 @@ struct FileDiffs { to_add: IndexSet, } -/// Read diffs from a diff file. -/// -/// Diff files follow the same format as base files, but with an extra column on the left with -/// '+' or '-' indicators to indicate whether that row should be added to or deleted from the base -/// file. -fn read_diffs(file_path: &Path) -> Result { - // Read the diff file from the given path, trimming any whitespace - let mut reader = ReaderBuilder::new() - .trim(Trim::All) - .from_path(file_path) - .with_context(|| input_err_msg(file_path))?; - - // Read header - let diff_header = reader.headers().with_context(|| input_err_msg(file_path))?; - ensure!(!diff_header.is_empty(), "Diff file header cannot be empty"); - - // Colect column headers (skip first column which is the diff indicator) - let headers: Vec = diff_header - .iter() - .skip(1) - .map(ToString::to_string) - .collect(); - ensure!( - !headers.is_empty(), - "Diff file must have at least one data column" - ); +impl Patch { + /// Read a diff file and construct a `Patch`. + pub fn from_file(file_path: &Path) -> Result { + let file_name = file_path + .file_name() + .and_then(|n| n.to_str()) + .context("Invalid filename encoding")?; + ensure!( + file_name.to_lowercase().ends_with("_diff.csv"), + "Diff file must end with '_diff.csv': {file_name}" + ); + let base_name = &file_name[..file_name.len() - "_diff.csv".len()]; + let base_filename = format!("{base_name}.csv"); - // Collect additions and deletions - let mut to_delete = IndexSet::new(); - let mut to_add = IndexSet::new(); - for (line_num, result) in reader.records().enumerate() { - let record = result.with_context(|| { - format!("Error reading record at line {} in diff file", line_num + 2) - })?; + let reader = ReaderBuilder::new() + .trim(Trim::All) + .from_path(file_path) + .with_context(|| input_err_msg(file_path))?; + + Self::from_reader(reader, base_filename) + } + + /// Read a diff from an in-memory string and construct a `Patch`. + pub fn _from_str(base_filename: &str, file_contents: &str) -> Result { + let reader = ReaderBuilder::new() + .trim(Trim::All) + .from_reader(file_contents.as_bytes()); + + Self::from_reader(reader, base_filename.to_string()) + } + /// Shared helper that parses a CSV `Reader` and constructs a `Patch`. + fn from_reader( + mut reader: Reader, + base_filename: String, + ) -> Result { + // Read header + let diff_header = reader + .headers() + .with_context(|| input_err_msg(Path::new(&base_filename)))?; + ensure!(!diff_header.is_empty(), "Diff file header cannot be empty"); + + // Colect column headers (skip first column which is the diff indicator) + let headers: Vec = diff_header + .iter() + .skip(1) + .map(ToString::to_string) + .collect(); ensure!( - !record.is_empty(), - "Empty row at line {} in diff file", - line_num + 2 + !headers.is_empty(), + "Diff file must have at least one data column" ); - // First column is the diff indicator - let diff_indicator = record.get(0).context("Missing diff indicator column")?; + // Collect additions and deletions + let mut to_delete = IndexSet::new(); + let mut to_add = IndexSet::new(); + for (line_num, result) in reader.records().enumerate() { + let record = result.with_context(|| { + format!("Error reading record at line {} in diff file", line_num + 2) + })?; - // Build normalized row string by joining trimmed fields with commas - let row_str = record.iter().skip(1).collect::>().join(","); + ensure!( + !record.is_empty(), + "Empty row at line {} in diff file", + line_num + 2 + ); - match diff_indicator { - "-" => { - ensure!( - to_delete.insert(row_str.clone()), - "Duplicate deletion entry at line {}: {}", - line_num + 2, - row_str - ); - } - "+" => { - ensure!( - to_add.insert(row_str.clone()), - "Duplicate addition entry at line {}: {}", - line_num + 2, - row_str - ); - } - _ => { - bail!( - "Invalid diff indicator at line {}: '{}'. Must be '+' or '-'", - line_num + 2, - diff_indicator - ); + // First column is the diff indicator + let diff_indicator = record.get(0).context("Missing diff indicator column")?; + + // Build normalized row string by joining trimmed fields with commas + let row_str = record.iter().skip(1).collect::>().join(","); + + match diff_indicator { + "-" => { + ensure!( + to_delete.insert(row_str.clone()), + "Duplicate deletion entry at line {}: {}", + line_num + 2, + row_str + ); + } + "+" => { + ensure!( + to_add.insert(row_str.clone()), + "Duplicate addition entry at line {}: {}", + line_num + 2, + row_str + ); + } + _ => { + bail!( + "Invalid diff indicator at line {}: '{}'. Must be '+' or '-'", + line_num + 2, + diff_indicator + ); + } } } + + // Disallow overlap between deletions and additions + for del_row in &to_delete { + ensure!( + !to_add.contains(del_row), + "Row appears in both deletions and additions: {del_row}" + ); + } + + Ok(Patch { + base_filename, + headers, + to_delete, + to_add, + }) } - // Disallow overlap between deletions and additions - for del_row in &to_delete { - ensure!( - !to_add.contains(del_row), - "Row appears in both deletions and additions: {del_row}" - ); + /// Apply this patch to a base model and return the modified CSV as a string. + fn apply(&self, base_model_dir: &Path) -> Result { + let base_path = base_model_dir.join(&self.base_filename); + + if !base_path.exists() { + bail!( + "Base file for patching does not exist: {}", + base_path.display() + ); + } + + let base = fs::read_to_string(&base_path).with_context(|| input_err_msg(&base_path))?; + let modified = modify_base_with_diffs(&base, self)?; + Ok(modified) } - Ok(FileDiffs { - headers, - to_delete, - to_add, - }) + /// Apply this patch to a base model and save the modified CSV to another directory. + pub fn apply_and_save(&self, base_model_dir: &Path, new_model_dir: &Path) -> Result<()> { + let modified = self.apply(base_model_dir)?; + let new_path = new_model_dir.join(&self.base_filename); + fs::write(&new_path, modified) + .with_context(|| format!("Failed to write patched file: {}", new_path.display()))?; + Ok(()) + } } /// Modify a base CSV file by applying diffs: removing rows and adding rows. /// Preserves the order of rows from the base file, with new rows appended at the end. -fn modify_base_with_diffs(base: &str, diffs: &FileDiffs) -> Result { +fn modify_base_with_diffs(base: &str, diffs: &Patch) -> Result { // Read base file from string, trimming whitespace let mut reader = ReaderBuilder::new() .trim(Trim::All) @@ -204,50 +259,40 @@ pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result>(base_model_dir: P, diffs_dir: P) -> Result Result { - // Read base file - if !base_file_path.exists() { - bail!( - "Base file for patching does not exist: {}", - base_file_path.display() - ); - } - let base = fs::read_to_string(base_file_path).with_context(|| input_err_msg(base_file_path))?; - - // Read diff file - if !diff_path.exists() { - bail!( - "Diff file for patching does not exist: {}", - diff_path.display() - ); - } - let diffs = read_diffs(diff_path).with_context(|| input_err_msg(diff_path))?; - - // Apply diffs to base file - let modified = modify_base_with_diffs(&base, &diffs)?; - Ok(modified) -} - fn patch_model_toml(base_path: &Path, patch_path: &Path) -> Result<()> { // Read original TOML file let base_str = fs::read_to_string(base_path).with_context(|| input_err_msg(base_path))?; @@ -330,7 +351,7 @@ mod tests { let mut file = fs::File::create(&diff_file).unwrap(); file.write_all(content.as_bytes()).unwrap(); - let diffs = read_diffs(&diff_file).unwrap(); + let diffs = Patch::from_file(&diff_file).unwrap(); assert_eq!(diffs.headers, vec!["col1", "col2"]); assert_eq!(diffs.to_delete.len(), 1); @@ -351,7 +372,7 @@ mod tests { let mut file = fs::File::create(&diff_file).unwrap(); file.write_all(content.as_bytes()).unwrap(); - let diffs = read_diffs(&diff_file).unwrap(); + let diffs = Patch::from_file(&diff_file).unwrap(); // Headers should be trimmed assert_eq!(diffs.headers, vec!["col1", "col2"]); @@ -373,8 +394,9 @@ mod tests { let mut to_add = IndexSet::new(); to_add.insert("row7,row8".to_string()); - let diffs = FileDiffs { + let diffs = Patch { headers: vec!["col1".to_string(), "col2".to_string()], + base_filename: "test.csv".to_string(), to_delete, to_add, }; @@ -393,8 +415,9 @@ mod tests { #[test] fn test_modify_base_with_diffs_mismatched_header() { let base = "col1,col2\nrow1,row2\n"; - let diffs = FileDiffs { + let diffs = Patch { headers: vec!["col1".to_string(), "col3".to_string()], + base_filename: "test.csv".to_string(), to_delete: IndexSet::new(), to_add: IndexSet::new(), }; From 3cffe4bd3287e0a872f2ad73a4511fa5c8401602 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 12 Dec 2025 14:27:42 +0000 Subject: [PATCH 08/32] Builder method for making patches in tests --- src/input.rs | 2 +- src/input/patch.rs | 159 +++++++++++++++++++++++++-------------------- 2 files changed, 90 insertions(+), 71 deletions(-) diff --git a/src/input.rs b/src/input.rs index 2a41745cc..2cf1d30c7 100644 --- a/src/input.rs +++ b/src/input.rs @@ -25,7 +25,7 @@ use asset::read_assets; mod commodity; use commodity::read_commodities; mod patch; -use patch::patch_model; +pub use patch::{Patch, patch_model}; mod process; use process::read_processes; mod region; diff --git a/src/input/patch.rs b/src/input/patch.rs index 73223e049..84f290e02 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -2,7 +2,7 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; -use csv::{Reader, ReaderBuilder, Trim}; +use csv::{ReaderBuilder, Trim}; use indexmap::IndexSet; use log::info; use std::fs; @@ -11,11 +11,11 @@ use tempfile::{TempDir, tempdir}; /// Structure to hold diffs from a diff file #[derive(Debug)] -struct Patch { +pub struct Patch { /// The target base filename that this patch applies to (e.g. "agents.csv") base_filename: String, - /// The column headers from the diff file - headers: Vec, + /// The header row + header_row: String, /// Rows to delete (normalized to remove whitespace) to_delete: IndexSet, /// Rows to add (normalized to remove whitespace) @@ -23,8 +23,43 @@ struct Patch { } impl Patch { + /// Create a new empty `Patch` with the given `base_filename` and + /// a comma-joined header string (e.g. "a,b,c"). + pub fn new(base_filename: B, header_row: H) -> Self + where + B: Into, + H: Into, + { + let base_filename = base_filename.into(); + let header_row = header_row.into().trim().to_string(); + + Patch { + base_filename, + header_row, + to_delete: IndexSet::new(), + to_add: IndexSet::new(), + } + } + + /// Add a row to the patch (row is a canonical comma-joined string, e.g. "a,b,c"). + /// Returns `self` for chaining. + pub fn add(&mut self, row: impl Into) -> &mut Self { + let s = row.into().trim().to_string(); + self.to_add.insert(s); + self + } + + /// Mark a row for deletion from the base (row is a canonical comma-joined string). + /// Returns `self` for chaining. + pub fn delete(&mut self, row: impl Into) -> &mut Self { + let s = row.into().trim().to_string(); + self.to_delete.insert(s); + self + } + /// Read a diff file and construct a `Patch`. pub fn from_file(file_path: &Path) -> Result { + // Extract the base filename by removing the `_diff.csv` suffix let file_name = file_path .file_name() .and_then(|n| n.to_str()) @@ -36,44 +71,29 @@ impl Patch { let base_name = &file_name[..file_name.len() - "_diff.csv".len()]; let base_filename = format!("{base_name}.csv"); - let reader = ReaderBuilder::new() + // Read diff CSV file + let mut reader = ReaderBuilder::new() .trim(Trim::All) .from_path(file_path) .with_context(|| input_err_msg(file_path))?; - Self::from_reader(reader, base_filename) - } - - /// Read a diff from an in-memory string and construct a `Patch`. - pub fn _from_str(base_filename: &str, file_contents: &str) -> Result { - let reader = ReaderBuilder::new() - .trim(Trim::All) - .from_reader(file_contents.as_bytes()); - - Self::from_reader(reader, base_filename.to_string()) - } - - /// Shared helper that parses a CSV `Reader` and constructs a `Patch`. - fn from_reader( - mut reader: Reader, - base_filename: String, - ) -> Result { // Read header let diff_header = reader .headers() .with_context(|| input_err_msg(Path::new(&base_filename)))?; ensure!(!diff_header.is_empty(), "Diff file header cannot be empty"); - // Colect column headers (skip first column which is the diff indicator) - let headers: Vec = diff_header + // Collect column headers (skip first column which is the diff indicator) + let headers_vec: Vec = diff_header .iter() .skip(1) .map(ToString::to_string) .collect(); ensure!( - !headers.is_empty(), + !headers_vec.is_empty(), "Diff file must have at least one data column" ); + let header_row = headers_vec.join(","); // Collect additions and deletions let mut to_delete = IndexSet::new(); @@ -99,40 +119,25 @@ impl Patch { "-" => { ensure!( to_delete.insert(row_str.clone()), - "Duplicate deletion entry at line {}: {}", - line_num + 2, - row_str + "Duplicate deletion entry: {row_str}", ); } "+" => { ensure!( to_add.insert(row_str.clone()), - "Duplicate addition entry at line {}: {}", - line_num + 2, - row_str + "Duplicate addition entry: {row_str}", ); } _ => { - bail!( - "Invalid diff indicator at line {}: '{}'. Must be '+' or '-'", - line_num + 2, - diff_indicator - ); + bail!("Invalid diff indicator: '{diff_indicator}'. Must be '+' or '-'"); } } } - // Disallow overlap between deletions and additions - for del_row in &to_delete { - ensure!( - !to_add.contains(del_row), - "Row appears in both deletions and additions: {del_row}" - ); - } - + // Create Patch object Ok(Patch { base_filename, - headers, + header_row, to_delete, to_add, }) @@ -140,17 +145,18 @@ impl Patch { /// Apply this patch to a base model and return the modified CSV as a string. fn apply(&self, base_model_dir: &Path) -> Result { + // Read the base file to string let base_path = base_model_dir.join(&self.base_filename); - if !base_path.exists() { bail!( "Base file for patching does not exist: {}", base_path.display() ); } - let base = fs::read_to_string(&base_path).with_context(|| input_err_msg(&base_path))?; - let modified = modify_base_with_diffs(&base, self)?; + + // Apply the patch + let modified = modify_base_with_patch(&base, self)?; Ok(modified) } @@ -166,7 +172,7 @@ impl Patch { /// Modify a base CSV file by applying diffs: removing rows and adding rows. /// Preserves the order of rows from the base file, with new rows appended at the end. -fn modify_base_with_diffs(base: &str, diffs: &Patch) -> Result { +fn modify_base_with_patch(base: &str, diffs: &Patch) -> Result { // Read base file from string, trimming whitespace let mut reader = ReaderBuilder::new() .trim(Trim::All) @@ -179,11 +185,17 @@ fn modify_base_with_diffs(base: &str, diffs: &Patch) -> Result { let base_header_vec: Vec = base_header.iter().map(ToString::to_string).collect(); + // Compare base header vector with the comma-joined header string stored in the patch + let diffs_header_vec: Vec = diffs + .header_row + .split(',') + .map(|s| s.trim().to_string()) + .collect(); ensure!( - base_header_vec == diffs.headers, + base_header_vec == diffs_header_vec, "Header mismatch: base file has [{}], diff file expects [{}]", base_header_vec.join(", "), - diffs.headers.join(", ") + diffs_header_vec.join(", ") ); // Read all rows from base file, preserving order and checking for duplicates @@ -205,6 +217,14 @@ fn modify_base_with_diffs(base: &str, diffs: &Patch) -> Result { ); } + // Check that there's no overlap between additions and deletions + for del_row in &diffs.to_delete { + ensure!( + !diffs.to_add.contains(del_row), + "Row appears in both deletions and additions: {del_row}" + ); + } + // Apply deletions base_rows.retain(|row| !diffs.to_delete.contains(row)); @@ -231,6 +251,8 @@ fn modify_base_with_diffs(base: &str, diffs: &Patch) -> Result { Ok(output) } +/// Patch a base model directory with diffs from another directory. +/// Returns a `TempDir` containing the patched model. pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result { info!( "Patching model at '{}' with diffs from '{}'", @@ -344,19 +366,18 @@ mod tests { #[test] fn test_read_diffs_basic() { + // Create diff file let temp_dir = tempdir().unwrap(); let diff_file = temp_dir.path().join("test_diff.csv"); - let content = "diff,col1,col2\n-,val1,val2\n+,val3,val4\n"; let mut file = fs::File::create(&diff_file).unwrap(); file.write_all(content.as_bytes()).unwrap(); + // Parse from the file let diffs = Patch::from_file(&diff_file).unwrap(); - - assert_eq!(diffs.headers, vec!["col1", "col2"]); + assert_eq!(diffs.header_row, "col1,col2"); assert_eq!(diffs.to_delete.len(), 1); assert_eq!(diffs.to_add.len(), 1); - let del_row = "val1,val2".to_string(); let add_row = "val3,val4".to_string(); assert!(diffs.to_delete.contains(&del_row)); @@ -365,24 +386,22 @@ mod tests { #[test] fn test_read_diffs_with_whitespace() { + // Create diff file with extra whitespace let temp_dir = tempdir().unwrap(); let diff_file = temp_dir.path().join("test_diff.csv"); - let content = " diff , col1 , col2 \n-, item1 , item2 \n+, another1 , another2 \n"; let mut file = fs::File::create(&diff_file).unwrap(); file.write_all(content.as_bytes()).unwrap(); - let diffs = Patch::from_file(&diff_file).unwrap(); - - // Headers should be trimmed - assert_eq!(diffs.headers, vec!["col1", "col2"]); - // Rows should be normalized (whitespace trimmed) - assert_eq!(diffs.to_delete.len(), 1); - assert_eq!(diffs.to_add.len(), 1); - - // Check that whitespace is normalized + // Parse from the file + let diffs_from_file = Patch::from_file(&diff_file).unwrap(); + assert_eq!(diffs_from_file.header_row, "col1,col2"); + assert_eq!(diffs_from_file.to_delete.len(), 1); + assert_eq!(diffs_from_file.to_add.len(), 1); let del_row = "item1,item2".to_string(); - assert!(diffs.to_delete.contains(&del_row)); + let add_row = "another1,another2".to_string(); + assert!(diffs_from_file.to_delete.contains(&del_row)); + assert!(diffs_from_file.to_add.contains(&add_row)); } #[test] @@ -395,13 +414,13 @@ mod tests { to_add.insert("row7,row8".to_string()); let diffs = Patch { - headers: vec!["col1".to_string(), "col2".to_string()], + header_row: "col1,col2".to_string(), base_filename: "test.csv".to_string(), to_delete, to_add, }; - let modified = modify_base_with_diffs(base, &diffs).unwrap(); + let modified = modify_base_with_patch(base, &diffs).unwrap(); // Should preserve order: row1,row2 -> row5,row6 -> row7,row8 let lines: Vec<&str> = modified.lines().collect(); @@ -416,13 +435,13 @@ mod tests { fn test_modify_base_with_diffs_mismatched_header() { let base = "col1,col2\nrow1,row2\n"; let diffs = Patch { - headers: vec!["col1".to_string(), "col3".to_string()], + header_row: "col1,col3".to_string(), base_filename: "test.csv".to_string(), to_delete: IndexSet::new(), to_add: IndexSet::new(), }; - let result = modify_base_with_diffs(base, &diffs); + let result = modify_base_with_patch(base, &diffs); assert_error!( result, "Header mismatch: base file has [col1, col2], diff file expects [col1, col3]" From 859c288ed08fd8072a4383a740d20d3d022dee35 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 12 Dec 2025 14:43:28 +0000 Subject: [PATCH 09/32] Make the header optional --- src/input/patch.rs | 70 ++++++++++++++++++++++------------------------ 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/src/input/patch.rs b/src/input/patch.rs index 84f290e02..55bfc013f 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -14,46 +14,41 @@ use tempfile::{TempDir, tempdir}; pub struct Patch { /// The target base filename that this patch applies to (e.g. "agents.csv") base_filename: String, - /// The header row - header_row: String, - /// Rows to delete (normalized to remove whitespace) + /// The header row (optional). If `None`, the header is not checked against base files. + header_row: Option, + /// Rows to delete to_delete: IndexSet, - /// Rows to add (normalized to remove whitespace) + /// Rows to add to_add: IndexSet, } impl Patch { - /// Create a new empty `Patch` with the given `base_filename` and - /// a comma-joined header string (e.g. "a,b,c"). - pub fn new(base_filename: B, header_row: H) -> Self - where - B: Into, - H: Into, - { + /// Create a new empty `Patch` with the given `base_filename`. + pub fn new(base_filename: impl Into) -> Self { let base_filename = base_filename.into(); - let header_row = header_row.into().trim().to_string(); - Patch { base_filename, - header_row, + header_row: None, to_delete: IndexSet::new(), to_add: IndexSet::new(), } } + /// Set the header row for this patch (`header` should be a comma-joined string, e.g. "a,b,c"). + pub fn with_header(mut self, header: impl Into) -> Self { + self.header_row = Some(header.into()); + self + } + /// Add a row to the patch (row is a canonical comma-joined string, e.g. "a,b,c"). - /// Returns `self` for chaining. pub fn add(&mut self, row: impl Into) -> &mut Self { - let s = row.into().trim().to_string(); - self.to_add.insert(s); + self.to_add.insert(row.into()); self } /// Mark a row for deletion from the base (row is a canonical comma-joined string). - /// Returns `self` for chaining. pub fn delete(&mut self, row: impl Into) -> &mut Self { - let s = row.into().trim().to_string(); - self.to_delete.insert(s); + self.to_delete.insert(row.into()); self } @@ -137,7 +132,7 @@ impl Patch { // Create Patch object Ok(Patch { base_filename, - header_row, + header_row: Some(header_row), to_delete, to_add, }) @@ -185,18 +180,19 @@ fn modify_base_with_patch(base: &str, diffs: &Patch) -> Result { let base_header_vec: Vec = base_header.iter().map(ToString::to_string).collect(); - // Compare base header vector with the comma-joined header string stored in the patch - let diffs_header_vec: Vec = diffs - .header_row - .split(',') - .map(|s| s.trim().to_string()) - .collect(); - ensure!( - base_header_vec == diffs_header_vec, - "Header mismatch: base file has [{}], diff file expects [{}]", - base_header_vec.join(", "), - diffs_header_vec.join(", ") - ); + // If the patch contains a header, compare it with the base file header. + if let Some(ref header_row) = diffs.header_row { + let diffs_header_vec: Vec = header_row + .split(',') + .map(|s| s.trim().to_string()) + .collect(); + ensure!( + base_header_vec == diffs_header_vec, + "Header mismatch: base file has [{}], diff file expects [{}]", + base_header_vec.join(", "), + diffs_header_vec.join(", ") + ); + } // Read all rows from base file, preserving order and checking for duplicates let mut base_rows = IndexSet::new(); @@ -375,7 +371,7 @@ mod tests { // Parse from the file let diffs = Patch::from_file(&diff_file).unwrap(); - assert_eq!(diffs.header_row, "col1,col2"); + assert_eq!(diffs.header_row.as_deref(), Some("col1,col2")); assert_eq!(diffs.to_delete.len(), 1); assert_eq!(diffs.to_add.len(), 1); let del_row = "val1,val2".to_string(); @@ -395,7 +391,7 @@ mod tests { // Parse from the file let diffs_from_file = Patch::from_file(&diff_file).unwrap(); - assert_eq!(diffs_from_file.header_row, "col1,col2"); + assert_eq!(diffs_from_file.header_row.as_deref(), Some("col1,col2")); assert_eq!(diffs_from_file.to_delete.len(), 1); assert_eq!(diffs_from_file.to_add.len(), 1); let del_row = "item1,item2".to_string(); @@ -414,7 +410,7 @@ mod tests { to_add.insert("row7,row8".to_string()); let diffs = Patch { - header_row: "col1,col2".to_string(), + header_row: Some("col1,col2".to_string()), base_filename: "test.csv".to_string(), to_delete, to_add, @@ -435,7 +431,7 @@ mod tests { fn test_modify_base_with_diffs_mismatched_header() { let base = "col1,col2\nrow1,row2\n"; let diffs = Patch { - header_row: "col1,col3".to_string(), + header_row: Some("col1,col3".to_string()), base_filename: "test.csv".to_string(), to_delete: IndexSet::new(), to_add: IndexSet::new(), From 2cb08e0315dfa88f75decbfed59733db664b9a28 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 12 Dec 2025 15:18:50 +0000 Subject: [PATCH 10/32] Store fields rather than comma-separated strings --- src/input.rs | 2 +- src/input/patch.rs | 189 ++++++++++++++++++++++++--------------------- 2 files changed, 103 insertions(+), 88 deletions(-) diff --git a/src/input.rs b/src/input.rs index 2cf1d30c7..58078e678 100644 --- a/src/input.rs +++ b/src/input.rs @@ -25,7 +25,7 @@ use asset::read_assets; mod commodity; use commodity::read_commodities; mod patch; -pub use patch::{Patch, patch_model}; +pub use patch::{FilePatch, patch_model}; mod process; use process::read_processes; mod region; diff --git a/src/input/patch.rs b/src/input/patch.rs index 55bfc013f..7c9e1316f 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -2,7 +2,7 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; -use csv::{ReaderBuilder, Trim}; +use csv::{ReaderBuilder, Trim, Writer}; use indexmap::IndexSet; use log::info; use std::fs; @@ -11,22 +11,43 @@ use tempfile::{TempDir, tempdir}; /// Structure to hold diffs from a diff file #[derive(Debug)] -pub struct Patch { +pub struct FilePatch { /// The target base filename that this patch applies to (e.g. "agents.csv") base_filename: String, /// The header row (optional). If `None`, the header is not checked against base files. - header_row: Option, - /// Rows to delete - to_delete: IndexSet, - /// Rows to add - to_add: IndexSet, + header_row: Option>, + /// Rows to delete (each row is a vector of canonicalized fields) + to_delete: IndexSet>, + /// Rows to add (each row is a vector of canonicalized fields) + to_add: IndexSet>, } -impl Patch { +/// Build a canonical comma-joined string from an iterator of field strings. +fn canonicalize_fields(fields: I) -> String +where + I: IntoIterator, + S: AsRef, +{ + fields + .into_iter() + .map(|s| s.as_ref().trim().to_string()) + .collect::>() + .join(",") +} + +/// Build a canonical vector of trimmed strings from an iterator of field strings. +fn canonicalize_vec<'a, I>(fields: I) -> Vec +where + I: IntoIterator, +{ + fields.into_iter().map(|s| s.trim().to_string()).collect() +} + +impl FilePatch { /// Create a new empty `Patch` with the given `base_filename`. pub fn new(base_filename: impl Into) -> Self { let base_filename = base_filename.into(); - Patch { + FilePatch { base_filename, header_row: None, to_delete: IndexSet::new(), @@ -36,24 +57,32 @@ impl Patch { /// Set the header row for this patch (`header` should be a comma-joined string, e.g. "a,b,c"). pub fn with_header(mut self, header: impl Into) -> Self { - self.header_row = Some(header.into()); + let s = header.into(); + let v = s.split(',').map(|s| s.trim().to_string()).collect(); + self.header_row = Some(v); self } /// Add a row to the patch (row is a canonical comma-joined string, e.g. "a,b,c"). - pub fn add(&mut self, row: impl Into) -> &mut Self { - self.to_add.insert(row.into()); + /// This consumes and returns the `Patch` so calls can be chained. + pub fn add_row(mut self, row: impl Into) -> Self { + let s = row.into(); + let v = s.split(',').map(|s| s.trim().to_string()).collect(); + self.to_add.insert(v); self } /// Mark a row for deletion from the base (row is a canonical comma-joined string). - pub fn delete(&mut self, row: impl Into) -> &mut Self { - self.to_delete.insert(row.into()); + /// This consumes and returns the `Patch` so calls can be chained. + pub fn delete_row(mut self, row: impl Into) -> Self { + let s = row.into(); + let v = s.split(',').map(|s| s.trim().to_string()).collect(); + self.to_delete.insert(v); self } /// Read a diff file and construct a `Patch`. - pub fn from_file(file_path: &Path) -> Result { + pub fn from_file(file_path: &Path) -> Result { // Extract the base filename by removing the `_diff.csv` suffix let file_name = file_path .file_name() @@ -88,11 +117,10 @@ impl Patch { !headers_vec.is_empty(), "Diff file must have at least one data column" ); - let header_row = headers_vec.join(","); // Collect additions and deletions - let mut to_delete = IndexSet::new(); - let mut to_add = IndexSet::new(); + let mut to_delete: IndexSet> = IndexSet::new(); + let mut to_add: IndexSet> = IndexSet::new(); for (line_num, result) in reader.records().enumerate() { let record = result.with_context(|| { format!("Error reading record at line {} in diff file", line_num + 2) @@ -107,19 +135,20 @@ impl Patch { // First column is the diff indicator let diff_indicator = record.get(0).context("Missing diff indicator column")?; - // Build normalized row string by joining trimmed fields with commas - let row_str = record.iter().skip(1).collect::>().join(","); + // Build normalized row vector from the csv record + let row_vec = canonicalize_vec(record.iter().skip(1)); + let row_str = canonicalize_fields(&row_vec); match diff_indicator { "-" => { ensure!( - to_delete.insert(row_str.clone()), + to_delete.insert(row_vec.clone()), "Duplicate deletion entry: {row_str}", ); } "+" => { ensure!( - to_add.insert(row_str.clone()), + to_add.insert(row_vec.clone()), "Duplicate addition entry: {row_str}", ); } @@ -130,9 +159,9 @@ impl Patch { } // Create Patch object - Ok(Patch { + Ok(FilePatch { base_filename, - header_row: Some(header_row), + header_row: Some(headers_vec), to_delete, to_add, }) @@ -167,7 +196,7 @@ impl Patch { /// Modify a base CSV file by applying diffs: removing rows and adding rows. /// Preserves the order of rows from the base file, with new rows appended at the end. -fn modify_base_with_patch(base: &str, diffs: &Patch) -> Result { +fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { // Read base file from string, trimming whitespace let mut reader = ReaderBuilder::new() .trim(Trim::All) @@ -181,32 +210,29 @@ fn modify_base_with_patch(base: &str, diffs: &Patch) -> Result { let base_header_vec: Vec = base_header.iter().map(ToString::to_string).collect(); // If the patch contains a header, compare it with the base file header. - if let Some(ref header_row) = diffs.header_row { - let diffs_header_vec: Vec = header_row - .split(',') - .map(|s| s.trim().to_string()) - .collect(); + if let Some(ref header_row_vec) = patch.header_row { ensure!( - base_header_vec == diffs_header_vec, + base_header_vec == *header_row_vec, "Header mismatch: base file has [{}], diff file expects [{}]", base_header_vec.join(", "), - diffs_header_vec.join(", ") + header_row_vec.join(", ") ); } // Read all rows from base file, preserving order and checking for duplicates - let mut base_rows = IndexSet::new(); + let mut base_rows: IndexSet> = IndexSet::new(); for (line_num, result) in reader.records().enumerate() { let record = result.with_context(|| { format!("Error reading record at line {} in base file", line_num + 2) })?; - // Create normalized row string by joining trimmed fields with commas - let row_str = record.iter().collect::>().join(","); + // Create normalized row vector by trimming fields + let row_vec = canonicalize_vec(record.iter()); + let row_str = canonicalize_fields(&row_vec); // Check for duplicates ensure!( - base_rows.insert(row_str.clone()), + base_rows.insert(row_vec.clone()), "Duplicate row at line {} in base file: {}", line_num + 2, row_str @@ -214,36 +240,35 @@ fn modify_base_with_patch(base: &str, diffs: &Patch) -> Result { } // Check that there's no overlap between additions and deletions - for del_row in &diffs.to_delete { + for del_row in &patch.to_delete { ensure!( - !diffs.to_add.contains(del_row), - "Row appears in both deletions and additions: {del_row}" + !patch.to_add.contains(del_row), + "Row appears in both deletions and additions: {}", + canonicalize_fields(del_row) ); } // Apply deletions - base_rows.retain(|row| !diffs.to_delete.contains(row)); + base_rows.retain(|row| !patch.to_delete.contains(row)); // Apply additions (append to end, checking for duplicates) - for add_row in &diffs.to_add { + for add_row in &patch.to_add { ensure!( base_rows.insert(add_row.clone()), - "Addition already present in base file: {add_row}" + "Addition already present in base file: {add_row:?}" ); } - // Rebuild CSV output - let mut output = String::new(); - - // Write header - output.push_str(&base_header_vec.join(",")); - - // Write rows - if !base_rows.is_empty() { - output.push('\n'); - output.push_str(&base_rows.iter().cloned().collect::>().join("\n")); + // Serialize CSV output using csv::Writer to ensure correct quoting/escaping + let mut wtr = Writer::from_writer(vec![]); + wtr.write_record(base_header_vec.iter())?; + for row in &base_rows { + let row_iter = row.iter().map(String::as_str); + wtr.write_record(row_iter)?; } - + wtr.flush()?; + let inner = wtr.into_inner()?; + let output = String::from_utf8(inner)?; Ok(output) } @@ -288,20 +313,16 @@ pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result row5,row6 -> row7,row8 let lines: Vec<&str> = modified.lines().collect(); @@ -430,14 +450,9 @@ mod tests { #[test] fn test_modify_base_with_diffs_mismatched_header() { let base = "col1,col2\nrow1,row2\n"; - let diffs = Patch { - header_row: Some("col1,col3".to_string()), - base_filename: "test.csv".to_string(), - to_delete: IndexSet::new(), - to_add: IndexSet::new(), - }; + let patch = FilePatch::new("test.csv").with_header("col1,col3"); - let result = modify_base_with_patch(base, &diffs); + let result = modify_base_with_patch(base, &patch); assert_error!( result, "Header mismatch: base file has [col1, col2], diff file expects [col1, col3]" From 7427b56e03c4543eaf7d1291f328e2d6d3558b85 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 12 Dec 2025 15:45:46 +0000 Subject: [PATCH 11/32] Allow patching to permanent paths --- src/input.rs | 18 ++++----- src/input/patch.rs | 98 +++++++++++++++++++++++++--------------------- 2 files changed, 60 insertions(+), 56 deletions(-) diff --git a/src/input.rs b/src/input.rs index 58078e678..56082f00f 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,6 +17,7 @@ use std::fmt::{self, Write}; use std::fs; use std::hash::Hash; use std::path::Path; +use tempfile::tempdir; mod agent; use agent::read_agents; @@ -25,7 +26,7 @@ use asset::read_assets; mod commodity; use commodity::read_commodities; mod patch; -pub use patch::{FilePatch, patch_model}; +pub use patch::{FilePatch, patch_model_to_path}; mod process; use process::read_processes; mod region; @@ -231,17 +232,12 @@ where pub fn load_model>(model_dir: P) -> Result<(Model, AssetPool)> { let model_params = ModelParameters::from_path(&model_dir)?; - // If `model_params` specifies a `base_dir`, patch the base model and load the patched model + // If `model_params` specifies a `base_dir`, patch the base model to a temporary directory and + // load the patched model if let Some(base_dir) = &model_params.base_model { - let patched_model = - patch_model(Path::new(base_dir), model_dir.as_ref()).with_context(|| { - format!( - "Error patching base model at {} with patches from {}", - base_dir, - model_dir.as_ref().display() - ) - })?; - return load_model(patched_model); + let temp = tempdir().context("Failed to create temporary directory for model patching")?; + patch_model_to_path(Path::new(base_dir), model_dir.as_ref(), &temp)?; + return load_model(temp.path()); } let time_slice_info = read_time_slice_info(model_dir.as_ref())?; diff --git a/src/input/patch.rs b/src/input/patch.rs index 7c9e1316f..b19890ffa 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -7,7 +7,6 @@ use indexmap::IndexSet; use log::info; use std::fs; use std::path::Path; -use tempfile::{TempDir, tempdir}; /// Structure to hold diffs from a diff file #[derive(Debug)] @@ -272,37 +271,47 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { Ok(output) } -/// Patch a base model directory with diffs from another directory. -/// Returns a `TempDir` containing the patched model. -pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result { +/// Patch a base model directory with diffs from another directory and save to an output directory. +pub fn patch_model_to_path, O: AsRef>( + base_model_dir: P, + diffs_dir: P, + out_dir: O, +) -> Result<()> { info!( "Patching model at '{}' with diffs from '{}'", base_model_dir.as_ref().display(), diffs_dir.as_ref().display() ); + let out_path = out_dir.as_ref(); - // Copy contents of `base_model_dir` to a temporary directory - let temp_dir = tempdir().context("Failed to create temporary directory")?; - let temp_path = temp_dir.path(); - + // Copy all CSV files from the base model directory to the temporary directory + // Some of these will be overwritten by the patched versions later. for entry in fs::read_dir(base_model_dir.as_ref())? { let entry = entry?; let src_path = entry.path(); - - // Only copy files (skip any subdirectories if present) - if src_path.is_file() { - let dst_path = temp_path.join(entry.file_name()); + if src_path.is_file() + && src_path + .extension() + .and_then(|e| e.to_str()) + .is_some_and(|ext| ext.eq_ignore_ascii_case("csv")) + { + let dst_path = out_path.join(entry.file_name()); fs::copy(&src_path, &dst_path) .with_context(|| format!("Failed to copy file: {}", src_path.display()))?; } } - // Patch the new model.toml file un the temporary directory - let base_toml_path = temp_path.join("model.toml"); + // Read and merge `model.toml` from the base model and the diffs directory, then + // write the merged result into the output directory. + let base_toml_src = base_model_dir.as_ref().join("model.toml"); let patch_toml_path = diffs_dir.as_ref().join("model.toml"); - patch_model_toml(&base_toml_path, &patch_toml_path)?; + let merged_toml = read_toml_with_patch(&base_toml_src, &patch_toml_path)?; + let merged_str = toml::to_string_pretty(&merged_toml)?; + fs::write(out_path.join("model.toml"), merged_str)?; - // Read all patch files into memory first + // Read all patch files into memory first. CSV diffs are parsed into `FilePatch` entries; + // any non-CSV files (e.g. README.txt) are copied from the diffs directory into the + // temporary model directory so they override/add to the patched model. let mut patches = Vec::new(); for entry in fs::read_dir(diffs_dir.as_ref())? { let entry = entry?; @@ -312,67 +321,65 @@ pub fn patch_model>(base_model_dir: P, diffs_dir: P) -> Result { + // Read diffs and push to vector (FilePatch::from_file validates `_diff.csv` suffix) + let patch = FilePatch::from_file(&diff_path).with_context(|| { + format!("Failed to read diff file: {}", diff_path.display()) + })?; + patches.push(patch); + } + _ => { + // Copy non-CSV file (e.g., README) into the temporary patched model directory + let dst_path = out_path.join(entry.file_name()); + fs::copy(&diff_path, &dst_path).with_context(|| { + format!("Failed to copy diff asset: {}", diff_path.display()) + })?; + } } - - // Read diffs and push to vector (Patch::from_file validates `_diff.csv` suffix) - let patch = FilePatch::from_file(&diff_path) - .with_context(|| format!("Failed to read diff file: {}", diff_path.display()))?; - patches.push(patch); } // Apply each patch to its corresponding base file and write to temp dir for patch in &patches { patch - .apply_and_save(base_model_dir.as_ref(), temp_path) + .apply_and_save(base_model_dir.as_ref(), out_path) .with_context(|| format!("Failed to apply patch to file: {}", patch.base_filename))?; } info!( - "Patching complete. Patched model saved to temporary path '{}'", - temp_path.display() + "Patching complete. Patched model saved to '{}'", + out_path.display() ); - // Return the temporary directory - Ok(temp_dir) + Ok(()) } -fn patch_model_toml(base_path: &Path, patch_path: &Path) -> Result<()> { - // Read original TOML file +/// Read `base_path` and `patch_path` TOML files, merge top-level fields from the patch +/// into the base +fn read_toml_with_patch(base_path: &Path, patch_path: &Path) -> Result { + // Read base TOML let base_str = fs::read_to_string(base_path).with_context(|| input_err_msg(base_path))?; let mut base_data: toml::Value = toml::from_str(&base_str).with_context(|| input_err_msg(base_path))?; - // Read patch TOML file + // Read patch TOML let patch_str = fs::read_to_string(patch_path).with_context(|| input_err_msg(patch_path))?; let patch_data: toml::Value = toml::from_str(&patch_str).with_context(|| input_err_msg(patch_path))?; - // Merge patch into base (only top-level fields allowed) let base_table = base_data.as_table_mut().expect("Base TOML must be a table"); let patch_table = patch_data.as_table().expect("Patch TOML must be a table"); + // Merge the patch into the base, skipping `base_model`, and prioritizing patch values for (key, patch_val) in patch_table { - // Skip `base_model` field if key == "base_model" { continue; } - - // Overwrite or add the field from the patch base_table.insert(key.clone(), patch_val.clone()); } - // Save modified TOML back to original file - let modified_str = toml::to_string_pretty(&base_data)?; - fs::write(base_path, modified_str)?; - - Ok(()) + Ok(base_data) } #[cfg(test)] @@ -380,6 +387,7 @@ mod tests { use super::*; use crate::fixture::assert_error; use std::io::Write; + use tempfile::tempdir; #[test] fn test_read_diffs_basic() { From 28b26138399957446bdec1b57dfde9d3132bcf41 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 15 Dec 2025 14:06:25 +0000 Subject: [PATCH 12/32] Introduce ModelPatch --- src/input.rs | 9 +- src/input/patch.rs | 305 +++++++++++++++++++++++++-------------------- 2 files changed, 177 insertions(+), 137 deletions(-) diff --git a/src/input.rs b/src/input.rs index 56082f00f..e74609ee5 100644 --- a/src/input.rs +++ b/src/input.rs @@ -17,7 +17,6 @@ use std::fmt::{self, Write}; use std::fs; use std::hash::Hash; use std::path::Path; -use tempfile::tempdir; mod agent; use agent::read_agents; @@ -26,7 +25,7 @@ use asset::read_assets; mod commodity; use commodity::read_commodities; mod patch; -pub use patch::{FilePatch, patch_model_to_path}; +pub use patch::{FilePatch, ModelPatch}; mod process; use process::read_processes; mod region; @@ -234,9 +233,9 @@ pub fn load_model>(model_dir: P) -> Result<(Model, AssetPool)> { // If `model_params` specifies a `base_dir`, patch the base model to a temporary directory and // load the patched model - if let Some(base_dir) = &model_params.base_model { - let temp = tempdir().context("Failed to create temporary directory for model patching")?; - patch_model_to_path(Path::new(base_dir), model_dir.as_ref(), &temp)?; + if model_params.base_model.is_some() { + let patch = ModelPatch::from_path(model_dir.as_ref())?; + let temp = patch.build_to_tempdir()?; return load_model(temp.path()); } diff --git a/src/input/patch.rs b/src/input/patch.rs index b19890ffa..939b1f128 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -4,10 +4,162 @@ use super::input_err_msg; use anyhow::{Context, Result, bail, ensure}; use csv::{ReaderBuilder, Trim, Writer}; use indexmap::IndexSet; -use log::info; use std::fs; use std::path::Path; +/// Structure to hold a set of patches to apply to a base model. +pub struct ModelPatch { + // The base model directory path + base_model_dir: String, + // The list of file patches to apply + file_patches: Vec, + // Optional settings patches (TOML values) + settings_patch: Option, +} + +impl ModelPatch { + /// Create a new empty `ModelPatch` with the given base model directory. + pub fn new(base_model_dir: String) -> Self { + ModelPatch { + base_model_dir, + file_patches: Vec::new(), + settings_patch: None, + } + } + + /// Add a `FilePatch` to this `ModelPatch`. + pub fn with_file_patch(mut self, patch: FilePatch) -> Self { + self.file_patches.push(patch); + self + } + + /// Add a settings patch (TOML table) to this `ModelPatch`. + pub fn with_settings_patch(mut self, patch: toml::value::Table) -> Self { + assert!( + self.settings_patch.is_none(), + "Settings patch already set for this ModelPatch" + ); + assert!( + !patch.contains_key("base_model"), + "Settings patch cannot contain `base_model` field" + ); + self.settings_patch = Some(patch); + self + } + + /// Build a `ModelPatch` from a diffs directory. Expects `model.toml` to be present in the + /// diffs directory and to contain a `base_model` string field that points to the base + /// model directory. Also collects all `*_diff.csv` files in the diffs directory into + /// `FilePatch` entries, and any other top-level fields in `model.toml` become the + /// `settings_patch`. + pub fn from_path(diffs_dir: &Path) -> Result { + // Read model.toml in the diffs directory + let patch_toml_str = fs::read_to_string(diffs_dir.join("model.toml"))?; + let patch_toml_data: toml::Value = toml::from_str(&patch_toml_str)?; + + // Extract `base_model` field from model.toml + // Any additional fields become the settings_patch + let (base_model_dir, settings_patch) = match patch_toml_data { + toml::Value::Table(mut tbl) => { + let base = tbl + .remove("base_model") + .and_then(|v| v.as_str().map(std::string::ToString::to_string)) + .context("Patch model.toml missing required `base_model` field")?; + (base, tbl) + } + _ => bail!("Patch TOML must be a table"), + }; + + // Collect all file patches from `*_diff.csv` files in diffs directory + let mut file_patches = Vec::new(); + for entry in fs::read_dir(diffs_dir)? { + let entry = entry?; + let p = entry.path(); + if !p.is_file() { + continue; + } + if let Some(name) = p.file_name().and_then(|n| n.to_str()) + && name.to_lowercase().ends_with("_diff.csv") + { + let fp = FilePatch::from_file(&p) + .with_context(|| format!("Failed to read diff file: {}", p.display()))?; + file_patches.push(fp); + } + } + + Ok(ModelPatch { + base_model_dir, + file_patches, + settings_patch: Some(settings_patch), + }) + } + + /// Apply this `ModelPatch` into `out_dir` (creating/overwriting files there). + fn build>(&self, out_dir: O) -> Result<()> { + let base_dir = Path::new(&self.base_model_dir); + let out_path = out_dir.as_ref(); + + // Copy all CSV files from the base model into the output directory + // Any files with associated patches will be overwritten later + for entry in fs::read_dir(base_dir)? { + let entry = entry?; + let src_path = entry.path(); + if src_path.is_file() + && src_path + .extension() + .and_then(|e| e.to_str()) + .is_some_and(|ext| ext.eq_ignore_ascii_case("csv")) + { + let dst_path = out_path.join(entry.file_name()); + fs::copy(&src_path, &dst_path) + .with_context(|| format!("Failed to copy file: {}", src_path.display()))?; + } + } + + // Apply settings patch (if any), or copy model.toml from the base model + let base_toml_path = base_dir.join("model.toml"); + let out_toml_path = out_path.join("model.toml"); + if let Some(settings_patch) = &self.settings_patch { + // Start with model.toml from base model + let settings_toml = fs::read_to_string(&base_toml_path)?; + let mut settings_value: toml::Value = toml::from_str(&settings_toml)?; + let merged_table = settings_value + .as_table_mut() + .context("Merged model TOML must be a table")?; + + // Apply settings patch + for (key, patch_val) in settings_patch { + merged_table.insert(key.clone(), patch_val.clone()); + } + + // Save to file + let merged_toml = toml::to_string_pretty(&settings_value)?; + fs::write(&out_toml_path, merged_toml)?; + } else { + // No settings patch; copy base model.toml + fs::copy(&base_toml_path, &out_toml_path)?; + } + + // Apply file patches + for patch in &self.file_patches { + patch + .apply_and_save(base_dir.as_ref(), out_path) + .with_context(|| { + format!("Failed to apply patch to file: {}", patch.base_filename) + })?; + } + + Ok(()) + } + + /// Build the patched model into a temporary directory and return the `TempDir`. + pub fn build_to_tempdir(&self) -> Result { + let temp_dir = tempfile::tempdir()?; + self.build(temp_dir.path())?; + Ok(temp_dir) + } +} + /// Structure to hold diffs from a diff file #[derive(Debug)] pub struct FilePatch { @@ -63,7 +215,6 @@ impl FilePatch { } /// Add a row to the patch (row is a canonical comma-joined string, e.g. "a,b,c"). - /// This consumes and returns the `Patch` so calls can be chained. pub fn add_row(mut self, row: impl Into) -> Self { let s = row.into(); let v = s.split(',').map(|s| s.trim().to_string()).collect(); @@ -72,7 +223,6 @@ impl FilePatch { } /// Mark a row for deletion from the base (row is a canonical comma-joined string). - /// This consumes and returns the `Patch` so calls can be chained. pub fn delete_row(mut self, row: impl Into) -> Self { let s = row.into(); let v = s.split(',').map(|s| s.trim().to_string()).collect(); @@ -184,9 +334,9 @@ impl FilePatch { } /// Apply this patch to a base model and save the modified CSV to another directory. - pub fn apply_and_save(&self, base_model_dir: &Path, new_model_dir: &Path) -> Result<()> { + pub fn apply_and_save(&self, base_model_dir: &Path, out_model_dir: &Path) -> Result<()> { let modified = self.apply(base_model_dir)?; - let new_path = new_model_dir.join(&self.base_filename); + let new_path = out_model_dir.join(&self.base_filename); fs::write(&new_path, modified) .with_context(|| format!("Failed to write patched file: {}", new_path.display()))?; Ok(()) @@ -271,117 +421,6 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { Ok(output) } -/// Patch a base model directory with diffs from another directory and save to an output directory. -pub fn patch_model_to_path, O: AsRef>( - base_model_dir: P, - diffs_dir: P, - out_dir: O, -) -> Result<()> { - info!( - "Patching model at '{}' with diffs from '{}'", - base_model_dir.as_ref().display(), - diffs_dir.as_ref().display() - ); - let out_path = out_dir.as_ref(); - - // Copy all CSV files from the base model directory to the temporary directory - // Some of these will be overwritten by the patched versions later. - for entry in fs::read_dir(base_model_dir.as_ref())? { - let entry = entry?; - let src_path = entry.path(); - if src_path.is_file() - && src_path - .extension() - .and_then(|e| e.to_str()) - .is_some_and(|ext| ext.eq_ignore_ascii_case("csv")) - { - let dst_path = out_path.join(entry.file_name()); - fs::copy(&src_path, &dst_path) - .with_context(|| format!("Failed to copy file: {}", src_path.display()))?; - } - } - - // Read and merge `model.toml` from the base model and the diffs directory, then - // write the merged result into the output directory. - let base_toml_src = base_model_dir.as_ref().join("model.toml"); - let patch_toml_path = diffs_dir.as_ref().join("model.toml"); - let merged_toml = read_toml_with_patch(&base_toml_src, &patch_toml_path)?; - let merged_str = toml::to_string_pretty(&merged_toml)?; - fs::write(out_path.join("model.toml"), merged_str)?; - - // Read all patch files into memory first. CSV diffs are parsed into `FilePatch` entries; - // any non-CSV files (e.g. README.txt) are copied from the diffs directory into the - // temporary model directory so they override/add to the patched model. - let mut patches = Vec::new(); - for entry in fs::read_dir(diffs_dir.as_ref())? { - let entry = entry?; - let diff_path = entry.path(); - - if !diff_path.is_file() { - continue; - } - - // If the file is a CSV, parse it as a diff. Otherwise, copy the file into the temp dir. - match diff_path.extension().and_then(|e| e.to_str()) { - Some(ext) if ext.eq_ignore_ascii_case("csv") => { - // Read diffs and push to vector (FilePatch::from_file validates `_diff.csv` suffix) - let patch = FilePatch::from_file(&diff_path).with_context(|| { - format!("Failed to read diff file: {}", diff_path.display()) - })?; - patches.push(patch); - } - _ => { - // Copy non-CSV file (e.g., README) into the temporary patched model directory - let dst_path = out_path.join(entry.file_name()); - fs::copy(&diff_path, &dst_path).with_context(|| { - format!("Failed to copy diff asset: {}", diff_path.display()) - })?; - } - } - } - - // Apply each patch to its corresponding base file and write to temp dir - for patch in &patches { - patch - .apply_and_save(base_model_dir.as_ref(), out_path) - .with_context(|| format!("Failed to apply patch to file: {}", patch.base_filename))?; - } - - info!( - "Patching complete. Patched model saved to '{}'", - out_path.display() - ); - - Ok(()) -} - -/// Read `base_path` and `patch_path` TOML files, merge top-level fields from the patch -/// into the base -fn read_toml_with_patch(base_path: &Path, patch_path: &Path) -> Result { - // Read base TOML - let base_str = fs::read_to_string(base_path).with_context(|| input_err_msg(base_path))?; - let mut base_data: toml::Value = - toml::from_str(&base_str).with_context(|| input_err_msg(base_path))?; - - // Read patch TOML - let patch_str = fs::read_to_string(patch_path).with_context(|| input_err_msg(patch_path))?; - let patch_data: toml::Value = - toml::from_str(&patch_str).with_context(|| input_err_msg(patch_path))?; - - let base_table = base_data.as_table_mut().expect("Base TOML must be a table"); - let patch_table = patch_data.as_table().expect("Patch TOML must be a table"); - - // Merge the patch into the base, skipping `base_model`, and prioritizing patch values - for (key, patch_val) in patch_table { - if key == "base_model" { - continue; - } - base_table.insert(key.clone(), patch_val.clone()); - } - - Ok(base_data) -} - #[cfg(test)] mod tests { use super::*; @@ -390,7 +429,7 @@ mod tests { use tempfile::tempdir; #[test] - fn test_read_diffs_basic() { + fn test_patch_from_file() { // Create diff file let temp_dir = tempdir().unwrap(); let diff_file = temp_dir.path().join("test_diff.csv"); @@ -399,21 +438,22 @@ mod tests { file.write_all(content.as_bytes()).unwrap(); // Parse from the file - let diffs = FilePatch::from_file(&diff_file).unwrap(); + let patch = FilePatch::from_file(&diff_file).unwrap(); + assert_eq!( - diffs.header_row.as_ref().map(|v| v.join(",")), + patch.header_row.as_ref().map(|v| v.join(",")), Some("col1,col2".to_string()) ); - assert_eq!(diffs.to_delete.len(), 1); - assert_eq!(diffs.to_add.len(), 1); + assert_eq!(patch.to_delete.len(), 1); + assert_eq!(patch.to_add.len(), 1); let del_row = vec!["val1".to_string(), "val2".to_string()]; let add_row = vec!["val3".to_string(), "val4".to_string()]; - assert!(diffs.to_delete.contains(&del_row)); - assert!(diffs.to_add.contains(&add_row)); + assert!(patch.to_delete.contains(&del_row)); + assert!(patch.to_add.contains(&add_row)); } #[test] - fn test_read_diffs_with_whitespace() { + fn test_patch_from_file_whitespace() { // Create diff file with extra whitespace let temp_dir = tempdir().unwrap(); let diff_file = temp_dir.path().join("test_diff.csv"); @@ -422,21 +462,22 @@ mod tests { file.write_all(content.as_bytes()).unwrap(); // Parse from the file - let diffs_from_file = FilePatch::from_file(&diff_file).unwrap(); + let patch = FilePatch::from_file(&diff_file).unwrap(); + assert_eq!( - diffs_from_file.header_row.as_ref().map(|v| v.join(",")), + patch.header_row.as_ref().map(|v| v.join(",")), Some("col1,col2".to_string()) ); - assert_eq!(diffs_from_file.to_delete.len(), 1); - assert_eq!(diffs_from_file.to_add.len(), 1); + assert_eq!(patch.to_delete.len(), 1); + assert_eq!(patch.to_add.len(), 1); let del_row = vec!["item1".to_string(), "item2".to_string()]; let add_row = vec!["another1".to_string(), "another2".to_string()]; - assert!(diffs_from_file.to_delete.contains(&del_row)); - assert!(diffs_from_file.to_add.contains(&add_row)); + assert!(patch.to_delete.contains(&del_row)); + assert!(patch.to_add.contains(&add_row)); } #[test] - fn test_modify_base_with_diffs_preserves_order() { + fn test_modify_base_with_patch() { let base = "col1,col2\nrow1,row2\nrow3,row4\nrow5,row6\n"; let patch = FilePatch::new("test.csv") @@ -456,7 +497,7 @@ mod tests { } #[test] - fn test_modify_base_with_diffs_mismatched_header() { + fn test_modify_base_with_patch_mismatched_header() { let base = "col1,col2\nrow1,row2\n"; let patch = FilePatch::new("test.csv").with_header("col1,col3"); From bae9f85f9a7bd7165048052f5178db837a77f252 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 15 Dec 2025 14:18:22 +0000 Subject: [PATCH 13/32] Path handling fixes --- src/input/patch.rs | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/input/patch.rs b/src/input/patch.rs index 939b1f128..c66a3c800 100644 --- a/src/input/patch.rs +++ b/src/input/patch.rs @@ -5,12 +5,12 @@ use anyhow::{Context, Result, bail, ensure}; use csv::{ReaderBuilder, Trim, Writer}; use indexmap::IndexSet; use std::fs; -use std::path::Path; +use std::path::{Path, PathBuf}; /// Structure to hold a set of patches to apply to a base model. pub struct ModelPatch { // The base model directory path - base_model_dir: String, + base_model_dir: PathBuf, // The list of file patches to apply file_patches: Vec, // Optional settings patches (TOML values) @@ -19,9 +19,9 @@ pub struct ModelPatch { impl ModelPatch { /// Create a new empty `ModelPatch` with the given base model directory. - pub fn new(base_model_dir: String) -> Self { + pub fn new>(base_model_dir: P) -> Self { ModelPatch { - base_model_dir, + base_model_dir: base_model_dir.into(), file_patches: Vec::new(), settings_patch: None, } @@ -63,7 +63,7 @@ impl ModelPatch { toml::Value::Table(mut tbl) => { let base = tbl .remove("base_model") - .and_then(|v| v.as_str().map(std::string::ToString::to_string)) + .and_then(|v| v.as_str().map(PathBuf::from)) .context("Patch model.toml missing required `base_model` field")?; (base, tbl) } @@ -96,7 +96,7 @@ impl ModelPatch { /// Apply this `ModelPatch` into `out_dir` (creating/overwriting files there). fn build>(&self, out_dir: O) -> Result<()> { - let base_dir = Path::new(&self.base_model_dir); + let base_dir = self.base_model_dir.as_path(); let out_path = out_dir.as_ref(); // Copy all CSV files from the base model into the output directory @@ -142,11 +142,9 @@ impl ModelPatch { // Apply file patches for patch in &self.file_patches { - patch - .apply_and_save(base_dir.as_ref(), out_path) - .with_context(|| { - format!("Failed to apply patch to file: {}", patch.base_filename) - })?; + patch.apply_and_save(base_dir, out_path).with_context(|| { + format!("Failed to apply patch to file: {}", patch.base_filename) + })?; } Ok(()) From 670c6107f068b9e0d8496e33b648a25f657c1b2e Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Mon, 15 Dec 2025 14:23:46 +0000 Subject: [PATCH 14/32] Move patch code to standalone module --- src/input.rs | 3 +-- src/lib.rs | 1 + src/{input => }/patch.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename src/{input => }/patch.rs (99%) diff --git a/src/input.rs b/src/input.rs index e74609ee5..c324c82a9 100644 --- a/src/input.rs +++ b/src/input.rs @@ -5,6 +5,7 @@ use crate::graph::validate::validate_commodity_graphs_for_model; use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model}; use crate::id::{HasID, IDLike}; use crate::model::{Model, ModelParameters}; +use crate::patch::ModelPatch; use crate::region::RegionID; use crate::units::UnitType; use anyhow::{Context, Result, bail, ensure}; @@ -24,8 +25,6 @@ mod asset; use asset::read_assets; mod commodity; use commodity::read_commodities; -mod patch; -pub use patch::{FilePatch, ModelPatch}; mod process; use process::read_processes; mod region; diff --git a/src/lib.rs b/src/lib.rs index 0e54f9f8e..dcf5836ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,6 +18,7 @@ pub mod input; pub mod log; pub mod model; pub mod output; +pub mod patch; pub mod process; pub mod region; pub mod settings; diff --git a/src/input/patch.rs b/src/patch.rs similarity index 99% rename from src/input/patch.rs rename to src/patch.rs index c66a3c800..7b26e4191 100644 --- a/src/input/patch.rs +++ b/src/patch.rs @@ -1,5 +1,5 @@ //! Code for applying patches/diffs to model input files. -use super::input_err_msg; +use crate::input::input_err_msg; use anyhow::{Context, Result, bail, ensure}; use csv::{ReaderBuilder, Trim, Writer}; From 23b4cd5040d6382a739b7f2531887c51777f8bdc Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 10:54:42 +0000 Subject: [PATCH 15/32] Add toml tests and integration tests --- src/patch.rs | 214 +++++++++++++++++++++++++++++++++---------------- tests/patch.rs | 75 +++++++++++++++++ 2 files changed, 222 insertions(+), 67 deletions(-) create mode 100644 tests/patch.rs diff --git a/src/patch.rs b/src/patch.rs index 7b26e4191..2e4f2ae84 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -13,8 +13,8 @@ pub struct ModelPatch { base_model_dir: PathBuf, // The list of file patches to apply file_patches: Vec, - // Optional settings patches (TOML values) - settings_patch: Option, + // Optional patch for model.toml (TOML table) + toml_patch: Option, } impl ModelPatch { @@ -23,27 +23,36 @@ impl ModelPatch { ModelPatch { base_model_dir: base_model_dir.into(), file_patches: Vec::new(), - settings_patch: None, + toml_patch: None, } } - /// Add a `FilePatch` to this `ModelPatch`. + /// Add a single `FilePatch` to this `ModelPatch`. pub fn with_file_patch(mut self, patch: FilePatch) -> Self { self.file_patches.push(patch); self } - /// Add a settings patch (TOML table) to this `ModelPatch`. - pub fn with_settings_patch(mut self, patch: toml::value::Table) -> Self { + /// Add multiple `FilePatch` entries to this `ModelPatch`. + pub fn with_file_patches(mut self, patches: I) -> Self + where + I: IntoIterator, + { + self.file_patches.extend(patches); + self + } + + /// Add a TOML patch (TOML table) to this `ModelPatch`. + pub fn with_toml_patch(mut self, patch: toml::value::Table) -> Self { assert!( - self.settings_patch.is_none(), - "Settings patch already set for this ModelPatch" + self.toml_patch.is_none(), + "TOML patch already set for this ModelPatch" ); assert!( !patch.contains_key("base_model"), - "Settings patch cannot contain `base_model` field" + "TOML patch cannot contain `base_model` field" ); - self.settings_patch = Some(patch); + self.toml_patch = Some(patch); self } @@ -51,24 +60,10 @@ impl ModelPatch { /// diffs directory and to contain a `base_model` string field that points to the base /// model directory. Also collects all `*_diff.csv` files in the diffs directory into /// `FilePatch` entries, and any other top-level fields in `model.toml` become the - /// `settings_patch`. + /// `toml_patch`. pub fn from_path(diffs_dir: &Path) -> Result { - // Read model.toml in the diffs directory - let patch_toml_str = fs::read_to_string(diffs_dir.join("model.toml"))?; - let patch_toml_data: toml::Value = toml::from_str(&patch_toml_str)?; - - // Extract `base_model` field from model.toml - // Any additional fields become the settings_patch - let (base_model_dir, settings_patch) = match patch_toml_data { - toml::Value::Table(mut tbl) => { - let base = tbl - .remove("base_model") - .and_then(|v| v.as_str().map(PathBuf::from)) - .context("Patch model.toml missing required `base_model` field")?; - (base, tbl) - } - _ => bail!("Patch TOML must be a table"), - }; + // Parse patch model.toml and extract base_model + toml patch + let (base_model_dir, toml_patch) = read_patch_toml(&diffs_dir.join("model.toml"))?; // Collect all file patches from `*_diff.csv` files in diffs directory let mut file_patches = Vec::new(); @@ -90,7 +85,7 @@ impl ModelPatch { Ok(ModelPatch { base_model_dir, file_patches, - settings_patch: Some(settings_patch), + toml_patch: Some(toml_patch), }) } @@ -116,27 +111,16 @@ impl ModelPatch { } } - // Apply settings patch (if any), or copy model.toml from the base model + // Apply toml patch (if any), or copy model.toml from the base model let base_toml_path = base_dir.join("model.toml"); let out_toml_path = out_path.join("model.toml"); - if let Some(settings_patch) = &self.settings_patch { - // Start with model.toml from base model - let settings_toml = fs::read_to_string(&base_toml_path)?; - let mut settings_value: toml::Value = toml::from_str(&settings_toml)?; - let merged_table = settings_value - .as_table_mut() - .context("Merged model TOML must be a table")?; - - // Apply settings patch - for (key, patch_val) in settings_patch { - merged_table.insert(key.clone(), patch_val.clone()); - } - - // Save to file - let merged_toml = toml::to_string_pretty(&settings_value)?; + if let Some(toml_patch) = &self.toml_patch { + // Start with model.toml from base model and merge via helper + let toml_content = fs::read_to_string(&base_toml_path)?; + let merged_toml = merge_model_toml(&toml_content, toml_patch)?; fs::write(&out_toml_path, merged_toml)?; } else { - // No settings patch; copy base model.toml + // No toml patch; copy base model.toml fs::copy(&base_toml_path, &out_toml_path)?; } @@ -171,27 +155,6 @@ pub struct FilePatch { to_add: IndexSet>, } -/// Build a canonical comma-joined string from an iterator of field strings. -fn canonicalize_fields(fields: I) -> String -where - I: IntoIterator, - S: AsRef, -{ - fields - .into_iter() - .map(|s| s.as_ref().trim().to_string()) - .collect::>() - .join(",") -} - -/// Build a canonical vector of trimmed strings from an iterator of field strings. -fn canonicalize_vec<'a, I>(fields: I) -> Vec -where - I: IntoIterator, -{ - fields.into_iter().map(|s| s.trim().to_string()).collect() -} - impl FilePatch { /// Create a new empty `Patch` with the given `base_filename`. pub fn new(base_filename: impl Into) -> Self { @@ -341,7 +304,61 @@ impl FilePatch { } } -/// Modify a base CSV file by applying diffs: removing rows and adding rows. +/// Read a patch `model.toml` file and return the `base_model` path and remaining +/// table as the toml patch. +fn read_patch_toml(toml_path: &Path) -> Result<(PathBuf, toml::value::Table)> { + let s = fs::read_to_string(toml_path)?; + let val: toml::Value = toml::from_str(&s)?; + match val { + toml::Value::Table(mut tbl) => { + let base = tbl + .remove("base_model") + .and_then(|v| v.as_str().map(PathBuf::from)) + .context("Patch model.toml missing required `base_model` field")?; + Ok((base, tbl)) + } + _ => bail!("Patch TOML must be a table"), + } +} + +/// Merge a TOML patch into a base model TOML string and return the merged TOML. +fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result { + if patch.contains_key("base_model") { + bail!("TOML patch cannot contain `base_model` field"); + } + let mut base_val: toml::Value = toml::from_str(base_toml)?; + let base_tbl = base_val + .as_table_mut() + .context("Base model TOML must be a table")?; + for (k, v) in patch { + base_tbl.insert(k.clone(), v.clone()); + } + let out = toml::to_string_pretty(&base_val)?; + Ok(out) +} + +/// Build a canonical comma-joined string from an iterator of field strings. +fn canonicalize_fields(fields: I) -> String +where + I: IntoIterator, + S: AsRef, +{ + fields + .into_iter() + .map(|s| s.as_ref().trim().to_string()) + .collect::>() + .join(",") +} + +/// Build a canonical vector of trimmed strings from an iterator of field strings. +fn canonicalize_vec<'a, I>(fields: I) -> Vec +where + I: IntoIterator, +{ + fields.into_iter().map(|s| s.trim().to_string()).collect() +} + +/// Modify a string representation of a base CSV file by applying a `FilePatch`. /// Preserves the order of rows from the base file, with new rows appended at the end. fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { // Read base file from string, trimming whitespace @@ -505,4 +522,67 @@ mod tests { "Header mismatch: base file has [col1, col2], diff file expects [col1, col3]" ); } + + #[test] + fn test_read_patch_toml() { + // Create patch TOML file + let td = tempdir().unwrap(); + let p = td.path().join("model.toml"); + let content = r#" + base_model = "some/base/path" + foo = "bar" + "#; + let mut f = fs::File::create(&p).unwrap(); + f.write_all(content.as_bytes()).unwrap(); + + // Read with `read_patch_toml` + // Should return the base_model path and remaining table (without the base_model field) + let (base, tbl) = read_patch_toml(&p).unwrap(); + assert_eq!(base, PathBuf::from("some/base/path")); + assert_eq!(tbl.get("foo").and_then(|v| v.as_str()), Some("bar")); + assert!(!tbl.contains_key("base_model")); + } + + #[test] + fn test_merge_model_toml_basic() { + let base = r#" + title = "base" + [section] + a = 1 + "#; + + // Create a patch table + let mut patch = toml::value::Table::new(); + patch.insert( + "title".to_string(), + toml::Value::String("patched".to_string()), + ); + patch.insert( + "new_field".to_string(), + toml::Value::String("added".to_string()), + ); + + // Apply patch with `merge_model_toml` + // Should overwrite title and add new_field, but keep section.a + let merged = merge_model_toml(base, &patch).unwrap(); + assert!(merged.contains("title = \"patched\"")); + assert!(merged.contains("[section]")); + assert!(merged.contains("new_field = \"added\"")); + } + + #[test] + fn test_merge_rejects_base_model_key() { + let base = r#"title = "base""#; + + // Create a patch table with a base_model key + let mut patch = toml::value::Table::new(); + patch.insert( + "base_model".to_string(), + toml::Value::String("..".to_string()), + ); + + // `merge_model_toml` should return an error + let res = merge_model_toml(base, &patch); + assert!(res.is_err()); + } } diff --git a/tests/patch.rs b/tests/patch.rs new file mode 100644 index 000000000..3e1ac21b9 --- /dev/null +++ b/tests/patch.rs @@ -0,0 +1,75 @@ +//! Integration tests for the `validate` command. +use anyhow::Result; +use muse2::cli::handle_validate_command; +use muse2::model::ModelParameters; +use muse2::patch::{FilePatch, ModelPatch}; +use muse2::settings::Settings; +use std::path::PathBuf; +use tempfile::TempDir; + +/// Patch of the "simple" model with a change to the `assets.csv` file. +fn get_model_dir_file_patch() -> Result { + let base_model_dir = PathBuf::from("examples/simple"); + + // Small change to an asset capacity + let assets_patch = FilePatch::new("assets.csv") + .delete_row("GASDRV,GBR,A0_GEX,4002.26,2020") + .add_row("GASDRV,GBR,A0_GEX,4003.26,2020"); + let model_patch = ModelPatch::new(&base_model_dir).with_file_patch(assets_patch); + + let temp_dir = model_patch.build_to_tempdir()?; + Ok(temp_dir) +} + +/// Patch of the "simple" model with a change to the `model.toml` file. +fn get_model_dir_toml_patch() -> Result { + let base_model_dir = PathBuf::from("examples/simple"); + + // Add an extra milestone year (2050) + let mut toml_patch = toml::value::Table::new(); + toml_patch.insert( + "milestone_years".to_string(), + toml::Value::Array(vec![ + toml::Value::Integer(2020), + toml::Value::Integer(2030), + toml::Value::Integer(2040), + toml::Value::Integer(2050), + ]), + ); + let model_patch = ModelPatch::new(&base_model_dir).with_toml_patch(toml_patch); + let temp_dir = model_patch.build_to_tempdir()?; + Ok(temp_dir) +} + +#[test] +fn test_file_patch_and_validate() { + unsafe { std::env::set_var("MUSE2_LOG_LEVEL", "off") }; + + // Model is patched successfully + let model_dir = get_model_dir_file_patch().unwrap(); + + // The appropriate change has been made + let assets_path = model_dir.path().join("assets.csv"); + let assets_content = std::fs::read_to_string(assets_path).unwrap(); + assert!(!assets_content.contains("GASDRV,GBR,A0_GEX,4002.26,2020")); + assert!(assets_content.contains("GASDRV,GBR,A0_GEX,4003.26,2020")); + + // Validation passes + handle_validate_command(&model_dir.path(), Some(Settings::default())).unwrap(); +} + +#[test] +fn test_toml_patch_and_validate() { + unsafe { std::env::set_var("MUSE2_LOG_LEVEL", "off") }; + + // Model is patched successfully + let model_dir = get_model_dir_toml_patch().unwrap(); + + // The appropriate change has been made + let model_params = ModelParameters::from_path(&model_dir).unwrap(); + assert_eq!(model_params.milestone_years, vec![2020, 2030, 2040, 2050]); + + // Validation should fail + let val = handle_validate_command(&model_dir.path(), Some(Settings::default())); + assert!(val.is_err()); +} From 5117c45e48f59a114cc66e1dbfcf2effbb5013cc Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 11:12:46 +0000 Subject: [PATCH 16/32] Pass toml patch as a string --- src/patch.rs | 12 ++++++++---- tests/patch.rs | 28 +++++++++++----------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/patch.rs b/src/patch.rs index 2e4f2ae84..1edb668ae 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -42,12 +42,16 @@ impl ModelPatch { self } - /// Add a TOML patch (TOML table) to this `ModelPatch`. - pub fn with_toml_patch(mut self, patch: toml::value::Table) -> Self { + /// Add a TOML patch (provided as a string) to this `ModelPatch`. + /// The string will be parsed into a `toml::value::Table`. + pub fn with_toml_patch(mut self, patch_str: impl AsRef) -> Self { assert!( self.toml_patch.is_none(), "TOML patch already set for this ModelPatch" ); + let s = patch_str.as_ref(); + let patch: toml::value::Table = + toml::from_str(s).expect("Failed to parse string passed to with_toml_patch"); assert!( !patch.contains_key("base_model"), "TOML patch cannot contain `base_model` field" @@ -531,7 +535,7 @@ mod tests { let content = r#" base_model = "some/base/path" foo = "bar" - "#; + "#; let mut f = fs::File::create(&p).unwrap(); f.write_all(content.as_bytes()).unwrap(); @@ -549,7 +553,7 @@ mod tests { title = "base" [section] a = 1 - "#; + "#; // Create a patch table let mut patch = toml::value::Table::new(); diff --git a/tests/patch.rs b/tests/patch.rs index 3e1ac21b9..2ee5de186 100644 --- a/tests/patch.rs +++ b/tests/patch.rs @@ -15,30 +15,24 @@ fn get_model_dir_file_patch() -> Result { let assets_patch = FilePatch::new("assets.csv") .delete_row("GASDRV,GBR,A0_GEX,4002.26,2020") .add_row("GASDRV,GBR,A0_GEX,4003.26,2020"); - let model_patch = ModelPatch::new(&base_model_dir).with_file_patch(assets_patch); - let temp_dir = model_patch.build_to_tempdir()?; - Ok(temp_dir) + ModelPatch::new(&base_model_dir) + .with_file_patch(assets_patch) + .build_to_tempdir() } -/// Patch of the "simple" model with a change to the `model.toml` file. +/// Patch of the "simple" model with a change to the `model.toml`. fn get_model_dir_toml_patch() -> Result { let base_model_dir = PathBuf::from("examples/simple"); // Add an extra milestone year (2050) - let mut toml_patch = toml::value::Table::new(); - toml_patch.insert( - "milestone_years".to_string(), - toml::Value::Array(vec![ - toml::Value::Integer(2020), - toml::Value::Integer(2030), - toml::Value::Integer(2040), - toml::Value::Integer(2050), - ]), - ); - let model_patch = ModelPatch::new(&base_model_dir).with_toml_patch(toml_patch); - let temp_dir = model_patch.build_to_tempdir()?; - Ok(temp_dir) + let toml_patch = r#" + milestone_years = [2020, 2030, 2040, 2050] + "#; + + ModelPatch::new(&base_model_dir) + .with_toml_patch(toml_patch) + .build_to_tempdir() } #[test] From 15ed820d0bce05051749b13e200bf97c3680baea Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 11:54:24 +0000 Subject: [PATCH 17/32] Add missing_commodity_patch example --- examples/missing_commodity_patch/README.txt | 6 ++++++ .../agent_commodity_portions_diff.csv | 3 +++ .../missing_commodity_patch/agent_objectives_diff.csv | 3 +++ examples/missing_commodity_patch/agents_diff.csv | 3 +++ examples/missing_commodity_patch/commodities_diff.csv | 3 +++ examples/missing_commodity_patch/model.toml | 1 + .../process_availabilities_diff.csv | 2 ++ examples/missing_commodity_patch/process_flows_diff.csv | 6 ++++++ .../missing_commodity_patch/process_parameters_diff.csv | 4 ++++ examples/missing_commodity_patch/processes_diff.csv | 4 ++++ src/input.rs | 6 ++++++ tests/regression.rs | 6 +++++- tests/regression_missing_commodity_patch.rs | 8 ++++++++ 13 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 examples/missing_commodity_patch/README.txt create mode 100644 examples/missing_commodity_patch/agent_commodity_portions_diff.csv create mode 100644 examples/missing_commodity_patch/agent_objectives_diff.csv create mode 100644 examples/missing_commodity_patch/agents_diff.csv create mode 100644 examples/missing_commodity_patch/commodities_diff.csv create mode 100644 examples/missing_commodity_patch/model.toml create mode 100644 examples/missing_commodity_patch/process_availabilities_diff.csv create mode 100644 examples/missing_commodity_patch/process_flows_diff.csv create mode 100644 examples/missing_commodity_patch/process_parameters_diff.csv create mode 100644 examples/missing_commodity_patch/processes_diff.csv create mode 100644 tests/regression_missing_commodity_patch.rs diff --git a/examples/missing_commodity_patch/README.txt b/examples/missing_commodity_patch/README.txt new file mode 100644 index 000000000..925a37a74 --- /dev/null +++ b/examples/missing_commodity_patch/README.txt @@ -0,0 +1,6 @@ +This model is like the "simple" example, except that it includes a separate biomass supply chain, +which has no capacity in the base year. This is partly done to exemplify how MUSE2 can generate +prices for commodities not utilised in a given year. + +This is identical to the "missing_commodity" example, but written as a patch with "_diff" files and +a `base_model` specified in `model.toml`. diff --git a/examples/missing_commodity_patch/agent_commodity_portions_diff.csv b/examples/missing_commodity_patch/agent_commodity_portions_diff.csv new file mode 100644 index 000000000..97f09fddb --- /dev/null +++ b/examples/missing_commodity_patch/agent_commodity_portions_diff.csv @@ -0,0 +1,3 @@ +,agent_id,commodity_id,years,commodity_portion ++,A0_BPD,BIOPRD,all,1 ++,A0_BPL,BIOPEL,all,1 diff --git a/examples/missing_commodity_patch/agent_objectives_diff.csv b/examples/missing_commodity_patch/agent_objectives_diff.csv new file mode 100644 index 000000000..7bac14c96 --- /dev/null +++ b/examples/missing_commodity_patch/agent_objectives_diff.csv @@ -0,0 +1,3 @@ +,agent_id,years,objective_type,decision_weight,decision_lexico_order ++,A0_BPD,all,lcox,, ++,A0_BPL,all,lcox,, diff --git a/examples/missing_commodity_patch/agents_diff.csv b/examples/missing_commodity_patch/agents_diff.csv new file mode 100644 index 000000000..490862b00 --- /dev/null +++ b/examples/missing_commodity_patch/agents_diff.csv @@ -0,0 +1,3 @@ +,id,description,regions,decision_rule,decision_lexico_tolerance ++,A0_BPD,Biomass producer,all,single, ++,A0_BPL,Biomass pelletiser,all,single, diff --git a/examples/missing_commodity_patch/commodities_diff.csv b/examples/missing_commodity_patch/commodities_diff.csv new file mode 100644 index 000000000..818f65c7c --- /dev/null +++ b/examples/missing_commodity_patch/commodities_diff.csv @@ -0,0 +1,3 @@ +,id,description,type,time_slice_level ++,BIOPRD,Biomass produced,sed,season ++,BIOPEL,Biomass pellets,sed,season diff --git a/examples/missing_commodity_patch/model.toml b/examples/missing_commodity_patch/model.toml new file mode 100644 index 000000000..e0a605b37 --- /dev/null +++ b/examples/missing_commodity_patch/model.toml @@ -0,0 +1 @@ +base_model = "examples/simple" diff --git a/examples/missing_commodity_patch/process_availabilities_diff.csv b/examples/missing_commodity_patch/process_availabilities_diff.csv new file mode 100644 index 000000000..57846d86a --- /dev/null +++ b/examples/missing_commodity_patch/process_availabilities_diff.csv @@ -0,0 +1,2 @@ +,process_id,regions,commission_years,time_slice,limits ++,BIOPLL,all,all,annual,..0.95 diff --git a/examples/missing_commodity_patch/process_flows_diff.csv b/examples/missing_commodity_patch/process_flows_diff.csv new file mode 100644 index 000000000..418fde52d --- /dev/null +++ b/examples/missing_commodity_patch/process_flows_diff.csv @@ -0,0 +1,6 @@ +-,process_id,commodity_id,regions,commission_years,coeff,type,cost ++,BIOPRO,BIOPRD,all,all,1.0,fixed, ++,BIOPLL,BIOPRD,all,all,-1.05,fixed, ++,BIOPLL,BIOPEL,all,all,1.0,fixed, ++,RBIOBL,BIOPEL,all,all,-1.2,fixed, ++,RBIOBL,RSHEAT,all,all,1,fixed, diff --git a/examples/missing_commodity_patch/process_parameters_diff.csv b/examples/missing_commodity_patch/process_parameters_diff.csv new file mode 100644 index 000000000..174922f1d --- /dev/null +++ b/examples/missing_commodity_patch/process_parameters_diff.csv @@ -0,0 +1,4 @@ +,process_id,regions,commission_years,capital_cost,fixed_operating_cost,variable_operating_cost,lifetime,discount_rate ++,BIOPRO,all,all,1.0,0.2,0.25,20,0.09 ++,BIOPLL,all,all,2.0,0.22,0.26,20,0.1 ++,RBIOBL,all,all,60,1.05,0.2,20,0.1 diff --git a/examples/missing_commodity_patch/processes_diff.csv b/examples/missing_commodity_patch/processes_diff.csv new file mode 100644 index 000000000..8df1d14d1 --- /dev/null +++ b/examples/missing_commodity_patch/processes_diff.csv @@ -0,0 +1,4 @@ +,id,description,regions,primary_output,start_year,end_year,capacity_to_activity ++,BIOPRO,Biomass production,all,BIOPRD,2020,2040,1.0 ++,BIOPLL,Biomass pelletiser,all,BIOPEL,2020,2040,1.0 ++,RBIOBL,Biomass boiler,all,RSHEAT,2020,2040,1.0 diff --git a/src/input.rs b/src/input.rs index c324c82a9..6417268a1 100644 --- a/src/input.rs +++ b/src/input.rs @@ -12,6 +12,7 @@ use anyhow::{Context, Result, bail, ensure}; use float_cmp::approx_eq; use indexmap::IndexMap; use itertools::Itertools; +use log::info; use serde::de::{Deserialize, DeserializeOwned, Deserializer}; use std::collections::HashMap; use std::fmt::{self, Write}; @@ -233,8 +234,13 @@ pub fn load_model>(model_dir: P) -> Result<(Model, AssetPool)> { // If `model_params` specifies a `base_dir`, patch the base model to a temporary directory and // load the patched model if model_params.base_model.is_some() { + info!("Patching base model specified in model.toml"); let patch = ModelPatch::from_path(model_dir.as_ref())?; let temp = patch.build_to_tempdir()?; + info!( + "Base model patched to temporary directory at {}", + temp.path().display() + ); return load_model(temp.path()); } diff --git a/tests/regression.rs b/tests/regression.rs index 3bb2a9453..964a77d89 100644 --- a/tests/regression.rs +++ b/tests/regression.rs @@ -47,7 +47,11 @@ fn run_regression_test_debug_opt(example_name: &str, debug_model: bool) { }; handle_example_run_command(example_name, &opts, Some(Settings::default())).unwrap(); - let test_data_dir = PathBuf::from(format!("tests/data/{example_name}")); + // If example_name ends with "_patch", map to the base example name. + let test_data_dir = PathBuf::from(format!( + "tests/data/{}", + example_name.strip_suffix("_patch").unwrap_or(example_name) + )); compare_output_dirs(&output_dir, &test_data_dir, debug_model); } diff --git a/tests/regression_missing_commodity_patch.rs b/tests/regression_missing_commodity_patch.rs new file mode 100644 index 000000000..01e31e5d3 --- /dev/null +++ b/tests/regression_missing_commodity_patch.rs @@ -0,0 +1,8 @@ +//! A regression test for the "missing_commodity_patch" example +mod regression; +use regression::run_regression_test; + +#[test] +fn test_regression_missing_commodity_patch() { + run_regression_test("missing_commodity_patch") +} From 13374a0df32d17ba07556202b394f655fb7801cd Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 11:55:45 +0000 Subject: [PATCH 18/32] Fix error in missing_commodity_patch --- examples/missing_commodity_patch/process_flows_diff.csv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/missing_commodity_patch/process_flows_diff.csv b/examples/missing_commodity_patch/process_flows_diff.csv index 418fde52d..a98e623bf 100644 --- a/examples/missing_commodity_patch/process_flows_diff.csv +++ b/examples/missing_commodity_patch/process_flows_diff.csv @@ -1,4 +1,4 @@ --,process_id,commodity_id,regions,commission_years,coeff,type,cost +,process_id,commodity_id,regions,commission_years,coeff,type,cost +,BIOPRO,BIOPRD,all,all,1.0,fixed, +,BIOPLL,BIOPRD,all,all,-1.05,fixed, +,BIOPLL,BIOPEL,all,all,1.0,fixed, From 47d4496a5f96082bc49b7852690c7eeef1a33fd0 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 12:07:09 +0000 Subject: [PATCH 19/32] Remove diff files interface --- examples/missing_commodity_patch/README.txt | 6 - .../agent_commodity_portions_diff.csv | 3 - .../agent_objectives_diff.csv | 3 - .../missing_commodity_patch/agents_diff.csv | 3 - .../commodities_diff.csv | 3 - examples/missing_commodity_patch/model.toml | 1 - .../process_availabilities_diff.csv | 2 - .../process_flows_diff.csv | 6 - .../process_parameters_diff.csv | 4 - .../processes_diff.csv | 4 - src/input.rs | 15 -- src/model/parameters.rs | 15 -- src/patch.rs | 206 ------------------ tests/regression.rs | 6 +- tests/regression_missing_commodity_patch.rs | 8 - 15 files changed, 1 insertion(+), 284 deletions(-) delete mode 100644 examples/missing_commodity_patch/README.txt delete mode 100644 examples/missing_commodity_patch/agent_commodity_portions_diff.csv delete mode 100644 examples/missing_commodity_patch/agent_objectives_diff.csv delete mode 100644 examples/missing_commodity_patch/agents_diff.csv delete mode 100644 examples/missing_commodity_patch/commodities_diff.csv delete mode 100644 examples/missing_commodity_patch/model.toml delete mode 100644 examples/missing_commodity_patch/process_availabilities_diff.csv delete mode 100644 examples/missing_commodity_patch/process_flows_diff.csv delete mode 100644 examples/missing_commodity_patch/process_parameters_diff.csv delete mode 100644 examples/missing_commodity_patch/processes_diff.csv delete mode 100644 tests/regression_missing_commodity_patch.rs diff --git a/examples/missing_commodity_patch/README.txt b/examples/missing_commodity_patch/README.txt deleted file mode 100644 index 925a37a74..000000000 --- a/examples/missing_commodity_patch/README.txt +++ /dev/null @@ -1,6 +0,0 @@ -This model is like the "simple" example, except that it includes a separate biomass supply chain, -which has no capacity in the base year. This is partly done to exemplify how MUSE2 can generate -prices for commodities not utilised in a given year. - -This is identical to the "missing_commodity" example, but written as a patch with "_diff" files and -a `base_model` specified in `model.toml`. diff --git a/examples/missing_commodity_patch/agent_commodity_portions_diff.csv b/examples/missing_commodity_patch/agent_commodity_portions_diff.csv deleted file mode 100644 index 97f09fddb..000000000 --- a/examples/missing_commodity_patch/agent_commodity_portions_diff.csv +++ /dev/null @@ -1,3 +0,0 @@ -,agent_id,commodity_id,years,commodity_portion -+,A0_BPD,BIOPRD,all,1 -+,A0_BPL,BIOPEL,all,1 diff --git a/examples/missing_commodity_patch/agent_objectives_diff.csv b/examples/missing_commodity_patch/agent_objectives_diff.csv deleted file mode 100644 index 7bac14c96..000000000 --- a/examples/missing_commodity_patch/agent_objectives_diff.csv +++ /dev/null @@ -1,3 +0,0 @@ -,agent_id,years,objective_type,decision_weight,decision_lexico_order -+,A0_BPD,all,lcox,, -+,A0_BPL,all,lcox,, diff --git a/examples/missing_commodity_patch/agents_diff.csv b/examples/missing_commodity_patch/agents_diff.csv deleted file mode 100644 index 490862b00..000000000 --- a/examples/missing_commodity_patch/agents_diff.csv +++ /dev/null @@ -1,3 +0,0 @@ -,id,description,regions,decision_rule,decision_lexico_tolerance -+,A0_BPD,Biomass producer,all,single, -+,A0_BPL,Biomass pelletiser,all,single, diff --git a/examples/missing_commodity_patch/commodities_diff.csv b/examples/missing_commodity_patch/commodities_diff.csv deleted file mode 100644 index 818f65c7c..000000000 --- a/examples/missing_commodity_patch/commodities_diff.csv +++ /dev/null @@ -1,3 +0,0 @@ -,id,description,type,time_slice_level -+,BIOPRD,Biomass produced,sed,season -+,BIOPEL,Biomass pellets,sed,season diff --git a/examples/missing_commodity_patch/model.toml b/examples/missing_commodity_patch/model.toml deleted file mode 100644 index e0a605b37..000000000 --- a/examples/missing_commodity_patch/model.toml +++ /dev/null @@ -1 +0,0 @@ -base_model = "examples/simple" diff --git a/examples/missing_commodity_patch/process_availabilities_diff.csv b/examples/missing_commodity_patch/process_availabilities_diff.csv deleted file mode 100644 index 57846d86a..000000000 --- a/examples/missing_commodity_patch/process_availabilities_diff.csv +++ /dev/null @@ -1,2 +0,0 @@ -,process_id,regions,commission_years,time_slice,limits -+,BIOPLL,all,all,annual,..0.95 diff --git a/examples/missing_commodity_patch/process_flows_diff.csv b/examples/missing_commodity_patch/process_flows_diff.csv deleted file mode 100644 index a98e623bf..000000000 --- a/examples/missing_commodity_patch/process_flows_diff.csv +++ /dev/null @@ -1,6 +0,0 @@ -,process_id,commodity_id,regions,commission_years,coeff,type,cost -+,BIOPRO,BIOPRD,all,all,1.0,fixed, -+,BIOPLL,BIOPRD,all,all,-1.05,fixed, -+,BIOPLL,BIOPEL,all,all,1.0,fixed, -+,RBIOBL,BIOPEL,all,all,-1.2,fixed, -+,RBIOBL,RSHEAT,all,all,1,fixed, diff --git a/examples/missing_commodity_patch/process_parameters_diff.csv b/examples/missing_commodity_patch/process_parameters_diff.csv deleted file mode 100644 index 174922f1d..000000000 --- a/examples/missing_commodity_patch/process_parameters_diff.csv +++ /dev/null @@ -1,4 +0,0 @@ -,process_id,regions,commission_years,capital_cost,fixed_operating_cost,variable_operating_cost,lifetime,discount_rate -+,BIOPRO,all,all,1.0,0.2,0.25,20,0.09 -+,BIOPLL,all,all,2.0,0.22,0.26,20,0.1 -+,RBIOBL,all,all,60,1.05,0.2,20,0.1 diff --git a/examples/missing_commodity_patch/processes_diff.csv b/examples/missing_commodity_patch/processes_diff.csv deleted file mode 100644 index 8df1d14d1..000000000 --- a/examples/missing_commodity_patch/processes_diff.csv +++ /dev/null @@ -1,4 +0,0 @@ -,id,description,regions,primary_output,start_year,end_year,capacity_to_activity -+,BIOPRO,Biomass production,all,BIOPRD,2020,2040,1.0 -+,BIOPLL,Biomass pelletiser,all,BIOPEL,2020,2040,1.0 -+,RBIOBL,Biomass boiler,all,RSHEAT,2020,2040,1.0 diff --git a/src/input.rs b/src/input.rs index 6417268a1..08358facc 100644 --- a/src/input.rs +++ b/src/input.rs @@ -5,14 +5,12 @@ use crate::graph::validate::validate_commodity_graphs_for_model; use crate::graph::{CommoditiesGraph, build_commodity_graphs_for_model}; use crate::id::{HasID, IDLike}; use crate::model::{Model, ModelParameters}; -use crate::patch::ModelPatch; use crate::region::RegionID; use crate::units::UnitType; use anyhow::{Context, Result, bail, ensure}; use float_cmp::approx_eq; use indexmap::IndexMap; use itertools::Itertools; -use log::info; use serde::de::{Deserialize, DeserializeOwned, Deserializer}; use std::collections::HashMap; use std::fmt::{self, Write}; @@ -231,19 +229,6 @@ where pub fn load_model>(model_dir: P) -> Result<(Model, AssetPool)> { let model_params = ModelParameters::from_path(&model_dir)?; - // If `model_params` specifies a `base_dir`, patch the base model to a temporary directory and - // load the patched model - if model_params.base_model.is_some() { - info!("Patching base model specified in model.toml"); - let patch = ModelPatch::from_path(model_dir.as_ref())?; - let temp = patch.build_to_tempdir()?; - info!( - "Base model patched to temporary directory at {}", - temp.path().display() - ); - return load_model(temp.path()); - } - let time_slice_info = read_time_slice_info(model_dir.as_ref())?; let regions = read_regions(model_dir.as_ref())?; let region_ids = regions.keys().cloned().collect(); diff --git a/src/model/parameters.rs b/src/model/parameters.rs index 862434fdd..0311a5c60 100644 --- a/src/model/parameters.rs +++ b/src/model/parameters.rs @@ -58,7 +58,6 @@ define_param_default!(default_mothball_years, u32, 0); #[derive(Debug, Deserialize, PartialEq)] pub struct ModelParameters { /// Milestone years - #[serde(default)] pub milestone_years: Vec, /// Allow known-broken options to be enabled. #[serde(default, rename = "please_give_me_broken_results")] // Can't use constant here :-( @@ -98,10 +97,6 @@ pub struct ModelParameters { /// Number of years an asset can remain unused before being decommissioned #[serde(default = "default_mothball_years")] pub mothball_years: u32, - /// Optional base model directory to use as a starting point, with this model's files applied - /// as patches/diffs. - #[serde(default)] - pub base_model: Option, } /// The strategy used for calculating commodity prices @@ -179,16 +174,6 @@ impl ModelParameters { let file_path = model_dir.as_ref().join(MODEL_PARAMETERS_FILE_NAME); let model_params: ModelParameters = read_toml(&file_path)?; - // If `base_model` is specified, just check that it exists and skip further validation - // as we will do this later on the fully patched model. - if let Some(base_model_path) = &model_params.base_model { - ensure!( - Path::new(base_model_path).is_dir(), - "`base_model` directory not found: {base_model_path}", - ); - return Ok(model_params); - } - // Set flag signalling whether broken model options are allowed or not BROKEN_OPTIONS_ALLOWED .set(model_params.allow_broken_options) diff --git a/src/patch.rs b/src/patch.rs index 1edb668ae..ce7144d10 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -60,39 +60,6 @@ impl ModelPatch { self } - /// Build a `ModelPatch` from a diffs directory. Expects `model.toml` to be present in the - /// diffs directory and to contain a `base_model` string field that points to the base - /// model directory. Also collects all `*_diff.csv` files in the diffs directory into - /// `FilePatch` entries, and any other top-level fields in `model.toml` become the - /// `toml_patch`. - pub fn from_path(diffs_dir: &Path) -> Result { - // Parse patch model.toml and extract base_model + toml patch - let (base_model_dir, toml_patch) = read_patch_toml(&diffs_dir.join("model.toml"))?; - - // Collect all file patches from `*_diff.csv` files in diffs directory - let mut file_patches = Vec::new(); - for entry in fs::read_dir(diffs_dir)? { - let entry = entry?; - let p = entry.path(); - if !p.is_file() { - continue; - } - if let Some(name) = p.file_name().and_then(|n| n.to_str()) - && name.to_lowercase().ends_with("_diff.csv") - { - let fp = FilePatch::from_file(&p) - .with_context(|| format!("Failed to read diff file: {}", p.display()))?; - file_patches.push(fp); - } - } - - Ok(ModelPatch { - base_model_dir, - file_patches, - toml_patch: Some(toml_patch), - }) - } - /// Apply this `ModelPatch` into `out_dir` (creating/overwriting files there). fn build>(&self, out_dir: O) -> Result<()> { let base_dir = self.base_model_dir.as_path(); @@ -195,92 +162,6 @@ impl FilePatch { self } - /// Read a diff file and construct a `Patch`. - pub fn from_file(file_path: &Path) -> Result { - // Extract the base filename by removing the `_diff.csv` suffix - let file_name = file_path - .file_name() - .and_then(|n| n.to_str()) - .context("Invalid filename encoding")?; - ensure!( - file_name.to_lowercase().ends_with("_diff.csv"), - "Diff file must end with '_diff.csv': {file_name}" - ); - let base_name = &file_name[..file_name.len() - "_diff.csv".len()]; - let base_filename = format!("{base_name}.csv"); - - // Read diff CSV file - let mut reader = ReaderBuilder::new() - .trim(Trim::All) - .from_path(file_path) - .with_context(|| input_err_msg(file_path))?; - - // Read header - let diff_header = reader - .headers() - .with_context(|| input_err_msg(Path::new(&base_filename)))?; - ensure!(!diff_header.is_empty(), "Diff file header cannot be empty"); - - // Collect column headers (skip first column which is the diff indicator) - let headers_vec: Vec = diff_header - .iter() - .skip(1) - .map(ToString::to_string) - .collect(); - ensure!( - !headers_vec.is_empty(), - "Diff file must have at least one data column" - ); - - // Collect additions and deletions - let mut to_delete: IndexSet> = IndexSet::new(); - let mut to_add: IndexSet> = IndexSet::new(); - for (line_num, result) in reader.records().enumerate() { - let record = result.with_context(|| { - format!("Error reading record at line {} in diff file", line_num + 2) - })?; - - ensure!( - !record.is_empty(), - "Empty row at line {} in diff file", - line_num + 2 - ); - - // First column is the diff indicator - let diff_indicator = record.get(0).context("Missing diff indicator column")?; - - // Build normalized row vector from the csv record - let row_vec = canonicalize_vec(record.iter().skip(1)); - let row_str = canonicalize_fields(&row_vec); - - match diff_indicator { - "-" => { - ensure!( - to_delete.insert(row_vec.clone()), - "Duplicate deletion entry: {row_str}", - ); - } - "+" => { - ensure!( - to_add.insert(row_vec.clone()), - "Duplicate addition entry: {row_str}", - ); - } - _ => { - bail!("Invalid diff indicator: '{diff_indicator}'. Must be '+' or '-'"); - } - } - } - - // Create Patch object - Ok(FilePatch { - base_filename, - header_row: Some(headers_vec), - to_delete, - to_add, - }) - } - /// Apply this patch to a base model and return the modified CSV as a string. fn apply(&self, base_model_dir: &Path) -> Result { // Read the base file to string @@ -308,23 +189,6 @@ impl FilePatch { } } -/// Read a patch `model.toml` file and return the `base_model` path and remaining -/// table as the toml patch. -fn read_patch_toml(toml_path: &Path) -> Result<(PathBuf, toml::value::Table)> { - let s = fs::read_to_string(toml_path)?; - let val: toml::Value = toml::from_str(&s)?; - match val { - toml::Value::Table(mut tbl) => { - let base = tbl - .remove("base_model") - .and_then(|v| v.as_str().map(PathBuf::from)) - .context("Patch model.toml missing required `base_model` field")?; - Ok((base, tbl)) - } - _ => bail!("Patch TOML must be a table"), - } -} - /// Merge a TOML patch into a base model TOML string and return the merged TOML. fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result { if patch.contains_key("base_model") { @@ -444,56 +308,6 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { mod tests { use super::*; use crate::fixture::assert_error; - use std::io::Write; - use tempfile::tempdir; - - #[test] - fn test_patch_from_file() { - // Create diff file - let temp_dir = tempdir().unwrap(); - let diff_file = temp_dir.path().join("test_diff.csv"); - let content = "diff,col1,col2\n-,val1,val2\n+,val3,val4\n"; - let mut file = fs::File::create(&diff_file).unwrap(); - file.write_all(content.as_bytes()).unwrap(); - - // Parse from the file - let patch = FilePatch::from_file(&diff_file).unwrap(); - - assert_eq!( - patch.header_row.as_ref().map(|v| v.join(",")), - Some("col1,col2".to_string()) - ); - assert_eq!(patch.to_delete.len(), 1); - assert_eq!(patch.to_add.len(), 1); - let del_row = vec!["val1".to_string(), "val2".to_string()]; - let add_row = vec!["val3".to_string(), "val4".to_string()]; - assert!(patch.to_delete.contains(&del_row)); - assert!(patch.to_add.contains(&add_row)); - } - - #[test] - fn test_patch_from_file_whitespace() { - // Create diff file with extra whitespace - let temp_dir = tempdir().unwrap(); - let diff_file = temp_dir.path().join("test_diff.csv"); - let content = " diff , col1 , col2 \n-, item1 , item2 \n+, another1 , another2 \n"; - let mut file = fs::File::create(&diff_file).unwrap(); - file.write_all(content.as_bytes()).unwrap(); - - // Parse from the file - let patch = FilePatch::from_file(&diff_file).unwrap(); - - assert_eq!( - patch.header_row.as_ref().map(|v| v.join(",")), - Some("col1,col2".to_string()) - ); - assert_eq!(patch.to_delete.len(), 1); - assert_eq!(patch.to_add.len(), 1); - let del_row = vec!["item1".to_string(), "item2".to_string()]; - let add_row = vec!["another1".to_string(), "another2".to_string()]; - assert!(patch.to_delete.contains(&del_row)); - assert!(patch.to_add.contains(&add_row)); - } #[test] fn test_modify_base_with_patch() { @@ -527,26 +341,6 @@ mod tests { ); } - #[test] - fn test_read_patch_toml() { - // Create patch TOML file - let td = tempdir().unwrap(); - let p = td.path().join("model.toml"); - let content = r#" - base_model = "some/base/path" - foo = "bar" - "#; - let mut f = fs::File::create(&p).unwrap(); - f.write_all(content.as_bytes()).unwrap(); - - // Read with `read_patch_toml` - // Should return the base_model path and remaining table (without the base_model field) - let (base, tbl) = read_patch_toml(&p).unwrap(); - assert_eq!(base, PathBuf::from("some/base/path")); - assert_eq!(tbl.get("foo").and_then(|v| v.as_str()), Some("bar")); - assert!(!tbl.contains_key("base_model")); - } - #[test] fn test_merge_model_toml_basic() { let base = r#" diff --git a/tests/regression.rs b/tests/regression.rs index 964a77d89..3bb2a9453 100644 --- a/tests/regression.rs +++ b/tests/regression.rs @@ -47,11 +47,7 @@ fn run_regression_test_debug_opt(example_name: &str, debug_model: bool) { }; handle_example_run_command(example_name, &opts, Some(Settings::default())).unwrap(); - // If example_name ends with "_patch", map to the base example name. - let test_data_dir = PathBuf::from(format!( - "tests/data/{}", - example_name.strip_suffix("_patch").unwrap_or(example_name) - )); + let test_data_dir = PathBuf::from(format!("tests/data/{example_name}")); compare_output_dirs(&output_dir, &test_data_dir, debug_model); } diff --git a/tests/regression_missing_commodity_patch.rs b/tests/regression_missing_commodity_patch.rs deleted file mode 100644 index 01e31e5d3..000000000 --- a/tests/regression_missing_commodity_patch.rs +++ /dev/null @@ -1,8 +0,0 @@ -//! A regression test for the "missing_commodity_patch" example -mod regression; -use regression::run_regression_test; - -#[test] -fn test_regression_missing_commodity_patch() { - run_regression_test("missing_commodity_patch") -} From bdc34a6662b66562772c0140268c2e79d26c9b2b Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 13:27:35 +0000 Subject: [PATCH 20/32] Fix test_toml_patch_and_validate --- src/patch.rs | 11 +++++------ tests/patch.rs | 7 ++++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/patch.rs b/src/patch.rs index ce7144d10..f4139fa96 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -1,4 +1,4 @@ -//! Code for applying patches/diffs to model input files. +//! Code for applying patches to model input files. use crate::input::input_err_msg; use anyhow::{Context, Result, bail, ensure}; @@ -113,7 +113,7 @@ impl ModelPatch { } } -/// Structure to hold diffs from a diff file +/// Structure to hold patches to a model csv file. #[derive(Debug)] pub struct FilePatch { /// The target base filename that this patch applies to (e.g. "agents.csv") @@ -234,18 +234,17 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { .trim(Trim::All) .from_reader(base.as_bytes()); - // Read and validate header + // Read header let base_header = reader .headers() .context("Failed to read base file header")?; - let base_header_vec: Vec = base_header.iter().map(ToString::to_string).collect(); // If the patch contains a header, compare it with the base file header. if let Some(ref header_row_vec) = patch.header_row { ensure!( base_header_vec == *header_row_vec, - "Header mismatch: base file has [{}], diff file expects [{}]", + "Header mismatch: base file has [{}], patch has [{}]", base_header_vec.join(", "), header_row_vec.join(", ") ); @@ -337,7 +336,7 @@ mod tests { let result = modify_base_with_patch(base, &patch); assert_error!( result, - "Header mismatch: base file has [col1, col2], diff file expects [col1, col3]" + "Header mismatch: base file has [col1, col2], patch has [col1, col3]" ); } diff --git a/tests/patch.rs b/tests/patch.rs index 2ee5de186..5978ed3fb 100644 --- a/tests/patch.rs +++ b/tests/patch.rs @@ -1,6 +1,7 @@ -//! Integration tests for the `validate` command. +//! Integration tests for the `patch` module. use anyhow::Result; use muse2::cli::handle_validate_command; +use muse2::input::read_toml; use muse2::model::ModelParameters; use muse2::patch::{FilePatch, ModelPatch}; use muse2::settings::Settings; @@ -60,8 +61,8 @@ fn test_toml_patch_and_validate() { let model_dir = get_model_dir_toml_patch().unwrap(); // The appropriate change has been made - let model_params = ModelParameters::from_path(&model_dir).unwrap(); - assert_eq!(model_params.milestone_years, vec![2020, 2030, 2040, 2050]); + let toml: ModelParameters = read_toml(&model_dir.path().join("model.toml")).unwrap(); + assert_eq!(toml.milestone_years, vec![2020, 2030, 2040, 2050]); // Validation should fail let val = handle_validate_command(&model_dir.path(), Some(Settings::default())); From 4be1f4f6796998c1b001b44495c0c9385d2d5b55 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 13:52:36 +0000 Subject: [PATCH 21/32] Simplify error handling a bit --- src/patch.rs | 103 ++++++++++++++++++++------------------------------- 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/src/patch.rs b/src/patch.rs index f4139fa96..9a2fece75 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -1,13 +1,11 @@ //! Code for applying patches to model input files. -use crate::input::input_err_msg; - -use anyhow::{Context, Result, bail, ensure}; +use anyhow::{Context, Result, ensure}; use csv::{ReaderBuilder, Trim, Writer}; use indexmap::IndexSet; use std::fs; use std::path::{Path, PathBuf}; -/// Structure to hold a set of patches to apply to a base model. +/// Struct to hold a set of patches to apply to a base model. pub struct ModelPatch { // The base model directory path base_model_dir: PathBuf, @@ -54,13 +52,13 @@ impl ModelPatch { toml::from_str(s).expect("Failed to parse string passed to with_toml_patch"); assert!( !patch.contains_key("base_model"), - "TOML patch cannot contain `base_model` field" + "TOML patch must not contain a `base_model` field" ); self.toml_patch = Some(patch); self } - /// Apply this `ModelPatch` into `out_dir` (creating/overwriting files there). + /// Build this `ModelPatch` into `out_dir` (creating/overwriting files there). fn build>(&self, out_dir: O) -> Result<()> { let base_dir = self.base_model_dir.as_path(); let out_path = out_dir.as_ref(); @@ -77,29 +75,24 @@ impl ModelPatch { .is_some_and(|ext| ext.eq_ignore_ascii_case("csv")) { let dst_path = out_path.join(entry.file_name()); - fs::copy(&src_path, &dst_path) - .with_context(|| format!("Failed to copy file: {}", src_path.display()))?; + fs::copy(&src_path, &dst_path)?; } } - // Apply toml patch (if any), or copy model.toml from the base model + // Apply toml patch (if any), or copy model.toml unchanged from the base model let base_toml_path = base_dir.join("model.toml"); let out_toml_path = out_path.join("model.toml"); if let Some(toml_patch) = &self.toml_patch { - // Start with model.toml from base model and merge via helper let toml_content = fs::read_to_string(&base_toml_path)?; let merged_toml = merge_model_toml(&toml_content, toml_patch)?; fs::write(&out_toml_path, merged_toml)?; } else { - // No toml patch; copy base model.toml fs::copy(&base_toml_path, &out_toml_path)?; } // Apply file patches for patch in &self.file_patches { - patch.apply_and_save(base_dir, out_path).with_context(|| { - format!("Failed to apply patch to file: {}", patch.base_filename) - })?; + patch.apply_and_save(base_dir, out_path)?; } Ok(()) @@ -138,15 +131,19 @@ impl FilePatch { } } - /// Set the header row for this patch (`header` should be a comma-joined string, e.g. "a,b,c"). + /// Set the header row for this patch (header should be a comma-joined string, e.g. "a,b,c"). pub fn with_header(mut self, header: impl Into) -> Self { + assert!( + self.header_row.is_none(), + "Header already set for this FilePatch", + ); let s = header.into(); let v = s.split(',').map(|s| s.trim().to_string()).collect(); self.header_row = Some(v); self } - /// Add a row to the patch (row is a canonical comma-joined string, e.g. "a,b,c"). + /// Add a row to the patch (row should be a comma-joined string, e.g. "a,b,c"). pub fn add_row(mut self, row: impl Into) -> Self { let s = row.into(); let v = s.split(',').map(|s| s.trim().to_string()).collect(); @@ -154,7 +151,7 @@ impl FilePatch { self } - /// Mark a row for deletion from the base (row is a canonical comma-joined string). + /// Mark a row for deletion from the base (row should be a comma-joined string, e.g. "a,b,c"). pub fn delete_row(mut self, row: impl Into) -> Self { let s = row.into(); let v = s.split(',').map(|s| s.trim().to_string()).collect(); @@ -166,16 +163,16 @@ impl FilePatch { fn apply(&self, base_model_dir: &Path) -> Result { // Read the base file to string let base_path = base_model_dir.join(&self.base_filename); - if !base_path.exists() { - bail!( - "Base file for patching does not exist: {}", - base_path.display() - ); - } - let base = fs::read_to_string(&base_path).with_context(|| input_err_msg(&base_path))?; + ensure!( + base_path.exists() && base_path.is_file(), + "Base file for patching does not exist: {}", + base_path.display() + ); + let base = fs::read_to_string(&base_path)?; // Apply the patch - let modified = modify_base_with_patch(&base, self)?; + let modified = modify_base_with_patch(&base, self) + .with_context(|| format!("Error applying patch to file: {}", self.base_filename))?; Ok(modified) } @@ -183,49 +180,34 @@ impl FilePatch { pub fn apply_and_save(&self, base_model_dir: &Path, out_model_dir: &Path) -> Result<()> { let modified = self.apply(base_model_dir)?; let new_path = out_model_dir.join(&self.base_filename); - fs::write(&new_path, modified) - .with_context(|| format!("Failed to write patched file: {}", new_path.display()))?; + fs::write(&new_path, modified)?; Ok(()) } } /// Merge a TOML patch into a base model TOML string and return the merged TOML. fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result { - if patch.contains_key("base_model") { - bail!("TOML patch cannot contain `base_model` field"); - } + assert!( + !patch.contains_key("base_model"), + "TOML patch cannot contain a `base_model` field" + ); + + // Parse base TOML into a table let mut base_val: toml::Value = toml::from_str(base_toml)?; let base_tbl = base_val .as_table_mut() .context("Base model TOML must be a table")?; + + // Apply patch entries for (k, v) in patch { base_tbl.insert(k.clone(), v.clone()); } + + // Serialize merged TOML back to string let out = toml::to_string_pretty(&base_val)?; Ok(out) } -/// Build a canonical comma-joined string from an iterator of field strings. -fn canonicalize_fields(fields: I) -> String -where - I: IntoIterator, - S: AsRef, -{ - fields - .into_iter() - .map(|s| s.as_ref().trim().to_string()) - .collect::>() - .join(",") -} - -/// Build a canonical vector of trimmed strings from an iterator of field strings. -fn canonicalize_vec<'a, I>(fields: I) -> Vec -where - I: IntoIterator, -{ - fields.into_iter().map(|s| s.trim().to_string()).collect() -} - /// Modify a string representation of a base CSV file by applying a `FilePatch`. /// Preserves the order of rows from the base file, with new rows appended at the end. fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { @@ -252,21 +234,19 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { // Read all rows from base file, preserving order and checking for duplicates let mut base_rows: IndexSet> = IndexSet::new(); - for (line_num, result) in reader.records().enumerate() { - let record = result.with_context(|| { - format!("Error reading record at line {} in base file", line_num + 2) - })?; + for result in reader.records() { + let record = result?; // Create normalized row vector by trimming fields - let row_vec = canonicalize_vec(record.iter()); - let row_str = canonicalize_fields(&row_vec); + let row_vec = record + .iter() + .map(|s| s.trim().to_string()) + .collect::>(); // Check for duplicates ensure!( base_rows.insert(row_vec.clone()), - "Duplicate row at line {} in base file: {}", - line_num + 2, - row_str + "Duplicate row in base file: {row_vec:?}", ); } @@ -274,8 +254,7 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { for del_row in &patch.to_delete { ensure!( !patch.to_add.contains(del_row), - "Row appears in both deletions and additions: {}", - canonicalize_fields(del_row) + "Row appears in both deletions and additions: {del_row:?}", ); } From d3165a1ff0809682bd3e3de5dfae437cbc540658 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 14:08:44 +0000 Subject: [PATCH 22/32] Fix tests --- src/patch.rs | 47 +++++++++++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/patch.rs b/src/patch.rs index 9a2fece75..b6b46992a 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -187,7 +187,7 @@ impl FilePatch { /// Merge a TOML patch into a base model TOML string and return the merged TOML. fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result { - assert!( + ensure!( !patch.contains_key("base_model"), "TOML patch cannot contain a `base_model` field" ); @@ -289,32 +289,33 @@ mod tests { #[test] fn test_modify_base_with_patch() { - let base = "col1,col2\nrow1,row2\nrow3,row4\nrow5,row6\n"; + let base = "col1,col2\nvalue1,value2\nvalue3,value4\nvalue5,value6\n"; + // Create a patch to delete row3,row4 and add row7,row8 let patch = FilePatch::new("test.csv") .with_header("col1,col2") - .delete_row("row3,row4") - .add_row("row7,row8"); + .delete_row("value3,value4") + .add_row("value7,value8"); let modified = modify_base_with_patch(base, &patch).unwrap(); - // Should preserve order: row1,row2 -> row5,row6 -> row7,row8 let lines: Vec<&str> = modified.lines().collect(); - assert_eq!(lines[0], "col1,col2"); - assert_eq!(lines[1], "row1,row2"); - assert_eq!(lines[2], "row5,row6"); - assert_eq!(lines[3], "row7,row8"); - assert!(!modified.contains("row3,row4")); + assert_eq!(lines[0], "col1,col2"); // header is present + assert_eq!(lines[1], "value1,value2"); // unchanged row + assert_eq!(lines[2], "value5,value6"); // unchanged row + assert_eq!(lines[3], "value7,value8"); // added row + assert!(!modified.contains("value3,value4")); // deleted row } #[test] fn test_modify_base_with_patch_mismatched_header() { - let base = "col1,col2\nrow1,row2\n"; + let base = "col1,col2\nvalue1,value2\n"; + + // Create a patch with a mismatched header let patch = FilePatch::new("test.csv").with_header("col1,col3"); - let result = modify_base_with_patch(base, &patch); assert_error!( - result, + modify_base_with_patch(base, &patch), "Header mismatch: base file has [col1, col2], patch has [col1, col3]" ); } @@ -322,15 +323,15 @@ mod tests { #[test] fn test_merge_model_toml_basic() { let base = r#" - title = "base" + field = "data" [section] a = 1 "#; - // Create a patch table + // Create a TOML patch let mut patch = toml::value::Table::new(); patch.insert( - "title".to_string(), + "field".to_string(), toml::Value::String("patched".to_string()), ); patch.insert( @@ -339,18 +340,18 @@ mod tests { ); // Apply patch with `merge_model_toml` - // Should overwrite title and add new_field, but keep section.a + // Should overwrite field and add new_field, but keep section.a let merged = merge_model_toml(base, &patch).unwrap(); - assert!(merged.contains("title = \"patched\"")); + assert!(merged.contains("field = \"patched\"")); assert!(merged.contains("[section]")); assert!(merged.contains("new_field = \"added\"")); } #[test] fn test_merge_rejects_base_model_key() { - let base = r#"title = "base""#; + let base = r#"field = "data""#; - // Create a patch table with a base_model key + // Create a TOML patch with a base_model key let mut patch = toml::value::Table::new(); patch.insert( "base_model".to_string(), @@ -358,7 +359,9 @@ mod tests { ); // `merge_model_toml` should return an error - let res = merge_model_toml(base, &patch); - assert!(res.is_err()); + assert_error!( + merge_model_toml(base, &patch), + "TOML patch cannot contain a `base_model` field" + ); } } From 84ccf2e9c53136e23b66e1b453252d8576f58a7c Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 15:58:30 +0000 Subject: [PATCH 23/32] Small fixes --- src/patch.rs | 69 +++++++++++++++++++++++++++++----------------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/src/patch.rs b/src/patch.rs index b6b46992a..fc58b5675 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -16,7 +16,7 @@ pub struct ModelPatch { } impl ModelPatch { - /// Create a new empty `ModelPatch` with the given base model directory. + /// Create a new empty `ModelPatch` for a base model at the given directory. pub fn new>(base_model_dir: P) -> Self { ModelPatch { base_model_dir: base_model_dir.into(), @@ -63,6 +63,17 @@ impl ModelPatch { let base_dir = self.base_model_dir.as_path(); let out_path = out_dir.as_ref(); + // Apply toml patch (if any), or copy model.toml unchanged from the base model + let base_toml_path = base_dir.join("model.toml"); + let out_toml_path = out_path.join("model.toml"); + if let Some(toml_patch) = &self.toml_patch { + let toml_content = fs::read_to_string(&base_toml_path)?; + let merged_toml = merge_model_toml(&toml_content, toml_patch)?; + fs::write(&out_toml_path, merged_toml)?; + } else { + fs::copy(&base_toml_path, &out_toml_path)?; + } + // Copy all CSV files from the base model into the output directory // Any files with associated patches will be overwritten later for entry in fs::read_dir(base_dir)? { @@ -79,17 +90,6 @@ impl ModelPatch { } } - // Apply toml patch (if any), or copy model.toml unchanged from the base model - let base_toml_path = base_dir.join("model.toml"); - let out_toml_path = out_path.join("model.toml"); - if let Some(toml_patch) = &self.toml_patch { - let toml_content = fs::read_to_string(&base_toml_path)?; - let merged_toml = merge_model_toml(&toml_content, toml_patch)?; - fs::write(&out_toml_path, merged_toml)?; - } else { - fs::copy(&base_toml_path, &out_toml_path)?; - } - // Apply file patches for patch in &self.file_patches { patch.apply_and_save(base_dir, out_path)?; @@ -106,25 +106,24 @@ impl ModelPatch { } } -/// Structure to hold patches to a model csv file. +/// Structure to hold patches for a model csv file. #[derive(Debug)] pub struct FilePatch { - /// The target base filename that this patch applies to (e.g. "agents.csv") - base_filename: String, + /// The file that this patch applies to (e.g. "agents.csv") + filename: String, /// The header row (optional). If `None`, the header is not checked against base files. header_row: Option>, - /// Rows to delete (each row is a vector of canonicalized fields) + /// Rows to delete (each row is a vector of fields) to_delete: IndexSet>, - /// Rows to add (each row is a vector of canonicalized fields) + /// Rows to add (each row is a vector of fields) to_add: IndexSet>, } impl FilePatch { - /// Create a new empty `Patch` with the given `base_filename`. - pub fn new(base_filename: impl Into) -> Self { - let base_filename = base_filename.into(); + /// Create a new empty `Patch` for the given file. + pub fn new(filename: impl Into) -> Self { FilePatch { - base_filename, + filename: filename.into(), header_row: None, to_delete: IndexSet::new(), to_add: IndexSet::new(), @@ -162,7 +161,7 @@ impl FilePatch { /// Apply this patch to a base model and return the modified CSV as a string. fn apply(&self, base_model_dir: &Path) -> Result { // Read the base file to string - let base_path = base_model_dir.join(&self.base_filename); + let base_path = base_model_dir.join(&self.filename); ensure!( base_path.exists() && base_path.is_file(), "Base file for patching does not exist: {}", @@ -172,24 +171,24 @@ impl FilePatch { // Apply the patch let modified = modify_base_with_patch(&base, self) - .with_context(|| format!("Error applying patch to file: {}", self.base_filename))?; + .with_context(|| format!("Error applying patch to file: {}", self.filename))?; Ok(modified) } /// Apply this patch to a base model and save the modified CSV to another directory. pub fn apply_and_save(&self, base_model_dir: &Path, out_model_dir: &Path) -> Result<()> { let modified = self.apply(base_model_dir)?; - let new_path = out_model_dir.join(&self.base_filename); + let new_path = out_model_dir.join(&self.filename); fs::write(&new_path, modified)?; Ok(()) } } -/// Merge a TOML patch into a base model TOML string and return the merged TOML. +/// Merge a TOML patch into a base TOML string and return the merged TOML. fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result { ensure!( !patch.contains_key("base_model"), - "TOML patch cannot contain a `base_model` field" + "TOML patch must not contain a `base_model` field" ); // Parse base TOML into a table @@ -211,18 +210,18 @@ fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result Result { - // Read base file from string, trimming whitespace + // Read base string, trimming whitespace let mut reader = ReaderBuilder::new() .trim(Trim::All) .from_reader(base.as_bytes()); - // Read header + // Extract header from the base string let base_header = reader .headers() .context("Failed to read base file header")?; let base_header_vec: Vec = base_header.iter().map(ToString::to_string).collect(); - // If the patch contains a header, compare it with the base file header. + // If the patch contains a header, compare it with the base header. if let Some(ref header_row_vec) = patch.header_row { ensure!( base_header_vec == *header_row_vec, @@ -232,7 +231,7 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { ); } - // Read all rows from base file, preserving order and checking for duplicates + // Read all rows from the base, preserving order and checking for duplicates let mut base_rows: IndexSet> = IndexSet::new(); for result in reader.records() { let record = result?; @@ -258,6 +257,14 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { ); } + // Ensure every row requested for deletion actually exists in the base file. + for del_row in &patch.to_delete { + ensure!( + base_rows.contains(del_row), + "Row to delete not present in base file: {del_row:?}" + ); + } + // Apply deletions base_rows.retain(|row| !patch.to_delete.contains(row)); @@ -269,7 +276,7 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { ); } - // Serialize CSV output using csv::Writer to ensure correct quoting/escaping + // Serialize CSV output using csv::Writer let mut wtr = Writer::from_writer(vec![]); wtr.write_record(base_header_vec.iter())?; for row in &base_rows { From f04599f30c60c42183a01c09e82382fa4075914e Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Tue, 16 Dec 2025 16:14:47 +0000 Subject: [PATCH 24/32] Remove checks we don't need --- src/patch.rs | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/src/patch.rs b/src/patch.rs index fc58b5675..9ce6297c5 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -50,10 +50,6 @@ impl ModelPatch { let s = patch_str.as_ref(); let patch: toml::value::Table = toml::from_str(s).expect("Failed to parse string passed to with_toml_patch"); - assert!( - !patch.contains_key("base_model"), - "TOML patch must not contain a `base_model` field" - ); self.toml_patch = Some(patch); self } @@ -186,11 +182,6 @@ impl FilePatch { /// Merge a TOML patch into a base TOML string and return the merged TOML. fn merge_model_toml(base_toml: &str, patch: &toml::value::Table) -> Result { - ensure!( - !patch.contains_key("base_model"), - "TOML patch must not contain a `base_model` field" - ); - // Parse base TOML into a table let mut base_val: toml::Value = toml::from_str(base_toml)?; let base_tbl = base_val @@ -353,22 +344,4 @@ mod tests { assert!(merged.contains("[section]")); assert!(merged.contains("new_field = \"added\"")); } - - #[test] - fn test_merge_rejects_base_model_key() { - let base = r#"field = "data""#; - - // Create a TOML patch with a base_model key - let mut patch = toml::value::Table::new(); - patch.insert( - "base_model".to_string(), - toml::Value::String("..".to_string()), - ); - - // `merge_model_toml` should return an error - assert_error!( - merge_model_toml(base, &patch), - "TOML patch cannot contain a `base_model` field" - ); - } } From a9b88bb20e4cb409b9989c8ca68a27e72bb18783 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 18 Dec 2025 15:38:39 +0000 Subject: [PATCH 25/32] Fix issues with tests --- src/model/parameters.rs | 21 ++++++++++--- src/patch.rs | 46 +++++++++++++++++++++++++++ tests/patch.rs | 70 ----------------------------------------- 3 files changed, 63 insertions(+), 74 deletions(-) delete mode 100644 tests/patch.rs diff --git a/src/model/parameters.rs b/src/model/parameters.rs index 0311a5c60..70d6c9ab6 100644 --- a/src/model/parameters.rs +++ b/src/model/parameters.rs @@ -27,6 +27,22 @@ pub fn broken_model_options_allowed() -> bool { .expect("Broken options flag not set") } +/// Set global flag signalling whether broken model options are allowed +/// +/// Can only be called once; subsequent calls will panic (except in tests, where it can be called +/// multiple times so long as the value is the same). +fn set_broken_model_options_flag(allowed: bool) { + let result = BROKEN_OPTIONS_ALLOWED.set(allowed); + if result.is_err() { + if cfg!(test) { + // Sanity check + assert_eq!(allowed, broken_model_options_allowed()); + } else { + panic!("Attempted to set BROKEN_OPTIONS_ALLOWED twice"); + } + } +} + macro_rules! define_unit_param_default { ($name:ident, $type: ty, $value: expr) => { fn $name() -> $type { @@ -174,10 +190,7 @@ impl ModelParameters { let file_path = model_dir.as_ref().join(MODEL_PARAMETERS_FILE_NAME); let model_params: ModelParameters = read_toml(&file_path)?; - // Set flag signalling whether broken model options are allowed or not - BROKEN_OPTIONS_ALLOWED - .set(model_params.allow_broken_options) - .unwrap(); // Will only fail if there is a race condition, which shouldn't happen + set_broken_model_options_flag(model_params.allow_broken_options); model_params .validate() diff --git a/src/patch.rs b/src/patch.rs index 9ce6297c5..fa864a493 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -284,6 +284,10 @@ fn modify_base_with_patch(base: &str, patch: &FilePatch) -> Result { mod tests { use super::*; use crate::fixture::assert_error; + use crate::input::read_toml; + use crate::model::ModelParameters; + use crate::patch::{FilePatch, ModelPatch}; + use std::path::PathBuf; #[test] fn test_modify_base_with_patch() { @@ -344,4 +348,46 @@ mod tests { assert!(merged.contains("[section]")); assert!(merged.contains("new_field = \"added\"")); } + + #[test] + fn test_file_patch() { + let base_model_dir = PathBuf::from("examples/simple"); + + // Patch with a small change to an asset capacity + let assets_patch = FilePatch::new("assets.csv") + .delete_row("GASDRV,GBR,A0_GEX,4002.26,2020") + .add_row("GASDRV,GBR,A0_GEX,4003.26,2020"); + + // Build patched model into a temporary directory + let model_dir = ModelPatch::new(&base_model_dir) + .with_file_patch(assets_patch) + .build_to_tempdir() + .unwrap(); + + // Check that the appropriate change has been made + let assets_path = model_dir.path().join("assets.csv"); + let assets_content = std::fs::read_to_string(assets_path).unwrap(); + assert!(!assets_content.contains("GASDRV,GBR,A0_GEX,4002.26,2020")); + assert!(assets_content.contains("GASDRV,GBR,A0_GEX,4003.26,2020")); + } + + #[test] + fn test_toml_patch() { + let base_model_dir = PathBuf::from("examples/simple"); + + // Patch to add an extra milestone year (2050) + let toml_patch = r#" + milestone_years = [2020, 2030, 2040, 2050] + "#; + + // Build patched model into a temporary directory + let model_dir = ModelPatch::new(&base_model_dir) + .with_toml_patch(toml_patch) + .build_to_tempdir() + .unwrap(); + + // Check that the appropriate change has been made + let toml: ModelParameters = read_toml(&model_dir.path().join("model.toml")).unwrap(); + assert_eq!(toml.milestone_years, vec![2020, 2030, 2040, 2050]); + } } diff --git a/tests/patch.rs b/tests/patch.rs deleted file mode 100644 index 5978ed3fb..000000000 --- a/tests/patch.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! Integration tests for the `patch` module. -use anyhow::Result; -use muse2::cli::handle_validate_command; -use muse2::input::read_toml; -use muse2::model::ModelParameters; -use muse2::patch::{FilePatch, ModelPatch}; -use muse2::settings::Settings; -use std::path::PathBuf; -use tempfile::TempDir; - -/// Patch of the "simple" model with a change to the `assets.csv` file. -fn get_model_dir_file_patch() -> Result { - let base_model_dir = PathBuf::from("examples/simple"); - - // Small change to an asset capacity - let assets_patch = FilePatch::new("assets.csv") - .delete_row("GASDRV,GBR,A0_GEX,4002.26,2020") - .add_row("GASDRV,GBR,A0_GEX,4003.26,2020"); - - ModelPatch::new(&base_model_dir) - .with_file_patch(assets_patch) - .build_to_tempdir() -} - -/// Patch of the "simple" model with a change to the `model.toml`. -fn get_model_dir_toml_patch() -> Result { - let base_model_dir = PathBuf::from("examples/simple"); - - // Add an extra milestone year (2050) - let toml_patch = r#" - milestone_years = [2020, 2030, 2040, 2050] - "#; - - ModelPatch::new(&base_model_dir) - .with_toml_patch(toml_patch) - .build_to_tempdir() -} - -#[test] -fn test_file_patch_and_validate() { - unsafe { std::env::set_var("MUSE2_LOG_LEVEL", "off") }; - - // Model is patched successfully - let model_dir = get_model_dir_file_patch().unwrap(); - - // The appropriate change has been made - let assets_path = model_dir.path().join("assets.csv"); - let assets_content = std::fs::read_to_string(assets_path).unwrap(); - assert!(!assets_content.contains("GASDRV,GBR,A0_GEX,4002.26,2020")); - assert!(assets_content.contains("GASDRV,GBR,A0_GEX,4003.26,2020")); - - // Validation passes - handle_validate_command(&model_dir.path(), Some(Settings::default())).unwrap(); -} - -#[test] -fn test_toml_patch_and_validate() { - unsafe { std::env::set_var("MUSE2_LOG_LEVEL", "off") }; - - // Model is patched successfully - let model_dir = get_model_dir_toml_patch().unwrap(); - - // The appropriate change has been made - let toml: ModelParameters = read_toml(&model_dir.path().join("model.toml")).unwrap(); - assert_eq!(toml.milestone_years, vec![2020, 2030, 2040, 2050]); - - // Validation should fail - let val = handle_validate_command(&model_dir.path(), Some(Settings::default())); - assert!(val.is_err()); -} From 687617e51bd48c64a9d221f193a996024bd004ab Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 18 Dec 2025 17:02:03 +0000 Subject: [PATCH 26/32] Method renames --- src/patch.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/patch.rs b/src/patch.rs index fa864a493..5dfe26f7c 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -139,7 +139,7 @@ impl FilePatch { } /// Add a row to the patch (row should be a comma-joined string, e.g. "a,b,c"). - pub fn add_row(mut self, row: impl Into) -> Self { + pub fn with_addition(mut self, row: impl Into) -> Self { let s = row.into(); let v = s.split(',').map(|s| s.trim().to_string()).collect(); self.to_add.insert(v); @@ -147,7 +147,7 @@ impl FilePatch { } /// Mark a row for deletion from the base (row should be a comma-joined string, e.g. "a,b,c"). - pub fn delete_row(mut self, row: impl Into) -> Self { + pub fn with_deletion(mut self, row: impl Into) -> Self { let s = row.into(); let v = s.split(',').map(|s| s.trim().to_string()).collect(); self.to_delete.insert(v); @@ -296,8 +296,8 @@ mod tests { // Create a patch to delete row3,row4 and add row7,row8 let patch = FilePatch::new("test.csv") .with_header("col1,col2") - .delete_row("value3,value4") - .add_row("value7,value8"); + .with_deletion("value3,value4") + .with_addition("value7,value8"); let modified = modify_base_with_patch(base, &patch).unwrap(); @@ -355,8 +355,8 @@ mod tests { // Patch with a small change to an asset capacity let assets_patch = FilePatch::new("assets.csv") - .delete_row("GASDRV,GBR,A0_GEX,4002.26,2020") - .add_row("GASDRV,GBR,A0_GEX,4003.26,2020"); + .with_deletion("GASDRV,GBR,A0_GEX,4002.26,2020") + .with_addition("GASDRV,GBR,A0_GEX,4003.26,2020"); // Build patched model into a temporary directory let model_dir = ModelPatch::new(&base_model_dir) From 5d3f7c902e39115c6b6424e72595fda5084f8d10 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 18 Dec 2025 17:12:40 +0000 Subject: [PATCH 27/32] Add macros --- src/fixture.rs | 70 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/src/fixture.rs b/src/fixture.rs index c6d5956b9..07e9de9b8 100644 --- a/src/fixture.rs +++ b/src/fixture.rs @@ -6,6 +6,8 @@ use crate::agent::{ }; use crate::asset::{Asset, AssetPool, AssetRef}; use crate::commodity::{Commodity, CommodityID, CommodityLevyMap, CommodityType, DemandMap}; +use crate::model::parameters::ALLOW_BROKEN_OPTION_NAME; +use crate::patch::{FilePatch, ModelPatch}; use crate::process::{ ActivityLimits, Process, ProcessActivityLimitsMap, ProcessFlow, ProcessFlowsMap, ProcessMap, ProcessParameter, ProcessParameterMap, @@ -19,6 +21,7 @@ use crate::units::{ Activity, ActivityPerCapacity, Capacity, Dimensionless, Flow, MoneyPerActivity, MoneyPerCapacity, MoneyPerCapacityPerYear, MoneyPerFlow, Year, }; +use anyhow::Result; use indexmap::indexmap; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; @@ -38,6 +41,49 @@ macro_rules! assert_error { } pub(crate) use assert_error; +/// Build a patched copy of `examples/simple` to a temporary directory and return the `TempDir`. +/// +/// As well as applying the given file patch, this also sets the allow broken options flag in the +/// model TOML to true. +pub(crate) fn build_patched_simple_tempdir( + file_patches: Vec, +) -> Result { + let toml_patch = format!(r#"{name} = true"#, name = ALLOW_BROKEN_OPTION_NAME); + + ModelPatch::new("examples/simple") + .with_file_patches(file_patches) + .with_toml_patch(&toml_patch) + .build_to_tempdir() +} + +/// Check whether the simple example passes or fails validation after applying a file patch +macro_rules! patch_and_validate_simple { + ($file_patches:expr) => {{ + (|| -> Result<()> { + let tmp = crate::fixture::build_patched_simple_tempdir($file_patches)?; + crate::input::load_model(tmp.path())?; + Ok(()) + })() + }}; +} +pub(crate) use patch_and_validate_simple; + +/// Check whether the simple example runs successfully after applying a file patch +macro_rules! patch_and_run_simple { + ($file_patches:expr) => {{ + (|| -> Result<()> { + let tmp = crate::fixture::build_patched_simple_tempdir($file_patches)?; + let (model, assets) = crate::input::load_model(tmp.path())?; + let output_path = tmp.path().join("output"); + std::fs::create_dir_all(&output_path)?; + + crate::simulation::run(&model, assets, &output_path, false)?; + Ok(()) + })() + }}; +} +pub(crate) use patch_and_run_simple; + #[fixture] pub fn region_id() -> RegionID { "GBR".into() @@ -307,3 +353,27 @@ pub fn appraisal_output(asset: Asset, time_slice: TimeSliceID) -> AppraisalOutpu metric: 4.14, } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn patch_and_validate_simple_smoke() { + let patches = Vec::new(); + assert!(patch_and_validate_simple!(patches).is_ok()); + } + + #[test] + fn patch_and_run_simple_smoke() { + let patches = Vec::new(); + assert!(patch_and_run_simple!(patches).is_ok()); + } + + #[test] + fn test_patch_and_validate_simple_fail() { + let patch = FilePatch::new("commodities.csv") + .with_deletion("RSHEAT,Residential heating,svd,daynight"); + assert!(patch_and_validate_simple!(vec![patch]).is_err()); + } +} From 2b445b0346f904f70147b96e1004fc5ba02a783a Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 18 Dec 2025 17:17:27 +0000 Subject: [PATCH 28/32] Small fixes from self-review --- src/fixture.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/fixture.rs b/src/fixture.rs index 2a72575e9..4026c548e 100644 --- a/src/fixture.rs +++ b/src/fixture.rs @@ -45,7 +45,7 @@ pub(crate) use assert_error; /// Build a patched copy of `examples/simple` to a temporary directory and return the `TempDir`. /// -/// As well as applying the given file patch, this also sets the allow broken options flag in the +/// As well as applying the given file patches, this also sets the allow broken options flag in the /// model TOML to true. pub(crate) fn build_patched_simple_tempdir( file_patches: Vec, @@ -58,7 +58,7 @@ pub(crate) fn build_patched_simple_tempdir( .build_to_tempdir() } -/// Check whether the simple example passes or fails validation after applying a file patch +/// Check whether the simple example passes or fails validation after applying file patches macro_rules! patch_and_validate_simple { ($file_patches:expr) => {{ (|| -> Result<()> { @@ -70,7 +70,7 @@ macro_rules! patch_and_validate_simple { } pub(crate) use patch_and_validate_simple; -/// Check whether the simple example runs successfully after applying a file patch +/// Check whether the simple example runs successfully after applying file patches macro_rules! patch_and_run_simple { ($file_patches:expr) => {{ (|| -> Result<()> { @@ -389,4 +389,11 @@ mod tests { .with_deletion("RSHEAT,Residential heating,svd,daynight"); assert!(patch_and_validate_simple!(vec![patch]).is_err()); } + + #[test] + fn test_patch_and_run_simple_fail() { + let patch = FilePatch::new("commodities.csv") + .with_deletion("RSHEAT,Residential heating,svd,daynight"); + assert!(patch_and_run_simple!(vec![patch]).is_err()); + } } From 6dc1e2926ba0a23f6a4a7e7060a0252ae816c55a Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 18 Dec 2025 17:19:50 +0000 Subject: [PATCH 29/32] build_patched_simple_tempdir should panic upon failure --- src/fixture.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/fixture.rs b/src/fixture.rs index 4026c548e..bd969482f 100644 --- a/src/fixture.rs +++ b/src/fixture.rs @@ -47,22 +47,23 @@ pub(crate) use assert_error; /// /// As well as applying the given file patches, this also sets the allow broken options flag in the /// model TOML to true. -pub(crate) fn build_patched_simple_tempdir( - file_patches: Vec, -) -> Result { +/// +/// If the patched model cannot be built, for whatever reason, this function will panic. +pub(crate) fn build_patched_simple_tempdir(file_patches: Vec) -> tempfile::TempDir { let toml_patch = format!(r#"{name} = true"#, name = ALLOW_BROKEN_OPTION_NAME); ModelPatch::new("examples/simple") .with_file_patches(file_patches) .with_toml_patch(&toml_patch) .build_to_tempdir() + .unwrap() } /// Check whether the simple example passes or fails validation after applying file patches macro_rules! patch_and_validate_simple { ($file_patches:expr) => {{ (|| -> Result<()> { - let tmp = crate::fixture::build_patched_simple_tempdir($file_patches)?; + let tmp = crate::fixture::build_patched_simple_tempdir($file_patches); crate::input::load_model(tmp.path())?; Ok(()) })() @@ -74,7 +75,7 @@ pub(crate) use patch_and_validate_simple; macro_rules! patch_and_run_simple { ($file_patches:expr) => {{ (|| -> Result<()> { - let tmp = crate::fixture::build_patched_simple_tempdir($file_patches)?; + let tmp = crate::fixture::build_patched_simple_tempdir($file_patches); let (model, assets) = crate::input::load_model(tmp.path())?; let output_path = tmp.path().join("output"); std::fs::create_dir_all(&output_path)?; From d6bbf9d9777d4e9b26a9ff1c3bc855594bc6e1ef Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 18 Dec 2025 17:26:24 +0000 Subject: [PATCH 30/32] No longer set broken model options in build_patched_simple_tempdir --- src/fixture.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/fixture.rs b/src/fixture.rs index bd969482f..567f1ba5d 100644 --- a/src/fixture.rs +++ b/src/fixture.rs @@ -45,16 +45,10 @@ pub(crate) use assert_error; /// Build a patched copy of `examples/simple` to a temporary directory and return the `TempDir`. /// -/// As well as applying the given file patches, this also sets the allow broken options flag in the -/// model TOML to true. -/// /// If the patched model cannot be built, for whatever reason, this function will panic. pub(crate) fn build_patched_simple_tempdir(file_patches: Vec) -> tempfile::TempDir { - let toml_patch = format!(r#"{name} = true"#, name = ALLOW_BROKEN_OPTION_NAME); - ModelPatch::new("examples/simple") .with_file_patches(file_patches) - .with_toml_patch(&toml_patch) .build_to_tempdir() .unwrap() } From a87a26e15719ff1ba548520189760855a935a96d Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 18 Dec 2025 17:28:25 +0000 Subject: [PATCH 31/32] Remove import --- src/fixture.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fixture.rs b/src/fixture.rs index 567f1ba5d..191eb4a2f 100644 --- a/src/fixture.rs +++ b/src/fixture.rs @@ -8,7 +8,6 @@ use crate::asset::{Asset, AssetPool, AssetRef}; use crate::commodity::{ Commodity, CommodityID, CommodityLevyMap, CommodityType, DemandMap, PricingStrategy, }; -use crate::model::parameters::ALLOW_BROKEN_OPTION_NAME; use crate::patch::{FilePatch, ModelPatch}; use crate::process::{ ActivityLimits, Process, ProcessActivityLimitsMap, ProcessFlow, ProcessFlowsMap, From 589fbcddbf5727fdc813cd2500a78a253fa90939 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 19 Dec 2025 10:29:56 +0000 Subject: [PATCH 32/32] Add from_example method --- src/fixture.rs | 2 +- src/patch.rs | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/fixture.rs b/src/fixture.rs index 191eb4a2f..19c02c5fe 100644 --- a/src/fixture.rs +++ b/src/fixture.rs @@ -46,7 +46,7 @@ pub(crate) use assert_error; /// /// If the patched model cannot be built, for whatever reason, this function will panic. pub(crate) fn build_patched_simple_tempdir(file_patches: Vec) -> tempfile::TempDir { - ModelPatch::new("examples/simple") + ModelPatch::from_example("simple") .with_file_patches(file_patches) .build_to_tempdir() .unwrap() diff --git a/src/patch.rs b/src/patch.rs index 5dfe26f7c..dc2ff023e 100644 --- a/src/patch.rs +++ b/src/patch.rs @@ -25,6 +25,13 @@ impl ModelPatch { } } + /// Create a new empty `ModelPatch` for an example model + #[cfg(test)] + pub fn from_example(name: &str) -> Self { + let base_model_dir = PathBuf::from("examples").join(name); + ModelPatch::new(base_model_dir) + } + /// Add a single `FilePatch` to this `ModelPatch`. pub fn with_file_patch(mut self, patch: FilePatch) -> Self { self.file_patches.push(patch); @@ -287,7 +294,6 @@ mod tests { use crate::input::read_toml; use crate::model::ModelParameters; use crate::patch::{FilePatch, ModelPatch}; - use std::path::PathBuf; #[test] fn test_modify_base_with_patch() { @@ -351,15 +357,13 @@ mod tests { #[test] fn test_file_patch() { - let base_model_dir = PathBuf::from("examples/simple"); - // Patch with a small change to an asset capacity let assets_patch = FilePatch::new("assets.csv") .with_deletion("GASDRV,GBR,A0_GEX,4002.26,2020") .with_addition("GASDRV,GBR,A0_GEX,4003.26,2020"); // Build patched model into a temporary directory - let model_dir = ModelPatch::new(&base_model_dir) + let model_dir = ModelPatch::from_example("simple") .with_file_patch(assets_patch) .build_to_tempdir() .unwrap(); @@ -373,15 +377,13 @@ mod tests { #[test] fn test_toml_patch() { - let base_model_dir = PathBuf::from("examples/simple"); - // Patch to add an extra milestone year (2050) let toml_patch = r#" milestone_years = [2020, 2030, 2040, 2050] "#; // Build patched model into a temporary directory - let model_dir = ModelPatch::new(&base_model_dir) + let model_dir = ModelPatch::from_example("simple") .with_toml_patch(toml_patch) .build_to_tempdir() .unwrap();