From 3a384a8c0bb71c17c6395e18fe3b2c414e09fefb Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 17 Nov 2025 14:40:17 +0100 Subject: [PATCH 01/13] Make leaf fully differentiable Vectorize leaf parameter Update leaf tests --- .../physical_models/crop/leaf_dynamics.py | 42 ++- .../crop/test_leaf_dynamics.py | 279 +++++++++--------- 2 files changed, 171 insertions(+), 150 deletions(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 62f3d32..f5abec2 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -100,6 +100,18 @@ class WOFOST_Leaf_Dynamics(SimulationObject): |--------|-------------------------------------------------------|------|-------------| | LAI | Leaf area index, including stem and pod area | Y | - | | TWLV | Dry weight of total leaves (living + dead) | Y | kg ha⁻¹ | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|------------------------------------------| + | LAI | TDWI, SPAN, RGRLAI, TBASE, KDIFTB, SLATB | + | TWLV | TDWI, PERDL | + + [!] Notice that the following gradient are zero: + - ∂SPAN/∂LAI + - ∂PERDL/∂TWLV + - ∂KDIFTB/∂LAI """ # noqa: E501 # The following parameters are used to initialize and control the arrays that store information @@ -113,8 +125,9 @@ class Parameters(ParamTemplate): TBASE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) PERDL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - SLATB = AfgenTrait() # FIXME - KDIFTB = AfgenTrait() # FIXME + SLATB = AfgenTrait() + KDIFTB = AfgenTrait() + shape = None # shape of the parameters (set during initialization) class StateVariables(StatesTemplate): LV = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) @@ -172,17 +185,17 @@ def initialize( DVS = self.kiosk["DVS"] params = self.params - shape = _get_params_shape(params) + params.shape = _get_params_shape(params) # Initial leaf biomass WLV = (params.TDWI * (1 - FR)) * FL - DWLV = torch.zeros(shape, dtype=DTYPE) + DWLV = torch.zeros(params.shape, dtype=DTYPE) TWLV = WLV + DWLV # Initialize leaf classes (SLA, age and weight) - SLA = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) - LVAGE = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) - LV = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE) + SLA = torch.zeros((*params.shape, self.MAX_DAYS), dtype=DTYPE) + LVAGE = torch.zeros((*params.shape, self.MAX_DAYS), dtype=DTYPE) + LV = torch.zeros((*params.shape, self.MAX_DAYS), dtype=DTYPE) SLA[..., 0] = params.SLATB(DVS) LV[..., 0] = WLV @@ -292,16 +305,20 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # Total death rate leaves r.DRLV = torch.maximum(r.DSLV, r.DALV) + # Get the temperature from the drv + TEMP = _broadcast_to(torch.as_tensor(drv.TEMP, dtype=DTYPE), p.shape) + # physiologic ageing of leaves per time step - FYSAGE = (drv.TEMP - p.TBASE) / (35.0 - p.TBASE) + TBASE = _broadcast_to(p.TBASE, p.shape) + FYSAGE = (TEMP - TBASE) / (35.0 - TBASE) r.FYSAGE = dvs_mask * torch.clamp(FYSAGE, 0.0) # specific leaf area of leaves per time step - r.SLAT = dvs_mask * torch.tensor(p.SLATB(DVS), dtype=DTYPE) + r.SLAT = dvs_mask * p.SLATB(DVS) # leaf area not to exceed exponential growth curve is_lai_exp = s.LAIEXP < 6.0 - DTEFF = torch.clamp(drv.TEMP - p.TBASE, 0.0) + DTEFF = torch.clamp(TEMP - TBASE, 0.0) # NOTE: conditional statements do not allow for the gradient to be # tracked through the condition. Thus, the gradient with respect to @@ -324,9 +341,10 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: GLA = torch.minimum(r.GLAIEX, r.GLASOL) # adjustment of specific leaf area of youngest leaf class + epsilon = 1e-10 # small value to avoid division by zero r.SLAT = torch.where( dvs_mask.bool(), - torch.where(is_lai_exp & (r.GRLV > 0.0), GLA / r.GRLV, r.SLAT), + torch.where(is_lai_exp & (r.GRLV > epsilon), GLA / (r.GRLV + epsilon), r.SLAT), torch.tensor(0.0, dtype=DTYPE), ) @@ -367,7 +385,7 @@ def integrate(self, day: datetime.date, delt=1.0) -> None: # tracked through the condition. Thus, the gradient with respect to # parameters that contribute to `is_alive` are expected to be incorrect. tLV = torch.where(is_alive, tLV, 0.0) - tLVAGE = tLVAGE + rates.FYSAGE + tLVAGE = tLVAGE + rates.FYSAGE.unsqueeze(-1) tLVAGE = torch.where(is_alive, tLVAGE, 0.0) tSLA = torch.where(is_alive, tSLA, 0.0) diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index a967829..ed1716d 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -1,4 +1,5 @@ import copy +import warnings from unittest.mock import patch import pytest import torch @@ -133,12 +134,14 @@ def test_leaf_dynamics_with_engine(self): config_path, ) - @pytest.mark.parametrize("param", ["TDWI", "SPAN"]) + @pytest.mark.parametrize( + "param", ["TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"] + ) def test_leaf_dynamics_with_one_parameter_vector(self, param): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -148,7 +151,11 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") # Setting a vector (with one value) for the selected parameter - repeated = crop_model_params_provider[param].repeat(10) + if param in ["KDIFTB", "SLATB"]: + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) crop_model_params_provider.set_override(param, repeated, check=False) engine = EngineTestHelper( @@ -184,7 +191,7 @@ def test_leaf_dynamics_with_different_parameter_values(self, param, delta): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -226,7 +233,7 @@ def test_leaf_dynamics_with_multiple_parameter_vectors(self): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -236,8 +243,12 @@ def test_leaf_dynamics_with_multiple_parameter_vectors(self): config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") # Setting a vector (with one value) for the TDWI and SPAN parameters - for param in ("TDWI", "SPAN"): - repeated = crop_model_params_provider[param].repeat(10) + for param in ("TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"): + if param in ("KDIFTB", "SLATB"): + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) crop_model_params_provider.set_override(param, repeated, check=False) engine = EngineTestHelper( @@ -266,7 +277,7 @@ def test_leaf_dynamics_with_multiple_parameter_arrays(self): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -275,10 +286,13 @@ def test_leaf_dynamics_with_multiple_parameter_arrays(self): ) = prepare_engine_input(test_data, crop_model_params) config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") - # Setting an array with arbitrary shape (and one value) for the - # TDWI and SPAN parameters - for param in ("TDWI", "SPAN"): - repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + # Setting an array with arbitrary shape (and one value) + for param in ("RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"): + if param in ("KDIFTB", "SLATB"): + # AfgenTrait parameters need to have shape (N, M) + repeated = crop_model_params_provider[param].repeat(30, 5, 1) + else: + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) crop_model_params_provider.set_override(param, repeated, check=False) engine = EngineTestHelper( @@ -310,7 +324,7 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -341,7 +355,7 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): def test_wofost_pp_with_leaf_dynamics(self, test_data_url): # prepare model input test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( prepare_engine_input(test_data, crop_model_params) ) @@ -370,7 +384,7 @@ def test_leaf_dynamics_with_sigmoid_approx(self, test_data_url): """Test if sigmoid approximation gives same results as leaf dynamics.""" # prepare model input test_data = get_test_data(test_data_url) - crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] ( crop_model_params_provider, weather_data_provider, @@ -405,144 +419,133 @@ def test_leaf_dynamics_with_sigmoid_approx(self, test_data_url): ) -class TestDiffLeafDynamicsTDWI: - @pytest.mark.parametrize( - "param_value,out_name", - [ - (torch.tensor(0.2, dtype=torch.float64), "LAI"), - (torch.tensor(0.2, dtype=torch.float64), "TWLV"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "LAI"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "TWLV"), - ], - ) - def test_gradients_leaf_dynamics(self, param_value, out_name): +class TestDiffLeafDynamicsGradients: + """Parametrized tests for gradient calculations in leaf dynamics.""" + + # Define parameters and outputs + param_names = ["TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"] + output_names = ["LAI", "TWLV"] + + # Define parameter configurations (value, dtype) + param_configs = { + "single": { + "TDWI": (0.2, torch.float64), + "SPAN": (30, torch.float64), + "RGRLAI": (0.016, torch.float64), + "TBASE": (3.0, torch.float64), + "PERDL": (0.03, torch.float64), + "KDIFTB": ([[0.0, 0.6, 2.0, 0.6]], torch.float64), + "SLATB": ([[0.0, 0.002, 2.0, 0.002]], torch.float64), + }, + "tensor": { + "TDWI": ([0.1, 0.2, 0.3], torch.float64), + "SPAN": ([25, 30, 35], torch.float64), + "RGRLAI": ([-10, 0.08, 1], torch.float64), + "TBASE": ([-5, 0, 10.0], torch.float64), + "PERDL": ([-10, 0.1, 15], torch.float64), + "KDIFTB": ( + [[0.0, 0.5, 10.0, 1.0], [0.0, 0.6, 12.0, 1.2], [0.0, 0.4, 8.0, 0.8]], + torch.float64, + ), + "SLATB": ( + [ + [0.0, 0.002031, 0.5, 0.002031, 2.0, 0.002031], + [0.0, 0.0025, 0.6, 0.0025, 2.5, 0.0025], + [0.0, 0.0015, 0.4, 0.0015, 1.5, 0.0015], + ], + torch.float64, + ), + }, + } + + # Define which parameter-output pairs should have gradients + # Format: {param_name: [list of outputs that should have gradients]} + gradient_mapping = { + "TDWI": ["LAI", "TWLV"], + "SPAN": ["LAI"], + "RGRLAI": ["LAI"], + "TBASE": ["LAI"], + "PERDL": ["TWLV"], + "KDIFTB": ["LAI"], + "SLATB": ["LAI"], + } + + # Generate all combinations + gradient_params = [] + no_gradient_params = [] + for param_name in param_names: + for output_name in output_names: + if output_name in gradient_mapping.get(param_name, []): + gradient_params.append((param_name, output_name)) + else: + no_gradient_params.append((param_name, output_name)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type): + """Test cases where parameters should not have gradients for specific outputs.""" model = get_test_diff_leaf_model() - tdwi = torch.nn.Parameter(param_value) - output = model({"TDWI": tdwi}) - loss = output[out_name].sum() - - # this is ∂loss/∂param without calling loss.backward(). - # this is called forward gradient here because it is calculated without backpropagation. - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - assert grads is not None, "Gradients should not be None" - - tdwi.grad = None # clear any existing gradient - loss.backward() - # this is ∂loss/∂param calculated using backpropagation - grad_backward = tdwi.grad - assert grad_backward is not None, "Backward gradients should not be None" - assert torch.all(grad_backward == grads), "Forward and backward gradients should match" - - @pytest.mark.parametrize( - "param_value,out_name", - [ - (torch.tensor(0.2, dtype=torch.float64), "LAI"), - (torch.tensor(0.2, dtype=torch.float64), "TWLV"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "LAI"), - (torch.tensor([0.1, 0.2, 0.3], dtype=torch.float64), "TWLV"), - ], - ) - def test_gradients_leaf_dynamics_numerical(self, param_value, out_name): - # first check if the numerical gradient isnot zero i.e. the parameter has an effect - tdwi = torch.nn.Parameter(param_value) - numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "TDWI", tdwi.data, out_name - ) # this is Δloss/Δparam - - model = get_test_diff_leaf_model() - output = model({"TDWI": tdwi}) - loss = output[out_name].sum() - - # this is ∂loss/∂param, for comparison with numerical gradient - grads = torch.autograd.grad(loss, tdwi, retain_graph=True)[0] - - assert_array_almost_equal(numerical_grad, grads.data, decimal=3) - + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + output = model({param_name: param}) + loss = output[output_name].sum() + + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) -class TestDiffLeafDynamicsSPAN: - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_lai_leaf_dynamics(self, param_value): + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type): + """Test that forward and backward gradients match for parameter-output pairs.""" model = get_test_diff_leaf_model() - span = torch.nn.Parameter(param_value) - output = model({"SPAN": span}) - loss = output["LAI"].sum() + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + output = model({param_name: param}) + loss = output[output_name].sum() - # this is ∂loss/∂param without calling loss.backward(). + # this is ∂loss/∂param # this is called forward gradient here because it is calculated without backpropagation. - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert grads is not None, "Gradients should not be None" - - span.grad = None # clear any existing gradient - loss.backward() - # this is ∂loss/∂param calculated using backpropagation - grad_backward = span.grad + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] - assert grad_backward is not None, "Backward gradients should not be None" - assert torch.all(grad_backward == grads), "Forward and backward gradients should match" - - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_lai_leaf_dynamics_numerical(self, param_value): - # first check if the numerical gradient isnot zero i.e. the parameter has an effect - span = torch.nn.Parameter(param_value) - numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "SPAN", span.data, "LAI" - ) # this is Δloss/Δparam - - model = get_test_diff_leaf_model() - output = model({"SPAN": span}) - loss = output["LAI"].sum() - - # this is ∂loss/∂param, for comparison with numerical gradient - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] + assert grads is not None, f"Gradients for {param_name} should not be None" - assert_array_almost_equal(numerical_grad, grads.data, decimal=3) - - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_twlv_leaf_dynamics(self, param_value): - model = get_test_diff_leaf_model() - span = torch.nn.Parameter(param_value) - output = model({"SPAN": span}) - loss = output["TWLV"].sum() - - # this is ∂loss/∂param without calling loss.backward(). - # this is called forward gradient here because it is calculated without backpropagation. - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] - assert grads is not None, "Gradients should not be None" - - span.grad = None # clear any existing gradient + param.grad = None # clear any existing gradient loss.backward() + # this is ∂loss/∂param calculated using backpropagation - grad_backward = span.grad + grad_backward = param.grad - assert grad_backward is not None, "Backward gradients should not be None" - assert torch.all(grad_backward == grads), "Forward and backward gradients should match" + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), ( + f"Forward and backward gradients for {param_name} should match" + ) - @pytest.mark.parametrize( - "param_value", - [torch.tensor(30, dtype=torch.float64), torch.tensor([25, 30, 35], dtype=torch.float64)], - ) - def test_gradients_leaf_twlv_dynamics_numerical(self, param_value): - # first check if the numerical gradient isnot zero i.e. the parameter has an effect - span = torch.nn.Parameter(param_value) + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type): + """Test that analytical gradients match numerical gradients.""" + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64)) numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, "SPAN", span.data, "TWLV" - ) # this is Δloss/Δparam + get_test_diff_leaf_model, param_name, param.data, output_name + ) model = get_test_diff_leaf_model() - output = model({"SPAN": span}) - loss = output["TWLV"].sum() + output = model({param_name: param}) + loss = output[output_name].sum() # this is ∂loss/∂param, for comparison with numerical gradient - grads = torch.autograd.grad(loss, span, retain_graph=True)[0] + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] - assert_array_almost_equal(grads.data, 0.0) assert_array_almost_equal(numerical_grad, grads.data, decimal=3) + + # Warn if gradient is zero (but this shouldn't happen for gradient_params) + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' with respect to output " + + f"'{output_name}' is zero: {grads.data}", + UserWarning, + ) From 7418c52f3c8aff6cfa41560b4b57919173920b46 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Tue, 18 Nov 2025 11:37:44 +0100 Subject: [PATCH 02/13] Make drv vector --- .../physical_models/crop/leaf_dynamics.py | 44 ++++++ src/diffwofost/physical_models/utils.py | 79 +++++++++- .../crop/test_leaf_dynamics.py | 42 +++++- tests/physical_models/test_utils.py | 137 ++++++++++++++++++ 4 files changed, 295 insertions(+), 7 deletions(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index f5abec2..2605d47 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -186,6 +186,7 @@ def initialize( params = self.params params.shape = _get_params_shape(params) + self._shape_finalized = False # Track if shape has been updated with drv data # Initial leaf biomass WLV = (params.TDWI * (1 - FR)) * FL @@ -245,6 +246,49 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: p = self.params k = self.kiosk + # Update shape if needed based on drv.TEMP + if not self._shape_finalized: + TEMP = torch.as_tensor(drv.TEMP, dtype=DTYPE) + # Broadcast to get the final shape + temp_shape = TEMP.shape if TEMP.ndim > 0 else () + if temp_shape != p.shape: + old_shape = p.shape + # Validate that shapes can be broadcasted + try: + # Update to the broadcasted shape + p.shape = torch.broadcast_shapes(p.shape, temp_shape) + except RuntimeError as e: + raise ValueError( + f"Parameter shape {p.shape} and weather driver shape {temp_shape} " + f"have incompatible shape and cannot be broadcasted together. " + f"Original error: {e}" + ) from e + + # Reshape state variables to match new parameter shape + # For time-series tensors (LV, SLA, LVAGE), we need to add the batch dimension + if len(old_shape) == 0: + # Old shape was scalar, new shape has batch dimension + s.LV = s.LV.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) + s.SLA = s.SLA.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) + s.LVAGE = s.LVAGE.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) + else: + # Both have batch dimensions, use standard broadcast + s.LV = s.LV.expand(*p.shape, self.MAX_DAYS) + s.SLA = s.SLA.expand(*p.shape, self.MAX_DAYS) + s.LVAGE = s.LVAGE.expand(*p.shape, self.MAX_DAYS) + + # For scalar state variables, use _broadcast_to + s.LAIEM = _broadcast_to(s.LAIEM, p.shape) + s.LASUM = _broadcast_to(s.LASUM, p.shape) + s.LAIEXP = _broadcast_to(s.LAIEXP, p.shape) + s.LAIMAX = _broadcast_to(s.LAIMAX, p.shape) + s.LAI = _broadcast_to(s.LAI, p.shape) + s.WLV = _broadcast_to(s.WLV, p.shape) + s.DWLV = _broadcast_to(s.DWLV, p.shape) + s.TWLV = _broadcast_to(s.TWLV, p.shape) + + self._shape_finalized = True + # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask # A mask (0 if DVS < 0, 1 if DVS >= 0) DVS = torch.as_tensor(k["DVS"], dtype=DTYPE) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index 635bc10..4d604bc 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -26,6 +26,7 @@ from pcse.base.weather import WeatherDataProvider from pcse.engine import BaseEngine from pcse.engine import Engine +from pcse.settings import settings from pcse.timer import Timer from pcse.traitlets import TraitType @@ -200,8 +201,9 @@ def _run(self): class WeatherDataProviderTestHelper(WeatherDataProvider): """It stores the weatherdata contained within the YAML tests.""" - def __init__(self, yaml_weather): + def __init__(self, yaml_weather, meteo_range_checks=True): super().__init__() + settings.METEO_RANGE_CHECKS = meteo_range_checks for weather in yaml_weather: if "SNOWDEPTH" in weather: weather.pop("SNOWDEPTH") @@ -209,12 +211,16 @@ def __init__(self, yaml_weather): self._store_WeatherDataContainer(wdc, wdc.DAY) -def prepare_engine_input(test_data, crop_model_params, dtype=torch.float64): +def prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=True, dtype=torch.float64 +): """Prepare the inputs for the engine from the YAML file.""" agro_management_inputs = test_data["AgroManagement"] cropd = test_data["ModelParameters"] - weather_data_provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) + weather_data_provider = WeatherDataProviderTestHelper( + test_data["WeatherVariables"], meteo_range_checks=meteo_range_checks + ) crop_model_params_provider = ParameterProvider(cropdata=cropd) external_states = test_data["ExternalStates"] @@ -539,6 +545,8 @@ def _get_params_shape(params): Parameters can have arbitrary number of dimensions, but all parameters that are not zero- dimensional should have the same shape. + + This check if fundamental for vectorized operations in the physical models. """ shape = () for parname in params.trait_names(): @@ -556,6 +564,71 @@ def _get_params_shape(params): return shape +def _check_drv_shape(drv, expected_shape): + """Check that the driving variables have the expected shape. + + Driving variables can be scalars (0-dimensional) or match the expected shape. + Scalars will be broadcast during operations. + + Args: + drv: WeatherDataContainer with driving variables + expected_shape: Expected shape tuple for non-scalar variables + + Raises: + ValueError: If any variable has incompatible shape + """ + # Define compulsory and optional weather variables + compulsory_vars = [ + "LAT", + "LON", + "ELEV", + "DAY", + "IRRAD", + "TMIN", + "TMAX", + "VAP", + "RAIN", + "WIND", + "E0", + "ES0", + "ET0", + ] + optional_vars = ["TEMP", "SNOWDEPTH"] + all_vars = compulsory_vars + optional_vars + + for var_name in all_vars: + # Skip if variable doesn't exist (only matters for optional vars) + if not hasattr(drv, var_name): + if var_name in compulsory_vars: + raise ValueError( + f"Compulsory variable '{var_name}' missing from WeatherDataContainer" + ) + continue + + var_value = getattr(drv, var_name) + + # Skip DAY as it's a date object, not a tensor + if var_name == "DAY": + continue + + # Convert to tensor if needed for shape checking + if not isinstance(var_value, torch.Tensor): + var_value = torch.as_tensor(var_value, dtype=DTYPE) + + # Check shape: must be scalar (0-d) or match expected_shape + if var_value.dim() == 0: + # Scalar is valid, will be broadcast + continue + elif var_value.shape == expected_shape: + # Matches expected shape + continue + else: + raise ValueError( + f"Weather variable '{var_name}' has incompatible shape {var_value.shape}. " + f"Expected scalar (0-dimensional) or shape {expected_shape}." + ) + + def _broadcast_to(x, shape): """Create a view of tensor X with the given shape.""" # If already the correct shape, return as-is diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index ed1716d..12eaf0b 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -135,7 +135,7 @@ def test_leaf_dynamics_with_engine(self): ) @pytest.mark.parametrize( - "param", ["TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB"] + "param", ["TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB", "TEMP"] ) def test_leaf_dynamics_with_one_parameter_vector(self, param): # prepare model input @@ -147,16 +147,21 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") # Setting a vector (with one value) for the selected parameter - if param in ["KDIFTB", "SLATB"]: + if param == "TEMP": + # Vectorize weather variable + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + elif param in ["KDIFTB", "SLATB"]: # AfgenTrait parameters need to have shape (N, M) repeated = crop_model_params_provider[param].repeat(10, 1) + crop_model_params_provider.set_override(param, repeated, check=False) else: repeated = crop_model_params_provider[param].repeat(10) - crop_model_params_provider.set_override(param, repeated, check=False) + crop_model_params_provider.set_override(param, repeated, check=False) engine = EngineTestHelper( crop_model_params_provider, @@ -351,6 +356,35 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): external_states, ) + def test_leaf_dynamics_with_incompatible_weather_parameter_vectors(self): + # prepare model input + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI", "KDIFTB", "SLATB"] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + + # Setting vectors with incompatible shapes: TDWI and TEMP + crop_model_params_provider.set_override( + "TDWI", crop_model_params_provider["TDWI"].repeat(10), check=False + ) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(5, dtype=torch.float64) * wdc.TEMP + + with pytest.raises(ValueError, match="incompatible shape"): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) def test_wofost_pp_with_leaf_dynamics(self, test_data_url): # prepare model input diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index e86aa0a..b93bedc 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -5,6 +5,9 @@ from diffwofost.physical_models.utils import DTYPE from diffwofost.physical_models.utils import Afgen from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import WeatherDataProviderTestHelper +from diffwofost.physical_models.utils import get_test_data +from . import phy_data_folder class TestAfgen: @@ -537,3 +540,137 @@ def test_backward_compatibility_non_batched(self): assert isinstance(result, torch.Tensor) assert result.dim() == 0 # Scalar tensor assert torch.isclose(result, torch.tensor(5.0, dtype=DTYPE)) + + +class TestCheckDrvShape: + """Tests for _check_drv_shape function.""" + + def test_scalar_drv(self): + """Test that drv accepts scalar variables for any expected shape.""" + from diffwofost.physical_models.utils import _check_drv_shape + + expected_shape = (3, 2) # Example batch shape + + # Read weather data from test YAML + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + weather_data_provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) + + # Get weather data for a specific date using the provider's __call__ method + first_date = weather_data_provider.first_date + wdc = weather_data_provider(first_date) + + # Check that the weather data container has valid shape + # (all scalar values should pass for any expected_shape) + _check_drv_shape(wdc, expected_shape) + + # If no exception was raised, test passes + assert True + + def test_single_vector_variable(self): + """Test that each weather variable can individually be vectorized to expected_shape.""" + from diffwofost.physical_models.utils import _check_drv_shape + + expected_shape = (3, 2) # Example batch shape + + # Read weather data from test YAML + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + weather_data_provider = WeatherDataProviderTestHelper( + test_data["WeatherVariables"], meteo_range_checks=False + ) + + # List of all weather variables to test + vector_vars = [ + "LAT", + "LON", + "ELEV", + "IRRAD", + "TMIN", + "TMAX", + "VAP", + "RAIN", + "WIND", + "E0", + "ES0", + "ET0", + "TEMP", + ] + + first_date = weather_data_provider.first_date + + # Test each variable individually + for var_name in vector_vars: + wdc = weather_data_provider(first_date) + + # Vectorize only this one variable + original_value = getattr(wdc, var_name) + setattr(wdc, var_name, torch.ones(expected_shape, dtype=DTYPE) * original_value) + + # Should not raise an exception (scalar + one vectorized variable is valid) + _check_drv_shape(wdc, expected_shape) + + def test_all_vector_variables(self): + """Test that all weather variables can be vectorized together to expected_shape.""" + from diffwofost.physical_models.utils import _check_drv_shape + + expected_shape = (3, 2) # Example batch shape + + # Read weather data from test YAML + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + weather_data_provider = WeatherDataProviderTestHelper( + test_data["WeatherVariables"], meteo_range_checks=False + ) + + # Get weather data for a specific date + first_date = weather_data_provider.first_date + wdc = weather_data_provider(first_date) + + # Vectorize all weather variables to the same expected_shape + wdc.LAT = torch.ones(expected_shape, dtype=DTYPE) * wdc.LAT + wdc.LON = torch.ones(expected_shape, dtype=DTYPE) * wdc.LON + wdc.ELEV = torch.ones(expected_shape, dtype=DTYPE) * wdc.ELEV + wdc.IRRAD = torch.ones(expected_shape, dtype=DTYPE) * wdc.IRRAD + wdc.TMIN = torch.ones(expected_shape, dtype=DTYPE) * wdc.TMIN + wdc.TMAX = torch.ones(expected_shape, dtype=DTYPE) * wdc.TMAX + wdc.VAP = torch.ones(expected_shape, dtype=DTYPE) * wdc.VAP + wdc.RAIN = torch.ones(expected_shape, dtype=DTYPE) * wdc.RAIN + wdc.WIND = torch.ones(expected_shape, dtype=DTYPE) * wdc.WIND + wdc.E0 = torch.ones(expected_shape, dtype=DTYPE) * wdc.E0 + wdc.ES0 = torch.ones(expected_shape, dtype=DTYPE) * wdc.ES0 + wdc.ET0 = torch.ones(expected_shape, dtype=DTYPE) * wdc.ET0 + wdc.TEMP = torch.ones(expected_shape, dtype=DTYPE) * wdc.TEMP + + # Should not raise an exception (all same shape) + _check_drv_shape(wdc, expected_shape) + + # If no exception was raised, test passes + assert True + + def test_mixed_shapes_raises_error(self): + """Test that two variables with different non-scalar shapes raise ValueError.""" + from diffwofost.physical_models.utils import _check_drv_shape + + expected_shape = (3, 2) + wrong_shape = (2, 3) + + # Read weather data from test YAML + test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + weather_data_provider = WeatherDataProviderTestHelper( + test_data["WeatherVariables"], meteo_range_checks=False + ) + + first_date = weather_data_provider.first_date + wdc = weather_data_provider(first_date) + + # Set IRRAD to expected_shape + wdc.IRRAD = torch.ones(expected_shape, dtype=DTYPE) * wdc.IRRAD + + # Set TMIN to a different wrong_shape + wdc.TMIN = torch.ones(wrong_shape, dtype=DTYPE) * wdc.TMIN + + # Should raise ValueError because shapes don't match + with pytest.raises(ValueError, match="incompatible shape"): + _check_drv_shape(wdc, expected_shape) From f3dcb80635d500f3e9d28665ab80685684842063 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Tue, 18 Nov 2025 11:41:35 +0100 Subject: [PATCH 03/13] Add drv shape check --- src/diffwofost/physical_models/crop/leaf_dynamics.py | 4 ++++ src/diffwofost/physical_models/utils.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 2605d47..9e8dcdc 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -14,6 +14,7 @@ from pcse.traitlets import Any from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _check_drv_shape from diffwofost.physical_models.utils import _get_params_shape DTYPE = torch.float64 # Default data type for tensors in this module @@ -288,6 +289,9 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: s.TWLV = _broadcast_to(s.TWLV, p.shape) self._shape_finalized = True + # Finally check if drv shape is consistent + # [!] once weathercontainer supports batched variables, this check will be redundant + _check_drv_shape(drv, p.shape) # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask # A mask (0 if DVS < 0, 1 if DVS >= 0) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index 4d604bc..f6ed93c 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -570,6 +570,8 @@ def _check_drv_shape(drv, expected_shape): Driving variables can be scalars (0-dimensional) or match the expected shape. Scalars will be broadcast during operations. + [!] This function will be redundant once weathercontainer supports batched variables. + Args: drv: WeatherDataContainer with driving variables expected_shape: Expected shape tuple for non-scalar variables From da2948026c6fc38903961ca4896923f2198d2dbc Mon Sep 17 00:00:00 2001 From: SCiarella Date: Tue, 18 Nov 2025 12:08:37 +0100 Subject: [PATCH 04/13] Refactor for sonar --- .../physical_models/crop/leaf_dynamics.py | 43 ++++++++++--------- tests/physical_models/test_utils.py | 6 --- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 9e8dcdc..33f3283 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -232,31 +232,17 @@ def _calc_LAI(self): total_LAI = self.states.LASUM + SAI + PAI return total_LAI - @prepare_rates - def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: - """Calculate the rates of change for the leaf dynamics. - - Args: - day (datetime.date, optional): The current date of the simulation. - drv (WeatherDataContainer, optional): A dictionary-like container holding - weather data elements as key/value. The values are - arrays or scalars. See PCSE documentation for details. - """ - r = self.rates - s = self.states + def _ensure_shape_finalized(self, drv: WeatherDataContainer) -> None: + """Ensure that the parameter and state variable shapes are finalized.""" p = self.params - k = self.kiosk + s = self.states - # Update shape if needed based on drv.TEMP if not self._shape_finalized: TEMP = torch.as_tensor(drv.TEMP, dtype=DTYPE) - # Broadcast to get the final shape temp_shape = TEMP.shape if TEMP.ndim > 0 else () if temp_shape != p.shape: old_shape = p.shape - # Validate that shapes can be broadcasted try: - # Update to the broadcasted shape p.shape = torch.broadcast_shapes(p.shape, temp_shape) except RuntimeError as e: raise ValueError( @@ -266,14 +252,11 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: ) from e # Reshape state variables to match new parameter shape - # For time-series tensors (LV, SLA, LVAGE), we need to add the batch dimension if len(old_shape) == 0: - # Old shape was scalar, new shape has batch dimension s.LV = s.LV.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) s.SLA = s.SLA.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) s.LVAGE = s.LVAGE.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) else: - # Both have batch dimensions, use standard broadcast s.LV = s.LV.expand(*p.shape, self.MAX_DAYS) s.SLA = s.SLA.expand(*p.shape, self.MAX_DAYS) s.LVAGE = s.LVAGE.expand(*p.shape, self.MAX_DAYS) @@ -289,10 +272,28 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: s.TWLV = _broadcast_to(s.TWLV, p.shape) self._shape_finalized = True + # Finally check if drv shape is consistent - # [!] once weathercontainer supports batched variables, this check will be redundant _check_drv_shape(drv, p.shape) + @prepare_rates + def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: + """Calculate the rates of change for the leaf dynamics. + + Args: + day (datetime.date, optional): The current date of the simulation. + drv (WeatherDataContainer, optional): A dictionary-like container holding + weather data elements as key/value. The values are + arrays or scalars. See PCSE documentation for details. + """ + r = self.rates + s = self.states + p = self.params + k = self.kiosk + + # Update shape if needed based on drv.TEMP + self._ensure_shape_finalized(drv) + # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask # A mask (0 if DVS < 0, 1 if DVS >= 0) DVS = torch.as_tensor(k["DVS"], dtype=DTYPE) diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index b93bedc..5937404 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -564,9 +564,6 @@ def test_scalar_drv(self): # (all scalar values should pass for any expected_shape) _check_drv_shape(wdc, expected_shape) - # If no exception was raised, test passes - assert True - def test_single_vector_variable(self): """Test that each weather variable can individually be vectorized to expected_shape.""" from diffwofost.physical_models.utils import _check_drv_shape @@ -645,9 +642,6 @@ def test_all_vector_variables(self): # Should not raise an exception (all same shape) _check_drv_shape(wdc, expected_shape) - # If no exception was raised, test passes - assert True - def test_mixed_shapes_raises_error(self): """Test that two variables with different non-scalar shapes raise ValueError.""" from diffwofost.physical_models.utils import _check_drv_shape From 203fd64bf1176e0f616998350b8099a0357f0f58 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 11:28:31 +0100 Subject: [PATCH 05/13] Remove _ensure_shape function Store shape outside of param class --- .../physical_models/crop/leaf_dynamics.py | 73 ++++--------------- .../crop/test_leaf_dynamics.py | 52 ++++++++----- 2 files changed, 49 insertions(+), 76 deletions(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 33f3283..9c318db 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -14,7 +14,6 @@ from pcse.traitlets import Any from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to -from diffwofost.physical_models.utils import _check_drv_shape from diffwofost.physical_models.utils import _get_params_shape DTYPE = torch.float64 # Default data type for tensors in this module @@ -109,7 +108,7 @@ class WOFOST_Leaf_Dynamics(SimulationObject): | LAI | TDWI, SPAN, RGRLAI, TBASE, KDIFTB, SLATB | | TWLV | TDWI, PERDL | - [!] Notice that the following gradient are zero: + [!] Notice that the following gradients are zero: - ∂SPAN/∂LAI - ∂PERDL/∂TWLV - ∂KDIFTB/∂LAI @@ -119,6 +118,7 @@ class WOFOST_Leaf_Dynamics(SimulationObject): # on the leaf classes during the time integration: leaf area, age, and biomass. START_DATE = None # Start date of the simulation MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts) + params_shape = None # Shape of the parameters and the drv data class Parameters(ParamTemplate): RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) @@ -128,7 +128,6 @@ class Parameters(ParamTemplate): TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) SLATB = AfgenTrait() KDIFTB = AfgenTrait() - shape = None # shape of the parameters (set during initialization) class StateVariables(StatesTemplate): LV = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) @@ -186,18 +185,18 @@ def initialize( DVS = self.kiosk["DVS"] params = self.params - params.shape = _get_params_shape(params) + self.params_shape = _get_params_shape(params) self._shape_finalized = False # Track if shape has been updated with drv data # Initial leaf biomass WLV = (params.TDWI * (1 - FR)) * FL - DWLV = torch.zeros(params.shape, dtype=DTYPE) + DWLV = torch.zeros(self.params_shape, dtype=DTYPE) TWLV = WLV + DWLV # Initialize leaf classes (SLA, age and weight) - SLA = torch.zeros((*params.shape, self.MAX_DAYS), dtype=DTYPE) - LVAGE = torch.zeros((*params.shape, self.MAX_DAYS), dtype=DTYPE) - LV = torch.zeros((*params.shape, self.MAX_DAYS), dtype=DTYPE) + SLA = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) + LVAGE = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) + LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) SLA[..., 0] = params.SLATB(DVS) LV[..., 0] = WLV @@ -232,50 +231,6 @@ def _calc_LAI(self): total_LAI = self.states.LASUM + SAI + PAI return total_LAI - def _ensure_shape_finalized(self, drv: WeatherDataContainer) -> None: - """Ensure that the parameter and state variable shapes are finalized.""" - p = self.params - s = self.states - - if not self._shape_finalized: - TEMP = torch.as_tensor(drv.TEMP, dtype=DTYPE) - temp_shape = TEMP.shape if TEMP.ndim > 0 else () - if temp_shape != p.shape: - old_shape = p.shape - try: - p.shape = torch.broadcast_shapes(p.shape, temp_shape) - except RuntimeError as e: - raise ValueError( - f"Parameter shape {p.shape} and weather driver shape {temp_shape} " - f"have incompatible shape and cannot be broadcasted together. " - f"Original error: {e}" - ) from e - - # Reshape state variables to match new parameter shape - if len(old_shape) == 0: - s.LV = s.LV.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) - s.SLA = s.SLA.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) - s.LVAGE = s.LVAGE.unsqueeze(0).expand(*p.shape, self.MAX_DAYS) - else: - s.LV = s.LV.expand(*p.shape, self.MAX_DAYS) - s.SLA = s.SLA.expand(*p.shape, self.MAX_DAYS) - s.LVAGE = s.LVAGE.expand(*p.shape, self.MAX_DAYS) - - # For scalar state variables, use _broadcast_to - s.LAIEM = _broadcast_to(s.LAIEM, p.shape) - s.LASUM = _broadcast_to(s.LASUM, p.shape) - s.LAIEXP = _broadcast_to(s.LAIEXP, p.shape) - s.LAIMAX = _broadcast_to(s.LAIMAX, p.shape) - s.LAI = _broadcast_to(s.LAI, p.shape) - s.WLV = _broadcast_to(s.WLV, p.shape) - s.DWLV = _broadcast_to(s.DWLV, p.shape) - s.TWLV = _broadcast_to(s.TWLV, p.shape) - - self._shape_finalized = True - - # Finally check if drv shape is consistent - _check_drv_shape(drv, p.shape) - @prepare_rates def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: """Calculate the rates of change for the leaf dynamics. @@ -291,9 +246,6 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: p = self.params k = self.kiosk - # Update shape if needed based on drv.TEMP - self._ensure_shape_finalized(drv) - # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask # A mask (0 if DVS < 0, 1 if DVS >= 0) DVS = torch.as_tensor(k["DVS"], dtype=DTYPE) @@ -355,10 +307,17 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: r.DRLV = torch.maximum(r.DSLV, r.DALV) # Get the temperature from the drv - TEMP = _broadcast_to(torch.as_tensor(drv.TEMP, dtype=DTYPE), p.shape) + if isinstance(drv.TEMP, torch.Tensor): + if drv.TEMP.shape != self.params_shape: + raise ValueError( + f"Weather driver TEMP shape {drv.TEMP.shape} is not compatible with " + f"parameter shape {self.params_shape}." + ) + else: + TEMP = _broadcast_to(torch.as_tensor(drv.TEMP, dtype=DTYPE), self.params_shape) # physiologic ageing of leaves per time step - TBASE = _broadcast_to(p.TBASE, p.shape) + TBASE = _broadcast_to(p.TBASE, self.params_shape) FYSAGE = (TEMP - TBASE) / (35.0 - TBASE) r.FYSAGE = dvs_mask * torch.clamp(FYSAGE, 0.0) diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index 12eaf0b..aa94cbf 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -163,27 +163,41 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): repeated = crop_model_params_provider[param].repeat(10) crop_model_params_provider.set_override(param, repeated, check=False) - engine = EngineTestHelper( - crop_model_params_provider, - weather_data_provider, - agro_management_inputs, - config_path, - external_states, - ) - engine.run_till_terminate() - actual_results = engine.get_output() + if param == "TEMP": + # Expect error due to incompatible shapes + # (By defaults parameters are not reshaped following weather variables) + with pytest.raises(ValueError): + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + else: + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() - # get expected results from YAML test data - expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + # get expected results from YAML test data + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] - assert len(actual_results) == len(expected_results) + assert len(actual_results) == len(expected_results) - for reference, model in zip(expected_results, actual_results, strict=False): - assert reference["DAY"] == model["day"] - assert all( - all(abs(reference[var] - model[var]) < precision) - for var, precision in expected_precision.items() - ) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) @pytest.mark.parametrize( "param,delta", @@ -376,7 +390,7 @@ def test_leaf_dynamics_with_incompatible_weather_parameter_vectors(self): for (_, _), wdc in weather_data_provider.store.items(): wdc.TEMP = torch.ones(5, dtype=torch.float64) * wdc.TEMP - with pytest.raises(ValueError, match="incompatible shape"): + with pytest.raises(ValueError): EngineTestHelper( crop_model_params_provider, weather_data_provider, From 4479182f0d63450a8512e2fd913e5bdc53c6fb29 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 11:33:40 +0100 Subject: [PATCH 06/13] Add disclaimer on meteo_range_check --- src/diffwofost/physical_models/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index f6ed93c..4c7a976 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -203,6 +203,10 @@ class WeatherDataProviderTestHelper(WeatherDataProvider): def __init__(self, yaml_weather, meteo_range_checks=True): super().__init__() + # This is a temporary workaround. The `METEO_RANGE_CHECKS` logic in + # `__setattr__` method in `WeatherDataContainer` is not vector compatible + # yet. So we can disable it here when creating the `WeatherDataContainer` + # instances with arrays. settings.METEO_RANGE_CHECKS = meteo_range_checks for weather in yaml_weather: if "SNOWDEPTH" in weather: From b6a11627940d57f17261d0f06d6e9825898fa258 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 12:04:16 +0100 Subject: [PATCH 07/13] Reduce scope of _check_drv_shape --- .../physical_models/crop/leaf_dynamics.py | 10 +- src/diffwofost/physical_models/utils.py | 78 +++------ tests/physical_models/test_utils.py | 156 +++++------------- 3 files changed, 61 insertions(+), 183 deletions(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 9c318db..90924fd 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -14,6 +14,7 @@ from pcse.traitlets import Any from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_drv_var from diffwofost.physical_models.utils import _get_params_shape DTYPE = torch.float64 # Default data type for tensors in this module @@ -307,14 +308,7 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: r.DRLV = torch.maximum(r.DSLV, r.DALV) # Get the temperature from the drv - if isinstance(drv.TEMP, torch.Tensor): - if drv.TEMP.shape != self.params_shape: - raise ValueError( - f"Weather driver TEMP shape {drv.TEMP.shape} is not compatible with " - f"parameter shape {self.params_shape}." - ) - else: - TEMP = _broadcast_to(torch.as_tensor(drv.TEMP, dtype=DTYPE), self.params_shape) + TEMP = _get_drv_var(drv.TEMP, self.params_shape) # physiologic ageing of leaves per time step TBASE = _broadcast_to(p.TBASE, self.params_shape) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index 4c7a976..e1334e7 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -203,10 +203,6 @@ class WeatherDataProviderTestHelper(WeatherDataProvider): def __init__(self, yaml_weather, meteo_range_checks=True): super().__init__() - # This is a temporary workaround. The `METEO_RANGE_CHECKS` logic in - # `__setattr__` method in `WeatherDataContainer` is not vector compatible - # yet. So we can disable it here when creating the `WeatherDataContainer` - # instances with arrays. settings.METEO_RANGE_CHECKS = meteo_range_checks for weather in yaml_weather: if "SNOWDEPTH" in weather: @@ -568,8 +564,8 @@ def _get_params_shape(params): return shape -def _check_drv_shape(drv, expected_shape): - """Check that the driving variables have the expected shape. +def _get_drv_var(drv_var, expected_shape): + """Check that the driving variables have the expected shape and fetch them. Driving variables can be scalars (0-dimensional) or match the expected shape. Scalars will be broadcast during operations. @@ -577,66 +573,34 @@ def _check_drv_shape(drv, expected_shape): [!] This function will be redundant once weathercontainer supports batched variables. Args: - drv: WeatherDataContainer with driving variables + drv_var: driving variable in WeatherDataContainer expected_shape: Expected shape tuple for non-scalar variables Raises: ValueError: If any variable has incompatible shape - """ - # Define compulsory and optional weather variables - compulsory_vars = [ - "LAT", - "LON", - "ELEV", - "DAY", - "IRRAD", - "TMIN", - "TMAX", - "VAP", - "RAIN", - "WIND", - "E0", - "ES0", - "ET0", - ] - optional_vars = ["TEMP", "SNOWDEPTH"] - all_vars = compulsory_vars + optional_vars - - for var_name in all_vars: - # Skip if variable doesn't exist (only matters for optional vars) - if not hasattr(drv, var_name): - if var_name in compulsory_vars: - raise ValueError( - f"Compulsory variable '{var_name}' missing from WeatherDataContainer" - ) - continue - - var_value = getattr(drv, var_name) - - # Skip DAY as it's a date object, not a tensor - if var_name == "DAY": - continue - # Convert to tensor if needed for shape checking - if not isinstance(var_value, torch.Tensor): - var_value = torch.as_tensor(var_value, dtype=DTYPE) - - # Check shape: must be scalar (0-d) or match expected_shape - if var_value.dim() == 0: - # Scalar is valid, will be broadcast - continue - elif var_value.shape == expected_shape: - # Matches expected shape - continue - else: - raise ValueError( - f"Weather variable '{var_name}' has incompatible shape {var_value.shape}. " - f"Expected scalar (0-dimensional) or shape {expected_shape}." - ) + Returns: + torch.Tensor: The validated variable, either as-is or broadcasted to expected shape. + """ + # Check shape: must be scalar (0-d) or match expected_shape + if not isinstance(drv_var, torch.Tensor) or drv_var.dim() == 0: + # Scalar is valid, will be broadcast + return _broadcast_to(drv_var, expected_shape) + elif drv_var.shape == expected_shape: + # Matches expected shape + return drv_var + else: + raise ValueError( + f"Requested weather variable has incompatible shape {drv_var.shape}. " + f"Expected scalar (0-dimensional) or shape {expected_shape}." + ) def _broadcast_to(x, shape): """Create a view of tensor X with the given shape.""" + # If x is not a tensor, convert it + if not isinstance(x, torch.Tensor): + x = torch.tensor(x, dtype=DTYPE) # If already the correct shape, return as-is if x.shape == shape: return x diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index 5937404..09f61b3 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -6,6 +6,7 @@ from diffwofost.physical_models.utils import Afgen from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import WeatherDataProviderTestHelper +from diffwofost.physical_models.utils import _get_drv_var from diffwofost.physical_models.utils import get_test_data from . import phy_data_folder @@ -542,129 +543,48 @@ def test_backward_compatibility_non_batched(self): assert torch.isclose(result, torch.tensor(5.0, dtype=DTYPE)) -class TestCheckDrvShape: - """Tests for _check_drv_shape function.""" +class TestGetDrvParam: + """Tests for _get_drv_var function.""" - def test_scalar_drv(self): - """Test that drv accepts scalar variables for any expected shape.""" - from diffwofost.physical_models.utils import _check_drv_shape - - expected_shape = (3, 2) # Example batch shape - - # Read weather data from test YAML - test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" - test_data = get_test_data(test_data_url) - weather_data_provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) - - # Get weather data for a specific date using the provider's __call__ method - first_date = weather_data_provider.first_date - wdc = weather_data_provider(first_date) - - # Check that the weather data container has valid shape - # (all scalar values should pass for any expected_shape) - _check_drv_shape(wdc, expected_shape) - - def test_single_vector_variable(self): - """Test that each weather variable can individually be vectorized to expected_shape.""" - from diffwofost.physical_models.utils import _check_drv_shape - - expected_shape = (3, 2) # Example batch shape - - # Read weather data from test YAML - test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" - test_data = get_test_data(test_data_url) - weather_data_provider = WeatherDataProviderTestHelper( - test_data["WeatherVariables"], meteo_range_checks=False - ) - - # List of all weather variables to test - vector_vars = [ - "LAT", - "LON", - "ELEV", - "IRRAD", - "TMIN", - "TMAX", - "VAP", - "RAIN", - "WIND", - "E0", - "ES0", - "ET0", - "TEMP", - ] - - first_date = weather_data_provider.first_date - - # Test each variable individually - for var_name in vector_vars: - wdc = weather_data_provider(first_date) - - # Vectorize only this one variable - original_value = getattr(wdc, var_name) - setattr(wdc, var_name, torch.ones(expected_shape, dtype=DTYPE) * original_value) - - # Should not raise an exception (scalar + one vectorized variable is valid) - _check_drv_shape(wdc, expected_shape) - - def test_all_vector_variables(self): - """Test that all weather variables can be vectorized together to expected_shape.""" - from diffwofost.physical_models.utils import _check_drv_shape - - expected_shape = (3, 2) # Example batch shape - - # Read weather data from test YAML + def test_float_broadcast(self): + expected_shape = (3, 2) test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - weather_data_provider = WeatherDataProviderTestHelper( - test_data["WeatherVariables"], meteo_range_checks=False - ) - - # Get weather data for a specific date - first_date = weather_data_provider.first_date - wdc = weather_data_provider(first_date) - - # Vectorize all weather variables to the same expected_shape - wdc.LAT = torch.ones(expected_shape, dtype=DTYPE) * wdc.LAT - wdc.LON = torch.ones(expected_shape, dtype=DTYPE) * wdc.LON - wdc.ELEV = torch.ones(expected_shape, dtype=DTYPE) * wdc.ELEV - wdc.IRRAD = torch.ones(expected_shape, dtype=DTYPE) * wdc.IRRAD - wdc.TMIN = torch.ones(expected_shape, dtype=DTYPE) * wdc.TMIN - wdc.TMAX = torch.ones(expected_shape, dtype=DTYPE) * wdc.TMAX - wdc.VAP = torch.ones(expected_shape, dtype=DTYPE) * wdc.VAP - wdc.RAIN = torch.ones(expected_shape, dtype=DTYPE) * wdc.RAIN - wdc.WIND = torch.ones(expected_shape, dtype=DTYPE) * wdc.WIND - wdc.E0 = torch.ones(expected_shape, dtype=DTYPE) * wdc.E0 - wdc.ES0 = torch.ones(expected_shape, dtype=DTYPE) * wdc.ES0 - wdc.ET0 = torch.ones(expected_shape, dtype=DTYPE) * wdc.ET0 - wdc.TEMP = torch.ones(expected_shape, dtype=DTYPE) * wdc.TEMP - - # Should not raise an exception (all same shape) - _check_drv_shape(wdc, expected_shape) - - def test_mixed_shapes_raises_error(self): - """Test that two variables with different non-scalar shapes raise ValueError.""" - from diffwofost.physical_models.utils import _check_drv_shape - + provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) + wdc = provider(provider.first_date) + scalar = wdc.TEMP + out = _get_drv_var(scalar, expected_shape) + assert out.shape == expected_shape + assert torch.allclose(out, torch.full(expected_shape, scalar, dtype=torch.float32)) + + def test_scalar_broadcast(self): expected_shape = (3, 2) - wrong_shape = (2, 3) - - # Read weather data from test YAML test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) - weather_data_provider = WeatherDataProviderTestHelper( - test_data["WeatherVariables"], meteo_range_checks=False - ) - - first_date = weather_data_provider.first_date - wdc = weather_data_provider(first_date) - - # Set IRRAD to expected_shape - wdc.IRRAD = torch.ones(expected_shape, dtype=DTYPE) * wdc.IRRAD - - # Set TMIN to a different wrong_shape - wdc.TMIN = torch.ones(wrong_shape, dtype=DTYPE) * wdc.TMIN + provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) + wdc = provider(provider.first_date) + scalar = torch.tensor(wdc.IRRAD, dtype=DTYPE) # 0-d tensor + out = _get_drv_var(scalar, expected_shape) + assert out.shape == expected_shape + assert torch.allclose(out, torch.full(expected_shape, scalar.item(), dtype=DTYPE)) + + def test_matching_shape_pass_through(self): + expected_shape = (3, 2) + base_val = torch.tensor(12.34, dtype=DTYPE) + var = torch.ones(expected_shape, dtype=DTYPE) * base_val + out = _get_drv_var(var, expected_shape) + assert out.shape == expected_shape + # Should be the same object (no copy) + assert out.data_ptr() == var.data_ptr() + + def test_wrong_shape_raises(self): + expected_shape = (3, 2) + wrong = torch.ones(2, 3, dtype=DTYPE) + with pytest.raises(ValueError, match="incompatible shape"): + _get_drv_var(wrong, expected_shape) - # Should raise ValueError because shapes don't match + def test_one_dim_shape_raises(self): + expected_shape = (3, 2) + one_dim = torch.ones(3, dtype=DTYPE) with pytest.raises(ValueError, match="incompatible shape"): - _check_drv_shape(wdc, expected_shape) + _get_drv_var(one_dim, expected_shape) From 7f1b76bc2f283b94c9a9068ce44aa24e78782c04 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 12:06:16 +0100 Subject: [PATCH 08/13] Lint --- .../physical_models/crop/leaf_dynamics.py | 4 ++-- src/diffwofost/physical_models/utils.py | 2 +- tests/physical_models/test_utils.py | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 90924fd..c93597e 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -14,7 +14,7 @@ from pcse.traitlets import Any from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to -from diffwofost.physical_models.utils import _get_drv_var +from diffwofost.physical_models.utils import _get_drv from diffwofost.physical_models.utils import _get_params_shape DTYPE = torch.float64 # Default data type for tensors in this module @@ -308,7 +308,7 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: r.DRLV = torch.maximum(r.DSLV, r.DALV) # Get the temperature from the drv - TEMP = _get_drv_var(drv.TEMP, self.params_shape) + TEMP = _get_drv(drv.TEMP, self.params_shape) # physiologic ageing of leaves per time step TBASE = _broadcast_to(p.TBASE, self.params_shape) diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index e1334e7..d19f1dd 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -564,7 +564,7 @@ def _get_params_shape(params): return shape -def _get_drv_var(drv_var, expected_shape): +def _get_drv(drv_var, expected_shape): """Check that the driving variables have the expected shape and fetch them. Driving variables can be scalars (0-dimensional) or match the expected shape. diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index 09f61b3..41915e8 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -6,7 +6,7 @@ from diffwofost.physical_models.utils import Afgen from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import WeatherDataProviderTestHelper -from diffwofost.physical_models.utils import _get_drv_var +from diffwofost.physical_models.utils import _get_drv from diffwofost.physical_models.utils import get_test_data from . import phy_data_folder @@ -544,7 +544,7 @@ def test_backward_compatibility_non_batched(self): class TestGetDrvParam: - """Tests for _get_drv_var function.""" + """Tests for _get_drv function.""" def test_float_broadcast(self): expected_shape = (3, 2) @@ -553,7 +553,7 @@ def test_float_broadcast(self): provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) wdc = provider(provider.first_date) scalar = wdc.TEMP - out = _get_drv_var(scalar, expected_shape) + out = _get_drv(scalar, expected_shape) assert out.shape == expected_shape assert torch.allclose(out, torch.full(expected_shape, scalar, dtype=torch.float32)) @@ -564,7 +564,7 @@ def test_scalar_broadcast(self): provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) wdc = provider(provider.first_date) scalar = torch.tensor(wdc.IRRAD, dtype=DTYPE) # 0-d tensor - out = _get_drv_var(scalar, expected_shape) + out = _get_drv(scalar, expected_shape) assert out.shape == expected_shape assert torch.allclose(out, torch.full(expected_shape, scalar.item(), dtype=DTYPE)) @@ -572,7 +572,7 @@ def test_matching_shape_pass_through(self): expected_shape = (3, 2) base_val = torch.tensor(12.34, dtype=DTYPE) var = torch.ones(expected_shape, dtype=DTYPE) * base_val - out = _get_drv_var(var, expected_shape) + out = _get_drv(var, expected_shape) assert out.shape == expected_shape # Should be the same object (no copy) assert out.data_ptr() == var.data_ptr() @@ -581,10 +581,10 @@ def test_wrong_shape_raises(self): expected_shape = (3, 2) wrong = torch.ones(2, 3, dtype=DTYPE) with pytest.raises(ValueError, match="incompatible shape"): - _get_drv_var(wrong, expected_shape) + _get_drv(wrong, expected_shape) def test_one_dim_shape_raises(self): expected_shape = (3, 2) one_dim = torch.ones(3, dtype=DTYPE) with pytest.raises(ValueError, match="incompatible shape"): - _get_drv_var(one_dim, expected_shape) + _get_drv(one_dim, expected_shape) From d76c6f83dd846c26ca699baa1dc7641b4e5f473f Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 12:08:26 +0100 Subject: [PATCH 09/13] Fix tests --- tests/physical_models/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index 41915e8..ec009de 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -555,7 +555,7 @@ def test_float_broadcast(self): scalar = wdc.TEMP out = _get_drv(scalar, expected_shape) assert out.shape == expected_shape - assert torch.allclose(out, torch.full(expected_shape, scalar, dtype=torch.float32)) + assert torch.allclose(out, torch.full(expected_shape, scalar, dtype=DTYPE)) def test_scalar_broadcast(self): expected_shape = (3, 2) From 5890a81271ea2928899d0563333b9379833970cc Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 14:16:37 +0100 Subject: [PATCH 10/13] Lint --- src/diffwofost/physical_models/crop/leaf_dynamics.py | 1 - src/diffwofost/physical_models/utils.py | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index c93597e..5fc6611 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -187,7 +187,6 @@ def initialize( params = self.params self.params_shape = _get_params_shape(params) - self._shape_finalized = False # Track if shape has been updated with drv data # Initial leaf biomass WLV = (params.TDWI * (1 - FR)) * FL diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index d19f1dd..3d5750c 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -203,6 +203,10 @@ class WeatherDataProviderTestHelper(WeatherDataProvider): def __init__(self, yaml_weather, meteo_range_checks=True): super().__init__() + # This is a temporary workaround. The `METEO_RANGE_CHECKS` logic in + # `__setattr__` method in `WeatherDataContainer` is not vector compatible + # yet. So we can disable it here when creating the `WeatherDataContainer` + # instances with arrays. settings.METEO_RANGE_CHECKS = meteo_range_checks for weather in yaml_weather: if "SNOWDEPTH" in weather: From ce77613281c29dda967e03d45865bf707cec8dfe Mon Sep 17 00:00:00 2001 From: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:55:55 +0100 Subject: [PATCH 11/13] Apply suggestion from @SarahAlidoost --- src/diffwofost/physical_models/crop/leaf_dynamics.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 5fc6611..731d939 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -119,7 +119,6 @@ class WOFOST_Leaf_Dynamics(SimulationObject): # on the leaf classes during the time integration: leaf area, age, and biomass. START_DATE = None # Start date of the simulation MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts) - params_shape = None # Shape of the parameters and the drv data class Parameters(ParamTemplate): RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) From 89ad3c2769e78d6acde6a1a53830be480587011a Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 15:07:03 +0100 Subject: [PATCH 12/13] Test TEMP 2dTensor --- tests/physical_models/crop/test_leaf_dynamics.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index aa94cbf..699979a 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -302,7 +302,7 @@ def test_leaf_dynamics_with_multiple_parameter_arrays(self): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params) + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") # Setting an array with arbitrary shape (and one value) @@ -314,6 +314,9 @@ def test_leaf_dynamics_with_multiple_parameter_arrays(self): repeated = crop_model_params_provider[param].broadcast_to((30, 5)) crop_model_params_provider.set_override(param, repeated, check=False) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + engine = EngineTestHelper( crop_model_params_provider, weather_data_provider, From b80a8ac17ed04ad3a87ded63f24fe352218fefa8 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 20 Nov 2025 15:09:51 +0100 Subject: [PATCH 13/13] Declare shape --- src/diffwofost/physical_models/crop/leaf_dynamics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 731d939..5ec36ba 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -119,6 +119,7 @@ class WOFOST_Leaf_Dynamics(SimulationObject): # on the leaf classes during the time integration: leaf area, age, and biomass. START_DATE = None # Start date of the simulation MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts) + params_shape = None # Shape of the parameters tensors class Parameters(ParamTemplate): RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])