diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index da24eb3..8e24d1a 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -8,8 +8,8 @@ from pcse.decorators import prepare_rates from pcse.decorators import prepare_states from pcse.traitlets import Any +from pcse.util import Afgen from pcse.util import AfgenTrait -from pcse.util import limit DTYPE = torch.float64 # Default data type for tensors in this module @@ -113,14 +113,19 @@ class WOFOST_Leaf_Dynamics(SimulationObject): LAI, TWLV """ + # The following parameters are used to initialize and control the arrays that store information + # on the leaf classes during the time integration: leaf area, age, and biomass. + START_DATE = None # Start date of the simulation + MAX_DAYS = 300 # Maximum number of days that can be simulated in one run (i.e. array lenghts) + class Parameters(ParamTemplate): RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) SPAN = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) TBASE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) PERDL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - SLATB = AfgenTrait() # FIXEME - KDIFTB = AfgenTrait() # FIXEME + SLATB = AfgenTrait() # FIXME + KDIFTB = AfgenTrait() # FIXME class StateVariables(StatesTemplate): LV = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) @@ -156,6 +161,7 @@ def initialize(self, day, kiosk, parvalues): :param parvalues: `ParameterProvider` object providing parameters as key/value pairs """ + self.START_DATE = day self.kiosk = kiosk # TODO check if parvalues are already torch.nn.Parameters self.params = self.Parameters(parvalues) @@ -171,19 +177,22 @@ def initialize(self, day, kiosk, parvalues): DVS = self.kiosk["DVS"] params = self.params + shape = _get_params_shape(params) # Initial leaf biomass WLV = (params.TDWI * (1 - FR)) * FL - DWLV = torch.tensor(0.0, dtype=DTYPE) + DWLV = torch.zeros(shape, dtype=DTYPE) TWLV = WLV + DWLV - # First leaf class (SLA, age and weight) - SLA = torch.tensor([params.SLATB(DVS)], dtype=DTYPE) - LVAGE = torch.tensor([0.0], dtype=DTYPE) - LV = torch.stack([WLV]) + # Initialize leaf classes (SLA, age and weight) + SLA = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) + LVAGE = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) + LV = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) + SLA[..., 0] = params.SLATB(DVS) + LV[..., 0] = WLV # Initial values for leaf area - LAIEM = LV[0] * SLA[0] + LAIEM = LV[..., 0] * SLA[..., 0] LASUM = LAIEM LAIEXP = LAIEM LAIMAX = LAIEM @@ -236,57 +245,54 @@ def calc_rates(self, day, drv): # death due to self shading cause by high LAI DVS = self.kiosk["DVS"] LAICR = 3.2 / p.KDIFTB(DVS) - r.DSLV2 = mask * s.WLV * limit(0.0, 0.03, 0.03 * (s.LAI - LAICR) / LAICR) + r.DSLV2 = mask * s.WLV * torch.clamp(0.03 * (s.LAI - LAICR) / LAICR, 0.0, 0.03) # Death of leaves due to frost damage as determined by # Reduction Factor Frost "RF_FROST" if "RF_FROST" in self.kiosk: r.DSLV3 = mask * s.WLV * k.RF_FROST else: - r.DSLV3 = torch.tensor(0.0, dtype=DTYPE) + r.DSLV3 = torch.zeros_like(s.WLV, dtype=DTYPE) # leaf death equals maximum of water stress, shading and frost - r.DSLV = torch.max(torch.stack([r.DSLV1, r.DSLV2, r.DSLV3])) + r.DSLV = torch.maximum(torch.maximum(r.DSLV1, r.DSLV2), r.DSLV3) # Determine how much leaf biomass classes have to die in states.LV, # given the a life span > SPAN, these classes will be accumulated # in DALV. # Note that the actual leaf death is imposed on the array LV during the # state integration step. - DALV = torch.tensor(0.0, dtype=DTYPE) - if p.SPAN.requires_grad: # replacing hard threshold `if lvage > p.SPAN`` - sharpness = 1000.0 # FIXEME - for lv, lvage in zip(s.LV, s.LVAGE, strict=False): - weight = torch.sigmoid((lvage - p.SPAN) * sharpness) - DALV = DALV + weight * lv - else: - for lv, lvage in zip(s.LV, s.LVAGE, strict=False): - if lvage > p.SPAN: - DALV = DALV + lv - - r.DALV = DALV + tSPAN = _broadcast_to(p.SPAN, s.LVAGE.shape) # Broadcast to same shape + # Using a sigmoid here instead of a conditional statement on the value of + # SPAN because the latter would not allow for the gradient to be tracked. + sharpness = torch.tensor(1000.0, dtype=DTYPE) # FIXME + weight = torch.sigmoid((s.LVAGE - tSPAN) * sharpness) + r.DALV = torch.sum(weight * s.LV, dim=-1) # Total death rate leaves - r.DRLV = torch.max(r.DSLV, r.DALV) + r.DRLV = torch.maximum(r.DSLV, r.DALV) # physiologic ageing of leaves per time step FYSAGE = (drv.TEMP - p.TBASE) / (35.0 - p.TBASE) - r.FYSAGE = mask * torch.max(torch.tensor(0.0, dtype=DTYPE), FYSAGE) + r.FYSAGE = mask * torch.clamp(FYSAGE, 0.0) # specific leaf area of leaves per time step r.SLAT = mask * torch.tensor(p.SLATB(DVS), dtype=DTYPE) # leaf area not to exceed exponential growth curve - if s.LAIEXP < 6.0: - DTEFF = torch.max(torch.tensor(0.0, dtype=DTYPE), drv.TEMP - p.TBASE) - r.GLAIEX = s.LAIEXP * p.RGRLAI * DTEFF - # source-limited increase in leaf area - r.GLASOL = r.GRLV * r.SLAT - # sink-limited increase in leaf area - GLA = torch.min(r.GLAIEX, r.GLASOL) - # adjustment of specific leaf area of youngest leaf class - if r.GRLV > 0.0: - r.SLAT = GLA / r.GRLV + is_lai_exp = s.LAIEXP < 6.0 + DTEFF = torch.clamp(drv.TEMP - p.TBASE, 0.0) + # NOTE: conditional statements do not allow for the gradient to be + # tracked through the condition. Thus, the gradient with respect to + # parameters that contribute to `is_lai_exp` (e.g. RGRLAI and TBASE) + # are expected to be incorrect. + r.GLAIEX = torch.where(is_lai_exp, s.LAIEXP * p.RGRLAI * DTEFF, r.GLAIEX) + # source-limited increase in leaf area + r.GLASOL = torch.where(is_lai_exp, r.GRLV * r.SLAT, r.GLASOL) + # sink-limited increase in leaf area + GLA = torch.minimum(r.GLAIEX, r.GLASOL) + # adjustment of specific leaf area of youngest leaf class + r.SLAT = torch.where(is_lai_exp & (r.GRLV > 0.0), GLA / r.GRLV, r.SLAT) @prepare_states def integrate(self, day, delt=1.0): @@ -299,44 +305,45 @@ def integrate(self, day, delt=1.0): tLV = states.LV.clone() tSLA = states.SLA.clone() tLVAGE = states.LVAGE.clone() - tDRLV = rates.DRLV - - # leaf death is imposed on leaves by removing leave classes from the - # right side. - for LVweigth in reversed(states.LV): - if tDRLV > 0.0: - if tDRLV >= LVweigth: # remove complete leaf class - tDRLV = tDRLV - LVweigth - tLV = tLV[:-1] # Remove last element - tLVAGE = tLVAGE[:-1] - tSLA = tSLA[:-1] - else: # Decrease value of oldest (rightmost) leave class - tLV[-1] = tLV[-1] - tDRLV - tDRLV = torch.tensor(0.0, dtype=DTYPE) - else: - break - + tDRLV = _broadcast_to(rates.DRLV, tLV.shape) + + # Leaf death is imposed on leaves from the oldest ones. + # Calculate the cumulative sum of weights after leaf death, and + # find out which leaf classes are dead (negative weights) + weight_cumsum = tLV.cumsum(dim=-1) - tDRLV + is_alive = weight_cumsum >= 0 + # Adjust value of oldest leaf class, i.e. the first non-zero + # weight along the time axis (the last dimension). + # Cast argument to int because torch.argmax requires it to be numeric + idx_oldest = torch.argmax(is_alive.type(torch.int), dim=-1, keepdim=True) + new_biomass = torch.take_along_dim(weight_cumsum, indices=idx_oldest, dim=-1) + tLV = torch.scatter(tLV, dim=-1, index=idx_oldest, src=new_biomass) + # Zero out all dead leaf classes + # NOTE: conditional statements do not allow for the gradient to be + # tracked through the condition. Thus, the gradient with respect to + # parameters that contribute to `is_alive` are expected to be incorrect. + tLV = torch.where(is_alive, tLV, 0.0) # Integration of physiological age - tLVAGE = torch.tensor([age + rates.FYSAGE for age in tLVAGE], dtype=DTYPE) + tLVAGE = tLVAGE + rates.FYSAGE + tLVAGE = torch.where(is_alive, tLVAGE, 0.0) + tSLA = torch.where(is_alive, tSLA, 0.0) # --------- leave growth --------- - # new leaves in class 1 - tLV = torch.cat((torch.tensor([rates.GRLV], dtype=DTYPE), tLV)) - tSLA = torch.cat((torch.tensor([rates.SLAT], dtype=DTYPE), tSLA)) - tLVAGE = torch.cat((torch.tensor([0.0], dtype=DTYPE), tLVAGE)) + idx = int((day - self.START_DATE).days / delt) + tLV[..., idx] = rates.GRLV + tSLA[..., idx] = rates.SLAT + tLVAGE[..., idx] = 0.0 # calculation of new leaf area - states.LASUM = torch.sum( - torch.stack([lv * sla for lv, sla in zip(tLV, tSLA, strict=False)]) - ) + states.LASUM = torch.sum(tLV * tSLA, dim=-1) states.LAI = self._calc_LAI() - states.LAIMAX = torch.max(states.LAI, states.LAIMAX) + states.LAIMAX = torch.maximum(states.LAI, states.LAIMAX) # exponential growth curve states.LAIEXP = states.LAIEXP + rates.GLAIEX # Update leaf biomass states - states.WLV = torch.sum(tLV) + states.WLV = torch.sum(tLV, dim=-1) states.DWLV = states.DWLV + rates.DRLV states.TWLV = states.WLV + states.DWLV @@ -345,55 +352,6 @@ def integrate(self, day, delt=1.0): self.states.SLA = tSLA self.states.LVAGE = tLVAGE - @prepare_states - def _set_variable_LAI(self, nLAI): # FIXEME - """Updates the value of LAI to to the new value provided as input. - - Related state variables will be updated as well and the increments - to all adjusted state variables will be returned as a dict. - """ - states = self.states - - # Store old values of states - oWLV = states.WLV - oLAI = states.LAI - oTWLV = states.TWLV - oLASUM = states.LASUM - - # Reduce oLAI for pod and stem area. SAI and PAI will not be adjusted - # because this is often only a small component of the total leaf - # area. For all current crop files in WOFOST SPA and SSA are zero - # anyway - SAI = self.kiosk["SAI"] - PAI = self.kiosk["PAI"] - adj_nLAI = torch.max(nLAI - SAI - PAI, 0.0) - adj_oLAI = torch.max(oLAI - SAI - PAI, 0.0) - - # LAI Adjustment factor for leaf biomass LV (rLAI) - if adj_oLAI > 0: - rLAI = adj_nLAI / adj_oLAI - LV = [lv * rLAI for lv in states.LV] - # If adj_oLAI == 0 then add the leave biomass directly to the - # youngest leave age class (LV[0]) - else: - LV = [nLAI / states.SLA[0]] - - states.LASUM = torch.sum( - torch.tensor([lv * sla for lv, sla in zip(LV, states.SLA, strict=False)], dtype=DTYPE) - ) - states.LV = LV - states.LAI = self._calc_LAI() - states.WLV = torch.sum(states.LV) - states.TWLV = states.WLV + states.DWLV - - increments = { - "LAI": states.LAI - oLAI, - "LAISUM": states.LASUM - oLASUM, - "WLV": states.WLV - oWLV, - "TWLV": states.TWLV - oTWLV, - } - return increments - def _exist_required_external_variables(kiosk): """Check if all required external variables are available in the kiosk.""" @@ -405,3 +363,39 @@ def _exist_required_external_variables(kiosk): f" Ensure that all required variables {required_external_vars_at_init}" " are provided." ) + + +def _get_params_shape(params): + """Get the parameters shape. + + Parameters can have arbitrary number of dimensions, but all parameters that are not zero- + dimensional should have the same shape. + """ + shape = () + for parname in params.trait_names(): + # Skip special traitlets attributes + if parname.startswith("trait"): + continue + param = getattr(params, parname) + # Skip Afgen parameters: + if isinstance(param, Afgen): + continue + # Parameters that are not zero dimensional should all have the same shape + if param.shape and not shape: + shape = param.shape + elif param.shape: + assert param.shape == shape, ( + "All parameters should have the same shape (or have no dimensions)" + ) + return shape + + +def _broadcast_to(x, shape): + """Create a view of tensor X with the given shape.""" + if x.dim() == 0: + # For 0-d tensors, we simply broadcast to the given shape + return torch.broadcast_to(x, shape) + # The given shape should match x in all but the last axis, which represents + # the dimension along which the time integration is carried out. + # We first append an axis to x, then expand to the given shape + return x.unsqueeze(-1).expand(shape) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index dea2a9c..68ddfb7 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -286,18 +286,18 @@ def get_test_data(file_path): return inputs["ModelResults"], inputs["Precision"] -def calculate_numerical_grad(get_model_fn, param_name, param_value, output_index): +def calculate_numerical_grad(get_model_fn, param_name, param_value, out_name): """Calculate the numerical gradient of output with respect to a parameter.""" delta = 1e-6 - p_plus = param_value.item() + delta - p_minus = param_value.item() - delta + p_plus = param_value + delta + p_minus = param_value - delta model = get_model_fn() - output = model({param_name: torch.nn.Parameter(torch.tensor(p_plus, dtype=torch.float64))}) - loss_plus = output[0, :, output_index].sum() + output = model({param_name: torch.nn.Parameter(p_plus)}) + loss_plus = output[out_name].sum(dim=0) model = get_model_fn() - output = model({param_name: torch.nn.Parameter(torch.tensor(p_minus, dtype=torch.float64))}) - loss_minus = output[0, :, output_index].sum() + output = model({param_name: torch.nn.Parameter(p_minus)}) + loss_minus = output[out_name].sum(dim=0) - return (loss_plus.item() - loss_minus.item()) / (2 * delta) + return (loss_plus.data - loss_minus.data) / (2 * delta) diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index 08582a8..eded0c5 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -2,8 +2,7 @@ from unittest.mock import patch import pytest import torch -import torch.testing -from numpy.testing import assert_almost_equal +from numpy.testing import assert_array_almost_equal from pcse.engine import Engine from pcse.models import Wofost72_PP from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics @@ -61,9 +60,7 @@ def forward(self, params_dict): engine.run_till_terminate() results = engine.get_output() - return torch.stack( - [torch.stack([item["LAI"], item["TWLV"]]) for item in results] - ).unsqueeze(0) # shape: [1, time_steps, 2] + return {var: torch.stack([item[var] for item in results]) for var in ["LAI", "TWLV"]} class TestLeafDynamics: @@ -122,6 +119,205 @@ def test_leaf_dynamics_with_engine(self): config_path, ) + @pytest.mark.parametrize("param", ["TDWI", "SPAN"]) + def test_leaf_dynamics_with_one_parameter_vector(self, param): + # prepare model input + test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data_path, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + + # Setting a vector (with one value) for the selected parameter + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results, expected_precision = get_test_data(test_data_path) + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize( + "param,delta", + [ + ("TDWI", 0.1), + ("SPAN", 5), + ], + ) + def test_leaf_dynamics_with_different_parameter_values(self, param, delta): + # prepare model input + test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data_path, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + + # Setting a vector with multiple values for the selected parameter + test_value = crop_model_params_provider[param] + # We set the value for which test data are available as the last element + param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) + crop_model_params_provider.set_override(param, param_vec, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results, expected_precision = get_test_data(test_data_path) + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + # The value for which test data are available is the last element + abs(reference[var] - model[var][-1]) < precision + for var, precision in expected_precision.items() + ) + + def test_leaf_dynamics_with_multiple_parameter_vectors(self): + # prepare model input + test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data_path, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + + # Setting a vector (with one value) for the TDWI and SPAN parameters + for param in ("TDWI", "SPAN"): + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results, expected_precision = get_test_data(test_data_path) + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + + def test_leaf_dynamics_with_multiple_parameter_arrays(self): + # prepare model input + test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data_path, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + + # Setting an array with arbitrary shape (and one value) for the + # TDWI and SPAN parameters + for param in ("TDWI", "SPAN"): + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + # get expected results from YAML test data + expected_results, expected_precision = get_test_data(test_data_path) + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + torch.all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + assert all( + model[var].shape == (30, 5) for var in expected_precision.keys() + ) # check the output shapes + + def test_leaf_dynamics_with_incompatible_parameter_vectors(self): + # prepare model input + test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data_path, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + + # Setting a vector (with one value) for the TDWI and SPAN parameters, + # but with different lengths + crop_model_params_provider.set_override( + "TDWI", crop_model_params_provider["TDWI"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "SPAN", crop_model_params_provider["SPAN"].repeat(5), check=False + ) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + def test_wofost_pp_with_leaf_dynamics(self): # prepare model input test_data_path = phy_data_folder / "test_potentialproduction_wofost72_01.yaml" @@ -151,156 +347,144 @@ def test_wofost_pp_with_leaf_dynamics(self): class TestDiffLeafDynamicsTDWI: - def test_gradients_tdwi_lai_leaf_dynamics(self): + @pytest.mark.parametrize( + "param_value,out_name", + [ + (torch.tensor(0.2, dtype=torch.float64), "LAI"), + (torch.tensor(0.2, dtype=torch.float64), "TWLV"), + (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "LAI"), + (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "TWLV"), + ], + ) + def test_gradients_leaf_dynamics(self, param_value, out_name): model = get_test_diff_leaf_model() - tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) + tdwi = torch.nn.Parameter(param_value) output = model({"TDWI": tdwi}) - lai = output[0, :, 0] - loss = lai.sum() + loss = output[out_name].sum() - # this is ∂loss/∂tdwi without calling loss.backward(). + # this is ∂loss/∂param without calling loss.backward(). # this is called forward gradient here because it is calculated without backpropagation. grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - assert grads is not None, "Gradients for TDWI should not be None" + assert grads is not None, "Gradients should not be None" tdwi.grad = None # clear any existing gradient loss.backward() - # this is ∂loss/∂tdwi calculated using backpropagation + # this is ∂loss/∂param calculated using backpropagation grad_backward = tdwi.grad - assert grad_backward is not None, "Backward gradients for TDWI should not be None" - assert grad_backward == grads, "Forward and backward gradients for TDWI should match" - - def test_gradients_tdwi_lai_leaf_dynamics_numerical(self): + assert grad_backward is not None, "Backward gradients should not be None" + assert torch.all(grad_backward == grads), "Forward and backward gradients should match" + + @pytest.mark.parametrize( + "param_value,out_name", + [ + (torch.tensor(0.2, dtype=torch.float64), "LAI"), + (torch.tensor(0.2, dtype=torch.float64), "TWLV"), + (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "LAI"), + (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "TWLV"), + ], + ) + def test_gradients_leaf_dynamics_numerical(self, param_value, out_name): # first check if the numerical gradient isnot zero i.e. the parameter has an effect - tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64)) - output_index = 0 # LAI is at index 0 + tdwi = torch.nn.Parameter(param_value) numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "TDWI", tdwi, output_index - ) # this is Δloss/Δtdwi - - model = get_test_diff_leaf_model() - output = model({"TDWI": tdwi}) - lai = output[0, :, output_index] - loss = lai.sum() - # this is ∂loss/∂tdwi, for comparison with numerical gradient - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] + get_test_diff_leaf_model, "TDWI", tdwi.data, out_name + ) # this is Δloss/Δparam - assert_almost_equal(numerical_grad, grads.item(), decimal=3) - - def test_gradients_tdwi_twlv_leaf_dynamics(self): - # prepare model input model = get_test_diff_leaf_model() - tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) output = model({"TDWI": tdwi}) - twlv = output[0, :, 1] - loss = twlv.sum() + loss = output[out_name].sum() - # this is ∂loss/∂tdwi - # this is called forward gradient here because it is calculated without backpropagation. + # this is ∂loss/∂param, for comparison with numerical gradient grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - assert grads is not None, "Gradients for TDWI should not be None" - - tdwi.grad = None # clear any existing gradient - loss.backward() - # this is ∂loss/∂tdwi calculated using backpropagation - grad_backward = tdwi.grad - assert grad_backward is not None, "Backward gradients for TDWI should not be None" - assert grad_backward == grads, "Forward and backward gradients for TDWI should match" - - def test_gradients_tdwi_twlv_leaf_dynamics_numerical(self): - # first check if the numerical gradient isnot zero i.e. the parameter has an effect - tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64)) - output_index = 1 # TWLV is at index 1 - numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "TDWI", tdwi, output_index - ) # this is Δloss/Δtdwi - - model = get_test_diff_leaf_model() - output = model({"TDWI": tdwi}) - twlv = output[0, :, output_index] - loss = twlv.sum() - # this is ∂loss/∂tdwi, for comparison with numerical gradient - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - - assert_almost_equal(numerical_grad, grads.item(), decimal=3) + assert_array_almost_equal(numerical_grad, grads.data, decimal=3) class TestDiffLeafDynamicsSPAN: - def test_gradients_span_lai_leaf_dynamics(self): - # prepare model input + @pytest.mark.parametrize( + "param_value", + [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], + ) + def test_gradients_lai_leaf_dynamics(self, param_value): model = get_test_diff_leaf_model() - span = torch.nn.Parameter(torch.tensor(30, dtype=torch.float32)) + span = torch.nn.Parameter(param_value) output = model({"SPAN": span}) - lai = output[0, :, 0] - loss = lai.sum() + loss = output["LAI"].sum() - # this is ∂loss/∂span + # this is ∂loss/∂param without calling loss.backward(). # this is called forward gradient here because it is calculated without backpropagation. grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert grads is not None, "Gradients for SPAN should not be None" + assert grads is not None, "Gradients should not be None" span.grad = None # clear any existing gradient loss.backward() - # this is ∂loss/∂span calculated using backpropagation + # this is ∂loss/∂param calculated using backpropagation grad_backward = span.grad - assert grad_backward is not None, "Backward gradients for SPAN should not be None" - assert grad_backward == grads, "Forward and backward gradients for SPAN should match" + assert grad_backward is not None, "Backward gradients should not be None" + assert torch.all(grad_backward == grads), "Forward and backward gradients should match" - def test_gradients_span_lai_leaf_dynamics_numerical(self): + @pytest.mark.parametrize( + "param_value", + [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], + ) + def test_gradients_lai_leaf_dynamics_numerical(self, param_value): # first check if the numerical gradient isnot zero i.e. the parameter has an effect - span = torch.nn.Parameter(torch.tensor(30, dtype=torch.float64)) - output_index = 0 # LAI is at index 0 + span = torch.nn.Parameter(param_value) numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "SPAN", span, output_index - ) # this is Δloss/Δspan + get_test_diff_leaf_model, "SPAN", span.data, "LAI" + ) # this is Δloss/Δparam model = get_test_diff_leaf_model() output = model({"SPAN": span}) - lai = output[0, :, output_index] - loss = lai.sum() - # this is ∂loss/∂tdwi, for comparison with numerical gradient + loss = output["LAI"].sum() + + # this is ∂loss/∂param, for comparison with numerical gradient grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert_almost_equal(numerical_grad, grads.item(), decimal=3) + assert_array_almost_equal(numerical_grad, grads.data, decimal=3) - def test_gradients_span_twlv_leaf_dynamics(self): - # prepare model input + @pytest.mark.parametrize( + "param_value", + [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], + ) + def test_gradients_twlv_leaf_dynamics(self, param_value): model = get_test_diff_leaf_model() - span = torch.nn.Parameter(torch.tensor(30, dtype=torch.float32)) + span = torch.nn.Parameter(param_value) output = model({"SPAN": span}) - twlv = output[0, :, 1] - loss = twlv.sum() + loss = output["TWLV"].sum() - # this is ∂loss/∂span + # this is ∂loss/∂param without calling loss.backward(). # this is called forward gradient here because it is calculated without backpropagation. grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert grads is not None, "Gradients for SPAN should not be None" + assert grads is not None, "Gradients should not be None" span.grad = None # clear any existing gradient loss.backward() - # this is ∂loss/∂span calculated using backpropagation + # this is ∂loss/∂param calculated using backpropagation grad_backward = span.grad - assert grad_backward is not None, "Backward gradients for SPAN should not be None" - assert grad_backward == grads, "Forward and backward gradients for SPAN should match" + assert grad_backward is not None, "Backward gradients should not be None" + assert torch.all(grad_backward == grads), "Forward and backward gradients should match" - def test_gradients_span_twlv_leaf_dynamics_numerical(self): + @pytest.mark.parametrize( + "param_value", + [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], + ) + def test_gradients_leaf_twlv_dynamics_numerical(self, param_value): # first check if the numerical gradient isnot zero i.e. the parameter has an effect - span = torch.nn.Parameter(torch.tensor(30, dtype=torch.float64)) - output_index = 1 # TWLV is at index 1 + span = torch.nn.Parameter(param_value) numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "SPAN", span, output_index - ) # this is Δloss/Δspan + get_test_diff_leaf_model, "SPAN", span.data, "TWLV" + ) # this is Δloss/Δparam model = get_test_diff_leaf_model() output = model({"SPAN": span}) - twlv = output[0, :, output_index] - loss = twlv.sum() - # this is ∂loss/∂tdwi, for comparison with numerical gradient + loss = output["TWLV"].sum() + + # this is ∂loss/∂param, for comparison with numerical gradient grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert numerical_grad == 0.0 - assert_almost_equal(numerical_grad, grads.item(), decimal=3) + assert_array_almost_equal(grads.data, 0.0) + assert_array_almost_equal(numerical_grad, grads.data, decimal=3) diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 02ed12f..7ae9893 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -2,7 +2,7 @@ from unittest.mock import patch import pytest import torch -from numpy.testing import assert_almost_equal +from numpy.testing import assert_array_almost_equal from pcse.engine import Engine from pcse.models import Wofost72_PP from diffwofost.physical_models.crop.root_dynamics import WOFOST_Root_Dynamics @@ -60,9 +60,7 @@ def forward(self, params_dict): engine.run_till_terminate() results = engine.get_output() - return torch.stack([torch.stack([item["RD"], item["TWRT"]]) for item in results]).unsqueeze( - 0 - ) # shape: [1, time_steps, 2] + return {var: torch.stack([item[var] for item in results]) for var in ["RD", "TWRT"]} class TestRootDynamics: @@ -154,48 +152,17 @@ def test_gradients_tdwi_rd_root_dynamics(self): model = get_test_diff_root_model() tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) output = model({"TDWI": tdwi}) - rd = output[0, :, 0] + rd = output["RD"] loss = rd.sum() - # this is ∂loss/∂tdwi without calling loss.backward(). - # this is called forward gradient here because it is calculated without backpropagation. - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - - assert grads is not None, "Gradients for TDWI should not be None" - - tdwi.grad = None # clear any existing gradient - loss.backward() - - # this is ∂loss/∂tdwi calculated using backpropagation - grad_backward = tdwi.grad - - assert grad_backward is not None, "Backward gradients for TDWI should not be None" - assert grad_backward == grads, "Forward and backward gradients for TDWI should match" - - def test_gradients_tdwi_rd_root_dynamics_numerical(self): - tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64)) - output_index = 0 # Index 0 is "RD" - numerical_grad = calculate_numerical_grad( - get_test_diff_root_model, "TDWI", tdwi, output_index - ) - - model = get_test_diff_root_model() - output = model({"TDWI": tdwi}) - rd = output[0, :, output_index] # Index 0 is "RD" - loss = rd.sum() - - # this is ∂loss/∂tdwi, for comparison with numerical gradient - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - - # in this test, grads is very small - assert_almost_equal(numerical_grad, grads.item(), decimal=3) + assert loss.grad_fn is None # tdwi does not contribute to rd def test_gradients_tdwi_twrt_root_dynamics(self): # prepare model input model = get_test_diff_root_model() tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) output = model({"TDWI": tdwi}) - twrt = output[0, :, 1] # Index 1 is "TWRT" + twrt = output["TWRT"] loss = twrt.sum() # this is ∂loss/∂tdwi @@ -215,18 +182,18 @@ def test_gradients_tdwi_twrt_root_dynamics(self): def test_gradients_tdwi_twrt_root_dynamics_numerical(self): tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64)) - output_index = 1 # Index 1 is "TWRT" + output_name = "TWRT" # Index 1 is "TWRT" numerical_grad = calculate_numerical_grad( - get_test_diff_root_model, "TDWI", tdwi, output_index + get_test_diff_root_model, "TDWI", tdwi.data, output_name ) model = get_test_diff_root_model() output = model({"TDWI": tdwi}) - twrt = output[0, :, output_index] + twrt = output[output_name] loss = twrt.sum() # this is ∂loss/∂tdwi, for comparison with numerical gradient grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] # in this test, grads is very small - assert_almost_equal(numerical_grad, grads.item(), decimal=3) + assert_array_almost_equal(numerical_grad, grads.item(), decimal=3)