Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ venv3
ENV/
env.bak/
venv.bak/
.vscode/

# vim
*.swp
Expand Down
7 changes: 2 additions & 5 deletions docs/api_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,12 @@ hide:

## **Crop modules**

!!! note
At the moment only two modules of `leaf_dynamics` and `root_dynamics` are
differentiable w.r.t two parameters of `SPAN` and `TDWI`. But the package is under
continuous development. So make sure that you install the latest version.

::: diffwofost.physical_models.crop.leaf_dynamics.WOFOST_Leaf_Dynamics

::: diffwofost.physical_models.crop.root_dynamics.WOFOST_Root_Dynamics

::: diffwofost.physical_models.crop.phenology.DVS_Phenology

## **Utility (under development)**

::: diffwofost.physical_models.utils.EngineTestHelper
2 changes: 2 additions & 0 deletions src/diffwofost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from diffwofost.physical_models import utils
from diffwofost.physical_models.crop import leaf_dynamics
from diffwofost.physical_models.crop import phenology
from diffwofost.physical_models.crop import root_dynamics

logging.getLogger(__name__).addHandler(logging.NullHandler())
Expand All @@ -14,5 +15,6 @@
__all__ = [
"leaf_dynamics",
"root_dynamics",
"phenology",
"utils",
]
3 changes: 2 additions & 1 deletion src/diffwofost/physical_models/crop/leaf_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ class WOFOST_Leaf_Dynamics(SimulationObject):
| LAI | TDWI, SPAN, RGRLAI, TBASE, KDIFTB, SLATB |
| TWLV | TDWI, PERDL |

[!] Notice that the following gradients are zero:
[!NOTE]
Notice that the following gradients are zero:
- ∂SPAN/∂LAI
- ∂PERDL/∂TWLV
- ∂KDIFTB/∂LAI
Expand Down
773 changes: 773 additions & 0 deletions src/diffwofost/physical_models/crop/phenology.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion src/diffwofost/physical_models/crop/root_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ class WOFOST_Root_Dynamics(SimulationObject):
| RD | RDI, RRI, RDMCR, RDMSOL |
| TWRT | TDWI, RDRRTB |

[!] Notice that the gradient ∂TWRT/∂RDRRTB is zero.
[!NOTE]
Notice that the gradient ∂TWRT/∂RDRRTB is zero.

**IMPORTANT NOTICE**

Expand Down
81 changes: 56 additions & 25 deletions src/diffwofost/physical_models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pcse.engine import Engine
from pcse.settings import settings
from pcse.timer import Timer
from pcse.traitlets import Enum
from pcse.traitlets import TraitType

DTYPE = torch.float64 # Default data type for tensors in this module
Expand All @@ -50,7 +51,7 @@ class VariableKioskTestHelper(VariableKiosk):
def __init__(self, external_state_list):
super().__init__()
self.current_externals = {}
if external_state_list is not None:
if external_state_list:
self.external_state_list = external_state_list

def __call__(self, day):
Expand All @@ -59,7 +60,7 @@ def __call__(self, day):
Returns True if the list of external state/rate variables is exhausted,
otherwise False.
"""
if self.external_state_list is not None:
if self.external_state_list:
current_externals = self.external_state_list.pop(0)
forcing_day = current_externals.pop("DAY")
msg = "Failure updating VariableKiosk with external states: days are not matching!"
Expand Down Expand Up @@ -226,13 +227,15 @@ def prepare_engine_input(
test_data["WeatherVariables"], meteo_range_checks=meteo_range_checks
)
crop_model_params_provider = ParameterProvider(cropdata=cropd)
external_states = test_data["ExternalStates"]
external_states = test_data.get("ExternalStates") or []

# convert parameters to tensors
crop_model_params_provider.clear_override()
for name in crop_model_params:
value = torch.tensor(crop_model_params_provider[name], dtype=dtype)
crop_model_params_provider.set_override(name, value, check=False)
# if name is missing in the YAML, skip it
if name in crop_model_params_provider:
value = torch.tensor(crop_model_params_provider[name], dtype=dtype)
crop_model_params_provider.set_override(name, value, check=False)

# convert external states to tensors
tensor_external_states = [
Expand Down Expand Up @@ -475,37 +478,53 @@ def __call__(self, x):
else:
x_val = flat_x[0] # Broadcast first value

# Boundary conditions
if x_val <= x_list[0]:
result = y_list[0]
elif x_val >= x_list[-1]:
result = y_list[-1]
else:
# Find interval and interpolate
i = torch.searchsorted(x_list, x_val, right=False) - 1
i = torch.clamp(i, 0, len(x_list) - 2)
result = y_list[i] + slopes[i] * (x_val - x_list[i])
# Ensure contiguous memory layout for searchsorted
x_list_contig = x_list.contiguous()
x_val_contig = (
x_val.contiguous()
if isinstance(x_val, torch.Tensor) and x_val.dim() > 0
else x_val
)

# Find interval and interpolate using torch.where for differentiability
i = torch.searchsorted(x_list_contig, x_val_contig, right=False) - 1
i = torch.clamp(i, 0, len(x_list) - 2)

# Calculate interpolated value
interp_result = y_list[i] + slopes[i] * (x_val - x_list[i])

# Apply boundary conditions using torch.where
result = torch.where(
x_val <= x_list[0],
y_list[0],
torch.where(x_val >= x_list[-1], y_list[-1], interp_result),
)

results.append(result)

# Reshape to original batch shape
output = torch.stack(results).reshape(self.batch_shape)
return output

# Original scalar logic from pcse
# Clamp to boundaries
if x <= self.x_list[0]:
return self.y_list[0]
if x >= self.x_list[-1]:
return self.y_list[-1]
# Ensure contiguous memory layout for searchsorted
x_list_contig = self.x_list.contiguous()
x_contig = x.contiguous() if isinstance(x, torch.Tensor) and x.dim() > 0 else x

# Find interval index using torch.searchsorted for differentiability
i = torch.searchsorted(self.x_list, x, right=False) - 1
i = torch.searchsorted(x_list_contig, x_contig, right=False) - 1
i = torch.clamp(i, 0, len(self.x_list) - 2)

# Linear interpolation
v = self.y_list[i] + self.slopes[i] * (x - self.x_list[i])
return v
# Calculate interpolated value
interp_value = self.y_list[i] + self.slopes[i] * (x - self.x_list[i])

# Apply boundary conditions using torch.where
result = torch.where(
x <= self.x_list[0],
self.y_list[0],
torch.where(x >= self.x_list[-1], self.y_list[-1], interp_value),
)

return result

@property
def shape(self):
Expand Down Expand Up @@ -558,6 +577,9 @@ def _get_params_shape(params):
if parname.startswith("trait"):
continue
param = getattr(params, parname)
# Skip Enum and str parameters
if isinstance(param, Enum) or isinstance(param, str):
continue
# Parameters that are not zero dimensional should all have the same shape
if param.shape and not shape:
shape = param.shape
Expand Down Expand Up @@ -615,3 +637,12 @@ def _broadcast_to(x, shape):
# the dimension along which the time integration is carried out.
# We first append an axis to x, then expand to the given shape
return x.unsqueeze(-1).expand(shape)


def _snapshot_state(obj):
return {name: val.clone() for name, val in obj.__dict__.items() if torch.is_tensor(val)}


def _restore_state(obj, snapshot):
for name, val in snapshot.items():
setattr(obj, name, val)
1 change: 1 addition & 0 deletions tests/physical_models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"leafdynamics",
"rootdynamics",
"potentialproduction",
"phenology",
]
FILE_NAMES = [
f"test_{model_name}_wofost72_{i:02d}.yaml" for model_name in model_names for i in range(1, 45)
Expand Down
12 changes: 11 additions & 1 deletion tests/physical_models/crop/test_leaf_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def test_leaf_dynamics_with_one_parameter_vector(self, param):
[
("TDWI", 0.1),
("SPAN", 5),
("TBASE", 2.0),
("PERDL", 0.01),
("RGRLAI", 0.002),
("KDIFTB", 0.1),
("SLATB", 0.0005),
],
)
def test_leaf_dynamics_with_different_parameter_values(self, param, delta):
Expand All @@ -222,7 +227,12 @@ def test_leaf_dynamics_with_different_parameter_values(self, param, delta):
# Setting a vector with multiple values for the selected parameter
test_value = crop_model_params_provider[param]
# We set the value for which test data are available as the last element
param_vec = torch.tensor([test_value - delta, test_value + delta, test_value])
if param in {"KDIFTB", "SLATB"}:
# AfgenTrait parameters need to have shape (N, M)
non_zeros_mask = test_value != 0
param_vec = torch.stack([test_value + non_zeros_mask * delta, test_value])
else:
param_vec = torch.tensor([test_value - delta, test_value + delta, test_value])
crop_model_params_provider.set_override(param, param_vec, check=False)

engine = EngineTestHelper(
Expand Down
Loading