From e756e79c41e4abe250d8826d6f5b6f179f0dd7df Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 30 Oct 2025 12:56:46 +0100 Subject: [PATCH 1/3] Test RDI differentiability --- .../crop/test_root_dynamics.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 7ae9893..b3c4cf9 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -197,3 +197,57 @@ def test_gradients_tdwi_twrt_root_dynamics_numerical(self): # in this test, grads is very small assert_array_almost_equal(numerical_grad, grads.item(), decimal=3) + + +class TestDiffRootDynamicsRDI: + def test_gradients_rdi_twrt_root_dynamics(self): + model = get_test_diff_root_model() + rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) + output = model({"RDI": rdi}) + twrt = output["TWRT"] + loss = twrt.sum() + + assert loss.grad_fn is None + + def test_gradients_rdi_rd_root_dynamics(self): + # prepare model input + model = get_test_diff_root_model() + rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) + output = model({"RDI": rdi}) + #twrt = output["TWRT"] + #loss = twrt.sum() + rd = output["RD"] + loss = rd.sum() + + # this is ∂loss/∂rdi + # this is called forward gradient here because it is calculated without backpropagation. + grads = torch.autograd.grad(loss, rdi, retain_graph=True)[0] + + assert grads is not None, "Gradients for RDI should not be None" + + rdi.grad = None # clear any existing gradient + loss.backward() + + # this is ∂loss/∂rdi calculated using backpropagation + grad_backward = rdi.grad + + assert grad_backward is not None, "Backward gradients for RDI should not be None" + assert grad_backward == grads, "Forward and backward gradients for RDI should match" + + def test_gradients_rdi_rd_root_dynamics_numerical(self): + rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64)) + output_name = "RD" + numerical_grad = calculate_numerical_grad( + get_test_diff_root_model, "RDI", rdi.data, output_name + ) + + model = get_test_diff_root_model() + output = model({"RDI": rdi}) + rd = output[output_name] + loss = rd.sum() + + # this is ∂loss/∂rdi, for comparison with numerical gradient + grads = torch.autograd.grad(loss, rdi, retain_graph=True)[0] + + # in this test, grads is very small + assert_array_almost_equal(numerical_grad, grads.item(), decimal=3) From 64f29f099088ef0e268d55022f3ffc72dd32e703 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 30 Oct 2025 13:21:46 +0100 Subject: [PATCH 2/3] Ruff lint --- tests/physical_models/crop/test_root_dynamics.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index b3c4cf9..0e31367 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -214,8 +214,6 @@ def test_gradients_rdi_rd_root_dynamics(self): model = get_test_diff_root_model() rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32)) output = model({"RDI": rdi}) - #twrt = output["TWRT"] - #loss = twrt.sum() rd = output["RD"] loss = rd.sum() From a02dda9b3cf35ec45d09096a88c40c7e043ec551 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 30 Oct 2025 13:29:03 +0100 Subject: [PATCH 3/3] Ruff lint --- tests/physical_models/crop/test_root_dynamics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 0e31367..25fce00 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -234,7 +234,7 @@ def test_gradients_rdi_rd_root_dynamics(self): def test_gradients_rdi_rd_root_dynamics_numerical(self): rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64)) - output_name = "RD" + output_name = "RD" numerical_grad = calculate_numerical_grad( get_test_diff_root_model, "RDI", rdi.data, output_name )