diff --git a/.gitignore b/.gitignore index 1706fbf..984fdbf 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ share/python-wheels/ # jupyter notebook .ipynb_checkpoints +docs/notebooks/test* # Unit test / coverage reports htmlcov/ diff --git a/pyproject.toml b/pyproject.toml index 51301e8..e167d54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,9 @@ Issues = "https://github.com/WUR-AI/diffwofost/issues" [tool.pytest.ini_options] testpaths = ["tests"] +filterwarnings = [ + "ignore::DeprecationWarning:pcse.base.simulationobject", +] [tool.coverage.run] diff --git a/src/diffwofost/physical_models/config.py b/src/diffwofost/physical_models/config.py index 3a36c76..ed348bb 100644 --- a/src/diffwofost/physical_models/config.py +++ b/src/diffwofost/physical_models/config.py @@ -3,11 +3,146 @@ from pathlib import Path from typing import Self import pcse +import torch from pcse.agromanager import AgroManager from pcse.base import AncillaryObject from pcse.base import SimulationObject +class ComputeConfig: + """Central configuration for device and dtype settings. + + This class provides a centralized way to control PyTorch device and dtype + settings across all simulation objects in diffWOFOST. Instead of setting + device and dtype individually for each class, use this central configuration + to apply settings globally. + + **Default Behavior:** + + - **Device**: Automatically defaults to 'cuda' if available, otherwise 'cpu' + - **Dtype**: Defaults to torch.float64 + + **Basic Usage:** + + >>> from diffwofost.physical_models.config import ComputeConfig + >>> import torch + >>> + >>> # Set device to CPU + >>> ComputeConfig.set_device('cpu') + >>> + >>> # Or use a torch.device object + >>> ComputeConfig.set_device(torch.device('cuda')) + >>> + >>> # Set dtype to float32 + >>> ComputeConfig.set_dtype(torch.float32) + >>> + >>> # Get current settings + >>> device = ComputeConfig.get_device() # Returns: torch.device('cpu') + >>> dtype = ComputeConfig.get_dtype() # Returns: torch.float32 + + **Using with Simulation Objects:** + + All simulation objects (e.g., WOFOST_Leaf_Dynamics, WOFOST_Phenology) + automatically use the settings from ComputeConfig. No changes needed to + instantiation code: + + >>> from diffwofost.physical_models.config import ComputeConfig + >>> from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics + >>> + >>> # Set global compute settings + >>> ComputeConfig.set_device('cuda') + >>> ComputeConfig.set_dtype(torch.float32) + >>> + >>> # Instantiate objects - they automatically use global settings + >>> leaf_dynamics = WOFOST_Leaf_Dynamics() + + **Switching Between Devices:** + + Useful for switching between GPU training and CPU evaluation: + + >>> # Train on GPU + >>> ComputeConfig.set_device('cuda') + >>> ComputeConfig.set_dtype(torch.float32) + >>> # ... run training ... + >>> + >>> # Evaluate on CPU + >>> ComputeConfig.set_device('cpu') + >>> ComputeConfig.set_dtype(torch.float64) + >>> # ... run evaluation ... + + **Resetting to Defaults:** + + >>> ComputeConfig.reset_to_defaults() + + """ + + _device: torch.device = None + _dtype: torch.dtype = None + + @classmethod + def _initialize_defaults(cls): + """Initialize default device and dtype if not already set.""" + if cls._device is None: + cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if cls._dtype is None: + cls._dtype = torch.float64 + + @classmethod + def get_device(cls) -> torch.device: + """Get the current device setting. + + Returns: + torch.device: The current device (cuda or cpu) + """ + cls._initialize_defaults() + return cls._device + + @classmethod + def set_device(cls, device: str | torch.device) -> None: + """Set the device to use for tensor operations. + + Args: + device (str | torch.device): Device to use ('cuda', 'cpu', or torch.device object) + + Example: + >>> ComputeConfig.set_device('cuda') + >>> ComputeConfig.set_device(torch.device('cpu')) + """ + if isinstance(device, str): + cls._device = torch.device(device) + else: + cls._device = device + + @classmethod + def get_dtype(cls) -> torch.dtype: + """Get the current dtype setting. + + Returns: + torch.dtype: The current dtype (e.g., torch.float32, torch.float64) + """ + cls._initialize_defaults() + return cls._dtype + + @classmethod + def set_dtype(cls, dtype: torch.dtype) -> None: + """Set the dtype to use for tensor creation. + + Args: + dtype (torch.dtype): PyTorch dtype (torch.float32, torch.float64, etc.) + + Example: + >>> ComputeConfig.set_dtype(torch.float32) + """ + cls._dtype = dtype + + @classmethod + def reset_to_defaults(cls) -> None: + """Reset device and dtype to their default values.""" + cls._device = None + cls._dtype = None + cls._initialize_defaults() + + @dataclass(frozen=True) class Configuration: """Class to store model configuration from a PCSE configuration files.""" diff --git a/src/diffwofost/physical_models/crop/leaf_dynamics.py b/src/diffwofost/physical_models/crop/leaf_dynamics.py index 4363670..75d68b6 100644 --- a/src/diffwofost/physical_models/crop/leaf_dynamics.py +++ b/src/diffwofost/physical_models/crop/leaf_dynamics.py @@ -12,13 +12,12 @@ from pcse.decorators import prepare_rates from pcse.decorators import prepare_states from pcse.traitlets import Any +from diffwofost.physical_models.config import ComputeConfig from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_drv from diffwofost.physical_models.utils import _get_params_shape -DTYPE = torch.float64 # Default data type for tensors in this module - class WOFOST_Leaf_Dynamics(SimulationObject): """Leaf dynamics for the WOFOST crop model. @@ -122,40 +121,118 @@ class WOFOST_Leaf_Dynamics(SimulationObject): MAX_DAYS = 365 # Maximum number of days that can be simulated in one run (i.e. array lenghts) params_shape = None # Shape of the parameters tensors + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + class Parameters(ParamTemplate): - RGRLAI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - SPAN = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - TBASE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - PERDL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) + RGRLAI = Any() + SPAN = Any() + TBASE = Any() + PERDL = Any() + TDWI = Any() SLATB = AfgenTrait() KDIFTB = AfgenTrait() + def __init__(self, parvalues): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.RGRLAI = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.SPAN = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.TBASE = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.PERDL = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.TDWI = [torch.tensor(-99.0, dtype=dtype, device=device)] + + # Call parent init + super().__init__(parvalues) + class StateVariables(StatesTemplate): - LV = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - SLA = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - LVAGE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - LAIEM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) - LASUM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) - LAIEXP = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) - LAIMAX = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) - LAI = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) - WLV = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) - DWLV = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) - TWLV = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) + LV = Any() + SLA = Any() + LVAGE = Any() + LAIEM = Any() + LASUM = Any() + LAIEXP = Any() + LAIMAX = Any() + LAI = Any() + WLV = Any() + DWLV = Any() + TWLV = Any() + + def __init__(self, kiosk, publish=None, **kwargs): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + if "LV" not in kwargs: + self.LV = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "SLA" not in kwargs: + self.SLA = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "LVAGE" not in kwargs: + self.LVAGE = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "LAIEM" not in kwargs: + self.LAIEM = torch.tensor(-99.0, dtype=dtype, device=device) + if "LASUM" not in kwargs: + self.LASUM = torch.tensor(-99.0, dtype=dtype, device=device) + if "LAIEXP" not in kwargs: + self.LAIEXP = torch.tensor(-99.0, dtype=dtype, device=device) + if "LAIMAX" not in kwargs: + self.LAIMAX = torch.tensor(-99.0, dtype=dtype, device=device) + if "LAI" not in kwargs: + self.LAI = torch.tensor(-99.0, dtype=dtype, device=device) + if "WLV" not in kwargs: + self.WLV = torch.tensor(-99.0, dtype=dtype, device=device) + if "DWLV" not in kwargs: + self.DWLV = torch.tensor(-99.0, dtype=dtype, device=device) + if "TWLV" not in kwargs: + self.TWLV = torch.tensor(-99.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish, **kwargs) class RateVariables(RatesTemplate): - GRLV = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - DSLV1 = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - DSLV2 = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - DSLV3 = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - DSLV = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - DALV = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - DRLV = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - SLAT = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - FYSAGE = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - GLAIEX = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - GLASOL = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) + GRLV = Any() + DSLV1 = Any() + DSLV2 = Any() + DSLV3 = Any() + DSLV = Any() + DALV = Any() + DRLV = Any() + SLAT = Any() + FYSAGE = Any() + GLAIEX = Any() + GLASOL = Any() + + def __init__(self, kiosk): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.GRLV = torch.tensor(0.0, dtype=dtype, device=device) + self.DSLV1 = torch.tensor(0.0, dtype=dtype, device=device) + self.DSLV2 = torch.tensor(0.0, dtype=dtype, device=device) + self.DSLV3 = torch.tensor(0.0, dtype=dtype, device=device) + self.DSLV = torch.tensor(0.0, dtype=dtype, device=device) + self.DALV = torch.tensor(0.0, dtype=dtype, device=device) + self.DRLV = torch.tensor(0.0, dtype=dtype, device=device) + self.SLAT = torch.tensor(0.0, dtype=dtype, device=device) + self.FYSAGE = torch.tensor(0.0, dtype=dtype, device=device) + self.GLAIEX = torch.tensor(0.0, dtype=dtype, device=device) + self.GLASOL = torch.tensor(0.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(kiosk) def initialize( self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider @@ -177,28 +254,38 @@ def initialize( self.params = self.Parameters(parvalues) self.rates = self.RateVariables(kiosk) + # Create scalar constants once at the beginning to avoid recreating them + self._zero = torch.tensor(0.0, dtype=self.dtype, device=self.device) + self._epsilon = torch.tensor(1e-12, dtype=self.dtype, device=self.device) + self._sigmoid_sharpness = torch.tensor(1e-16, dtype=self.dtype, device=self.device) + self._sigmoid_epsilon = torch.tensor(1e-14, dtype=self.dtype, device=self.device) + # CALCULATE INITIAL STATE VARIABLES # check for required external variables _exist_required_external_variables(self.kiosk) # TODO check if external variables are already torch tensors - FL = self.kiosk["FL"] - FR = self.kiosk["FR"] - DVS = self.kiosk["DVS"] + # Get kiosk values and ensure they are on the correct device + FL = torch.as_tensor(self.kiosk["FL"], dtype=self.dtype, device=self.device) + FR = torch.as_tensor(self.kiosk["FR"], dtype=self.dtype, device=self.device) + DVS = torch.as_tensor(self.kiosk["DVS"], dtype=self.dtype, device=self.device) params = self.params self.params_shape = _get_params_shape(params) # Initial leaf biomass - WLV = (params.TDWI * (1 - FR)) * FL - DWLV = torch.zeros(self.params_shape, dtype=DTYPE) + TDWI = _broadcast_to(params.TDWI, self.params_shape, dtype=self.dtype, device=self.device) + WLV = (TDWI * (1 - FR)) * FL + DWLV = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) TWLV = WLV + DWLV # Initialize leaf classes (SLA, age and weight) - SLA = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) - LVAGE = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) - LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=DTYPE) - SLA[..., 0] = params.SLATB(DVS) + SLA = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device) + LVAGE = torch.zeros( + (*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device + ) + LV = torch.zeros((*self.params_shape, self.MAX_DAYS), dtype=self.dtype, device=self.device) + SLA[..., 0] = params.SLATB(DVS).to(dtype=self.dtype, device=self.device) LV[..., 0] = WLV # Initial values for leaf area @@ -206,7 +293,9 @@ def initialize( LASUM = LAIEM LAIEXP = LAIEM LAIMAX = LAIEM - LAI = LASUM + self.kiosk["SAI"] + self.kiosk["PAI"] + SAI = torch.as_tensor(self.kiosk["SAI"], dtype=self.dtype, device=self.device) + PAI = torch.as_tensor(self.kiosk["PAI"], dtype=self.dtype, device=self.device) + LAI = LASUM + SAI + PAI # Initialize StateVariables object self.states = self.StateVariables( @@ -249,19 +338,24 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask # A mask (0 if DVS < 0, 1 if DVS >= 0) - DVS = torch.as_tensor(k["DVS"], dtype=DTYPE) - dvs_mask = (DVS >= 0).to(dtype=DTYPE) + DVS = torch.as_tensor(k["DVS"], dtype=self.dtype, device=self.device) + dvs_mask = (DVS >= 0).to(dtype=self.dtype).to(device=self.device) # Growth rate leaves # weight of new leaves r.GRLV = dvs_mask * k.ADMI * k.FL # death of leaves due to water/oxygen stress - r.DSLV1 = dvs_mask * s.WLV * (1.0 - k.RFTRA) * p.PERDL + RFTRA = _broadcast_to(k.RFTRA, self.params_shape, dtype=self.dtype, device=self.device) + PERDL = _broadcast_to(p.PERDL, self.params_shape, dtype=self.dtype, device=self.device) + r.DSLV1 = dvs_mask * s.WLV * (1.0 - RFTRA) * PERDL # death due to self shading cause by high LAI - DVS = self.kiosk["DVS"] - LAICR = 3.2 / p.KDIFTB(DVS) + DVS = _broadcast_to( + self.kiosk["DVS"], self.params_shape, dtype=self.dtype, device=self.device + ) + KDIFTB = p.KDIFTB.to(device=self.device, dtype=self.dtype) + LAICR = 3.2 / KDIFTB(DVS) r.DSLV2 = dvs_mask * s.WLV * torch.clamp(0.03 * (s.LAI - LAICR) / LAICR, 0.0, 0.03) # Death of leaves due to frost damage as determined by @@ -269,7 +363,7 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: if "RF_FROST" in self.kiosk: r.DSLV3 = s.WLV * k.RF_FROST else: - r.DSLV3 = torch.zeros_like(s.WLV, dtype=DTYPE) + r.DSLV3 = torch.zeros_like(s.WLV, dtype=self.dtype) r.DSLV3 = dvs_mask * r.DSLV3 @@ -282,7 +376,9 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # in DALV. # Note that the actual leaf death is imposed on the array LV during the # state integration step. - tSPAN = _broadcast_to(p.SPAN, s.LVAGE.shape) # Broadcast to same shape + tSPAN = _broadcast_to( + p.SPAN, s.LVAGE.shape, dtype=self.dtype, device=self.device + ) # Broadcast to same shape # Using a sigmoid here instead of a conditional statement on the value of # SPAN because the latter would not allow for the gradient to be tracked. @@ -292,14 +388,13 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: if p.SPAN.requires_grad: # 1e-16 is chosen empirically for cases when s.LVAGE - tSPAN is very # small and mask should be 1 - sharpness = torch.tensor(1e-16, dtype=DTYPE) - # 1e-14 is chosen empirically for cases when s.LVAGE - tSPAN is # equal to zero and mask should be 0.0 - epsilon = 1e-14 - span_mask = torch.sigmoid((s.LVAGE - tSPAN - epsilon) / sharpness).to(dtype=DTYPE) + span_mask = torch.sigmoid( + (s.LVAGE - tSPAN - self._sigmoid_epsilon) / self._sigmoid_sharpness + ).to(dtype=self.dtype) else: - span_mask = (s.LVAGE > tSPAN).to(dtype=DTYPE) + span_mask = (s.LVAGE > tSPAN).to(dtype=self.dtype) r.DALV = torch.sum(span_mask * s.LV, dim=-1) r.DALV = dvs_mask * r.DALV @@ -308,15 +403,16 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: r.DRLV = torch.maximum(r.DSLV, r.DALV) # Get the temperature from the drv - TEMP = _get_drv(drv.TEMP, self.params_shape) + TEMP = _get_drv(drv.TEMP, self.params_shape, self.dtype, self.device) # physiologic ageing of leaves per time step - TBASE = _broadcast_to(p.TBASE, self.params_shape) + TBASE = _broadcast_to(p.TBASE, self.params_shape, dtype=self.dtype, device=self.device) FYSAGE = (TEMP - TBASE) / (35.0 - TBASE) r.FYSAGE = dvs_mask * torch.clamp(FYSAGE, 0.0) # specific leaf area of leaves per time step - r.SLAT = dvs_mask * p.SLATB(DVS) + SLATB = p.SLATB.to(device=self.device, dtype=self.dtype) + r.SLAT = dvs_mask * SLATB(DVS) # leaf area not to exceed exponential growth curve is_lai_exp = s.LAIEXP < 6.0 @@ -326,28 +422,30 @@ def calc_rates(self, day: datetime.date, drv: WeatherDataContainer) -> None: # tracked through the condition. Thus, the gradient with respect to # parameters that contribute to `is_lai_exp` (e.g. RGRLAI and TBASE) # are expected to be incorrect. + RGRLAI = _broadcast_to(p.RGRLAI, self.params_shape, dtype=self.dtype, device=self.device) r.GLAIEX = torch.where( dvs_mask.bool(), - torch.where(is_lai_exp, s.LAIEXP * p.RGRLAI * DTEFF, r.GLAIEX), - torch.tensor(0.0, dtype=DTYPE), + torch.where(is_lai_exp, s.LAIEXP * RGRLAI * DTEFF, r.GLAIEX), + self._zero, ) # source-limited increase in leaf area r.GLASOL = torch.where( dvs_mask.bool(), torch.where(is_lai_exp, r.GRLV * r.SLAT, r.GLASOL), - torch.tensor(0.0, dtype=DTYPE), + self._zero, ) # sink-limited increase in leaf area GLA = torch.minimum(r.GLAIEX, r.GLASOL) # adjustment of specific leaf area of youngest leaf class - epsilon = 1e-10 # small value to avoid division by zero r.SLAT = torch.where( dvs_mask.bool(), - torch.where(is_lai_exp & (r.GRLV > epsilon), GLA / (r.GRLV + epsilon), r.SLAT), - torch.tensor(0.0, dtype=DTYPE), + torch.where( + is_lai_exp & (r.GRLV > self._epsilon), GLA / (r.GRLV + self._epsilon), r.SLAT + ), + self._zero, ) @prepare_states @@ -366,7 +464,7 @@ def integrate(self, day: datetime.date, delt=1.0) -> None: tLV = states.LV.clone() tSLA = states.SLA.clone() tLVAGE = states.LVAGE.clone() - tDRLV = _broadcast_to(rates.DRLV, tLV.shape) + tDRLV = _broadcast_to(rates.DRLV, tLV.shape, dtype=self.dtype, device=self.device) # Leaf death is imposed on leaves from the oldest ones. # Calculate the cumulative sum of weights after leaf death, and @@ -377,7 +475,9 @@ def integrate(self, day: datetime.date, delt=1.0) -> None: # Adjust value of oldest leaf class, i.e. the first non-zero # weight along the time axis (the last dimension). # Cast argument to int because torch.argmax requires it to be numeric - idx_oldest = torch.argmax(is_alive.type(torch.int), dim=-1, keepdim=True) + idx_oldest = torch.argmax(is_alive.type(torch.int), dim=-1, keepdim=True).to( + device=self.device + ) new_biomass = torch.take_along_dim(weight_cumsum, indices=idx_oldest, dim=-1) tLV = torch.scatter(tLV, dim=-1, index=idx_oldest, src=new_biomass) diff --git a/src/diffwofost/physical_models/crop/phenology.py b/src/diffwofost/physical_models/crop/phenology.py index db7fb2c..bc30b64 100644 --- a/src/diffwofost/physical_models/crop/phenology.py +++ b/src/diffwofost/physical_models/crop/phenology.py @@ -20,6 +20,7 @@ from pcse.traitlets import Enum from pcse.traitlets import Instance from pcse.util import daylength +from diffwofost.physical_models.config import ComputeConfig from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_drv @@ -27,9 +28,6 @@ from diffwofost.physical_models.utils import _restore_state from diffwofost.physical_models.utils import _snapshot_state -DTYPE = torch.float64 # Default data type for tensors in this module -EPS = torch.tensor(1e-8, dtype=DTYPE) # Small epsilon to avoid div by zero - class Vernalisation(SimulationObject): """Modification of phenological development due to vernalisation. @@ -94,31 +92,72 @@ class Vernalisation(SimulationObject): params_shape = None # Shape of the parameters tensors + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + class Parameters(ParamTemplate): - VERNSAT = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Saturated vernalisation requirements - VERNBASE = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Base vernalisation requirements - VERNRTB = AfgenTrait() # Vernalisation temperature response - VERNDVS = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Critical DVS for vernalisation fulfillment + VERNSAT = Any() + VERNBASE = Any() + VERNRTB = AfgenTrait() + VERNDVS = Any() + + def __init__(self, parvalues): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values using the ComputeConfig dtype and device + self.VERNSAT = torch.tensor(-99.0, dtype=dtype, device=device) + self.VERNBASE = torch.tensor(-99.0, dtype=dtype, device=device) + self.VERNDVS = torch.tensor(-99.0, dtype=dtype, device=device) + self.VERNRTB = self.VERNRTB.to(device=device, dtype=dtype) + + # Call parent init + super().__init__(parvalues) class RateVariables(RatesTemplate): - VERNR = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) # Rate of vernalisation - VERNFAC = Any( - default_value=torch.tensor(0.0, dtype=DTYPE) - ) # Red. factor for phenol. devel. + VERNR = Any() + VERNFAC = Any() + + def __init__(self, kiosk, publish=None): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values using the ComputeConfig dtype and device + self.VERNR = torch.tensor(0.0, dtype=dtype, device=device) + self.VERNFAC = torch.tensor(0.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish) class StateVariables(StatesTemplate): - VERN = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Vernalisation state - DOV = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Day ordinal when vernalisation fulfilled - ISVERNALISED = Any(default_value=torch.tensor(False)) # True when VERNSAT is reached and - # Forced when DVS > VERNDVS + VERN = Any() + DOV = Any() + ISVERNALISED = Any() + + def __init__(self, kiosk, publish=None, **kwargs): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values using the ComputeConfig dtype and device if not in kwargs + if "VERN" not in kwargs: + self.VERN = torch.tensor(-99.0, dtype=dtype, device=device) + if "DOV" not in kwargs: + self.DOV = torch.tensor(-99.0, dtype=dtype, device=device) + if "ISVERNALISED" not in kwargs: + self.ISVERNALISED = torch.tensor(False, dtype=torch.bool, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish, **kwargs) def initialize(self, day, kiosk, parvalues, dvs_shape=None): """Initialize the Vernalisation sub-module. @@ -142,6 +181,9 @@ def initialize(self, day, kiosk, parvalues, dvs_shape=None): """ self.params = self.Parameters(parvalues) self.params_shape = _get_params_shape(self.params) + + # Small epsilon tensor reused in multiple safe divisions. + self._epsilon = torch.tensor(1e-8, dtype=self.dtype, device=self.device) if dvs_shape is not None: if self.params_shape == (): self.params_shape = dvs_shape @@ -150,27 +192,46 @@ def initialize(self, day, kiosk, parvalues, dvs_shape=None): f"Vernalisation params shape {self.params_shape}" + " incompatible with dvs_shape {dvs_shape}" ) + + # Common constant tensors (same shape/dtype/device as this module). + self._ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) + self._zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) # Explicitly initialize rates self.rates = self.RateVariables(kiosk, publish=["VERNFAC"]) - self.rates.VERNR = _broadcast_to(self.rates.VERNR, self.params_shape) - self.rates.VERNFAC = _broadcast_to(self.rates.VERNFAC, self.params_shape) + self.rates.VERNR = _broadcast_to( + self.rates.VERNR, self.params_shape, dtype=self.dtype, device=self.device + ) + self.rates.VERNFAC = _broadcast_to( + self.rates.VERNFAC, self.params_shape, dtype=self.dtype, device=self.device + ) self.kiosk = kiosk # Explicitly broadcast all parameters to params_shape - self.params.VERNSAT = _broadcast_to(self.params.VERNSAT, self.params_shape) - self.params.VERNBASE = _broadcast_to(self.params.VERNBASE, self.params_shape) - self.params.VERNDVS = _broadcast_to(self.params.VERNDVS, self.params_shape) + self.params.VERNSAT = _broadcast_to( + self.params.VERNSAT, self.params_shape, dtype=self.dtype, device=self.device + ) + self.params.VERNBASE = _broadcast_to( + self.params.VERNBASE, self.params_shape, dtype=self.dtype, device=self.device + ) + self.params.VERNDVS = _broadcast_to( + self.params.VERNDVS, self.params_shape, dtype=self.dtype, device=self.device + ) + self.params.VERNRTB = self.params.VERNRTB.to(device=self.device, dtype=self.dtype) # Define initial states self.states = self.StateVariables( kiosk, - VERN=torch.zeros(self.params_shape, dtype=DTYPE), - DOV=torch.full(self.params_shape, -1.0, dtype=DTYPE), # -1 indicates not yet fulfilled - ISVERNALISED=torch.zeros(self.params_shape, dtype=torch.bool), + VERN=torch.zeros(self.params_shape, dtype=self.dtype, device=self.device), + DOV=torch.full( + self.params_shape, -1.0, dtype=self.dtype, device=self.device + ), # -1 indicates not yet fulfilled + ISVERNALISED=torch.zeros(self.params_shape, dtype=torch.bool, device=self.device), publish=["ISVERNALISED"], ) # Per-element force flag (False for all elements initially) - self._force_vernalisation = torch.zeros(self.params_shape, dtype=torch.bool) + self._force_vernalisation = torch.zeros( + self.params_shape, dtype=torch.bool, device=self.device + ) @prepare_rates def calc_rates(self, day, drv): @@ -192,7 +253,7 @@ def calc_rates(self, day, drv): VERNBASE = params.VERNBASE DVS = self.kiosk["DVS"] - TEMP = _get_drv(drv.TEMP, self.params_shape) + TEMP = _get_drv(drv.TEMP, self.params_shape, self.dtype, self.device) # Operate elementwise only on elements not yet vernalised not_vernalised = ~self.states.ISVERNALISED @@ -201,16 +262,20 @@ def calc_rates(self, day, drv): # VERNR only for vegetative elements self.rates.VERNR = torch.where( - vegetative_mask, params.VERNRTB(TEMP), torch.zeros(self.params_shape, dtype=DTYPE) + vegetative_mask, + params.VERNRTB(TEMP), + self._zeros, ) # compute VERNFAC from current VERN for vegetative elements; others = 1 safe_den = VERNSAT - VERNBASE - safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), self._epsilon) r = (self.states.VERN - VERNBASE) / safe_den vernfac_computed = torch.clamp(r, 0.0, 1.0) self.rates.VERNFAC = torch.where( - vegetative_mask, vernfac_computed, torch.ones(self.params_shape, dtype=DTYPE) + vegetative_mask, + vernfac_computed, + self._ones, ) # mark per-element force flags for elements that passed VERNDVS but aren't vernalised @@ -254,7 +319,9 @@ def integrate(self, day, delt=1.0): if torch.any(newly_reached_and_no_dov): states.DOV = torch.where( newly_reached_and_no_dov, - torch.full(self.params_shape, day.toordinal(), dtype=DTYPE), + torch.full( + self.params_shape, day.toordinal(), dtype=self.dtype, device=self.device + ), states.DOV, ) self.logger.info(f"Vernalization requirements reached at day {day}.") @@ -357,55 +424,137 @@ class DVS_Phenology(SimulationObject): params_shape = None # Shape of the parameters tensors + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + class Parameters(ParamTemplate): - TSUMEM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Temp. sum for emergence - TBASEM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Base temp. for emergence - TEFFMX = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Max eff temperature for emergence - TSUM1 = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Temperature sum emergence to anthesis - TSUM2 = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Temperature sum anthesis to maturity - IDSL = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Switch for photoperiod (1) and vernalisation (2) - DLO = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Optimal day length for phenol. development - DLC = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Critical day length for phenol. development - DVSI = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Initial development stage - DVSEND = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Final development stage - DTSMTB = AfgenTrait() # Temperature response function for phenol. - # development. + TSUMEM = Any() + TBASEM = Any() + TEFFMX = Any() + TSUM1 = Any() + TSUM2 = Any() + IDSL = Any() + DLO = Any() + DLC = Any() + DVSI = Any() + DVSEND = Any() + DTSMTB = AfgenTrait() CROP_START_TYPE = Enum(["sowing", "emergence"]) CROP_END_TYPE = Enum(["maturity", "harvest", "earliest"]) + def __init__(self, parvalues): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values using the ComputeConfig dtype and device + self.TSUMEM = torch.tensor(-99.0, dtype=dtype, device=device) + self.TBASEM = torch.tensor(-99.0, dtype=dtype, device=device) + self.TEFFMX = torch.tensor(-99.0, dtype=dtype, device=device) + self.TSUM1 = torch.tensor(-99.0, dtype=dtype, device=device) + self.TSUM2 = torch.tensor(-99.0, dtype=dtype, device=device) + self.IDSL = torch.tensor(-99.0, dtype=dtype, device=device) + self.DLO = torch.tensor(-99.0, dtype=dtype, device=device) + self.DLC = torch.tensor(-99.0, dtype=dtype, device=device) + self.DVSI = torch.tensor(-99.0, dtype=dtype, device=device) + self.DVSEND = torch.tensor(-99.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(parvalues) + class RateVariables(RatesTemplate): - DTSUME = Any( - default_value=torch.tensor(0.0, dtype=DTYPE) - ) # increase in temperature sum for emergence - DTSUM = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) # increase in temperature sum - DVR = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) # development rate + DTSUME = Any() + DTSUM = Any() + DVR = Any() + + def __init__(self, kiosk, publish=None): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.DTSUME = torch.tensor(0.0, dtype=dtype, device=device) + self.DTSUM = torch.tensor(0.0, dtype=dtype, device=device) + self.DVR = torch.tensor(0.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish) class StateVariables(StatesTemplate): - DVS = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Development stage - TSUM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Temperature sum state - TSUME = Any( - default_value=torch.tensor(-99.0, dtype=DTYPE) - ) # Temperature sum for emergence state - # States which register phenological events as day ordinals (tensor of floats) - DOS = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of sowing (ordinal) - DOE = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of emergence (ordinal) - DOA = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of anthesis (ordinal) - DOM = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of maturity (ordinal) - DOH = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) # Day of harvest (ordinal) - # STAGE as integer tensor: 0=emerging, 1=vegetative, 2=reproductive, 3=mature - STAGE = Any(default_value=torch.tensor(-99, dtype=torch.long)) + DVS = Any() + TSUM = Any() + TSUME = Any() + DOS = Any() + DOE = Any() + DOA = Any() + DOM = Any() + DOH = Any() + STAGE = Any() + + def __init__(self, kiosk, publish=None, **kwargs): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + if "DVS" not in kwargs: + self.DVS = torch.tensor(-99.0, dtype=dtype, device=device) + if "TSUM" not in kwargs: + self.TSUM = torch.tensor(-99.0, dtype=dtype, device=device) + if "TSUME" not in kwargs: + self.TSUME = torch.tensor(-99.0, dtype=dtype, device=device) + if "DOS" not in kwargs: + self.DOS = torch.tensor(-99.0, dtype=dtype, device=device) + if "DOE" not in kwargs: + self.DOE = torch.tensor(-99.0, dtype=dtype, device=device) + if "DOA" not in kwargs: + self.DOA = torch.tensor(-99.0, dtype=dtype, device=device) + if "DOM" not in kwargs: + self.DOM = torch.tensor(-99.0, dtype=dtype, device=device) + if "DOH" not in kwargs: + self.DOH = torch.tensor(-99.0, dtype=dtype, device=device) + if "STAGE" not in kwargs: + self.STAGE = torch.tensor(-99, dtype=torch.long, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish, **kwargs) + + def _cast_and_broadcast_params(self): + """Cast and broadcast all parameters to params_shape with correct dtype/device. + + This ensures all parameters have consistent shape, dtype, and device. + Necessary if Vernalisation changes the params_shape during initialization. + """ + p = self.params + # Broadcast numeric parameters to the final params_shape and ensure dtype/device. + for name in ( + "TSUMEM", + "TBASEM", + "TEFFMX", + "TSUM1", + "TSUM2", + "IDSL", + "DLO", + "DLC", + "DVSI", + "DVSEND", + ): + setattr( + p, + name, + _broadcast_to(getattr(p, name), self.params_shape, self.dtype, self.device), + ) + + # Move AFGEN table buffers, if present. + if hasattr(p, "DTSMTB") and hasattr(p.DTSMTB, "to"): + p.DTSMTB.to(device=self.device, dtype=self.dtype) def initialize(self, day, kiosk, parvalues): """:param day: start date of the simulation @@ -419,7 +568,10 @@ def initialize(self, day, kiosk, parvalues): # Initialize vernalisation for IDSL>=2 # It has to be done in advance to get the correct params_shape - IDSL = _broadcast_to(self.params.IDSL, self.params_shape) + IDSL = _broadcast_to( + self.params.IDSL, self.params_shape, dtype=self.dtype, device=self.device + ) + self.params.IDSL = IDSL if torch.any(IDSL >= 2): if self.params_shape != (): self.vernalisation = Vernalisation( @@ -432,6 +584,14 @@ def initialize(self, day, kiosk, parvalues): else: self.vernalisation = None + # After Vernalisation initialization the final params_shape may have changed. + self._cast_and_broadcast_params() + + # Create scalar constants once at the beginning to avoid recreating them + self._ones = torch.ones(self.params_shape, dtype=self.dtype, device=self.device) + self._zeros = torch.zeros(self.params_shape, dtype=self.dtype, device=self.device) + self._epsilon = torch.tensor(1e-8, dtype=self.dtype, device=self.device) + # Initialize rates and kiosk self.rates = self.RateVariables(kiosk) self.kiosk = kiosk @@ -440,19 +600,23 @@ def initialize(self, day, kiosk, parvalues): # Define initial states DVS, DOS, DOE, STAGE = self._get_initial_stage(day) - DVS = _broadcast_to(DVS, self.params_shape) + DVS = _broadcast_to(DVS, self.params_shape, dtype=self.dtype, device=self.device) # Initialize all date tensors with -1 (not yet occurred) - DOS = _broadcast_to(DOS, self.params_shape) - DOE = _broadcast_to(DOE, self.params_shape) - DOA = torch.full(self.params_shape, -1.0, dtype=DTYPE) - DOM = torch.full(self.params_shape, -1.0, dtype=DTYPE) - DOH = torch.full(self.params_shape, -1.0, dtype=DTYPE) - STAGE = _broadcast_to(STAGE, self.params_shape) + DOS = _broadcast_to(DOS, self.params_shape, dtype=self.dtype, device=self.device) + DOE = _broadcast_to(DOE, self.params_shape, dtype=self.dtype, device=self.device) + DOA = torch.full(self.params_shape, -1.0, dtype=self.dtype, device=self.device) + DOM = torch.full(self.params_shape, -1.0, dtype=self.dtype, device=self.device) + DOH = torch.full(self.params_shape, -1.0, dtype=self.dtype, device=self.device) + STAGE = _broadcast_to(STAGE, self.params_shape, dtype=self.dtype, device=self.device) # Also ensure TSUM and TSUME are properly shaped - TSUM = torch.zeros(self.params_shape, dtype=DTYPE, requires_grad=True) - TSUME = torch.zeros(self.params_shape, dtype=DTYPE, requires_grad=True) + TSUM = torch.zeros( + self.params_shape, dtype=self.dtype, device=self.device, requires_grad=True + ) + TSUME = torch.zeros( + self.params_shape, dtype=self.dtype, device=self.device, requires_grad=True + ) self.states = self.StateVariables( kiosk, @@ -483,26 +647,26 @@ def _get_initial_stage(self, day): STAGE (Tensor): Integer stage code (0=emerging, 1=vegetative). """ p = self.params - day_ordinal = torch.tensor(day.toordinal(), dtype=DTYPE) + day_ordinal = torch.tensor(day.toordinal(), dtype=self.dtype, device=self.device) # Define initial stage type (emergence/sowing) and fill the # respective day of sowing/emergence (DOS/DOE) if p.CROP_START_TYPE == "emergence": - STAGE = torch.tensor(1, dtype=torch.long) # 1 = vegetative + STAGE = torch.tensor(1, dtype=torch.long, device=self.device) # 1 = vegetative DOE = day_ordinal - DOS = torch.tensor(-1.0, dtype=DTYPE) # Not applicable + DOS = torch.tensor(-1.0, dtype=self.dtype, device=self.device) # Not applicable DVS = p.DVSI if not isinstance(DVS, torch.Tensor): - DVS = torch.tensor(DVS, dtype=DTYPE) + DVS = torch.tensor(DVS, dtype=self.dtype, device=self.device) # send signal to indicate crop emergence self._send_signal(signals.crop_emerged) elif p.CROP_START_TYPE == "sowing": - STAGE = torch.tensor(0, dtype=torch.long) # 0 = emerging + STAGE = torch.tensor(0, dtype=torch.long, device=self.device) # 0 = emerging DOS = day_ordinal - DOE = torch.tensor(-1.0, dtype=DTYPE) # Not yet occurred - DVS = torch.tensor(-0.1, dtype=DTYPE) + DOE = torch.tensor(-1.0, dtype=self.dtype, device=self.device) # Not yet occurred + DVS = torch.tensor(-0.1, dtype=self.dtype, device=self.device) else: msg = f"Unknown start type: {p.CROP_START_TYPE}" @@ -541,30 +705,32 @@ def calc_rates(self, day, drv): # Day length sensitivity DAYLP = daylength(day, drv.LAT) - DAYLP_t = _broadcast_to(DAYLP, shape) + DAYLP_t = _broadcast_to(DAYLP, shape, dtype=self.dtype, device=self.device) # Compute DVRED conditionally based on IDSL >= 1 safe_den = p.DLO - p.DLC - safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), self._epsilon) dvred_active = torch.clamp((DAYLP_t - p.DLC) / safe_den, 0.0, 1.0) - DVRED = torch.where(p.IDSL >= 1, dvred_active, torch.ones(shape, dtype=DTYPE)) + DVRED = torch.where(p.IDSL >= 1, dvred_active, self._ones) # Vernalisation factor - always compute if module exists - VERNFAC = torch.ones(shape, dtype=DTYPE) + VERNFAC = self._ones if hasattr(self, "vernalisation") and self.vernalisation is not None: # Always call calc_rates (it handles stage internally now) self.vernalisation.calc_rates(day, drv) # Apply vernalisation only where IDSL >= 2 AND in vegetative stage is_vegetative = s.STAGE == 1 VERNFAC = torch.where( - (p.IDSL >= 2) & is_vegetative, self.kiosk["VERNFAC"], torch.ones(shape, dtype=DTYPE) + (p.IDSL >= 2) & is_vegetative, + self.kiosk["VERNFAC"], + self._ones, ) - TEMP = _get_drv(drv.TEMP, shape) + TEMP = _get_drv(drv.TEMP, shape, self.dtype, self.device) # Initialize all rate variables - r.DTSUME = torch.zeros(shape, dtype=DTYPE) - r.DTSUM = torch.zeros(shape, dtype=DTYPE) - r.DVR = torch.zeros(shape, dtype=DTYPE) + r.DTSUME = self._zeros + r.DTSUM = self._zeros + r.DVR = self._zeros # Compute rates for emerging stage (STAGE == 0) is_emerging = s.STAGE == 0 @@ -575,7 +741,7 @@ def calc_rates(self, day, drv): dtsume_emerging = torch.clamp(temp_diff, min=0.0) dtsume_emerging = torch.minimum(dtsume_emerging, max_diff) safe_den = p.TSUMEM - safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), self._epsilon) dvr_emerging = 0.1 * dtsume_emerging / safe_den r.DTSUME = torch.where(is_emerging, dtsume_emerging, r.DTSUME) @@ -586,7 +752,7 @@ def calc_rates(self, day, drv): if torch.any(is_vegetative): dtsum_vegetative = p.DTSMTB(TEMP) * VERNFAC * DVRED safe_den = p.TSUM1 - safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), self._epsilon) dvr_vegetative = dtsum_vegetative / safe_den r.DTSUM = torch.where(is_vegetative, dtsum_vegetative, r.DTSUM) @@ -597,7 +763,7 @@ def calc_rates(self, day, drv): if torch.any(is_reproductive): dtsum_reproductive = p.DTSMTB(TEMP) safe_den = p.TSUM2 - safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), EPS) + safe_den = safe_den.sign() * torch.maximum(torch.abs(safe_den), self._epsilon) dvr_reproductive = dtsum_reproductive / safe_den r.DTSUM = torch.where(is_reproductive, dtsum_reproductive, r.DTSUM) @@ -673,13 +839,19 @@ def integrate(self, day, delt=1.0): s.DVS = s.DVS + r.DVR s.TSUM = s.TSUM + r.DTSUM - day_ordinal = torch.tensor(day.toordinal(), dtype=DTYPE) + day_ordinal = torch.tensor(day.toordinal(), dtype=self.dtype, device=self.device) # Check transitions for emerging -> vegetative (STAGE 0 -> 1) is_emerging = s.STAGE == 0 should_emerge = is_emerging & (s.DVS >= 0.0) - s.STAGE = torch.where(should_emerge, torch.ones(shape, dtype=torch.long), s.STAGE) - s.DOE = torch.where(should_emerge, torch.full(shape, day_ordinal, dtype=DTYPE), s.DOE) + s.STAGE = torch.where( + should_emerge, torch.ones(shape, dtype=torch.long, device=self.device), s.STAGE + ) + s.DOE = torch.where( + should_emerge, + torch.full(shape, day_ordinal, dtype=self.dtype, device=self.device), + s.DOE, + ) s.DVS = torch.where(should_emerge, torch.clamp(s.DVS, max=0.0), s.DVS) # Send signal if any crop emerged (only once per day) @@ -689,15 +861,27 @@ def integrate(self, day, delt=1.0): # Check transitions for vegetative -> reproductive (STAGE 1 -> 2) is_vegetative = s.STAGE == 1 should_flower = is_vegetative & (s.DVS >= 1.0) - s.STAGE = torch.where(should_flower, torch.full(shape, 2, dtype=torch.long), s.STAGE) - s.DOA = torch.where(should_flower, torch.full(shape, day_ordinal, dtype=DTYPE), s.DOA) + s.STAGE = torch.where( + should_flower, torch.full(shape, 2, dtype=torch.long, device=self.device), s.STAGE + ) + s.DOA = torch.where( + should_flower, + torch.full(shape, day_ordinal, dtype=self.dtype, device=self.device), + s.DOA, + ) s.DVS = torch.where(should_flower, torch.clamp(s.DVS, max=1.0), s.DVS) # Check transitions for reproductive -> mature (STAGE 2 -> 3) is_reproductive = s.STAGE == 2 should_mature = is_reproductive & (s.DVS >= p.DVSEND) - s.STAGE = torch.where(should_mature, torch.full(shape, 3, dtype=torch.long), s.STAGE) - s.DOM = torch.where(should_mature, torch.full(shape, day_ordinal, dtype=DTYPE), s.DOM) + s.STAGE = torch.where( + should_mature, torch.full(shape, 3, dtype=torch.long, device=self.device), s.STAGE + ) + s.DOM = torch.where( + should_mature, + torch.full(shape, day_ordinal, dtype=self.dtype, device=self.device), + s.DOM, + ) s.DVS = torch.where(should_mature, torch.minimum(s.DVS, p.DVSEND), s.DVS) # Send crop_finish signal if maturity reached for one. @@ -730,8 +914,10 @@ def _on_CROP_FINISH(self, day, finish_type=None): """ if finish_type in ["harvest", "earliest"]: - day_ordinal = torch.tensor(day.toordinal(), dtype=DTYPE) - self._for_finalize["DOH"] = torch.full(self.params_shape, day_ordinal, dtype=DTYPE) + day_ordinal = torch.tensor(day.toordinal(), dtype=self.dtype, device=self.device) + self._for_finalize["DOH"] = torch.full( + self.params_shape, day_ordinal, dtype=self.dtype, device=self.device + ) def get_variable(self, varname): # TODO: should be removed while fixing #49. this is needed because diff --git a/src/diffwofost/physical_models/crop/root_dynamics.py b/src/diffwofost/physical_models/crop/root_dynamics.py index 6143609..b38604f 100644 --- a/src/diffwofost/physical_models/crop/root_dynamics.py +++ b/src/diffwofost/physical_models/crop/root_dynamics.py @@ -1,4 +1,4 @@ -from datetime import datetime +import datetime import torch from pcse.base import ParamTemplate from pcse.base import RatesTemplate @@ -10,12 +10,11 @@ from pcse.decorators import prepare_rates from pcse.decorators import prepare_states from pcse.traitlets import Any +from diffwofost.physical_models.config import ComputeConfig from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import _broadcast_to from diffwofost.physical_models.utils import _get_params_shape -DTYPE = torch.float64 # Default data type for tensors in this module - class WOFOST_Root_Dynamics(SimulationObject): """Root biomass dynamics and rooting depth. @@ -117,27 +116,89 @@ class WOFOST_Root_Dynamics(SimulationObject): better and more biophysical approach to root development in WOFOST. """ # noqa: E501 + params_shape = None # Shape of the parameters tensors + + @property + def device(self): + """Get device from ComputeConfig.""" + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + return ComputeConfig.get_dtype() + class Parameters(ParamTemplate): - RDI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - RRI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - RDMCR = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - RDMSOL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - IAIRDU = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) + RDI = Any() + RRI = Any() + RDMCR = Any() + RDMSOL = Any() + TDWI = Any() + IAIRDU = Any() RDRRTB = AfgenTrait() + def __init__(self, parvalues): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.RDI = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.RRI = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.RDMCR = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.RDMSOL = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.TDWI = [torch.tensor(-99.0, dtype=dtype, device=device)] + self.IAIRDU = [torch.tensor(-99.0, dtype=dtype, device=device)] + + # Call parent init + super().__init__(parvalues) + class RateVariables(RatesTemplate): - RR = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - GRRT = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - DRRT = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) - GWRT = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) + RR = Any() + GRRT = Any() + DRRT = Any() + GWRT = Any() + + def __init__(self, kiosk, publish=None): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + self.RR = torch.tensor(0.0, dtype=dtype, device=device) + self.GRRT = torch.tensor(0.0, dtype=dtype, device=device) + self.DRRT = torch.tensor(0.0, dtype=dtype, device=device) + self.GWRT = torch.tensor(0.0, dtype=dtype, device=device) + + # Call parent init + super().__init__(kiosk, publish=publish) class StateVariables(StatesTemplate): - RD = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - RDM = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - WRT = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - DWRT = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) - TWRT = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) + RD = Any() + RDM = Any() + WRT = Any() + DWRT = Any() + TWRT = Any() + + def __init__(self, kiosk, publish=None, **kwargs): + # Get dtype and device from ComputeConfig + dtype = ComputeConfig.get_dtype() + device = ComputeConfig.get_device() + + # Set default values + if "RD" not in kwargs: + self.RD = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "RDM" not in kwargs: + self.RDM = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "WRT" not in kwargs: + self.WRT = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "DWRT" not in kwargs: + self.DWRT = [torch.tensor(-99.0, dtype=dtype, device=device)] + if "TWRT" not in kwargs: + self.TWRT = [torch.tensor(-99.0, dtype=dtype, device=device)] + + # Call parent init + super().__init__(kiosk, publish=publish, **kwargs) def initialize( self, day: datetime.date, kiosk: VariableKiosk, parvalues: ParameterProvider @@ -153,22 +214,29 @@ def initialize( all parameter sets (crop, soil, site) as key/value. The values are arrays or scalars. See PCSE documentation for details. """ + self.kiosk = kiosk self.params = self.Parameters(parvalues) self.rates = self.RateVariables(kiosk, publish=["DRRT", "GRRT"]) - self.kiosk = kiosk # INITIAL STATES params = self.params - shape = _get_params_shape(params) + self.params_shape = _get_params_shape(params) + shape = self.params_shape # Initial root depth states - rdmax = torch.max(params.RDI, torch.min(params.RDMCR, params.RDMSOL)) - RDM = _broadcast_to(rdmax, shape) - RD = _broadcast_to(params.RDI, shape) + RDI = _broadcast_to(params.RDI, shape, dtype=self.dtype, device=self.device) + RDMCR = _broadcast_to(params.RDMCR, shape, dtype=self.dtype, device=self.device) + RDMSOL = _broadcast_to(params.RDMSOL, shape, dtype=self.dtype, device=self.device) + + rdmax = torch.maximum(RDI, torch.minimum(RDMCR, RDMSOL)) + RDM = rdmax + RD = RDI # Initial root biomass states - WRT = _broadcast_to(params.TDWI * self.kiosk.FR, shape) - DWRT = torch.zeros_like(WRT) if shape else torch.zeros((), dtype=DTYPE) + TDWI = _broadcast_to(params.TDWI, shape, dtype=self.dtype, device=self.device) + FR = _broadcast_to(self.kiosk["FR"], shape, dtype=self.dtype, device=self.device) + WRT = TDWI * FR + DWRT = torch.zeros(shape, dtype=self.dtype, device=self.device) TWRT = WRT + DWRT self.states = self.StateVariables( @@ -190,23 +258,27 @@ def calc_rates(self, day: datetime.date = None, drv: WeatherDataContainer = None s = self.states k = self.kiosk - # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask + # If DVS < 0, the crop has not yet emerged, so we zerofy the rates using mask. # Make a mask (0 if DVS < 0, 1 if DVS >= 0) - DVS = torch.as_tensor(k["DVS"], dtype=DTYPE) - dvs_mask = (DVS >= 0).to(dtype=DTYPE) + DVS = _broadcast_to(k["DVS"], self.params_shape, dtype=self.dtype, device=self.device) + dvs_mask = (DVS >= 0).to(dtype=self.dtype) # Increase in root biomass - r.GRRT = dvs_mask * k.FR * k.DMI - r.DRRT = dvs_mask * s.WRT * p.RDRRTB(k.DVS) + FR = _broadcast_to(k["FR"], self.params_shape, dtype=self.dtype, device=self.device) + DMI = _broadcast_to(k["DMI"], self.params_shape, dtype=self.dtype, device=self.device) + RDRRTB = p.RDRRTB.to(device=self.device, dtype=self.dtype) + + r.GRRT = dvs_mask * FR * DMI + r.DRRT = dvs_mask * s.WRT * RDRRTB(DVS) r.GWRT = r.GRRT - r.DRRT # Increase in root depth - r.RR = dvs_mask * torch.min((s.RDM - s.RD), p.RRI) + RRI = _broadcast_to(p.RRI, self.params_shape, dtype=self.dtype, device=self.device) + r.RR = dvs_mask * torch.minimum((s.RDM - s.RD), RRI) # Do not let the roots growth if partioning to the roots # (variable FR) is zero. - FR = torch.as_tensor(k["FR"], dtype=DTYPE) - mask = (FR > 0.0).to(dtype=DTYPE) + mask = (FR > 0.0).to(dtype=self.dtype) r.RR = r.RR * mask * dvs_mask @prepare_states diff --git a/src/diffwofost/physical_models/utils.py b/src/diffwofost/physical_models/utils.py index ecd6157..4324587 100644 --- a/src/diffwofost/physical_models/utils.py +++ b/src/diffwofost/physical_models/utils.py @@ -28,8 +28,6 @@ from .config import Configuration from .engine import Engine -DTYPE = torch.float64 # Default data type for tensors in this module - logging.disable(logging.CRITICAL) @@ -98,6 +96,8 @@ def __init__( agromanagement, config, external_states=None, + device=None, + dtype=None, ): BaseEngine.__init__(self) @@ -109,6 +109,12 @@ def __init__( self.parameterprovider = parameterprovider + # Configure device and dtype on crop module class if it supports them + if hasattr(self.mconf.CROP, "device") and device is not None: + self.mconf.CROP.device = device + if hasattr(self.mconf.CROP, "dtype") and dtype is not None: + self.mconf.CROP.dtype = dtype + # Variable kiosk for registering and publishing variables self.kiosk = VariableKioskTestHelper(external_states) @@ -196,7 +202,7 @@ def __init__(self, yaml_weather, meteo_range_checks=True): def prepare_engine_input( - test_data, crop_model_params, meteo_range_checks=True, dtype=torch.float64 + test_data, crop_model_params, meteo_range_checks=True, dtype=torch.float64, device="cpu" ): """Prepare the inputs for the engine from the YAML file.""" agro_management_inputs = test_data["AgroManagement"] @@ -213,7 +219,7 @@ def prepare_engine_input( for name in crop_model_params: # if name is missing in the YAML, skip it if name in crop_model_params_provider: - value = torch.tensor(crop_model_params_provider[name], dtype=dtype) + value = torch.tensor(crop_model_params_provider[name], dtype=dtype, device=device) crop_model_params_provider.set_override(name, value, check=False) # convert external states to tensors @@ -274,6 +280,20 @@ class Afgen: Now supports batched tables (tensor of lists) for vectorized operations. """ + @property + def device(self): + """Get device from ComputeConfig.""" + from diffwofost.physical_models.config import ComputeConfig + + return ComputeConfig.get_device() + + @property + def dtype(self): + """Get dtype from ComputeConfig.""" + from diffwofost.physical_models.config import ComputeConfig + + return ComputeConfig.get_dtype() + def _check_x_ascending(self, tbl_xy): """Checks that the x values are strictly ascending. @@ -290,71 +310,37 @@ def _check_x_ascending(self, tbl_xy): Raises: ValueError: If x values are not strictly ascending. """ - # Handle batched tables (>1D tensors) - if tbl_xy.dim() > 1: - batch_shape = tbl_xy.shape[:-1] - table_len = tbl_xy.shape[-1] - - # Flatten batch dimensions for processing - flat_tables = tbl_xy.reshape(-1, table_len) - num_tables = flat_tables.shape[0] - valid_counts = [] - for idx in range(num_tables): - table = flat_tables[idx] - x_list = table[0::2] - y_list = table[1::2] - n = len(x_list) - - # Find trailing (0, 0) pairs to truncate - valid_n = n - for i in range(n - 1, 0, -1): - if x_list[i] == 0 and y_list[i] == 0: - valid_n = i - else: - break - - # Check if x range is strictly ascending - valid_x_list = x_list[:valid_n] - for i in range(1, len(valid_x_list)): - if valid_x_list[i] <= valid_x_list[i - 1]: - msg = ( - "X values for AFGEN input list" - + " not strictly ascending: {x_list.tolist()}" - ) - raise ValueError(msg) - - valid_counts.append(valid_n) - - return torch.tensor(valid_counts).reshape(batch_shape) - - # Original 1D logic from pcse - x_list = tbl_xy[0::2] - y_list = tbl_xy[1::2] - n = len(x_list) - - # Find trailing (0, 0) pairs to truncate - valid_n = n - for i in range(n - 1, 0, -1): - if x_list[i] == 0 and y_list[i] == 0: - valid_n = i - else: - break + def _valid_n_and_check(x_list: torch.Tensor, y_list: torch.Tensor) -> int: + # Truncate trailing (0,0) pairs. If all pairs are (0,0), keep first pair. + nonzero = ~(x_list.eq(0) & y_list.eq(0)) + last_valid = int(nonzero.nonzero()[-1].item()) if bool(nonzero.any()) else 0 + valid_n = last_valid + 1 - # Check only the valid (non-trailing-zero) portion - valid_x_list = x_list[:valid_n] + x_valid = x_list[:valid_n] + if x_valid.numel() > 1 and not bool(torch.all(torch.diff(x_valid) > 0)): + raise ValueError( + f"X values for AFGEN input list not strictly ascending: {x_list.tolist()}" + ) + return valid_n - # Check if x range is strictly ascending - for i in range(1, len(valid_x_list)): - if valid_x_list[i] <= valid_x_list[i - 1]: - msg = f"X values for AFGEN input list not strictly ascending: {x_list.tolist()}" - raise ValueError(msg) + if tbl_xy.dim() > 1: + batch_shape = tbl_xy.shape[:-1] + table_len = tbl_xy.shape[-1] + flat = tbl_xy.reshape(-1, table_len) + counts = [_valid_n_and_check(t[0::2], t[1::2]) for t in flat] + return torch.tensor(counts, device=tbl_xy.device).reshape(batch_shape) + valid_n = _valid_n_and_check(tbl_xy[0::2], tbl_xy[1::2]) return list(range(valid_n)) def __init__(self, tbl_xy): # Convert to tensor if needed - tbl_xy = torch.as_tensor(tbl_xy, dtype=DTYPE) + tbl_xy = torch.as_tensor(tbl_xy, dtype=self.dtype, device=self.device) + # If the table was provided as ints, promote to float so interpolation + # doesn't truncate query points (e.g. 2.5 -> 2) and autograd works. + if not tbl_xy.is_floating_point(): + tbl_xy = tbl_xy.to(dtype=self.dtype) # Detect if we have batched tables (>1D) self.is_batched = tbl_xy.dim() > 1 @@ -363,63 +349,60 @@ def __init__(self, tbl_xy): self.batch_shape = tbl_xy.shape[:-1] table_len = tbl_xy.shape[-1] - # Store the full batched tables + # Keep the full batched tables for debugging/inspection self.tbl_xy = tbl_xy - # Get valid counts for each table + # Validate and compute how many (x,y) pairs are valid per table valid_counts = self._check_x_ascending(tbl_xy) self.valid_counts = valid_counts - # Extract x and y for all tables flat_tables = tbl_xy.reshape(-1, table_len) + flat_valid = valid_counts.reshape(-1).to(device=self.device) num_tables = flat_tables.shape[0] - - x_list_batch = [] - y_list_batch = [] - slopes_batch = [] + max_n = int(flat_valid.max().item()) if num_tables > 0 else 0 + + # Store padded tensors so we can vectorize __call__. + pad_x = torch.finfo(tbl_xy.dtype).max + x_flat = torch.full((num_tables, max_n), pad_x, dtype=self.dtype, device=self.device) + y_flat = torch.zeros((num_tables, max_n), dtype=self.dtype, device=self.device) + slopes_flat = torch.zeros( + (num_tables, max(0, max_n - 1)), dtype=self.dtype, device=self.device + ) for idx in range(num_tables): + n = int(flat_valid[idx].item()) table = flat_tables[idx] - valid_n = valid_counts.flatten()[idx].item() - - x_indices = torch.tensor([2 * i for i in range(valid_n)]) - y_indices = torch.tensor([2 * i + 1 for i in range(valid_n)]) - - x_vals = table[x_indices] - y_vals = table[y_indices] - - # Calculate slopes - if len(x_vals) > 1: - slopes = (y_vals[1:] - y_vals[:-1]) / (x_vals[1:] - x_vals[:-1]) - else: - slopes = torch.tensor([], dtype=DTYPE) - - x_list_batch.append(x_vals) - y_list_batch.append(y_vals) - slopes_batch.append(slopes) - - # Store as lists - don't reshape, just keep the flat structure - self.x_list_batch = x_list_batch - self.y_list_batch = y_list_batch - self.slopes_batch = slopes_batch + x_vals = table[0::2][:n] + y_vals = table[1::2][:n] + + x_flat[idx, :n] = x_vals + y_flat[idx, :n] = y_vals + if n < max_n: + y_flat[idx, n:] = y_vals[-1] + if n > 1: + slopes_flat[idx, : n - 1] = (y_vals[1:] - y_vals[:-1]) / ( + x_vals[1:] - x_vals[:-1] + ) + + self._x_flat = x_flat + self._y_flat = y_flat + self._slopes_flat = slopes_flat + self._valid_counts_flat = flat_valid else: # Original 1D logic from pcse self.batch_shape = None indices = self._check_x_ascending(tbl_xy) + valid_n = len(indices) - # Extract x and y values using indices - x_indices = torch.tensor([2 * i for i in indices]) - y_indices = torch.tensor([2 * i + 1 for i in indices]) - self.x_list = tbl_xy[x_indices] - self.y_list = tbl_xy[y_indices] - - # Calculate slopes - x1 = self.x_list[:-1] - x2 = self.x_list[1:] - y1 = self.y_list[:-1] - y2 = self.y_list[1:] - self.slopes = (y2 - y1) / (x2 - x1) + self.x_list = tbl_xy[0::2][:valid_n] + self.y_list = tbl_xy[1::2][:valid_n] + if valid_n > 1: + self.slopes = (self.y_list[1:] - self.y_list[:-1]) / ( + self.x_list[1:] - self.x_list[:-1] + ) + else: + self.slopes = torch.tensor([], dtype=self.dtype, device=self.device) def __call__(self, x): """Returns the interpolated value at abscissa x. @@ -431,59 +414,47 @@ def __call__(self, x): Returns: torch.Tensor: The interpolated value, preserving batch dimensions. """ - x = torch.as_tensor(x, dtype=DTYPE) - if self.is_batched: - # Ensure x has compatible shape for broadcasting - # x can be scalar or have batch dimensions - - # Flatten batch dimensions for processing + x = torch.as_tensor(x, dtype=self._x_flat.dtype, device=self._x_flat.device) flat_x = x.reshape(-1) if x.dim() > 0 else x.unsqueeze(0) - num_queries = flat_x.shape[0] if flat_x.dim() > 0 else 1 - - results = [] - - # Process each table - for idx in range(len(self.x_list_batch)): - x_list = self.x_list_batch[idx] - y_list = self.y_list_batch[idx] - slopes = self.slopes_batch[idx] - - # Get the query value (broadcast if needed) - if num_queries == 1: - x_val = flat_x[0] if flat_x.dim() > 0 else flat_x - elif idx < num_queries: - x_val = flat_x[idx] - else: - x_val = flat_x[0] # Broadcast first value - - # Ensure contiguous memory layout for searchsorted - x_list_contig = x_list.contiguous() - x_val_contig = ( - x_val.contiguous() - if isinstance(x_val, torch.Tensor) and x_val.dim() > 0 - else x_val - ) - - # Find interval and interpolate using torch.where for differentiability - i = torch.searchsorted(x_list_contig, x_val_contig, right=False) - 1 - i = torch.clamp(i, 0, len(x_list) - 2) - - # Calculate interpolated value - interp_result = y_list[i] + slopes[i] * (x_val - x_list[i]) + num_tables = self._x_flat.shape[0] - # Apply boundary conditions using torch.where - result = torch.where( - x_val <= x_list[0], - y_list[0], - torch.where(x_val >= x_list[-1], y_list[-1], interp_result), - ) - - results.append(result) + if flat_x.numel() == 1: + x_vals = flat_x.expand(num_tables) + elif flat_x.numel() == num_tables: + x_vals = flat_x + else: + x_vals = flat_x[0].expand(num_tables) + + # Find interval index per table + # Ensure contiguous query tensor to avoid internal copies in searchsorted + x_query = x_vals.unsqueeze(1).contiguous() + i = torch.searchsorted(self._x_flat, x_query, right=False) - 1 + i = i.squeeze(1) + upper = torch.clamp(self._valid_counts_flat - 2, min=0) + i = torch.clamp(i, min=0) + i = torch.minimum(i, upper) + + idx = i.unsqueeze(1) + x_i = self._x_flat.gather(1, idx).squeeze(1) + y_i = self._y_flat.gather(1, idx).squeeze(1) + slope_i = self._slopes_flat.gather(1, idx).squeeze(1) + interp = y_i + slope_i * (x_vals - x_i) + + x0 = self._x_flat[:, 0] + y0 = self._y_flat[:, 0] + last_idx = (self._valid_counts_flat - 1).to(dtype=torch.long).unsqueeze(1) + x_last = self._x_flat.gather(1, last_idx).squeeze(1) + y_last = self._y_flat.gather(1, last_idx).squeeze(1) + + out = torch.where( + x_vals <= x0, + y0, + torch.where(x_vals >= x_last, y_last, interp), + ) + return out.reshape(self.batch_shape) - # Reshape to original batch shape - output = torch.stack(results).reshape(self.batch_shape) - return output + x = torch.as_tensor(x, dtype=self.x_list.dtype, device=self.x_list.device) # Ensure contiguous memory layout for searchsorted x_list_contig = self.x_list.contiguous() @@ -505,6 +476,38 @@ def __call__(self, x): return result + def to(self, device=None, dtype=None): + """Move internal tensors to a different device/dtype (PyTorch-style). + + This is an in-place operation and returns ``self`` for chaining. + """ + if device is None and dtype is None: + return self + + for name in ( + "tbl_xy", + "x_list", + "y_list", + "slopes", + "_x_flat", + "_y_flat", + "_slopes_flat", + "valid_counts", + "_valid_counts_flat", + ): + if not hasattr(self, name): + continue + t = getattr(self, name) + if not isinstance(t, torch.Tensor): + continue + # Keep integer tensors as integers; only move device for them. + if t.is_floating_point(): + setattr(self, name, t.to(device=device, dtype=dtype)) + else: + setattr(self, name, t.to(device=device)) + + return self + @property def shape(self): """Returns the shape of the Afgen table.""" @@ -569,7 +572,7 @@ def _get_params_shape(params): return shape -def _get_drv(drv_var, expected_shape): +def _get_drv(drv_var, expected_shape, dtype, device=None): """Check that the driving variables have the expected shape and fetch them. Driving variables can be scalars (0-dimensional) or match the expected shape. @@ -580,6 +583,8 @@ def _get_drv(drv_var, expected_shape): Args: drv_var: driving variable in WeatherDataContainer expected_shape: Expected shape tuple for non-scalar variables + dtype: dtype for the tensor + device: Optional device for the tensor Raises: ValueError: If any variable has incompatible shape @@ -590,9 +595,13 @@ def _get_drv(drv_var, expected_shape): # Check shape: must be scalar (0-d) or match expected_shape if not isinstance(drv_var, torch.Tensor) or drv_var.dim() == 0: # Scalar is valid, will be broadcast - return _broadcast_to(drv_var, expected_shape) + return _broadcast_to(drv_var, expected_shape, dtype, device) elif drv_var.shape == expected_shape: # Matches expected shape + if dtype is not None: + drv_var = drv_var.to(dtype=dtype) + if device is not None: + drv_var = drv_var.to(device=device) return drv_var else: raise ValueError( @@ -601,11 +610,23 @@ def _get_drv(drv_var, expected_shape): ) -def _broadcast_to(x, shape): - """Create a view of tensor X with the given shape.""" +def _broadcast_to(x, shape, dtype, device=None): + """Create a view of tensor X with the given shape. + + Args: + x: The tensor or value to broadcast + shape: The target shape + dtype: dtype for the tensor + device: Optional device for the tensor + """ # If x is not a tensor, convert it if not isinstance(x, torch.Tensor): - x = torch.tensor(x, dtype=DTYPE) + x = torch.tensor(x, dtype=dtype) + # Ensure correct dtype and device + if dtype is not None: + x = x.to(dtype=dtype) + if device is not None: + x = x.to(device=device) # If already the correct shape, return as-is if x.shape == shape: return x diff --git a/tests/physical_models/conftest.py b/tests/physical_models/conftest.py index 3605e93..dba659f 100644 --- a/tests/physical_models/conftest.py +++ b/tests/physical_models/conftest.py @@ -1,6 +1,8 @@ from pathlib import Path import pytest import requests +import torch +from diffwofost.physical_models.config import ComputeConfig LOCAL_TEST_DIR = Path(__file__).parent / "test_data" BASE_PCSE_URL = "https://raw.githubusercontent.com/ajwdewit/pcse/refs/heads/master/tests/test_data" @@ -37,3 +39,24 @@ def download_test_files(): """Download all required test files before running tests.""" for file_name in FILE_NAMES: download_file(file_name) + + +@pytest.fixture(params=["cpu", "cuda"]) +def device(request): + """Parametrize tests over CPU and GPU devices. + + Sets the global ComputeConfig to use the specified device. + Skips CUDA runs when CUDA isn't available. + """ + + device_name = request.param + if device_name == "cuda" and not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + # Set the global ComputeConfig to use the specified device + ComputeConfig.set_device(device_name) + + yield device_name + + # Reset to defaults after the test + ComputeConfig.reset_to_defaults() diff --git a/tests/physical_models/crop/test_leaf_dynamics.py b/tests/physical_models/crop/test_leaf_dynamics.py index d1dd78b..6002dce 100644 --- a/tests/physical_models/crop/test_leaf_dynamics.py +++ b/tests/physical_models/crop/test_leaf_dynamics.py @@ -13,21 +13,18 @@ from diffwofost.physical_models.utils import prepare_engine_input from .. import phy_data_folder -# Ignore deprecation warnings from pcse.base.simulationobject -pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning:pcse.base.simulationobject") - leaf_dynamics_config = Configuration( CROP=WOFOST_Leaf_Dynamics, OUTPUT_VARS=["LAI", "TWLV"], ) -def get_test_diff_leaf_model(): +def get_test_diff_leaf_model(device: str = "cpu"): test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) crop_model_params = ["SPAN", "TDWI", "TBASE", "PERDL", "RGRLAI"] (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( - prepare_engine_input(test_data, crop_model_params) + prepare_engine_input(test_data, crop_model_params, device=device) ) return DiffLeafDynamics( copy.deepcopy(crop_model_params_provider), @@ -35,6 +32,7 @@ def get_test_diff_leaf_model(): agro_management_inputs, leaf_dynamics_config, copy.deepcopy(external_states), + device=device, ) @@ -46,6 +44,7 @@ def __init__( agro_management_inputs, config, external_states, + device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider @@ -53,6 +52,7 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.device = device def forward(self, params_dict): # pass new value of parameters to the model @@ -65,6 +65,7 @@ def forward(self, params_dict): self.agro_management_inputs, self.config, self.external_states, + device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -83,8 +84,8 @@ class TestLeafDynamics: for i in range(1, 45) # there are 44 test files ] - @pytest.mark.parametrize("test_data_url", leafdynamics_data_urls) - def test_leaf_dynamics_with_testengine(self, test_data_url): + @pytest.mark.parametrize("test_data_url", leafdynamics_data_urls) # Test subset for GPU + def test_leaf_dynamics_with_testengine(self, test_data_url, device): """EngineTestHelper and not Engine because it allows to specify `external_states`.""" # prepare model input test_data = get_test_data(test_data_url) @@ -102,6 +103,7 @@ def test_leaf_dynamics_with_testengine(self, test_data_url): agro_management_inputs, leaf_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -112,15 +114,20 @@ def test_leaf_dynamics_with_testengine(self, test_data_url): assert len(actual_results) == len(expected_results) for reference, model in zip(expected_results, actual_results, strict=False): assert reference["DAY"] == model["day"] + # Verify output is on the correct device + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + # Move to CPU for comparison if needed + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} assert all( - abs(reference[var] - model[var]) < precision + abs(reference[var] - model_cpu[var]) < precision for var, precision in expected_precision.items() ) @pytest.mark.parametrize( "param", ["TDWI", "SPAN", "RGRLAI", "TBASE", "PERDL", "KDIFTB", "SLATB", "TEMP"] ) - def test_leaf_dynamics_with_one_parameter_vector(self, param): + def test_leaf_dynamics_with_one_parameter_vector(self, param, device): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -130,7 +137,9 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): weather_data_provider, agro_management_inputs, external_states, - ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) # Setting a vector (with one value) for the selected parameter if param == "TEMP": @@ -155,6 +164,7 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): agro_management_inputs, leaf_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -165,6 +175,7 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): agro_management_inputs, leaf_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -176,8 +187,15 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): for reference, model in zip(expected_results, actual_results, strict=False): assert reference["DAY"] == model["day"] + # Verify output is on the correct device + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + # Move to CPU for comparison + model_cpu = { + k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items() + } assert all( - all(abs(reference[var] - model[var]) < precision) + all(abs(reference[var] - model_cpu[var]) < precision) for var, precision in expected_precision.items() ) @@ -193,7 +211,7 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param): ("SLATB", 0.0005), ], ) - def test_leaf_dynamics_with_different_parameter_values(self, param, delta): + def test_leaf_dynamics_with_different_parameter_values(self, param, delta, device): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -222,6 +240,7 @@ def test_leaf_dynamics_with_different_parameter_values(self, param, delta): agro_management_inputs, leaf_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -233,13 +252,18 @@ def test_leaf_dynamics_with_different_parameter_values(self, param, delta): for reference, model in zip(expected_results, actual_results, strict=False): assert reference["DAY"] == model["day"] + # Verify output is on the correct device + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + # Move to CPU for comparison + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} assert all( # The value for which test data are available is the last element - abs(reference[var] - model[var][-1]) < precision + abs(reference[var] - model_cpu[var][-1]) < precision for var, precision in expected_precision.items() ) - def test_leaf_dynamics_with_multiple_parameter_vectors(self): + def test_leaf_dynamics_with_multiple_parameter_vectors(self, device): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -266,6 +290,7 @@ def test_leaf_dynamics_with_multiple_parameter_vectors(self): agro_management_inputs, leaf_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -282,7 +307,7 @@ def test_leaf_dynamics_with_multiple_parameter_vectors(self): for var, precision in expected_precision.items() ) - def test_leaf_dynamics_with_multiple_parameter_arrays(self): + def test_leaf_dynamics_with_multiple_parameter_arrays(self, device): # prepare model input test_data_url = f"{phy_data_folder}/test_leafdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -312,6 +337,7 @@ def test_leaf_dynamics_with_multiple_parameter_arrays(self): agro_management_inputs, leaf_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -359,6 +385,7 @@ def test_leaf_dynamics_with_incompatible_parameter_vectors(self): agro_management_inputs, leaf_dynamics_config, external_states, + device="cpu", ) def test_leaf_dynamics_with_incompatible_weather_parameter_vectors(self): @@ -387,6 +414,7 @@ def test_leaf_dynamics_with_incompatible_weather_parameter_vectors(self): agro_management_inputs, leaf_dynamics_config, external_states, + device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) @@ -439,6 +467,7 @@ def test_leaf_dynamics_with_sigmoid_approx(self, test_data_url): agro_management_inputs, leaf_dynamics_config, external_states, + device="cpu", ) engine.run_till_terminate() actual_results = engine.get_output() @@ -518,11 +547,11 @@ class TestDiffLeafDynamicsGradients: @pytest.mark.parametrize("param_name,output_name", no_gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_no_gradients(self, param_name, output_name, config_type): + def test_no_gradients(self, param_name, output_name, config_type, device): """Test cases where parameters should not have gradients for specific outputs.""" - model = get_test_diff_leaf_model() + model = get_test_diff_leaf_model(device=device) value, dtype = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) loss = output[output_name].sum() @@ -534,11 +563,11 @@ def test_no_gradients(self, param_name, output_name, config_type): @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_gradients_forward_backward_match(self, param_name, output_name, config_type): + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): """Test that forward and backward gradients match for parameter-output pairs.""" - model = get_test_diff_leaf_model() + model = get_test_diff_leaf_model(device=device) value, dtype = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) loss = output[output_name].sum() @@ -561,22 +590,24 @@ def test_gradients_forward_backward_match(self, param_name, output_name, config_ @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_gradients_numerical(self, param_name, output_name, config_type): + def test_gradients_numerical(self, param_name, output_name, config_type, device): """Test that analytical gradients match numerical gradients.""" value, _ = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64)) + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) numerical_grad = calculate_numerical_grad( - get_test_diff_leaf_model, param_name, param.data, output_name + lambda: get_test_diff_leaf_model(device=device), param_name, param.data, output_name ) - model = get_test_diff_leaf_model() + model = get_test_diff_leaf_model(device=device) output = model({param_name: param}) loss = output[output_name].sum() # this is ∂loss/∂param, for comparison with numerical gradient grads = torch.autograd.grad(loss, param, retain_graph=True)[0] - assert_array_almost_equal(numerical_grad, grads.data, decimal=3) + assert_array_almost_equal( + numerical_grad.detach().cpu().numpy(), grads.detach().cpu().numpy(), decimal=3 + ) # Warn if gradient is zero (but this shouldn't happen for gradient_params) if torch.all(grads == 0): diff --git a/tests/physical_models/crop/test_phenology.py b/tests/physical_models/crop/test_phenology.py index 6e16cf0..78ed856 100644 --- a/tests/physical_models/crop/test_phenology.py +++ b/tests/physical_models/crop/test_phenology.py @@ -12,9 +12,6 @@ from diffwofost.physical_models.utils import prepare_engine_input from .. import phy_data_folder -# Ignore deprecation warnings from pcse.base.simulationobject -pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning:pcse.base.simulationobject") - phenology_config = Configuration( CROP=DVS_Phenology, OUTPUT_VARS=["DVR", "DVS", "TSUM", "TSUME", "VERN"], @@ -31,12 +28,17 @@ def assert_reference_match(reference, model, expected_precision): # for some data tests, both reference and model can have None values if reference[var] is None and model[var] is None: continue - assert torch.all( - torch.abs(torch.as_tensor(reference[var]) - torch.as_tensor(model[var])) < precision + ref_t = torch.as_tensor(reference[var]) + model_v = model[var] + model_t = ( + model_v.detach().cpu() + if isinstance(model_v, torch.Tensor) + else torch.as_tensor(model_v) ) + assert torch.all(torch.abs(ref_t - model_t) < precision) -def get_test_diff_phenology_model(): +def get_test_diff_phenology_model(device: str = "cpu"): test_data_url = f"{phy_data_folder}/test_phenology_wofost72_01.yaml" test_data = get_test_data(test_data_url) # Phenology-related crop model parameters @@ -52,18 +54,20 @@ def get_test_diff_phenology_model(): "DVSI", "DVSEND", "DTSMTB", + "VERNRTB", "VERNSAT", "VERNBASE", "VERNDVS", ] (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( - prepare_engine_input(test_data, crop_model_params) + prepare_engine_input(test_data, crop_model_params, device=device) ) return DiffPhenologyDynamics( copy.deepcopy(crop_model_params_provider), weather_data_provider, agro_management_inputs, phenology_config, + device=device, ) @@ -74,12 +78,14 @@ def __init__( weather_data_provider, agro_management_inputs, config, + device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider self.weather_data_provider = weather_data_provider self.agro_management_inputs = agro_management_inputs self.config = config + self.device = device def forward(self, params_dict): # pass new value of parameters to the model @@ -91,6 +97,7 @@ def forward(self, params_dict): self.weather_data_provider, self.agro_management_inputs, self.config, + device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -109,7 +116,7 @@ class TestPhenologyDynamics: ] @pytest.mark.parametrize("test_data_url", phenology_data_urls) - def test_phenology_with_testengine(self, test_data_url): + def test_phenology_with_testengine(self, test_data_url, device): test_data = get_test_data(test_data_url) crop_model_params = [ "TSUMEM", @@ -123,6 +130,7 @@ def test_phenology_with_testengine(self, test_data_url): "DVSI", "DVSEND", "DTSMTB", + "VERNRTB", "VERNSAT", "VERNBASE", "VERNDVS", @@ -132,13 +140,14 @@ def test_phenology_with_testengine(self, test_data_url): weather_data_provider, agro_management_inputs, _, - ) = prepare_engine_input(test_data, crop_model_params) + ) = prepare_engine_input(test_data, crop_model_params, device=device) engine = EngineTestHelper( crop_model_params_provider, weather_data_provider, agro_management_inputs, phenology_config, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -147,7 +156,12 @@ def test_phenology_with_testengine(self, test_data_url): assert len(actual_results) == len(expected_results) for reference, model in zip(expected_results, actual_results, strict=False): - assert_reference_match(reference, model, expected_precision) + for var in expected_precision.keys(): + value = model.get(var) + if isinstance(value, torch.Tensor): + assert value.device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} + assert_reference_match(reference, model_cpu, expected_precision) @pytest.mark.parametrize( "param", @@ -169,7 +183,7 @@ def test_phenology_with_testengine(self, test_data_url): "TEMP", ], ) - def test_phenology_with_one_parameter_vector(self, param): + def test_phenology_with_one_parameter_vector(self, param, device): # pick a test case with vernalisation to have all the parameters test_data_url = f"{phy_data_folder}/test_phenology_wofost72_17.yaml" test_data = get_test_data(test_data_url) @@ -185,6 +199,7 @@ def test_phenology_with_one_parameter_vector(self, param): "DVSI", "DVSEND", "DTSMTB", + "VERNRTB", "VERNSAT", "VERNBASE", "VERNDVS", @@ -194,11 +209,17 @@ def test_phenology_with_one_parameter_vector(self, param): weather_data_provider, agro_management_inputs, _, - ) = prepare_engine_input(test_data, crop_model_params, meteo_range_checks=False) + ) = prepare_engine_input( + test_data, crop_model_params, meteo_range_checks=False, device=device + ) if param == "TEMP": + if device == "cuda": + pytest.skip("Weather parameter vector tests are CPU-only") for (_, _), wdc in weather_data_provider.store.items(): - wdc.TEMP = torch.ones(10, dtype=torch.float64) * wdc.TEMP + wdc.TEMP = torch.ones(10, dtype=torch.float64, device=device) * torch.as_tensor( + wdc.TEMP, dtype=torch.float64, device=device + ) elif param == "DTSMTB": repeated = crop_model_params_provider[param].repeat(10, 1) crop_model_params_provider.set_override(param, repeated, check=False) @@ -213,6 +234,7 @@ def test_phenology_with_one_parameter_vector(self, param): weather_data_provider, agro_management_inputs, phenology_config, + device=device, ) engine.run_till_terminate() _ = engine.get_output() @@ -222,6 +244,7 @@ def test_phenology_with_one_parameter_vector(self, param): weather_data_provider, agro_management_inputs, phenology_config, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -246,7 +269,7 @@ def test_phenology_with_one_parameter_vector(self, param): ("VERNDVS", 0.1), ], ) - def test_phenology_with_different_parameter_values(self, param, delta): + def test_phenology_with_different_parameter_values(self, param, delta, device): # we dont test IDSL,DLO, DLC, DVSEND because these paramaters controls the # simulation duration # TODO: revisit this choice when Engine is fixed @@ -264,6 +287,7 @@ def test_phenology_with_different_parameter_values(self, param, delta): "DVSI", "DVSEND", "DTSMTB", + "VERNRTB", "VERNSAT", "VERNBASE", "VERNDVS", @@ -273,7 +297,7 @@ def test_phenology_with_different_parameter_values(self, param, delta): weather_data_provider, agro_management_inputs, _, - ) = prepare_engine_input(test_data, crop_model_params) + ) = prepare_engine_input(test_data, crop_model_params, device=device) test_value = crop_model_params_provider[param] if param == "DTSMTB": @@ -282,7 +306,7 @@ def test_phenology_with_different_parameter_values(self, param, delta): non_zeros_mask = test_value != 0 param_vec = torch.stack([test_value + non_zeros_mask * delta, test_value]) else: - param_vec = torch.tensor([test_value - delta, test_value + delta, test_value]) + param_vec = torch.stack([test_value - delta, test_value + delta, test_value]) crop_model_params_provider.set_override(param, param_vec, check=False) engine = EngineTestHelper( @@ -290,6 +314,7 @@ def test_phenology_with_different_parameter_values(self, param, delta): weather_data_provider, agro_management_inputs, phenology_config, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -309,7 +334,7 @@ def test_phenology_with_different_parameter_values(self, param, delta): continue assert torch.all(torch.abs(reference[var] - model[var][-1]) < precision) - def test_phenology_with_multiple_parameter_vectors(self): + def test_phenology_with_multiple_parameter_vectors(self, device): test_data_url = f"{phy_data_folder}/test_phenology_wofost72_17.yaml" test_data = get_test_data(test_data_url) crop_model_params = [ @@ -347,6 +372,7 @@ def test_phenology_with_multiple_parameter_vectors(self): weather_data_provider, agro_management_inputs, phenology_config, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -356,7 +382,7 @@ def test_phenology_with_multiple_parameter_vectors(self): for reference, model in zip(expected_results, actual_results, strict=False): assert_reference_match(reference, model, expected_precision) - def test_phenology_with_multiple_parameter_arrays(self): + def test_phenology_with_multiple_parameter_arrays(self, device): test_data_url = f"{phy_data_folder}/test_phenology_wofost72_17.yaml" test_data = get_test_data(test_data_url) crop_model_params = [ @@ -412,6 +438,7 @@ def test_phenology_with_multiple_parameter_arrays(self): weather_data_provider, agro_management_inputs, phenology_config, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -465,6 +492,7 @@ def test_phenology_with_incompatible_parameter_vectors(self): weather_data_provider, agro_management_inputs, phenology_config, + device="cpu", ) def test_phenology_with_incompatible_weather_parameter_vectors(self): @@ -505,10 +533,11 @@ def test_phenology_with_incompatible_weather_parameter_vectors(self): weather_data_provider, agro_management_inputs, phenology_config, + device="cpu", ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) - def test_wofost_pp_with_phenology(self, test_data_url): + def test_wofost_pp_with_phenology(self, test_data_url, monkeypatch): test_data = get_test_data(test_data_url) crop_model_params = [ "TSUMEM", @@ -527,10 +556,14 @@ def test_wofost_pp_with_phenology(self, test_data_url): "VERNDVS", ] (crop_model_params_provider, weather_data_provider, agro_management_inputs, _) = ( - prepare_engine_input(test_data, crop_model_params) + prepare_engine_input(test_data, crop_model_params, device="cpu") ) expected_results, expected_precision = test_data["ModelResults"], test_data["Precision"] + # Keep this integration test on CPU. + monkeypatch.setattr(DVS_Phenology, "device", "cpu") + monkeypatch.setattr(DVS_Phenology, "dtype", torch.float64) + with patch("pcse.crop.wofost72.Phenology", DVS_Phenology): model = Wofost72_PP( crop_model_params_provider, weather_data_provider, agro_management_inputs @@ -624,12 +657,14 @@ class TestDiffPhenologyDynamicsGradients: @pytest.mark.parametrize("param_name,output_name", no_gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_no_gradients(self, param_name, output_name, config_type): - model = get_test_diff_phenology_model() + def test_no_gradients(self, param_name, output_name, config_type, device): + model = get_test_diff_phenology_model(device=device) value, dtype = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) loss = output[output_name].sum() + if not loss.requires_grad: + return grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] if grads is not None: assert torch.all((grads == 0) | torch.isnan(grads)), ( @@ -638,10 +673,10 @@ def test_no_gradients(self, param_name, output_name, config_type): @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_gradients_forward_backward_match(self, param_name, output_name, config_type): - model = get_test_diff_phenology_model() + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): + model = get_test_diff_phenology_model(device=device) value, dtype = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) loss = output[output_name].sum() grads = torch.autograd.grad(loss, param, retain_graph=True)[0] @@ -650,17 +685,20 @@ def test_gradients_forward_backward_match(self, param_name, output_name, config_ loss.backward() grad_backward = param.grad assert grad_backward is not None - assert torch.all(grad_backward == grads) + assert torch.allclose(grad_backward, grads) @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_gradients_numerical(self, param_name, output_name, config_type): + def test_gradients_numerical(self, param_name, output_name, config_type, device): value, _ = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64)) + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) numerical_grad = calculate_numerical_grad( - get_test_diff_phenology_model, param_name, param.data, output_name + lambda: get_test_diff_phenology_model(device=device), + param_name, + param.data, + output_name, ) - model = get_test_diff_phenology_model() + model = get_test_diff_phenology_model(device=device) output = model({param_name: param}) loss = output[output_name].sum() grads = torch.autograd.grad(loss, param, retain_graph=True)[0] @@ -670,6 +708,9 @@ def test_gradients_numerical(self, param_name, output_name, config_type): ) if torch.all(grads == 0): warnings.warn( - f"Gradient for par '{param_name}' wrt out '{output_name}' is zero: {grads.data}", + ( + f"Gradient for par '{param_name}' wrt out '{output_name}' is zero: " + f"{grads.data.detach().cpu().numpy()}" + ), UserWarning, ) diff --git a/tests/physical_models/crop/test_root_dynamics.py b/tests/physical_models/crop/test_root_dynamics.py index b262b3a..69b7c50 100644 --- a/tests/physical_models/crop/test_root_dynamics.py +++ b/tests/physical_models/crop/test_root_dynamics.py @@ -13,21 +13,18 @@ from diffwofost.physical_models.utils import prepare_engine_input from .. import phy_data_folder -# Ignore deprecation warnings from pcse.base.simulationobject -pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning:pcse.base.simulationobject") - root_dynamics_config = Configuration( CROP=WOFOST_Root_Dynamics, OUTPUT_VARS=["RD", "TWRT"], ) -def get_test_diff_root_model(): +def get_test_diff_root_model(device: str = "cpu"): test_data_url = f"{phy_data_folder}/test_rootdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) crop_model_params = ["RDI", "RRI", "RDMCR", "RDMSOL", "TDWI", "IAIRDU"] (crop_model_params_provider, weather_data_provider, agro_management_inputs, external_states) = ( - prepare_engine_input(test_data, crop_model_params) + prepare_engine_input(test_data, crop_model_params, device=device) ) return DiffRootDynamics( copy.deepcopy(crop_model_params_provider), @@ -35,6 +32,7 @@ def get_test_diff_root_model(): agro_management_inputs, root_dynamics_config, copy.deepcopy(external_states), + device=device, ) @@ -46,6 +44,7 @@ def __init__( agro_management_inputs, config, external_states, + device: str = "cpu", ): super().__init__() self.crop_model_params_provider = crop_model_params_provider @@ -53,6 +52,7 @@ def __init__( self.agro_management_inputs = agro_management_inputs self.config = config self.external_states = external_states + self.device = device def forward(self, params_dict): # pass new value of parameters to the model @@ -65,6 +65,7 @@ def forward(self, params_dict): self.agro_management_inputs, self.config, self.external_states, + device=self.device, ) engine.run_till_terminate() results = engine.get_output() @@ -84,7 +85,7 @@ class TestRootDynamics: ] @pytest.mark.parametrize("test_data_url", rootdynamics_data_urls) - def test_root_dynamics_with_testengine(self, test_data_url): + def test_root_dynamics_with_testengine(self, test_data_url, device): """EngineTestHelper and not Engine because it allows to specify `external_states`.""" # prepare model input test_data = get_test_data(test_data_url) @@ -102,6 +103,7 @@ def test_root_dynamics_with_testengine(self, test_data_url): agro_management_inputs, root_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -113,13 +115,16 @@ def test_root_dynamics_with_testengine(self, test_data_url): for reference, model in zip(expected_results, actual_results, strict=False): assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} assert all( - abs(reference[var] - model[var]) < precision + abs(reference[var] - model_cpu[var]) < precision for var, precision in expected_precision.items() ) @pytest.mark.parametrize("param", ["RDI", "RRI", "RDMCR", "RDMSOL", "TDWI", "IAIRDU", "RDRRTB"]) - def test_root_dynamics_with_one_parameter_vector(self, param): + def test_root_dynamics_with_one_parameter_vector(self, param, device): # prepare model input test_data_url = phy_data_folder / "test_rootdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -146,6 +151,7 @@ def test_root_dynamics_with_one_parameter_vector(self, param): agro_management_inputs, root_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -157,8 +163,11 @@ def test_root_dynamics_with_one_parameter_vector(self, param): for reference, model in zip(expected_results, actual_results, strict=False): assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} assert all( - all(abs(reference[var] - model[var]) < precision) + all(abs(reference[var] - model_cpu[var]) < precision) for var, precision in expected_precision.items() ) @@ -174,7 +183,7 @@ def test_root_dynamics_with_one_parameter_vector(self, param): ("RDRRTB", 0.01), ], ) - def test_root_dynamics_with_different_parameter_values(self, param, delta): + def test_root_dynamics_with_different_parameter_values(self, param, delta, device): # prepare model input test_data_url = phy_data_folder / "test_rootdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -203,6 +212,7 @@ def test_root_dynamics_with_different_parameter_values(self, param, delta): agro_management_inputs, root_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -214,13 +224,16 @@ def test_root_dynamics_with_different_parameter_values(self, param, delta): for reference, model in zip(expected_results, actual_results, strict=False): assert reference["DAY"] == model["day"] + for var in expected_precision.keys(): + assert model[var].device.type == device, f"{var} should be on {device}" + model_cpu = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in model.items()} assert all( # The value for which test data are available is the last element - abs(reference[var] - model[var][-1]) < precision + abs(reference[var] - model_cpu[var][-1]) < precision for var, precision in expected_precision.items() ) - def test_root_dynamics_with_multiple_parameter_vectors(self): + def test_root_dynamics_with_multiple_parameter_vectors(self, device): # prepare model input test_data_url = phy_data_folder / "test_rootdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -248,6 +261,7 @@ def test_root_dynamics_with_multiple_parameter_vectors(self): agro_management_inputs, root_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -264,7 +278,7 @@ def test_root_dynamics_with_multiple_parameter_vectors(self): for var, precision in expected_precision.items() ) - def test_root_dynamics_with_multiple_parameter_arrays(self): + def test_root_dynamics_with_multiple_parameter_arrays(self, device): # prepare model input test_data_url = phy_data_folder / "test_rootdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -290,6 +304,7 @@ def test_root_dynamics_with_multiple_parameter_arrays(self): agro_management_inputs, root_dynamics_config, external_states, + device=device, ) engine.run_till_terminate() actual_results = engine.get_output() @@ -309,7 +324,7 @@ def test_root_dynamics_with_multiple_parameter_arrays(self): model[var].shape == (30, 5) for var in expected_precision.keys() ) # check the output shapes - def test_root_dynamics_with_incompatible_parameter_vectors(self): + def test_root_dynamics_with_incompatible_parameter_vectors(self, device): # prepare model input test_data_url = phy_data_folder / "test_rootdynamics_wofost72_01.yaml" test_data = get_test_data(test_data_url) @@ -337,6 +352,7 @@ def test_root_dynamics_with_incompatible_parameter_vectors(self): agro_management_inputs, root_dynamics_config, external_states, + device=device, ) @pytest.mark.parametrize("test_data_url", wofost72_data_urls) @@ -408,35 +424,61 @@ class TestDiffRootDynamicsGradients: "RDRRTB": ["TWRT"], } - # Generate all combinations + # Generate all combinations and no_graph_mapping gradient_params = [] no_gradient_params = [] + no_graph_mapping = {} for param_name in param_names: + no_graph_outputs = [] for output_name in output_names: if output_name in gradient_mapping.get(param_name, []): gradient_params.append((param_name, output_name)) else: no_gradient_params.append((param_name, output_name)) + no_graph_outputs.append(output_name) + if no_graph_outputs: + no_graph_mapping[param_name] = no_graph_outputs @pytest.mark.parametrize("param_name,output_name", no_gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_no_gradients(self, param_name, output_name, config_type): + def test_no_gradients(self, param_name, output_name, config_type, device): """Test cases where parameters should not have gradients for specific outputs.""" - model = get_test_diff_root_model() + model = get_test_diff_root_model(device=device) value, dtype = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) loss = output[output_name].sum() - assert loss.grad_fn is None + # Check if this parameter-output pair should have no graph at all + should_have_no_graph = output_name in self.no_graph_mapping.get(param_name, []) + + if not loss.requires_grad: + # If there is no graph, assert that this is expected + assert should_have_no_graph, ( + f"Expected graph for {param_name} w.r.t. {output_name}, " + f"but loss.requires_grad is False" + ) + return + + # If we get here, there should be a graph + assert not should_have_no_graph, ( + f"Expected no graph for {param_name} w.r.t. {output_name}, " + f"but loss.requires_grad is True" + ) + + grads = torch.autograd.grad(loss, param, retain_graph=True, allow_unused=True)[0] + if grads is not None: + assert torch.all((grads == 0) | torch.isnan(grads)), ( + f"Gradient for {param_name} w.r.t. {output_name} should be zero or NaN" + ) @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_gradients_forward_backward_match(self, param_name, output_name, config_type): + def test_gradients_forward_backward_match(self, param_name, output_name, config_type, device): """Test that forward and backward gradients match for parameter-output pairs.""" - model = get_test_diff_root_model() + model = get_test_diff_root_model(device=device) value, dtype = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=dtype)) + param = torch.nn.Parameter(torch.tensor(value, dtype=dtype, device=device)) output = model({param_name: param}) loss = output[output_name].sum() @@ -459,27 +501,29 @@ def test_gradients_forward_backward_match(self, param_name, output_name, config_ @pytest.mark.parametrize("param_name,output_name", gradient_params) @pytest.mark.parametrize("config_type", ["single", "tensor"]) - def test_gradients_numerical(self, param_name, output_name, config_type): + def test_gradients_numerical(self, param_name, output_name, config_type, device): """Test that analytical gradients match numerical gradients.""" value, _ = self.param_configs[config_type][param_name] - param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64)) + param = torch.nn.Parameter(torch.tensor(value, dtype=torch.float64, device=device)) numerical_grad = calculate_numerical_grad( - get_test_diff_root_model, param_name, param.data, output_name + lambda: get_test_diff_root_model(device=device), param_name, param.data, output_name ) - model = get_test_diff_root_model() + model = get_test_diff_root_model(device=device) output = model({param_name: param}) loss = output[output_name].sum() # this is ∂loss/∂param, for comparison with numerical gradient grads = torch.autograd.grad(loss, param, retain_graph=True)[0] - assert_array_almost_equal(numerical_grad, grads.detach().numpy(), decimal=3) + assert_array_almost_equal( + numerical_grad.detach().cpu().numpy(), grads.detach().cpu().numpy(), decimal=3 + ) # Warn if gradient is zero if torch.all(grads == 0): warnings.warn( f"Gradient for parameter '{param_name}' with respect to output" - + f"'{output_name}' is zero: {grads.detach().numpy()}", + + f"'{output_name}' is zero: {grads.detach().cpu().numpy()}", UserWarning, ) diff --git a/tests/physical_models/test_config.py b/tests/physical_models/test_config.py index b7efe2d..397d1b4 100644 --- a/tests/physical_models/test_config.py +++ b/tests/physical_models/test_config.py @@ -1,6 +1,8 @@ +import torch from pcse.agromanager import AgroManager from pcse.crop.phenology import DVS_Phenology from pcse.soil.classic_waterbalance import WaterbalancePP +from diffwofost.physical_models.config import ComputeConfig from diffwofost.physical_models.config import Configuration from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics from . import phy_data_folder @@ -54,3 +56,41 @@ def test_output_variables_can_be_updated(self): assert config.OUTPUT_VARS == ["DVS", "LAI"] assert config.SUMMARY_OUTPUT_VARS == ["LAI"] assert config.TERMINAL_OUTPUT_VARS == ["DVS"] + + +class TestComputeConfig: + def test_default_device_is_cuda_or_cpu(self): + ComputeConfig.reset_to_defaults() + device = ComputeConfig.get_device() + assert device.type in ["cpu", "cuda"] + + def test_default_dtype_is_float64(self): + ComputeConfig.reset_to_defaults() + dtype = ComputeConfig.get_dtype() + assert dtype == torch.float64 + + def test_set_device_with_string(self): + ComputeConfig.set_device("cpu") + device = ComputeConfig.get_device() + assert device.type == "cpu" + + def test_set_device_with_torch_device(self): + ComputeConfig.set_device(torch.device("cpu")) + device = ComputeConfig.get_device() + assert device.type == "cpu" + + def test_set_dtype(self): + ComputeConfig.set_dtype(torch.float32) + dtype = ComputeConfig.get_dtype() + assert dtype == torch.float32 + + def test_reset_to_defaults(self): + ComputeConfig.set_device("cpu") + ComputeConfig.set_dtype(torch.float32) + ComputeConfig.reset_to_defaults() + + device = ComputeConfig.get_device() + dtype = ComputeConfig.get_dtype() + + assert device.type in ["cpu", "cuda"] + assert dtype == torch.float64 diff --git a/tests/physical_models/test_utils.py b/tests/physical_models/test_utils.py index cc4503b..0f28153 100644 --- a/tests/physical_models/test_utils.py +++ b/tests/physical_models/test_utils.py @@ -2,7 +2,7 @@ import pytest import torch -from diffwofost.physical_models.utils import DTYPE +from diffwofost.physical_models.config import ComputeConfig from diffwofost.physical_models.utils import Afgen from diffwofost.physical_models.utils import AfgenTrait from diffwofost.physical_models.utils import WeatherDataProviderTestHelper @@ -10,6 +10,8 @@ from diffwofost.physical_models.utils import get_test_data from . import phy_data_folder +DTYPE = ComputeConfig.get_dtype() + class TestAfgen: """Tests for the Afgen class.""" @@ -162,6 +164,25 @@ def test_complex_table(self): expected = torch.tensor(1.25, dtype=DTYPE) # Linear interpolation assert torch.isclose(result, expected) + def test_to_moves_dtype_and_device(self): + afgen = Afgen([0, 0, 10, 10]) + returned = afgen.to(dtype=torch.float64) + assert returned is afgen + out = afgen(torch.tensor(5.0)) + assert out.dtype == torch.float64 + + # Batched tables + tbl = torch.tensor([[0.0, 0.0, 10.0, 10.0], [0.0, 0.0, 10.0, 20.0]], dtype=torch.float32) + afgen_batched = Afgen(tbl) + afgen_batched.to(dtype=torch.float64) + out_batched = afgen_batched(torch.tensor([5.0, 5.0])) + assert out_batched.dtype == torch.float64 + + if torch.cuda.is_available(): + afgen_cuda = Afgen([0, 0, 10, 10]).to(device="cuda") + out_cuda = afgen_cuda(torch.tensor(5.0, device="cuda")) + assert out_cuda.device.type == "cuda" + class TestAfgenTrait: """Tests for the AfgenTrait class.""" @@ -633,7 +654,7 @@ def test_float_broadcast(self): provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) wdc = provider(provider.first_date) scalar = wdc.TEMP - out = _get_drv(scalar, expected_shape) + out = _get_drv(scalar, expected_shape, dtype=DTYPE) assert out.shape == expected_shape assert torch.allclose(out, torch.full(expected_shape, scalar, dtype=DTYPE)) @@ -644,7 +665,7 @@ def test_scalar_broadcast(self): provider = WeatherDataProviderTestHelper(test_data["WeatherVariables"]) wdc = provider(provider.first_date) scalar = torch.tensor(wdc.IRRAD, dtype=DTYPE) # 0-d tensor - out = _get_drv(scalar, expected_shape) + out = _get_drv(scalar, expected_shape, dtype=DTYPE) assert out.shape == expected_shape assert torch.allclose(out, torch.full(expected_shape, scalar.item(), dtype=DTYPE)) @@ -652,7 +673,7 @@ def test_matching_shape_pass_through(self): expected_shape = (3, 2) base_val = torch.tensor(12.34, dtype=DTYPE) var = torch.ones(expected_shape, dtype=DTYPE) * base_val - out = _get_drv(var, expected_shape) + out = _get_drv(var, expected_shape, dtype=DTYPE) assert out.shape == expected_shape # Should be the same object (no copy) assert out.data_ptr() == var.data_ptr() @@ -661,10 +682,10 @@ def test_wrong_shape_raises(self): expected_shape = (3, 2) wrong = torch.ones(2, 3, dtype=DTYPE) with pytest.raises(ValueError, match="incompatible shape"): - _get_drv(wrong, expected_shape) + _get_drv(wrong, expected_shape, dtype=DTYPE) def test_one_dim_shape_raises(self): expected_shape = (3, 2) one_dim = torch.ones(3, dtype=DTYPE) with pytest.raises(ValueError, match="incompatible shape"): - _get_drv(one_dim, expected_shape) + _get_drv(one_dim, expected_shape, dtype=DTYPE)