Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ categories = ["science", "simulation", "command-line-utilities"]
csv = "1.3.0"
env_logger = "0.11.3"
float-cmp = "0.9.0"
itertools = "0.13.0"
log = "0.4.22"
log-panics = "2.1.0"
serde = {version = "1.0.202", features = ["derive"]}
serde = {version = "1.0.202", features = ["derive", "rc"]}
serde_string_enum = "0.2.1"
tempfile = "3.10.1"
toml = "0.8.13"
Expand Down
231 changes: 213 additions & 18 deletions src/input.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
//! Common routines for handling input data.
use itertools::Itertools;
use serde::de::{Deserialize, DeserializeOwned, Deserializer};
use serde_string_enum::{DeserializeLabeledStringEnum, SerializeLabeledStringEnum};
use std::collections::HashMap;
use std::collections::HashSet;
use std::error::Error;
use std::fs;
use std::path::Path;
use std::rc::Rc;

/// Read a series of type `T`s from a CSV file.
///
/// # Arguments
///
/// * `file_path` - Path to the CSV file
pub fn read_csv<'a, T: DeserializeOwned + 'a>(file_path: &'a Path) -> impl Iterator<Item = T> + 'a {
csv::Reader::from_path(file_path)
.unwrap_input_err(file_path)
.into_deserialize()
.unwrap_input_err(file_path)
}

/// Read a series of type `T`s from a CSV file into a `Vec<T>`.
///
/// # Arguments
///
/// * `file_path` - Path to the CSV file
pub fn read_csv_as_vec<T: DeserializeOwned>(file_path: &Path) -> Vec<T> {
let mut reader = csv::Reader::from_path(file_path).unwrap_input_err(file_path);

let mut vec = Vec::new();
for result in reader.deserialize() {
let d: T = result.unwrap_input_err(file_path);
vec.push(d)
}
let vec: Vec<T> = read_csv(file_path).collect();

if vec.is_empty() {
input_panic(file_path, "CSV file cannot be empty");
Expand Down Expand Up @@ -79,6 +89,125 @@ impl<T, E: Error> UnwrapInputError<T> for Result<T, E> {
}
}

pub trait UnwrapInputErrorIter<T> {
/// Maps an `Iterator` of `Result`s with an arbitrary `Error` type to an `Iterator<Item = T>`
fn unwrap_input_err(self, file_path: &Path) -> impl Iterator<Item = T>;
}

impl<T, E, I> UnwrapInputErrorIter<T> for I
where
E: Error,
I: Iterator<Item = Result<T, E>>,
{
fn unwrap_input_err(self, file_path: &Path) -> impl Iterator<Item = T> {
self.map(|x| x.unwrap_input_err(file_path))
}
}

/// Indicates that the struct has an ID field
pub trait HasID {
/// Get a string representation of the struct's ID
fn get_id(&self) -> &str;
}

/// Implement the `HasID` trait for the given type, assuming it has a field called `id`
macro_rules! define_id_getter {
($t:ty) => {
impl HasID for $t {
fn get_id(&self) -> &str {
&self.id
}
}
};
}

pub(crate) use define_id_getter;

pub trait IDCollection {
/// Get the ID after checking that it exists this collection. Returns a copy of the `Rc<str>` in
/// `self` or panics on error.
fn get_id_checked(&self, file_path: &Path, id: &str) -> Rc<str>;
}

impl IDCollection for HashSet<Rc<str>> {
fn get_id_checked(&self, file_path: &Path, id: &str) -> Rc<str> {
match self.get(id) {
None => input_panic(file_path, &format!("Unknown ID {id} found")),
Some(id) => Rc::clone(id),
}
}
}

/// Read a CSV file of items with IDs.
///
/// This is like `read_csv_grouped_by_id`, with the difference that it is to be used on the "main"
/// CSV file for a record type, so it assumes that all IDs encountered are valid.
pub fn read_csv_id_file<T>(file_path: &Path) -> HashMap<Rc<str>, T>
where
T: HasID + DeserializeOwned,
{
let mut map = HashMap::new();
for record in read_csv::<T>(file_path) {
let id = record.get_id();

if map.contains_key(id) {
input_panic(file_path, &format!("Duplicate ID found: {id}"));
}

map.insert(id.into(), record);
}
if map.is_empty() {
input_panic(file_path, "CSV file is empty");
}

map
}

pub trait IntoIDMap<T> {
fn into_id_map(self, file_path: &Path, ids: &HashSet<Rc<str>>) -> HashMap<Rc<str>, Vec<T>>;
}

impl<T, I> IntoIDMap<T> for I
where
T: HasID,
I: Iterator<Item = T>,
{
/// Convert the specified iterator into a `HashMap` of the items grouped by ID.
///
/// # Arguments
///
/// `file_path` - The path to the CSV file this relates to
/// `ids` - The set of valid IDs to check against.
fn into_id_map(self, file_path: &Path, ids: &HashSet<Rc<str>>) -> HashMap<Rc<str>, Vec<T>> {
let map = self.into_group_map_by(|item| ids.get_id_checked(file_path, item.get_id()));
if map.is_empty() {
input_panic(file_path, "CSV file is empty");
}

map
}
}

/// Read a CSV file, grouping the entries by ID
///
/// # Arguments
///
/// * `file_path` - Path to CSV file
/// * `ids` - All possible IDs that will be encountered
///
/// # Returns
///
/// A HashMap with ID as a key and a vector of CSV data as a value.
pub fn read_csv_grouped_by_id<T>(
file_path: &Path,
ids: &HashSet<Rc<str>>,
) -> HashMap<Rc<str>, Vec<T>>
where
T: HasID + DeserializeOwned,
{
read_csv(file_path).into_id_map(file_path, ids)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -92,8 +221,14 @@ mod tests {

#[derive(Debug, PartialEq, Deserialize)]
struct Record {
a: u32,
b: String,
id: String,
value: u32,
}

impl HasID for Record {
fn get_id(&self) -> &str {
&self.id
}
}

/// Create an example CSV file in dir_path
Expand All @@ -108,18 +243,18 @@ mod tests {
#[test]
fn test_read_csv_as_vec() {
let dir = tempdir().unwrap();
let file_path = create_csv_file(dir.path(), "a,b\n1,hello\n2,world\n");
let file_path = create_csv_file(dir.path(), "id,value\nhello,1\nworld,2\n");
let records: Vec<Record> = read_csv_as_vec(&file_path);
assert_eq!(
records,
&[
Record {
a: 1,
b: "hello".to_string()
id: "hello".to_string(),
value: 1,
},
Record {
a: 2,
b: "world".to_string()
id: "world".to_string(),
value: 2,
}
]
);
Expand All @@ -130,7 +265,7 @@ mod tests {
#[should_panic]
fn test_read_csv_as_vec_empty() {
let dir = tempdir().unwrap();
let file_path = create_csv_file(dir.path(), "a,b\n");
let file_path = create_csv_file(dir.path(), "id,value\n");
read_csv_as_vec::<Record>(&file_path);
}

Expand All @@ -140,14 +275,14 @@ mod tests {
let file_path = dir.path().join("test.toml");
{
let mut file = File::create(&file_path).unwrap();
writeln!(file, "a = 1\nb = \"hello\"").unwrap();
writeln!(file, "id = \"hello\"\nvalue = 1").unwrap();
}

assert_eq!(
read_toml::<Record>(&file_path),
Record {
a: 1,
b: "hello".to_string()
id: "hello".to_string(),
value: 1,
}
);
}
Expand All @@ -171,4 +306,64 @@ mod tests {
assert!(deserialise_f64(f64::NAN).is_err());
assert!(deserialise_f64(f64::INFINITY).is_err());
}

fn create_ids() -> HashSet<Rc<str>> {
HashSet::from(["A".into(), "B".into()])
}

#[test]
fn test_read_csv_grouped_by_id() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("data.csv");
{
let file_path: &Path = &file_path; // cast
let mut file = File::create(file_path).unwrap();
writeln!(file, "id,value\nA,1\nB,2\nA,3").unwrap();
}

let expected = HashMap::from([
(
"A".into(),
vec![
Record {
id: "A".to_string(),
value: 1,
},
Record {
id: "A".to_string(),
value: 3,
},
],
),
(
"B".into(),
vec![Record {
id: "B".to_string(),
value: 2,
}],
),
]);
let process_ids = create_ids();
let file_path = dir.path().join("data.csv");
let map = read_csv_grouped_by_id::<Record>(&file_path, &process_ids);
assert_eq!(expected, map);
}

#[test]
#[should_panic]
fn test_read_csv_grouped_by_id_duplicate() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("data.csv");
{
let file_path: &Path = &file_path; // cast
let mut file = File::create(file_path).unwrap();

// NB: Process ID "C" isn't valid
writeln!(file, "process_id,value\nA,1\nB,2\nC,3").unwrap();
}

// Check that it fails if a non-existent process ID is provided
let process_ids = create_ids();
read_csv_grouped_by_id::<Record>(&file_path, &process_ids);
}
}
4 changes: 3 additions & 1 deletion src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ use crate::region::{read_regions, Region};
use crate::time_slice::{read_time_slices, TimeSlice};
use log::warn;
use serde::Deserialize;
use std::collections::HashMap;
use std::path::Path;
use std::rc::Rc;

const MODEL_FILE_NAME: &str = "model.toml";

/// Model definition
pub struct Model {
pub milestone_years: Vec<u32>,
pub processes: Vec<Process>,
pub processes: HashMap<Rc<str>, Process>,
pub time_slices: Vec<TimeSlice>,
pub demand_data: Vec<Demand>,
pub regions: Vec<Region>,
Expand Down
Loading