diff --git a/docs/api_reference.md b/docs/api_reference.md index 2b3460d..b67bb58 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -19,6 +19,8 @@ hide: ::: diffwofost.physical_models.crop.respiration.WOFOST_Maintenance_Respiration +::: diffwofost.physical_models.crop.evapotranspiration.Evapotranspiration + ## **Utility (under development)** ::: diffwofost.physical_models.config.Configuration diff --git a/src/diffwofost/physical_models/crop/evapotranspiration.py b/src/diffwofost/physical_models/crop/evapotranspiration.py new file mode 100644 index 0000000..812cea7 --- /dev/null +++ b/src/diffwofost/physical_models/crop/evapotranspiration.py @@ -0,0 +1,808 @@ +import datetime +import torch +from pcse.base import SimulationObject +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.base.weather import WeatherDataContainer +from pcse.decorators import prepare_rates +from pcse.decorators import prepare_states +from pcse.traitlets import Any +from pcse.traitlets import Bool +from pcse.traitlets import Instance +from diffwofost.physical_models.base import TensorParamTemplate +from diffwofost.physical_models.base import TensorRatesTemplate +from diffwofost.physical_models.base import TensorStatesTemplate +from diffwofost.physical_models.config import ComputeConfig +from diffwofost.physical_models.traitlets import Tensor +from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_drv + + +def SWEAF(ET0: torch.Tensor, DEPNR: torch.Tensor) -> torch.Tensor: + """Soil Water Easily Available Fraction (SWEAF). + + The fraction of easily available soil water between field capacity and + wilting point is a function of the potential evapotranspiration rate (for a + closed canopy) in cm/day, ET0, and the crop group number, DEPNR (from 1 + (=drought-sensitive) to 5 (=drought-resistent)). The function SWEAF + describes this relationship given in tabular form by Doorenbos & Kassam + (1979) and by Van Keulen & Wolf (1986; p.108, table 20) + http://edepot.wur.nl/168025. + + Args: + ET0: The evapotranpiration from a reference crop. + DEPNR: The crop dependency number. + + Returns: + SWEAF value between 0.10 and 0.95. + """ + A = 0.76 + B = 1.5 + sweaf = 1.0 / (A + B * ET0) - (5.0 - DEPNR) * 0.10 + correction = (ET0 - 0.6) / (DEPNR * (DEPNR + 3.0)) + # NOTE: PCSE applies `correction` only when `DEPNR < 3` (hard switch), which + # is non-differentiable at `DEPNR==3`. + # + # We use a Straight-Through Estimator (STE): the forward pass applies the + # exact hard switch (identical to PCSE), while the backward pass routes + # gradients through a smooth sigmoid surrogate so that autograd works. + # TODO: sharpness can be exposed as a parameter + _sigmoid_sharpness = 1000.0 + _sigmoid_epsilon = 1e-14 + + # soft mask using sigmoid + soft_mask = torch.sigmoid((3.0 - DEPNR - _sigmoid_epsilon) / _sigmoid_sharpness) + + # original hard mask + hard_mask = (DEPNR < 3.0).to(dtype=DEPNR.dtype) + + # STE method: during forward pass the hard_mask is used, during + # backpropagation the gradient is computed only through soft_mask. + correction_mask = hard_mask.detach() + soft_mask - soft_mask.detach() + sweaf = sweaf + correction * correction_mask + return torch.clamp(sweaf, min=0.10, max=0.95) + + +class EvapotranspirationWrapper(SimulationObject): + """Selects the evapotranspiration implementation. + + Selection logic: + - If `soil_profile` is present in parameters: use the layered CO2-aware module. + - Else if `CO2TRATB` is present: use the non-layered CO2 module. + - Else: use the non-layered (no CO2) module. + """ + + etmodule = Instance(SimulationObject) + + def initialize( + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, + ) -> None: + """Select and initialize the evapotranspiration implementation. + + Chooses between layered CO2-aware, non-layered CO2, or standard evapotranspiration + based on available parameters. + + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. + """ + if "soil_profile" in parvalues: + self.etmodule = EvapotranspirationCO2Layered(day, kiosk, parvalues, shape=shape) + elif "CO2" in parvalues and "CO2TRATB" in parvalues: + self.etmodule = EvapotranspirationCO2(day, kiosk, parvalues, shape=shape) + else: + self.etmodule = Evapotranspiration(day, kiosk, parvalues, shape=shape) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Delegate rate calculation to the selected evapotranspiration module. + + Args: + day (datetime.date, optional): The current date of the simulation. + drv (WeatherDataContainer, optional): A dictionary-like container holding + weather data elements as key/value. The values are + arrays or scalars. See PCSE documentation for details. + """ + return self.etmodule.calc_rates(day, drv) + + def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Callable interface for rate calculation.""" + return self.calc_rates(day, drv) + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Delegate state integration to the selected evapotranspiration module. + + Args: + day (datetime.date, optional): The current date of the simulation. + delt (float, optional): The time step for integration. Defaults to 1.0. + """ + return self.etmodule.integrate(day, delt) + + @prepare_states + def finalize(self, day: datetime.date) -> None: + """Delegate finalization to the selected evapotranspiration module.""" + self.etmodule.finalize(day) + + +class _BaseEvapotranspiration(SimulationObject): + """Shared base class for evapotranspiration implementations.""" + + params_shape = None + + @property + def device(self): + """Get device from ComputeConfig.""" + return getattr(self, "_device", ComputeConfig.get_device()) + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return getattr(self, "_dtype", ComputeConfig.get_dtype()) + + class RateVariables(TensorRatesTemplate): + EVWMX = Tensor(0.0) + EVSMX = Tensor(0.0) + TRAMX = Tensor(0.0) + TRA = Tensor(0.0) + IDOS = Bool(False) + IDWS = Bool(False) + RFWS = Tensor(0.0) + RFOS = Tensor(0.0) + RFTRA = Tensor(0.0) + + class StateVariables(TensorStatesTemplate): + IDOST = Tensor(-99.0) + IDWST = Tensor(-99.0) + + def _initialize_base( + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + *, + publish_rates: list[str], + shape: tuple | None = None, + ) -> None: + """Shared initialization for evapotranspiration modules. + + Sets up parameters, rate and state variables, and numerical epsilon for all + evapotranspiration implementations. + """ + self.kiosk = kiosk + self.params = self.Parameters(parvalues) + if shape is None: + shape = self.params.shape + self.params_shape = shape + self._device = ComputeConfig.get_device() + self._dtype = ComputeConfig.get_dtype() + self.rates = self.RateVariables(kiosk, publish=publish_rates, shape=shape) + self.states = self.StateVariables(kiosk, shape=shape, IDOST=-999, IDWST=-999) + self._epsilon = torch.tensor(1e-12, dtype=self.dtype, device=self.device) + + # Private accumulators for stress-day counters (written to states in finalize) + self._IDWST = torch.zeros(shape, dtype=self.dtype, device=self.device) + self._IDOST = torch.zeros(shape, dtype=self.dtype, device=self.device) + + def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Callable interface for rate calculation.""" + return self.calc_rates(day, drv) + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Accumulate stress-day counters for water and oxygen stress.""" + rfws_stress = (self.rates.RFWS < 1.0).to(dtype=self.dtype) + rfos_stress = (self.rates.RFOS < 1.0).to(dtype=self.dtype) + self._IDWST = self._IDWST + rfws_stress + self._IDOST = self._IDOST + rfos_stress + + @prepare_states + def finalize(self, day: datetime.date) -> None: + """Finalize the evapotranspiration simulation.""" + self.states.IDWST = self._IDWST + self.states.IDOST = self._IDOST + SimulationObject.finalize(self, day) + + +class _BaseEvapotranspirationNonLayered(_BaseEvapotranspiration): + """Shared implementation for non-layered evapotranspiration.""" + + def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + """Return CO2 reduction factor for TRAMX (no CO2 effect in base implementation).""" + return torch.ones_like(et0) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + p = self.params + r = self.rates + k = self.kiosk + + lai = k["LAI"] + sm = k["SM"] + # [!] DVS needs to be broadcasted explicetly because it is used + # in torch.where and the kiosk does not format it correctly + # TODO see #22 + dvs = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) + + et0 = _get_drv(drv.ET0, self.params_shape, dtype=self.dtype, device=self.device) + e0 = _get_drv(drv.E0, self.params_shape, dtype=self.dtype, device=self.device) + es0 = _get_drv(drv.ES0, self.params_shape, dtype=self.dtype, device=self.device) + rf_tramx_co2 = self._rf_tramx_co2(drv, et0) + + # If DVS < 0, the crop has not yet emerged, so we zero the rates using a mask + # A mask (1 if DVS >= 0, 0 if DVS < 0) + dvs_mask = (dvs >= 0.0).to(dtype=self.dtype) + + kglob = 0.75 * p.KDIFTB(dvs) + # crop specific correction on potential transpiration rate + et0_crop = torch.clamp(p.CFET * et0, min=0.0) + # maximum evaporation and transpiration rates + ekl = torch.exp(-kglob * lai) + + r.EVWMX = dvs_mask * e0 * ekl + r.EVSMX = dvs_mask * torch.clamp(es0 * ekl, min=0.0) + r.TRAMX = dvs_mask * et0_crop * (1.0 - ekl) * rf_tramx_co2 + + # Critical soil moisture + swdep = SWEAF(et0_crop, p.DEPNR) + smcr = (1.0 - swdep) * (p.SMFCF - p.SMW) + p.SMW + + # Reduction factor for transpiration in case of water shortage (RFWS) + denom = torch.where((smcr - p.SMW).abs() > self._epsilon, (smcr - p.SMW), self._epsilon) + r.RFWS = dvs_mask * torch.clamp((sm - p.SMW) / denom, min=0.0, max=1.0) + (1.0 - dvs_mask) + + # reduction in transpiration in case of oxygen shortage (RFOS) + # for non-rice crops, and possibly deficient land drainage + r.RFOS = torch.ones_like(et0) + iairdu = p.IAIRDU + iox = p.IOX + mask_ox = (iairdu == 0) & (iox == 1) + + if "DSOS" in k: + dsos = k["DSOS"] + else: + dsos = torch.zeros_like(dvs) + + crairc = p.CRAIRC + sm0 = p.SM0 + denom_ox = torch.where(crairc.abs() > self._epsilon, crairc, self._epsilon) + rfosmx = torch.clamp((sm0 - sm) / denom_ox, min=0.0, max=1.0) + # maximum reduction reached after 4 days + rfos = rfosmx + (1.0 - torch.clamp(dsos, max=4.0) / 4.0) * (1.0 - rfosmx) + r.RFOS = torch.where(mask_ox, rfos, r.RFOS) + # Pre-emergence: RFOS = 1.0 + r.RFOS = dvs_mask * r.RFOS + (1.0 - dvs_mask) + + # Transpiration rate multiplied with reduction factors for oxygen and water + r.RFTRA = r.RFOS * r.RFWS + r.TRA = r.TRAMX * r.RFTRA + + # Counting stress days + r.IDWS = bool(torch.any(r.RFWS < 1.0)) + r.IDOS = bool(torch.any(r.RFOS < 1.0)) + return r.TRA, r.TRAMX + + +class Evapotranspiration(_BaseEvapotranspirationNonLayered): + """Potential evaporation (water and soil) rates and crop transpiration rate. + + **Simulation parameters** + + | Name | Description | Type | Unit | + |--------|---------------------------------------------------------|------|------| + | CFET | Correction factor for potential transpiration rate | SCr | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | SCr | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | TCr | - | + | IAIRDU | Switch airducts on (1) or off (0) | SCr | - | + | IOX | Switch oxygen stress on (1) or off (0) | SCr | - | + | CRAIRC | Critical air content for root aeration | SSo | - | + | SM0 | Soil porosity | SSo | - | + | SMW | Volumetric soil moisture at wilting point | SSo | - | + | SMFCF | Volumetric soil moisture at field capacity | SSo | - | + + **State variables** + + | Name | Description | Pbl | Unit | + |-------|------------------------------------|-----|------| + | IDWST | Number of days with water stress | N | - | + | IDOST | Number of days with oxygen stress | N | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy | Y | cm day⁻¹ | + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | IDOS | Indicates oxygen stress on this day (True|False) | N | - | + | IDWS | Indicates water stress on this day (True|False) | N | - | + | RFWS | Reduction factor for water stress | N | - | + | RFOS | Reduction factor for oxygen stress | N | - | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|----------------------------------|---------------|------| + | DVS | Crop development stage | Phenology | - | + | LAI | Leaf area index | Leaf dynamics | - | + | SM | Volumetric soil moisture content | Waterbalance | - | + + + **Outputs** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy | Y | cm day⁻¹ | + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|----------------------------------------------------| + | EVWMX | KDIFTB | + | EVSMX | KDIFTB | + | TRAMX | CFET, KDIFTB | + | TRA | CFET, KDIFTB, DEPNR, SMFCF, SMW, CRAIRC, SM0 | + | RFTRA | CFET, DEPNR, SMFCF, SMW, CRAIRC, SM0 | + + """ + + class Parameters(TensorParamTemplate): + CFET = Tensor(-99.0) + DEPNR = Tensor(-99.0) + KDIFTB = AfgenTrait() + IAIRDU = Tensor(-99.0) + IOX = Tensor(-99.0) + CRAIRC = Tensor(-99.0) + SM0 = Tensor(-99.0) + SMW = Tensor(-99.0) + SMFCF = Tensor(-99.0) + + def initialize( + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, + ) -> None: + """Initialize the standard evapotranspiration module (no CO2 effects). + + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. + """ + self._initialize_base( + day, + kiosk, + parvalues, + publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "RFTRA"], + shape=shape, + ) + + +class EvapotranspirationCO2(_BaseEvapotranspirationNonLayered): + """Potential evaporation and crop transpiration with CO2 effect on TRAMX. + + **Simulation parameters** + + | Name | Description | Type | Unit | + |----------|--------------------------------------------------------|------|------| + | CFET | Correction factor for potential transpiration rate | S | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress | S | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | T | - | + | IAIRDU | Switch airducts on (1) or off (0) | S | - | + | IOX | Switch oxygen stress on (1) or off (0) | S | - | + | CRAIRC | Critical air content for root aeration | S | - | + | SM0 | Soil porosity | S | - | + | SMW | Volumetric soil moisture at wilting point | S | - | + | SMFCF | Volumetric soil moisture at field capacity | S | - | + | CO2 | Atmospheric CO2 concentration (used if not in drivers) | S | ppm | + | CO2TRATB | Reduction factor for TRAMX as function of CO2 | T | - | + + **State variables** + + | Name | Description | Pbl | Unit | + |-------|------------------------------------|-----|------| + | IDWST | Number of days with water stress | N | - | + | IDOST | Number of days with oxygen stress | N | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy (CO2-adjusted) | Y | cm day⁻¹ | + | TRA | Actual transpiration rate from canopy | Y | cm day⁻¹ | + | IDOS | Indicates oxygen stress on this day (True|False) | N | - | + | IDWS | Indicates water stress on this day (True|False) | N | - | + | RFWS | Reduction factor for water stress | Y | - | + | RFOS | Reduction factor for oxygen stress | Y | - | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|----------------------------------|---------------|------| + | DVS | Crop development stage | Phenology | - | + | LAI | Leaf area index | Leaf dynamics | - | + | SM | Volumetric soil moisture content | Waterbalance | - | + """ + + class Parameters(TensorParamTemplate): + CFET = Tensor(-99.0) + DEPNR = Tensor(-99.0) + KDIFTB = AfgenTrait() + IAIRDU = Tensor(-99.0) + IOX = Tensor(-99.0) + CRAIRC = Tensor(-99.0) + SM0 = Tensor(-99.0) + SMW = Tensor(-99.0) + SMFCF = Tensor(-99.0) + CO2 = Tensor(-99.0) + CO2TRATB = AfgenTrait() + + def initialize( + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, + ) -> None: + """Initialize the CO2-aware evapotranspiration module. + + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the state and rate variables. + """ + self._initialize_base( + day, + kiosk, + parvalues, + publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "RFTRA"], + shape=shape, + ) + + def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + """Calculate CO2 reduction factor for TRAMX based on atmospheric CO2 concentration.""" + if hasattr(drv, "CO2") and drv.CO2 is not None: + co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) + else: + co2 = self.params.CO2 + return self.params.CO2TRATB(co2) + + +class EvapotranspirationCO2Layered(_BaseEvapotranspiration): + """Layered-soil evapotranspiration with CO2 effect on TRAMX. + + This implementation expects a layered soil water balance. + + **Simulation parameters** + + | Name | Description | Type | Unit | + |----------|--------------------------------------------------------|------|------| + | CFET | Correction factor for potential transpiration rate | S | - | + | DEPNR | Dependency number for crop sensitivity to soil moisture stress. | S | - | + | KDIFTB | Extinction coefficient for diffuse visible light vs DVS | T | - | + | IAIRDU | Switch airducts on (1) or off (0) | S | - | + | IOX | Switch oxygen stress on (1) or off (0) | S | - | + | CO2 | Atmospheric CO2 concentration (used if not in drivers) | S | ppm | + | CO2TRATB | Reduction factor for TRAMX as function of CO2 | T | - | + + Layer-specific soil parameters (SMW, SMFCF, SM0, CRAIRC, Thickness) are + taken from `soil_profile` entries. + **Soil parameters** + + | Name | Description | Type | Unit | + |----------|--------------------------------------------------------|------|------| + | SMW | Volumetric soil moisture content at wilting point | S | - | + | SMFCF | Volumetric soil moisture content at field capacity | S | - | + | SM0 | Soil porosity | S | - | + | CRAIRC | Critical air content for root aeration | S | - | + + **State variables** + + | Name | Description | Pbl | Unit | + |-------|------------------------------------|-----|------| + | IDWST | Number of days with water stress | N | - | + | IDOST | Number of days with oxygen stress | N | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |-------|-----------------------------------------------------|-----|-----------| + | EVWMX | Max evaporation rate from open water surface | Y | cm day⁻¹ | + | EVSMX | Max evaporation rate from wet soil surface | Y | cm day⁻¹ | + | TRAMX | Max transpiration rate from canopy (CO2-adjusted) | Y | cm day⁻¹ | + | TRA | Actual canopy transpiration (sum over layers) | Y | cm day⁻¹ | + | IDOS | Indicates oxygen stress on this day (True|False) | N | - | + | IDWS | Indicates water stress on this day (True|False) | N | - | + | TRALY | Transpiration per soil layer | Y | cm day⁻¹ | + | RFWS | Water-stress reduction per layer | Y | - | + | RFOS | Oxygen-stress reduction per layer | Y | - | + | RFTRA | Combined reduction factor for transpiration | Y | - | + + **External dependencies** + + | Name | Description | Provided by | Unit | + |------|----------------------------------|---------------|------| + | DVS | Crop development stage | Phenology | - | + | LAI | Leaf area index | Leaf dynamics | - | + | RD | Rooting depth | Root dynamics | cm | + | SM | Soil moisture per layer | Waterbalance | - | + """ + + soil_profile = Any() + + class Parameters(TensorParamTemplate): + CFET = Tensor(-99.0) + DEPNR = Tensor(-99.0) + KDIFTB = AfgenTrait() + IAIRDU = Tensor(-99.0) + IOX = Tensor(-99.0) + CO2 = Tensor(-99.0) + CO2TRATB = AfgenTrait() + + class RateVariables(TensorRatesTemplate): + EVWMX = Tensor(0) + EVSMX = Tensor(0) + TRAMX = Tensor(0) + TRA = Tensor(0) + TRALY = Tensor(0) + IDOS = Bool(False) + IDWS = Bool(False) + RFWS = Tensor(0) + RFOS = Tensor(0) + RFTRALY = Tensor(0) + RFTRA = Tensor(0) + + class StateVariables(TensorStatesTemplate): + IDOST = Tensor(-99.0) + IDWST = Tensor(-99.0) + + def initialize( + self, + day: datetime.date, + kiosk: VariableKiosk, + parvalues: ParameterProvider, + shape: tuple | None = None, + ) -> None: + """Initialize the layered-soil CO2-aware evapotranspiration module. + + Sets up layer-specific soil parameters and internal oxygen stress tracking. + + Args: + day (datetime.date): The starting date of the simulation. + kiosk (VariableKiosk): A container for registering and publishing + (internal and external) state variables. See PCSE documentation for + details. + parvalues (ParameterProvider): A dictionary-like container holding + all parameter sets (crop, soil, site) as key/value. The values are + arrays or scalars. See PCSE documentation for details. + shape (tuple | torch.Size | None): Target shape for the states and rates. + """ + self.soil_profile = parvalues["soil_profile"] + self._initialize_base( + day, + kiosk, + parvalues, + publish_rates=["EVWMX", "EVSMX", "TRAMX", "TRA", "TRALY", "RFTRA"], + shape=shape, + ) + + # Pre-stack layer soil properties as tensors + n_layers = len(self.soil_profile) + self._n_layers = n_layers + self._layer_smw = torch.tensor( + [layer.SMW for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + self._layer_smfcf = torch.tensor( + [layer.SMFCF for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + self._layer_sm0 = torch.tensor( + [layer.SM0 for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + self._layer_crairc = torch.tensor( + [layer.CRAIRC for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + thicknesses = torch.tensor( + [layer.Thickness for layer in self.soil_profile], dtype=self.dtype, device=self.device + ) + self._layer_depth_hi = torch.cumsum(thicknesses, dim=0) + self._layer_depth_lo = self._layer_depth_hi - thicknesses + + # Internal DSOS tracker for layered oxygen-stress response + self._dsos = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + + def _rf_tramx_co2(self, drv: WeatherDataContainer, et0: torch.Tensor) -> torch.Tensor: + """Calculate CO2 reduction factor for TRAMX using CO2 from driver or parameters.""" + if hasattr(drv, "CO2") and drv.CO2 is not None: + co2 = _get_drv(drv.CO2, self.params_shape, dtype=self.dtype, device=self.device) + else: + co2 = self.params.CO2 + return self.params.CO2TRATB(co2) + + @prepare_rates + def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Calculate daily evapotranspiration rates per soil layer with CO2 effects. + + Computes transpiration and stress factors for each soil layer based on root + distribution and layer-specific soil moisture conditions. + """ + p = self.params + r = self.rates + k = self.kiosk + + dvs = k["DVS"] + lai = k["LAI"] + rd = k["RD"] + + n_layers = self._n_layers + + et0 = _get_drv(drv.ET0, self.params_shape, dtype=self.dtype, device=self.device) + e0 = _get_drv(drv.E0, self.params_shape, dtype=self.dtype, device=self.device) + es0 = _get_drv(drv.ES0, self.params_shape, dtype=self.dtype, device=self.device) + + # reduction factor for CO2 on TRAMX + rf_tramx_co2 = self._rf_tramx_co2(drv, et0) + + # If DVS < 0, the crop has not yet emerged, so we zero the rates using a mask + # A mask (1 if DVS >= 0, 0 if DVS < 0) + dvs_mask = (dvs >= 0.0).to(dtype=self.dtype) + # Layered mask: (n_layers, *params_shape) + dvs_mask_layers = dvs_mask.unsqueeze(0).expand(n_layers, *self.params_shape) + + # crop specific correction on potential transpiration rate + et0_crop = torch.clamp(p.CFET * et0, min=0.0) + # maximum evaporation and transpiration rates + kglob = 0.75 * p.KDIFTB(dvs) + ekl = torch.exp(-kglob * lai) + r.EVWMX = dvs_mask * e0 * ekl + r.EVSMX = dvs_mask * torch.clamp(es0 * ekl, min=0.0) + r.TRAMX = dvs_mask * et0_crop * (1.0 - ekl) * rf_tramx_co2 + + # Critical soil moisture + swdep = SWEAF(et0_crop, p.DEPNR) + + # Layered soil moisture can be provided as: + # - torch.Tensor with shape (n_layers, *params_shape) + # - list/tuple of length n_layers, each element scalar or tensor + sm_layers = k["SM"] + if isinstance(sm_layers, torch.Tensor): + sm_layers_t = sm_layers.to(dtype=self.dtype, device=self.device) + elif isinstance(sm_layers, (list, tuple)): + if len(sm_layers) != n_layers: + raise ValueError( + "Layered evapotranspiration expects SM with " + + f"{n_layers} layers, got {len(sm_layers)}." + ) + sm_layers_t = torch.stack( + [ + _broadcast_to(sm_i, self.params_shape, dtype=self.dtype, device=self.device) + for sm_i in sm_layers + ], + dim=0, + ) + else: + sm_layers_t = torch.as_tensor(sm_layers, dtype=self.dtype, device=self.device) + if sm_layers_t.dim() == 1: + # Interpret as per-layer scalars + if sm_layers_t.shape[0] != n_layers: + raise ValueError( + "Layered evapotranspiration expects SM with " + + f"{n_layers} layers, got {sm_layers_t.shape[0]}." + ) + sm_layers_t = torch.stack( + [ + _broadcast_to( + sm_layers_t[i], self.params_shape, dtype=self.dtype, device=self.device + ) + for i in range(n_layers) + ], + dim=0, + ) + + if sm_layers_t.shape[0] != n_layers: + raise ValueError( + "Layered evapotranspiration expects SM first dim to be " + + f"{n_layers}, got {sm_layers_t.shape[0]}." + ) + + # Reshape pre-stacked layer properties for broadcasting against + # (n_layers, *params_shape) tensors: (n_layers,) → (n_layers, 1, 1, ...) + ndim = len(self.params_shape) + expand = (-1,) + (1,) * ndim + layer_smw = self._layer_smw.view(expand) + layer_smfcf = self._layer_smfcf.view(expand) + depth_lo = self._layer_depth_lo.view(expand) + depth_hi = self._layer_depth_hi.view(expand) + + # Vectorised RFWS across all layers: (n_layers, *params_shape) + smcr = (1.0 - swdep) * (layer_smfcf - layer_smw) + layer_smw + denom = torch.where( + (smcr - layer_smw).abs() > self._epsilon, smcr - layer_smw, self._epsilon + ) + # Reduction factor for transpiration in case of water shortage (RFWS) + r.RFWS = dvs_mask_layers * torch.clamp( + (sm_layers_t - layer_smw) / denom, min=0.0, max=1.0 + ) + (1.0 - dvs_mask_layers) + + # Vectorised root fraction across all layers: (n_layers, *params_shape) + root_len = torch.clamp(torch.minimum(rd, depth_hi) - depth_lo, min=0.0) + root_fraction = torch.where(rd > self._epsilon, root_len / rd, 0.0) + + # Oxygen-stress reduction factor (sequential across layers due to + # temporal _dsos accumulator that feeds forward between layers). + # reduction in transpiration in case of oxygen shortage (RFOS) + # for non-rice crops, and possibly deficient land drainage + r.RFOS = torch.ones_like(r.RFWS) + mask_ox = (p.IAIRDU == 0) & (p.IOX == 1) + if bool(torch.any(mask_ox)): + layer_sm0 = self._layer_sm0.view(expand) + layer_crairc = self._layer_crairc.view(expand) + for i in range(n_layers): + smair = layer_sm0[i] - layer_crairc[i] + self._dsos = torch.where( + sm_layers_t[i] >= smair, + torch.clamp(self._dsos + 1.0, max=4.0), + 0.0, + ) + denom_ox = torch.where( + layer_crairc[i].abs() > self._epsilon, layer_crairc[i], self._epsilon + ) + rfosmx = torch.clamp((layer_sm0[i] - sm_layers_t[i]) / denom_ox, min=0.0, max=1.0) + r.RFOS[i] = rfosmx + (1.0 - torch.clamp(self._dsos, max=4.0) / 4.0) * (1.0 - rfosmx) + + # Transpiration per layer + rftra = r.RFOS * r.RFWS + r.TRALY = r.TRAMX * rftra * root_fraction + r.TRA = r.TRALY.sum(dim=0) + r.RFTRA = torch.where(r.TRAMX > self._epsilon, r.TRA / r.TRAMX, 1.0) + + # Pre-emergence: RFOS = 1.0 + r.RFOS = dvs_mask_layers * r.RFOS + (1.0 - dvs_mask_layers) + + # Counting stress days + r.IDWS = bool(torch.any(r.RFWS < 1.0)) + r.IDOS = bool(torch.any(r.RFOS < 1.0)) + return r.TRA, r.TRAMX + + def __call__(self, day: datetime.date = None, drv: WeatherDataContainer = None): + """Callable interface for rate calculation.""" + return self.calc_rates(day, drv) + + @prepare_states + def integrate(self, day: datetime.date = None, delt=1.0) -> None: + """Accumulate stress-day counters based on any layer experiencing stress.""" + rfws_stress = (self.rates.RFWS < 1.0).any(dim=0).to(dtype=self.dtype) + rfos_stress = (self.rates.RFOS < 1.0).any(dim=0).to(dtype=self.dtype) + self._IDWST = self._IDWST + rfws_stress + self._IDOST = self._IDOST + rfos_stress diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index 5f07619..f7d0558 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -211,6 +211,19 @@ def prepare_engine_input( weather_data_provider = WeatherDataProviderTestHelper( test_data["WeatherVariables"], meteo_range_checks=meteo_range_checks ) + + # The PCSE WeatherDataContainer stores required variables as Python floats. + # Some of our tests rely on weather inputs being torch.Tensors (e.g. to + # broadcast/batch weather variables). We only do this conversion when + # METEO_RANGE_CHECKS is disabled because the PCSE range checks assume + # scalar floats. + if not meteo_range_checks: + for (_, _), wdc in weather_data_provider.store.items(): + for varname in ("IRRAD", "TMIN", "TMAX", "VAP", "RAIN", "WIND", "E0", "ES0", "ET0"): + if hasattr(wdc, varname): + value = getattr(wdc, varname) + if not isinstance(value, torch.Tensor): + setattr(wdc, varname, torch.tensor(value, dtype=dtype, device=device)) crop_model_params_provider = ParameterProvider(cropdata=cropd) external_states = test_data.get("ExternalStates") or [] diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index 6fc73ec..3e4c394 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -14,6 +14,7 @@ "phenology", "partitioning", "assimilation", + "transpiration", "respiration", ] FILE_NAMES = [ diff --git a/tests/physical_models/crop/test_evapotranspiration.py b/tests/physical_models/crop/test_evapotranspiration.py new file mode 100644 index 0000000..a49c75a --- /dev/null +++ b/tests/physical_models/crop/test_evapotranspiration.py @@ -0,0 +1,798 @@ +import copy +import datetime +import warnings +from types import SimpleNamespace +from unittest.mock import patch +import pytest +import torch +from pcse.base.parameter_providers import ParameterProvider +from pcse.base.variablekiosk import VariableKiosk +from pcse.models import Wofost72_PP +from diffwofost.physical_models.config import Configuration +from diffwofost.physical_models.crop.evapotranspiration import Evapotranspiration +from diffwofost.physical_models.crop.evapotranspiration import EvapotranspirationCO2 +from diffwofost.physical_models.crop.evapotranspiration import EvapotranspirationCO2Layered +from diffwofost.physical_models.crop.evapotranspiration import EvapotranspirationWrapper +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +evapotranspiration_config = Configuration( + CROP=EvapotranspirationWrapper, + OUTPUT_VARS=["EVSMX", "EVWMX", "TRAMX", "TRA", "RFTRA"], +) + + +def _augment_params_for_variant(crop_model_params_provider, variant: str, device: str): + """Augment parameters to enable specific evapotranspiration variant. + + Args: + crop_model_params_provider: Base parameter provider + variant: One of 'base', 'co2', or 'layered' + device: The device to create tensors on + """ + if variant == "base": + # No augmentation needed + return + elif variant == "co2": + # Add CO2 parameters to enable EvapotranspirationCO2 + crop_model_params_provider.set_override( + "CO2", torch.tensor(360.0, dtype=torch.float64, device=device), check=False + ) + crop_model_params_provider.set_override( + "CO2TRATB", + torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64, device=device), + check=False, + ) + elif variant == "layered": + # Add CO2 and soil_profile to enable EvapotranspirationCO2Layered + crop_model_params_provider.set_override( + "CO2", torch.tensor(360.0, dtype=torch.float64, device=device), check=False + ) + crop_model_params_provider.set_override( + "CO2TRATB", + torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=torch.float64, device=device), + check=False, + ) + # Create a simple two-layer soil profile using existing soil parameters + smw = crop_model_params_provider["SMW"] + smfcf = crop_model_params_provider["SMFCF"] + sm0 = crop_model_params_provider["SM0"] + crairc = crop_model_params_provider["CRAIRC"] + + # Convert to Python scalars if they are tensors + smw_val = float(smw.item() if isinstance(smw, torch.Tensor) else smw) + smfcf_val = float(smfcf.item() if isinstance(smfcf, torch.Tensor) else smfcf) + sm0_val = float(sm0.item() if isinstance(sm0, torch.Tensor) else sm0) + crairc_val = float(crairc.item() if isinstance(crairc, torch.Tensor) else crairc) + + soil_profile = [ + SimpleNamespace( + SMW=smw_val, SMFCF=smfcf_val, SM0=sm0_val, CRAIRC=crairc_val, Thickness=10.0 + ), + SimpleNamespace( + SMW=smw_val, SMFCF=smfcf_val, SM0=sm0_val, CRAIRC=crairc_val, Thickness=20.0 + ), + ] + crop_model_params_provider.set_override("soil_profile", soil_profile, check=False) + + +def get_test_diff_evapotranspiration_model(device: str = "cpu"): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( + prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False, device=device) + ) + return DiffEvapotranspiration( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + copy.deepcopy(external_states), + device=device, + ) + + +class DiffEvapotranspiration(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config, + external_states, + device: str = "cpu", + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config = config + self.external_states = external_states + self.device = device + + def forward(self, params_dict: dict[str, torch.Tensor]): + for name, value in params_dict.items(): + if isinstance(value, torch.Tensor) and value.device.type != self.device: + value = value.to(self.device) + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config, + self.external_states, + ) + engine.run_till_terminate() + results = engine.get_output() + return { + var: torch.stack([item[var] for item in results]) + for var in ["EVSMX", "EVWMX", "TRAMX", "TRA", "RFTRA"] + } + + +class TestEvapotranspiration: + transpiration_data_urls = [ + f"{phy_data_folder}/test_transpiration_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + @pytest.mark.parametrize("test_data_url", transpiration_data_urls) + @pytest.mark.parametrize("variant", ["base", "co2", "layered"]) + def test_evapotranspiration_with_testengine(self, test_data_url, variant, device): + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + # Augment parameters based on variant to test different implementations + _augment_params_for_variant(crop_model_params_provider, variant, device) + + # For layered variant, also need to augment external_states with SM as a list and RD + if variant == "layered": + # Convert SM to a 2-layer list structure for each state dict + for state_dict in external_states: + if "SM" in state_dict: + sm_val = state_dict["SM"] + state_dict["SM"] = [sm_val, sm_val] + # Add RD (rooting depth) if not present - use a simple default of 30 cm + if "RD" not in state_dict: + state_dict["RD"] = torch.tensor(30.0, dtype=torch.float64, device=device) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + + # For layered and CO2 variants, we just verify they run without errors + # (to achieve coverage) but don't check exact values since they use different + # implementations that produce different results + if variant in ("co2", "layered"): + # Just verify we got results with the correct structure + for model in actual_results: + assert "day" in model + for var in expected_precision.keys(): + assert var in model + assert model[var].device.type == device, f"{var} should be on {device}" + else: + # For base variant, check exact values against reference + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = { + k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items() + } + assert all( + abs(reference[var] - model_cpu[var]) < precision + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize( + "param", + [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + "ET0", + ], + ) + def test_evapotranspiration_with_one_parameter_vector(self, param, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + if param == "ET0": + for (_, _), wdc in weather_data_provider.store.items(): + wdc.ET0 = torch.ones(10, dtype=torch.float64, device=wdc.ET0.device) * wdc.ET0 + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + return + + if param == "KDIFTB": + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + all(abs(reference[var] - model_cpu[var]) < precision) + for var, precision in expected_precision.items() + ) + + @pytest.mark.parametrize( + "param,delta", + [ + ("CFET", 0.1), + ("DEPNR", 1.0), + ("KDIFTB", 0.05), + ("SMW", 0.01), + ("SMFCF", 0.01), + ("SM0", 0.01), + ], + ) + def test_evapotranspiration_with_different_parameter_values(self, param, delta, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + test_value = crop_model_params_provider[param] + if param == "KDIFTB": + non_zeros_mask = test_value != 0 + param_vec = torch.stack([test_value + non_zeros_mask * delta, test_value]) + else: + param_vec = torch.tensor( + [test_value - delta, test_value + delta, test_value], device=device + ) + crop_model_params_provider.set_override(param, param_vec, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert all( + abs(reference[var] - model_cpu[var][-1]) < precision + for var, precision in expected_precision.items() + ) + + def test_evapotranspiration_with_multiple_parameter_vectors(self, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + for param in crop_model_params: + if param == "KDIFTB": + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + + def test_evapotranspiration_with_multiple_parameter_arrays(self, device): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) + + # Use an arbitrary batched shape and keep weather vars consistent. + batch_shape = (30, 5) + for param in ("CFET", "DEPNR", "KDIFTB"): + if param == "KDIFTB": + repeated = crop_model_params_provider[param].repeat(*batch_shape, 1) + else: + repeated = crop_model_params_provider[param].broadcast_to(batch_shape) + crop_model_params_provider.set_override(param, repeated, check=False) + + for (_, _), wdc in weather_data_provider.store.items(): + wdc.ET0 = torch.ones(batch_shape, dtype=torch.float64, device=wdc.ET0.device) * wdc.ET0 + wdc.E0 = torch.ones(batch_shape, dtype=torch.float64, device=wdc.E0.device) * wdc.E0 + wdc.ES0 = torch.ones(batch_shape, dtype=torch.float64, device=wdc.ES0.device) * wdc.ES0 + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + assert all( + torch.all(abs(reference[var] - model[var]) < precision) + for var, precision in expected_precision.items() + ) + assert all(model[var].shape == batch_shape for var in expected_precision.keys()) + + def test_evapotranspiration_with_incompatible_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + + crop_model_params_provider.set_override( + "CFET", crop_model_params_provider["CFET"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "DEPNR", crop_model_params_provider["DEPNR"].repeat(5), check=False + ) + + with pytest.raises(ValueError, match="Non-matching shapes found in parameter provider!"): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + + def test_evapotranspiration_with_incompatible_weather_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_transpiration_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + + crop_model_params_provider.set_override( + "CFET", crop_model_params_provider["CFET"].repeat(10), check=False + ) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.ET0 = torch.ones(5, dtype=torch.float64, device=wdc.ET0.device) * wdc.ET0 + + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + evapotranspiration_config, + external_states, + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) + def test_wofost_pp_with_evapotranspiration(self, test_data_url): + test_data = get_test_data(test_data_url) + crop_model_params = [ + "CFET", + "DEPNR", + "KDIFTB", + "IAIRDU", + "IOX", + "CRAIRC", + "SM0", + "SMW", + "SMFCF", + ] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + ) + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.Evapotranspiration", EvapotranspirationWrapper): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert reference["DAY"] == model["day"] + for var, precision in expected_precision.items(): + if abs(reference[var] - model[var]) >= precision: + print( + f"Mismatch for {var} on day {model['day']}: expected {reference[var]}," + + f" got {model[var]}, diff {abs(reference[var] - model[var])}" + + f", precision {precision}" + ) + assert all( + abs(reference[var] - model[var]) < precision + for var, precision in expected_precision.items() + ) + + +def _minimal_parvalues(device: str, *, include_co2: bool = False, include_layers: bool = False): + dtype = torch.float64 + pars: dict[str, object] = { + "CFET": torch.tensor(1.0, dtype=dtype, device=device), + "DEPNR": torch.tensor(2.0, dtype=dtype, device=device), + "KDIFTB": torch.tensor([0.0, 0.69, 2.0, 0.69], dtype=dtype, device=device), + "IAIRDU": torch.tensor(0.0, dtype=dtype, device=device), + "IOX": torch.tensor(0.0, dtype=dtype, device=device), + "CRAIRC": torch.tensor(0.06, dtype=dtype, device=device), + "SM0": torch.tensor(0.40, dtype=dtype, device=device), + "SMW": torch.tensor(0.15, dtype=dtype, device=device), + "SMFCF": torch.tensor(0.29, dtype=dtype, device=device), + } + + if include_co2: + pars.update( + { + "CO2": torch.tensor(700.0, dtype=dtype, device=device), + "CO2TRATB": torch.tensor([0.0, 1.0, 1000.0, 0.5], dtype=dtype, device=device), + } + ) + + if include_layers: + soil_profile = [ + SimpleNamespace(SMW=0.15, SMFCF=0.29, SM0=0.40, CRAIRC=0.06, Thickness=10.0), + SimpleNamespace(SMW=0.16, SMFCF=0.30, SM0=0.41, CRAIRC=0.06, Thickness=20.0), + ] + pars["soil_profile"] = soil_profile + + return ParameterProvider(cropdata=pars) + + +class TestEvapotranspirationVariants: + def test_wrapper_selects_base(self, device): + parvalues = _minimal_parvalues(device) + kiosk = VariableKiosk() + wrapper = EvapotranspirationWrapper(datetime.date(2000, 1, 1), kiosk, parvalues) + assert isinstance(wrapper.etmodule, Evapotranspiration) + + def test_wrapper_selects_co2(self, device): + parvalues = _minimal_parvalues(device, include_co2=True) + kiosk = VariableKiosk() + wrapper = EvapotranspirationWrapper(datetime.date(2000, 1, 1), kiosk, parvalues) + assert isinstance(wrapper.etmodule, EvapotranspirationCO2) + + def test_wrapper_selects_layered(self, device): + parvalues = _minimal_parvalues(device, include_co2=True, include_layers=True) + kiosk = VariableKiosk() + wrapper = EvapotranspirationWrapper(datetime.date(2000, 1, 1), kiosk, parvalues) + assert isinstance(wrapper.etmodule, EvapotranspirationCO2Layered) + + def test_co2_reduces_tramx(self, device): + def _kiosk_with_states(): + kiosk = VariableKiosk() + oid = 0 + for name in ("DVS", "LAI", "SM"): + kiosk.register_variable(oid, name, type="S", publish=True) + kiosk.set_variable(oid, "DVS", torch.tensor(1.0, dtype=torch.float64, device=device)) + kiosk.set_variable(oid, "LAI", torch.tensor(3.0, dtype=torch.float64, device=device)) + kiosk.set_variable(oid, "SM", torch.tensor(0.25, dtype=torch.float64, device=device)) + return kiosk + + drv = SimpleNamespace( + ET0=torch.tensor(0.5, dtype=torch.float64, device=device), + E0=torch.tensor(0.6, dtype=torch.float64, device=device), + ES0=torch.tensor(0.55, dtype=torch.float64, device=device), + CO2=torch.tensor(700.0, dtype=torch.float64, device=device), + ) + + p_base = _minimal_parvalues(device) + kiosk_base = _kiosk_with_states() + base = Evapotranspiration(datetime.date(2000, 1, 1), kiosk_base, p_base) + base.calc_rates(datetime.date(2000, 1, 2), drv) + tramx_base = base.rates.TRAMX + + p_co2 = _minimal_parvalues(device, include_co2=True) + kiosk_co2 = _kiosk_with_states() + co2 = EvapotranspirationCO2(datetime.date(2000, 1, 1), kiosk_co2, p_co2) + co2.calc_rates(datetime.date(2000, 1, 2), drv) + tramx_co2 = co2.rates.TRAMX + + assert torch.all(tramx_co2 <= tramx_base) + + +class TestDiffEvapotranspirationGradients: + param_names = ["CFET", "DEPNR", "KDIFTB", "SMW", "SMFCF", "SM0", "CRAIRC"] + output_names = ["EVWMX", "EVSMX", "TRAMX", "TRA", "RFTRA"] + + param_configs = { + "single": { + "CFET": (1.0, torch.float64), + "DEPNR": (2.0, torch.float64), + "KDIFTB": ([[0.0, 0.69, 2.0, 0.69]], torch.float64), + "SMW": (0.15, torch.float64), + "SMFCF": (0.29, torch.float64), + "SM0": (0.40, torch.float64), + "CRAIRC": (0.06, torch.float64), + }, + "tensor": { + "CFET": ([0.8, 1.0, 1.2], torch.float64), + "DEPNR": ([1.0, 2.0, 4.0], torch.float64), + "KDIFTB": ( + [[0.0, 0.60, 2.0, 0.60], [0.0, 0.69, 2.0, 0.69], [0.0, 0.78, 2.0, 0.78]], + torch.float64, + ), + "SMW": ([0.12, 0.15, 0.18], torch.float64), + "SMFCF": ([0.26, 0.29, 0.32], torch.float64), + "SM0": ([0.37, 0.40, 0.43], torch.float64), + "CRAIRC": ([0.04, 0.06, 0.08], torch.float64), + }, + } + + gradient_mapping = { + "CFET": ["TRAMX", "TRA", "RFTRA"], + "DEPNR": ["TRA", "RFTRA"], + "KDIFTB": ["EVWMX", "EVSMX", "TRAMX", "TRA"], + "SMW": ["TRA", "RFTRA"], + "SMFCF": ["TRA", "RFTRA"], + "SM0": ["TRA", "RFTRA"], + "CRAIRC": ["TRA", "RFTRA"], + } + + gradient_params = [] + no_gradient_params = [] + for param_name in param_names: + for output_name in output_names: + if output_name in gradient_mapping.get(param_name, []): + gradient_params.append((param_name, output_name)) + else: + no_gradient_params.append((param_name, output_name)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type, device): + model = get_test_diff_evapotranspiration_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + output = model({param_name: param}) + loss = output[output_name].sum() + + if not loss.requires_grad: + grads = None + else: + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): + model = get_test_diff_evapotranspiration_model(device=device) + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) + + output = model({param_name: param}) + loss = output[output_name].sum() + + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + assert grads is not None, f"Gradients for {param_name} should not be None" + + param.grad = None + loss.backward() + grad_backward = param.grad + + assert grad_backward is not None, f"Backward gradients for {param_name} should not be None" + assert torch.all(grad_backward == grads), "Forward and backward gradients should match" + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type, device): + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) + + numerical_grad = calculate_numerical_grad( + lambda: get_test_diff_evapotranspiration_model(device=device), + param_name, + param, + output_name, + ) + + model = get_test_diff_evapotranspiration_model(device=device) + output = model({param_name: param}) + loss = output[output_name].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + + torch.testing.assert_close( + numerical_grad.detach().cpu(), + grads.detach().cpu(), + rtol=1e-3, + atol=1e-3, + ) + + if torch.all(grads == 0): + warnings.warn( + f"Gradient for parameter '{param_name}'" + + f" w.r.t '{output_name}' is zero: {grads.data}", + UserWarning, + )