diff --git a/.gitignore b/.gitignore index e3c6c2e..1706fbf 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,7 @@ venv3 ENV/ env.bak/ venv.bak/ +.vscode/ # vim *.swp diff --git a/docs/api_reference.md b/docs/api_reference.md index ebb4a1b..b7256e4 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -6,15 +6,12 @@ hide: ## **Crop modules** -!!! note - At the moment only two modules of `leaf_dynamics` and `root_dynamics` are - differentiable w.r.t two parameters of `SPAN` and `TDWI`. But the package is under - continuous development. So make sure that you install the latest version. - ::: diffwofost.physical_models.crop.leaf_dynamics.WOFOST_Leaf_Dynamics ::: diffwofost.physical_models.crop.root_dynamics.WOFOST_Root_Dynamics +::: diffwofost.physical_models.crop.phenology.DVS_Phenology + ## **Utility (under development)** ::: diffwofost.physical_models.utils.EngineTestHelper diff --git a/src/diffwofost/__init__.py b/src/diffwofost/__init__.py index bd78a5f..598fa37 100644 --- a/src/diffwofost/__init__.py +++ b/src/diffwofost/__init__.py @@ -3,6 +3,7 @@ import logging from diffwofost.physical_models import utils from diffwofost.physical_models.crop import leaf_dynamics +from diffwofost.physical_models.crop import phenology from diffwofost.physical_models.crop import root_dynamics logging.getLogger(__name__).addHandler(logging.NullHandler()) @@ -14,5 +15,6 @@ __all__ = [ "leaf_dynamics", "root_dynamics", + "phenology", "utils", ] diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 5ec36ba..4363670 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -109,7 +109,8 @@ class WOFOST_Leaf_Dynamics(SimulationObject): | LAI | TDWI, SPAN, RGRLAI, TBASE, KDIFTB, SLATB | | TWLV | TDWI, PERDL | - [!] Notice that the following gradients are zero: + [!NOTE] + Notice that the following gradients are zero: - ∂SPAN/∂LAI - ∂PERDL/∂TWLV - ∂KDIFTB/∂LAI diff --git a/src/diffwofost/physical_models/crop/phenology.py b/src/diffwofost/physical_models/crop/phenology.py new file mode 100644 index 0000000..db7fb2c --- /dev/null +++ b/src/diffwofost/physical_models/crop/phenology.py @@ -0,0 +1,773 @@ +"""Phenological development and vernalisation models for WOFOST. + +This module implements: +- Vernalisation: modification of phenological development due to cold +exposure. +- DVS_Phenology: main phenology progression (DVS scale: 0 emergence, 1 +anthesis, 2 maturity). +""" + +import torch +from pcse import exceptions as exc +from pcse import signals +from pcse.base import ParamTemplate +from pcse.base import RatesTemplate +from pcse.base import SimulationObject +from pcse.base import StatesTemplate +from pcse.decorators import prepare_rates +from pcse.decorators import prepare_states +from pcse.traitlets import Any +from pcse.traitlets import Enum +from pcse.traitlets import Instance +from pcse.util import daylength +from diffwofost.physical_models.utils import AfgenTrait +from diffwofost.physical_models.utils import _broadcast_to +from diffwofost.physical_models.utils import _get_drv +from diffwofost.physical_models.utils import _get_params_shape +from diffwofost.physical_models.utils import _restore_state +from diffwofost.physical_models.utils import _snapshot_state + +DTYPE = torch.float64 # Default data type for tensors in this module +EPS = torch.tensor(1e-8, dtype=DTYPE) # Small epsilon to avoid div by zero + + +class Vernalisation(SimulationObject): + """Modification of phenological development due to vernalisation. + + The vernalization approach here is based on the work of Lenny van + Bussel (2011), which in turn is based on Wang and Engel (1998). The + basic principle is that winter wheat needs a certain number of days + with temperatures within an optimum temperature range to complete + its vernalisation requirement. Until the vernalisation requirement + is fulfilled, the crop development is delayed. + + The rate of vernalization (VERNR) is defined by the temperature + response function VERNRTB. Within the optimal temperature range 1 + day is added to the vernalisation state (VERN). The reduction on the + phenological development is calculated from the base and saturated + vernalisation requirements (VERNBASE and VERNSAT). The reduction + factor (VERNFAC) is scaled linearly between VERNBASE and VERNSAT. + + A critical development stage (VERNDVS) is used to stop the effect of + vernalisation when this DVS is reached. This is done to improve + model stability in order to avoid that Anthesis is never reached + due to a somewhat too high VERNSAT. Nevertheless, a warning is + written to the log file, if this happens. + + * Van Bussel, 2011. From field to globe: Upscaling of crop growth + modelling. Wageningen PhD thesis. http://edepot.wur.nl/180295 + * Wang and Engel, 1998. Simulation of phenological development of + wheat crops. Agric. Systems 58:1 pp 1-24 + + *Simulation parameters* (provide in cropdata dictionary) + + | Name | Description | Type | Unit | + |----------|---------------------------------------------------------------|------|------| + | VERNSAT | Saturated vernalisation requirements | SCr | days | + | VERNBASE | Base vernalisation requirements | SCr | days | + | VERNRTB | Rate of vernalisation as a function of daily mean temperature | TCr | - | + | VERNDVS | Critical development stage after which the effect of | SCr | - | + | | vernalisation is halted | | | + + **State variables** + + | Name | Description | Pbl | Unit | + |---------------|----------------------------------------------------|-----|------| + | VERN | Vernalisation state | N | days | + | DOV | Day when vernalisation requirements are fulfilled. | N | - | + | ISVERNALISED | Flag indicated that vernalisation requirement has been reached | Y | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |---------|------------------------------------------------------------------|-----|------| + | VERNR | Rate of vernalisation | N | - | + | VERNFAC | Reduction factor on development rate due to vernalisation effect.| Y | - | + + **External dependencies:** + + | Name | Description | Provided by | Unit | + |------|--------------------------------------------------------|-------------|------| + | DVS | Development stage (only to test if critical VERNDVS | Phenology | - | + | | for vernalisation reached) | | | + """ + + params_shape = None # Shape of the parameters tensors + + class Parameters(ParamTemplate): + VERNSAT = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Saturated vernalisation requirements + VERNBASE = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Base vernalisation requirements + VERNRTB = AfgenTrait() # Vernalisation temperature response + VERNDVS = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Critical DVS for vernalisation fulfillment + + class RateVariables(RatesTemplate): + VERNR = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) # Rate of vernalisation + VERNFAC = Any( + default_value=torch.tensor(0.0, dtype=DTYPE) + ) # Red. factor for phenol. devel. + + class StateVariables(StatesTemplate): + VERN = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Vernalisation state + DOV = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Day ordinal when vernalisation fulfilled + ISVERNALISED = Any(default_value=torch.tensor(False)) # True when VERNSAT is reached and + # Forced when DVS > VERNDVS + + def initialize(self, day, kiosk, parvalues, dvs_shape=None): + """Initialize the Vernalisation sub-module. + + Args: + day (datetime.date): Simulation start date. + kiosk: Shared PCSE kiosk for inter-module variable exchange. + parvalues: ParameterProvider/dict containing VERNSAT, VERNBASE, + VERNRTB and VERNDVS. + dvs_shape (torch.Size, optional): Shape of the DVS_phenology parameters + + Side Effects: + - Instantiates params, rates and states containers. + - Publishes VERNFAC (rate) and ISVERNALISED (state) to kiosk. + + Initial State: + VERN = 0.0 (no vernalisation accrued), + DOV = None (fulfillment date unknown), + ISVERNALISED = False. + + """ + self.params = self.Parameters(parvalues) + self.params_shape = _get_params_shape(self.params) + if dvs_shape is not None: + if self.params_shape == (): + self.params_shape = dvs_shape + elif self.params_shape != dvs_shape: + raise ValueError( + f"Vernalisation params shape {self.params_shape}" + + " incompatible with dvs_shape {dvs_shape}" + ) + # Explicitly initialize rates + self.rates = self.RateVariables(kiosk, publish=["VERNFAC"]) + self.rates.VERNR = _broadcast_to(self.rates.VERNR, self.params_shape) + self.rates.VERNFAC = _broadcast_to(self.rates.VERNFAC, self.params_shape) + self.kiosk = kiosk + + # Explicitly broadcast all parameters to params_shape + self.params.VERNSAT = _broadcast_to(self.params.VERNSAT, self.params_shape) + self.params.VERNBASE = _broadcast_to(self.params.VERNBASE, self.params_shape) + self.params.VERNDVS = _broadcast_to(self.params.VERNDVS, self.params_shape) + + # Define initial states + self.states = self.StateVariables( + kiosk, + VERN=torch.zeros(self.params_shape, dtype=DTYPE), + DOV=torch.full(self.params_shape, -1.0, dtype=DTYPE), # -1 indicates not yet fulfilled + ISVERNALISED=torch.zeros(self.params_shape, dtype=torch.bool), + publish=["ISVERNALISED"], + ) + # Per-element force flag (False for all elements initially) + self._force_vernalisation = torch.zeros(self.params_shape, dtype=torch.bool) + + @prepare_rates + def calc_rates(self, day, drv): + """Calculate vernalisation rates. + + Args: + day (datetime.date): Current simulation date. + drv: Driver object providing TEMP. + + Logic: + - If not vernalised and DVS < VERNDVS: accumulate VERN via VERNRTB(TEMP) and + compute VERNFAC scaled between VERNBASE and VERNSAT. + - If DVS >= VERNDVS before fulfillment: stop accumulation, set VERNFAC=1, flag forced. + - After fulfillment: VERNR=0, VERNFAC=1. + """ + params = self.params + VERNDVS = params.VERNDVS + VERNSAT = params.VERNSAT + VERNBASE = params.VERNBASE + DVS = self.kiosk["DVS"] + + TEMP = _get_drv(drv.TEMP, self.params_shape) + + # Operate elementwise only on elements not yet vernalised + not_vernalised = ~self.states.ISVERNALISED + vegetative_mask = not_vernalised & (DVS >= 0.0) & (DVS < VERNDVS) + past_threshold_mask = not_vernalised & (DVS >= VERNDVS) + + # VERNR only for vegetative elements + self.rates.VERNR = torch.where( + vegetative_mask, params.VERNRTB(TEMP), torch.zeros(self.params_shape, dtype=DTYPE) + ) + + # compute VERNFAC from current VERN for vegetative elements; others = 1 + safe_den = VERNSAT - VERNBASE + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + r = (self.states.VERN - VERNBASE) / safe_den + vernfac_computed = torch.clamp(r, 0.0, 1.0) + self.rates.VERNFAC = torch.where( + vegetative_mask, vernfac_computed, torch.ones(self.params_shape, dtype=DTYPE) + ) + + # mark per-element force flags for elements that passed VERNDVS but aren't vernalised + if torch.any(past_threshold_mask): + self._force_vernalisation = self._force_vernalisation | past_threshold_mask + + @prepare_states + def integrate(self, day, delt=1.0): + """Advance vernalisation state. + + Args: + day (datetime.date): Current simulation date. + delt (float, optional): Timestep length in days (default 1.0). + + Updates: + - VERN += VERNR + - When VERN >= VERNSAT: sets ISVERNALISED=True and records DOV. + - When critical DVS already passed (forced): sets ISVERNALISED=True + without assigning DOV and logs a warning. + - Otherwise keeps ISVERNALISED False. + + Notes: + VERNFAC is computed in calc_rates and published for use in phenology. + + """ + states = self.states + rates = self.rates + params = self.params + + VERNSAT = params.VERNSAT + # accumulate vernalisation per element + states.VERN = states.VERN + rates.VERNR + + # elements that reached requirement + reached = states.VERN >= VERNSAT + # update ISVERNALISED per-element + states.ISVERNALISED = states.ISVERNALISED | reached + + # set DOV only for newly reached elements + newly_reached_and_no_dov = reached & (states.DOV < 0.0) + if torch.any(newly_reached_and_no_dov): + states.DOV = torch.where( + newly_reached_and_no_dov, + torch.full(self.params_shape, day.toordinal(), dtype=DTYPE), + states.DOV, + ) + self.logger.info(f"Vernalization requirements reached at day {day}.") + + # forced vernalisation per-element + forced_mask = self._force_vernalisation & (~states.ISVERNALISED) + if torch.any(forced_mask): + states.ISVERNALISED = states.ISVERNALISED | forced_mask + self.logger.warning( + "Critical DVS for vernalization (VERNDVS) reached at" + + f" day {day} for some elements; forcing vernalization now." + ) + # clear force bits for those elements + self._force_vernalisation = self._force_vernalisation & (~forced_mask) + + +class DVS_Phenology(SimulationObject): + """Implements the algorithms for phenologic development in WOFOST. + + Phenologic development in WOFOST is expresses using a unitless scale + which takes the values 0 at emergence, 1 at Anthesis (flowering) and + 2 at maturity. This type of phenological development is mainly + representative for cereal crops. All other crops that are simulated + with WOFOST are forced into this scheme as well, although this may + not be appropriate for all crops. For example, for potatoes + development stage 1 represents the start of tuber formation rather + than flowering. + + Phenological development is mainly governed by temperature and can + be modified by the effects of day length and vernalization during + the period before Anthesis. After Anthesis, only temperature + influences the development rate. + + **Simulation parameters** + + | Name | Description | Type | Unit | + |---------|-----------------------------------------------------------|------|------| + | TSUMEM | Temperature sum from sowing to emergence | SCr | |C| day | + | TBASEM | Base temperature for emergence | SCr | |C| | + | TEFFMX | Maximum effective temperature for emergence | SCr | |C| | + | TSUM1 | Temperature sum from emergence to anthesis | SCr | |C| day | + | TSUM2 | Temperature sum from anthesis to maturity | SCr | |C| day | + | IDSL | Switch for development options: temp only (0), +daylength | SCr | - | + | | (1), +vernalization (>=2) | | | + | DLO | Optimal daylength for phenological development | SCr | hr | + | DLC | Critical daylength for phenological development | SCr | hr | + | DVSI | Initial development stage at emergence (may be >0 for | SCr | - | + | | transplanted crops) | | | + | DVSEND | Final development stage | SCr | - | + | DTSMTB | Daily increase in temperature sum as a function of daily | TCr | |C| | + | | mean temperature | | | + + **State variables** + + | Name | Description | Pbl | Unit | + |-------|----------------------------------------------------------|-----|---------| + | DVS | Development stage | Y | - | + | TSUM | Temperature sum | N | |C| day | + | TSUME | Temperature sum for emergence | N | |C| day | + | DOS | Day of sowing | N | - | + | DOE | Day of emergence | N | - | + | DOA | Day of Anthesis | N | - | + | DOM | Day of maturity | N | - | + | DOH | Day of harvest | N | - | + | STAGE | Current stage (`emerging|vegetative|reproductive|mature`) | N | - | + + **Rate variables** + + | Name | Description | Pbl | Unit | + |--------|-----------------------------------------------------|-----|-------| + | DTSUME | Increase in temperature sum for emergence | N | |C| | + | DTSUM | Increase in temperature sum for anthesis or maturity| N | |C| | + | DVR | Development rate | Y | |day-1| | + + **External dependencies:** + + None + + **Signals sent or handled** + + `DVS_Phenology` sends the `crop_finish` signal when maturity is + reached and the `end_type` is 'maturity' or 'earliest'. + + **Gradient mapping (which parameters have a gradient):** + + | Output | Parameters influencing it | + |--------|------------------------------------------| + | DVS | ... | + | TSUM | ... | + + [!NOTE] + Notice that the gradient ∂DVS/∂TEFFMX is zero. + + [!NOTE] + The parameter IDSL it is not differentiable since it is a switch. + """ + + # Placeholder for start/stop types and vernalisation module + vernalisation = Instance(Vernalisation) + + params_shape = None # Shape of the parameters tensors + + class Parameters(ParamTemplate): + TSUMEM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Temp. sum for emergence + TBASEM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Base temp. for emergence + TEFFMX = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Max eff temperature for emergence + TSUM1 = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Temperature sum emergence to anthesis + TSUM2 = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Temperature sum anthesis to maturity + IDSL = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Switch for photoperiod (1) and vernalisation (2) + DLO = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Optimal day length for phenol. development + DLC = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Critical day length for phenol. development + DVSI = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Initial development stage + DVSEND = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Final development stage + DTSMTB = AfgenTrait() # Temperature response function for phenol. + # development. + CROP_START_TYPE = Enum(["sowing", "emergence"]) + CROP_END_TYPE = Enum(["maturity", "harvest", "earliest"]) + + class RateVariables(RatesTemplate): + DTSUME = Any( + default_value=torch.tensor(0.0, dtype=DTYPE) + ) # increase in temperature sum for emergence + DTSUM = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) # increase in temperature sum + DVR = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) # development rate + + class StateVariables(StatesTemplate): + DVS = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Development stage + TSUM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Temperature sum state + TSUME = Any( + default_value=torch.tensor(-99.0, dtype=DTYPE) + ) # Temperature sum for emergence state + # States which register phenological events as day ordinals (tensor of floats) + DOS = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of sowing (ordinal) + DOE = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of emergence (ordinal) + DOA = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of anthesis (ordinal) + DOM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of maturity (ordinal) + DOH = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of harvest (ordinal) + # STAGE as integer tensor: 0=emerging, 1=vegetative, 2=reproductive, 3=mature + STAGE = Any(default_value=torch.tensor(-99, dtype=torch.long)) + + def initialize(self, day, kiosk, parvalues): + """:param day: start date of the simulation + + :param kiosk: variable kiosk of this PCSE instance + :param parvalues: `ParameterProvider` object providing parameters as + key/value pairs + """ + self.params = self.Parameters(parvalues) + self.params_shape = _get_params_shape(self.params) + + # Initialize vernalisation for IDSL>=2 + # It has to be done in advance to get the correct params_shape + IDSL = _broadcast_to(self.params.IDSL, self.params_shape) + if torch.any(IDSL >= 2): + if self.params_shape != (): + self.vernalisation = Vernalisation( + day, kiosk, parvalues, dvs_shape=self.params_shape + ) + else: + self.vernalisation = Vernalisation(day, kiosk, parvalues) + if self.vernalisation.params_shape != self.params_shape: + self.params_shape = self.vernalisation.params_shape + else: + self.vernalisation = None + + # Initialize rates and kiosk + self.rates = self.RateVariables(kiosk) + self.kiosk = kiosk + + self._connect_signal(self._on_CROP_FINISH, signal=signals.crop_finish) + + # Define initial states + DVS, DOS, DOE, STAGE = self._get_initial_stage(day) + DVS = _broadcast_to(DVS, self.params_shape) + + # Initialize all date tensors with -1 (not yet occurred) + DOS = _broadcast_to(DOS, self.params_shape) + DOE = _broadcast_to(DOE, self.params_shape) + DOA = torch.full(self.params_shape, -1.0, dtype=DTYPE) + DOM = torch.full(self.params_shape, -1.0, dtype=DTYPE) + DOH = torch.full(self.params_shape, -1.0, dtype=DTYPE) + STAGE = _broadcast_to(STAGE, self.params_shape) + + # Also ensure TSUM and TSUME are properly shaped + TSUM = torch.zeros(self.params_shape, dtype=DTYPE, requires_grad=True) + TSUME = torch.zeros(self.params_shape, dtype=DTYPE, requires_grad=True) + + self.states = self.StateVariables( + kiosk, + publish="DVS", + TSUM=TSUM, + TSUME=TSUME, + DVS=DVS, + DOS=DOS, + DOE=DOE, + DOA=DOA, + DOM=DOM, + DOH=DOH, + STAGE=STAGE, + ) + + def _get_initial_stage(self, day): + """Determine initial phenological state at simulation start. + + Args: + day (datetime.date): Simulation start day. + + Returns: + tuple: (DVS, DOS, DOE, STAGE) + DVS (Tensor): Initial development stage (-0.1 if sowing start, + or DVSI if emergence start). + DOS (Tensor): Sowing date ordinal (or -1 if not applicable). + DOE (Tensor): Emergence date ordinal (or -1 if not applicable). + STAGE (Tensor): Integer stage code (0=emerging, 1=vegetative). + """ + p = self.params + day_ordinal = torch.tensor(day.toordinal(), dtype=DTYPE) + + # Define initial stage type (emergence/sowing) and fill the + # respective day of sowing/emergence (DOS/DOE) + if p.CROP_START_TYPE == "emergence": + STAGE = torch.tensor(1, dtype=torch.long) # 1 = vegetative + DOE = day_ordinal + DOS = torch.tensor(-1.0, dtype=DTYPE) # Not applicable + DVS = p.DVSI + if not isinstance(DVS, torch.Tensor): + DVS = torch.tensor(DVS, dtype=DTYPE) + + # send signal to indicate crop emergence + self._send_signal(signals.crop_emerged) + + elif p.CROP_START_TYPE == "sowing": + STAGE = torch.tensor(0, dtype=torch.long) # 0 = emerging + DOS = day_ordinal + DOE = torch.tensor(-1.0, dtype=DTYPE) # Not yet occurred + DVS = torch.tensor(-0.1, dtype=DTYPE) + + else: + msg = f"Unknown start type: {p.CROP_START_TYPE}" + raise exc.PCSEError(msg) + + return DVS, DOS, DOE, STAGE + + @prepare_rates + def calc_rates(self, day, drv): + """Compute daily phenological development rates. + + Args: + day (datetime.date): Current simulation date. + drv: Meteorological driver object with at least TEMP and LAT. + + Logic: + 1. Photoperiod reduction (DVRED) if IDSL >= 1 using daylength. + 2. Vernalisation factor (VERNFAC) if IDSL >= 2 and in vegetative stage. + 3. Stage-specific: + - emerging: temperature sum for emergence (DTSUME), DVR via TSUMEM. + - vegetative: temperature sum (DTSUM) scaled by VERNFAC and DVRED. + - reproductive: temperature sum (DTSUM) only temperature-driven. + - mature: all rates zero. + + Sets: + r.DTSUME, r.DTSUM, r.DVR. + + Raises: + PCSEError: If STAGE unrecognized. + + """ + p = self.params + r = self.rates + s = self.states + shape = self.params_shape + + # Day length sensitivity + DAYLP = daylength(day, drv.LAT) + DAYLP_t = _broadcast_to(DAYLP, shape) + # Compute DVRED conditionally based on IDSL >= 1 + safe_den = p.DLO - p.DLC + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + dvred_active = torch.clamp((DAYLP_t - p.DLC) / safe_den, 0.0, 1.0) + DVRED = torch.where(p.IDSL >= 1, dvred_active, torch.ones(shape, dtype=DTYPE)) + + # Vernalisation factor - always compute if module exists + VERNFAC = torch.ones(shape, dtype=DTYPE) + if hasattr(self, "vernalisation") and self.vernalisation is not None: + # Always call calc_rates (it handles stage internally now) + self.vernalisation.calc_rates(day, drv) + # Apply vernalisation only where IDSL >= 2 AND in vegetative stage + is_vegetative = s.STAGE == 1 + VERNFAC = torch.where( + (p.IDSL >= 2) & is_vegetative, self.kiosk["VERNFAC"], torch.ones(shape, dtype=DTYPE) + ) + + TEMP = _get_drv(drv.TEMP, shape) + + # Initialize all rate variables + r.DTSUME = torch.zeros(shape, dtype=DTYPE) + r.DTSUM = torch.zeros(shape, dtype=DTYPE) + r.DVR = torch.zeros(shape, dtype=DTYPE) + + # Compute rates for emerging stage (STAGE == 0) + is_emerging = s.STAGE == 0 + if torch.any(is_emerging): + temp_diff = TEMP - p.TBASEM + # Ensure the maximum effective temperature difference is non-negative + max_diff = torch.clamp(p.TEFFMX - p.TBASEM, min=0.0) + dtsume_emerging = torch.clamp(temp_diff, min=0.0) + dtsume_emerging = torch.minimum(dtsume_emerging, max_diff) + safe_den = p.TSUMEM + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + dvr_emerging = 0.1 * dtsume_emerging / safe_den + + r.DTSUME = torch.where(is_emerging, dtsume_emerging, r.DTSUME) + r.DVR = torch.where(is_emerging, dvr_emerging, r.DVR) + + # Compute rates for vegetative stage (STAGE == 1) + is_vegetative = s.STAGE == 1 + if torch.any(is_vegetative): + dtsum_vegetative = p.DTSMTB(TEMP) * VERNFAC * DVRED + safe_den = p.TSUM1 + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + dvr_vegetative = dtsum_vegetative / safe_den + + r.DTSUM = torch.where(is_vegetative, dtsum_vegetative, r.DTSUM) + r.DVR = torch.where(is_vegetative, dvr_vegetative, r.DVR) + + # Compute rates for reproductive stage (STAGE == 2) + is_reproductive = s.STAGE == 2 + if torch.any(is_reproductive): + dtsum_reproductive = p.DTSMTB(TEMP) + safe_den = p.TSUM2 + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + dvr_reproductive = dtsum_reproductive / safe_den + + r.DTSUM = torch.where(is_reproductive, dtsum_reproductive, r.DTSUM) + r.DVR = torch.where(is_reproductive, dvr_reproductive, r.DVR) + + # Mature stage (STAGE == 3) keeps zeros (already initialized) + + msg = "Finished rate calculation for %s" + self.logger.debug(msg % day) + + @prepare_states + def integrate(self, day, delt=1.0): + """Integrate phenology states and manage stage transitions. + + Args: + day (datetime.date): Current simulation day. + delt (float, optional): Timestep length in days (default 1.0). + + Sequence: + - Integrates vernalisation module if active and in vegetative stage. + - Accumulates TSUME, TSUM, advances DVS by DVR. + - Checks threshold crossings to move through stages: + emerging -> vegetative (DVS >= 0) + vegetative -> reproductive (DVS >= 1) + reproductive -> mature (DVS >= DVSEND) + + Side Effects: + - Emits crop_emerged signal on emergence. + - Emits crop_finish signal at maturity if end type matches. + + Notes: + Caps DVS at stage boundary values. + + Raises: + PCSEError: If STAGE undefined. + + """ + p = self.params + r = self.rates + s = self.states + shape = self.params_shape + + # Integrate vernalisation module + if self.vernalisation: + # Save a copy of state + state_copy = _snapshot_state(self.vernalisation.states) + mask_IDSL = p.IDSL >= 2 + + # Check if any element is in vegetative stage i.e. stage 1 + mask_STAGE = mask_IDSL & (s.STAGE == 1) + self.vernalisation.integrate(day, delt) + state_integrated = _snapshot_state(self.vernalisation.states) + + # Restore original state + _restore_state(self.vernalisation.states, state_copy) + self.vernalisation.touch() + state_touched = _snapshot_state(self.vernalisation.states) + + # Apply the masks + for name in state_copy: + # results of vernalisation module + vernalisation_states = torch.where( + mask_STAGE, state_integrated[name], state_touched[name] + ) + setattr( + self.vernalisation.states, + name, + torch.where(mask_IDSL, vernalisation_states, state_copy[name]), + ) + + # Integrate phenologic states + s.TSUME = s.TSUME + r.DTSUME + s.DVS = s.DVS + r.DVR + s.TSUM = s.TSUM + r.DTSUM + + day_ordinal = torch.tensor(day.toordinal(), dtype=DTYPE) + + # Check transitions for emerging -> vegetative (STAGE 0 -> 1) + is_emerging = s.STAGE == 0 + should_emerge = is_emerging & (s.DVS >= 0.0) + s.STAGE = torch.where(should_emerge, torch.ones(shape, dtype=torch.long), s.STAGE) + s.DOE = torch.where(should_emerge, torch.full(shape, day_ordinal, dtype=DTYPE), s.DOE) + s.DVS = torch.where(should_emerge, torch.clamp(s.DVS, max=0.0), s.DVS) + + # Send signal if any crop emerged (only once per day) + if torch.any(should_emerge): + self._send_signal(signals.crop_emerged) + + # Check transitions for vegetative -> reproductive (STAGE 1 -> 2) + is_vegetative = s.STAGE == 1 + should_flower = is_vegetative & (s.DVS >= 1.0) + s.STAGE = torch.where(should_flower, torch.full(shape, 2, dtype=torch.long), s.STAGE) + s.DOA = torch.where(should_flower, torch.full(shape, day_ordinal, dtype=DTYPE), s.DOA) + s.DVS = torch.where(should_flower, torch.clamp(s.DVS, max=1.0), s.DVS) + + # Check transitions for reproductive -> mature (STAGE 2 -> 3) + is_reproductive = s.STAGE == 2 + should_mature = is_reproductive & (s.DVS >= p.DVSEND) + s.STAGE = torch.where(should_mature, torch.full(shape, 3, dtype=torch.long), s.STAGE) + s.DOM = torch.where(should_mature, torch.full(shape, day_ordinal, dtype=DTYPE), s.DOM) + s.DVS = torch.where(should_mature, torch.minimum(s.DVS, p.DVSEND), s.DVS) + + # Send crop_finish signal if maturity reached for one. + # assumption is that all elements mature simultaneously + # TODO: revisit this when fixing engine for agromanager + if torch.any(should_mature) and p.CROP_END_TYPE in ["maturity", "earliest"]: + self._send_signal( + signal=signals.crop_finish, + day=day, + finish_type="maturity", + crop_delete=True, + ) + + msg = "Finished state integration for %s" + self.logger.debug(msg % day) + + def _on_CROP_FINISH(self, day, finish_type=None): + """Handle external crop finish signal to set harvest date. + + Args: + day (datetime.date): Date provided by finish event. + finish_type (str|None): 'harvest', 'earliest', or other finish reason. + + Behavior: + - If finish_type in ('harvest','earliest'): registers DOH for finalization. + + Notes: + Maturity-driven finish is triggered internally in _next_stage; this + handler captures management-induced harvests. + + """ + if finish_type in ["harvest", "earliest"]: + day_ordinal = torch.tensor(day.toordinal(), dtype=DTYPE) + self._for_finalize["DOH"] = torch.full(self.params_shape, day_ordinal, dtype=DTYPE) + + def get_variable(self, varname): + # TODO: should be removed while fixing #49. this is needed because + # conditions are applied on STAGE in pcse.crop.wofost72.py + """Return the value of the specified state or rate variable. + + :param varname: Name of the variable. + + Note that the `get_variable()` will searches for `varname` exactly + as specified (case sensitive). + """ + if varname == "STAGE": + # Return string representation of current stage + stage_map = { + 0: "emerging", + 1: "vegetative", + 2: "reproductive", + 3: "mature", + } + stage_value = self.states.STAGE + if stage_value.dim() != 0: + stage_id = stage_value.flatten()[0].item() + else: + stage_id = stage_value.item() + return stage_map[stage_id] + + # Search for variable in the current object, then traverse the hierarchy + value = None + if hasattr(self.states, varname): + value = getattr(self.states, varname) + elif hasattr(self.rates, varname): + value = getattr(self.rates, varname) + # Query individual sub-SimObject for existence of variable v + else: + for simobj in self.subSimObjects: + value = simobj.get_variable(varname) + if value is not None: + break + return value diff --git a/src/diffwofost/physical_models/crop/root_dynamics.py b/src/diffwofost/physical_models/crop/root_dynamics.py index a80e4a0..6143609 100644 --- a/src/diffwofost/physical_models/crop/root_dynamics.py +++ b/src/diffwofost/physical_models/crop/root_dynamics.py @@ -89,7 +89,8 @@ class WOFOST_Root_Dynamics(SimulationObject): | RD | RDI, RRI, RDMCR, RDMSOL | | TWRT | TDWI, RDRRTB | - [!] Notice that the gradient ∂TWRT/∂RDRRTB is zero. + [!NOTE] + Notice that the gradient ∂TWRT/∂RDRRTB is zero. **IMPORTANT NOTICE** diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index 3d5750c..5195526 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -28,6 +28,7 @@ from pcse.engine import Engine from pcse.settings import settings from pcse.timer import Timer +from pcse.traitlets import Enum from pcse.traitlets import TraitType DTYPE = torch.float64 # Default data type for tensors in this module @@ -50,7 +51,7 @@ class VariableKioskTestHelper(VariableKiosk): def __init__(self, external_state_list): super().__init__() self.current_externals = {} - if external_state_list is not None: + if external_state_list: self.external_state_list = external_state_list def __call__(self, day): @@ -59,7 +60,7 @@ def __call__(self, day): Returns True if the list of external state/rate variables is exhausted, otherwise False. """ - if self.external_state_list is not None: + if self.external_state_list: current_externals = self.external_state_list.pop(0) forcing_day = current_externals.pop("DAY") msg = "Failure updating VariableKiosk with external states: days are not matching!" @@ -226,13 +227,15 @@ def prepare_engine_input( test_data["WeatherVariables"], meteo_range_checks=meteo_range_checks ) crop_model_params_provider = ParameterProvider(cropdata=cropd) - external_states = test_data["ExternalStates"] + external_states = test_data.get("ExternalStates") or [] # convert parameters to tensors crop_model_params_provider.clear_override() for name in crop_model_params: - value = torch.tensor(crop_model_params_provider[name], dtype=dtype) - crop_model_params_provider.set_override(name, value, check=False) + # if name is missing in the YAML, skip it + if name in crop_model_params_provider: + value = torch.tensor(crop_model_params_provider[name], dtype=dtype) + crop_model_params_provider.set_override(name, value, check=False) # convert external states to tensors tensor_external_states = [ @@ -475,16 +478,27 @@ def __call__(self, x): else: x_val = flat_x[0] # Broadcast first value - # Boundary conditions - if x_val <= x_list[0]: - result = y_list[0] - elif x_val >= x_list[-1]: - result = y_list[-1] - else: - # Find interval and interpolate - i = torch.searchsorted(x_list, x_val, right=False) - 1 - i = torch.clamp(i, 0, len(x_list) - 2) - result = y_list[i] + slopes[i] * (x_val - x_list[i]) + # Ensure contiguous memory layout for searchsorted + x_list_contig = x_list.contiguous() + x_val_contig = ( + x_val.contiguous() + if isinstance(x_val, torch.Tensor) and x_val.dim() > 0 + else x_val + ) + + # Find interval and interpolate using torch.where for differentiability + i = torch.searchsorted(x_list_contig, x_val_contig, right=False) - 1 + i = torch.clamp(i, 0, len(x_list) - 2) + + # Calculate interpolated value + interp_result = y_list[i] + slopes[i] * (x_val - x_list[i]) + + # Apply boundary conditions using torch.where + result = torch.where( + x_val <= x_list[0], + y_list[0], + torch.where(x_val >= x_list[-1], y_list[-1], interp_result), + ) results.append(result) @@ -492,20 +506,25 @@ def __call__(self, x): output = torch.stack(results).reshape(self.batch_shape) return output - # Original scalar logic from pcse - # Clamp to boundaries - if x <= self.x_list[0]: - return self.y_list[0] - if x >= self.x_list[-1]: - return self.y_list[-1] + # Ensure contiguous memory layout for searchsorted + x_list_contig = self.x_list.contiguous() + x_contig = x.contiguous() if isinstance(x, torch.Tensor) and x.dim() > 0 else x # Find interval index using torch.searchsorted for differentiability - i = torch.searchsorted(self.x_list, x, right=False) - 1 + i = torch.searchsorted(x_list_contig, x_contig, right=False) - 1 i = torch.clamp(i, 0, len(self.x_list) - 2) - # Linear interpolation - v = self.y_list[i] + self.slopes[i] * (x - self.x_list[i]) - return v + # Calculate interpolated value + interp_value = self.y_list[i] + self.slopes[i] * (x - self.x_list[i]) + + # Apply boundary conditions using torch.where + result = torch.where( + x <= self.x_list[0], + self.y_list[0], + torch.where(x >= self.x_list[-1], self.y_list[-1], interp_value), + ) + + return result @property def shape(self): @@ -558,6 +577,9 @@ def _get_params_shape(params): if parname.startswith("trait"): continue param = getattr(params, parname) + # Skip Enum and str parameters + if isinstance(param, Enum) or isinstance(param, str): + continue # Parameters that are not zero dimensional should all have the same shape if param.shape and not shape: shape = param.shape @@ -615,3 +637,12 @@ def _broadcast_to(x, shape): # the dimension along which the time integration is carried out. # We first append an axis to x, then expand to the given shape return x.unsqueeze(-1).expand(shape) + + +def _snapshot_state(obj): + return {name: val.clone() for name, val in obj.__dict__.items() if torch.is_tensor(val)} + + +def _restore_state(obj, snapshot): + for name, val in snapshot.items(): + setattr(obj, name, val) diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index fb7320a..3605e93 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -9,6 +9,7 @@ "leafdynamics", "rootdynamics", "potentialproduction", + "phenology", ] FILE_NAMES = [ f"test_{model_name}_wofost72_{i:02d}.yaml" for model_name in model_names for i in range(1, 45) diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index 699979a..3eab5a5 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -204,6 +204,11 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): [ ("TDWI", 0.1), ("SPAN", 5), + ("TBASE", 2.0), + ("PERDL", 0.01), + ("RGRLAI", 0.002), + ("KDIFTB", 0.1), + ("SLATB", 0.0005), ], ) def test_leaf_dynamics_with_different_parameter_values(self, param, delta): @@ -222,7 +227,12 @@ def test_leaf_dynamics_with_different_parameter_values(self, param, delta): # Setting a vector with multiple values for the selected parameter test_value = crop_model_params_provider[param] # We set the value for which test data are available as the last element - param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) + if param in {"KDIFTB", "SLATB"}: + # AfgenTrait parameters need to have shape (N, M) + non_zeros_mask = test_value != 0 + param_vec = torch.stack([test_value + non_zeros_mask * delta, test_value]) + else: + param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) crop_model_params_provider.set_override(param, param_vec, check=False) engine = EngineTestHelper( diff --git a/tests/physical_models/crop/test_phenology.py b/tests/physical_models/crop/test_phenology.py new file mode 100644 index 0000000..ccdbb02 --- /dev/null +++ b/tests/physical_models/crop/test_phenology.py @@ -0,0 +1,722 @@ +import copy +import warnings +from unittest.mock import patch +import pytest +import torch +from pcse.engine import Engine +from pcse.models import Wofost72_PP +from diffwofost.physical_models.crop.phenology import DVS_Phenology +from diffwofost.physical_models.utils import EngineTestHelper +from diffwofost.physical_models.utils import calculate_numerical_grad +from diffwofost.physical_models.utils import get_test_data +from diffwofost.physical_models.utils import prepare_engine_input +from .. import phy_data_folder + +# Ignore deprecation warnings from pcse.base.simulationobject +pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning:pcse.base.simulationobject") + + +def assert_reference_match(reference, model, expected_precision): + assert reference["DAY"] == model["day"] + for var, precision in expected_precision.items(): + # [!] These are not 'State variables' and are not stored in model output + if var in ["VERNFAC", "VERNR"]: + continue + + # for some data tests, both reference and model can have None values + if reference[var] is None and model[var] is None: + continue + assert torch.all( + torch.abs(torch.as_tensor(reference[var]) - torch.as_tensor(model[var])) < precision + ) + + +def get_test_diff_phenology_model(): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + # Phenology-related crop model parameters + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( + prepare_engine_input(test_data, crop_model_params) + ) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + return DiffPhenologyDynamics( + copy.deepcopy(crop_model_params_provider), + weather_data_provider, + agro_management_inputs, + config_path, + copy.deepcopy(external_states), + ) + + +class DiffPhenologyDynamics(torch.nn.Module): + def __init__( + self, + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ): + super().__init__() + self.crop_model_params_provider = crop_model_params_provider + self.weather_data_provider = weather_data_provider + self.agro_management_inputs = agro_management_inputs + self.config_path = config_path + self.external_states = external_states + + def forward(self, params_dict): + # pass new value of parameters to the model + for name, value in params_dict.items(): + self.crop_model_params_provider.set_override(name, value, check=False) + + engine = EngineTestHelper( + self.crop_model_params_provider, + self.weather_data_provider, + self.agro_management_inputs, + self.config_path, + self.external_states, + ) + engine.run_till_terminate() + results = engine.get_output() + + # Collect phenology outputs analogous to leaf dynamics test + return {var: torch.stack([item[var] for item in results]) for var in ["DVS", "TSUM"]} + + +class TestPhenologyDynamics: + phenology_data_urls = [ + f"{phy_data_folder}/test_phenology_wofost72_{i:02d}.yaml" + for i in range(1, 45) # assume 44 test files + ] + wofost72_data_urls = [ + f"{phy_data_folder}/test_potentialproduction_wofost72_{i:02d}.yaml" for i in range(1, 45) + ] + + @pytest.mark.parametrize("test_data_url", phenology_data_urls) + def test_phenology_with_testengine(self, test_data_url): + """EngineTestHelper because it allows to specify `external_states`.""" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert_reference_match(reference, model, expected_precision) + + def test_phenology_with_engine(self): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params) + ) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + Engine( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + ) + + @pytest.mark.parametrize( + "param", + [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + "TEMP", + ], + ) + def test_phenology_with_one_parameter_vector(self, param): + # pick a test case with vernalisation to have all the parameters + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_17.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + if param == "TEMP": + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + elif param == "DTSMTB": + repeated = crop_model_params_provider[param].repeat(10, 1) + crop_model_params_provider.set_override(param, repeated, check=False) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + if param == "TEMP": + with pytest.raises(ValueError): + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + _ = engine.get_output() + else: + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert_reference_match(reference, model, expected_precision) + + @pytest.mark.parametrize( + "param,delta", + [ + ("TSUMEM", 1.0), + ("TBASEM", 1.0), + ("TEFFMX", 1.0), + ("TSUM1", 1.0), + ("TSUM2", 1.0), + ("DVSI", 0.1), + ("DTSMTB", 1.0), + ("VERNSAT", 1.0), + ("VERNBASE", 0.5), + ("VERNDVS", 0.1), + ], + ) + def test_phenology_with_different_parameter_values(self, param, delta): + # we dont test IDSL,DLO, DLC, DVSEND because these paramaters controls the + # simulation duration + # TODO: revisit this choice when Engine is fixed + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_17.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + test_value = crop_model_params_provider[param] + if param == "DTSMTB": + # AfgenTrait parameters need to have shape (N, M) + # DTSMTB is increase in tempearture, so avoid negative values + non_zeros_mask = test_value != 0 + param_vec = torch.stack([test_value + non_zeros_mask * delta, test_value]) + else: + param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) + crop_model_params_provider.set_override(param, param_vec, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + + for reference, model in zip(expected_results, actual_results, strict=False): + # keep original special case using last element + for var, precision in expected_precision.items(): + # [!] These are not 'State variables' and are not stored in model output + if var in ["VERNFAC", "VERNR"]: + continue + + # for some data tests, both reference and model can have None values + if reference[var] is None and model[var] is None: + continue + assert torch.all(torch.abs(reference[var] - model[var][-1]) < precision) + + def test_phenology_with_multiple_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_17.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + for param in crop_model_params: + if param == "DTSMTB": + repeated = crop_model_params_provider[param].repeat(10, 1) + else: + repeated = crop_model_params_provider[param].repeat(10) + crop_model_params_provider.set_override(param, repeated, check=False) + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert_reference_match(reference, model, expected_precision) + + def test_phenology_with_multiple_parameter_arrays(self): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_17.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + for param in ( + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ): + if param == "DTSMTB": + repeated = crop_model_params_provider[param].repeat(30, 5, 1) + else: + repeated = crop_model_params_provider[param].broadcast_to((30, 5)) + crop_model_params_provider.set_override(param, repeated, check=False) + + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones((30, 5), dtype=torch.float64) * wdc.TEMP + + engine = EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + engine.run_till_terminate() + actual_results = engine.get_output() + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + assert len(actual_results) == len(expected_results) + for reference, model in zip(expected_results, actual_results, strict=False): + assert_reference_match(reference, model, expected_precision) + assert all( + model[var].shape == (30, 5) + for var in expected_precision.keys() + if var not in ["VERNFAC", "VERNR"] + ) + + def test_phenology_with_incompatible_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + crop_model_params_provider.set_override( + "TSUM1", crop_model_params_provider["TSUM1"].repeat(10), check=False + ) + crop_model_params_provider.set_override( + "TSUM2", crop_model_params_provider["TSUM2"].repeat(5), check=False + ) + + with pytest.raises(AssertionError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + + def test_phenology_with_incompatible_weather_parameter_vectors(self): + test_data_url = f"{phy_data_folder}/test_phenology_wofost72_01.yaml" + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + ( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + external_states, + ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + config_path = str(phy_data_folder / "WOFOST_Phenology.conf") + + crop_model_params_provider.set_override( + "TSUM1", crop_model_params_provider["TSUM1"].repeat(10), check=False + ) + for (_, _), wdc in weather_data_provider.store.items(): + wdc.TEMP = torch.ones(5, dtype=torch.float64) * wdc.TEMP + + with pytest.raises(ValueError): + EngineTestHelper( + crop_model_params_provider, + weather_data_provider, + agro_management_inputs, + config_path, + external_states, + ) + + @pytest.mark.parametrize("test_data_url", wofost72_data_urls) + def test_wofost_pp_with_phenology(self, test_data_url): + test_data = get_test_data(test_data_url) + crop_model_params = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + "DTSMTB", + "VERNSAT", + "VERNBASE", + "VERNDVS", + ] + (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( + prepare_engine_input(test_data, crop_model_params) + ) + expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + + with patch("pcse.crop.wofost72.Phenology", DVS_Phenology): + model = Wofost72_PP( + crop_model_params_provider, weather_data_provider, agro_management_inputs + ) + model.run_till_terminate() + actual_results = model.get_output() + + assert len(actual_results) == len(expected_results) + for reference, model_day in zip(expected_results, actual_results, strict=False): + assert_reference_match(reference, model_day, expected_precision) + + +class TestDiffPhenologyDynamicsGradients: + """Parametrized tests for gradient calculations in phenology dynamics.""" + + # Check if they contribute to gradients of outputs + param_names = [ + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "DLO", + "DLC", + "DVSEND", + "DTSMTB", + ] + output_names = ["DVS", "TSUM"] + + param_configs = { + "single": { + "TSUMEM": (50.0, torch.float64), + "TBASEM": (0.0, torch.float64), + "TEFFMX": (35.0, torch.float64), + "TSUM1": (500.0, torch.float64), + "TSUM2": (600.0, torch.float64), + "DLO": (0.5, torch.float64), + "DLC": (0.5, torch.float64), + "DVSEND": (2.0, torch.float64), + "DTSMTB": ([0.0, 0.0, 35.0, 35.0, 45.0, 35.0], torch.float64), + "VERNSAT": (15.0, torch.float64), + "VERNBASE": (5.0, torch.float64), + "VERNDVS": (0.5, torch.float64), + }, + "tensor": { + "TSUMEM": ([45.0, 50.0, 55.0], torch.float64), + "TBASEM": ([-2.0, 0.0, 2.0], torch.float64), + "TEFFMX": ([32.0, 35.0, 38.0], torch.float64), + "TSUM1": ([450.0, 500.0, 550.0], torch.float64), + "TSUM2": ([550.0, 600.0, 650.0], torch.float64), + "DLO": ([0.4, 0.5, 0.6], torch.float64), + "DLC": ([0.4, 0.5, 0.6], torch.float64), + "DVSEND": ([1.9, 2.0, 2.1], torch.float64), + "DTSMTB": ( + [ + [0, 0, 15, 8, 30, 18], + [0, 0, 5, 9, 10, 19], + [0, 0, 25, 1, 30, 20], + ], + torch.float64, + ), + "VERNSAT": ([14.0, 15.0, 16.0], torch.float64), + "VERNBASE": ([4.0, 5.0, 6.0], torch.float64), + "VERNDVS": ([0.4, 0.5, 0.6], torch.float64), + }, + } + gradient_mapping = { + "TSUMEM": ["DVS"], + "TBASEM": ["DVS"], + "TEFFMX": ["DVS"], + "TSUM1": ["DVS"], + "TSUM2": ["DVS"], + "DLO": ["DVS"], + "DLC": ["DVS"], + "DVSI": ["DVS", "TSUM"], + "DVSEND": ["DVS"], + "DTSMTB": ["DVS", "TSUM"], + "VERNSAT": ["DVS", "TSUM"], + "VERNBASE": ["DVS", "TSUM"], + "VERNDVS": ["DVS", "TSUM"], + } + + gradient_params = [] + no_gradient_params = [] + for pname in param_names: + for oname in output_names: + if oname in gradient_mapping.get(pname, []): + gradient_params.append((pname, oname)) + else: + no_gradient_params.append((pname, oname)) + + @pytest.mark.parametrize("param_name,output_name", no_gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_no_gradients(self, param_name, output_name, config_type): + model = get_test_diff_phenology_model() + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + output = model({param_name: param}) + loss = output[output_name].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_forward_backward_match(self, param_name, output_name, config_type): + model = get_test_diff_phenology_model() + value, dtype = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + output = model({param_name: param}) + loss = output[output_name].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + assert grads is not None + param.grad = None + loss.backward() + grad_backward = param.grad + assert grad_backward is not None + assert torch.all(grad_backward == grads) + + @pytest.mark.parametrize("param_name,output_name", gradient_params) + @pytest.mark.parametrize("config_type", ["single", "tensor"]) + def test_gradients_numerical(self, param_name, output_name, config_type): + value, _ = self.param_configs[config_type][param_name] + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64)) + numerical_grad = calculate_numerical_grad( + get_test_diff_phenology_model, param_name, param.data, output_name + ) + model = get_test_diff_phenology_model() + output = model({param_name: param}) + loss = output[output_name].sum() + grads = torch.autograd.grad(loss, param, retain_graph=True)[0] + rtol = 0.005 + assert torch.all( + torch.abs(numerical_grad - grads.data) / (torch.abs(grads.data) + 1e-8) < rtol + ) + if torch.all(grads == 0): + warnings.warn( + f"Gradient for par '{param_name}' wrt out '{output_name}' is zero: {grads.data}", + UserWarning, + ) diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index 63af26b..9e2423e 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -185,6 +185,11 @@ def test_root_dynamics_with_one_parameter_vector(self, param): [ ("RDI", 1.0), ("RRI", 0.1), + ("RDMCR", 10.0), + ("RDMSOL", 10.0), + ("TDWI", 0.05), + ("IAIRDU", 0.05), + ("RDRRTB", 0.01), ], ) def test_root_dynamics_with_different_parameter_values(self, param, delta): @@ -203,7 +208,12 @@ def test_root_dynamics_with_different_parameter_values(self, param, delta): # Setting a vector with multiple values for the selected parameter test_value = crop_model_params_provider[param] # We set the value for which test data are available as the last element - param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) + if param == "RDRRTB": + # AfgenTrait parameters need to have shape (N, M) + non_zeros_mask = test_value != 0 + param_vec = torch.stack([test_value + non_zeros_mask * delta, test_value]) + else: + param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) crop_model_params_provider.set_override(param, param_vec, check=False) engine = EngineTestHelper( diff --git a/tests/physical_models/test_data/WOFOST_Phenology.conf b/tests/physical_models/test_data/WOFOST_Phenology.conf new file mode 100644 index 0000000..c47fead --- /dev/null +++ b/tests/physical_models/test_data/WOFOST_Phenology.conf @@ -0,0 +1,33 @@ + +from diffwofost.physical_models.crop.phenology import DVS_Phenology +from pcse.agromanager import AgroManager + +# Module to be used for water balance +SOIL = None + +# Module to be used for the crop simulation itself +CROP = DVS_Phenology + +# Module to use for AgroManagement actions +AGROMANAGEMENT = AgroManager + +# variables to save at OUTPUT signals +# Set to an empty list if you do not want any OUTPUT +OUTPUT_VARS = ["DVR","DVS","TSUM","TSUME", "VERN"] +# interval for OUTPUT signals, either "daily"|"dekadal"|"monthly"|"weekly" +# For daily output you change the number of days between successive +# outputs using OUTPUT_INTERVAL_DAYS. For dekadal and monthly +# output this is ignored. +OUTPUT_INTERVAL = "daily" +OUTPUT_INTERVAL_DAYS = 1 +# Weekday: Monday is 0 and Sunday is 6 +OUTPUT_WEEKDAY = 0 + +# Summary variables to save at CROP_FINISH signals +# Set to an empty list if you do not want any SUMMARY_OUTPUT +SUMMARY_OUTPUT_VARS = ["DVS", "DOS", "DOE", "DOA", + "DOM", "DOH", "DOV"] + +# Summary variables to save at TERMINATE signals +# Set to an empty list if you do not want any TERMINAL_OUTPUT +TERMINAL_OUTPUT_VARS = [] \ No newline at end of file diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index ec009de..cc4503b 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -274,6 +274,56 @@ def test_zero_slope_segment(self): expected = torch.tensor(7.5, dtype=DTYPE) assert torch.isclose(result, expected) + def test_tensor_input_at_boundaries(self): + """Test tensor inputs at boundary conditions with gradients.""" + afgen = Afgen([0, 5, 10, 15]) + + # Test below lower bound with gradient + x_low = torch.tensor(-2.0, dtype=DTYPE, requires_grad=True) + result_low = afgen(x_low) + assert result_low == 5.0 + result_low.backward() + # Gradient should be 0 when clamped at boundary + assert x_low.grad == 0.0 + + # Test above upper bound with gradient + x_high = torch.tensor(15.0, dtype=DTYPE, requires_grad=True) + result_high = afgen(x_high) + assert result_high == 15.0 + result_high.backward() + # Gradient should be 0 when clamped at boundary + assert x_high.grad == 0.0 + + def test_tensor_input_near_boundaries(self): + """Test tensor inputs just inside boundaries maintain gradients.""" + afgen = Afgen([0, 0, 10, 10]) + + # Just above lower bound + x_near_low = torch.tensor(0.1, dtype=DTYPE, requires_grad=True) + result = afgen(x_near_low) + result.backward() + # Should have gradient of 1 (slope of the line) + assert torch.isclose(x_near_low.grad, torch.tensor(1.0, dtype=DTYPE), atol=1e-5) + + # Just below upper bound + x_near_high = torch.tensor(9.9, dtype=DTYPE, requires_grad=True) + result = afgen(x_near_high) + result.backward() + assert torch.isclose(x_near_high.grad, torch.tensor(1.0, dtype=DTYPE), atol=1e-5) + + def test_1d_tensor_batch_input(self): + """Test that we can pass a 1D tensor to evaluate multiple points at once.""" + afgen = Afgen([0, 0, 10, 10]) + + # Process multiple values in a vectorized manner + x_batch = torch.tensor([2.0, 5.0, 8.0], dtype=DTYPE) + results = torch.stack([afgen(x) for x in x_batch]) + + assert results.shape == (3,) + assert torch.isclose(results[0], torch.tensor(2.0, dtype=DTYPE)) + assert torch.isclose(results[1], torch.tensor(5.0, dtype=DTYPE)) + assert torch.isclose(results[2], torch.tensor(8.0, dtype=DTYPE)) + class TestAfgenBatched: """Tests for batched Afgen functionality with multidimensional tensors.""" @@ -542,6 +592,36 @@ def test_backward_compatibility_non_batched(self): assert result.dim() == 0 # Scalar tensor assert torch.isclose(result, torch.tensor(5.0, dtype=DTYPE)) + def test_batched_gradient_at_boundaries(self): + """Test gradients at boundaries for batched tables.""" + tables = torch.tensor( + [ + [0, 0, 10, 10], + [0, 5, 10, 15], + ], + dtype=DTYPE, + ) + + afgen = Afgen(tables) + + # Test at lower boundary + x = torch.tensor(-1.0, dtype=DTYPE, requires_grad=True) + result = afgen(x) + loss = result.sum() + loss.backward() + + # Both tables clamped at lower bound, gradient should be 0 + assert x.grad == 0.0 + + # Test in interpolation region + x2 = torch.tensor(5.0, dtype=DTYPE, requires_grad=True) + result2 = afgen(x2) + loss2 = result2.sum() + loss2.backward() + + # Sum of slopes: 1 + 1 = 2 + assert torch.isclose(x2.grad, torch.tensor(2.0, dtype=DTYPE), atol=1e-5) + class TestGetDrvParam: """Tests for _get_drv function."""