Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tests/physical_models/crop/test_root_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,55 @@ def test_gradients_tdwi_twrt_root_dynamics_numerical(self):

# in this test, grads is very small
assert_array_almost_equal(numerical_grad, grads.item(), decimal=3)


class TestDiffRootDynamicsRDI:
def test_gradients_rdi_twrt_root_dynamics(self):
model = get_test_diff_root_model()
rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32))
output = model({"RDI": rdi})
twrt = output["TWRT"]
loss = twrt.sum()

assert loss.grad_fn is None

def test_gradients_rdi_rd_root_dynamics(self):
# prepare model input
model = get_test_diff_root_model()
rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float32))
output = model({"RDI": rdi})
rd = output["RD"]
loss = rd.sum()

# this is ∂loss/∂rdi
# this is called forward gradient here because it is calculated without backpropagation.
grads = torch.autograd.grad(loss, rdi, retain_graph=True)[0]

assert grads is not None, "Gradients for RDI should not be None"

rdi.grad = None # clear any existing gradient
loss.backward()

# this is ∂loss/∂rdi calculated using backpropagation
grad_backward = rdi.grad

assert grad_backward is not None, "Backward gradients for RDI should not be None"
assert grad_backward == grads, "Forward and backward gradients for RDI should match"

def test_gradients_rdi_rd_root_dynamics_numerical(self):
rdi = torch.nn.Parameter(torch.tensor(0.2, dtype=torch.float64))
output_name = "RD"
numerical_grad = calculate_numerical_grad(
get_test_diff_root_model, "RDI", rdi.data, output_name
)

model = get_test_diff_root_model()
output = model({"RDI": rdi})
rd = output[output_name]
loss = rd.sum()

# this is ∂loss/∂rdi, for comparison with numerical gradient
grads = torch.autograd.grad(loss, rdi, retain_graph=True)[0]

# in this test, grads is very small
assert_array_almost_equal(numerical_grad, grads.item(), decimal=3)