Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
995c73e
vectorize integration for single parameters
fnattino Oct 1, 2025
6128341
apply ruff formatting
fnattino Oct 1, 2025
3ea31bf
fix typo
fnattino Oct 14, 2025
c1d771d
adapt to any shape of parameters
fnattino Oct 14, 2025
69320a2
add ruff formatting
fnattino Oct 14, 2025
4241d4d
fix in order to deal with arbitrary shapes together with 0-d params
fnattino Oct 14, 2025
ecfe10c
remove BMI-related function
fnattino Oct 15, 2025
4768a26
fix bug in dimensions handling
fnattino Oct 17, 2025
a9b9baa
add tests
fnattino Oct 17, 2025
5d28d3f
Merge branch 'main' into 19-vectorize-leaf-dynamics-fn
fnattino Oct 17, 2025
af9dd7c
Update tests/physical_models/crop/test_leaf_dynamics.py
fnattino Oct 22, 2025
fd755d7
add comments to the class attributes
fnattino Oct 22, 2025
bc275db
add comments on trackability of gradients
fnattino Oct 22, 2025
8b9b941
apply ruff format
fnattino Oct 22, 2025
4c78770
bug fix on oldest leaf class identification
fnattino Oct 23, 2025
c8c2609
fix initialization of variables
fnattino Oct 23, 2025
919bb7a
refactor function
fnattino Oct 23, 2025
af72abd
add parametrization to tests
fnattino Oct 23, 2025
2752563
fix expected error
fnattino Oct 23, 2025
8679bc3
adapt function to calculate gradients numerically
fnattino Oct 23, 2025
eb4e084
adapt root dynamics test to fit changes in leaf dynamics
fnattino Oct 23, 2025
7148a5d
fix comments to the more generic parameter
fnattino Oct 23, 2025
d7b311b
apply ruff format
fnattino Oct 23, 2025
c952e0f
upack some of the parametrized tests
fnattino Oct 24, 2025
74f0b17
readd check on gradient being zero for TWLV w.r.t. SPAN
fnattino Oct 24, 2025
85a5b55
fix tests on gradient of RD w.r.t. TDWI
fnattino Oct 24, 2025
8aaa76f
lint fix of import statements
fnattino Oct 24, 2025
6abf44e
add tolerance to test
fnattino Oct 24, 2025
053f62f
linting fix
fnattino Oct 24, 2025
c03fb06
replace comparison to use autogradient
fnattino Oct 24, 2025
234a17c
linting fix
fnattino Oct 24, 2025
de555c2
remove test on numerical gradient of twrt wrt rd (is zero)
fnattino Oct 27, 2025
0bf92eb
remove unused import statement
fnattino Oct 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 106 additions & 112 deletions src/diffwofost/physical_models/crop/leaf_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pcse.decorators import prepare_rates
from pcse.decorators import prepare_states
from pcse.traitlets import Any
from pcse.util import Afgen
from pcse.util import AfgenTrait
from pcse.util import limit

DTYPE = torch.float64 # Default data type for tensors in this module

Expand Down Expand Up @@ -113,14 +113,19 @@ class WOFOST_Leaf_Dynamics(SimulationObject):
LAI, TWLV
"""

# The following parameters are used to initialize and control the arrays that store information
# on the leaf classes during the time integration: leaf area, age, and biomass.
START_DATE = None # Start date of the simulation
MAX_DAYS = 300 # Maximum number of days that can be simulated in one run (i.e. array lenghts)

class Parameters(ParamTemplate):
RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
SPAN = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
TBASE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
PERDL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
SLATB = AfgenTrait() # FIXEME
KDIFTB = AfgenTrait() # FIXEME
SLATB = AfgenTrait() # FIXME
KDIFTB = AfgenTrait() # FIXME

class StateVariables(StatesTemplate):
LV = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
Expand Down Expand Up @@ -156,6 +161,7 @@ def initialize(self, day, kiosk, parvalues):
:param parvalues: `ParameterProvider` object providing parameters as
key/value pairs
"""
self.START_DATE = day
self.kiosk = kiosk
# TODO check if parvalues are already torch.nn.Parameters
self.params = self.Parameters(parvalues)
Expand All @@ -171,19 +177,22 @@ def initialize(self, day, kiosk, parvalues):
DVS = self.kiosk["DVS"]

params = self.params
shape = _get_params_shape(params)

# Initial leaf biomass
WLV = (params.TDWI * (1 - FR)) * FL
DWLV = torch.tensor(0.0, dtype=DTYPE)
DWLV = torch.zeros(shape, dtype=DTYPE)
TWLV = WLV + DWLV

# First leaf class (SLA, age and weight)
SLA = torch.tensor([params.SLATB(DVS)], dtype=DTYPE)
LVAGE = torch.tensor([0.0], dtype=DTYPE)
LV = torch.stack([WLV])
# Initialize leaf classes (SLA, age and weight)
SLA = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE)
LVAGE = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE)
LV = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE)
SLA[..., 0] = params.SLATB(DVS)
LV[..., 0] = WLV

# Initial values for leaf area
LAIEM = LV[0] * SLA[0]
LAIEM = LV[..., 0] * SLA[..., 0]
LASUM = LAIEM
LAIEXP = LAIEM
LAIMAX = LAIEM
Expand Down Expand Up @@ -236,57 +245,54 @@ def calc_rates(self, day, drv):
# death due to self shading cause by high LAI
DVS = self.kiosk["DVS"]
LAICR = 3.2 / p.KDIFTB(DVS)
r.DSLV2 = mask * s.WLV * limit(0.0, 0.03, 0.03 * (s.LAI - LAICR) / LAICR)
r.DSLV2 = mask * s.WLV * torch.clamp(0.03 * (s.LAI - LAICR) / LAICR, 0.0, 0.03)

# Death of leaves due to frost damage as determined by
# Reduction Factor Frost "RF_FROST"
if "RF_FROST" in self.kiosk:
r.DSLV3 = mask * s.WLV * k.RF_FROST
else:
r.DSLV3 = torch.tensor(0.0, dtype=DTYPE)
r.DSLV3 = torch.zeros_like(s.WLV, dtype=DTYPE)

# leaf death equals maximum of water stress, shading and frost
r.DSLV = torch.max(torch.stack([r.DSLV1, r.DSLV2, r.DSLV3]))
r.DSLV = torch.maximum(torch.maximum(r.DSLV1, r.DSLV2), r.DSLV3)

# Determine how much leaf biomass classes have to die in states.LV,
# given the a life span > SPAN, these classes will be accumulated
# in DALV.
# Note that the actual leaf death is imposed on the array LV during the
# state integration step.
DALV = torch.tensor(0.0, dtype=DTYPE)
if p.SPAN.requires_grad: # replacing hard threshold `if lvage > p.SPAN``
sharpness = 1000.0 # FIXEME
for lv, lvage in zip(s.LV, s.LVAGE, strict=False):
weight = torch.sigmoid((lvage - p.SPAN) * sharpness)
DALV = DALV + weight * lv
else:
for lv, lvage in zip(s.LV, s.LVAGE, strict=False):
if lvage > p.SPAN:
DALV = DALV + lv

r.DALV = DALV
tSPAN = _broadcast_to(p.SPAN, s.LVAGE.shape) # Broadcast to same shape
# Using a sigmoid here instead of a conditional statement on the value of
# SPAN because the latter would not allow for the gradient to be tracked.
sharpness = torch.tensor(1000.0, dtype=DTYPE) # FIXME
Comment thread
SarahAlidoost marked this conversation as resolved.
weight = torch.sigmoid((s.LVAGE - tSPAN) * sharpness)
r.DALV = torch.sum(weight * s.LV, dim=-1)

# Total death rate leaves
r.DRLV = torch.max(r.DSLV, r.DALV)
r.DRLV = torch.maximum(r.DSLV, r.DALV)

# physiologic ageing of leaves per time step
FYSAGE = (drv.TEMP - p.TBASE) / (35.0 - p.TBASE)
r.FYSAGE = mask * torch.max(torch.tensor(0.0, dtype=DTYPE), FYSAGE)
r.FYSAGE = mask * torch.clamp(FYSAGE, 0.0)

# specific leaf area of leaves per time step
r.SLAT = mask * torch.tensor(p.SLATB(DVS), dtype=DTYPE)

# leaf area not to exceed exponential growth curve
if s.LAIEXP < 6.0:
DTEFF = torch.max(torch.tensor(0.0, dtype=DTYPE), drv.TEMP - p.TBASE)
r.GLAIEX = s.LAIEXP * p.RGRLAI * DTEFF
# source-limited increase in leaf area
r.GLASOL = r.GRLV * r.SLAT
# sink-limited increase in leaf area
GLA = torch.min(r.GLAIEX, r.GLASOL)
# adjustment of specific leaf area of youngest leaf class
if r.GRLV > 0.0:
r.SLAT = GLA / r.GRLV
is_lai_exp = s.LAIEXP < 6.0
DTEFF = torch.clamp(drv.TEMP - p.TBASE, 0.0)
# NOTE: conditional statements do not allow for the gradient to be
# tracked through the condition. Thus, the gradient with respect to
# parameters that contribute to `is_lai_exp` (e.g. RGRLAI and TBASE)
# are expected to be incorrect.
r.GLAIEX = torch.where(is_lai_exp, s.LAIEXP * p.RGRLAI * DTEFF, r.GLAIEX)
# source-limited increase in leaf area
r.GLASOL = torch.where(is_lai_exp, r.GRLV * r.SLAT, r.GLASOL)
# sink-limited increase in leaf area
GLA = torch.minimum(r.GLAIEX, r.GLASOL)
# adjustment of specific leaf area of youngest leaf class
r.SLAT = torch.where(is_lai_exp & (r.GRLV > 0.0), GLA / r.GRLV, r.SLAT)

@prepare_states
def integrate(self, day, delt=1.0):
Expand All @@ -299,44 +305,45 @@ def integrate(self, day, delt=1.0):
tLV = states.LV.clone()
tSLA = states.SLA.clone()
tLVAGE = states.LVAGE.clone()
tDRLV = rates.DRLV

# leaf death is imposed on leaves by removing leave classes from the
# right side.
for LVweigth in reversed(states.LV):
if tDRLV > 0.0:
if tDRLV >= LVweigth: # remove complete leaf class
tDRLV = tDRLV - LVweigth
tLV = tLV[:-1] # Remove last element
tLVAGE = tLVAGE[:-1]
tSLA = tSLA[:-1]
else: # Decrease value of oldest (rightmost) leave class
tLV[-1] = tLV[-1] - tDRLV
tDRLV = torch.tensor(0.0, dtype=DTYPE)
else:
break

tDRLV = _broadcast_to(rates.DRLV, tLV.shape)

# Leaf death is imposed on leaves from the oldest ones.
# Calculate the cumulative sum of weights after leaf death, and
# find out which leaf classes are dead (negative weights)
weight_cumsum = tLV.cumsum(dim=-1) - tDRLV
is_alive = weight_cumsum >= 0
# Adjust value of oldest leaf class, i.e. the first non-zero
# weight along the time axis (the last dimension).
# Cast argument to int because torch.argmax requires it to be numeric
idx_oldest = torch.argmax(is_alive.type(torch.int), dim=-1, keepdim=True)
new_biomass = torch.take_along_dim(weight_cumsum, indices=idx_oldest, dim=-1)
tLV = torch.scatter(tLV, dim=-1, index=idx_oldest, src=new_biomass)
# Zero out all dead leaf classes
# NOTE: conditional statements do not allow for the gradient to be
# tracked through the condition. Thus, the gradient with respect to
# parameters that contribute to `is_alive` are expected to be incorrect.
tLV = torch.where(is_alive, tLV, 0.0)
# Integration of physiological age
tLVAGE = torch.tensor([age + rates.FYSAGE for age in tLVAGE], dtype=DTYPE)
tLVAGE = tLVAGE + rates.FYSAGE
tLVAGE = torch.where(is_alive, tLVAGE, 0.0)
tSLA = torch.where(is_alive, tSLA, 0.0)

# --------- leave growth ---------
# new leaves in class 1
tLV = torch.cat((torch.tensor([rates.GRLV], dtype=DTYPE), tLV))
tSLA = torch.cat((torch.tensor([rates.SLAT], dtype=DTYPE), tSLA))
tLVAGE = torch.cat((torch.tensor([0.0], dtype=DTYPE), tLVAGE))
idx = int((day - self.START_DATE).days / delt)
tLV[..., idx] = rates.GRLV
tSLA[..., idx] = rates.SLAT
tLVAGE[..., idx] = 0.0

# calculation of new leaf area
states.LASUM = torch.sum(
torch.stack([lv * sla for lv, sla in zip(tLV, tSLA, strict=False)])
)
states.LASUM = torch.sum(tLV * tSLA, dim=-1)
states.LAI = self._calc_LAI()
states.LAIMAX = torch.max(states.LAI, states.LAIMAX)
states.LAIMAX = torch.maximum(states.LAI, states.LAIMAX)

# exponential growth curve
states.LAIEXP = states.LAIEXP + rates.GLAIEX

# Update leaf biomass states
states.WLV = torch.sum(tLV)
states.WLV = torch.sum(tLV, dim=-1)
states.DWLV = states.DWLV + rates.DRLV
states.TWLV = states.WLV + states.DWLV

Expand All @@ -345,55 +352,6 @@ def integrate(self, day, delt=1.0):
self.states.SLA = tSLA
self.states.LVAGE = tLVAGE

@prepare_states
def _set_variable_LAI(self, nLAI): # FIXEME
"""Updates the value of LAI to to the new value provided as input.

Related state variables will be updated as well and the increments
to all adjusted state variables will be returned as a dict.
"""
states = self.states

# Store old values of states
oWLV = states.WLV
oLAI = states.LAI
oTWLV = states.TWLV
oLASUM = states.LASUM

# Reduce oLAI for pod and stem area. SAI and PAI will not be adjusted
# because this is often only a small component of the total leaf
# area. For all current crop files in WOFOST SPA and SSA are zero
# anyway
SAI = self.kiosk["SAI"]
PAI = self.kiosk["PAI"]
adj_nLAI = torch.max(nLAI - SAI - PAI, 0.0)
adj_oLAI = torch.max(oLAI - SAI - PAI, 0.0)

# LAI Adjustment factor for leaf biomass LV (rLAI)
if adj_oLAI > 0:
rLAI = adj_nLAI / adj_oLAI
LV = [lv * rLAI for lv in states.LV]
# If adj_oLAI == 0 then add the leave biomass directly to the
# youngest leave age class (LV[0])
else:
LV = [nLAI / states.SLA[0]]

states.LASUM = torch.sum(
torch.tensor([lv * sla for lv, sla in zip(LV, states.SLA, strict=False)], dtype=DTYPE)
)
states.LV = LV
states.LAI = self._calc_LAI()
states.WLV = torch.sum(states.LV)
states.TWLV = states.WLV + states.DWLV

increments = {
"LAI": states.LAI - oLAI,
"LAISUM": states.LASUM - oLASUM,
"WLV": states.WLV - oWLV,
"TWLV": states.TWLV - oTWLV,
}
return increments


def _exist_required_external_variables(kiosk):
"""Check if all required external variables are available in the kiosk."""
Expand All @@ -405,3 +363,39 @@ def _exist_required_external_variables(kiosk):
f" Ensure that all required variables {required_external_vars_at_init}"
" are provided."
)


def _get_params_shape(params):
"""Get the parameters shape.

Parameters can have arbitrary number of dimensions, but all parameters that are not zero-
dimensional should have the same shape.
"""
shape = ()
for parname in params.trait_names():
# Skip special traitlets attributes
if parname.startswith("trait"):
continue
param = getattr(params, parname)
# Skip Afgen parameters:
if isinstance(param, Afgen):
continue
# Parameters that are not zero dimensional should all have the same shape
if param.shape and not shape:
shape = param.shape
elif param.shape:
assert param.shape == shape, (
"All parameters should have the same shape (or have no dimensions)"
)
return shape


def _broadcast_to(x, shape):
"""Create a view of tensor X with the given shape."""
if x.dim() == 0:
# For 0-d tensors, we simply broadcast to the given shape
return torch.broadcast_to(x, shape)
# The given shape should match x in all but the last axis, which represents
# the dimension along which the time integration is carried out.
# We first append an axis to x, then expand to the given shape
return x.unsqueeze(-1).expand(shape)
16 changes: 8 additions & 8 deletions src/diffwofost/physical_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,18 +286,18 @@ def get_test_data(file_path):
return inputs["ModelResults"], inputs["Precision"]


def calculate_numerical_grad(get_model_fn, param_name, param_value, output_index):
def calculate_numerical_grad(get_model_fn, param_name, param_value, out_name):
"""Calculate the numerical gradient of output with respect to a parameter."""
delta = 1e-6
p_plus = param_value.item() + delta
p_minus = param_value.item() - delta
p_plus = param_value + delta
p_minus = param_value - delta

model = get_model_fn()
output = model({param_name: torch.nn.Parameter(torch.tensor(p_plus, dtype=torch.float64))})
loss_plus = output[0, :, output_index].sum()
output = model({param_name: torch.nn.Parameter(p_plus)})
loss_plus = output[out_name].sum(dim=0)

model = get_model_fn()
output = model({param_name: torch.nn.Parameter(torch.tensor(p_minus, dtype=torch.float64))})
loss_minus = output[0, :, output_index].sum()
output = model({param_name: torch.nn.Parameter(p_minus)})
loss_minus = output[out_name].sum(dim=0)

return (loss_plus.item() - loss_minus.item()) / (2 * delta)
return (loss_plus.data - loss_minus.data) / (2 * delta)
Loading