Skip to content
43 changes: 31 additions & 12 deletions src/diffwofost/physical_models/crop/leaf_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pcse.traitlets import Any
from diffwofost.physical_models.utils import AfgenTrait
from diffwofost.physical_models.utils import _broadcast_to
from diffwofost.physical_models.utils import _get_drv
from diffwofost.physical_models.utils import _get_params_shape

DTYPE = torch.float64 # Default data type for tensors in this module
Expand Down Expand Up @@ -100,21 +101,34 @@ class WOFOST_Leaf_Dynamics(SimulationObject):
|--------|-------------------------------------------------------|------|-------------|
| LAI | Leaf area index, including stem and pod area | Y | - |
| TWLV | Dry weight of total leaves (living + dead) | Y | kg ha⁻¹ |

**Gradient mapping (which parameters have a gradient):**

| Output | Parameters influencing it |
|--------|------------------------------------------|
| LAI | TDWI, SPAN, RGRLAI, TBASE, KDIFTB, SLATB |
| TWLV | TDWI, PERDL |

[!] Notice that the following gradients are zero:
- ∂SPAN/∂LAI
- ∂PERDL/∂TWLV
- ∂KDIFTB/∂LAI
Comment thread
SCiarella marked this conversation as resolved.
""" # noqa: E501

# The following parameters are used to initialize and control the arrays that store information
# on the leaf classes during the time integration: leaf area, age, and biomass.
START_DATE = None # Start date of the simulation
MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts)
params_shape = None # Shape of the parameters tensors

class Parameters(ParamTemplate):
RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
SPAN = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
TBASE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
PERDL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
SLATB = AfgenTrait() # FIXME
KDIFTB = AfgenTrait() # FIXME
SLATB = AfgenTrait()
KDIFTB = AfgenTrait()

class StateVariables(StatesTemplate):
LV = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
Expand Down Expand Up @@ -172,17 +186,17 @@ def initialize(
DVS = self.kiosk["DVS"]

params = self.params
shape = _get_params_shape(params)
self.params_shape = _get_params_shape(params)

# Initial leaf biomass
WLV = (params.TDWI * (1 - FR)) * FL
DWLV = torch.zeros(shape, dtype=DTYPE)
DWLV = torch.zeros(self.params_shape, dtype=DTYPE)
TWLV = WLV + DWLV

# Initialize leaf classes (SLA, age and weight)
SLA = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE)
LVAGE = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE)
LV = torch.zeros((*shape, self.MAX_DAYS), dtype=DTYPE)
SLA = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE)
LVAGE = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE)
LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE)
SLA[..., 0] = params.SLATB(DVS)
LV[..., 0] = WLV

Expand Down Expand Up @@ -292,16 +306,20 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None:
# Total death rate leaves
r.DRLV = torch.maximum(r.DSLV, r.DALV)

# Get the temperature from the drv
TEMP = _get_drv(drv.TEMP, self.params_shape)

# physiologic ageing of leaves per time step
FYSAGE = (drv.TEMP - p.TBASE) / (35.0 - p.TBASE)
TBASE = _broadcast_to(p.TBASE, self.params_shape)
FYSAGE = (TEMP - TBASE) / (35.0 - TBASE)
r.FYSAGE = dvs_mask * torch.clamp(FYSAGE, 0.0)

# specific leaf area of leaves per time step
r.SLAT = dvs_mask * torch.tensor(p.SLATB(DVS), dtype=DTYPE)
r.SLAT = dvs_mask * p.SLATB(DVS)

# leaf area not to exceed exponential growth curve
is_lai_exp = s.LAIEXP < 6.0
DTEFF = torch.clamp(drv.TEMP - p.TBASE, 0.0)
DTEFF = torch.clamp(TEMP - TBASE, 0.0)

# NOTE: conditional statements do not allow for the gradient to be
# tracked through the condition. Thus, the gradient with respect to
Expand All @@ -324,9 +342,10 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None:
GLA = torch.minimum(r.GLAIEX, r.GLASOL)

# adjustment of specific leaf area of youngest leaf class
epsilon = 1e-10 # small value to avoid division by zero
Comment thread
SCiarella marked this conversation as resolved.
r.SLAT = torch.where(
dvs_mask.bool(),
torch.where(is_lai_exp & (r.GRLV > 0.0), GLA / r.GRLV, r.SLAT),
torch.where(is_lai_exp & (r.GRLV > epsilon), GLA / (r.GRLV + epsilon), r.SLAT),
torch.tensor(0.0, dtype=DTYPE),
)

Expand Down Expand Up @@ -367,7 +386,7 @@ def integrate(self, day: datetime.date, delt=1.0) -> None:
# tracked through the condition. Thus, the gradient with respect to
# parameters that contribute to `is_alive` are expected to be incorrect.
tLV = torch.where(is_alive, tLV, 0.0)
tLVAGE = tLVAGE + rates.FYSAGE
tLVAGE = tLVAGE + rates.FYSAGE.unsqueeze(-1)
tLVAGE = torch.where(is_alive, tLVAGE, 0.0)
tSLA = torch.where(is_alive, tSLA, 0.0)

Expand Down
53 changes: 50 additions & 3 deletions src/diffwofost/physical_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pcse.base.weather import WeatherDataProvider
from pcse.engine import BaseEngine
from pcse.engine import Engine
from pcse.settings import settings
from pcse.timer import Timer
from pcse.traitlets import TraitType

Expand Down Expand Up @@ -200,21 +201,30 @@ def _run(self):
class WeatherDataProviderTestHelper(WeatherDataProvider):
"""It stores the weatherdata contained within the YAML tests."""

def __init__(self, yaml_weather):
def __init__(self, yaml_weather, meteo_range_checks=True):
super().__init__()
# This is a temporary workaround. The `METEO_RANGE_CHECKS` logic in
# `__setattr__` method in `WeatherDataContainer` is not vector compatible
# yet. So we can disable it here when creating the `WeatherDataContainer`
# instances with arrays.
settings.METEO_RANGE_CHECKS = meteo_range_checks
Comment thread
SCiarella marked this conversation as resolved.
for weather in yaml_weather:
if "SNOWDEPTH" in weather:
weather.pop("SNOWDEPTH")
wdc = WeatherDataContainer(**weather)
self._store_WeatherDataContainer(wdc, wdc.DAY)


def prepare_engine_input(test_data, crop_model_params, dtype=torch.float64):
def prepare_engine_input(
test_data, crop_model_params, meteo_range_checks=True, dtype=torch.float64
):
"""Prepare the inputs for the engine from the YAML file."""
agro_management_inputs = test_data["AgroManagement"]
cropd = test_data["ModelParameters"]

weather_data_provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"])
weather_data_provider = WeatherDataProviderTestHelper(
test_data["WeatherVariables"], meteo_range_checks=meteo_range_checks
)
crop_model_params_provider = ParameterProvider(cropdata=cropd)
external_states = test_data["ExternalStates"]

Expand Down Expand Up @@ -539,6 +549,8 @@ def _get_params_shape(params):

Parameters can have arbitrary number of dimensions, but all parameters that are not zero-
dimensional should have the same shape.

This check if fundamental for vectorized operations in the physical models.
"""
shape = ()
for parname in params.trait_names():
Expand All @@ -556,8 +568,43 @@ def _get_params_shape(params):
return shape


def _get_drv(drv_var, expected_shape):
"""Check that the driving variables have the expected shape and fetch them.

Driving variables can be scalars (0-dimensional) or match the expected shape.
Scalars will be broadcast during operations.

[!] This function will be redundant once weathercontainer supports batched variables.

Args:
drv_var: driving variable in WeatherDataContainer
expected_shape: Expected shape tuple for non-scalar variables

Raises:
ValueError: If any variable has incompatible shape

Returns:
torch.Tensor: The validated variable, either as-is or broadcasted to expected shape.
"""
# Check shape: must be scalar (0-d) or match expected_shape
if not isinstance(drv_var, torch.Tensor) or drv_var.dim() == 0:
# Scalar is valid, will be broadcast
return _broadcast_to(drv_var, expected_shape)
elif drv_var.shape == expected_shape:
# Matches expected shape
return drv_var
else:
raise ValueError(
f"Requested weather variable has incompatible shape {drv_var.shape}. "
f"Expected scalar (0-dimensional) or shape {expected_shape}."
)


def _broadcast_to(x, shape):
"""Create a view of tensor X with the given shape."""
# If x is not a tensor, convert it
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=DTYPE)
# If already the correct shape, return as-is
if x.shape == shape:
return x
Expand Down
Loading