diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 7ae9893..25fce00 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -197,3 +197,55 @@ def test_gradients_tdwi_twrt_root_dynamics_numerical(self): # in this test, grads is very small assert_array_almost_equal(numerical_grad, grads.item(), decimal=3) + + +class TestDiffRootDynamicsRDI: + def test_gradients_rdi_twrt_root_dynamics(self): + model = get_test_diff_root_model() + rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) + output = model({"RDI": rdi}) + twrt = output["TWRT"] + loss = twrt.sum() + + assert loss.grad_fn is None + + def test_gradients_rdi_rd_root_dynamics(self): + # prepare model input + model = get_test_diff_root_model() + rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) + output = model({"RDI": rdi}) + rd = output["RD"] + loss = rd.sum() + + # this is ∂loss/∂rdi + # this is called forward gradient here because it is calculated without backpropagation. + grads = torch.autograd.grad(loss, rdi, retain_graph=True)[0] + + assert grads is not None, "Gradients for RDI should not be None" + + rdi.grad = None # clear any existing gradient + loss.backward() + + # this is ∂loss/∂rdi calculated using backpropagation + grad_backward = rdi.grad + + assert grad_backward is not None, "Backward gradients for RDI should not be None" + assert grad_backward == grads, "Forward and backward gradients for RDI should match" + + def test_gradients_rdi_rd_root_dynamics_numerical(self): + rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64)) + output_name = "RD" + numerical_grad = calculate_numerical_grad( + get_test_diff_root_model, "RDI", rdi.data, output_name + ) + + model = get_test_diff_root_model() + output = model({"RDI": rdi}) + rd = output[output_name] + loss = rd.sum() + + # this is ∂loss/∂rdi, for comparison with numerical gradient + grads = torch.autograd.grad(loss, rdi, retain_graph=True)[0] + + # in this test, grads is very small + assert_array_almost_equal(numerical_grad, grads.item(), decimal=3)