diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index 0770237..89dbb9e 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -43,6 +43,15 @@ def get_test_data(file_path): return inputs["ModelResults"], inputs["Precision"] +def get_test_diff_leaf_model(): + test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" + params, wdp, agro, external_states = prepare_engine_input(test_data_path) + config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") + return DiffLeafDynamics(params, wdp, agro, config_path, external_states) + + + + class DiffLeafDynamics(torch.nn.Module): def __init__(self, params, wdp, agro, config_path, external_states): super().__init__() @@ -132,13 +141,7 @@ def test_wofost_pp_with_leaf_dynamics(self): class TestDiffLeafDynamicsTDWI: def test_gradients_tdwi_lai_leaf_dynamics(self): - # prepare model input - test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" - params, wdp, agro, external_states = prepare_engine_input(test_data_path) - config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") - - # create a model and optimizer - model = DiffLeafDynamics(params, wdp, agro, config_path, external_states) + model = get_test_diff_leaf_model() tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) output = model({"TDWI": tdwi}) lai = output[0, :, 0] @@ -159,12 +162,7 @@ def test_gradients_tdwi_lai_leaf_dynamics(self): def test_gradients_tdwi_twlv_leaf_dynamics(self): # prepare model input - test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" - params, wdp, agro, external_states = prepare_engine_input(test_data_path) - config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") - - # create a model and optimizer - model = DiffLeafDynamics(params, wdp, agro, config_path, external_states) + model = get_test_diff_leaf_model() tdwi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) output = model({"TDWI": tdwi}) twlv = output[0, :, 1] @@ -187,12 +185,7 @@ def test_gradients_tdwi_twlv_leaf_dynamics(self): class TestDiffLeafDynamicsSPAN: def test_gradients_span_lai_leaf_dynamics(self): # prepare model input - test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" - params, wdp, agro, external_states = prepare_engine_input(test_data_path) - config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") - - # create a model and optimizer - model = DiffLeafDynamics(params, wdp, agro, config_path, external_states) + model = get_test_diff_leaf_model() span = torch.nn.Parameter(torch.tensor(30, dtype=torch.float32)) output = model({"SPAN": span}) lai = output[0, :, 0] @@ -213,12 +206,7 @@ def test_gradients_span_lai_leaf_dynamics(self): def test_gradients_span_twlv_leaf_dynamics(self): # prepare model input - test_data_path = phy_data_folder / "test_leafdynamics_wofost72_01.yaml" - params, wdp, agro, external_states = prepare_engine_input(test_data_path) - config_path = str(phy_data_folder / "WOFOST_Leaf_Dynamics.conf") - - # create a model and optimizer - model = DiffLeafDynamics(params, wdp, agro, config_path, external_states) + model = get_test_diff_leaf_model() span = torch.nn.Parameter(torch.tensor(30, dtype=torch.float32)) output = model({"SPAN": span}) twlv = output[0, :, 1]