diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 62f3d32..5ec36ba 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -14,6 +14,7 @@ from pcse.traitlets import Any from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_drv from diffwofost.physical_models.utils import _get_params_shape DTYPE = torch.float64 # Default data type for tensors in this module @@ -100,12 +101,25 @@ class WOFOST_Leaf_Dynamics(SimulationObject): |--------|-------------------------------------------------------|------|-------------| | LAI | Leaf area index, including stem and pod area | Y | - | | TWLV | Dry weight of total leaves (living + dead) | Y | kg ha⁻¹ | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|------------------------------------------| + | LAI | TDWI, SPAN, RGRLAI, TBASE, KDIFTB, SLATB | + | TWLV | TDWI, PERDL | + + [!] Notice that the following gradients are zero: + - ∂SPAN/∂LAI + - ∂PERDL/∂TWLV + - ∂KDIFTB/∂LAI """ # noqa: E501 # The following parameters are used to initialize and control the arrays that store information # on the leaf classes during the time integration: leaf area, age, and biomass. START_DATE = None # Start date of the simulation MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts) + params_shape = None # Shape of the parameters tensors class Parameters(ParamTemplate): RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) @@ -113,8 +127,8 @@ class Parameters(ParamTemplate): TBASE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) PERDL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - SLATB = AfgenTrait() # FIXME - KDIFTB = AfgenTrait() # FIXME + SLATB = AfgenTrait() + KDIFTB = AfgenTrait() class StateVariables(StatesTemplate): LV = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) @@ -172,17 +186,17 @@ def initialize( DVS = self.kiosk["DVS"] params = self.params - shape = _get_params_shape(params) + self.params_shape = _get_params_shape(params) # Initial leaf biomass WLV = (params.TDWI * (1 - FR)) * FL - DWLV = torch.zeros(shape, dtype=DTYPE) + DWLV = torch.zeros(self.params_shape, dtype=DTYPE) TWLV = WLV + DWLV # Initialize leaf classes (SLA, age and weight) - SLA = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) - LVAGE = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) - LV = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) + SLA = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) + LVAGE = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) + LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) SLA[..., 0] = params.SLATB(DVS) LV[..., 0] = WLV @@ -292,16 +306,20 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # Total death rate leaves r.DRLV = torch.maximum(r.DSLV, r.DALV) + # Get the temperature from the drv + TEMP = _get_drv(drv.TEMP, self.params_shape) + # physiologic ageing of leaves per time step - FYSAGE = (drv.TEMP - p.TBASE) / (35.0 - p.TBASE) + TBASE = _broadcast_to(p.TBASE, self.params_shape) + FYSAGE = (TEMP - TBASE) / (35.0 - TBASE) r.FYSAGE = dvs_mask * torch.clamp(FYSAGE, 0.0) # specific leaf area of leaves per time step - r.SLAT = dvs_mask * torch.tensor(p.SLATB(DVS), dtype=DTYPE) + r.SLAT = dvs_mask * p.SLATB(DVS) # leaf area not to exceed exponential growth curve is_lai_exp = s.LAIEXP < 6.0 - DTEFF = torch.clamp(drv.TEMP - p.TBASE, 0.0) + DTEFF = torch.clamp(TEMP - TBASE, 0.0) # NOTE: conditional statements do not allow for the gradient to be # tracked through the condition. Thus, the gradient with respect to @@ -324,9 +342,10 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: GLA = torch.minimum(r.GLAIEX, r.GLASOL) # adjustment of specific leaf area of youngest leaf class + epsilon = 1e-10 # small value to avoid division by zero r.SLAT = torch.where( dvs_mask.bool(), - torch.where(is_lai_exp & (r.GRLV > 0.0), GLA / r.GRLV, r.SLAT), + torch.where(is_lai_exp & (r.GRLV > epsilon), GLA / (r.GRLV + epsilon), r.SLAT), torch.tensor(0.0, dtype=DTYPE), ) @@ -367,7 +386,7 @@ def integrate(self, day: datetime.date, delt=1.0) -> None: # tracked through the condition. Thus, the gradient with respect to # parameters that contribute to `is_alive` are expected to be incorrect. tLV = torch.where(is_alive, tLV, 0.0) - tLVAGE = tLVAGE + rates.FYSAGE + tLVAGE = tLVAGE + rates.FYSAGE.unsqueeze(-1) tLVAGE = torch.where(is_alive, tLVAGE, 0.0) tSLA = torch.where(is_alive, tSLA, 0.0) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index 635bc10..3d5750c 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -26,6 +26,7 @@ from pcse.base.weather import WeatherDataProvider from pcse.engine import BaseEngine from pcse.engine import Engine +from pcse.settings import settings from pcse.timer import Timer from pcse.traitlets import TraitType @@ -200,8 +201,13 @@ def _run(self): class WeatherDataProviderTestHelper(WeatherDataProvider): """It stores the weatherdata contained within the YAML tests.""" - def __init__(self, yaml_weather): + def __init__(self, yaml_weather, meteo_range_checks=True): super().__init__() + # This is a temporary workaround. The `METEO_RANGE_CHECKS` logic in + # `__setattr__` method in `WeatherDataContainer` is not vector compatible + # yet. So we can disable it here when creating the `WeatherDataContainer` + # instances with arrays. + settings.METEO_RANGE_CHECKS = meteo_range_checks for weather in yaml_weather: if "SNOWDEPTH" in weather: weather.pop("SNOWDEPTH") @@ -209,12 +215,16 @@ def __init__(self, yaml_weather): self._store_WeatherDataContainer(wdc, wdc.DAY) -def prepare_engine_input(test_data, crop_model_params, dtype=torch.float64): +def prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=True, dtype=torch.float64 +): """Prepare the inputs for the engine from the YAML file.""" agro_management_inputs = test_data["AgroManagement"] cropd = test_data["ModelParameters"] - weather_data_provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) + weather_data_provider = WeatherDataProviderTestHelper( + test_data["WeatherVariables"], meteo_range_checks=meteo_range_checks + ) crop_model_params_provider = ParameterProvider(cropdata=cropd) external_states = test_data["ExternalStates"] @@ -539,6 +549,8 @@ def _get_params_shape(params): Parameters can have arbitrary number of dimensions, but all parameters that are not zero- dimensional should have the same shape. + + This check if fundamental for vectorized operations in the physical models. """ shape = () for parname in params.trait_names(): @@ -556,8 +568,43 @@ def _get_params_shape(params): return shape +def _get_drv(drv_var, expected_shape): + """Check that the driving variables have the expected shape and fetch them. + + Driving variables can be scalars (0-dimensional) or match the expected shape. + Scalars will be broadcast during operations. + + [!] This function will be redundant once weathercontainer supports batched variables. + + Args: + drv_var: driving variable in WeatherDataContainer + expected_shape: Expected shape tuple for non-scalar variables + + Raises: + ValueError: If any variable has incompatible shape + + Returns: + torch.Tensor: The validated variable, either as-is or broadcasted to expected shape. + """ + # Check shape: must be scalar (0-d) or match expected_shape + if not isinstance(drv_var, torch.Tensor) or drv_var.dim() == 0: + # Scalar is valid, will be broadcast + return _broadcast_to(drv_var, expected_shape) + elif drv_var.shape == expected_shape: + # Matches expected shape + return drv_var + else: + raise ValueError( + f"Requested weather variable has incompatible shape {drv_var.shape}. " + f"Expected scalar (0-dimensional) or shape {expected_shape}." + ) + + def _broadcast_to(x, shape): """Create a view of tensor X with the given shape.""" + # If x is not a tensor, convert it + if not isinstance(x, torch.Tensor): + x = torch.tensor(x, dtype=DTYPE) # If already the correct shape, return as-is if x.shape == shape: return x diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index a967829..699979a 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -1,4 +1,5 @@ import copy +import warnings from unittest.mock import patch import pytest import torch @@ -133,45 +134,70 @@ def test_leaf_dynamics_with_engine(self): config_path, ) - @pytest.mark.parametrize("param", ["TDWI", "SPAN"]) + @pytest.mark.parametrize( + "param", ["TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB", "TEMP"] + ) def test_leaf_dynamics_with_one_parameter_vector(self, param): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") # Setting a vector (with one value) for the selected parameter - repeated = crop_model_params_provider[param].repeat(10) - crop_model_params_provider.set_override(param, repeated, check=False) + if param == "TEMP": + # Vectorize weather variable + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + elif param in ["KDIFTB", "SLATB"]: + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[param].repeat(10, 1) + crop_model_params_provider.set_override(param, repeated, check=False) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) - engine = EngineTestHelper( - crop_model_params_provider, - weather_data_provider, - agro_management_inputs, - config_path, - external_states, - ) - engine.run_till_terminate() - actual_results = engine.get_output() + if param == "TEMP": + # Expect error due to incompatible shapes + # (By defaults parameters are not reshaped following weather variables) + with pytest.raises(ValueError): + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + else: + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() - # get expected results from YAML test data - expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + # get expected results from YAML test data + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] - assert len(actual_results) == len(expected_results) + assert len(actual_results) == len(expected_results) - for reference, model in zip(expected_results, actual_results, strict=False): - assert reference["DAY"] == model["day"] - assert all( - all(abs(reference[var] - model[var]) < precision) - for var, precision in expected_precision.items() - ) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) @pytest.mark.parametrize( "param,delta", @@ -184,7 +210,7 @@ def test_leaf_dynamics_with_different_parameter_values(self, param, delta): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -226,7 +252,7 @@ def test_leaf_dynamics_with_multiple_parameter_vectors(self): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -236,8 +262,12 @@ def test_leaf_dynamics_with_multiple_parameter_vectors(self): config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") # Setting a vector (with one value) for the TDWI and SPAN parameters - for param in ("TDWI", "SPAN"): - repeated = crop_model_params_provider[param].repeat(10) + for param in ("TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"): + if param in ("KDIFTB", "SLATB"): + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) crop_model_params_provider.set_override(param, repeated, check=False) engine = EngineTestHelper( @@ -266,21 +296,27 @@ def test_leaf_dynamics_with_multiple_parameter_arrays(self): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") - # Setting an array with arbitrary shape (and one value) for the - # TDWI and SPAN parameters - for param in ("TDWI", "SPAN"): - repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + # Setting an array with arbitrary shape (and one value) + for param in ("RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"): + if param in ("KDIFTB", "SLATB"): + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[param].repeat(30, 5, 1) + else: + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) crop_model_params_provider.set_override(param, repeated, check=False) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + engine = EngineTestHelper( crop_model_params_provider, weather_data_provider, @@ -310,7 +346,7 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -337,11 +373,40 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): external_states, ) + def test_leaf_dynamics_with_incompatible_weather_parameter_vectors(self): + # prepare model input + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + + # Setting vectors with incompatible shapes: TDWI and TEMP + crop_model_params_provider.set_override( + "TDWI", crop_model_params_provider["TDWI"].repeat(10), check=False + ) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(5, dtype=torch.float64) * wdc.TEMP + + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) def test_wofost_pp_with_leaf_dynamics(self, test_data_url): # prepare model input test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( prepare_engine_input(test_data, crop_model_params) ) @@ -370,7 +435,7 @@ def test_leaf_dynamics_with_sigmoid_approx(self, test_data_url): """Test if sigmoid approximation gives same results as leaf dynamics.""" # prepare model input test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -405,144 +470,133 @@ def test_leaf_dynamics_with_sigmoid_approx(self, test_data_url): ) -class TestDiffLeafDynamicsTDWI: - @pytest.mark.parametrize( - "param_value,out_name", - [ - (torch.tensor(0.2, dtype=torch.float64), "LAI"), - (torch.tensor(0.2, dtype=torch.float64), "TWLV"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "LAI"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "TWLV"), - ], - ) - def test_gradients_leaf_dynamics(self, param_value, out_name): - model = get_test_diff_leaf_model() - tdwi = torch.nn.Parameter(param_value) - output = model({"TDWI": tdwi}) - loss = output[out_name].sum() - - # this is ∂loss/∂param without calling loss.backward(). - # this is called forward gradient here because it is calculated without backpropagation. - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - assert grads is not None, "Gradients should not be None" - - tdwi.grad = None # clear any existing gradient - loss.backward() - # this is ∂loss/∂param calculated using backpropagation - grad_backward = tdwi.grad - assert grad_backward is not None, "Backward gradients should not be None" - assert torch.all(grad_backward == grads), "Forward and backward gradients should match" - - @pytest.mark.parametrize( - "param_value,out_name", - [ - (torch.tensor(0.2, dtype=torch.float64), "LAI"), - (torch.tensor(0.2, dtype=torch.float64), "TWLV"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "LAI"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "TWLV"), - ], - ) - def test_gradients_leaf_dynamics_numerical(self, param_value, out_name): - # first check if the numerical gradient isnot zero i.e. the parameter has an effect - tdwi = torch.nn.Parameter(param_value) - numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "TDWI", tdwi.data, out_name - ) # this is Δloss/Δparam - +class TestDiffLeafDynamicsGradients: + """Parametrized tests for gradient calculations in leaf dynamics.""" + + # Define parameters and outputs + param_names = ["TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"] + output_names = ["LAI", "TWLV"] + + # Define parameter configurations (value, dtype) + param_configs = { + "single": { + "TDWI": (0.2, torch.float64), + "SPAN": (30, torch.float64), + "RGRLAI": (0.016, torch.float64), + "TBASE": (3.0, torch.float64), + "PERDL": (0.03, torch.float64), + "KDIFTB": ([[0.0, 0.6, 2.0, 0.6]], torch.float64), + "SLATB": ([[0.0, 0.002, 2.0, 0.002]], torch.float64), + }, + "tensor": { + "TDWI": ([0.1, 0.2, 0.3], torch.float64), + "SPAN": ([25, 30, 35], torch.float64), + "RGRLAI": ([-10, 0.08, 1], torch.float64), + "TBASE": ([-5, 0, 10.0], torch.float64), + "PERDL": ([-10, 0.1, 15], torch.float64), + "KDIFTB": ( + [[0.0, 0.5, 10.0, 1.0], [0.0, 0.6, 12.0, 1.2], [0.0, 0.4, 8.0, 0.8]], + torch.float64, + ), + "SLATB": ( + [ + [0.0, 0.002031, 0.5, 0.002031, 2.0, 0.002031], + [0.0, 0.0025, 0.6, 0.0025, 2.5, 0.0025], + [0.0, 0.0015, 0.4, 0.0015, 1.5, 0.0015], + ], + torch.float64, + ), + }, + } + + # Define which parameter-output pairs should have gradients + # Format: {param_name: [list of outputs that should have gradients]} + gradient_mapping = { + "TDWI": ["LAI", "TWLV"], + "SPAN": ["LAI"], + "RGRLAI": ["LAI"], + "TBASE": ["LAI"], + "PERDL": ["TWLV"], + "KDIFTB": ["LAI"], + "SLATB": ["LAI"], + } + + # Generate all combinations + gradient_params = [] + no_gradient_params = [] + for param_name in param_names: + for output_name in output_names: + if output_name in gradient_mapping.get(param_name, []): + gradient_params.append((param_name, output_name)) + else: + no_gradient_params.append((param_name, output_name)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type): + """Test cases where parameters should not have gradients for specific outputs.""" model = get_test_diff_leaf_model() - output = model({"TDWI": tdwi}) - loss = output[out_name].sum() - - # this is ∂loss/∂param, for comparison with numerical gradient - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - - assert_array_almost_equal(numerical_grad, grads.data, decimal=3) - + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + output = model({param_name: param}) + loss = output[output_name].sum() + + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) -class TestDiffLeafDynamicsSPAN: - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_lai_leaf_dynamics(self, param_value): + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type): + """Test that forward and backward gradients match for parameter-output pairs.""" model = get_test_diff_leaf_model() - span = torch.nn.Parameter(param_value) - output = model({"SPAN": span}) - loss = output["LAI"].sum() + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + output = model({param_name: param}) + loss = output[output_name].sum() - # this is ∂loss/∂param without calling loss.backward(). + # this is ∂loss/∂param # this is called forward gradient here because it is calculated without backpropagation. - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert grads is not None, "Gradients should not be None" - - span.grad = None # clear any existing gradient - loss.backward() - # this is ∂loss/∂param calculated using backpropagation - grad_backward = span.grad - - assert grad_backward is not None, "Backward gradients should not be None" - assert torch.all(grad_backward == grads), "Forward and backward gradients should match" - - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_lai_leaf_dynamics_numerical(self, param_value): - # first check if the numerical gradient isnot zero i.e. the parameter has an effect - span = torch.nn.Parameter(param_value) - numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "SPAN", span.data, "LAI" - ) # this is Δloss/Δparam - - model = get_test_diff_leaf_model() - output = model({"SPAN": span}) - loss = output["LAI"].sum() - - # this is ∂loss/∂param, for comparison with numerical gradient - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - - assert_array_almost_equal(numerical_grad, grads.data, decimal=3) - - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_twlv_leaf_dynamics(self, param_value): - model = get_test_diff_leaf_model() - span = torch.nn.Parameter(param_value) - output = model({"SPAN": span}) - loss = output["TWLV"].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] - # this is ∂loss/∂param without calling loss.backward(). - # this is called forward gradient here because it is calculated without backpropagation. - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert grads is not None, "Gradients should not be None" + assert grads is not None, f"Gradients for {param_name} should not be None" - span.grad = None # clear any existing gradient + param.grad = None # clear any existing gradient loss.backward() + # this is ∂loss/∂param calculated using backpropagation - grad_backward = span.grad + grad_backward = param.grad - assert grad_backward is not None, "Backward gradients should not be None" - assert torch.all(grad_backward == grads), "Forward and backward gradients should match" + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), ( + f"Forward and backward gradients for {param_name} should match" + ) - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_leaf_twlv_dynamics_numerical(self, param_value): - # first check if the numerical gradient isnot zero i.e. the parameter has an effect - span = torch.nn.Parameter(param_value) + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type): + """Test that analytical gradients match numerical gradients.""" + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64)) numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "SPAN", span.data, "TWLV" - ) # this is Δloss/Δparam + get_test_diff_leaf_model, param_name, param.data, output_name + ) model = get_test_diff_leaf_model() - output = model({"SPAN": span}) - loss = output["TWLV"].sum() + output = model({param_name: param}) + loss = output[output_name].sum() # this is ∂loss/∂param, for comparison with numerical gradient - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] - assert_array_almost_equal(grads.data, 0.0) assert_array_almost_equal(numerical_grad, grads.data, decimal=3) + + # Warn if gradient is zero (but this shouldn't happen for gradient_params) + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' with respect to output " + + f"'{output_name}' is zero: {grads.data}", + UserWarning, + ) diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index e86aa0a..ec009de 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -5,6 +5,10 @@ from diffwofost.physical_models.utils import DTYPE from diffwofost.physical_models.utils import Afgen from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import WeatherDataProviderTestHelper +from diffwofost.physical_models.utils import _get_drv +from diffwofost.physical_models.utils import get_test_data +from . import phy_data_folder class TestAfgen: @@ -537,3 +541,50 @@ def test_backward_compatibility_non_batched(self): assert isinstance(result, torch.Tensor) assert result.dim() == 0 # Scalar tensor assert torch.isclose(result, torch.tensor(5.0, dtype=DTYPE)) + + +class TestGetDrvParam: + """Tests for _get_drv function.""" + + def test_float_broadcast(self): + expected_shape = (3, 2) + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) + wdc = provider(provider.first_date) + scalar = wdc.TEMP + out = _get_drv(scalar, expected_shape) + assert out.shape == expected_shape + assert torch.allclose(out, torch.full(expected_shape, scalar, dtype=DTYPE)) + + def test_scalar_broadcast(self): + expected_shape = (3, 2) + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) + wdc = provider(provider.first_date) + scalar = torch.tensor(wdc.IRRAD, dtype=DTYPE) # 0-d tensor + out = _get_drv(scalar, expected_shape) + assert out.shape == expected_shape + assert torch.allclose(out, torch.full(expected_shape, scalar.item(), dtype=DTYPE)) + + def test_matching_shape_pass_through(self): + expected_shape = (3, 2) + base_val = torch.tensor(12.34, dtype=DTYPE) + var = torch.ones(expected_shape, dtype=DTYPE) * base_val + out = _get_drv(var, expected_shape) + assert out.shape == expected_shape + # Should be the same object (no copy) + assert out.data_ptr() == var.data_ptr() + + def test_wrong_shape_raises(self): + expected_shape = (3, 2) + wrong = torch.ones(2, 3, dtype=DTYPE) + with pytest.raises(ValueError, match="incompatible shape"): + _get_drv(wrong, expected_shape) + + def test_one_dim_shape_raises(self): + expected_shape = (3, 2) + one_dim = torch.ones(3, dtype=DTYPE) + with pytest.raises(ValueError, match="incompatible shape"): + _get_drv(one_dim, expected_shape)