diff --git a/docs/api_reference.md b/docs/api_reference.md index 6ee0077..6ea2db1 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -12,10 +12,14 @@ hide: ::: diffwofost.physical_models.crop.phenology.DVS_Phenology +::: diffwofost.physical_models.crop.partitioning.DVS_Partitioning + ## **Utility (under development)** ::: diffwofost.physical_models.config.Configuration +::: diffwofost.physical_models.config.ComputeConfig + ::: diffwofost.physical_models.engine.Engine ::: diffwofost.physical_models.utils.EngineTestHelper diff --git a/src/diffwofost/physical_models/crop/partitioning.py b/src/diffwofost/physical_models/crop/partitioning.py new file mode 100644 index 0000000..52fbbc6 --- /dev/null +++ b/src/diffwofost/physical_models/crop/partitioning.py @@ -0,0 +1,356 @@ +from collections import namedtuple +from warnings import warn +import torch +from pcse import exceptions as exc +from pcse.base import ParamTemplate +from pcse.base import SimulationObject +from pcse.base import StatesTemplate +from pcse.decorators import prepare_states +from pcse.traitlets import Any +from pcse.traitlets import Instance +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_params_shape + + +# Template for namedtuple containing partitioning factors +class PartioningFactors(namedtuple("partitioning_factors", "FR FL FS FO")): + pass + + +def _first_tensor_item(x: torch.Tensor) -> float: + """Returns the first element of a tensor as a python float (for logging).""" + if not isinstance(x, torch.Tensor): + x = torch.as_tensor(x) + if x.dim() == 0: + return x.item() + return x.reshape(-1)[0].item() + + +class _BaseDVSPartitioning(SimulationObject): + """Shared implementation for DVS-based partitioning. + + This is intentionally private: it exists to avoid code duplication between + the public partitioning classes. + """ + + params_shape = None # Shape of the parameters tensors + + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + + class Parameters(ParamTemplate): + FRTB = AfgenTrait() + FLTB = AfgenTrait() + FSTB = AfgenTrait() + FOTB = AfgenTrait() + + def __init__(self, parvalues): + super().__init__(parvalues) + + class StateVariables(StatesTemplate): + FR = Any() + FL = Any() + FS = Any() + FO = Any() + PF = Instance(PartioningFactors) + + def __init__(self, kiosk, publish=None, **kwargs): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + if "FR" not in kwargs: + kwargs["FR"] = torch.tensor(-99.0, dtype=dtype, device=device) + if "FL" not in kwargs: + kwargs["FL"] = torch.tensor(-99.0, dtype=dtype, device=device) + if "FS" not in kwargs: + kwargs["FS"] = torch.tensor(-99.0, dtype=dtype, device=device) + if "FO" not in kwargs: + kwargs["FO"] = torch.tensor(-99.0, dtype=dtype, device=device) + + super().__init__(kiosk, publish=publish, **kwargs) + + def _handle_partitioning_error(self, msg: str) -> None: + """Hook for error handling (warn vs raise).""" + warn(msg) + + def _format_partitioning_error(self, checksum, FR, FL, FS, FO) -> str: + cs = _first_tensor_item(checksum) + fr = _first_tensor_item(FR) + fl = _first_tensor_item(FL) + fs = _first_tensor_item(FS) + fo = _first_tensor_item(FO) + msg = f"Error in partitioning!\nChecksum: {cs:.6f}, FR: {fr:.3f}, " + msg += f"FL: {fl:.3f}, FS: {fs:.3f}, FO: {fo:.3f}\n" + return msg + + def _check_partitioning(self): + """Check for partitioning errors.""" + FR = self.states.FR + FL = self.states.FL + FS = self.states.FS + FO = self.states.FO + checksum = FR + (FL + FS + FO) * (1.0 - FR) - 1.0 + if torch.any(torch.abs(checksum) >= 0.0001): + msg = self._format_partitioning_error(checksum, FR, FL, FS, FO) + self.logger.error(msg) + self._handle_partitioning_error(msg) + + def _broadcast_partitioning(self, FR, FL, FS, FO): + FR = _broadcast_to(FR, self.params_shape, dtype=self.dtype, device=self.device) + FL = _broadcast_to(FL, self.params_shape, dtype=self.dtype, device=self.device) + FS = _broadcast_to(FS, self.params_shape, dtype=self.dtype, device=self.device) + FO = _broadcast_to(FO, self.params_shape, dtype=self.dtype, device=self.device) + return FR, FL, FS, FO + + def _set_partitioning_states(self, FR, FL, FS, FO): + self.states.FR = FR + self.states.FL = FL + self.states.FS = FS + self.states.FO = FO + self.states.PF = PartioningFactors(FR, FL, FS, FO) + + def _compute_partitioning_from_tables(self, DVS): + p = self.params + FR = p.FRTB(DVS) + FL = p.FLTB(DVS) + FS = p.FSTB(DVS) + FO = p.FOTB(DVS) + return FR, FL, FS, FO + + def _initialize_from_tables(self, kiosk, parvalues): + self.params = self.Parameters(parvalues) + self.kiosk = kiosk + self.params_shape = _get_params_shape(self.params) + + DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) + FR, FL, FS, FO = self._compute_partitioning_from_tables(DVS) + FR, FL, FS, FO = self._broadcast_partitioning(FR, FL, FS, FO) + + self.states = self.StateVariables( + kiosk, + publish=["FR", "FL", "FS", "FO"], + FR=FR, + FL=FL, + FS=FS, + FO=FO, + PF=PartioningFactors(FR, FL, FS, FO), + ) + self._check_partitioning() + + def _update_from_tables(self): + DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) + FR, FL, FS, FO = self._compute_partitioning_from_tables(DVS) + FR, FL, FS, FO = self._broadcast_partitioning(FR, FL, FS, FO) + self._set_partitioning_states(FR, FL, FS, FO) + self._check_partitioning() + + +class DVS_Partitioning(_BaseDVSPartitioning): + """Class for assimilate partitioning based on development stage (DVS). + + `DVS_Partitioning` calculates the partitioning of the assimilates to roots, + stems, leaves and storage organs using fixed partitioning tables as a + function of crop development stage. The available assimilates are first + split into below-ground and aboveground using the values in FRTB. In a + second stage they are split into leaves (FLTB), stems (FSTB) and storage + organs (FOTB). + + Since the partitioning fractions are derived from the state variable DVS + they are regarded state variables as well. + + **Simulation parameters** (To be provided in cropdata dictionary): + + | Name | Description | Type | Unit | + |------|----------------------------------------------------------------|------|------| + | FRTB | Partitioning to roots as a function of development stage | TCr | - | + | FSTB | Partitioning to stems as a function of development stage | TCr | - | + | FLTB | Partitioning to leaves as a function of development stage | TCr | - | + | FOTB | Partitioning to storage organs as a function of development stage | TCr | - | + + **State variables** + + | Name | Description | Pbl | Unit | + |------|------------------------------------------|-----|------| + | FR | Fraction partitioned to roots | Y | - | + | FS | Fraction partitioned to stems | Y | - | + | FL | Fraction partitioned to leaves | Y | - | + | FO | Fraction partitioned to storage organs | Y | - | + | PF | Partitioning factors packed in tuple | N | - | + + **Rate variables** + + None + + **External dependencies:** + + | Name | Description | Provided by | Unit | + |------|--------------------------|---------------|------| + | DVS | Crop development stage | DVS_Phenology | - | + + **Outputs** + + | Name | Description | Pbl | Unit | + |------|----------------------------------------|-----|------| + | FR | Fraction partitioned to roots | Y | - | + | FL | Fraction partitioned to leaves | Y | - | + | FS | Fraction partitioned to stems | Y | - | + | FO | Fraction partitioned to storage organs | Y | - | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|---------------------------| + | FR | FRTB, DVS | + | FL | FLTB, DVS | + | FS | FSTB, DVS | + | FO | FOTB, DVS | + + *Exceptions raised* + + A PartitioningError is raised if the partitioning coefficients to leaves, + stems and storage organs on a given day do not add up to 1. + """ + + def initialize(self, day, kiosk, parvalues): + """Initialize the DVS_Partitioning simulation object. + + Args: + day: Start date of the simulation. + kiosk (VariableKiosk): Variable kiosk of this PCSE instance. + parvalues (ParameterProvider): Object providing parameters as + key/value pairs. + """ + self._initialize_from_tables(kiosk, parvalues) + + @prepare_states + def integrate(self, day, delt=1.0): + """Update partitioning factors based on development stage (DVS).""" + self._update_from_tables() + + def calc_rates(self, day, drv): + """Return partitioning factors based on current DVS. + + Rate calculation does nothing for partitioning as it is a derived state. + """ + return self.states.PF + + +# This class is used in `wofost81` and has NOT been tested, see #41 +class DVS_Partitioning_N(_BaseDVSPartitioning): + """Class for assimilate partitioning based on development stage (DVS) with N stress. + + `DVS_Partitioning_N` calculates the partitioning of the assimilates to roots, + stems, leaves and storage organs using fixed partitioning tables as a + function of crop development stage. The only difference with the normal + partitioning class is the effect of nitrogen stress on partitioning to + leaves. The available assimilates are first split into below-ground and + aboveground using the values in FRTB. In a second stage they are split into + leaves (FLTB), stems (FSTB) and storage organs (FOTB). + + Since the partitioning fractions are derived from the state variable DVS + they are regarded state variables as well. + + **Simulation parameters** (To be provided in cropdata dictionary): + + | Name | Description | Type | Unit | + |------|----------------------------------------------------------------|------|------| + | FRTB | Partitioning to roots as a function of development stage | TCr | - | + | FSTB | Partitioning to stems as a function of development stage | TCr | - | + | FLTB | Partitioning to leaves as a function of development stage | TCr | - | + | FOTB | Partitioning to storage organs as a function of development stage | TCr | - | + + **State variables** + + | Name | Description | Pbl | Unit | + |------|------------------------------------------|-----|------| + | FR | Fraction partitioned to roots | Y | - | + | FS | Fraction partitioned to stems | Y | - | + | FL | Fraction partitioned to leaves | Y | - | + | FO | Fraction partitioned to storage organs | Y | - | + | PF | Partitioning factors packed in tuple | N | - | + + **Rate variables** + + None + + **External dependencies:** + + | Name | Description | Provided by | Unit | + |-------|------------------------------------------------|--------------------------|------| + | DVS | Crop development stage | DVS_Phenology | - | + | RFTRA | Reduction factor for transpiration (water & | Water & Oxygen dynamics | - | + | | oxygen stress) | | | + + **Outputs** + + | Name | Description | Pbl | Unit | + |------|----------------------------------------|-----|------| + | FR | Fraction partitioned to roots | Y | - | + | FL | Fraction partitioned to leaves | Y | - | + | FS | Fraction partitioned to stems | Y | - | + | FO | Fraction partitioned to storage organs | Y | - | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|----------------------------| + | FR | FRTB, DVS, RFTRA | + | FL | FLTB, DVS | + | FS | FSTB, DVS | + | FO | FOTB, DVS | + + *Exceptions raised* + + A PartitioningError is raised if the partitioning coefficients to leaves, + stems and storage organs on a given day do not add up to 1. + """ + + def _handle_partitioning_error(self, msg: str) -> None: + raise exc.PartitioningError(msg) + + def initialize(self, day, kiosk, parameters): + """Initialize the DVS_Partitioning_N simulation object. + + Args: + day: Start date of the simulation. + kiosk (VariableKiosk): Variable kiosk of this PCSE instance. + parameters (ParameterProvider): Dictionary with WOFOST cropdata + key/value pairs. + """ + self._initialize_from_tables(kiosk, parameters) + + def _calculate_stressed_fr(self, DVS: torch.Tensor, RFTRA: torch.Tensor) -> torch.Tensor: + """Computes the FR partitioning fraction under water/oxygen stress.""" + FRTMOD = torch.max(torch.ones_like(RFTRA), 1.0 / (RFTRA + 0.5)) + return torch.min(torch.full_like(FRTMOD, 0.6), (self.params.FRTB(DVS) * FRTMOD)) + + @prepare_states + def integrate(self, day, delt=1.0): + """Update partitioning factors based on DVS and water/oxygen stress.""" + DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) + RFTRA = torch.as_tensor(self.kiosk["RFTRA"], dtype=self.dtype, device=self.device) + + FR = self._calculate_stressed_fr(DVS, RFTRA) + FL = self.params.FLTB(DVS) + FS = self.params.FSTB(DVS) + FO = self.params.FOTB(DVS) + + FR, FL, FS, FO = self._broadcast_partitioning(FR, FL, FS, FO) + self._set_partitioning_states(FR, FL, FS, FO) + self._check_partitioning() + + def calc_rates(self, day, drv): + """Return partitioning factors based on current DVS. + + Rate calculation does nothing for partitioning as it is a derived state. + """ + return self.states.PF diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index dba659f..08f3ebd 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -12,6 +12,7 @@ "rootdynamics", "potentialproduction", "phenology", + "partitioning", ] FILE_NAMES = [ f"test_{model_name}_wofost72_{i:02d}.yaml" for model_name in model_names for i in range(1, 45) diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index 6002dce..882e6b9 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -84,7 +84,7 @@ class TestLeafDynamics: for i in range(1, 45) # there are 44 test files ] - @pytest.mark.parametrize("test_data_url", leafdynamics_data_urls) # Test subset for GPU + @pytest.mark.parametrize("test_data_url", leafdynamics_data_urls) def test_leaf_dynamics_with_testengine(self, test_data_url, device): """EngineTestHelper and not Engine because it allows to specify `external_states`.""" # prepare model input diff --git a/tests/physical_models/crop/test_partitioning.py b/tests/physical_models/crop/test_partitioning.py new file mode 100644 index 0000000..ae5c6e4 --- /dev/null +++ b/tests/physical_models/crop/test_partitioning.py @@ -0,0 +1,467 @@ +import copy +import warnings +from unittest.mock import patch +import pytest +import torch +from numpy.testing import assert_array_almost_equal +from pcse.models import Wofost72_PP +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.partitioning import DVS_Partitioning +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +partitioning_config = Configuration(CROP=DVS_Partitioning, OUTPUT_VARS=["FR", "FL", "FS", "FO"]) + + +def get_test_diff_partitioning(device: str = "cpu"): + """Build a small wrapper module for differentiable tests.""" + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + return DiffPartitioning( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + partitioning_config, + copy.deepcopy(external_states), + device=device, + ) + + +class DiffPartitioning(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config, + external_states, + device: str = "cpu", + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config = config + self.external_states = external_states + self.device = device + + def forward(self, params_dict: dict[str, torch.Tensor]): + # pass new value of parameters to the model + for name, value in params_dict.items(): + if isinstance(value, torch.Tensor) and value.device.type != self.device: + value = value.to(self.device) + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config, + self.external_states, + device=self.device, + ) + engine.run_till_terminate() + results = engine.get_output() + + output_dict = {} + for var in ["FR", "FL", "FS", "FO"]: + stacked = torch.stack([item[var] for item in results]) + # Keep outputs that have grad_fn in the computation graph + # For outputs without grad_fn, keep them as-is (they don't require gradients) + output_dict[var] = stacked + return output_dict + + +class TestPartitioning: + data_urls = [f"{phy_data_folder}/test_partitioning_wofost72_{i:02d}.yaml" for i in range(1, 45)] + + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + @pytest.mark.parametrize("test_data_url", data_urls) + def test_partitioning_with_testengine(self, test_data_url, device): + """Mirror of leaf dynamics structure: compare against YAML references.""" + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + partitioning_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = ( + test_data["ModelResults"], + test_data.get("Precision", {"FR": 1e-6, "FL": 1e-6, "FS": 1e-6, "FO": 1e-6}), + ) + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + # Range and checksum invariants + for key in ("FR", "FL", "FS", "FO"): + assert torch.all((model[key] >= 0.0) & (model[key] <= 1.0)) + checksum = model["FR"] + (model["FL"] + model["FS"] + model["FO"]) * (1.0 - model["FR"]) + assert torch.allclose(checksum, torch.ones_like(checksum), atol=1e-4) + # Reference checks + assert all( + torch.all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize("param", ["FRTB", "FLTB", "FSTB", "FOTB"]) + def test_partitioning_with_one_parameter_vector(self, param, device): + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[param].repeat(10, 1) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + partitioning_config, + external_states, + device=device, + ) + engine.run_till_terminate() + results = engine.get_output() + + for day in results: + for key in ("FR", "FL", "FS", "FO"): + assert day[key].ndim >= 1 + assert torch.all(torch.isfinite(day[key])) + + def test_partitioning_with_different_parameter_values(self, device): + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, device=device) + + # Setting vectors with multiple values for each table parameter + for param in ("FRTB", "FLTB", "FSTB", "FOTB"): + # AfgenTrait parameters need to have shape (N, M) + base = crop_model_params_provider[param] + # Create two variations of the table + param_vec = torch.stack([base * 0.8, base]) + crop_model_params_provider.set_override(param, param_vec, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + partitioning_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + for day in actual_results: + for key in ("FR", "FL", "FS", "FO"): + assert day[key].shape[0] == 2 + assert torch.all(torch.isfinite(day[key])) + + def test_partitioning_with_multiple_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + + # Setting vectors for all table parameters + for name in ("FRTB", "FLTB", "FSTB", "FOTB"): + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[name].repeat(2, 1) + crop_model_params_provider.set_override(name, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + partitioning_config, + external_states, + device="cpu", + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + for day in actual_results: + for key in ("FR", "FL", "FS", "FO"): + assert day[key].shape[0] == 2 + assert torch.all(torch.isfinite(day[key])) + + def test_partitioning_with_multiple_parameter_arrays(self): + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + + # Repeat AfgenTrait tables to vectorize like in leaf dynamics: shape (30, 5, K) + for param in ("FRTB", "FLTB", "FSTB", "FOTB"): + base = crop_model_params_provider[param] + repeated = base.repeat(30, 5, 1) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + partitioning_config, + external_states, + device="cpu", + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = ( + test_data["ModelResults"], + test_data.get("Precision", {"FR": 1e-6, "FL": 1e-6, "FS": 1e-6, "FO": 1e-6}), + ) + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + torch.all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + assert all( + model[var].shape == (30, 5) for var in expected_precision.keys() + ) # check the output shapes + + def test_partitioning_with_incompatible_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_partitioning_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + + crop_model_params_provider.set_override("FRTB", [[0.0, 0.3, 2.0, 0.1]] * 4, check=False) + crop_model_params_provider.set_override("FLTB", [[0.0, 0.3, 2.0, 0.1]] * 2, check=False) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + partitioning_config, + external_states, + device="cpu", + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls[:1]) + def test_wofost_pp_with_partitioning(self, test_data_url): + # prepare model input + test_data = get_test_data(test_data_url) + crop_model_params = ["FRTB", "FLTB", "FSTB", "FOTB"] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params) + ) + + # get expected results from YAML test data + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.Partitioning", DVS_Partitioning): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + abs(reference[var] - model[var]) < precision + for var, precision in expected_precision.items() + ) + + +class TestDiffPartitioningGradients: + """Gradient tests mirroring leaf dynamics test structure.""" + + param_names = ["FRTB", "FLTB", "FSTB", "FOTB"] + output_names = ["FR", "FL", "FS", "FO"] + + # parameter configurations (value, dtype) + param_configs = { + "single": { + "FRTB": ([[0.0, 0.3, 2.0, 0.1]], torch.float64), + "FLTB": ([[0.0, 0.3, 2.0, 0.1]], torch.float64), + "FSTB": ([[0.0, 0.3, 2.0, 0.1]], torch.float64), + "FOTB": ([[0.0, 0.3, 2.0, 0.1]], torch.float64), + }, + "tensor": { + "FRTB": ( + [[0.0, 0.3, 2.0, 0.1], [0.0, 0.4, 2.0, 0.2], [0.0, 0.2, 2.0, 0.05]], + torch.float64, + ), + "FLTB": ( + [[0.0, 0.25, 2.0, 0.15], [0.0, 0.35, 2.0, 0.25], [0.0, 0.2, 2.0, 0.1]], + torch.float64, + ), + "FSTB": ( + [[0.0, 0.5, 2.0, 0.3], [0.0, 0.4, 2.0, 0.2], [0.0, 0.6, 2.0, 0.35]], + torch.float64, + ), + "FOTB": ( + [[0.0, 0.1, 2.0, 0.05], [0.0, 0.2, 2.0, 0.1], [0.0, 0.15, 2.0, 0.08]], + torch.float64, + ), + }, + } + + # mapping of which outputs should have gradients for each param + gradient_mapping = { + "FRTB": ["FR"], + "FLTB": ["FL"], + "FSTB": ["FS"], + "FOTB": ["FO"], + } + + gradient_params: list[tuple[str, str]] = [] + no_gradient_params: list[tuple[str, str]] = [] + for p in param_names: + for o in output_names: + if o in gradient_mapping.get(p, []): + gradient_params.append((p, o)) + else: + no_gradient_params.append((p, o)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type, device): + model = get_test_diff_partitioning(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output[output_name].sum() + + # For outputs that don't depend on the parameter, gradient will be None + # This is the expected behavior for parameters that shouldn't affect this output + try: + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + except RuntimeError as e: + if "does not require grad and does not have a grad_fn" in str(e): + # Output is independent of parameter - this is expected + return + raise + + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): + model = get_test_diff_partitioning(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output[output_name].sum() + + # For these gradient_params, the output should depend on the parameter + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + assert grads is not None, f"Gradients for {param_name} should not be None" + + param.grad = None + loss.backward() + grad_backward = param.grad + + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), ( + f"Forward and backward gradients for {param_name} should match" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type, device): + """Test that analytical gradients match numerical gradients.""" + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + numerical_grad = calculate_numerical_grad( + lambda: get_test_diff_partitioning(device=device), param_name, param.data, output_name + ) + + model = get_test_diff_partitioning(device=device) + output = model({param_name: param}) + loss = output[output_name].sum() + + # this is ∂loss/∂param, for comparison with numerical gradient + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + # [!] AFGEN uses interval selection (searchsorted) + branching (where) which makes + # the function non-differentiable w.r.t. the x-coordinates of the table. + # This non-differentiable behavior is handled non-consistently by: + # - Autograd will ignore the effect of moving breakpoints + # - finite differences do capture the effect of moving breakpoints + # Therefore, we ignore the x-coordinates and only compare gradients for the y-values. + # Check test_utils.py::TestAfgenEdgeCases::test_x_breakpoint_at_clamp for more details. + + # AFGEN tables are encoded as [x0, y0, x1, y1, ...], so y-values are at odd indices. + numerical_np = numerical_grad.detach().cpu().numpy() + grads_np = grads.detach().cpu().numpy() + assert numerical_np.shape == grads_np.shape + + y_slice = (..., slice(1, None, 2)) + assert_array_almost_equal(numerical_np[y_slice], grads_np[y_slice], decimal=3) + + # Warn if gradient is zero + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' with respect to output" + + f"'{output_name}' is zero: {grads.detach().cpu().numpy()}", + UserWarning, + ) diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index 0f28153..805dc21 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -259,6 +259,61 @@ def test_two_point_table(self): assert afgen(torch.tensor(10.0)) == 20.0 assert torch.isclose(afgen(torch.tensor(5.0)), torch.tensor(15.0, dtype=DTYPE)) + def test_x_breakpoint_at_clamp(self): + """Illustrate why AFGEN x-breakpoint grads can disagree with finite differences. + + AFGEN is piecewise-linear in the query x, but evaluation includes discrete interval + selection (via searchsorted) and boundary clamping (via where). + + At the *exact* upper breakpoint (x_query == x_last), the output is clamped to y_last. + Autograd therefore reports ~0 gradient w.r.t. x_last (the breakpoint), because within + that branch the output does not depend on x_last. + + However, a finite-difference perturbation of x_last changes which branch is taken + for x_query fixed at the boundary, producing a non-zero numerical derivative. + This is the same phenomenon that caused the partitioning numerical-grad test to fail + when comparing all AFGEN table entries. + """ + + # Keep this example deterministic across environments. + old_device = ComputeConfig.get_device() + old_dtype = ComputeConfig.get_dtype() + ComputeConfig.set_device("cpu") + ComputeConfig.set_dtype(torch.float64) + try: + # Table is encoded as [x0, y0, x1, y1]. Use a non-flat y so x-breakpoints matter. + tbl = torch.tensor([0.0, 0.3, 2.0, 0.1], dtype=torch.float64, requires_grad=True) + afgen = Afgen(tbl) + + # Query exactly at the last breakpoint, which triggers the clamp branch. + x_query = torch.tensor(2.0, dtype=torch.float64) + out = afgen(x_query) + + (grad_auto,) = torch.autograd.grad(out, tbl, retain_graph=False) + + # Central finite difference w.r.t. each table entry + delta = 1e-6 + grad_num = torch.zeros_like(tbl) + for i in range(tbl.numel()): + tbl_plus = tbl.detach().clone() + tbl_minus = tbl.detach().clone() + tbl_plus[i] += delta + tbl_minus[i] -= delta + out_plus = Afgen(tbl_plus)(x_query) + out_minus = Afgen(tbl_minus)(x_query) + grad_num[i] = (out_plus - out_minus) / (2 * delta) + + # Gradients w.r.t. y-entries (odd indices) should match closely. + assert torch.allclose(grad_auto[1::2], grad_num[1::2], atol=1e-5, rtol=1e-4) + + # Gradient w.r.t. the last x-breakpoint (index 2) is the illustrative mismatch. + # Autograd sees the clamp branch => ~0; finite differences see branch switching. + assert abs(float(grad_auto[2])) < 1e-8 + assert abs(float(grad_num[2])) > 1e-3 + finally: + ComputeConfig.set_device(old_device) + ComputeConfig.set_dtype(old_dtype) + def test_negative_values(self): """Test with negative x and y values.""" afgen = Afgen([-10, -20, 0, 0, 10, 20])