From b45dc179b45f9b1cf9c0ebff2f0732773f6805a1 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Wed, 27 Nov 2024 11:23:01 +0000 Subject: [PATCH 01/16] Enforce that all objectives have an asset dimension --- src/muse/objectives.py | 21 +++++++++++++++------ tests/test_objectives.py | 30 +++++++++++++++++------------- 2 files changed, 32 insertions(+), 19 deletions(-) diff --git a/src/muse/objectives.py b/src/muse/objectives.py index a12530326..519a92720 100644 --- a/src/muse/objectives.py +++ b/src/muse/objectives.py @@ -42,9 +42,8 @@ def comfort( these parameters. Returns: - A DataArray with at least one dimension corresponding to ``replacement``. - Other dimensions can be present, as long as the subsequent decision function knows - how to reduce them. + A DataArray with at least two dimension corresponding to `replacement` and `asset`. + A `timeslice` dimension may also be present. """ __all__ = [ @@ -180,6 +179,8 @@ def decorated_objective(technologies: xr.Dataset, *args, **kwargs) -> xr.DataArr if "replacement" not in result.dims: raise RuntimeError("Objective should return a dimension 'replacement'") + if "asset" not in result.dims: + raise RuntimeError("Objective should return a dimension 'asset'") if "technology" in result.dims: raise RuntimeError("Objective should not return a dimension 'technology'") if "technology" in result.coords: @@ -196,21 +197,25 @@ def decorated_objective(technologies: xr.Dataset, *args, **kwargs) -> xr.DataArr @register_objective def comfort( technologies: xr.Dataset, + demand: xr.DataArray, *args, **kwargs, ) -> xr.DataArray: """Comfort value provided by technologies.""" - return technologies.comfort + result = xr.broadcast(technologies.comfort, demand.asset)[0] + return result @register_objective def efficiency( technologies: xr.Dataset, + demand: xr.DataArray, *args, **kwargs, ) -> xr.DataArray: """Efficiency of the technologies.""" - return technologies.efficiency + result = xr.broadcast(technologies.efficiency, demand.asset)[0] + return result @register_objective(name="capacity") @@ -292,6 +297,7 @@ def fixed_costs( @register_objective def capital_costs( technologies: xr.Dataset, + demand: xr.Dataset, *args, **kwargs, ) -> xr.DataArray: @@ -303,6 +309,7 @@ def capital_costs( simulation for each technology. """ result = technologies.cap_par * (technologies.scaling_size**technologies.cap_exp) + result = xr.broadcast(result, demand.asset)[0] return result @@ -373,10 +380,12 @@ def annual_levelized_cost_of_energy( """ from muse.costs import annual_levelized_cost_of_energy as aLCOE - return filter_input( + result = filter_input( aLCOE(technologies=technologies, prices=prices).max("timeslice"), year=demand.year.item(), ) + result = xr.broadcast(result, demand.asset)[0] + return result @register_objective(name=["LCOE", "LLCOE"]) diff --git a/tests/test_objectives.py b/tests/test_objectives.py index a8030634d..593877f4d 100644 --- a/tests/test_objectives.py +++ b/tests/test_objectives.py @@ -60,11 +60,15 @@ def test_computing_objectives(_technologies, _demand, _prices): from muse.objectives import factory, register_objective @register_objective - def first(technologies, switch=True, *args, **kwargs): - from xarray import full_like + def first(technologies, demand, switch=True, *args, **kwargs): + from xarray import broadcast, full_like value = 1 if switch else 2 - result = full_like(technologies["replacement"], value, dtype=float) + result = full_like( + broadcast(technologies["replacement"], demand["asset"])[0], + value, + dtype=float, + ) return result @register_objective @@ -104,20 +108,20 @@ def second(technologies, demand, assets=None, *args, **kwargs): assert (objectives.second.isel(asset=1) == 5).all() -def test_comfort(_technologies): +def test_comfort(_technologies, _demand): from muse.objectives import comfort _technologies["comfort"] = add_var(_technologies, "replacement") - result = comfort(_technologies) - assert set(result.dims) == {"replacement"} + result = comfort(_technologies, _demand) + assert set(result.dims) == {"replacement", "asset"} -def test_efficiency(_technologies): +def test_efficiency(_technologies, _demand): from muse.objectives import efficiency _technologies["efficiency"] = add_var(_technologies, "replacement") - result = efficiency(_technologies) - assert set(result.dims) == {"replacement"} + result = efficiency(_technologies, _demand) + assert set(result.dims) == {"replacement", "asset"} def test_capacity_to_service_demand(_technologies, _demand): @@ -148,12 +152,12 @@ def test_fixed_costs(_technologies, _demand): assert set(result.dims) == {"replacement", "asset"} -def test_capital_costs(_technologies): +def test_capital_costs(_technologies, _demand): from muse.objectives import capital_costs _technologies["scaling_size"] = add_var(_technologies, "replacement") - result = capital_costs(_technologies) - assert set(result.dims) == {"replacement"} + result = capital_costs(_technologies, _demand) + assert set(result.dims) == {"replacement", "asset"} def test_emission_cost(_technologies, _demand, _prices): @@ -174,7 +178,7 @@ def test_annual_levelized_cost_of_energy(_technologies, _demand, _prices): from muse.objectives import annual_levelized_cost_of_energy result = annual_levelized_cost_of_energy(_technologies, _demand, _prices) - assert set(result.dims) == {"replacement"} + assert set(result.dims) == {"replacement", "asset"} def test_lifetime_levelized_cost_of_energy(_technologies, _demand, _prices): From 2bc4e29815ff974a5bc16fd6a1b7b77b4ecaf96e Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Wed, 27 Nov 2024 11:59:28 +0000 Subject: [PATCH 02/16] Check inputs to objectives --- src/muse/objectives.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/muse/objectives.py b/src/muse/objectives.py index 519a92720..071f39670 100644 --- a/src/muse/objectives.py +++ b/src/muse/objectives.py @@ -167,11 +167,23 @@ def register_objective(function: OBJECTIVE_SIGNATURE): from functools import wraps @wraps(function) - def decorated_objective(technologies: xr.Dataset, *args, **kwargs) -> xr.DataArray: + def decorated_objective( + technologies: xr.Dataset, demand: xr.DataArray, *args, **kwargs + ) -> xr.DataArray: from logging import getLogger - result = function(technologies, *args, **kwargs) + # Check inputs + assert set(demand.dims) == {"asset", "timeslice", "commodity"} + technologies_dims = set(technologies.dims) + assert {"replacement", "commodity"}.issubset( + technologies_dims + ) and technologies_dims <= {"replacement", "commodity", "timeslice"} + + # Calculate objective + result = function(technologies, demand, *args, **kwargs) + result.name = function.__name__ + # Check result dtype = result.values.dtype if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)): msg = f"dtype of objective {function.__name__} is not a number ({dtype})" @@ -187,7 +199,7 @@ def decorated_objective(technologies: xr.Dataset, *args, **kwargs) -> xr.DataArr raise RuntimeError("Objective should not return a coordinate 'technology'") if "year" in result.dims: raise RuntimeError("Objective should not return a dimension 'year'") - result.name = function.__name__ + cache_quantity(**{result.name: result}) return result From aad93730aadf009564dd966f52bfa4e4b1a1a0e2 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Wed, 27 Nov 2024 13:07:03 +0000 Subject: [PATCH 03/16] Temporarily suppress tests --- tests/test_subsector.py | 3 ++- tests/test_trade.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_subsector.py b/tests/test_subsector.py index 3994ece31..f32a2fc09 100644 --- a/tests/test_subsector.py +++ b/tests/test_subsector.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import xarray as xr -from pytest import fixture, raises +from pytest import fixture, mark, raises @fixture @@ -53,6 +53,7 @@ def test_subsector_investing_aggregation(): assert initial.assets.sum() != final.assets.sum() +@mark.xfail # temporary def test_subsector_noninvesting_aggregation(market, model, technologies, tmp_path): """Create some default agents and run subsector. diff --git a/tests/test_trade.py b/tests/test_trade.py index 2398b47d9..be3b14168 100644 --- a/tests/test_trade.py +++ b/tests/test_trade.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr -from pytest import approx, fixture +from pytest import approx, fixture, mark @fixture @@ -136,6 +136,7 @@ def test_power_sector_no_investment(): assert (initial == final).all() +@mark.xfail # temporary def test_power_sector_some_investment(): from muse import examples from muse.utilities import agent_concatenation From 0cef993c971870d9f5450fe2c9da0424439a9c8c Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Wed, 27 Nov 2024 13:08:59 +0000 Subject: [PATCH 04/16] Bump version From e4246e2b4dd22e834b4912c0ef6fd71e8148a231 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:12:20 +0000 Subject: [PATCH 05/16] [pre-commit.ci] pre-commit autoupdate (#579) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [pre-commit.ci] pre-commit autoupdate updates: - [github.com/astral-sh/ruff-pre-commit: v0.7.4 → v0.8.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.7.4...v0.8.0) - [github.com/igorshubovych/markdownlint-cli: v0.42.0 → v0.43.0](https://github.com/igorshubovych/markdownlint-cli/compare/v0.42.0...v0.43.0) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> From 058d24c5749b1f4ef8acebd6af9068bd3805cf08 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 13:21:28 +0000 Subject: [PATCH 06/16] Revert "Bump version" This reverts commit 3d223e640d3f5bdaaa62b19bf16d3bafe3df7eb6. From 1ebabd40f256de25787fa1916b61ae00ca92f2bf Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 14:30:49 +0000 Subject: [PATCH 07/16] Add check_dimensions function --- src/muse/objectives.py | 23 ++++--------- src/muse/utilities.py | 76 ++++++++++++++++++++++++++--------------- tests/test_utilities.py | 23 ++++++++++++- 3 files changed, 77 insertions(+), 45 deletions(-) diff --git a/src/muse/objectives.py b/src/muse/objectives.py index 071f39670..eda460137 100644 --- a/src/muse/objectives.py +++ b/src/muse/objectives.py @@ -71,7 +71,7 @@ def comfort( from muse.outputs.cache import cache_quantity from muse.registration import registrator from muse.timeslices import broadcast_timeslice, distribute_timeslice, drop_timeslice -from muse.utilities import filter_input +from muse.utilities import check_dimensions, filter_input OBJECTIVE_SIGNATURE = Callable[ [xr.Dataset, xr.DataArray, xr.DataArray, KwArg(Any)], xr.DataArray @@ -173,11 +173,10 @@ def decorated_objective( from logging import getLogger # Check inputs - assert set(demand.dims) == {"asset", "timeslice", "commodity"} - technologies_dims = set(technologies.dims) - assert {"replacement", "commodity"}.issubset( - technologies_dims - ) and technologies_dims <= {"replacement", "commodity", "timeslice"} + check_dimensions(demand, ["asset", "timeslice", "commodity"]) + check_dimensions( + technologies, ["replacement", "commodity"], optional=["timeslice"] + ) # Calculate objective result = function(technologies, demand, *args, **kwargs) @@ -188,17 +187,7 @@ def decorated_objective( if not (np.issubdtype(dtype, np.number) or np.issubdtype(dtype, np.bool_)): msg = f"dtype of objective {function.__name__} is not a number ({dtype})" getLogger(function.__module__).warning(msg) - - if "replacement" not in result.dims: - raise RuntimeError("Objective should return a dimension 'replacement'") - if "asset" not in result.dims: - raise RuntimeError("Objective should return a dimension 'asset'") - if "technology" in result.dims: - raise RuntimeError("Objective should not return a dimension 'technology'") - if "technology" in result.coords: - raise RuntimeError("Objective should not return a coordinate 'technology'") - if "year" in result.dims: - raise RuntimeError("Objective should not return a dimension 'year'") + check_dimensions(result, ["replacement", "asset"], optional=["timeslice"]) cache_quantity(**{result.name: result}) return result diff --git a/src/muse/utilities.py b/src/muse/utilities.py index 78ef36b7d..4c9f32d13 100644 --- a/src/muse/utilities.py +++ b/src/muse/utilities.py @@ -1,12 +1,12 @@ """Collection of functions and stand-alone algorithms.""" +from __future__ import annotations + from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence from typing import ( Any, Callable, NamedTuple, - Optional, - Union, cast, ) @@ -14,9 +14,7 @@ import xarray as xr -def multiindex_to_coords( - data: Union[xr.Dataset, xr.DataArray], dimension: str = "asset" -): +def multiindex_to_coords(data: xr.Dataset | xr.DataArray, dimension: str = "asset"): """Flattens multi-index dimension into multi-coord dimension.""" from pandas import MultiIndex @@ -33,8 +31,8 @@ def multiindex_to_coords( def coords_to_multiindex( - data: Union[xr.Dataset, xr.DataArray], dimension: str = "asset" -) -> Union[xr.Dataset, xr.DataArray]: + data: xr.Dataset | xr.DataArray, dimension: str = "asset" +) -> xr.Dataset | xr.DataArray: """Creates a multi-index from flattened multiple coords.""" from pandas import MultiIndex @@ -47,11 +45,11 @@ def coords_to_multiindex( def reduce_assets( - assets: Union[xr.DataArray, xr.Dataset, Sequence[Union[xr.Dataset, xr.DataArray]]], - coords: Optional[Union[str, Sequence[str], Iterable[str]]] = None, + assets: xr.DataArray | xr.Dataset | Sequence[xr.Dataset | xr.DataArray], + coords: str | Sequence[str] | Iterable[str] | None = None, dim: str = "asset", - operation: Optional[Callable] = None, -) -> Union[xr.DataArray, xr.Dataset]: + operation: Callable | None = None, +) -> xr.DataArray | xr.Dataset: r"""Combine assets along given asset dimension. This method simplifies combining assets across multiple agents, or combining assets @@ -182,13 +180,13 @@ def operation(x): def broadcast_techs( - technologies: Union[xr.Dataset, xr.DataArray], - template: Union[xr.DataArray, xr.Dataset], + technologies: xr.Dataset | xr.DataArray, + template: xr.DataArray | xr.Dataset, dimension: str = "asset", interpolation: str = "linear", installed_as_year: bool = True, **kwargs, -) -> Union[xr.Dataset, xr.DataArray]: +) -> xr.Dataset | xr.DataArray: """Broadcasts technologies to the shape of template in given dimension. The dimensions of the technologies are fully explicit, in that each concept @@ -246,7 +244,7 @@ def broadcast_techs( return techs.sel(second_sel) -def clean_assets(assets: xr.Dataset, years: Union[int, Sequence[int]]): +def clean_assets(assets: xr.Dataset, years: int | Sequence[int]): """Cleans up and prepares asset for current iteration. - adds current and forecast year by backfilling missing entries @@ -265,11 +263,11 @@ def clean_assets(assets: xr.Dataset, years: Union[int, Sequence[int]]): def filter_input( - dataset: Union[xr.Dataset, xr.DataArray], - year: Optional[Union[int, Iterable[int]]] = None, + dataset: xr.Dataset | xr.DataArray, + year: int | Iterable[int] | None = None, interpolation: str = "linear", **kwargs, -) -> Union[xr.Dataset, xr.DataArray]: +) -> xr.Dataset | xr.DataArray: """Filter inputs, taking care to interpolate years.""" if year is None: setyear: set[int] = set() @@ -300,8 +298,8 @@ def filter_input( def filter_with_template( - data: Union[xr.Dataset, xr.DataArray], - template: Union[xr.DataArray, xr.Dataset], + data: xr.Dataset | xr.DataArray, + template: xr.DataArray | xr.Dataset, asset_dimension: str = "asset", **kwargs, ): @@ -350,7 +348,7 @@ def tupled_dimension(array: np.ndarray, axis: int): def lexical_comparison( objectives: xr.Dataset, binsize: xr.Dataset, - order: Optional[Sequence[Hashable]] = None, + order: Sequence[Hashable] | None = None, bin_last: bool = True, ) -> xr.DataArray: """Lexical comparison over the objectives. @@ -438,7 +436,7 @@ def avoid_repetitions(data: xr.DataArray, dim: str = "year") -> xr.DataArray: return data.year[years] -def nametuple_to_dict(nametup: Union[Mapping, NamedTuple]) -> Mapping: +def nametuple_to_dict(nametup: Mapping | NamedTuple) -> Mapping: """Transforms a nametuple of type GenericDict into an OrderDict.""" from collections import OrderedDict from dataclasses import asdict, is_dataclass @@ -537,11 +535,11 @@ def future_propagation( def agent_concatenation( - data: Mapping[Hashable, Union[xr.DataArray, xr.Dataset]], + data: Mapping[Hashable, xr.DataArray | xr.Dataset], dim: str = "asset", name: str = "agent", fill_value: Any = 0, -) -> Union[xr.DataArray, xr.Dataset]: +) -> xr.DataArray | xr.Dataset: """Concatenates input map along given dimension. Example: @@ -613,10 +611,10 @@ def agent_concatenation( def aggregate_technology_model( - data: Union[xr.DataArray, xr.Dataset], + data: xr.DataArray | xr.Dataset, dim: str = "asset", - drop: Union[str, Sequence[str]] = "installed", -) -> Union[xr.DataArray, xr.Dataset]: + drop: str | Sequence[str] = "installed", +) -> xr.DataArray | xr.Dataset: """Aggregate together assets with the same installation year. The assets of a given agent, region, and technology but different installation year @@ -659,3 +657,27 @@ def aggregate_technology_model( data, [cast(str, u) for u in data.coords if u not in drop and data[u].dims == (dim,)], ) + + +def check_dimensions( + data: xr.DataArray | xr.Dataset, + required: list[str] = [], + optional: list[str] = [], +): + """Check that an array has the required dimensions. + + This will check that all required dimensions are present, and that no other + dimensions are present, apart from those listed as optional. + + Args: + data: DataArray or Dataset to check dimensions of + required: List of dimension names that must be present + optional: List of dimension names that may be present + """ + present = set(data.dims) + missing = set(required) - present + extra = present - set(required + optional) + if missing: + raise ValueError(f"Missing required dimensions: {missing}") + if extra: + raise ValueError(f"Extra dimensions: {extra}") diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 87c8586b8..f05165e8e 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,6 +1,6 @@ import numpy as np import xarray as xr -from pytest import approx, mark +from pytest import approx, mark, raises def make_array(array): @@ -296,3 +296,24 @@ def test_avoid_repetitions(): assert 3 * len(result.year) == 2 * len(assets.year) original = result.interp(year=assets.year, method="linear") assert (original == assets).all() + + +def test_check_dimensions(): + from muse.utilities import check_dimensions + + data = xr.DataArray( + np.random.rand(4, 5), + dims=["dim1", "dim2"], + coords={"dim1": range(4), "dim2": range(5)}, + ) + + # Valid + check_dimensions(data, required=["dim1"], optional=["dim2"]) + + # Missing required + with raises(ValueError): + check_dimensions(data, required=["dim1", "dim3"], optional=["dim2"]) + + # Extra dimension + with raises(ValueError): + check_dimensions(data, required=["dim1"]) From 129434c6f5d46ca57255723524c0af5f304d8d8b Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 16:19:47 +0000 Subject: [PATCH 08/16] Add checks for demand_share dimensions --- src/muse/demand_share.py | 17 ++++++++++++++++- src/muse/objectives.py | 4 +++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/muse/demand_share.py b/src/muse/demand_share.py index db03a2a25..8c60a0301 100644 --- a/src/muse/demand_share.py +++ b/src/muse/demand_share.py @@ -63,6 +63,7 @@ def demand_share( RetrofitAgentInStandardDemandShare, ) from muse.registration import registrator +from muse.utilities import check_dimensions DEMAND_SHARE_SIGNATURE = Callable[ [Sequence[AbstractAgent], xr.Dataset, xr.Dataset, KwArg(Any)], xr.DataArray @@ -102,7 +103,21 @@ def demand_share( keyword_args = copy(keywords) keyword_args.update(**kwargs) - return function(agents, market, technologies, **keyword_args) + + # Check inputs + check_dimensions(market, ["commodity", "year", "timeslice", "region"]) + check_dimensions( + technologies, + ["technology", "year", "region"], + optional=["timeslice", "commodity"], + ) + + # Calculate demand share + result = function(agents, market, technologies, **keyword_args) + + # Check result + check_dimensions(result, ["asset", "timeslice", "commodity"]) + return result return cast(DEMAND_SHARE_SIGNATURE, demand_share) diff --git a/src/muse/objectives.py b/src/muse/objectives.py index eda460137..08a705f87 100644 --- a/src/muse/objectives.py +++ b/src/muse/objectives.py @@ -173,7 +173,9 @@ def decorated_objective( from logging import getLogger # Check inputs - check_dimensions(demand, ["asset", "timeslice", "commodity"]) + check_dimensions( + demand, ["asset", "timeslice", "commodity"], optional=["region"] + ) check_dimensions( technologies, ["replacement", "commodity"], optional=["timeslice"] ) From 3665aa8857ad124ab22e651729fd7cf3178f4028 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 17:33:38 +0000 Subject: [PATCH 09/16] Add dst_region to check --- src/muse/demand_share.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/muse/demand_share.py b/src/muse/demand_share.py index 8c60a0301..a3a4e34d1 100644 --- a/src/muse/demand_share.py +++ b/src/muse/demand_share.py @@ -105,7 +105,11 @@ def demand_share( keyword_args.update(**kwargs) # Check inputs - check_dimensions(market, ["commodity", "year", "timeslice", "region"]) + check_dimensions( + market, + ["commodity", "year", "timeslice", "region"], + optional=["dst_region"], + ) check_dimensions( technologies, ["technology", "year", "region"], From 19c5466ef5e96c2844525c924907c64f594a0384 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 17:41:11 +0000 Subject: [PATCH 10/16] Fix more checks --- src/muse/demand_share.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/muse/demand_share.py b/src/muse/demand_share.py index a3a4e34d1..82dff7e31 100644 --- a/src/muse/demand_share.py +++ b/src/muse/demand_share.py @@ -113,7 +113,7 @@ def demand_share( check_dimensions( technologies, ["technology", "year", "region"], - optional=["timeslice", "commodity"], + optional=["timeslice", "commodity", "dst_region"], ) # Calculate demand share From 4c51ddaaa52223baffcba2d79a672bc7c141e357 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 17:46:25 +0000 Subject: [PATCH 11/16] And more --- src/muse/demand_share.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/muse/demand_share.py b/src/muse/demand_share.py index 82dff7e31..daad7bb03 100644 --- a/src/muse/demand_share.py +++ b/src/muse/demand_share.py @@ -120,7 +120,9 @@ def demand_share( result = function(agents, market, technologies, **keyword_args) # Check result - check_dimensions(result, ["asset", "timeslice", "commodity"]) + check_dimensions( + result, ["timeslice", "commodity"], optional=["asset", "region"] + ) # asset should be required return result return cast(DEMAND_SHARE_SIGNATURE, demand_share) From 9ee0e09435ad96f5f5e77e83936c9e9aa21e863e Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 17:50:18 +0000 Subject: [PATCH 12/16] Fix function default args --- src/muse/utilities.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/muse/utilities.py b/src/muse/utilities.py index 4c9f32d13..4ed8f8769 100644 --- a/src/muse/utilities.py +++ b/src/muse/utilities.py @@ -661,10 +661,10 @@ def aggregate_technology_model( def check_dimensions( data: xr.DataArray | xr.Dataset, - required: list[str] = [], - optional: list[str] = [], + required: list[str] | None = None, + optional: list[str] | None = None, ): - """Check that an array has the required dimensions. + """Ensure that an array has the required dimensions. This will check that all required dimensions are present, and that no other dimensions are present, apart from those listed as optional. @@ -674,6 +674,11 @@ def check_dimensions( required: List of dimension names that must be present optional: List of dimension names that may be present """ + if required is None: + required = [] + if optional is None: + optional = [] + present = set(data.dims) missing = set(required) - present extra = present - set(required + optional) From d0d22dbbbe4d81d2eca9626cb9faa8c7af0d6dbd Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 18:29:50 +0000 Subject: [PATCH 13/16] Remove xfail mark --- tests/test_subsector.py | 3 +-- tests/test_trade.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_subsector.py b/tests/test_subsector.py index f32a2fc09..3994ece31 100644 --- a/tests/test_subsector.py +++ b/tests/test_subsector.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import xarray as xr -from pytest import fixture, mark, raises +from pytest import fixture, raises @fixture @@ -53,7 +53,6 @@ def test_subsector_investing_aggregation(): assert initial.assets.sum() != final.assets.sum() -@mark.xfail # temporary def test_subsector_noninvesting_aggregation(market, model, technologies, tmp_path): """Create some default agents and run subsector. diff --git a/tests/test_trade.py b/tests/test_trade.py index be3b14168..2398b47d9 100644 --- a/tests/test_trade.py +++ b/tests/test_trade.py @@ -3,7 +3,7 @@ import numpy as np import xarray as xr -from pytest import approx, fixture, mark +from pytest import approx, fixture @fixture @@ -136,7 +136,6 @@ def test_power_sector_no_investment(): assert (initial == final).all() -@mark.xfail # temporary def test_power_sector_some_investment(): from muse import examples from muse.utilities import agent_concatenation From 9051ae21c64f12465a20af0e8892072918808eaf Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Thu, 28 Nov 2024 18:54:59 +0000 Subject: [PATCH 14/16] More descriptive comment --- src/muse/demand_share.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/muse/demand_share.py b/src/muse/demand_share.py index daad7bb03..9171a816f 100644 --- a/src/muse/demand_share.py +++ b/src/muse/demand_share.py @@ -122,7 +122,7 @@ def demand_share( # Check result check_dimensions( result, ["timeslice", "commodity"], optional=["asset", "region"] - ) # asset should be required + ) # TODO: asset should be required, but trade model is failing return result return cast(DEMAND_SHARE_SIGNATURE, demand_share) From 8bc99d202fc699ee57acd953e281ca2c362357f2 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 29 Nov 2024 12:01:06 +0000 Subject: [PATCH 15/16] Address reviewer comments --- src/muse/utilities.py | 9 ++------- tests/test_utilities.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/muse/utilities.py b/src/muse/utilities.py index 4ed8f8769..06f45a295 100644 --- a/src/muse/utilities.py +++ b/src/muse/utilities.py @@ -661,8 +661,8 @@ def aggregate_technology_model( def check_dimensions( data: xr.DataArray | xr.Dataset, - required: list[str] | None = None, - optional: list[str] | None = None, + required: Iterable[str] = (), + optional: Iterable[str] = (), ): """Ensure that an array has the required dimensions. @@ -674,11 +674,6 @@ def check_dimensions( required: List of dimension names that must be present optional: List of dimension names that may be present """ - if required is None: - required = [] - if optional is None: - optional = [] - present = set(data.dims) missing = set(required) - present extra = present - set(required + optional) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index f05165e8e..4dd15af4d 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -311,9 +311,9 @@ def test_check_dimensions(): check_dimensions(data, required=["dim1"], optional=["dim2"]) # Missing required - with raises(ValueError): + with raises(ValueError, match="Missing required dimensions"): check_dimensions(data, required=["dim1", "dim3"], optional=["dim2"]) # Extra dimension - with raises(ValueError): + with raises(ValueError, match="Extra dimensions"): check_dimensions(data, required=["dim1"]) From 66b25e5efaceadd184d7266af6c23d278ec18f49 Mon Sep 17 00:00:00 2001 From: Tom Bland Date: Fri, 29 Nov 2024 12:06:31 +0000 Subject: [PATCH 16/16] Allow function to work with any iterable of strings --- src/muse/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/muse/utilities.py b/src/muse/utilities.py index 06f45a295..9459196f9 100644 --- a/src/muse/utilities.py +++ b/src/muse/utilities.py @@ -676,8 +676,8 @@ def check_dimensions( """ present = set(data.dims) missing = set(required) - present - extra = present - set(required + optional) if missing: raise ValueError(f"Missing required dimensions: {missing}") + extra = present - set(required) - set(optional) if extra: raise ValueError(f"Extra dimensions: {extra}")