From 51bccbe9f8c1bd90752ad70067a04ba0125ea97b Mon Sep 17 00:00:00 2001 From: SCiarella Date: Thu, 15 Jan 2026 16:43:09 +0100 Subject: [PATCH 01/45] Diffvectorize --- docs/api_reference.md | 2 + .../crop/evapotranspiration.py | 754 ++++++++++++++++++ src/diffwofost/physical_models/utils.py | 13 + tests/physical_models/conftest.py | 1 + .../crop/test_evapotranspiration.py | 703 ++++++++++++++++ 5 files changed, 1473 insertions(+) create mode 100644 src/diffwofost/physical_models/crop/evapotranspiration.py create mode 100644 tests/physical_models/crop/test_evapotranspiration.py diff --git a/docs/api_reference.md b/docs/api_reference.md index ed8303e5..a5960cc3 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -16,6 +16,8 @@ hide: ::: diffwofost.physical_models.crop.partitioning.DVS_Partitioning +::: diffwofost.physical_models.crop.evapotranspiration.Evapotranspiration + ## **Utility (under development)** ::: diffwofost.physical_models.config.Configuration diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py new file mode 100644 index 00000000..c546f790 --- /dev/null +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -0,0 +1,754 @@ +import datetime +import torch +from pcse.base import ParamTemplate +from pcse.base import RatesTemplate +from pcse.base import SimulationObject +from pcse.base import StatesTemplate +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.base.weather import WeatherDataContainer +from pcse.decorators import prepare_rates +from pcse.decorators import prepare_states +from pcse.traitlets import Any +from pcse.traitlets import Bool +from pcse.traitlets import Instance +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_drv +from diffwofost.physical_models.utils import _get_params_shape + + +def _clamp(x: torch.Tensor, lo: float, hi: float) -> torch.Tensor: + return torch.clamp(x, min=lo, max=hi) + + +def _as_tensor(x, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + if isinstance(x, torch.Tensor): + t = x + if dtype is not None: + t = t.to(dtype=dtype) + if device is not None: + t = t.to(device=device) + return t + return torch.tensor(x, dtype=dtype, device=device) + + +def SWEAF(ET0: torch.Tensor, DEPNR: torch.Tensor) -> torch.Tensor: + """Soil Water Easily Available Fraction (SWEAF). + + SWEAF is a function of the potential evapotranspiration rate for a closed + canopy (cm day⁻¹) and the crop dependency number (1..5). + """ + A = 0.76 + B = 1.5 + sweaf = 1.0 / (A + B * ET0) - (5.0 - DEPNR) * 0.10 + correction = (ET0 - 0.6) / (DEPNR * (DEPNR + 3.0)) + # NOTE: PCSE applies `correction` only when `DEPNR < 3` (hard switch), which + # is non-differentiable at `DEPNR==3` and causes numerical vs autograd + # gradient mismatches when treating DEPNR as a continuous tensor. + # + # To keep regression behaviour intact we preserve exact values at the + # discrete DEPNR values used in the YAML fixtures (2.0/3.0/3.5/4.5): + # - DEPNR <= 2: full correction + # - DEPNR >= 3: no correction + # and smoothly transition (C1) between 2 and 3 using a cubic smoothstep. + t = DEPNR - 2.0 + s = 3.0 * t**2 - 2.0 * t**3 # smoothstep on [0,1] + taper_mid = 1.0 - s + taper = torch.where( + DEPNR <= 2.0, + torch.ones_like(DEPNR), + torch.where(DEPNR >= 3.0, torch.zeros_like(DEPNR), taper_mid), + ) + sweaf = sweaf + correction * taper + return _clamp(sweaf, 0.10, 0.95) + + +class EvapotranspirationWrapper(SimulationObject): + """Selects the evapotranspiration implementation. + + Selection logic: + - If `soil_profile` is present in parameters: use the layered CO2-aware module. + - Else if `CO2TRATB` is present: use the non-layered CO2 module. + - Else: use the non-layered (no CO2) module. + """ + + etmodule = Instance(SimulationObject) + + def initialize( + self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + ) -> None: + """Select and initialize the evapotranspiration implementation.""" + if "soil_profile" in parvalues: + self.etmodule = EvapotranspirationCO2Layered(day, kiosk, parvalues) + elif "CO2TRATB" in parvalues: + self.etmodule = EvapotranspirationCO2(day, kiosk, parvalues) + else: + self.etmodule = Evapotranspiration(day, kiosk, parvalues) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + return self.etmodule.calc_rates(day, drv) + + def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + return self.calc_rates(day, drv) + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + return self.etmodule.integrate(day, delt) + + +class _BaseEvapotranspiration(SimulationObject): + """Shared base class for evapotranspiration implementations.""" + + params_shape = None + + @property + def device(self): + return ComputeConfig.get_device() + + @property + def dtype(self): + return ComputeConfig.get_dtype() + + class RateVariables(RatesTemplate): + EVWMX = Any() + EVSMX = Any() + TRAMX = Any() + TRA = Any() + TRALY = Any() + IDOS = Bool(False) + IDWS = Bool(False) + RFWS = Any() + RFOS = Any() + RFTRA = Any() + + def __init__(self, kiosk, publish=None): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + self.EVWMX = torch.tensor(0.0, dtype=dtype, device=device) + self.EVSMX = torch.tensor(0.0, dtype=dtype, device=device) + self.TRAMX = torch.tensor(0.0, dtype=dtype, device=device) + self.TRA = torch.tensor(0.0, dtype=dtype, device=device) + self.TRALY = torch.tensor(0.0, dtype=dtype, device=device) + self.RFWS = torch.tensor(0.0, dtype=dtype, device=device) + self.RFOS = torch.tensor(0.0, dtype=dtype, device=device) + self.RFTRA = torch.tensor(0.0, dtype=dtype, device=device) + super().__init__(kiosk, publish=publish) + + class StateVariables(StatesTemplate): + IDOST = Any() + IDWST = Any() + + def __init__(self, kiosk, publish=None, **kwargs): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + if "IDOST" not in kwargs: + kwargs["IDOST"] = torch.tensor(0.0, dtype=dtype, device=device) + if "IDWST" not in kwargs: + kwargs["IDWST"] = torch.tensor(0.0, dtype=dtype, device=device) + super().__init__(kiosk, publish=publish, **kwargs) + + def _initialize_base( + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + *, + publish_rates: list[str], + ) -> None: + """Shared initialization for evapotranspiration modules.""" + self.kiosk = kiosk + self.params = self.Parameters(parvalues) + self.params_shape = _get_params_shape(self.params) + self.rates = self.RateVariables(kiosk, publish=publish_rates) + self.states = self.StateVariables(kiosk, publish=["IDOST", "IDWST"]) + self._epsilon = torch.tensor(1e-12, dtype=self.dtype, device=self.device) + + def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + return self.calc_rates(day, drv) + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Accumulate stress-day counters.""" + rfws_stress = (self.rates.RFWS < 1.0).to(dtype=self.dtype) + rfos_stress = (self.rates.RFOS < 1.0).to(dtype=self.dtype) + self.states.IDWST = self.states.IDWST + rfws_stress + self.states.IDOST = self.states.IDOST + rfos_stress + + +class _BaseEvapotranspirationNonLayered(_BaseEvapotranspiration): + """Shared implementation for non-layered evapotranspiration.""" + + def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + return torch.ones_like(et0) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + p = self.params + r = self.rates + k = self.kiosk + + dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) + lai = _broadcast_to(k["LAI"], self.params_shape, dtype=self.dtype, device=self.device) + sm = _broadcast_to(k["SM"], self.params_shape, dtype=self.dtype, device=self.device) + + et0 = _get_drv(drv.ET0, self.params_shape, dtype=self.dtype, device=self.device) + e0 = _get_drv(drv.E0, self.params_shape, dtype=self.dtype, device=self.device) + es0 = _get_drv(drv.ES0, self.params_shape, dtype=self.dtype, device=self.device) + rf_tramx_co2 = self._rf_tramx_co2(drv, et0) + + pre_emergence = dvs < 0.0 + if bool(torch.all(pre_emergence)): + zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) + r.EVWMX = zeros + r.EVSMX = zeros + r.TRAMX = zeros + r.TRA = zeros + r.TRALY = zeros + r.RFWS = ones + r.RFOS = ones + r.RFTRA = ones + r.IDWS = False + r.IDOS = False + return r.TRA, r.TRAMX + + kglob = 0.75 * p.KDIFTB(dvs) + et0_crop = torch.clamp(p.CFET * et0, min=0.0) + ekl = torch.exp(-kglob * lai) + + r.EVWMX = e0 * ekl + r.EVSMX = torch.clamp(es0 * ekl, min=0.0) + r.TRAMX = et0_crop * (1.0 - ekl) * rf_tramx_co2 + + swdep = SWEAF(et0_crop, p.DEPNR) + smcr = (1.0 - swdep) * (p.SMFCF - p.SMW) + p.SMW + + denom = torch.where((smcr - p.SMW).abs() > self._epsilon, (smcr - p.SMW), self._epsilon) + r.RFWS = _clamp((sm - p.SMW) / denom, 0.0, 1.0) + + # Oxygen-stress reduction factor (RFOS) + r.RFOS = torch.ones_like(r.RFWS) + iairdu = _broadcast_to(p.IAIRDU, self.params_shape, dtype=self.dtype, device=self.device) + iox = _broadcast_to(p.IOX, self.params_shape, dtype=self.dtype, device=self.device) + mask_ox = (iairdu == 0) & (iox == 1) + + if "DSOS" in k: + dsos = _broadcast_to(k["DSOS"], self.params_shape, dtype=self.dtype, device=self.device) + else: + dsos = torch.zeros_like(r.RFWS) + + crairc = _broadcast_to(p.CRAIRC, self.params_shape, dtype=self.dtype, device=self.device) + sm0 = _broadcast_to(p.SM0, self.params_shape, dtype=self.dtype, device=self.device) + denom_ox = torch.where(crairc.abs() > self._epsilon, crairc, self._epsilon) + rfosmx = _clamp((sm0 - sm) / denom_ox, 0.0, 1.0) + rfos = rfosmx + (1.0 - torch.clamp(dsos, max=4.0) / 4.0) * (1.0 - rfosmx) + r.RFOS = torch.where(mask_ox, rfos, r.RFOS) + + r.RFTRA = r.RFOS * r.RFWS + r.TRA = r.TRAMX * r.RFTRA + r.TRALY = r.TRA + + if bool(torch.any(pre_emergence)): + zeros = torch.zeros_like(r.TRA) + ones = torch.ones_like(r.RFTRA) + r.EVWMX = torch.where(pre_emergence, zeros, r.EVWMX) + r.EVSMX = torch.where(pre_emergence, zeros, r.EVSMX) + r.TRAMX = torch.where(pre_emergence, zeros, r.TRAMX) + r.TRA = torch.where(pre_emergence, zeros, r.TRA) + r.TRALY = torch.where(pre_emergence, zeros, r.TRALY) + r.RFWS = torch.where(pre_emergence, ones, r.RFWS) + r.RFOS = torch.where(pre_emergence, ones, r.RFOS) + r.RFTRA = torch.where(pre_emergence, ones, r.RFTRA) + + r.IDWS = bool(torch.any(r.RFWS < 1.0)) + r.IDOS = bool(torch.any(r.RFOS < 1.0)) + return r.TRA, r.TRAMX + + +class Evapotranspiration(_BaseEvapotranspirationNonLayered): + """Potential evaporation and crop transpiration (no CO2 effect). + + **Simulation parameters** + + | Name | Description | Type | Unit | + |--------|---------------------------------------------------------|------|------| + | CFET | Correction factor for potential transpiration rate | SCr | - | + | DEPNR | Crop dependency number (drought sensitivity, 1..5) | SCr | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | + | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | + | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | + | CRAIRC | Critical air content for root aeration | SSo | - | + | SM0 | Soil porosity | SSo | - | + | SMW | Volumetric soil moisture at wilting point | SSo | - | + | SMFCF | Volumetric soil moisture at field capacity | SSo | - | + + **State variables** + + | Name | Description | Pbl | Unit | + |-------|------------------------------------|-----|------| + | IDWST | Number of days with water stress | N | - | + | IDOST | Number of days with oxygen stress | N | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy | Y | cm day⁻¹ | + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | RFWS | Reduction factor for water stress | N | - | + | RFOS | Reduction factor for oxygen stress | N | - | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|----------------------------------|---------------|------| + | DVS | Crop development stage | Phenology | - | + | LAI | Leaf area index | Leaf dynamics | - | + | SM | Volumetric soil moisture content | Waterbalance | - | + """ + + class Parameters(ParamTemplate): + CFET = Any() + DEPNR = Any() + KDIFTB = AfgenTrait() + IAIRDU = Any() + IOX = Any() + CRAIRC = Any() + SM0 = Any() + SMW = Any() + SMFCF = Any() + + def __init__(self, parvalues): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) + self.DEPNR = torch.tensor(-99.0, dtype=dtype, device=device) + self.IAIRDU = torch.tensor(-99.0, dtype=dtype, device=device) + self.IOX = torch.tensor(-99.0, dtype=dtype, device=device) + self.CRAIRC = torch.tensor(-99.0, dtype=dtype, device=device) + self.SM0 = torch.tensor(-99.0, dtype=dtype, device=device) + self.SMW = torch.tensor(-99.0, dtype=dtype, device=device) + self.SMFCF = torch.tensor(-99.0, dtype=dtype, device=device) + super().__init__(parvalues) + + def initialize( + self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + ) -> None: + self._initialize_base( + day, + kiosk, + parvalues, + publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "RFTRA"], + ) + + +class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): + """Potential evaporation and crop transpiration with CO2 effect on TRAMX. + + **Simulation parameters** + + | Name | Description | Type | Unit | + |----------|--------------------------------------------------------|------|------| + | CFET | Correction factor for potential transpiration rate | SCr | - | + | DEPNR | Crop dependency number (drought sensitivity, 1..5) | SCr | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | + | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | + | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | + | CRAIRC | Critical air content for root aeration | SSo | - | + | SM0 | Soil porosity | SSo | - | + | SMW | Volumetric soil moisture at wilting point | SSo | - | + | SMFCF | Volumetric soil moisture at field capacity | SSo | - | + | CO2 | Atmospheric CO2 concentration (used if not in drivers) | SCr | ppm | + | CO2TRATB | Reduction factor for TRAMX as function of CO2 | TCr | - | + + **State variables** + + | Name | Description | Pbl | Unit | + |-------|------------------------------------|-----|------| + | IDWST | Number of days with water stress | N | - | + | IDOST | Number of days with oxygen stress | N | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy (CO2-adjusted) | Y | cm day⁻¹ | + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | RFWS | Reduction factor for water stress | N | - | + | RFOS | Reduction factor for oxygen stress | N | - | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|----------------------------------|---------------|------| + | DVS | Crop development stage | Phenology | - | + | LAI | Leaf area index | Leaf dynamics | - | + | SM | Volumetric soil moisture content | Waterbalance | - | + """ + + class Parameters(ParamTemplate): + CFET = Any() + DEPNR = Any() + KDIFTB = AfgenTrait() + IAIRDU = Any() + IOX = Any() + CRAIRC = Any() + SM0 = Any() + SMW = Any() + SMFCF = Any() + CO2 = Any() + CO2TRATB = AfgenTrait() + + def __init__(self, parvalues): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) + self.DEPNR = torch.tensor(-99.0, dtype=dtype, device=device) + self.IAIRDU = torch.tensor(-99.0, dtype=dtype, device=device) + self.IOX = torch.tensor(-99.0, dtype=dtype, device=device) + self.CRAIRC = torch.tensor(-99.0, dtype=dtype, device=device) + self.SM0 = torch.tensor(-99.0, dtype=dtype, device=device) + self.SMW = torch.tensor(-99.0, dtype=dtype, device=device) + self.SMFCF = torch.tensor(-99.0, dtype=dtype, device=device) + self.CO2 = torch.tensor(-99.0, dtype=dtype, device=device) + super().__init__(parvalues) + + def initialize( + self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + ) -> None: + self._initialize_base( + day, + kiosk, + parvalues, + publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "TRALY", "RFTRA"], + ) + + def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + p = self.params + + if hasattr(drv, "CO2") and drv.CO2 is not None: + co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) + else: + co2 = _broadcast_to(p.CO2, self.params_shape, dtype=self.dtype, device=self.device) + return p.CO2TRATB(co2) + + +class EvapotranspirationCO2Layered(_BaseEvapotranspiration): + """Layered-soil evapotranspiration with CO2 effect on TRAMX. + + This implementation expects a layered soil water balance. + + **Simulation parameters** + + | Name | Description | Type | Unit | + |----------|--------------------------------------------------------|------|------| + | CFET | Correction factor for potential transpiration rate | SCr | - | + | DEPNR | Crop dependency number (drought sensitivity, 1..5) | SCr | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | + | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | + | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | + | CO2 | Atmospheric CO2 concentration (used if not in drivers) | SCr | ppm | + | CO2TRATB | Reduction factor for TRAMX as function of CO2 | TCr | - | + + Layer-specific soil parameters (SMW, SMFCF, SM0, CRAIRC, Thickness) are + taken from `soil_profile` entries. + + **State variables** + + | Name | Description | Pbl | Unit | + |-------|------------------------------------|-----|------| + | IDWST | Number of days with water stress | N | - | + | IDOST | Number of days with oxygen stress | N | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy (CO2-adjusted) | Y | cm day⁻¹ | + | TRA | Actual canopy transpiration (sum over layers) | Y | cm day⁻¹ | + | TRALY | Transpiration per soil layer | Y | cm day⁻¹ | + | RFWS | Water-stress reduction per layer | N | - | + | RFOS | Oxygen-stress reduction per layer | N | - | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|----------------------------------|---------------|------| + | DVS | Crop development stage | Phenology | - | + | LAI | Leaf area index | Leaf dynamics | - | + | RD | Rooting depth | Root dynamics | cm | + | SM | Soil moisture per layer | Waterbalance | - | + """ + + soil_profile = Any() + + class Parameters(ParamTemplate): + CFET = Any() + DEPNR = Any() + KDIFTB = AfgenTrait() + IAIRDU = Any() + IOX = Any() + CO2 = Any() + CO2TRATB = AfgenTrait() + + def __init__(self, parvalues): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) + self.DEPNR = torch.tensor(-99.0, dtype=dtype, device=device) + self.IAIRDU = torch.tensor(-99.0, dtype=dtype, device=device) + self.IOX = torch.tensor(-99.0, dtype=dtype, device=device) + self.CO2 = torch.tensor(-99.0, dtype=dtype, device=device) + super().__init__(parvalues) + + class RateVariables(RatesTemplate): + EVWMX = Any() + EVSMX = Any() + TRAMX = Any() + TRA = Any() + TRALY = Any() + IDOS = Bool(False) + IDWS = Bool(False) + RFWS = Any() + RFOS = Any() + RFTRALY = Any() + RFTRA = Any() + + def __init__(self, kiosk, publish=None): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + self.EVWMX = torch.tensor(0.0, dtype=dtype, device=device) + self.EVSMX = torch.tensor(0.0, dtype=dtype, device=device) + self.TRAMX = torch.tensor(0.0, dtype=dtype, device=device) + self.TRA = torch.tensor(0.0, dtype=dtype, device=device) + self.TRALY = torch.tensor(0.0, dtype=dtype, device=device) + self.RFWS = torch.tensor(0.0, dtype=dtype, device=device) + self.RFOS = torch.tensor(0.0, dtype=dtype, device=device) + self.RFTRALY = torch.tensor(0.0, dtype=dtype, device=device) + self.RFTRA = torch.tensor(0.0, dtype=dtype, device=device) + super().__init__(kiosk, publish=publish) + + class StateVariables(StatesTemplate): + IDOST = Any() + IDWST = Any() + + def __init__(self, kiosk, publish=None, **kwargs): + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + if "IDOST" not in kwargs: + kwargs["IDOST"] = torch.tensor(0.0, dtype=dtype, device=device) + if "IDWST" not in kwargs: + kwargs["IDWST"] = torch.tensor(0.0, dtype=dtype, device=device) + super().__init__(kiosk, publish=publish, **kwargs) + + def initialize( + self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + ) -> None: + self.soil_profile = parvalues["soil_profile"] + self._initialize_base( + day, + kiosk, + parvalues, + publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "TRALY", "RFTRA"], + ) + # Internal DSOS tracker for layered oxygen-stress response (vectorized). + self._dsos = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + + def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + p = self.params + if hasattr(drv, "CO2") and drv.CO2 is not None: + co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) + else: + co2 = _broadcast_to(p.CO2, self.params_shape, dtype=self.dtype, device=self.device) + return p.CO2TRATB(co2) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + p = self.params + r = self.rates + k = self.kiosk + + dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) + lai = _broadcast_to(k["LAI"], self.params_shape, dtype=self.dtype, device=self.device) + rd = _broadcast_to(k["RD"], self.params_shape, dtype=self.dtype, device=self.device) + + pre_emergence = dvs < 0.0 + n_layers = len(self.soil_profile) + + et0 = _get_drv(drv.ET0, self.params_shape, dtype=self.dtype, device=self.device) + e0 = _get_drv(drv.E0, self.params_shape, dtype=self.dtype, device=self.device) + es0 = _get_drv(drv.ES0, self.params_shape, dtype=self.dtype, device=self.device) + + rf_tramx_co2 = self._rf_tramx_co2(drv, et0) + + if bool(torch.all(pre_emergence)): + zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) + r.EVWMX = zeros + r.EVSMX = zeros + r.TRAMX = zeros + r.TRA = zeros + r.TRALY = torch.zeros( + (n_layers,) + self.params_shape, dtype=self.dtype, device=self.device + ) + r.RFWS = torch.ones( + (n_layers,) + self.params_shape, dtype=self.dtype, device=self.device + ) + r.RFOS = torch.ones( + (n_layers,) + self.params_shape, dtype=self.dtype, device=self.device + ) + r.RFTRA = ones + r.IDWS = False + r.IDOS = False + return r.TRA, r.TRAMX + + et0_crop = torch.clamp(p.CFET * et0, min=0.0) + kglob = 0.75 * p.KDIFTB(dvs) + ekl = torch.exp(-kglob * lai) + r.EVWMX = e0 * ekl + r.EVSMX = torch.clamp(es0 * ekl, min=0.0) + r.TRAMX = et0_crop * (1.0 - ekl) * rf_tramx_co2 + + swdep = SWEAF(et0_crop, p.DEPNR) + + # Layered soil moisture can be provided as: + # - torch.Tensor with shape (n_layers, *params_shape) + # - list/tuple of length n_layers, each element scalar or tensor + sm_layers = k["SM"] + if isinstance(sm_layers, torch.Tensor): + sm_layers_t = sm_layers.to(dtype=self.dtype, device=self.device) + elif isinstance(sm_layers, (list, tuple)): + if len(sm_layers) != n_layers: + raise ValueError( + f"Layered evapotranspiration expects SM with {n_layers} layers, got {len(sm_layers)}." + ) + sm_layers_t = torch.stack( + [ + _broadcast_to(sm_i, self.params_shape, dtype=self.dtype, device=self.device) + for sm_i in sm_layers + ], + dim=0, + ) + else: + sm_layers_t = torch.as_tensor(sm_layers, dtype=self.dtype, device=self.device) + if sm_layers_t.dim() == 1: + # Interpret as per-layer scalars + if sm_layers_t.shape[0] != n_layers: + raise ValueError( + f"Layered evapotranspiration expects SM with {n_layers} layers, got {sm_layers_t.shape[0]}." + ) + sm_layers_t = torch.stack( + [ + _broadcast_to( + sm_layers_t[i], self.params_shape, dtype=self.dtype, device=self.device + ) + for i in range(n_layers) + ], + dim=0, + ) + + if sm_layers_t.shape[0] != n_layers: + raise ValueError( + f"Layered evapotranspiration expects SM first dim to be {n_layers}, got {sm_layers_t.shape[0]}." + ) + + rfws_list = [] + rfos_list = [] + traly_list = [] + + depth = 0.0 + for i, layer in enumerate(self.soil_profile): + sm_i = _broadcast_to( + sm_layers_t[i], self.params_shape, dtype=self.dtype, device=self.device + ) + layer_smw = _as_tensor(layer.SMW, dtype=self.dtype, device=self.device) + layer_smfcf = _as_tensor(layer.SMFCF, dtype=self.dtype, device=self.device) + + smcr = (1.0 - swdep) * (layer_smfcf - layer_smw) + layer_smw + denom = torch.where( + (smcr - layer_smw).abs() > self._epsilon, (smcr - layer_smw), self._epsilon + ) + rfws_i = _clamp((sm_i - layer_smw) / denom, 0.0, 1.0) + + rfos_i = torch.ones_like(rfws_i) + iairdu = _broadcast_to( + p.IAIRDU, self.params_shape, dtype=self.dtype, device=self.device + ) + iox = _broadcast_to(p.IOX, self.params_shape, dtype=self.dtype, device=self.device) + if bool(torch.any((iairdu == 0) & (iox == 1))): + layer_sm0 = _as_tensor(layer.SM0, dtype=self.dtype, device=self.device) + layer_crairc = _as_tensor(layer.CRAIRC, dtype=self.dtype, device=self.device) + smair = layer_sm0 - layer_crairc + self._dsos = torch.where( + sm_i >= smair, + torch.clamp(self._dsos + 1.0, max=4.0), + torch.zeros_like(self._dsos), + ) + denom_ox = torch.where( + layer_crairc.abs() > self._epsilon, layer_crairc, self._epsilon + ) + rfosmx = _clamp((layer_sm0 - sm_i) / denom_ox, 0.0, 1.0) + rfos_i = rfosmx + (1.0 - torch.clamp(self._dsos, max=4.0) / 4.0) * (1.0 - rfosmx) + + thickness = float(layer.Thickness) + depth_lo = _as_tensor(depth, dtype=self.dtype, device=self.device) + depth_hi = _as_tensor(depth + thickness, dtype=self.dtype, device=self.device) + root_len = torch.clamp(torch.minimum(rd, depth_hi) - depth_lo, min=0.0) + root_fraction = torch.where( + rd > self._epsilon, root_len / rd, torch.zeros_like(root_len) + ) + rftra_i = rfos_i * rfws_i + traly_i = r.TRAMX * rftra_i * root_fraction + + rfws_list.append(rfws_i) + rfos_list.append(rfos_i) + traly_list.append(traly_i) + depth += thickness + + r.RFWS = torch.stack(rfws_list, dim=0) + r.RFOS = torch.stack(rfos_list, dim=0) + r.TRALY = torch.stack(traly_list, dim=0) + r.TRA = r.TRALY.sum(dim=0) + r.RFTRA = torch.where(r.TRAMX > self._epsilon, r.TRA / r.TRAMX, torch.ones_like(r.TRA)) + + if bool(torch.any(pre_emergence)): + zeros = torch.zeros_like(r.TRA) + ones = torch.ones_like(r.RFTRA) + r.EVWMX = torch.where(pre_emergence, zeros, r.EVWMX) + r.EVSMX = torch.where(pre_emergence, zeros, r.EVSMX) + r.TRAMX = torch.where(pre_emergence, zeros, r.TRAMX) + r.TRA = torch.where(pre_emergence, zeros, r.TRA) + r.RFTRA = torch.where(pre_emergence, ones, r.RFTRA) + + pre_layers = pre_emergence.unsqueeze(0).expand_as(r.RFWS) + ones_layers = torch.ones_like(r.RFWS) + zeros_layers = torch.zeros_like(r.TRALY) + r.RFWS = torch.where(pre_layers, ones_layers, r.RFWS) + r.RFOS = torch.where(pre_layers, ones_layers, r.RFOS) + r.TRALY = torch.where(pre_layers, zeros_layers, r.TRALY) + + r.IDWS = bool(torch.any(r.RFWS < 1.0)) + r.IDOS = bool(torch.any(r.RFOS < 1.0)) + return r.TRA, r.TRAMX + + def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + return self.calc_rates(day, drv) + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + rfws_stress = (self.rates.RFWS < 1.0).any(dim=0).to(dtype=self.dtype) + rfos_stress = (self.rates.RFOS < 1.0).any(dim=0).to(dtype=self.dtype) + self.states.IDWST = self.states.IDWST + rfws_stress + self.states.IDOST = self.states.IDOST + rfos_stress diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index f3229c7f..6f0fa09a 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -211,6 +211,19 @@ def prepare_engine_input( weather_data_provider = WeatherDataProviderTestHelper( test_data["WeatherVariables"], meteo_range_checks=meteo_range_checks ) + + # The PCSE WeatherDataContainer stores required variables as Python floats. + # Some of our tests rely on weather inputs being torch.Tensors (e.g. to + # broadcast/batch weather variables). We only do this conversion when + # METEO_RANGE_CHECKS is disabled because the PCSE range checks assume + # scalar floats. + if not meteo_range_checks: + for (_, _), wdc in weather_data_provider.store.items(): + for varname in ("IRRAD", "TMIN", "TMAX", "VAP", "RAIN", "WIND", "E0", "ES0", "ET0"): + if hasattr(wdc, varname): + value = getattr(wdc, varname) + if not isinstance(value, torch.Tensor): + setattr(wdc, varname, torch.tensor(value, dtype=dtype, device=device)) crop_model_params_provider = ParameterProvider(cropdata=cropd) external_states = test_data.get("ExternalStates") or [] diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index 862c62e1..00aa7e08 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -14,6 +14,7 @@ "phenology", "partitioning", "assimilation", + "transpiration", ] FILE_NAMES = [ f"test_{model_name}_wofost72_{i:02d}.yaml" for model_name in model_names for i in range(1, 45) diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py new file mode 100644 index 00000000..44d22636 --- /dev/null +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -0,0 +1,703 @@ +import copy +import datetime +import warnings +from types import SimpleNamespace +from unittest.mock import patch +import pytest +import torch +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.models import Wofost72_PP +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.evapotranspiration import Evapotranspiration +from diffwofost.physical_models.crop.evapotranspiration import EvapotranspirationCO2 +from diffwofost.physical_models.crop.evapotranspiration import EvapotranspirationCO2Layered +from diffwofost.physical_models.crop.evapotranspiration import EvapotranspirationWrapper +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +evapotranspiration_config = Configuration( + CROP=EvapotranspirationWrapper, + OUTPUT_VARS=["EVSMX", "EVWMX", "TRAMX", "TRA"], +) + + +def get_test_diff_evapotranspiration_model(device: str = "cpu"): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( + prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False, device=device) + ) + return DiffEvapotranspiration( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + copy.deepcopy(external_states), + device=device, + ) + + +class DiffEvapotranspiration(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config, + external_states, + device: str = "cpu", + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config = config + self.external_states = external_states + self.device = device + + def forward(self, params_dict: dict[str, torch.Tensor]): + for name, value in params_dict.items(): + if isinstance(value, torch.Tensor) and value.device.type != self.device: + value = value.to(self.device) + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config, + self.external_states, + device=self.device, + ) + engine.run_till_terminate() + results = engine.get_output() + return { + var: torch.stack([item[var] for item in results]) + for var in ["EVSMX", "EVWMX", "TRAMX", "TRA"] + } + + +class TestEvapotranspiration: + transpiration_data_urls = [ + f"{phy_data_folder}/test_transpiration_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + @pytest.mark.parametrize("test_data_url", transpiration_data_urls) + def test_evapotranspiration_with_testengine(self, test_data_url, device): + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + abs(reference[var] - model_cpu[var]) < precision + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize( + "param", + [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + "ET0", + ], + ) + def test_evapotranspiration_with_one_parameter_vector(self, param, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + if param == "ET0": + for (_, _), wdc in weather_data_provider.store.items(): + wdc.ET0 = torch.ones(10, dtype=torch.float64, device=wdc.ET0.device) * wdc.ET0 + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device=device, + ) + return + + if param == "KDIFTB": + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + all(abs(reference[var] - model_cpu[var]) < precision) + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize( + "param,delta", + [ + ("CFET", 0.1), + ("DEPNR", 1.0), + ("KDIFTB", 0.05), + ("SMW", 0.01), + ("SMFCF", 0.01), + ("SM0", 0.01), + ], + ) + def test_evapotranspiration_with_different_parameter_values(self, param, delta, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + test_value = crop_model_params_provider[param] + if param == "KDIFTB": + non_zeros_mask = test_value != 0 + param_vec = torch.stack([test_value + non_zeros_mask * delta, test_value]) + else: + param_vec = torch.tensor( + [test_value - delta, test_value + delta, test_value], device=device + ) + crop_model_params_provider.set_override(param, param_vec, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + abs(reference[var] - model_cpu[var][-1]) < precision + for var, precision in expected_precision.items() + ) + + def test_evapotranspiration_with_multiple_parameter_vectors(self, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + for param in crop_model_params: + if param == "KDIFTB": + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + + def test_evapotranspiration_with_multiple_parameter_arrays(self, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + # Use an arbitrary batched shape and keep weather vars consistent. + batch_shape = (30, 5) + for param in ("CFET", "DEPNR", "KDIFTB"): + if param == "KDIFTB": + repeated = crop_model_params_provider[param].repeat(*batch_shape, 1) + else: + repeated = crop_model_params_provider[param].broadcast_to(batch_shape) + crop_model_params_provider.set_override(param, repeated, check=False) + + for (_, _), wdc in weather_data_provider.store.items(): + wdc.ET0 = torch.ones(batch_shape, dtype=torch.float64, device=wdc.ET0.device) * wdc.ET0 + wdc.E0 = torch.ones(batch_shape, dtype=torch.float64, device=wdc.E0.device) * wdc.E0 + wdc.ES0 = torch.ones(batch_shape, dtype=torch.float64, device=wdc.ES0.device) * wdc.ES0 + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device=device, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + torch.all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + assert all(model[var].shape == batch_shape for var in expected_precision.keys()) + + def test_evapotranspiration_with_incompatible_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + + crop_model_params_provider.set_override( + "CFET", crop_model_params_provider["CFET"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "DEPNR", crop_model_params_provider["DEPNR"].repeat(5), check=False + ) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device="cpu", + ) + + def test_evapotranspiration_with_incompatible_weather_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + + crop_model_params_provider.set_override( + "CFET", crop_model_params_provider["CFET"].repeat(10), check=False + ) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.ET0 = torch.ones(5, dtype=torch.float64, device=wdc.ET0.device) * wdc.ET0 + + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + device="cpu", + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls[:1]) + def test_wofost_pp_with_evapotranspiration(self, test_data_url): + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + ) + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.Evapotranspiration", EvapotranspirationWrapper): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + abs(reference[var] - model[var]) < precision + for var, precision in expected_precision.items() + ) + + +def _minimal_parvalues(device: str, *, include_co2: bool = False, include_layers: bool = False): + dtype = torch.float64 + pars: dict[str, object] = { + "CFET": torch.tensor(1.0, dtype=dtype, device=device), + "DEPNR": torch.tensor(2.0, dtype=dtype, device=device), + "KDIFTB": torch.tensor([0.0, 0.69, 2.0, 0.69], dtype=dtype, device=device), + "IAIRDU": torch.tensor(0.0, dtype=dtype, device=device), + "IOX": torch.tensor(0.0, dtype=dtype, device=device), + "CRAIRC": torch.tensor(0.06, dtype=dtype, device=device), + "SM0": torch.tensor(0.40, dtype=dtype, device=device), + "SMW": torch.tensor(0.15, dtype=dtype, device=device), + "SMFCF": torch.tensor(0.29, dtype=dtype, device=device), + } + + if include_co2: + pars.update( + { + "CO2": torch.tensor(700.0, dtype=dtype, device=device), + "CO2TRATB": torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=dtype, device=device), + } + ) + + if include_layers: + soil_profile = [ + SimpleNamespace(SMW=0.15, SMFCF=0.29, SM0=0.40, CRAIRC=0.06, Thickness=10.0), + SimpleNamespace(SMW=0.16, SMFCF=0.30, SM0=0.41, CRAIRC=0.06, Thickness=20.0), + ] + pars["soil_profile"] = soil_profile + + return ParameterProvider(cropdata=pars) + + +class TestEvapotranspirationVariants: + def test_wrapper_selects_base(self, device): + parvalues = _minimal_parvalues(device) + kiosk = VariableKiosk() + wrapper = EvapotranspirationWrapper(datetime.date(2000, 1, 1), kiosk, parvalues) + assert isinstance(wrapper.etmodule, Evapotranspiration) + + def test_wrapper_selects_co2(self, device): + parvalues = _minimal_parvalues(device, include_co2=True) + kiosk = VariableKiosk() + wrapper = EvapotranspirationWrapper(datetime.date(2000, 1, 1), kiosk, parvalues) + assert isinstance(wrapper.etmodule, EvapotranspirationCO2) + + def test_wrapper_selects_layered(self, device): + parvalues = _minimal_parvalues(device, include_co2=True, include_layers=True) + kiosk = VariableKiosk() + wrapper = EvapotranspirationWrapper(datetime.date(2000, 1, 1), kiosk, parvalues) + assert isinstance(wrapper.etmodule, EvapotranspirationCO2Layered) + + def test_co2_reduces_tramx(self, device): + def _kiosk_with_states(): + kiosk = VariableKiosk() + oid = 0 + for name in ("DVS", "LAI", "SM"): + kiosk.register_variable(oid, name, type="S", publish=True) + kiosk.set_variable(oid, "DVS", torch.tensor(1.0, dtype=torch.float64, device=device)) + kiosk.set_variable(oid, "LAI", torch.tensor(3.0, dtype=torch.float64, device=device)) + kiosk.set_variable(oid, "SM", torch.tensor(0.25, dtype=torch.float64, device=device)) + return kiosk + + drv = SimpleNamespace( + ET0=torch.tensor(0.5, dtype=torch.float64, device=device), + E0=torch.tensor(0.6, dtype=torch.float64, device=device), + ES0=torch.tensor(0.55, dtype=torch.float64, device=device), + CO2=torch.tensor(700.0, dtype=torch.float64, device=device), + ) + + p_base = _minimal_parvalues(device) + kiosk_base = _kiosk_with_states() + base = Evapotranspiration(datetime.date(2000, 1, 1), kiosk_base, p_base) + base.calc_rates(datetime.date(2000, 1, 2), drv) + tramx_base = base.rates.TRAMX + + p_co2 = _minimal_parvalues(device, include_co2=True) + kiosk_co2 = _kiosk_with_states() + co2 = EvapotranspirationCO2(datetime.date(2000, 1, 1), kiosk_co2, p_co2) + co2.calc_rates(datetime.date(2000, 1, 2), drv) + tramx_co2 = co2.rates.TRAMX + + assert torch.all(tramx_co2 <= tramx_base) + + +class TestDiffEvapotranspirationGradients: + param_names = ["CFET", "DEPNR", "KDIFTB"] + output_names = ["EVWMX", "EVSMX", "TRAMX", "TRA"] + + param_configs = { + "single": { + "CFET": (1.0, torch.float64), + "DEPNR": (2.0, torch.float64), + "KDIFTB": ([[0.0, 0.69, 2.0, 0.69]], torch.float64), + }, + "tensor": { + "CFET": ([0.8, 1.0, 1.2], torch.float64), + "DEPNR": ([1.0, 2.0, 3.0], torch.float64), + "KDIFTB": ( + [[0.0, 0.60, 2.0, 0.60], [0.0, 0.69, 2.0, 0.69], [0.0, 0.78, 2.0, 0.78]], + torch.float64, + ), + }, + } + + gradient_mapping = { + "CFET": ["TRAMX", "TRA"], + "DEPNR": ["TRA"], + "KDIFTB": ["EVWMX", "EVSMX", "TRAMX", "TRA"], + } + + gradient_params = [] + no_gradient_params = [] + for param_name in param_names: + for output_name in output_names: + if output_name in gradient_mapping.get(param_name, []): + gradient_params.append((param_name, output_name)) + else: + no_gradient_params.append((param_name, output_name)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type, device): + model = get_test_diff_evapotranspiration_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output[output_name].sum() + + if not loss.requires_grad: + grads = None + else: + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): + model = get_test_diff_evapotranspiration_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + + output = model({param_name: param}) + loss = output[output_name].sum() + + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + assert grads is not None, f"Gradients for {param_name} should not be None" + + param.grad = None + loss.backward() + grad_backward = param.grad + + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), "Forward and backward gradients should match" + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type, device): + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + + numerical_grad = calculate_numerical_grad( + lambda: get_test_diff_evapotranspiration_model(device=device), + param_name, + param, + output_name, + ) + + model = get_test_diff_evapotranspiration_model(device=device) + output = model({param_name: param}) + loss = output[output_name].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + torch.testing.assert_close( + numerical_grad.detach().cpu(), + grads.detach().cpu(), + rtol=1e-3, + atol=1e-3, + ) + + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}' w.r.t '{output_name}' is zero: {grads.data}", + UserWarning, + ) From 8da9fbeeb0117267b9fe2a51a9cb93d9d6cb0164 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 19 Jan 2026 11:52:36 +0100 Subject: [PATCH 02/45] Clean --- .../crop/evapotranspiration.py | 51 ++++++++++++++++--- .../crop/test_evapotranspiration.py | 3 +- 2 files changed, 47 insertions(+), 7 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index c546f790..26549988 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -20,10 +20,12 @@ def _clamp(x: torch.Tensor, lo: float, hi: float) -> torch.Tensor: + """Clamp tensor values to the range [lo, hi].""" return torch.clamp(x, min=lo, max=hi) def _as_tensor(x, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: + """Convert input to a tensor with specified dtype and device.""" if isinstance(x, torch.Tensor): t = x if dtype is not None: @@ -79,7 +81,11 @@ class EvapotranspirationWrapper(SimulationObject): def initialize( self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider ) -> None: - """Select and initialize the evapotranspiration implementation.""" + """Select and initialize the evapotranspiration implementation. + + Chooses between layered CO2-aware, non-layered CO2, or standard evapotranspiration + based on available parameters. + """ if "soil_profile" in parvalues: self.etmodule = EvapotranspirationCO2Layered(day, kiosk, parvalues) elif "CO2TRATB" in parvalues: @@ -89,13 +95,16 @@ def initialize( @prepare_rates def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Delegate rate calculation to the selected evapotranspiration module.""" return self.etmodule.calc_rates(day, drv) def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Callable interface for rate calculation.""" return self.calc_rates(day, drv) @prepare_states def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Delegate state integration to the selected evapotranspiration module.""" return self.etmodule.integrate(day, delt) @@ -106,10 +115,12 @@ class _BaseEvapotranspiration(SimulationObject): @property def device(self): + """Get the compute device (CPU or CUDA) from global configuration.""" return ComputeConfig.get_device() @property def dtype(self): + """Get the default data type (float32/float64) from global configuration.""" return ComputeConfig.get_dtype() class RateVariables(RatesTemplate): @@ -158,7 +169,11 @@ def _initialize_base( *, publish_rates: list[str], ) -> None: - """Shared initialization for evapotranspiration modules.""" + """Shared initialization for evapotranspiration modules. + + Sets up parameters, rate and state variables, and numerical epsilon for all + evapotranspiration implementations. + """ self.kiosk = kiosk self.params = self.Parameters(parvalues) self.params_shape = _get_params_shape(self.params) @@ -167,11 +182,12 @@ def _initialize_base( self._epsilon = torch.tensor(1e-12, dtype=self.dtype, device=self.device) def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Callable interface for rate calculation.""" return self.calc_rates(day, drv) @prepare_states def integrate(self, day: datetime.date = None, delt=1.0) -> None: - """Accumulate stress-day counters.""" + """Accumulate stress-day counters for water and oxygen stress.""" rfws_stress = (self.rates.RFWS < 1.0).to(dtype=self.dtype) rfos_stress = (self.rates.RFOS < 1.0).to(dtype=self.dtype) self.states.IDWST = self.states.IDWST + rfws_stress @@ -182,6 +198,7 @@ class _BaseEvapotranspirationNonLayered(_BaseEvapotranspiration): """Shared implementation for non-layered evapotranspiration.""" def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + """Return CO2 reduction factor for TRAMX (no CO2 effect in base implementation).""" return torch.ones_like(et0) @prepare_rates @@ -340,6 +357,7 @@ def __init__(self, parvalues): def initialize( self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider ) -> None: + """Initialize the standard evapotranspiration module (no CO2 effects).""" self._initialize_base( day, kiosk, @@ -409,6 +427,7 @@ class Parameters(ParamTemplate): CO2TRATB = AfgenTrait() def __init__(self, parvalues): + """Initialize CO2-aware parameters with default placeholder values before loading.""" dtype = ComputeConfig.get_dtype() device = ComputeConfig.get_device() self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) @@ -425,6 +444,7 @@ def __init__(self, parvalues): def initialize( self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider ) -> None: + """Initialize the CO2-aware evapotranspiration module.""" self._initialize_base( day, kiosk, @@ -433,6 +453,7 @@ def initialize( ) def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + """Calculate CO2 reduction factor for TRAMX based on atmospheric CO2 concentration.""" p = self.params if hasattr(drv, "CO2") and drv.CO2 is not None: @@ -504,6 +525,7 @@ class Parameters(ParamTemplate): CO2TRATB = AfgenTrait() def __init__(self, parvalues): + """Initialize layered CO2-aware parameters with default placeholder values.""" dtype = ComputeConfig.get_dtype() device = ComputeConfig.get_device() self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) @@ -527,6 +549,7 @@ class RateVariables(RatesTemplate): RFTRA = Any() def __init__(self, kiosk, publish=None): + """Initialize rate variables including per-layer transpiration and stress factors.""" dtype = ComputeConfig.get_dtype() device = ComputeConfig.get_device() self.EVWMX = torch.tensor(0.0, dtype=dtype, device=device) @@ -545,6 +568,7 @@ class StateVariables(StatesTemplate): IDWST = Any() def __init__(self, kiosk, publish=None, **kwargs): + """Initialize state variables for layered stress-day counters.""" dtype = ComputeConfig.get_dtype() device = ComputeConfig.get_device() if "IDOST" not in kwargs: @@ -556,6 +580,10 @@ def __init__(self, kiosk, publish=None, **kwargs): def initialize( self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider ) -> None: + """Initialize the layered-soil CO2-aware evapotranspiration module. + + Sets up layer-specific soil parameters and internal oxygen stress tracking. + """ self.soil_profile = parvalues["soil_profile"] self._initialize_base( day, @@ -567,6 +595,7 @@ def initialize( self._dsos = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + """Calculate CO2 reduction factor for TRAMX using CO2 from driver or parameters.""" p = self.params if hasattr(drv, "CO2") and drv.CO2 is not None: co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) @@ -576,6 +605,11 @@ def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.T @prepare_rates def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Calculate daily evapotranspiration rates per soil layer with CO2 effects. + + Computes transpiration and stress factors for each soil layer based on root + distribution and layer-specific soil moisture conditions. + """ p = self.params r = self.rates k = self.kiosk @@ -632,7 +666,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None elif isinstance(sm_layers, (list, tuple)): if len(sm_layers) != n_layers: raise ValueError( - f"Layered evapotranspiration expects SM with {n_layers} layers, got {len(sm_layers)}." + "Layered evapotranspiration expects SM with " + + f"{n_layers} layers, got {len(sm_layers)}." ) sm_layers_t = torch.stack( [ @@ -647,7 +682,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # Interpret as per-layer scalars if sm_layers_t.shape[0] != n_layers: raise ValueError( - f"Layered evapotranspiration expects SM with {n_layers} layers, got {sm_layers_t.shape[0]}." + "Layered evapotranspiration expects SM with " + + f"{n_layers} layers, got {sm_layers_t.shape[0]}." ) sm_layers_t = torch.stack( [ @@ -661,7 +697,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None if sm_layers_t.shape[0] != n_layers: raise ValueError( - f"Layered evapotranspiration expects SM first dim to be {n_layers}, got {sm_layers_t.shape[0]}." + "Layered evapotranspiration expects SM first dim to be " + + f"{n_layers}, got {sm_layers_t.shape[0]}." ) rfws_list = [] @@ -744,10 +781,12 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None return r.TRA, r.TRAMX def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Callable interface for rate calculation.""" return self.calc_rates(day, drv) @prepare_states def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Accumulate stress-day counters based on any layer experiencing stress.""" rfws_stress = (self.rates.RFWS < 1.0).any(dim=0).to(dtype=self.dtype) rfos_stress = (self.rates.RFOS < 1.0).any(dim=0).to(dtype=self.dtype) self.states.IDWST = self.states.IDWST + rfws_stress diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py index 44d22636..3834ebb8 100644 --- a/tests/physical_models/crop/test_evapotranspiration.py +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -698,6 +698,7 @@ def test_gradients_numerical(self, param_name, output_name, config_type, device) if torch.all(grads == 0): warnings.warn( - f"Gradient for parameter '{param_name}' w.r.t '{output_name}' is zero: {grads.data}", + f"Gradient for parameter '{param_name}'" + + f" w.r.t '{output_name}' is zero: {grads.data}", UserWarning, ) From ae5016cfe2ecd5893940a6215cf4ac1f44d94744 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 19 Jan 2026 12:45:24 +0100 Subject: [PATCH 03/45] Up coverage --- .../crop/test_evapotranspiration.py | 99 +++++++++++++++++-- 1 file changed, 89 insertions(+), 10 deletions(-) diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py index 3834ebb8..915c37ed 100644 --- a/tests/physical_models/crop/test_evapotranspiration.py +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -25,6 +25,55 @@ ) +def _augment_params_for_variant(crop_model_params_provider, variant: str): + """Augment parameters to enable specific evapotranspiration variant. + + Args: + crop_model_params_provider: Base parameter provider + variant: One of 'base', 'co2', or 'layered' + """ + if variant == "base": + # No augmentation needed + return + elif variant == "co2": + # Add CO2 parameters to enable EvapotranspirationCO2 + crop_model_params_provider.set_override( + "CO2", torch.tensor(360.0, dtype=torch.float64), check=False + ) + crop_model_params_provider.set_override( + "CO2TRATB", torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64), check=False + ) + elif variant == "layered": + # Add CO2 and soil_profile to enable EvapotranspirationCO2Layered + crop_model_params_provider.set_override( + "CO2", torch.tensor(360.0, dtype=torch.float64), check=False + ) + crop_model_params_provider.set_override( + "CO2TRATB", torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64), check=False + ) + # Create a simple two-layer soil profile using existing soil parameters + smw = crop_model_params_provider["SMW"] + smfcf = crop_model_params_provider["SMFCF"] + sm0 = crop_model_params_provider["SM0"] + crairc = crop_model_params_provider["CRAIRC"] + + # Convert to Python scalars if they are tensors + smw_val = float(smw.item() if isinstance(smw, torch.Tensor) else smw) + smfcf_val = float(smfcf.item() if isinstance(smfcf, torch.Tensor) else smfcf) + sm0_val = float(sm0.item() if isinstance(sm0, torch.Tensor) else sm0) + crairc_val = float(crairc.item() if isinstance(crairc, torch.Tensor) else crairc) + + soil_profile = [ + SimpleNamespace( + SMW=smw_val, SMFCF=smfcf_val, SM0=sm0_val, CRAIRC=crairc_val, Thickness=10.0 + ), + SimpleNamespace( + SMW=smw_val, SMFCF=smfcf_val, SM0=sm0_val, CRAIRC=crairc_val, Thickness=20.0 + ), + ] + crop_model_params_provider.set_override("soil_profile", soil_profile, check=False) + + def get_test_diff_evapotranspiration_model(device: str = "cpu"): test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -102,7 +151,8 @@ class TestEvapotranspiration: ] @pytest.mark.parametrize("test_data_url", transpiration_data_urls) - def test_evapotranspiration_with_testengine(self, test_data_url, device): + @pytest.mark.parametrize("variant", ["base", "co2", "layered"]) + def test_evapotranspiration_with_testengine(self, test_data_url, variant, device): test_data = get_test_data(test_data_url) crop_model_params = [ "CFET", @@ -124,6 +174,20 @@ def test_evapotranspiration_with_testengine(self, test_data_url, device): test_data, crop_model_params, meteo_range_checks=False, device=device ) + # Augment parameters based on variant to test different implementations + _augment_params_for_variant(crop_model_params_provider, variant) + + # For layered variant, also need to augment external_states with SM as a list and RD + if variant == "layered": + # Convert SM to a 2-layer list structure for each state dict + for state_dict in external_states: + if "SM" in state_dict: + sm_val = state_dict["SM"] + state_dict["SM"] = [sm_val, sm_val] + # Add RD (rooting depth) if not present - use a simple default of 30 cm + if "RD" not in state_dict: + state_dict["RD"] = torch.tensor(30.0, dtype=torch.float64, device=device) + engine = EngineTestHelper( crop_model_params_provider, weather_data_provider, @@ -137,15 +201,30 @@ def test_evapotranspiration_with_testengine(self, test_data_url, device): expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] assert len(actual_results) == len(expected_results) - for reference, model in zip(expected_results, actual_results, strict=False): - assert reference["DAY"] == model["day"] - for var in expected_precision.keys(): - assert model[var].device.type == device, f"{var} should be on {device}" - model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} - assert all( - abs(reference[var] - model_cpu[var]) < precision - for var, precision in expected_precision.items() - ) + + # For layered and CO2 variants, we just verify they run without errors + # (to achieve coverage) but don't check exact values since they use different + # implementations that produce different results + if variant in ("co2", "layered"): + # Just verify we got results with the correct structure + for model in actual_results: + assert "day" in model + for var in expected_precision.keys(): + assert var in model + assert model[var].device.type == device, f"{var} should be on {device}" + else: + # For base variant, check exact values against reference + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = { + k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items() + } + assert all( + abs(reference[var] - model_cpu[var]) < precision + for var, precision in expected_precision.items() + ) @pytest.mark.parametrize( "param", From 2298b1519494fddeb2080796946a239a290b2b93 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Tue, 17 Feb 2026 12:06:26 +0100 Subject: [PATCH 04/45] Fix module --- .../crop/evapotranspiration.py | 501 ++++++++---------- .../crop/test_evapotranspiration.py | 37 +- 2 files changed, 232 insertions(+), 306 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 26549988..62c19235 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -1,9 +1,6 @@ import datetime import torch -from pcse.base import ParamTemplate -from pcse.base import RatesTemplate from pcse.base import SimulationObject -from pcse.base import StatesTemplate from pcse.base.parameter_providers import ParameterProvider from pcse.base.variablekiosk import VariableKiosk from pcse.base.weather import WeatherDataContainer @@ -12,28 +9,14 @@ from pcse.traitlets import Any from pcse.traitlets import Bool from pcse.traitlets import Instance +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate +from diffwofost.physical_models.base import TensorStatesTemplate from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_drv -from diffwofost.physical_models.utils import _get_params_shape - - -def _clamp(x: torch.Tensor, lo: float, hi: float) -> torch.Tensor: - """Clamp tensor values to the range [lo, hi].""" - return torch.clamp(x, min=lo, max=hi) - - -def _as_tensor(x, *, dtype: torch.dtype, device: torch.device) -> torch.Tensor: - """Convert input to a tensor with specified dtype and device.""" - if isinstance(x, torch.Tensor): - t = x - if dtype is not None: - t = t.to(dtype=dtype) - if device is not None: - t = t.to(device=device) - return t - return torch.tensor(x, dtype=dtype, device=device) def SWEAF(ET0: torch.Tensor, DEPNR: torch.Tensor) -> torch.Tensor: @@ -64,7 +47,7 @@ def SWEAF(ET0: torch.Tensor, DEPNR: torch.Tensor) -> torch.Tensor: torch.where(DEPNR >= 3.0, torch.zeros_like(DEPNR), taper_mid), ) sweaf = sweaf + correction * taper - return _clamp(sweaf, 0.10, 0.95) + return torch.clamp(sweaf, min=0.10, max=0.95) class EvapotranspirationWrapper(SimulationObject): @@ -79,7 +62,11 @@ class EvapotranspirationWrapper(SimulationObject): etmodule = Instance(SimulationObject) def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, ) -> None: """Select and initialize the evapotranspiration implementation. @@ -87,11 +74,11 @@ def initialize( based on available parameters. """ if "soil_profile" in parvalues: - self.etmodule = EvapotranspirationCO2Layered(day, kiosk, parvalues) + self.etmodule = EvapotranspirationCO2Layered(day, kiosk, parvalues, shape=shape) elif "CO2TRATB" in parvalues: - self.etmodule = EvapotranspirationCO2(day, kiosk, parvalues) + self.etmodule = EvapotranspirationCO2(day, kiosk, parvalues, shape=shape) else: - self.etmodule = Evapotranspiration(day, kiosk, parvalues) + self.etmodule = Evapotranspiration(day, kiosk, parvalues, shape=shape) @prepare_rates def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): @@ -115,51 +102,29 @@ class _BaseEvapotranspiration(SimulationObject): @property def device(self): - """Get the compute device (CPU or CUDA) from global configuration.""" - return ComputeConfig.get_device() + """Get device from ComputeConfig.""" + return getattr(self, "_device", ComputeConfig.get_device()) @property def dtype(self): - """Get the default data type (float32/float64) from global configuration.""" - return ComputeConfig.get_dtype() - - class RateVariables(RatesTemplate): - EVWMX = Any() - EVSMX = Any() - TRAMX = Any() - TRA = Any() - TRALY = Any() + """Get dtype from ComputeConfig.""" + return getattr(self, "_dtype", ComputeConfig.get_dtype()) + + class RateVariables(TensorRatesTemplate): + EVWMX = Tensor(0.0) + EVSMX = Tensor(0.0) + TRAMX = Tensor(0.0) + TRA = Tensor(0.0) + TRALY = Tensor(0.0) IDOS = Bool(False) IDWS = Bool(False) - RFWS = Any() - RFOS = Any() - RFTRA = Any() - - def __init__(self, kiosk, publish=None): - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - self.EVWMX = torch.tensor(0.0, dtype=dtype, device=device) - self.EVSMX = torch.tensor(0.0, dtype=dtype, device=device) - self.TRAMX = torch.tensor(0.0, dtype=dtype, device=device) - self.TRA = torch.tensor(0.0, dtype=dtype, device=device) - self.TRALY = torch.tensor(0.0, dtype=dtype, device=device) - self.RFWS = torch.tensor(0.0, dtype=dtype, device=device) - self.RFOS = torch.tensor(0.0, dtype=dtype, device=device) - self.RFTRA = torch.tensor(0.0, dtype=dtype, device=device) - super().__init__(kiosk, publish=publish) - - class StateVariables(StatesTemplate): - IDOST = Any() - IDWST = Any() - - def __init__(self, kiosk, publish=None, **kwargs): - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - if "IDOST" not in kwargs: - kwargs["IDOST"] = torch.tensor(0.0, dtype=dtype, device=device) - if "IDWST" not in kwargs: - kwargs["IDWST"] = torch.tensor(0.0, dtype=dtype, device=device) - super().__init__(kiosk, publish=publish, **kwargs) + RFWS = Tensor(0.0) + RFOS = Tensor(0.0) + RFTRA = Tensor(0.0) + + class StateVariables(TensorStatesTemplate): + IDOST = Tensor(-99.0) + IDWST = Tensor(-99.0) def _initialize_base( self, @@ -168,6 +133,7 @@ def _initialize_base( parvalues: ParameterProvider, *, publish_rates: list[str], + shape: tuple | None = None, ) -> None: """Shared initialization for evapotranspiration modules. @@ -176,9 +142,13 @@ def _initialize_base( """ self.kiosk = kiosk self.params = self.Parameters(parvalues) - self.params_shape = _get_params_shape(self.params) - self.rates = self.RateVariables(kiosk, publish=publish_rates) - self.states = self.StateVariables(kiosk, publish=["IDOST", "IDWST"]) + if shape is None: + shape = self.params.shape + self.params_shape = shape + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + self.rates = self.RateVariables(kiosk, publish=publish_rates, shape=shape) + self.states = self.StateVariables(kiosk, shape=shape, IDOST=-999, IDWST=-999) self._epsilon = torch.tensor(1e-12, dtype=self.dtype, device=self.device) def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): @@ -207,9 +177,12 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r = self.rates k = self.kiosk + lai = k["LAI"] + sm = k["SM"] + # [!] DVS needs to be broadcasted explicetly because it is used + # in torch.where and the kiosk does not format it correctly + # TODO see #22 dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) - lai = _broadcast_to(k["LAI"], self.params_shape, dtype=self.dtype, device=self.device) - sm = _broadcast_to(k["SM"], self.params_shape, dtype=self.dtype, device=self.device) et0 = _get_drv(drv.ET0, self.params_shape, dtype=self.dtype, device=self.device) e0 = _get_drv(drv.E0, self.params_shape, dtype=self.dtype, device=self.device) @@ -218,16 +191,16 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None pre_emergence = dvs < 0.0 if bool(torch.all(pre_emergence)): - zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) - ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) - r.EVWMX = zeros - r.EVSMX = zeros - r.TRAMX = zeros - r.TRA = zeros - r.TRALY = zeros - r.RFWS = ones - r.RFOS = ones - r.RFTRA = ones + _z = torch.zeros_like(et0) + _o = torch.ones_like(et0) + r.EVWMX = _z + r.EVSMX = _z + r.TRAMX = _z + r.TRA = _z + r.TRALY = _z + r.RFWS = _o + r.RFOS = _o + r.RFTRA = _o r.IDWS = False r.IDOS = False return r.TRA, r.TRAMX @@ -244,23 +217,23 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None smcr = (1.0 - swdep) * (p.SMFCF - p.SMW) + p.SMW denom = torch.where((smcr - p.SMW).abs() > self._epsilon, (smcr - p.SMW), self._epsilon) - r.RFWS = _clamp((sm - p.SMW) / denom, 0.0, 1.0) + r.RFWS = torch.clamp((sm - p.SMW) / denom, min=0.0, max=1.0) # Oxygen-stress reduction factor (RFOS) - r.RFOS = torch.ones_like(r.RFWS) - iairdu = _broadcast_to(p.IAIRDU, self.params_shape, dtype=self.dtype, device=self.device) - iox = _broadcast_to(p.IOX, self.params_shape, dtype=self.dtype, device=self.device) + r.RFOS = torch.ones_like(et0) + iairdu = p.IAIRDU + iox = p.IOX mask_ox = (iairdu == 0) & (iox == 1) if "DSOS" in k: - dsos = _broadcast_to(k["DSOS"], self.params_shape, dtype=self.dtype, device=self.device) + dsos = k["DSOS"] else: - dsos = torch.zeros_like(r.RFWS) + dsos = torch.zeros_like(dvs) - crairc = _broadcast_to(p.CRAIRC, self.params_shape, dtype=self.dtype, device=self.device) - sm0 = _broadcast_to(p.SM0, self.params_shape, dtype=self.dtype, device=self.device) + crairc = p.CRAIRC + sm0 = p.SM0 denom_ox = torch.where(crairc.abs() > self._epsilon, crairc, self._epsilon) - rfosmx = _clamp((sm0 - sm) / denom_ox, 0.0, 1.0) + rfosmx = torch.clamp((sm0 - sm) / denom_ox, min=0.0, max=1.0) rfos = rfosmx + (1.0 - torch.clamp(dsos, max=4.0) / 4.0) * (1.0 - rfosmx) r.RFOS = torch.where(mask_ox, rfos, r.RFOS) @@ -269,16 +242,14 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.TRALY = r.TRA if bool(torch.any(pre_emergence)): - zeros = torch.zeros_like(r.TRA) - ones = torch.ones_like(r.RFTRA) - r.EVWMX = torch.where(pre_emergence, zeros, r.EVWMX) - r.EVSMX = torch.where(pre_emergence, zeros, r.EVSMX) - r.TRAMX = torch.where(pre_emergence, zeros, r.TRAMX) - r.TRA = torch.where(pre_emergence, zeros, r.TRA) - r.TRALY = torch.where(pre_emergence, zeros, r.TRALY) - r.RFWS = torch.where(pre_emergence, ones, r.RFWS) - r.RFOS = torch.where(pre_emergence, ones, r.RFOS) - r.RFTRA = torch.where(pre_emergence, ones, r.RFTRA) + r.EVWMX = torch.where(pre_emergence, 0.0, r.EVWMX) + r.EVSMX = torch.where(pre_emergence, 0.0, r.EVSMX) + r.TRAMX = torch.where(pre_emergence, 0.0, r.TRAMX) + r.TRA = torch.where(pre_emergence, 0.0, r.TRA) + r.TRALY = torch.where(pre_emergence, 0.0, r.TRALY) + r.RFWS = torch.where(pre_emergence, 1.0, r.RFWS) + r.RFOS = torch.where(pre_emergence, 1.0, r.RFOS) + r.RFTRA = torch.where(pre_emergence, 1.0, r.RFTRA) r.IDWS = bool(torch.any(r.RFWS < 1.0)) r.IDOS = bool(torch.any(r.RFOS < 1.0)) @@ -330,32 +301,23 @@ class Evapotranspiration(_BaseEvapotranspirationNonLayered): | SM | Volumetric soil moisture content | Waterbalance | - | """ - class Parameters(ParamTemplate): - CFET = Any() - DEPNR = Any() + class Parameters(TensorParamTemplate): + CFET = Tensor(-99.0) + DEPNR = Tensor(-99.0) KDIFTB = AfgenTrait() - IAIRDU = Any() - IOX = Any() - CRAIRC = Any() - SM0 = Any() - SMW = Any() - SMFCF = Any() - - def __init__(self, parvalues): - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) - self.DEPNR = torch.tensor(-99.0, dtype=dtype, device=device) - self.IAIRDU = torch.tensor(-99.0, dtype=dtype, device=device) - self.IOX = torch.tensor(-99.0, dtype=dtype, device=device) - self.CRAIRC = torch.tensor(-99.0, dtype=dtype, device=device) - self.SM0 = torch.tensor(-99.0, dtype=dtype, device=device) - self.SMW = torch.tensor(-99.0, dtype=dtype, device=device) - self.SMFCF = torch.tensor(-99.0, dtype=dtype, device=device) - super().__init__(parvalues) + IAIRDU = Tensor(-99.0) + IOX = Tensor(-99.0) + CRAIRC = Tensor(-99.0) + SM0 = Tensor(-99.0) + SMW = Tensor(-99.0) + SMFCF = Tensor(-99.0) def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, ) -> None: """Initialize the standard evapotranspiration module (no CO2 effects).""" self._initialize_base( @@ -363,6 +325,7 @@ def initialize( kiosk, parvalues, publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "RFTRA"], + shape=shape, ) @@ -413,36 +376,25 @@ class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): | SM | Volumetric soil moisture content | Waterbalance | - | """ - class Parameters(ParamTemplate): - CFET = Any() - DEPNR = Any() + class Parameters(TensorParamTemplate): + CFET = Tensor(-99.0) + DEPNR = Tensor(-99.0) KDIFTB = AfgenTrait() - IAIRDU = Any() - IOX = Any() - CRAIRC = Any() - SM0 = Any() - SMW = Any() - SMFCF = Any() - CO2 = Any() + IAIRDU = Tensor(-99.0) + IOX = Tensor(-99.0) + CRAIRC = Tensor(-99.0) + SM0 = Tensor(-99.0) + SMW = Tensor(-99.0) + SMFCF = Tensor(-99.0) + CO2 = Tensor(-99.0) CO2TRATB = AfgenTrait() - def __init__(self, parvalues): - """Initialize CO2-aware parameters with default placeholder values before loading.""" - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) - self.DEPNR = torch.tensor(-99.0, dtype=dtype, device=device) - self.IAIRDU = torch.tensor(-99.0, dtype=dtype, device=device) - self.IOX = torch.tensor(-99.0, dtype=dtype, device=device) - self.CRAIRC = torch.tensor(-99.0, dtype=dtype, device=device) - self.SM0 = torch.tensor(-99.0, dtype=dtype, device=device) - self.SMW = torch.tensor(-99.0, dtype=dtype, device=device) - self.SMFCF = torch.tensor(-99.0, dtype=dtype, device=device) - self.CO2 = torch.tensor(-99.0, dtype=dtype, device=device) - super().__init__(parvalues) - def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, ) -> None: """Initialize the CO2-aware evapotranspiration module.""" self._initialize_base( @@ -450,6 +402,7 @@ def initialize( kiosk, parvalues, publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "TRALY", "RFTRA"], + shape=shape, ) def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: @@ -459,7 +412,7 @@ def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.T if hasattr(drv, "CO2") and drv.CO2 is not None: co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) else: - co2 = _broadcast_to(p.CO2, self.params_shape, dtype=self.dtype, device=self.device) + co2 = p.CO2 return p.CO2TRATB(co2) @@ -515,70 +468,38 @@ class EvapotranspirationCO2Layered(_BaseEvapotranspiration): soil_profile = Any() - class Parameters(ParamTemplate): - CFET = Any() - DEPNR = Any() + class Parameters(TensorParamTemplate): + CFET = Tensor(-99.0) + DEPNR = Tensor(-99.0) KDIFTB = AfgenTrait() - IAIRDU = Any() - IOX = Any() - CO2 = Any() + IAIRDU = Tensor(-99.0) + IOX = Tensor(-99.0) + CO2 = Tensor(-99.0) CO2TRATB = AfgenTrait() - def __init__(self, parvalues): - """Initialize layered CO2-aware parameters with default placeholder values.""" - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - self.CFET = torch.tensor(-99.0, dtype=dtype, device=device) - self.DEPNR = torch.tensor(-99.0, dtype=dtype, device=device) - self.IAIRDU = torch.tensor(-99.0, dtype=dtype, device=device) - self.IOX = torch.tensor(-99.0, dtype=dtype, device=device) - self.CO2 = torch.tensor(-99.0, dtype=dtype, device=device) - super().__init__(parvalues) - - class RateVariables(RatesTemplate): - EVWMX = Any() - EVSMX = Any() - TRAMX = Any() - TRA = Any() - TRALY = Any() + class RateVariables(TensorRatesTemplate): + EVWMX = Tensor(0) + EVSMX = Tensor(0) + TRAMX = Tensor(0) + TRA = Tensor(0) + TRALY = Tensor(0) IDOS = Bool(False) IDWS = Bool(False) - RFWS = Any() - RFOS = Any() - RFTRALY = Any() - RFTRA = Any() - - def __init__(self, kiosk, publish=None): - """Initialize rate variables including per-layer transpiration and stress factors.""" - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - self.EVWMX = torch.tensor(0.0, dtype=dtype, device=device) - self.EVSMX = torch.tensor(0.0, dtype=dtype, device=device) - self.TRAMX = torch.tensor(0.0, dtype=dtype, device=device) - self.TRA = torch.tensor(0.0, dtype=dtype, device=device) - self.TRALY = torch.tensor(0.0, dtype=dtype, device=device) - self.RFWS = torch.tensor(0.0, dtype=dtype, device=device) - self.RFOS = torch.tensor(0.0, dtype=dtype, device=device) - self.RFTRALY = torch.tensor(0.0, dtype=dtype, device=device) - self.RFTRA = torch.tensor(0.0, dtype=dtype, device=device) - super().__init__(kiosk, publish=publish) - - class StateVariables(StatesTemplate): - IDOST = Any() - IDWST = Any() - - def __init__(self, kiosk, publish=None, **kwargs): - """Initialize state variables for layered stress-day counters.""" - dtype = ComputeConfig.get_dtype() - device = ComputeConfig.get_device() - if "IDOST" not in kwargs: - kwargs["IDOST"] = torch.tensor(0.0, dtype=dtype, device=device) - if "IDWST" not in kwargs: - kwargs["IDWST"] = torch.tensor(0.0, dtype=dtype, device=device) - super().__init__(kiosk, publish=publish, **kwargs) + RFWS = Tensor(0) + RFOS = Tensor(0) + RFTRALY = Tensor(0) + RFTRA = Tensor(0) + + class StateVariables(TensorStatesTemplate): + IDOST = Tensor(-99.0) + IDWST = Tensor(-99.0) def initialize( - self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, ) -> None: """Initialize the layered-soil CO2-aware evapotranspiration module. @@ -590,7 +511,31 @@ def initialize( kiosk, parvalues, publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "TRALY", "RFTRA"], + shape=shape, + ) + + # Pre-stack layer soil properties as tensors (avoids repeated + # torch.as_tensor conversions on every calc_rates call). + n_layers = len(self.soil_profile) + self._n_layers = n_layers + self._layer_smw = torch.tensor( + [layer.SMW for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + self._layer_smfcf = torch.tensor( + [layer.SMFCF for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + self._layer_sm0 = torch.tensor( + [layer.SM0 for layer in self.soil_profile], dtype=self.dtype, device=self.device ) + self._layer_crairc = torch.tensor( + [layer.CRAIRC for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + thicknesses = torch.tensor( + [layer.Thickness for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + self._layer_depth_hi = torch.cumsum(thicknesses, dim=0) + self._layer_depth_lo = self._layer_depth_hi - thicknesses + # Internal DSOS tracker for layered oxygen-stress response (vectorized). self._dsos = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) @@ -600,7 +545,7 @@ def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.T if hasattr(drv, "CO2") and drv.CO2 is not None: co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) else: - co2 = _broadcast_to(p.CO2, self.params_shape, dtype=self.dtype, device=self.device) + co2 = p.CO2 return p.CO2TRATB(co2) @prepare_rates @@ -614,12 +559,12 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r = self.rates k = self.kiosk - dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) - lai = _broadcast_to(k["LAI"], self.params_shape, dtype=self.dtype, device=self.device) - rd = _broadcast_to(k["RD"], self.params_shape, dtype=self.dtype, device=self.device) + dvs = k["DVS"] + lai = k["LAI"] + rd = k["RD"] pre_emergence = dvs < 0.0 - n_layers = len(self.soil_profile) + n_layers = self._n_layers et0 = _get_drv(drv.ET0, self.params_shape, dtype=self.dtype, device=self.device) e0 = _get_drv(drv.E0, self.params_shape, dtype=self.dtype, device=self.device) @@ -628,22 +573,17 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None rf_tramx_co2 = self._rf_tramx_co2(drv, et0) if bool(torch.all(pre_emergence)): - zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) - ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) - r.EVWMX = zeros - r.EVSMX = zeros - r.TRAMX = zeros - r.TRA = zeros - r.TRALY = torch.zeros( - (n_layers,) + self.params_shape, dtype=self.dtype, device=self.device - ) - r.RFWS = torch.ones( - (n_layers,) + self.params_shape, dtype=self.dtype, device=self.device - ) - r.RFOS = torch.ones( - (n_layers,) + self.params_shape, dtype=self.dtype, device=self.device - ) - r.RFTRA = ones + _z = torch.zeros_like(et0) + _o = torch.ones_like(et0) + _layered_shape = (n_layers,) + self.params_shape + r.EVWMX = _z + r.EVSMX = _z + r.TRAMX = _z + r.TRA = _z + r.TRALY = torch.zeros(_layered_shape, dtype=self.dtype, device=self.device) + r.RFWS = torch.ones(_layered_shape, dtype=self.dtype, device=self.device) + r.RFOS = torch.ones(_layered_shape, dtype=self.dtype, device=self.device) + r.RFTRA = _o r.IDWS = False r.IDOS = False return r.TRA, r.TRAMX @@ -701,80 +641,63 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None + f"{n_layers}, got {sm_layers_t.shape[0]}." ) - rfws_list = [] - rfos_list = [] - traly_list = [] + # Reshape pre-stacked layer properties for broadcasting against + # (n_layers, *params_shape) tensors: (n_layers,) → (n_layers, 1, 1, ...) + ndim = len(self.params_shape) + expand = (-1,) + (1,) * ndim + layer_smw = self._layer_smw.view(expand) + layer_smfcf = self._layer_smfcf.view(expand) + depth_lo = self._layer_depth_lo.view(expand) + depth_hi = self._layer_depth_hi.view(expand) + + # Vectorised RFWS across all layers: (n_layers, *params_shape) + smcr = (1.0 - swdep) * (layer_smfcf - layer_smw) + layer_smw + denom = torch.where( + (smcr - layer_smw).abs() > self._epsilon, smcr - layer_smw, self._epsilon + ) + r.RFWS = torch.clamp((sm_layers_t - layer_smw) / denom, min=0.0, max=1.0) - depth = 0.0 - for i, layer in enumerate(self.soil_profile): - sm_i = _broadcast_to( - sm_layers_t[i], self.params_shape, dtype=self.dtype, device=self.device - ) - layer_smw = _as_tensor(layer.SMW, dtype=self.dtype, device=self.device) - layer_smfcf = _as_tensor(layer.SMFCF, dtype=self.dtype, device=self.device) + # Vectorised root fraction across all layers: (n_layers, *params_shape) + root_len = torch.clamp(torch.minimum(rd, depth_hi) - depth_lo, min=0.0) + root_fraction = torch.where(rd > self._epsilon, root_len / rd, 0.0) - smcr = (1.0 - swdep) * (layer_smfcf - layer_smw) + layer_smw - denom = torch.where( - (smcr - layer_smw).abs() > self._epsilon, (smcr - layer_smw), self._epsilon - ) - rfws_i = _clamp((sm_i - layer_smw) / denom, 0.0, 1.0) - - rfos_i = torch.ones_like(rfws_i) - iairdu = _broadcast_to( - p.IAIRDU, self.params_shape, dtype=self.dtype, device=self.device - ) - iox = _broadcast_to(p.IOX, self.params_shape, dtype=self.dtype, device=self.device) - if bool(torch.any((iairdu == 0) & (iox == 1))): - layer_sm0 = _as_tensor(layer.SM0, dtype=self.dtype, device=self.device) - layer_crairc = _as_tensor(layer.CRAIRC, dtype=self.dtype, device=self.device) - smair = layer_sm0 - layer_crairc + # Oxygen-stress reduction factor (sequential across layers due to + # temporal _dsos accumulator that feeds forward between layers). + r.RFOS = torch.ones_like(r.RFWS) + mask_ox = (p.IAIRDU == 0) & (p.IOX == 1) + if bool(torch.any(mask_ox)): + layer_sm0 = self._layer_sm0.view(expand) + layer_crairc = self._layer_crairc.view(expand) + for i in range(n_layers): + smair = layer_sm0[i] - layer_crairc[i] self._dsos = torch.where( - sm_i >= smair, + sm_layers_t[i] >= smair, torch.clamp(self._dsos + 1.0, max=4.0), - torch.zeros_like(self._dsos), + 0.0, ) denom_ox = torch.where( - layer_crairc.abs() > self._epsilon, layer_crairc, self._epsilon + layer_crairc[i].abs() > self._epsilon, layer_crairc[i], self._epsilon ) - rfosmx = _clamp((layer_sm0 - sm_i) / denom_ox, 0.0, 1.0) - rfos_i = rfosmx + (1.0 - torch.clamp(self._dsos, max=4.0) / 4.0) * (1.0 - rfosmx) - - thickness = float(layer.Thickness) - depth_lo = _as_tensor(depth, dtype=self.dtype, device=self.device) - depth_hi = _as_tensor(depth + thickness, dtype=self.dtype, device=self.device) - root_len = torch.clamp(torch.minimum(rd, depth_hi) - depth_lo, min=0.0) - root_fraction = torch.where( - rd > self._epsilon, root_len / rd, torch.zeros_like(root_len) - ) - rftra_i = rfos_i * rfws_i - traly_i = r.TRAMX * rftra_i * root_fraction - - rfws_list.append(rfws_i) - rfos_list.append(rfos_i) - traly_list.append(traly_i) - depth += thickness + rfosmx = torch.clamp((layer_sm0[i] - sm_layers_t[i]) / denom_ox, min=0.0, max=1.0) + r.RFOS[i] = rfosmx + (1.0 - torch.clamp(self._dsos, max=4.0) / 4.0) * (1.0 - rfosmx) - r.RFWS = torch.stack(rfws_list, dim=0) - r.RFOS = torch.stack(rfos_list, dim=0) - r.TRALY = torch.stack(traly_list, dim=0) + # Transpiration per layer + rftra = r.RFOS * r.RFWS + r.TRALY = r.TRAMX * rftra * root_fraction r.TRA = r.TRALY.sum(dim=0) - r.RFTRA = torch.where(r.TRAMX > self._epsilon, r.TRA / r.TRAMX, torch.ones_like(r.TRA)) + r.RFTRA = torch.where(r.TRAMX > self._epsilon, r.TRA / r.TRAMX, 1.0) if bool(torch.any(pre_emergence)): - zeros = torch.zeros_like(r.TRA) - ones = torch.ones_like(r.RFTRA) - r.EVWMX = torch.where(pre_emergence, zeros, r.EVWMX) - r.EVSMX = torch.where(pre_emergence, zeros, r.EVSMX) - r.TRAMX = torch.where(pre_emergence, zeros, r.TRAMX) - r.TRA = torch.where(pre_emergence, zeros, r.TRA) - r.RFTRA = torch.where(pre_emergence, ones, r.RFTRA) + r.EVWMX = torch.where(pre_emergence, 0.0, r.EVWMX) + r.EVSMX = torch.where(pre_emergence, 0.0, r.EVSMX) + r.TRAMX = torch.where(pre_emergence, 0.0, r.TRAMX) + r.TRA = torch.where(pre_emergence, 0.0, r.TRA) + r.RFTRA = torch.where(pre_emergence, 1.0, r.RFTRA) pre_layers = pre_emergence.unsqueeze(0).expand_as(r.RFWS) - ones_layers = torch.ones_like(r.RFWS) - zeros_layers = torch.zeros_like(r.TRALY) - r.RFWS = torch.where(pre_layers, ones_layers, r.RFWS) - r.RFOS = torch.where(pre_layers, ones_layers, r.RFOS) - r.TRALY = torch.where(pre_layers, zeros_layers, r.TRALY) + r.RFWS = torch.where(pre_layers, 1.0, r.RFWS) + r.RFOS = torch.where(pre_layers, 1.0, r.RFOS) + r.TRALY = torch.where(pre_layers, 0.0, r.TRALY) r.IDWS = bool(torch.any(r.RFWS < 1.0)) r.IDOS = bool(torch.any(r.RFOS < 1.0)) diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py index 915c37ed..6efda8bd 100644 --- a/tests/physical_models/crop/test_evapotranspiration.py +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -25,12 +25,13 @@ ) -def _augment_params_for_variant(crop_model_params_provider, variant: str): +def _augment_params_for_variant(crop_model_params_provider, variant: str, device: str): """Augment parameters to enable specific evapotranspiration variant. Args: crop_model_params_provider: Base parameter provider variant: One of 'base', 'co2', or 'layered' + device: The device to create tensors on """ if variant == "base": # No augmentation needed @@ -38,18 +39,22 @@ def _augment_params_for_variant(crop_model_params_provider, variant: str): elif variant == "co2": # Add CO2 parameters to enable EvapotranspirationCO2 crop_model_params_provider.set_override( - "CO2", torch.tensor(360.0, dtype=torch.float64), check=False + "CO2", torch.tensor(360.0, dtype=torch.float64, device=device), check=False ) crop_model_params_provider.set_override( - "CO2TRATB", torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64), check=False + "CO2TRATB", + torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64, device=device), + check=False, ) elif variant == "layered": # Add CO2 and soil_profile to enable EvapotranspirationCO2Layered crop_model_params_provider.set_override( - "CO2", torch.tensor(360.0, dtype=torch.float64), check=False + "CO2", torch.tensor(360.0, dtype=torch.float64, device=device), check=False ) crop_model_params_provider.set_override( - "CO2TRATB", torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64), check=False + "CO2TRATB", + torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64, device=device), + check=False, ) # Create a simple two-layer soil profile using existing soil parameters smw = crop_model_params_provider["SMW"] @@ -131,7 +136,6 @@ def forward(self, params_dict: dict[str, torch.Tensor]): self.agro_management_inputs, self.config, self.external_states, - device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -175,7 +179,7 @@ def test_evapotranspiration_with_testengine(self, test_data_url, variant, device ) # Augment parameters based on variant to test different implementations - _augment_params_for_variant(crop_model_params_provider, variant) + _augment_params_for_variant(crop_model_params_provider, variant, device) # For layered variant, also need to augment external_states with SM as a list and RD if variant == "layered": @@ -194,7 +198,6 @@ def test_evapotranspiration_with_testengine(self, test_data_url, variant, device agro_management_inputs, evapotranspiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -274,7 +277,6 @@ def test_evapotranspiration_with_one_parameter_vector(self, param, device): agro_management_inputs, evapotranspiration_config, external_states, - device=device, ) return @@ -290,7 +292,6 @@ def test_evapotranspiration_with_one_parameter_vector(self, param, device): agro_management_inputs, evapotranspiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -357,7 +358,6 @@ def test_evapotranspiration_with_different_parameter_values(self, param, delta, agro_management_inputs, evapotranspiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -410,7 +410,6 @@ def test_evapotranspiration_with_multiple_parameter_vectors(self, device): agro_management_inputs, evapotranspiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -467,7 +466,6 @@ def test_evapotranspiration_with_multiple_parameter_arrays(self, device): agro_management_inputs, evapotranspiration_config, external_states, - device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -510,14 +508,13 @@ def test_evapotranspiration_with_incompatible_parameter_vectors(self): "DEPNR", crop_model_params_provider["DEPNR"].repeat(5), check=False ) - with pytest.raises(AssertionError): + with pytest.raises(ValueError, match="Non-matching shapes found in parameter provider!"): EngineTestHelper( crop_model_params_provider, weather_data_provider, agro_management_inputs, evapotranspiration_config, external_states, - device="cpu", ) def test_evapotranspiration_with_incompatible_weather_parameter_vectors(self): @@ -554,10 +551,9 @@ def test_evapotranspiration_with_incompatible_weather_parameter_vectors(self): agro_management_inputs, evapotranspiration_config, external_states, - device="cpu", ) - @pytest.mark.parametrize("test_data_url", wofost72_data_urls[:1]) + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) def test_wofost_pp_with_evapotranspiration(self, test_data_url): test_data = get_test_data(test_data_url) crop_model_params = [ @@ -587,6 +583,13 @@ def test_wofost_pp_with_evapotranspiration(self, test_data_url): assert len(actual_results) == len(expected_results) for reference, model in zip(expected_results, actual_results, strict=False): assert reference["DAY"] == model["day"] + for var, precision in expected_precision.items(): + if abs(reference[var] - model[var]) >= precision: + print( + f"Mismatch for {var} on day {model['day']}: expected {reference[var]}," + + f" got {model[var]}, diff {abs(reference[var] - model[var])}" + + f", precision {precision}" + ) assert all( abs(reference[var] - model[var]) < precision for var, precision in expected_precision.items() From ac0032d737aa8cc4bb8004d4c7f95a0a1e59f972 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Tue, 17 Feb 2026 12:09:38 +0100 Subject: [PATCH 05/45] Comments --- src/diffwofost/physical_models/crop/evapotranspiration.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 62c19235..6d2d8880 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -514,8 +514,7 @@ def initialize( shape=shape, ) - # Pre-stack layer soil properties as tensors (avoids repeated - # torch.as_tensor conversions on every calc_rates call). + # Pre-stack layer soil properties as tensors n_layers = len(self.soil_profile) self._n_layers = n_layers self._layer_smw = torch.tensor( @@ -536,7 +535,7 @@ def initialize( self._layer_depth_hi = torch.cumsum(thicknesses, dim=0) self._layer_depth_lo = self._layer_depth_hi - thicknesses - # Internal DSOS tracker for layered oxygen-stress response (vectorized). + # Internal DSOS tracker for layered oxygen-stress response self._dsos = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: From 81a0722306f1c179db1ecc15f6cce1b944496235 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 09:58:33 +0100 Subject: [PATCH 06/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/evapotranspiration.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 6d2d8880..167f8a93 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -22,8 +22,19 @@ def SWEAF(ET0: torch.Tensor, DEPNR: torch.Tensor) -> torch.Tensor: """Soil Water Easily Available Fraction (SWEAF). - SWEAF is a function of the potential evapotranspiration rate for a closed - canopy (cm day⁻¹) and the crop dependency number (1..5). + The fraction of easily available soil water between field capacity and + wilting point is a function of the potential evapotranspiration rate (for a + closed canopy) in cm/day, ET0, and the crop group number, DEPNR (from 1 + (=drought-sensitive) to 5 (=drought-resistent)). The function SWEAF + describes this relationship given in tabular form by Doorenbos & Kassam + (1979) and by Van Keulen & Wolf (1986; p.108, table 20) + http://edepot.wur.nl/168025. + + Args: + ET0: The evapotranpiration from a reference crop. + DEPNR: The crop dependency number. + Returns: + SWEAF value between 0.10 and 0.95. """ A = 0.76 B = 1.5 From 4dacae1724d21a672c322fc2525546abe85417a5 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:15:38 +0100 Subject: [PATCH 07/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/evapotranspiration.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 167f8a93..ea73ddf7 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -83,6 +83,16 @@ def initialize( Chooses between layered CO2-aware, non-layered CO2, or standard evapotranspiration based on available parameters. + + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. """ if "soil_profile" in parvalues: self.etmodule = EvapotranspirationCO2Layered(day, kiosk, parvalues, shape=shape) From 7479da8aa4032edaf3875af096b20281a424a97d Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:18:22 +0100 Subject: [PATCH 08/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/evapotranspiration.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index ea73ddf7..9bcd4cc9 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -103,7 +103,14 @@ def initialize( @prepare_rates def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): - """Delegate rate calculation to the selected evapotranspiration module.""" + """Delegate rate calculation to the selected evapotranspiration module. + + Args: + day (datetime.date, optional): The current date of the simulation. + drv (WeatherDataContainer, optional): A dictionary-like container holding + weather data elements as key/value. The values are + arrays or scalars. See PCSE documentation for details. + """ return self.etmodule.calc_rates(day, drv) def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): From 5464c14299889739b594df27145c95840ce6c431 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:21:03 +0100 Subject: [PATCH 09/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 9bcd4cc9..c899fd40 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -119,7 +119,12 @@ def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): @prepare_states def integrate(self, day: datetime.date = None, delt=1.0) -> None: - """Delegate state integration to the selected evapotranspiration module.""" + """Delegate state integration to the selected evapotranspiration module. + + Args: + day (datetime.date, optional): The current date of the simulation. + delt (float, optional): The time step for integration. Defaults to 1.0. + """ return self.etmodule.integrate(day, delt) From 21a1fb5f6139fd034da29b1cb6faa5abd2c96248 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:23:06 +0100 Subject: [PATCH 10/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index c899fd40..dce90d84 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -290,7 +290,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None class Evapotranspiration(_BaseEvapotranspirationNonLayered): - """Potential evaporation and crop transpiration (no CO2 effect). + """Calculation of potential evaporation (water and soil) rates and actual + crop transpiration rate. **Simulation parameters** From 8072a866d819c6a83ef5b79d6f61432ef3e8e56f Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:23:29 +0100 Subject: [PATCH 11/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index dce90d84..16ea6af9 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -298,7 +298,7 @@ class Evapotranspiration(_BaseEvapotranspirationNonLayered): | Name | Description | Type | Unit | |--------|---------------------------------------------------------|------|------| | CFET | Correction factor for potential transpiration rate | SCr | - | - | DEPNR | Crop dependency number (drought sensitivity, 1..5) | SCr | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | SCr | - | | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | From 7a1b6a7b3ab4816488e3ce01396d5a100853e894 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:24:24 +0100 Subject: [PATCH 12/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 16ea6af9..baa35e32 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -322,6 +322,8 @@ class Evapotranspiration(_BaseEvapotranspirationNonLayered): | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | | TRAMX | Max transpiration rate from canopy | Y | cm day⁻¹ | | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | IDOS | Indicates oxygen stress on this day (True|False) | N | - | + | IDWS | Indicates water stress on this day (True|False) | N | - | | RFWS | Reduction factor for water stress | N | - | | RFOS | Reduction factor for oxygen stress | N | - | | RFTRA | Combined reduction factor for transpiration | Y | - | From 8ae48033c99493bb8fdb87f4ecf38ba5711f4e06 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:24:45 +0100 Subject: [PATCH 13/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/evapotranspiration.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index baa35e32..d126d6f6 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -355,7 +355,18 @@ def initialize( parvalues: ParameterProvider, shape: tuple | None = None, ) -> None: - """Initialize the standard evapotranspiration module (no CO2 effects).""" + """Initialize the standard evapotranspiration module (no CO2 effects). + + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. + """ self._initialize_base( day, kiosk, From 4e1bccaef8a92c0c84d5d9e27f5cf2fc8087d4af Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:30:46 +0100 Subject: [PATCH 14/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index d126d6f6..aed72672 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -239,6 +239,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None return r.TRA, r.TRAMX kglob = 0.75 * p.KDIFTB(dvs) + # crop specific correction on potential transpiration rate et0_crop = torch.clamp(p.CFET * et0, min=0.0) ekl = torch.exp(-kglob * lai) From c15f34d7fe5631c745cc3af32048a370041c3e5b Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:31:02 +0100 Subject: [PATCH 15/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index aed72672..c771e925 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -241,6 +241,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None kglob = 0.75 * p.KDIFTB(dvs) # crop specific correction on potential transpiration rate et0_crop = torch.clamp(p.CFET * et0, min=0.0) + # maximum evaporation and transpiration rates ekl = torch.exp(-kglob * lai) r.EVWMX = e0 * ekl From bee0bfa828c840410c4f835766897a5bf0bcd23c Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:31:21 +0100 Subject: [PATCH 16/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index c771e925..20d9c352 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -248,6 +248,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.EVSMX = torch.clamp(es0 * ekl, min=0.0) r.TRAMX = et0_crop * (1.0 - ekl) * rf_tramx_co2 + # Critical soil moisture swdep = SWEAF(et0_crop, p.DEPNR) smcr = (1.0 - swdep) * (p.SMFCF - p.SMW) + p.SMW From 7f1d861acdd58c0109242db374f0b796ba3fac21 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:31:44 +0100 Subject: [PATCH 17/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 20d9c352..0ba50a01 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -252,6 +252,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None swdep = SWEAF(et0_crop, p.DEPNR) smcr = (1.0 - swdep) * (p.SMFCF - p.SMW) + p.SMW + # Reduction factor for transpiration in case of water shortage (RFWS) denom = torch.where((smcr - p.SMW).abs() > self._epsilon, (smcr - p.SMW), self._epsilon) r.RFWS = torch.clamp((sm - p.SMW) / denom, min=0.0, max=1.0) From 04368667a1c76d619068e93aa3ad52c78b123abc Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:33:32 +0100 Subject: [PATCH 18/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 0ba50a01..a2fbaa1c 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -256,7 +256,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None denom = torch.where((smcr - p.SMW).abs() > self._epsilon, (smcr - p.SMW), self._epsilon) r.RFWS = torch.clamp((sm - p.SMW) / denom, min=0.0, max=1.0) - # Oxygen-stress reduction factor (RFOS) + # reduction in transpiration in case of oxygen shortage (RFOS) + # for non-rice crops, and possibly deficient land drainage r.RFOS = torch.ones_like(et0) iairdu = p.IAIRDU iox = p.IOX From 9ff15d35f703b11f67ed7ca319e37a8362bca13f Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:33:51 +0100 Subject: [PATCH 19/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index a2fbaa1c..c3542b2b 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -272,6 +272,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None sm0 = p.SM0 denom_ox = torch.where(crairc.abs() > self._epsilon, crairc, self._epsilon) rfosmx = torch.clamp((sm0 - sm) / denom_ox, min=0.0, max=1.0) + # maximum reduction reached after 4 days rfos = rfosmx + (1.0 - torch.clamp(dsos, max=4.0) / 4.0) * (1.0 - rfosmx) r.RFOS = torch.where(mask_ox, rfos, r.RFOS) From 649bb70885402c147a626d79e928e3cffee55546 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:34:06 +0100 Subject: [PATCH 20/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index c3542b2b..7a3d2264 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -276,6 +276,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None rfos = rfosmx + (1.0 - torch.clamp(dsos, max=4.0) / 4.0) * (1.0 - rfosmx) r.RFOS = torch.where(mask_ox, rfos, r.RFOS) + # Transpiration rate multiplied with reduction factors for oxygen and water r.RFTRA = r.RFOS * r.RFWS r.TRA = r.TRAMX * r.RFTRA r.TRALY = r.TRA From da78b0afa65eaec51f7124f2d8eb8257d9ca41c5 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:34:32 +0100 Subject: [PATCH 21/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 7a3d2264..cc40901e 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -291,6 +291,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.RFOS = torch.where(pre_emergence, 1.0, r.RFOS) r.RFTRA = torch.where(pre_emergence, 1.0, r.RFTRA) + # Counting stress days r.IDWS = bool(torch.any(r.RFWS < 1.0)) r.IDOS = bool(torch.any(r.RFOS < 1.0)) return r.TRA, r.TRAMX From fd859ae5a9391412c14cb7c28b295bf812ba6841 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:35:03 +0100 Subject: [PATCH 22/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index cc40901e..fc610a0c 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -392,7 +392,7 @@ class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): | Name | Description | Type | Unit | |----------|--------------------------------------------------------|------|------| | CFET | Correction factor for potential transpiration rate | SCr | - | - | DEPNR | Crop dependency number (drought sensitivity, 1..5) | SCr | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | SCr | - | | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | From 10354b11f68b8d0014548c174de8eb98a9992ea6 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:35:22 +0100 Subject: [PATCH 23/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index fc610a0c..88b7dabb 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -417,7 +417,9 @@ class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | | TRAMX | Max transpiration rate from canopy (CO2-adjusted) | Y | cm day⁻¹ | - | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | | Y | cm day⁻¹ | + | IDOS | Indicates oxygen stress on this day (True|False) | N | - | + | IDWS | Indicates water stress on this day (True|False) | N | - | | RFWS | Reduction factor for water stress | N | - | | RFOS | Reduction factor for oxygen stress | N | - | | RFTRA | Combined reduction factor for transpiration | Y | - | From 0a30b67cc1efaa7e9ba2401b7fc986f4c03f14d6 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:35:42 +0100 Subject: [PATCH 24/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/evapotranspiration.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 88b7dabb..262ea5a7 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -453,7 +453,17 @@ def initialize( parvalues: ParameterProvider, shape: tuple | None = None, ) -> None: - """Initialize the CO2-aware evapotranspiration module.""" + """Initialize the CO2-aware evapotranspiration module. + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. + """ self._initialize_base( day, kiosk, From 9ff16fc61b41336cc7b26d8738349912c0ea2996 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 10:36:43 +0100 Subject: [PATCH 25/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 262ea5a7..ba12eb6f 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -493,7 +493,7 @@ class EvapotranspirationCO2Layered(_BaseEvapotranspiration): | Name | Description | Type | Unit | |----------|--------------------------------------------------------|------|------| | CFET | Correction factor for potential transpiration rate | SCr | - | - | DEPNR | Crop dependency number (drought sensitivity, 1..5) | SCr | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | SCr | - | | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | From cd077982668a43da117b24581d497c15f480b41c Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 23 Feb 2026 10:43:06 +0100 Subject: [PATCH 26/45] Use STE+sigmoid --- .../crop/evapotranspiration.py | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index ba12eb6f..9c9a0ab3 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -33,6 +33,7 @@ def SWEAF(ET0: torch.Tensor, DEPNR: torch.Tensor) -> torch.Tensor: Args: ET0: The evapotranpiration from a reference crop. DEPNR: The crop dependency number. + Returns: SWEAF value between 0.10 and 0.95. """ @@ -41,23 +42,25 @@ def SWEAF(ET0: torch.Tensor, DEPNR: torch.Tensor) -> torch.Tensor: sweaf = 1.0 / (A + B * ET0) - (5.0 - DEPNR) * 0.10 correction = (ET0 - 0.6) / (DEPNR * (DEPNR + 3.0)) # NOTE: PCSE applies `correction` only when `DEPNR < 3` (hard switch), which - # is non-differentiable at `DEPNR==3` and causes numerical vs autograd - # gradient mismatches when treating DEPNR as a continuous tensor. + # is non-differentiable at `DEPNR==3`. # - # To keep regression behaviour intact we preserve exact values at the - # discrete DEPNR values used in the YAML fixtures (2.0/3.0/3.5/4.5): - # - DEPNR <= 2: full correction - # - DEPNR >= 3: no correction - # and smoothly transition (C1) between 2 and 3 using a cubic smoothstep. - t = DEPNR - 2.0 - s = 3.0 * t**2 - 2.0 * t**3 # smoothstep on [0,1] - taper_mid = 1.0 - s - taper = torch.where( - DEPNR <= 2.0, - torch.ones_like(DEPNR), - torch.where(DEPNR >= 3.0, torch.zeros_like(DEPNR), taper_mid), - ) - sweaf = sweaf + correction * taper + # We use a Straight-Through Estimator (STE): the forward pass applies the + # exact hard switch (identical to PCSE), while the backward pass routes + # gradients through a smooth sigmoid surrogate so that autograd works. + # TODO: sharpness can be exposed as a parameter + _sigmoid_sharpness = 1000.0 + _sigmoid_epsilon = 1e-14 + + # soft mask using sigmoid + soft_mask = torch.sigmoid((3.0 - DEPNR - _sigmoid_epsilon) / _sigmoid_sharpness) + + # original hard mask + hard_mask = (DEPNR < 3.0).to(dtype=DEPNR.dtype) + + # STE method: during forward pass the hard_mask is used, during + # backpropagation the gradient is computed only through soft_mask. + correction_mask = hard_mask.detach() + soft_mask - soft_mask.detach() + sweaf = sweaf + correction * correction_mask return torch.clamp(sweaf, min=0.10, max=0.95) @@ -96,7 +99,7 @@ def initialize( """ if "soil_profile" in parvalues: self.etmodule = EvapotranspirationCO2Layered(day, kiosk, parvalues, shape=shape) - elif "CO2TRATB" in parvalues: + elif "CO2" in parvalues and "CO2TRATB" in parvalues: self.etmodule = EvapotranspirationCO2(day, kiosk, parvalues, shape=shape) else: self.etmodule = Evapotranspiration(day, kiosk, parvalues, shape=shape) @@ -248,7 +251,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.EVSMX = torch.clamp(es0 * ekl, min=0.0) r.TRAMX = et0_crop * (1.0 - ekl) * rf_tramx_co2 - # Critical soil moisture + # Critical soil moisture swdep = SWEAF(et0_crop, p.DEPNR) smcr = (1.0 - swdep) * (p.SMFCF - p.SMW) + p.SMW @@ -298,8 +301,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None class Evapotranspiration(_BaseEvapotranspirationNonLayered): - """Calculation of potential evaporation (water and soil) rates and actual - crop transpiration rate. + """Potential evaporation (water and soil) rates and crop transpiration rate. **Simulation parameters** @@ -392,7 +394,7 @@ class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): | Name | Description | Type | Unit | |----------|--------------------------------------------------------|------|------| | CFET | Correction factor for potential transpiration rate | SCr | - | - | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | SCr | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress | SCr | - | | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | @@ -417,7 +419,7 @@ class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | | TRAMX | Max transpiration rate from canopy (CO2-adjusted) | Y | cm day⁻¹ | - | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | | Y | cm day⁻¹ | + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | | IDOS | Indicates oxygen stress on this day (True|False) | N | - | | IDWS | Indicates water stress on this day (True|False) | N | - | | RFWS | Reduction factor for water stress | N | - | @@ -454,6 +456,7 @@ def initialize( shape: tuple | None = None, ) -> None: """Initialize the CO2-aware evapotranspiration module. + Args: day (datetime.date): The starting date of the simulation. kiosk (VariableKiosk): A container for registering and publishing From e890e26b56fc671aa22502697c9be2434d67fbee Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 23 Feb 2026 11:15:21 +0100 Subject: [PATCH 27/45] Use dvs_mask --- .../crop/evapotranspiration.py | 38 +++++-------------- .../crop/test_evapotranspiration.py | 2 +- 2 files changed, 10 insertions(+), 30 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 9c9a0ab3..938ef91c 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -225,21 +225,9 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None es0 = _get_drv(drv.ES0, self.params_shape, dtype=self.dtype, device=self.device) rf_tramx_co2 = self._rf_tramx_co2(drv, et0) - pre_emergence = dvs < 0.0 - if bool(torch.all(pre_emergence)): - _z = torch.zeros_like(et0) - _o = torch.ones_like(et0) - r.EVWMX = _z - r.EVSMX = _z - r.TRAMX = _z - r.TRA = _z - r.TRALY = _z - r.RFWS = _o - r.RFOS = _o - r.RFTRA = _o - r.IDWS = False - r.IDOS = False - return r.TRA, r.TRAMX + # If DVS < 0, the crop has not yet emerged, so we zero the rates using a mask + # A mask (1 if DVS >= 0, 0 if DVS < 0) + dvs_mask = (dvs >= 0.0).to(dtype=self.dtype) kglob = 0.75 * p.KDIFTB(dvs) # crop specific correction on potential transpiration rate @@ -247,9 +235,9 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # maximum evaporation and transpiration rates ekl = torch.exp(-kglob * lai) - r.EVWMX = e0 * ekl - r.EVSMX = torch.clamp(es0 * ekl, min=0.0) - r.TRAMX = et0_crop * (1.0 - ekl) * rf_tramx_co2 + r.EVWMX = dvs_mask * e0 * ekl + r.EVSMX = dvs_mask * torch.clamp(es0 * ekl, min=0.0) + r.TRAMX = dvs_mask * et0_crop * (1.0 - ekl) * rf_tramx_co2 # Critical soil moisture swdep = SWEAF(et0_crop, p.DEPNR) @@ -257,7 +245,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # Reduction factor for transpiration in case of water shortage (RFWS) denom = torch.where((smcr - p.SMW).abs() > self._epsilon, (smcr - p.SMW), self._epsilon) - r.RFWS = torch.clamp((sm - p.SMW) / denom, min=0.0, max=1.0) + r.RFWS = dvs_mask * torch.clamp((sm - p.SMW) / denom, min=0.0, max=1.0) + (1.0 - dvs_mask) # reduction in transpiration in case of oxygen shortage (RFOS) # for non-rice crops, and possibly deficient land drainage @@ -278,22 +266,14 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # maximum reduction reached after 4 days rfos = rfosmx + (1.0 - torch.clamp(dsos, max=4.0) / 4.0) * (1.0 - rfosmx) r.RFOS = torch.where(mask_ox, rfos, r.RFOS) + # Pre-emergence: RFOS = 1.0 + r.RFOS = dvs_mask * r.RFOS + (1.0 - dvs_mask) # Transpiration rate multiplied with reduction factors for oxygen and water r.RFTRA = r.RFOS * r.RFWS r.TRA = r.TRAMX * r.RFTRA r.TRALY = r.TRA - if bool(torch.any(pre_emergence)): - r.EVWMX = torch.where(pre_emergence, 0.0, r.EVWMX) - r.EVSMX = torch.where(pre_emergence, 0.0, r.EVSMX) - r.TRAMX = torch.where(pre_emergence, 0.0, r.TRAMX) - r.TRA = torch.where(pre_emergence, 0.0, r.TRA) - r.TRALY = torch.where(pre_emergence, 0.0, r.TRALY) - r.RFWS = torch.where(pre_emergence, 1.0, r.RFWS) - r.RFOS = torch.where(pre_emergence, 1.0, r.RFOS) - r.RFTRA = torch.where(pre_emergence, 1.0, r.RFTRA) - # Counting stress days r.IDWS = bool(torch.any(r.RFWS < 1.0)) r.IDOS = bool(torch.any(r.RFOS < 1.0)) diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py index 6efda8bd..56111451 100644 --- a/tests/physical_models/crop/test_evapotranspiration.py +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -692,7 +692,7 @@ class TestDiffEvapotranspirationGradients: }, "tensor": { "CFET": ([0.8, 1.0, 1.2], torch.float64), - "DEPNR": ([1.0, 2.0, 3.0], torch.float64), + "DEPNR": ([1.0, 2.0, 4.0], torch.float64), "KDIFTB": ( [[0.0, 0.60, 2.0, 0.60], [0.0, 0.69, 2.0, 0.69], [0.0, 0.78, 2.0, 0.78]], torch.float64, From 0e33304805e3a3de313c436bd6be3d492827fc1b Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 23 Feb 2026 11:24:23 +0100 Subject: [PATCH 28/45] remove TRALY from base --- src/diffwofost/physical_models/crop/evapotranspiration.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 938ef91c..e84d8a91 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -151,7 +151,6 @@ class RateVariables(TensorRatesTemplate): EVSMX = Tensor(0.0) TRAMX = Tensor(0.0) TRA = Tensor(0.0) - TRALY = Tensor(0.0) IDOS = Bool(False) IDWS = Bool(False) RFWS = Tensor(0.0) @@ -272,7 +271,6 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # Transpiration rate multiplied with reduction factors for oxygen and water r.RFTRA = r.RFOS * r.RFWS r.TRA = r.TRAMX * r.RFTRA - r.TRALY = r.TRA # Counting stress days r.IDWS = bool(torch.any(r.RFWS < 1.0)) @@ -451,7 +449,7 @@ def initialize( day, kiosk, parvalues, - publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "TRALY", "RFTRA"], + publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "RFTRA"], shape=shape, ) From ce78e0366aa95535953e424d5ad25b7bccc990b4 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:41:16 +0100 Subject: [PATCH 29/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index e84d8a91..ccfe557b 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -499,6 +499,8 @@ class EvapotranspirationCO2Layered(_BaseEvapotranspiration): | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | | TRAMX | Max transpiration rate from canopy (CO2-adjusted) | Y | cm day⁻¹ | | TRA | Actual canopy transpiration (sum over layers) | Y | cm day⁻¹ | + | IDOS | Indicates oxygen stress on this day (True|False) | N | - | + | IDWS | Indicates water stress on this day (True|False) | N | - | | TRALY | Transpiration per soil layer | Y | cm day⁻¹ | | RFWS | Water-stress reduction per layer | N | - | | RFOS | Oxygen-stress reduction per layer | N | - | From 5c5c3f19c13c459599b3a0969d0e15a50a6b8ae6 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 23 Feb 2026 12:42:05 +0100 Subject: [PATCH 30/45] Add finalize method --- .../crop/evapotranspiration.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index ccfe557b..45730858 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -130,6 +130,11 @@ def integrate(self, day: datetime.date = None, delt=1.0) -> None: """ return self.etmodule.integrate(day, delt) + @prepare_states + def finalize(self, day: datetime.date) -> None: + """Delegate finalization to the selected evapotranspiration module.""" + self.etmodule.finalize(day) + class _BaseEvapotranspiration(SimulationObject): """Shared base class for evapotranspiration implementations.""" @@ -186,6 +191,10 @@ def _initialize_base( self.states = self.StateVariables(kiosk, shape=shape, IDOST=-999, IDWST=-999) self._epsilon = torch.tensor(1e-12, dtype=self.dtype, device=self.device) + # Private accumulators for stress-day counters (written to states in finalize) + self._IDWST = torch.zeros(shape, dtype=self.dtype, device=self.device) + self._IDOST = torch.zeros(shape, dtype=self.dtype, device=self.device) + def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): """Callable interface for rate calculation.""" return self.calc_rates(day, drv) @@ -195,8 +204,15 @@ def integrate(self, day: datetime.date = None, delt=1.0) -> None: """Accumulate stress-day counters for water and oxygen stress.""" rfws_stress = (self.rates.RFWS < 1.0).to(dtype=self.dtype) rfos_stress = (self.rates.RFOS < 1.0).to(dtype=self.dtype) - self.states.IDWST = self.states.IDWST + rfws_stress - self.states.IDOST = self.states.IDOST + rfos_stress + self._IDWST = self._IDWST + rfws_stress + self._IDOST = self._IDOST + rfos_stress + + @prepare_states + def finalize(self, day: datetime.date) -> None: + """Finalize the evapotranspiration simulation.""" + self.states.IDWST = self._IDWST + self.states.IDOST = self._IDOST + SimulationObject.finalize(self, day) class _BaseEvapotranspirationNonLayered(_BaseEvapotranspiration): @@ -761,5 +777,5 @@ def integrate(self, day: datetime.date = None, delt=1.0) -> None: """Accumulate stress-day counters based on any layer experiencing stress.""" rfws_stress = (self.rates.RFWS < 1.0).any(dim=0).to(dtype=self.dtype) rfos_stress = (self.rates.RFOS < 1.0).any(dim=0).to(dtype=self.dtype) - self.states.IDWST = self.states.IDWST + rfws_stress - self.states.IDOST = self.states.IDOST + rfos_stress + self._IDWST = self._IDWST + rfws_stress + self._IDOST = self._IDOST + rfos_stress From b03effb3e21ecf8eacfc679793711635211a3280 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:42:36 +0100 Subject: [PATCH 31/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/evapotranspiration.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 45730858..57a5c0a2 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -570,6 +570,15 @@ def initialize( """Initialize the layered-soil CO2-aware evapotranspiration module. Sets up layer-specific soil parameters and internal oxygen stress tracking. + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. """ self.soil_profile = parvalues["soil_profile"] self._initialize_base( From 3036a7741a4cf1be20fe6e669a35c2264f98653a Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:48:16 +0100 Subject: [PATCH 32/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 57a5c0a2..846704ab 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -644,6 +644,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None e0 = _get_drv(drv.E0, self.params_shape, dtype=self.dtype, device=self.device) es0 = _get_drv(drv.ES0, self.params_shape, dtype=self.dtype, device=self.device) + # reduction factor for CO2 on TRAMX rf_tramx_co2 = self._rf_tramx_co2(drv, et0) if bool(torch.all(pre_emergence)): From a739166193f5d8c6620814ed743f39b3999929e9 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:49:03 +0100 Subject: [PATCH 33/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 846704ab..fdb46ce5 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -663,6 +663,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.IDOS = False return r.TRA, r.TRAMX + # crop specific correction on potential transpiration rate et0_crop = torch.clamp(p.CFET * et0, min=0.0) kglob = 0.75 * p.KDIFTB(dvs) ekl = torch.exp(-kglob * lai) From 147f9a8e1c16e86d792c256e8796c272f20c6306 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:50:03 +0100 Subject: [PATCH 34/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index fdb46ce5..32f09cdf 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -665,6 +665,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # crop specific correction on potential transpiration rate et0_crop = torch.clamp(p.CFET * et0, min=0.0) + # maximum evaporation and transpiration rates kglob = 0.75 * p.KDIFTB(dvs) ekl = torch.exp(-kglob * lai) r.EVWMX = e0 * ekl From 16f1acf67a5acae604090a5dbc09144ff8aea8e1 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:51:08 +0100 Subject: [PATCH 35/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 32f09cdf..495129ce 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -672,6 +672,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.EVSMX = torch.clamp(es0 * ekl, min=0.0) r.TRAMX = et0_crop * (1.0 - ekl) * rf_tramx_co2 + # Critical soil moisture swdep = SWEAF(et0_crop, p.DEPNR) # Layered soil moisture can be provided as: From 65325a58e5f87eed958a402509e88c31fa3b4663 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:51:26 +0100 Subject: [PATCH 36/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 495129ce..a5049a11 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -733,6 +733,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None denom = torch.where( (smcr - layer_smw).abs() > self._epsilon, smcr - layer_smw, self._epsilon ) + # Reduction factor for transpiration in case of water shortage (RFWS) r.RFWS = torch.clamp((sm_layers_t - layer_smw) / denom, min=0.0, max=1.0) # Vectorised root fraction across all layers: (n_layers, *params_shape) From c07977ab55f00dd24c87fb5fe034a2e43a76d072 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:52:02 +0100 Subject: [PATCH 37/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index a5049a11..6620976f 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -742,6 +742,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # Oxygen-stress reduction factor (sequential across layers due to # temporal _dsos accumulator that feeds forward between layers). + # reduction in transpiration in case of oxygen shortage (RFOS) + # for non-rice crops, and possibly deficient land drainage r.RFOS = torch.ones_like(r.RFWS) mask_ox = (p.IAIRDU == 0) & (p.IOX == 1) if bool(torch.any(mask_ox)): From 3b6393c2a1c33d31fedc97850864d0b2429e7c0b Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Mon, 23 Feb 2026 12:52:32 +0100 Subject: [PATCH 38/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 6620976f..06e31ef4 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -780,6 +780,7 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.RFOS = torch.where(pre_layers, 1.0, r.RFOS) r.TRALY = torch.where(pre_layers, 0.0, r.TRALY) + # Counting stress days r.IDWS = bool(torch.any(r.RFWS < 1.0)) r.IDOS = bool(torch.any(r.RFOS < 1.0)) return r.TRA, r.TRAMX From 5eaf23746c4569cccbdb65eb1579dd7bcaef7099 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 23 Feb 2026 12:55:14 +0100 Subject: [PATCH 39/45] Fix --- .../crop/evapotranspiration.py | 62 +++++++------------ 1 file changed, 21 insertions(+), 41 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 06e31ef4..5359bdbe 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -471,13 +471,11 @@ def initialize( def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: """Calculate CO2 reduction factor for TRAMX based on atmospheric CO2 concentration.""" - p = self.params - if hasattr(drv, "CO2") and drv.CO2 is not None: co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) else: - co2 = p.CO2 - return p.CO2TRATB(co2) + co2 = self.params.CO2 + return self.params.CO2TRATB(co2) class EvapotranspirationCO2Layered(_BaseEvapotranspiration): @@ -570,6 +568,7 @@ def initialize( """Initialize the layered-soil CO2-aware evapotranspiration module. Sets up layer-specific soil parameters and internal oxygen stress tracking. + Args: day (datetime.date): The starting date of the simulation. kiosk (VariableKiosk): A container for registering and publishing @@ -578,7 +577,7 @@ def initialize( parvalues (ParameterProvider): A dictionary-like container holding all parameter sets (crop, soil, site) as key/value. The values are arrays or scalars. See PCSE documentation for details. - shape (tuple | torch.Size | None): Target shape for the state and rate variables. + shape (tuple | torch.Size | None): Target shape for the states and rates. """ self.soil_profile = parvalues["soil_profile"] self._initialize_base( @@ -615,12 +614,11 @@ def initialize( def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: """Calculate CO2 reduction factor for TRAMX using CO2 from driver or parameters.""" - p = self.params if hasattr(drv, "CO2") and drv.CO2 is not None: co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) else: - co2 = p.CO2 - return p.CO2TRATB(co2) + co2 = self.params.CO2 + return self.params.CO2TRATB(co2) @prepare_rates def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): @@ -637,7 +635,6 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None lai = k["LAI"] rd = k["RD"] - pre_emergence = dvs < 0.0 n_layers = self._n_layers et0 = _get_drv(drv.ET0, self.params_shape, dtype=self.dtype, device=self.device) @@ -647,30 +644,20 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # reduction factor for CO2 on TRAMX rf_tramx_co2 = self._rf_tramx_co2(drv, et0) - if bool(torch.all(pre_emergence)): - _z = torch.zeros_like(et0) - _o = torch.ones_like(et0) - _layered_shape = (n_layers,) + self.params_shape - r.EVWMX = _z - r.EVSMX = _z - r.TRAMX = _z - r.TRA = _z - r.TRALY = torch.zeros(_layered_shape, dtype=self.dtype, device=self.device) - r.RFWS = torch.ones(_layered_shape, dtype=self.dtype, device=self.device) - r.RFOS = torch.ones(_layered_shape, dtype=self.dtype, device=self.device) - r.RFTRA = _o - r.IDWS = False - r.IDOS = False - return r.TRA, r.TRAMX + # If DVS < 0, the crop has not yet emerged, so we zero the rates using a mask + # A mask (1 if DVS >= 0, 0 if DVS < 0) + dvs_mask = (dvs >= 0.0).to(dtype=self.dtype) + # Layered mask: (n_layers, *params_shape) + dvs_mask_layers = dvs_mask.unsqueeze(0).expand(n_layers, *self.params_shape) # crop specific correction on potential transpiration rate et0_crop = torch.clamp(p.CFET * et0, min=0.0) # maximum evaporation and transpiration rates kglob = 0.75 * p.KDIFTB(dvs) ekl = torch.exp(-kglob * lai) - r.EVWMX = e0 * ekl - r.EVSMX = torch.clamp(es0 * ekl, min=0.0) - r.TRAMX = et0_crop * (1.0 - ekl) * rf_tramx_co2 + r.EVWMX = dvs_mask * e0 * ekl + r.EVSMX = dvs_mask * torch.clamp(es0 * ekl, min=0.0) + r.TRAMX = dvs_mask * et0_crop * (1.0 - ekl) * rf_tramx_co2 # Critical soil moisture swdep = SWEAF(et0_crop, p.DEPNR) @@ -734,7 +721,9 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None (smcr - layer_smw).abs() > self._epsilon, smcr - layer_smw, self._epsilon ) # Reduction factor for transpiration in case of water shortage (RFWS) - r.RFWS = torch.clamp((sm_layers_t - layer_smw) / denom, min=0.0, max=1.0) + r.RFWS = dvs_mask_layers * torch.clamp( + (sm_layers_t - layer_smw) / denom, min=0.0, max=1.0 + ) + (1.0 - dvs_mask_layers) # Vectorised root fraction across all layers: (n_layers, *params_shape) root_len = torch.clamp(torch.minimum(rd, depth_hi) - depth_lo, min=0.0) @@ -742,8 +731,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None # Oxygen-stress reduction factor (sequential across layers due to # temporal _dsos accumulator that feeds forward between layers). - # reduction in transpiration in case of oxygen shortage (RFOS) - # for non-rice crops, and possibly deficient land drainage + # reduction in transpiration in case of oxygen shortage (RFOS) + # for non-rice crops, and possibly deficient land drainage r.RFOS = torch.ones_like(r.RFWS) mask_ox = (p.IAIRDU == 0) & (p.IOX == 1) if bool(torch.any(mask_ox)): @@ -768,17 +757,8 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None r.TRA = r.TRALY.sum(dim=0) r.RFTRA = torch.where(r.TRAMX > self._epsilon, r.TRA / r.TRAMX, 1.0) - if bool(torch.any(pre_emergence)): - r.EVWMX = torch.where(pre_emergence, 0.0, r.EVWMX) - r.EVSMX = torch.where(pre_emergence, 0.0, r.EVSMX) - r.TRAMX = torch.where(pre_emergence, 0.0, r.TRAMX) - r.TRA = torch.where(pre_emergence, 0.0, r.TRA) - r.RFTRA = torch.where(pre_emergence, 1.0, r.RFTRA) - - pre_layers = pre_emergence.unsqueeze(0).expand_as(r.RFWS) - r.RFWS = torch.where(pre_layers, 1.0, r.RFWS) - r.RFOS = torch.where(pre_layers, 1.0, r.RFOS) - r.TRALY = torch.where(pre_layers, 0.0, r.TRALY) + # Pre-emergence: RFOS = 1.0 + r.RFOS = dvs_mask_layers * r.RFOS + (1.0 - dvs_mask_layers) # Counting stress days r.IDWS = bool(torch.any(r.RFWS < 1.0)) From 3aa7c4c49d8625b898496ea819fde88de6a7f911 Mon Sep 17 00:00:00 2001 From: SCiarella Date: Mon, 23 Feb 2026 13:16:47 +0100 Subject: [PATCH 40/45] Add missing tests --- .../crop/evapotranspiration.py | 22 +++++++++++++++++ .../crop/test_evapotranspiration.py | 24 ++++++++++++++----- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 5359bdbe..32fd217a 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -339,6 +339,28 @@ class Evapotranspiration(_BaseEvapotranspirationNonLayered): | DVS | Crop development stage | Phenology | - | | LAI | Leaf area index | Leaf dynamics | - | | SM | Volumetric soil moisture content | Waterbalance | - | + + + **Outputs** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy | Y | cm day⁻¹ | + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|----------------------------------------------------| + | EVWMX | KDIFTB | + | EVSMX | KDIFTB | + | TRAMX | CFET, KDIFTB | + | TRA | CFET, KDIFTB, DEPNR, SMFCF, SMW, CRAIRC, SM0 | + | RFTRA | CFET, DEPNR, SMFCF, SMW, CRAIRC, SM0 | + """ class Parameters(TensorParamTemplate): diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py index 56111451..a49c75ac 100644 --- a/tests/physical_models/crop/test_evapotranspiration.py +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -21,7 +21,7 @@ evapotranspiration_config = Configuration( CROP=EvapotranspirationWrapper, - OUTPUT_VARS=["EVSMX", "EVWMX", "TRAMX", "TRA"], + OUTPUT_VARS=["EVSMX", "EVWMX", "TRAMX", "TRA", "RFTRA"], ) @@ -141,7 +141,7 @@ def forward(self, params_dict: dict[str, torch.Tensor]): results = engine.get_output() return { var: torch.stack([item[var] for item in results]) - for var in ["EVSMX", "EVWMX", "TRAMX", "TRA"] + for var in ["EVSMX", "EVWMX", "TRAMX", "TRA", "RFTRA"] } @@ -681,14 +681,18 @@ def _kiosk_with_states(): class TestDiffEvapotranspirationGradients: - param_names = ["CFET", "DEPNR", "KDIFTB"] - output_names = ["EVWMX", "EVSMX", "TRAMX", "TRA"] + param_names = ["CFET", "DEPNR", "KDIFTB", "SMW", "SMFCF", "SM0", "CRAIRC"] + output_names = ["EVWMX", "EVSMX", "TRAMX", "TRA", "RFTRA"] param_configs = { "single": { "CFET": (1.0, torch.float64), "DEPNR": (2.0, torch.float64), "KDIFTB": ([[0.0, 0.69, 2.0, 0.69]], torch.float64), + "SMW": (0.15, torch.float64), + "SMFCF": (0.29, torch.float64), + "SM0": (0.40, torch.float64), + "CRAIRC": (0.06, torch.float64), }, "tensor": { "CFET": ([0.8, 1.0, 1.2], torch.float64), @@ -697,13 +701,21 @@ class TestDiffEvapotranspirationGradients: [[0.0, 0.60, 2.0, 0.60], [0.0, 0.69, 2.0, 0.69], [0.0, 0.78, 2.0, 0.78]], torch.float64, ), + "SMW": ([0.12, 0.15, 0.18], torch.float64), + "SMFCF": ([0.26, 0.29, 0.32], torch.float64), + "SM0": ([0.37, 0.40, 0.43], torch.float64), + "CRAIRC": ([0.04, 0.06, 0.08], torch.float64), }, } gradient_mapping = { - "CFET": ["TRAMX", "TRA"], - "DEPNR": ["TRA"], + "CFET": ["TRAMX", "TRA", "RFTRA"], + "DEPNR": ["TRA", "RFTRA"], "KDIFTB": ["EVWMX", "EVSMX", "TRAMX", "TRA"], + "SMW": ["TRA", "RFTRA"], + "SMFCF": ["TRA", "RFTRA"], + "SM0": ["TRA", "RFTRA"], + "CRAIRC": ["TRA", "RFTRA"], } gradient_params = [] From cbc337a652e9dc92e6a2f09dac9c984f9f686b76 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:24:53 +0100 Subject: [PATCH 41/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 32fd217a..fa9933d3 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -519,6 +519,14 @@ class EvapotranspirationCO2Layered(_BaseEvapotranspiration): Layer-specific soil parameters (SMW, SMFCF, SM0, CRAIRC, Thickness) are taken from `soil_profile` entries. + **Soil parameters** + + | Name | Description | Type | Unit | + |----------|--------------------------------------------------------|------|------| + | SMW | Volumetric soil moisture content at wilting point | S | - | + | SMFCF | Volumetric soil moisture content at field capacity | S | - | + | SM0 | Soil porosity | S | - | + | CRAIRC | Critical air content for root aeration | S | - | **State variables** From ae0bf1ba35a7d5270f245dc335bf5413d7d3a5e2 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:25:06 +0100 Subject: [PATCH 42/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../physical_models/crop/evapotranspiration.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index fa9933d3..f4227a63 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -509,13 +509,13 @@ class EvapotranspirationCO2Layered(_BaseEvapotranspiration): | Name | Description | Type | Unit | |----------|--------------------------------------------------------|------|------| - | CFET | Correction factor for potential transpiration rate | SCr | - | - | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | SCr | - | - | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | - | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | - | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | - | CO2 | Atmospheric CO2 concentration (used if not in drivers) | SCr | ppm | - | CO2TRATB | Reduction factor for TRAMX as function of CO2 | TCr | - | + | CFET | Correction factor for potential transpiration rate | S | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | S | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | T | - | + | IAIRDU | Switch airducts on (1) or off (0) | S | - | + | IOX | Switch oxygen stress on (1) or off (0) | S | - | + | CO2 | Atmospheric CO2 concentration (used if not in drivers) | S | ppm | + | CO2TRATB | Reduction factor for TRAMX as function of CO2 | T | - | Layer-specific soil parameters (SMW, SMFCF, SM0, CRAIRC, Thickness) are taken from `soil_profile` entries. From b627a278c6da984b609e02e45de3fa69c8ef1488 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:25:16 +0100 Subject: [PATCH 43/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index f4227a63..9d89b127 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -546,8 +546,8 @@ class EvapotranspirationCO2Layered(_BaseEvapotranspiration): | IDOS | Indicates oxygen stress on this day (True|False) | N | - | | IDWS | Indicates water stress on this day (True|False) | N | - | | TRALY | Transpiration per soil layer | Y | cm day⁻¹ | - | RFWS | Water-stress reduction per layer | N | - | - | RFOS | Oxygen-stress reduction per layer | N | - | + | RFWS | Water-stress reduction per layer | Y | - | + | RFOS | Oxygen-stress reduction per layer | Y | - | | RFTRA | Combined reduction factor for transpiration | Y | - | **External dependencies** From 0c88073da9d1093e4be56014f8c97069dd00f44b Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:25:28 +0100 Subject: [PATCH 44/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- .../crop/evapotranspiration.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 9d89b127..161fb5fb 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -409,17 +409,17 @@ class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): | Name | Description | Type | Unit | |----------|--------------------------------------------------------|------|------| - | CFET | Correction factor for potential transpiration rate | SCr | - | - | DEPNR | Dependency number for crop sensitivity to soil moisture stress | SCr | - | - | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | - | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | - | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | - | CRAIRC | Critical air content for root aeration | SSo | - | - | SM0 | Soil porosity | SSo | - | - | SMW | Volumetric soil moisture at wilting point | SSo | - | - | SMFCF | Volumetric soil moisture at field capacity | SSo | - | - | CO2 | Atmospheric CO2 concentration (used if not in drivers) | SCr | ppm | - | CO2TRATB | Reduction factor for TRAMX as function of CO2 | TCr | - | + | CFET | Correction factor for potential transpiration rate | S | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress | S | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | T | - | + | IAIRDU | Switch airducts on (1) or off (0) | S | - | + | IOX | Switch oxygen stress on (1) or off (0) | S | - | + | CRAIRC | Critical air content for root aeration | S | - | + | SM0 | Soil porosity | S | - | + | SMW | Volumetric soil moisture at wilting point | S | - | + | SMFCF | Volumetric soil moisture at field capacity | S | - | + | CO2 | Atmospheric CO2 concentration (used if not in drivers) | S | ppm | + | CO2TRATB | Reduction factor for TRAMX as function of CO2 | T | - | **State variables** From 99e148cbbac1491058eed44a2e0ece39e267e314 Mon Sep 17 00:00:00 2001 From: SCiarella <58949181+SCiarella@users.noreply.github.com> Date: Tue, 24 Feb 2026 09:25:43 +0100 Subject: [PATCH 45/45] Update src/diffwofost/physical_models/crop/evapotranspiration.py Co-authored-by: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> --- src/diffwofost/physical_models/crop/evapotranspiration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py index 161fb5fb..812cea7b 100644 --- a/src/diffwofost/physical_models/crop/evapotranspiration.py +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -438,8 +438,8 @@ class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | | IDOS | Indicates oxygen stress on this day (True|False) | N | - | | IDWS | Indicates water stress on this day (True|False) | N | - | - | RFWS | Reduction factor for water stress | N | - | - | RFOS | Reduction factor for oxygen stress | N | - | + | RFWS | Reduction factor for water stress | Y | - | + | RFOS | Reduction factor for oxygen stress | Y | - | | RFTRA | Combined reduction factor for transpiration | Y | - | **External dependencies**