Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ venv.bak/
*.swp
*.swo
*.orig

# test data
tests/physical_models/test_data/*.yaml
67 changes: 50 additions & 17 deletions src/diffwofost/physical_models/crop/leaf_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class WOFOST_Leaf_Dynamics(SimulationObject):
# The following parameters are used to initialize and control the arrays that store information
# on the leaf classes during the time integration: leaf area, age, and biomass.
START_DATE = None # Start date of the simulation
MAX_DAYS = 300 # Maximum number of days that can be simulated in one run (i.e. array lenghts)
MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts)

class Parameters(ParamTemplate):
RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)])
Expand Down Expand Up @@ -233,68 +233,102 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None:
k = self.kiosk

# If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask
# Make a mask (0 if DVS < 0, 1 if DVS >= 0)
# A mask (0 if DVS < 0, 1 if DVS >= 0)
DVS = torch.as_tensor(k["DVS"], dtype=DTYPE)
mask = (DVS >= 0).to(dtype=DTYPE)
dvs_mask = (DVS >= 0).to(dtype=DTYPE)

# Growth rate leaves
# weight of new leaves
r.GRLV = mask * k.ADMI * k.FL
r.GRLV = dvs_mask * k.ADMI * k.FL

# death of leaves due to water/oxygen stress
r.DSLV1 = mask * s.WLV * (1.0 - k.RFTRA) * p.PERDL
r.DSLV1 = dvs_mask * s.WLV * (1.0 - k.RFTRA) * p.PERDL

# death due to self shading cause by high LAI
DVS = self.kiosk["DVS"]
LAICR = 3.2 / p.KDIFTB(DVS)
r.DSLV2 = mask * s.WLV * torch.clamp(0.03 * (s.LAI - LAICR) / LAICR, 0.0, 0.03)
r.DSLV2 = dvs_mask * s.WLV * torch.clamp(0.03 * (s.LAI - LAICR) / LAICR, 0.0, 0.03)

# Death of leaves due to frost damage as determined by
# Reduction Factor Frost "RF_FROST"
if "RF_FROST" in self.kiosk:
r.DSLV3 = mask * s.WLV * k.RF_FROST
r.DSLV3 = s.WLV * k.RF_FROST
else:
r.DSLV3 = torch.zeros_like(s.WLV, dtype=DTYPE)

r.DSLV3 = dvs_mask * r.DSLV3

# leaf death equals maximum of water stress, shading and frost
r.DSLV = torch.maximum(torch.maximum(r.DSLV1, r.DSLV2), r.DSLV3)
r.DSLV = dvs_mask * r.DSLV

# Determine how much leaf biomass classes have to die in states.LV,
# given the a life span > SPAN, these classes will be accumulated
# in DALV.
# Note that the actual leaf death is imposed on the array LV during the
# state integration step.
tSPAN = _broadcast_to(p.SPAN, s.LVAGE.shape) # Broadcast to same shape

# Using a sigmoid here instead of a conditional statement on the value of
# SPAN because the latter would not allow for the gradient to be tracked.
sharpness = torch.tensor(1000.0, dtype=DTYPE) # FIXME
weight = torch.sigmoid((s.LVAGE - tSPAN) * sharpness)
r.DALV = torch.sum(weight * s.LV, dim=-1)
# the if statement `p.SPAN.requires_grad` to avoid unnecessary
# approximation when SPAN is not a learnable parameter.
# TODO: sharpness can be exposed as a parameter
if p.SPAN.requires_grad:
# 1e-16 is chosen empirically for cases when s.LVAGE - tSPAN is very
# small and mask should be 1
sharpness = torch.tensor(1e-16, dtype=DTYPE)

# 1e-14 is chosen empirically for cases when s.LVAGE - tSPAN is
# equal to zero and mask should be 0.0
epsilon = 1e-14
span_mask = torch.sigmoid((s.LVAGE - tSPAN - epsilon) / sharpness).to(dtype=DTYPE)
else:
span_mask = (s.LVAGE > tSPAN).to(dtype=DTYPE)

r.DALV = torch.sum(span_mask * s.LV, dim=-1)
r.DALV = dvs_mask * r.DALV

# Total death rate leaves
r.DRLV = torch.maximum(r.DSLV, r.DALV)

# physiologic ageing of leaves per time step
FYSAGE = (drv.TEMP - p.TBASE) / (35.0 - p.TBASE)
r.FYSAGE = mask * torch.clamp(FYSAGE, 0.0)
r.FYSAGE = dvs_mask * torch.clamp(FYSAGE, 0.0)

# specific leaf area of leaves per time step
r.SLAT = mask * torch.tensor(p.SLATB(DVS), dtype=DTYPE)
r.SLAT = dvs_mask * torch.tensor(p.SLATB(DVS), dtype=DTYPE)

# leaf area not to exceed exponential growth curve
is_lai_exp = s.LAIEXP < 6.0
DTEFF = torch.clamp(drv.TEMP - p.TBASE, 0.0)

# NOTE: conditional statements do not allow for the gradient to be
# tracked through the condition. Thus, the gradient with respect to
# parameters that contribute to `is_lai_exp` (e.g. RGRLAI and TBASE)
# are expected to be incorrect.
r.GLAIEX = torch.where(is_lai_exp, s.LAIEXP * p.RGRLAI * DTEFF, r.GLAIEX)
r.GLAIEX = torch.where(
dvs_mask.bool(),
torch.where(is_lai_exp, s.LAIEXP * p.RGRLAI * DTEFF, r.GLAIEX),
torch.tensor(0.0, dtype=DTYPE),
)

Comment thread
SarahAlidoost marked this conversation as resolved.
# source-limited increase in leaf area
r.GLASOL = torch.where(is_lai_exp, r.GRLV * r.SLAT, r.GLASOL)
r.GLASOL = torch.where(
Comment thread
fnattino marked this conversation as resolved.
dvs_mask.bool(),
torch.where(is_lai_exp, r.GRLV * r.SLAT, r.GLASOL),
torch.tensor(0.0, dtype=DTYPE),
)

# sink-limited increase in leaf area
GLA = torch.minimum(r.GLAIEX, r.GLASOL)

# adjustment of specific leaf area of youngest leaf class
r.SLAT = torch.where(is_lai_exp & (r.GRLV > 0.0), GLA / r.GRLV, r.SLAT)
r.SLAT = torch.where(
Comment thread
fnattino marked this conversation as resolved.
dvs_mask.bool(),
torch.where(is_lai_exp & (r.GRLV > 0.0), GLA / r.GRLV, r.SLAT),
torch.tensor(0.0, dtype=DTYPE),
)

@prepare_states
def integrate(self, day: datetime.date, delt=1.0) -> None:
Expand Down Expand Up @@ -327,13 +361,12 @@ def integrate(self, day: datetime.date, delt=1.0) -> None:
new_biomass = torch.take_along_dim(weight_cumsum, indices=idx_oldest, dim=-1)
tLV = torch.scatter(tLV, dim=-1, index=idx_oldest, src=new_biomass)

# Integration of physiological age
# Zero out all dead leaf classes
# NOTE: conditional statements do not allow for the gradient to be
# tracked through the condition. Thus, the gradient with respect to
# parameters that contribute to `is_alive` are expected to be incorrect.
tLV = torch.where(is_alive, tLV, 0.0)

# Integration of physiological age
tLVAGE = tLVAGE + rates.FYSAGE
tLVAGE = torch.where(is_alive, tLVAGE, 0.0)
tSLA = torch.where(is_alive, tSLA, 0.0)
Expand Down
67 changes: 10 additions & 57 deletions src/diffwofost/physical_models/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""This file contains code that is required to run the YAML unit tests.

It contains:
- SimulationObjectTestHelper: a simobj that wraps the simulation object to be tested.
- VariableKioskTestHelper: A subclass of the VariableKiosk that can use externally
forced states/rates
- ConfigurationLoaderTestHelper: An subclass of ConfigurationLoader that allows to
Expand All @@ -22,14 +21,12 @@
from pcse.agromanager import AgroManager
from pcse.base import ConfigurationLoader
from pcse.base.parameter_providers import ParameterProvider
from pcse.base.simulationobject import SimulationObject
from pcse.base.variablekiosk import VariableKiosk
from pcse.base.weather import WeatherDataContainer
from pcse.base.weather import WeatherDataProvider
from pcse.engine import BaseEngine
from pcse.engine import Engine
from pcse.timer import Timer
from pcse.traitlets import Instance
from pcse.traitlets import TraitType

DTYPE = torch.float64 # Default data type for tensors in this module
Expand All @@ -44,49 +41,6 @@ def nothing(*args, **kwargs):
pass


class SimulationObjectTestHelper(SimulationObject):
"""This wraps the SimulationObject for testing.

This ensuree that the computations are not carried out before crop emergence
(e.g. DVS >= 0). The latter does not apply for the phenology simobject
itself which simulates emergence. The phenology simobject is recognized
because the variable DVS is not an external variable.
"""

test_class = None
subsimobject = Instance(SimulationObject)

def initialize(self, day, kiosk, parvalues):
"""Initialize the subsimobject."""
self.subsimobject = self.test_class(day, kiosk, parvalues)

def calc_rates(self, day, drv):
"""Calculate the rates of the subsimobject."""
# some simobject do not provide a `calc_rates()` function but are directly callable
# here we check for those cases.
func = self.subsimobject if callable(self.subsimobject) else self.subsimobject.calc_rates
if not self.kiosk.is_external_state("DVS"):
func(day, drv)
else:
if self.kiosk.DVS >= 0:
func(day, drv)
else:
self.subsimobject.zerofy()

def integrate(self, day, delt=1.0):
"""Integrate the states of the subsimobject."""
# If the simobject is callable, we do not need integration so we use the
# `nothing()` function.
func = nothing if callable(self.subsimobject) else self.subsimobject.integrate
if not self.kiosk.is_external_state("DVS"):
func(day, delt)
else:
if self.kiosk.DVS >= 0:
func(day, delt)
else:
self.subsimobject.touch()


class VariableKioskTestHelper(VariableKiosk):
"""Variable Kiosk for testing purposes which allows to use external states."""

Expand Down Expand Up @@ -255,25 +209,24 @@ def __init__(self, yaml_weather):
self._store_WeatherDataContainer(wdc, wdc.DAY)


def prepare_engine_input(file_path, crop_model_params):
def prepare_engine_input(test_data, crop_model_params, dtype=torch.float64):
"""Prepare the inputs for the engine from the YAML file."""
inputs = yaml.safe_load(open(file_path))
agro_management_inputs = inputs["AgroManagement"]
cropd = inputs["ModelParameters"]
agro_management_inputs = test_data["AgroManagement"]
cropd = test_data["ModelParameters"]

weather_data_provider = WeatherDataProviderTestHelper(inputs["WeatherVariables"])
weather_data_provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"])
crop_model_params_provider = ParameterProvider(cropdata=cropd)
external_states = inputs["ExternalStates"]
external_states = test_data["ExternalStates"]

# convert parameters to tensors
crop_model_params_provider.clear_override()
for name in crop_model_params:
value = torch.tensor(crop_model_params_provider[name], dtype=torch.float32)
value = torch.tensor(crop_model_params_provider[name], dtype=dtype)
crop_model_params_provider.set_override(name, value, check=False)

# convert external states to tensors
tensor_external_states = [
{k: v if k == "DAY" else torch.tensor(v, dtype=torch.float32) for k, v in item.items()}
{k: v if k == "DAY" else torch.tensor(v, dtype=dtype) for k, v in item.items()}
for item in external_states
]
return (
Expand All @@ -284,10 +237,10 @@ def prepare_engine_input(file_path, crop_model_params):
)


def get_test_data(file_path):
def get_test_data(test_data_path):
"""Get the test data from the YAML file."""
inputs = yaml.safe_load(open(file_path))
return inputs["ModelResults"], inputs["Precision"]
with open(test_data_path) as f:
return yaml.safe_load(f)


def calculate_numerical_grad(get_model_fn, param_name, param_value, out_name):
Expand Down
38 changes: 38 additions & 0 deletions tests/physical_models/conftest.py
Comment thread
SarahAlidoost marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path
import pytest
import requests

LOCAL_TEST_DIR = Path(__file__).parent / "test_data"
BASE_PCSE_URL = "https://raw.githubusercontent.com/ajwdewit/pcse/refs/heads/master/tests/test_data"

model_names = [
"leafdynamics",
"rootdynamics",
"potentialproduction",
]
FILE_NAMES = [
f"test_{model_name}_wofost72_{i:02d}.yaml" for model_name in model_names for i in range(1, 45)
]


def download_file(file_name, local_test_dir=LOCAL_TEST_DIR, base_url=BASE_PCSE_URL):
"""Download a single file from GitHub raw URL to local test_data folder."""

url = f"{base_url}/{file_name}"
local_test_dir.mkdir(exist_ok=True)
local_path = local_test_dir / file_name

if local_path.exists():
return # Already downloaded

print(f"Downloading {file_name} from {url}...")
response = requests.get(url)
response.raise_for_status() # Raise exception on HTTP error
local_path.write_bytes(response.content)


@pytest.fixture(scope="session", autouse=True)
def download_test_files():
"""Download all required test files before running tests."""
for file_name in FILE_NAMES:
download_file(file_name)
Loading