From aeaac305891ed8168a60d017e3ba7b4aedf39afe Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Fri, 22 Nov 2024 15:02:18 +0000 Subject: [PATCH 01/11] Use `anyhow` crate for `commodity.rs` --- src/commodity.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/commodity.rs b/src/commodity.rs index 020173b68..b116134dc 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -2,11 +2,11 @@ use crate::demand::{read_demand, Demand}; use crate::input::*; use crate::time_slice::{TimeSliceInfo, TimeSliceLevel, TimeSliceSelection}; +use anyhow::{ensure, Result}; use itertools::Itertools; use serde::Deserialize; use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::{HashMap, HashSet}; -use std::error::Error; use std::ops::RangeInclusive; use std::path::Path; use std::rc::Rc; @@ -82,15 +82,17 @@ impl CommodityCostRaw { region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, year_range: &RangeInclusive, - ) -> Result> { + ) -> Result { let commodity_id = commodity_ids.get_id(&self.commodity_id)?; let region_id = region_ids.get_id(&self.region_id)?; let time_slice = time_slice_info.get_selection(&self.time_slice)?; // Check year is in range - if !year_range.contains(&self.year) { - Err(format!("Year {} is out of range", self.year))?; - } + ensure!( + year_range.contains(&self.year), + "Year {} is out of range", + self.year + ); Ok(CommodityCost { commodity_id, @@ -134,7 +136,7 @@ fn read_commodity_costs_iter( region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, year_range: &RangeInclusive, -) -> Result, Vec>, Box> +) -> Result, Vec>> where I: Iterator, { From 96fd1bedd473a58b89ad7e77394a8b3de5bb0f8d Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Fri, 22 Nov 2024 14:54:42 +0000 Subject: [PATCH 02/11] Disallow non-milestone years for commodity cost entries --- src/commodity.rs | 81 +++++++++++++++++++++++++++--------------------- src/model.rs | 8 ++--- 2 files changed, 47 insertions(+), 42 deletions(-) diff --git a/src/commodity.rs b/src/commodity.rs index b116134dc..34dc8e023 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -2,12 +2,11 @@ use crate::demand::{read_demand, Demand}; use crate::input::*; use crate::time_slice::{TimeSliceInfo, TimeSliceLevel, TimeSliceSelection}; -use anyhow::{ensure, Result}; +use anyhow::Result; use itertools::Itertools; use serde::Deserialize; use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::{HashMap, HashSet}; -use std::ops::RangeInclusive; use std::path::Path; use std::rc::Rc; @@ -81,18 +80,19 @@ impl CommodityCostRaw { commodity_ids: &HashSet>, region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, - year_range: &RangeInclusive, + milestone_years: &[u32], ) -> Result { let commodity_id = commodity_ids.get_id(&self.commodity_id)?; let region_id = region_ids.get_id(&self.region_id)?; let time_slice = time_slice_info.get_selection(&self.time_slice)?; - // Check year is in range - ensure!( - year_range.contains(&self.year), - "Year {} is out of range", - self.year - ); + if milestone_years.binary_search(&self.year).is_err() { + todo!( + "Year {} is not a milestone year. \ + Input of non-milestone years is currently not supported.", + self.year + ); + } Ok(CommodityCost { commodity_id, @@ -135,13 +135,13 @@ fn read_commodity_costs_iter( commodity_ids: &HashSet>, region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, - year_range: &RangeInclusive, + milestone_years: &[u32], ) -> Result, Vec>> where I: Iterator, { iter.map(|cost| { - cost.try_into_commodity_cost(commodity_ids, region_ids, time_slice_info, year_range) + cost.try_into_commodity_cost(commodity_ids, region_ids, time_slice_info, milestone_years) }) // Commodity IDs have already been validated .process_results(|iter| iter.into_id_map(commodity_ids).unwrap()) @@ -155,7 +155,7 @@ where /// * `commodity_ids` - All possible commodity IDs /// * `region_ids` - All possible region IDs /// * `time_slice_info` - Information about time slices -/// * `year_range` - The possible range of milestone years +/// * `milestone_years` - All milestone years /// /// # Returns /// @@ -165,7 +165,7 @@ fn read_commodity_costs( commodity_ids: &HashSet>, region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, - year_range: &RangeInclusive, + milestone_years: &[u32], ) -> HashMap, Vec> { let file_path = model_dir.join(COMMODITY_COSTS_FILE_NAME); read_commodity_costs_iter( @@ -173,7 +173,7 @@ fn read_commodity_costs( commodity_ids, region_ids, time_slice_info, - year_range, + milestone_years, ) .unwrap_input_err(&file_path) } @@ -185,7 +185,7 @@ fn read_commodity_costs( /// * `model_dir` - Folder containing model configuration files /// * `region_ids` - All possible region IDs /// * `time_slice_info` - Information about time slices -/// * `year_range` - The possible range of milestone years +/// * `milestone_years` - All milestone years /// /// # Returns /// @@ -194,7 +194,7 @@ pub fn read_commodities( model_dir: &Path, region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, - year_range: &RangeInclusive, + milestone_years: &[u32], ) -> HashMap, Rc> { let commodities = read_csv_id_file::(&model_dir.join(COMMODITY_FILE_NAME)); let commodity_ids = commodities.keys().cloned().collect(); @@ -203,14 +203,16 @@ pub fn read_commodities( &commodity_ids, region_ids, time_slice_info, - year_range, + milestone_years, ); + + let year_range = *milestone_years.first().unwrap()..=*milestone_years.last().unwrap(); let mut demand = read_demand( model_dir, &commodity_ids, region_ids, time_slice_info, - year_range, + &year_range, ); // Populate Vecs for each Commodity @@ -238,7 +240,7 @@ mod tests { let commodity_ids = ["commodity".into()].into_iter().collect(); let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); let time_slice_info = TimeSliceInfo::default(); - let year_range = 2010..=2020; + let milestone_years = vec![2010, 2020]; // Valid let cost = CommodityCostRaw { @@ -250,7 +252,12 @@ mod tests { value: 5.0, }; assert!(cost - .try_into_commodity_cost(&commodity_ids, ®ion_ids, &time_slice_info, &year_range) + .try_into_commodity_cost( + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years + ) .is_ok()); // Bad commodity @@ -263,7 +270,12 @@ mod tests { value: 5.0, }; assert!(cost - .try_into_commodity_cost(&commodity_ids, ®ion_ids, &time_slice_info, &year_range) + .try_into_commodity_cost( + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years + ) .is_err()); // Bad region @@ -276,7 +288,12 @@ mod tests { value: 5.0, }; assert!(cost - .try_into_commodity_cost(&commodity_ids, ®ion_ids, &time_slice_info, &year_range) + .try_into_commodity_cost( + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years + ) .is_err()); // Bad time slice selection @@ -289,20 +306,12 @@ mod tests { value: 5.0, }; assert!(cost - .try_into_commodity_cost(&commodity_ids, ®ion_ids, &time_slice_info, &year_range) - .is_err()); - - // Bad year - let cost = CommodityCostRaw { - commodity_id: "commodity".into(), - region_id: "GBR".into(), - balance_type: BalanceType::Consumption, - year: 1999, - time_slice: "".into(), - value: 5.0, - }; - assert!(cost - .try_into_commodity_cost(&commodity_ids, ®ion_ids, &time_slice_info, &year_range) + .try_into_commodity_cost( + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years + ) .is_err()); } } diff --git a/src/model.rs b/src/model.rs index 001d14798..01bd2c6a7 100644 --- a/src/model.rs +++ b/src/model.rs @@ -98,12 +98,8 @@ impl Model { let years = &model_file.milestone_years.years; let year_range = *years.first().unwrap()..=*years.last().unwrap(); - let commodities = read_commodities( - model_dir.as_ref(), - ®ion_ids, - &time_slice_info, - &year_range, - ); + let commodities = + read_commodities(model_dir.as_ref(), ®ion_ids, &time_slice_info, years); let processes = read_processes( model_dir.as_ref(), &commodities, From 8652f9c4b29726a0812d856846d456c2f5796ee3 Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Fri, 22 Nov 2024 15:57:39 +0000 Subject: [PATCH 03/11] Add method to iterate over time slices in selection --- src/time_slice.rs | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/src/time_slice.rs b/src/time_slice.rs index 8ff67a5bf..967dafae2 100644 --- a/src/time_slice.rs +++ b/src/time_slice.rs @@ -11,6 +11,7 @@ use serde::Deserialize; use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::{HashMap, HashSet}; use std::fmt::Display; +use std::iter; use std::path::Path; use std::rc::Rc; @@ -108,6 +109,23 @@ impl TimeSliceInfo { Ok(TimeSliceSelection::Season(season)) } } + + /// Iterate over the subset of [`TimeSliceID`] indicated by `selection` + pub fn iter_selection<'a>( + &'a self, + selection: &'a TimeSliceSelection, + ) -> Box + 'a> { + match selection { + TimeSliceSelection::Annual => Box::new(self.fractions.keys().cloned()), + TimeSliceSelection::Season(season) => Box::new( + self.fractions + .keys() + .filter(move |ts| &ts.season == season) + .cloned(), + ), + TimeSliceSelection::Single(ts) => Box::new(iter::once(ts.clone())), + } + } } /// A time slice record retrieved from a CSV file @@ -298,6 +316,41 @@ autumn,evening,0.25" assert_eq!(actual, TimeSliceInfo::default()); } + #[test] + fn test_iter_selection() { + let slices = [ + TimeSliceID { + season: "winter".into(), + time_of_day: "day".into(), + }, + TimeSliceID { + season: "summer".into(), + time_of_day: "night".into(), + }, + ]; + let ts_info = TimeSliceInfo { + seasons: ["winter".into(), "summer".into()].into_iter().collect(), + times_of_day: ["day".into(), "night".into()].into_iter().collect(), + fractions: [(slices[0].clone(), 0.5), (slices[1].clone(), 0.5)] + .into_iter() + .collect(), + }; + + assert_eq!( + HashSet::::from_iter(ts_info.iter_selection(&TimeSliceSelection::Annual)), + HashSet::from_iter(slices.iter().cloned()) + ); + itertools::assert_equal( + ts_info.iter_selection(&TimeSliceSelection::Season("winter".into())), + iter::once(slices[0].clone()), + ); + let ts = ts_info.get_time_slice_id_from_str("summer.night").unwrap(); + itertools::assert_equal( + ts_info.iter_selection(&TimeSliceSelection::Single(ts)), + iter::once(slices[1].clone()), + ); + } + #[test] fn test_check_time_slice_fractions_sum_to_one() { // Single input, valid From f946da8fe84ee2f443b213fbc86121b93f396723 Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Fri, 22 Nov 2024 16:36:30 +0000 Subject: [PATCH 04/11] Return commodity costs as map Closes #153. --- src/commodity.rs | 331 ++++++++++++++++++++++++++++++++--------------- src/process.rs | 4 +- 2 files changed, 227 insertions(+), 108 deletions(-) diff --git a/src/commodity.rs b/src/commodity.rs index 34dc8e023..dc877aaab 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -1,9 +1,8 @@ #![allow(missing_docs)] use crate::demand::{read_demand, Demand}; use crate::input::*; -use crate::time_slice::{TimeSliceInfo, TimeSliceLevel, TimeSliceSelection}; -use anyhow::Result; -use itertools::Itertools; +use crate::time_slice::{TimeSliceID, TimeSliceInfo, TimeSliceLevel}; +use anyhow::{ensure, Result}; use serde::Deserialize; use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::{HashMap, HashSet}; @@ -27,26 +26,43 @@ pub struct Commodity { pub time_slice_level: TimeSliceLevel, #[serde(skip)] - pub costs: Vec, + pub costs: CommodityCostMap, #[serde(skip)] pub demand_by_region: HashMap, Demand>, } define_id_getter! {Commodity} -macro_rules! define_commodity_id_getter { - ($t:ty) => { - impl HasID for $t { - fn get_id(&self) -> &str { - &self.commodity_id - } - } - }; +impl CommodityCostMap { + /// Create a new, empty [`CommodityCostMap`] + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Retrieve a [`CommodityCost`] from the map + pub fn get( + &self, + region_id: Rc, + year: u32, + time_slice: TimeSliceID, + ) -> Option<&CommodityCost> { + let key = CommodityCostKey { + region_id, + year, + time_slice, + }; + self.0.get(&key) + } } -pub(crate) use define_commodity_id_getter; +impl Default for CommodityCostMap { + /// Create a new, empty [`CommodityCostMap`] + fn default() -> Self { + Self::new() + } +} /// Type of balance for application of cost -#[derive(PartialEq, Debug, DeserializeLabeledStringEnum)] +#[derive(PartialEq, Clone, Debug, DeserializeLabeledStringEnum)] pub enum BalanceType { #[string = "net"] Net, @@ -57,7 +73,7 @@ pub enum BalanceType { } /// Cost parameters for each commodity -#[derive(PartialEq, Debug, Deserialize)] +#[derive(PartialEq, Debug, Deserialize, Clone)] struct CommodityCostRaw { /// Unique identifier for the commodity (e.g. "ELC") pub commodity_id: String, @@ -73,49 +89,26 @@ struct CommodityCostRaw { pub value: f64, } -impl CommodityCostRaw { - /// Convert the raw record type into a validated `CommodityCost` type - fn try_into_commodity_cost( - self, - commodity_ids: &HashSet>, - region_ids: &HashSet>, - time_slice_info: &TimeSliceInfo, - milestone_years: &[u32], - ) -> Result { - let commodity_id = commodity_ids.get_id(&self.commodity_id)?; - let region_id = region_ids.get_id(&self.region_id)?; - let time_slice = time_slice_info.get_selection(&self.time_slice)?; - - if milestone_years.binary_search(&self.year).is_err() { - todo!( - "Year {} is not a milestone year. \ - Input of non-milestone years is currently not supported.", - self.year - ); - } - - Ok(CommodityCost { - commodity_id, - region_id, - balance_type: self.balance_type, - year: self.year, - time_slice, - value: self.value, - }) - } -} - /// Cost parameters for each commodity -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Clone, Debug)] pub struct CommodityCost { - pub commodity_id: Rc, - pub region_id: Rc, + /// Type of balance for application of cost. pub balance_type: BalanceType, - pub year: u32, - pub time_slice: TimeSliceSelection, + /// Cost per unit commodity. For example, if a CO2 price is specified in input data, it can be applied to net CO2 via this value. pub value: f64, } -define_commodity_id_getter! {CommodityCost} + +/// Used for looking up [`CommodityCost`]s in a [`CommodityCostMap`] +#[derive(PartialEq, Eq, Hash, Debug)] +struct CommodityCostKey { + region_id: Rc, + year: u32, + time_slice: TimeSliceID, +} + +/// A data structure for easy lookup of [`CommodityCost`]s +#[derive(PartialEq, Debug)] +pub struct CommodityCostMap(HashMap); /// Commodity balance type #[derive(PartialEq, Debug, DeserializeLabeledStringEnum)] @@ -136,15 +129,53 @@ fn read_commodity_costs_iter( region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], -) -> Result, Vec>> +) -> Result, CommodityCostMap>> where I: Iterator, { - iter.map(|cost| { - cost.try_into_commodity_cost(commodity_ids, region_ids, time_slice_info, milestone_years) - }) - // Commodity IDs have already been validated - .process_results(|iter| iter.into_id_map(commodity_ids).unwrap()) + let mut map = HashMap::new(); + + for cost in iter { + let commodity_id = commodity_ids.get_id(&cost.commodity_id)?; + let region_id = region_ids.get_id(&cost.region_id)?; + let ts_selection = time_slice_info.get_selection(&cost.time_slice)?; + + if milestone_years.binary_search(&cost.year).is_err() { + todo!( + "Year {} is not a milestone year. \ + Input of non-milestone years is currently not supported.", + cost.year + ); + } + + // Get or create CommodityCostMap for this commodity + let map = map + .entry(commodity_id) + .or_insert_with(|| CommodityCostMap(HashMap::with_capacity(1))); + + for time_slice in time_slice_info.iter_selection(&ts_selection) { + let key = CommodityCostKey { + region_id: Rc::clone(®ion_id), + year: cost.year, + time_slice: time_slice.clone(), + }; + let value = CommodityCost { + balance_type: cost.balance_type.clone(), + value: cost.value, + }; + + ensure!( + map.0.insert(key, value).is_none(), + "Commodity cost entry covered by more than one time slice \ + (region: {}, year: {}, time slice: {})", + region_id, + cost.year, + time_slice + ); + } + } + + Ok(map) } /// Read costs associated with each commodity from commodity costs CSV file. @@ -166,7 +197,7 @@ fn read_commodity_costs( region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], -) -> HashMap, Vec> { +) -> HashMap, CommodityCostMap> { let file_path = model_dir.join(COMMODITY_COSTS_FILE_NAME); read_commodity_costs_iter( read_csv::(&file_path), @@ -234,84 +265,172 @@ pub fn read_commodities( #[cfg(test)] mod tests { use super::*; + use std::iter; #[test] - fn test_try_into_commodity_cost() { + fn test_commodity_cost_map_get() { + let ts = TimeSliceID { + season: "winter".into(), + time_of_day: "day".into(), + }; + let key = CommodityCostKey { + region_id: "GBR".into(), + year: 2010, + time_slice: ts.clone(), + }; + let value = CommodityCost { + balance_type: BalanceType::Consumption, + value: 0.5, + }; + let map = CommodityCostMap(HashMap::from_iter([(key, value.clone())])); + assert_eq!(map.get("GBR".into(), 2010, ts).unwrap(), &value); + } + + #[test] + fn test_read_commodity_costs_iter() { let commodity_ids = ["commodity".into()].into_iter().collect(); let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); - let time_slice_info = TimeSliceInfo::default(); - let milestone_years = vec![2010, 2020]; + let slices = [ + TimeSliceID { + season: "winter".into(), + time_of_day: "day".into(), + }, + TimeSliceID { + season: "summer".into(), + time_of_day: "night".into(), + }, + ]; + let time_slice_info = TimeSliceInfo { + seasons: ["winter".into(), "summer".into()].into_iter().collect(), + times_of_day: ["day".into(), "night".into()].into_iter().collect(), + fractions: [(slices[0].clone(), 0.5), (slices[1].clone(), 0.5)] + .into_iter() + .collect(), + }; + let time_slice = time_slice_info + .get_time_slice_id_from_str("winter.day") + .unwrap(); + let milestone_years = [2010]; // Valid - let cost = CommodityCostRaw { + let cost1 = CommodityCostRaw { commodity_id: "commodity".into(), region_id: "GBR".into(), balance_type: BalanceType::Consumption, year: 2010, - time_slice: "".into(), - value: 5.0, + time_slice: "winter.day".into(), + value: 0.5, + }; + let cost2 = CommodityCostRaw { + commodity_id: "commodity".into(), + region_id: "FRA".into(), + balance_type: BalanceType::Production, + year: 2010, + time_slice: "winter.day".into(), + value: 0.5, + }; + let key1 = CommodityCostKey { + region_id: "GBR".into(), + year: cost1.year, + time_slice: time_slice.clone(), + }; + let value1 = CommodityCost { + balance_type: cost1.balance_type.clone(), + value: cost1.value, + }; + let key2 = CommodityCostKey { + region_id: "FRA".into(), + year: cost2.year, + time_slice: time_slice.clone(), }; - assert!(cost - .try_into_commodity_cost( + let value2 = CommodityCost { + balance_type: cost2.balance_type.clone(), + value: cost2.value, + }; + let map = CommodityCostMap(HashMap::from_iter([(key1, value1), (key2, value2)])); + let expected = HashMap::from_iter([("commodity".into(), map)]); + assert_eq!( + read_commodity_costs_iter( + [cost1.clone(), cost2].into_iter(), &commodity_ids, ®ion_ids, &time_slice_info, - &milestone_years + &milestone_years, ) - .is_ok()); + .unwrap(), + expected + ); + + // Invalid: Overlapping time slices + let cost2 = CommodityCostRaw { + commodity_id: "commodity".into(), + region_id: "GBR".into(), + balance_type: BalanceType::Production, + year: 2010, + time_slice: "winter".into(), // NB: Covers all winter + value: 0.5, + }; + assert!(read_commodity_costs_iter( + [cost1.clone(), cost2].into_iter(), + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years, + ) + .is_err()); - // Bad commodity + // Invalid: Bad commodity let cost = CommodityCostRaw { commodity_id: "commodity2".into(), region_id: "GBR".into(), - balance_type: BalanceType::Consumption, + balance_type: BalanceType::Production, year: 2010, - time_slice: "".into(), - value: 5.0, + time_slice: "winter.day".into(), + value: 0.5, }; - assert!(cost - .try_into_commodity_cost( - &commodity_ids, - ®ion_ids, - &time_slice_info, - &milestone_years - ) - .is_err()); + assert!(read_commodity_costs_iter( + iter::once(cost), + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years, + ) + .is_err()); - // Bad region + // Invalid: Bad region let cost = CommodityCostRaw { commodity_id: "commodity".into(), region_id: "USA".into(), - balance_type: BalanceType::Consumption, + balance_type: BalanceType::Production, year: 2010, - time_slice: "".into(), - value: 5.0, + time_slice: "winter.day".into(), + value: 0.5, }; - assert!(cost - .try_into_commodity_cost( - &commodity_ids, - ®ion_ids, - &time_slice_info, - &milestone_years - ) - .is_err()); + assert!(read_commodity_costs_iter( + iter::once(cost), + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years, + ) + .is_err()); - // Bad time slice selection + // Invalid: Bad time slice selection let cost = CommodityCostRaw { commodity_id: "commodity".into(), region_id: "GBR".into(), - balance_type: BalanceType::Consumption, + balance_type: BalanceType::Production, year: 2010, - time_slice: "spring".into(), - value: 5.0, + time_slice: "summer.evening".into(), + value: 0.5, }; - assert!(cost - .try_into_commodity_cost( - &commodity_ids, - ®ion_ids, - &time_slice_info, - &milestone_years - ) - .is_err()); + assert!(read_commodity_costs_iter( + iter::once(cost), + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years, + ) + .is_err()); } } diff --git a/src/process.rs b/src/process.rs index 1a573f01b..065fdcece 100644 --- a/src/process.rs +++ b/src/process.rs @@ -438,7 +438,7 @@ pub fn read_processes( #[cfg(test)] mod tests { - use crate::commodity::CommodityType; + use crate::commodity::{CommodityCostMap, CommodityType}; use crate::time_slice::TimeSliceLevel; use super::*; @@ -733,7 +733,7 @@ mod tests { description: "Some description".into(), kind: CommodityType::InputCommodity, time_slice_level: TimeSliceLevel::Annual, - costs: vec![], + costs: CommodityCostMap::new(), demand_by_region: HashMap::new(), }; From 7559188be714271d70491d744aee11c39904a20f Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Fri, 22 Nov 2024 17:17:22 +0000 Subject: [PATCH 05/11] Add test that non-milestone years fail --- src/commodity.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/commodity.rs b/src/commodity.rs index dc877aaab..225366637 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -433,4 +433,29 @@ mod tests { ) .is_err()); } + + #[test] + #[should_panic] + fn test_read_commodity_costs_iter_non_milestone_year() { + let commodity_ids = ["commodity".into()].into_iter().collect(); + let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); + let time_slice_info = TimeSliceInfo::default(); + let milestone_years = [2010, 2020]; + + let cost = CommodityCostRaw { + commodity_id: "commodity".into(), + region_id: "GBR".into(), + balance_type: BalanceType::Consumption, + year: 2011, // NB: Non-milestone year + time_slice: "all-year.all-day".into(), + value: 0.5, + }; + let _ = read_commodity_costs_iter( + iter::once(cost), + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years, + ); + } } From 3c026d4e8742e2f4cec4519621abc87e72ad03de Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Fri, 22 Nov 2024 17:24:28 +0000 Subject: [PATCH 06/11] Require that all milestone years are covered for each commodity + region combo --- examples/simple/commodity_costs.csv | 1 + src/commodity.rs | 60 ++++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/examples/simple/commodity_costs.csv b/examples/simple/commodity_costs.csv index 5464f9cbd..b0a613ffc 100644 --- a/examples/simple/commodity_costs.csv +++ b/examples/simple/commodity_costs.csv @@ -1,2 +1,3 @@ commodity_id,region_id,balance_type,year,time_slice,value CO2EMT,GBR,net,2020,annual,0.04 +CO2EMT,GBR,net,2100,annual,0.04 diff --git a/src/commodity.rs b/src/commodity.rs index 225366637..bb6f948a8 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -135,6 +135,11 @@ where { let mut map = HashMap::new(); + // Keep track of milestone years used for each commodity + region combo. If a user provides an + // entry with a given commodity + region combo for one milestone year, they must also provide + // entries for all the other milestone years. + let mut used_milestone_years = HashMap::new(); + for cost in iter { let commodity_id = commodity_ids.get_id(&cost.commodity_id)?; let region_id = region_ids.get_id(&cost.region_id)?; @@ -150,7 +155,7 @@ where // Get or create CommodityCostMap for this commodity let map = map - .entry(commodity_id) + .entry(commodity_id.clone()) .or_insert_with(|| CommodityCostMap(HashMap::with_capacity(1))); for time_slice in time_slice_info.iter_selection(&ts_selection) { @@ -173,6 +178,23 @@ where time_slice ); } + + // Keep track of milestone years used for each commodity + region combo + used_milestone_years + .entry((commodity_id, region_id)) + .or_insert_with(|| HashSet::with_capacity(1)) + .insert(cost.year); + } + + let milestone_years = HashSet::from_iter(milestone_years.iter().cloned()); + for ((commodity_id, region_id), years) in used_milestone_years.iter() { + if years != &milestone_years { + todo!( + "Commodity costs missing for some milestone years (commodity: {}, region: {})", + commodity_id, + region_id + ); + } } Ok(map) @@ -437,16 +459,50 @@ mod tests { #[test] #[should_panic] fn test_read_commodity_costs_iter_non_milestone_year() { + let commodity_ids = ["commodity".into()].into_iter().collect(); + let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); + let time_slice_info = TimeSliceInfo::default(); + let milestone_years = [2010]; + + let cost1 = CommodityCostRaw { + commodity_id: "commodity".into(), + region_id: "GBR".into(), + balance_type: BalanceType::Consumption, + year: 2010, + time_slice: "all-year.all-day".into(), + value: 0.5, + }; + let cost2 = CommodityCostRaw { + commodity_id: "commodity".into(), + region_id: "GBR".into(), + balance_type: BalanceType::Consumption, + year: 2011, // NB: Non-milestone year + time_slice: "all-year.all-day".into(), + value: 0.5, + }; + let _ = read_commodity_costs_iter( + [cost1, cost2].into_iter(), + &commodity_ids, + ®ion_ids, + &time_slice_info, + &milestone_years, + ); + } + + #[test] + #[should_panic] + fn test_read_commodity_costs_iter_missing_milestone_year() { let commodity_ids = ["commodity".into()].into_iter().collect(); let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); let time_slice_info = TimeSliceInfo::default(); let milestone_years = [2010, 2020]; + // NB: Milestone year 2020 is not covered let cost = CommodityCostRaw { commodity_id: "commodity".into(), region_id: "GBR".into(), balance_type: BalanceType::Consumption, - year: 2011, // NB: Non-milestone year + year: 2010, time_slice: "all-year.all-day".into(), value: 0.5, }; From 399f65312f1e83f3070376faeac30c9df05070c0 Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Mon, 25 Nov 2024 09:10:22 +0000 Subject: [PATCH 07/11] Replace use of `todo!` with regular errors I don't think panicking in these cases really adds anything. --- src/commodity.rs | 66 +++++++++++++++--------------------------------- 1 file changed, 21 insertions(+), 45 deletions(-) diff --git a/src/commodity.rs b/src/commodity.rs index bb6f948a8..a12a8e3c0 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -145,13 +145,12 @@ where let region_id = region_ids.get_id(&cost.region_id)?; let ts_selection = time_slice_info.get_selection(&cost.time_slice)?; - if milestone_years.binary_search(&cost.year).is_err() { - todo!( - "Year {} is not a milestone year. \ + ensure!( + milestone_years.binary_search(&cost.year).is_ok(), + "Year {} is not a milestone year. \ Input of non-milestone years is currently not supported.", - cost.year - ); - } + cost.year + ); // Get or create CommodityCostMap for this commodity let map = map @@ -188,13 +187,12 @@ where let milestone_years = HashSet::from_iter(milestone_years.iter().cloned()); for ((commodity_id, region_id), years) in used_milestone_years.iter() { - if years != &milestone_years { - todo!( - "Commodity costs missing for some milestone years (commodity: {}, region: {})", - commodity_id, - region_id - ); - } + ensure!( + years == &milestone_years, + "Commodity costs missing for some milestone years (commodity: {}, region: {})", + commodity_id, + region_id + ); } Ok(map) @@ -454,64 +452,42 @@ mod tests { &milestone_years, ) .is_err()); - } - #[test] - #[should_panic] - fn test_read_commodity_costs_iter_non_milestone_year() { - let commodity_ids = ["commodity".into()].into_iter().collect(); - let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); - let time_slice_info = TimeSliceInfo::default(); - let milestone_years = [2010]; - - let cost1 = CommodityCostRaw { - commodity_id: "commodity".into(), - region_id: "GBR".into(), - balance_type: BalanceType::Consumption, - year: 2010, - time_slice: "all-year.all-day".into(), - value: 0.5, - }; + // Invalid: non-milestone year let cost2 = CommodityCostRaw { commodity_id: "commodity".into(), region_id: "GBR".into(), balance_type: BalanceType::Consumption, year: 2011, // NB: Non-milestone year - time_slice: "all-year.all-day".into(), + time_slice: "winter.day".into(), value: 0.5, }; - let _ = read_commodity_costs_iter( + assert!(read_commodity_costs_iter( [cost1, cost2].into_iter(), &commodity_ids, ®ion_ids, &time_slice_info, &milestone_years, - ); - } + ) + .is_err()); - #[test] - #[should_panic] - fn test_read_commodity_costs_iter_missing_milestone_year() { - let commodity_ids = ["commodity".into()].into_iter().collect(); - let region_ids = ["GBR".into(), "FRA".into()].into_iter().collect(); - let time_slice_info = TimeSliceInfo::default(); + // Invalid: Milestone year 2020 is not covered let milestone_years = [2010, 2020]; - - // NB: Milestone year 2020 is not covered let cost = CommodityCostRaw { commodity_id: "commodity".into(), region_id: "GBR".into(), balance_type: BalanceType::Consumption, year: 2010, - time_slice: "all-year.all-day".into(), + time_slice: "winter.day".into(), value: 0.5, }; - let _ = read_commodity_costs_iter( + assert!(read_commodity_costs_iter( iter::once(cost), &commodity_ids, ®ion_ids, &time_slice_info, &milestone_years, - ); + ) + .is_err()); } } From 2ae199d775a2034d493d7244ca559740bf817acd Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Tue, 26 Nov 2024 15:14:47 +0000 Subject: [PATCH 08/11] Move `CommodityCostMap`'s `impl` block to be beside its definition --- src/commodity.rs | 53 +++++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 30 deletions(-) diff --git a/src/commodity.rs b/src/commodity.rs index a12a8e3c0..542fc4cd3 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -32,35 +32,6 @@ pub struct Commodity { } define_id_getter! {Commodity} -impl CommodityCostMap { - /// Create a new, empty [`CommodityCostMap`] - pub fn new() -> Self { - Self(HashMap::new()) - } - - /// Retrieve a [`CommodityCost`] from the map - pub fn get( - &self, - region_id: Rc, - year: u32, - time_slice: TimeSliceID, - ) -> Option<&CommodityCost> { - let key = CommodityCostKey { - region_id, - year, - time_slice, - }; - self.0.get(&key) - } -} - -impl Default for CommodityCostMap { - /// Create a new, empty [`CommodityCostMap`] - fn default() -> Self { - Self::new() - } -} - /// Type of balance for application of cost #[derive(PartialEq, Clone, Debug, DeserializeLabeledStringEnum)] pub enum BalanceType { @@ -107,9 +78,31 @@ struct CommodityCostKey { } /// A data structure for easy lookup of [`CommodityCost`]s -#[derive(PartialEq, Debug)] +#[derive(PartialEq, Debug, Default)] pub struct CommodityCostMap(HashMap); +impl CommodityCostMap { + /// Create a new, empty [`CommodityCostMap`] + pub fn new() -> Self { + Self(HashMap::new()) + } + + /// Retrieve a [`CommodityCost`] from the map + pub fn get( + &self, + region_id: Rc, + year: u32, + time_slice: TimeSliceID, + ) -> Option<&CommodityCost> { + let key = CommodityCostKey { + region_id, + year, + time_slice, + }; + self.0.get(&key) + } +} + /// Commodity balance type #[derive(PartialEq, Debug, DeserializeLabeledStringEnum)] pub enum CommodityType { From f2939b2914ef1c3e75933edc0694f1cab477a76d Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Tue, 26 Nov 2024 11:28:40 +0000 Subject: [PATCH 09/11] Add method to iterate over all time slices + return by reference --- src/time_slice.rs | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/time_slice.rs b/src/time_slice.rs index 967dafae2..d4d08a5bc 100644 --- a/src/time_slice.rs +++ b/src/time_slice.rs @@ -110,20 +110,28 @@ impl TimeSliceInfo { } } - /// Iterate over the subset of [`TimeSliceID`] indicated by `selection` + /// Iterate over all [`TimeSliceID`]s. + /// + /// The order will be consistent each time this is called, but not every time the program is + /// run. + pub fn iter(&self) -> impl Iterator { + self.fractions.keys() + } + + /// Iterate over the subset of [`TimeSliceID`] indicated by `selection`. + /// + /// The order will be consistent each time this is called, but not every time the program is + /// run. pub fn iter_selection<'a>( &'a self, selection: &'a TimeSliceSelection, - ) -> Box + 'a> { + ) -> Box + 'a> { match selection { - TimeSliceSelection::Annual => Box::new(self.fractions.keys().cloned()), - TimeSliceSelection::Season(season) => Box::new( - self.fractions - .keys() - .filter(move |ts| &ts.season == season) - .cloned(), - ), - TimeSliceSelection::Single(ts) => Box::new(iter::once(ts.clone())), + TimeSliceSelection::Annual => Box::new(self.iter()), + TimeSliceSelection::Season(season) => { + Box::new(self.iter().filter(move |ts| ts.season == *season)) + } + TimeSliceSelection::Single(ts) => Box::new(iter::once(ts)), } } } @@ -337,17 +345,17 @@ autumn,evening,0.25" }; assert_eq!( - HashSet::::from_iter(ts_info.iter_selection(&TimeSliceSelection::Annual)), - HashSet::from_iter(slices.iter().cloned()) + HashSet::<&TimeSliceID>::from_iter(ts_info.iter_selection(&TimeSliceSelection::Annual)), + HashSet::from_iter(slices.iter()) ); itertools::assert_equal( ts_info.iter_selection(&TimeSliceSelection::Season("winter".into())), - iter::once(slices[0].clone()), + iter::once(&slices[0]), ); let ts = ts_info.get_time_slice_id_from_str("summer.night").unwrap(); itertools::assert_equal( ts_info.iter_selection(&TimeSliceSelection::Single(ts)), - iter::once(slices[1].clone()), + iter::once(&slices[1]), ); } From 812f28e6c3f848c8b72692fe53a7200f197dfe43 Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Thu, 28 Nov 2024 15:22:07 +0000 Subject: [PATCH 10/11] Return `Result` from `read_commodity_costs` too This makes it consistent with the other error handling in this file. --- src/commodity.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/commodity.rs b/src/commodity.rs index 44c080fbc..a18a9dbc0 100644 --- a/src/commodity.rs +++ b/src/commodity.rs @@ -2,7 +2,7 @@ use crate::demand::{read_demand, Demand}; use crate::input::*; use crate::time_slice::{TimeSliceID, TimeSliceInfo, TimeSliceLevel}; -use anyhow::{ensure, Result}; +use anyhow::{ensure, Context, Result}; use serde::Deserialize; use serde_string_enum::DeserializeLabeledStringEnum; use std::collections::{HashMap, HashSet}; @@ -210,7 +210,7 @@ fn read_commodity_costs( region_ids: &HashSet>, time_slice_info: &TimeSliceInfo, milestone_years: &[u32], -) -> HashMap, CommodityCostMap> { +) -> Result, CommodityCostMap>> { let file_path = model_dir.join(COMMODITY_COSTS_FILE_NAME); read_commodity_costs_iter( read_csv::(&file_path), @@ -219,7 +219,7 @@ fn read_commodity_costs( time_slice_info, milestone_years, ) - .unwrap_input_err(&file_path) + .context("Error reading commodity costs") } /// Read commodity data from the specified model directory. @@ -248,7 +248,7 @@ pub fn read_commodities( region_ids, time_slice_info, milestone_years, - ); + )?; let year_range = *milestone_years.first().unwrap()..=*milestone_years.last().unwrap(); let mut demand = read_demand( From f6de9226d6b15d088b85661db5fb500346a3472c Mon Sep 17 00:00:00 2001 From: Alex Dewar Date: Thu, 28 Nov 2024 16:15:12 +0000 Subject: [PATCH 11/11] Fix warning about elided lifetime --- src/time_slice.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/time_slice.rs b/src/time_slice.rs index 8ba87bea5..03a30ba33 100644 --- a/src/time_slice.rs +++ b/src/time_slice.rs @@ -127,7 +127,7 @@ impl TimeSliceInfo { pub fn iter_selection<'a>( &'a self, selection: &'a TimeSliceSelection, - ) -> Box + 'a> { + ) -> Box + 'a> { match selection { TimeSliceSelection::Annual => Box::new(self.iter()), TimeSliceSelection::Season(season) => {