Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
67a20ff
add validation to config
SarahAlidoost May 11, 2026
e5273d3
create a simulationobject class
SarahAlidoost May 11, 2026
5a26782
move overrides function to a new module, and call them in engine
SarahAlidoost May 11, 2026
193954f
fix crop_components in config
SarahAlidoost May 11, 2026
c91f7de
fix engine
SarahAlidoost May 11, 2026
5aea9c0
fix ovverride, support defaults
SarahAlidoost May 11, 2026
1988a5b
remove new class simulation object
SarahAlidoost May 11, 2026
45c0e99
fix linters
SarahAlidoost May 11, 2026
57d50e7
fix wofost test of overrides
SarahAlidoost May 11, 2026
d307f83
add tests for simulationobject new functions
SarahAlidoost May 11, 2026
36d2b9e
add more tests for config
SarahAlidoost May 11, 2026
98e4148
add more tests to engine
SarahAlidoost May 11, 2026
c56ac84
refactor post_init in config
SarahAlidoost May 12, 2026
a49cd57
refactor initialize components
SarahAlidoost May 12, 2026
e9a82fb
fix linter errors and formatting
SarahAlidoost May 12, 2026
f36a589
fix tests after refactoring
SarahAlidoost May 12, 2026
2594629
add more tests for ml_models
SarahAlidoost May 12, 2026
3f89210
remove frozen True from config, add more tests to config
SarahAlidoost May 12, 2026
9771613
rerun hybrid nb
SarahAlidoost May 12, 2026
a0956d0
add init of the ml tests
SarahAlidoost May 12, 2026
2acf4ef
add tests of override
SarahAlidoost May 12, 2026
56b6fc8
fix test engine
SarahAlidoost May 12, 2026
6aaa52c
rename initialize functions in simulationobject
SarahAlidoost May 18, 2026
12c80d5
Merge branch 'main' into fix_111
SarahAlidoost May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 61 additions & 120 deletions docs/notebooks/hybrid_partitioning_wofost72.ipynb

Large diffs are not rendered by default.

79 changes: 79 additions & 0 deletions src/diffwofost/physical_models/base/simulationobject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from pcse.base import SimulationObject
from diffwofost.physical_models.override import ComponentOverride


def initialize_single_component(
component_spec: ComponentOverride,
day,
kiosk,
parvalues,
shape=None,
):
"""Build one embedded model component from the override definition.

Args:
component_spec: Specification of the component to be initialized.
day: Current simulation day.
kiosk: Variable kiosk shared across crop components.
parvalues: Physical-model parameter provider.
shape: Optional tensor broadcast shape for the component.
component_overrides: Normalized component override mapping.

Returns:
The instantiated simulation component.

The constructor call depends on whether the override provides a
``model``. Default physical components expect the parameter provider as
their third positional argument, whereas ML-backed wrappers typically
expect a model object there instead. This function centralizes that
dispatch so callers only need to describe the override declaratively.

"""
constructor_kwargs = dict(component_spec.kwargs or {})
component_class = component_spec.component_class
if shape is not None:
constructor_kwargs["shape"] = shape

if component_spec.model is not None:
return component_class(day, kiosk, component_spec.model, **constructor_kwargs)

return component_class(day, kiosk, parvalues, **constructor_kwargs)


def initialize_all_components(
simulation_object: SimulationObject,
day,
kiosk,
parvalues,
shape=None,
component_overrides: dict | None = None,
) -> None:
"""Generic crop component initialization for any SimulationObject.

Args:
simulation_object: The SimulationObject for which to initialize the
components.
day: Start date of the simulation.
kiosk: Variable kiosk used to read and publish crop state.
parvalues: Parameter provider containing the physical-model
parameters for the crop.
shape: Target tensor shape for state and rate variables.
component_overrides: Mapping used to replace one or more
internal components (e.g. in WOFOST) at initialization time.
The order of components in the mapping matter because these are
physical models to be initialized one by one, and some
components may depend on previously initialized ones.
"""
for component_name, (attribute_name, default_spec) in simulation_object.COMPONENT_SPECS.items():
if component_overrides is None:
component_spec = ComponentOverride.from_default(default_spec)
else:
component_spec = component_overrides[component_name]
component = initialize_single_component(
component_spec,
day,
kiosk,
parvalues,
shape=shape,
)
setattr(simulation_object, attribute_name, component)
31 changes: 29 additions & 2 deletions src/diffwofost/physical_models/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
Expand Down Expand Up @@ -139,12 +140,13 @@ def reset_to_defaults(cls) -> None:
cls._initialize_defaults()


@dataclass(frozen=True)
@dataclass
class Configuration:
"""Class to store model configuration from a PCSE configuration files."""

CROP: type[SimulationObject]
CROP_COMPONENTS: dict = field(default_factory=dict)
CROP_COMPONENTS: dict | None = None
CROP_NN_MODEL: type[torch.nn.Module] | None = None
SOIL: type[SimulationObject] | None = None
AGROMANAGEMENT: type[AncillaryObject] = AgroManager
OUTPUT_VARS: list = field(default_factory=list)
Expand All @@ -156,6 +158,31 @@ class Configuration:
model_config_file: str | Path | None = None
description: str | None = None

def __post_init__(self):
"""Validate config data based on CROP.initialize signature."""
sig_arguments = inspect.signature(self.CROP.initialize).parameters

# Nullify CROP_NN_MODEL and CROP_COMPONENTS, if not compatible with CROP.initialize
for field_value, sig_key, attr_name in [
(self.CROP_NN_MODEL, "nn_model", "CROP_NN_MODEL"),
(self.CROP_COMPONENTS, "component_overrides", "CROP_COMPONENTS"),
]:
if field_value is not None and sig_key not in sig_arguments:
setattr(self, attr_name, None)

# Validate component overrides have "class" key with non-None value
for component_name, override in (self.CROP_COMPONENTS or {}).items():
self._validate_component_override(component_name, override)

@staticmethod
def _validate_component_override(component_name: str, override) -> None:
if not isinstance(override, dict) or not override:
raise ValueError(f"Component override for '{component_name}' must be a non-empty dict")
if "class" not in override:
raise ValueError(f"Component override '{component_name}' must have a 'class' key")
if override["class"] is None:
raise ValueError(f"Component override '{component_name}' 'class' cannot be None")

@classmethod
def from_pcse_config_file(cls, filename: str | Path) -> Self:
"""Load the model configuration from a PCSE configuration file.
Expand Down
62 changes: 15 additions & 47 deletions src/diffwofost/physical_models/crop/wofost72.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from diffwofost.physical_models.base import TensorParamTemplate
from diffwofost.physical_models.base import TensorRatesTemplate
from diffwofost.physical_models.base import TensorStatesTemplate
from diffwofost.physical_models.base.simulationobject import initialize_all_components
from diffwofost.physical_models.config import ComputeConfig
from diffwofost.physical_models.crop.assimilation import WOFOST72_Assimilation as Assimilation
from diffwofost.physical_models.crop.evapotranspiration import (
Expand All @@ -28,8 +29,6 @@
WOFOST_Storage_Organ_Dynamics as Storage_Organ_Dynamics,
)
from diffwofost.physical_models.traitlets import Tensor
from diffwofost.physical_models.utils import initialize_component
from diffwofost.physical_models.utils import normalize_component_overrides


class Wofost72(SimulationObject):
Expand Down Expand Up @@ -122,7 +121,6 @@ class Wofost72(SimulationObject):
"storage_organ_dynamics": ("so_dynamics", Storage_Organ_Dynamics),
"leaf_dynamics": ("lv_dynamics", Leaf_Dynamics),
}
COMPONENT_OVERRIDE_META_KEYS = frozenset({"class", "model", "kwargs"})

@property
def device(self):
Expand Down Expand Up @@ -180,58 +178,28 @@ def initialize(
shape: Target tensor shape for state and rate variables.
component_overrides: Optional mapping used to replace one or more
internal WOFOST components at construction time.

The ``component_overrides`` mapping must use the canonical component
names from ``COMPONENT_SPECS``, such as ``partitioning``,
``phenology``, ``assimilation``, ``maintenance_respiration``,
``evapotranspiration``, ``root_dynamics``, ``stem_dynamics``,
``storage_organ_dynamics``, and ``leaf_dynamics``.

Each override entry may be one of the following:

- ``None``: keep the default component class with no extra arguments.
- a ``SimulationObject`` subclass: replace only the component class.
- a dict containing reserved keys:
``class`` for the replacement component class and ``model`` for an
optional ML model object.
- any additional keys in that dict are forwarded as keyword arguments
to the component constructor. A nested ``kwargs`` dict is also
accepted for backward-compatible explicit constructor kwargs.

ML-backed overrides are supported by passing a ``model`` object in the
override entry. When no model is provided, the component is constructed
with ``(day, kiosk, parvalues, shape=..., **kwargs)`` so the component
reads crop parameters from the ``ParameterProvider`` as usual. When a
model is provided, the component is constructed with
``(day, kiosk, model, shape=..., **kwargs)`` instead. This allows a
replacement component such as a neural partitioning module to consume a
trained or trainable PyTorch model while the rest of WOFOST remains
unchanged.
The ``component_overrides`` is a dictionary containing:
- "component_class": The class to use for the component.
- "model": The model to use for the component, if specified in the override.
- "kwargs": Any additional keyword arguments to pass to the component
constructor.
"""
self.params = self.Parameters(parvalues, shape=shape)
self.rates = self.RateVariables(
kiosk, publish=["DMI", "ADMI", "REALLOC_LV", "REALLOC_ST", "REALLOC_SO"], shape=shape
)
self.kiosk = kiosk
component_overrides = normalize_component_overrides(
component_overrides, self.COMPONENT_SPECS, self.COMPONENT_OVERRIDE_META_KEYS
)

# Initialize components of the crop
for component_name, (attribute_name, _) in self.COMPONENT_SPECS.items():
setattr(
self,
attribute_name,
initialize_component(
component_name,
self.COMPONENT_SPECS,
day,
kiosk,
parvalues,
shape=shape,
component_overrides=component_overrides,
),
)
# This will add attributes to self for each component, e.g. self.pheno, self.part, etc.
initialize_all_components(
self,
day,
kiosk,
parvalues,
shape=shape,
component_overrides=component_overrides,
)

# Initial total (living+dead) above-ground biomass of the crop
TAGP = self.kiosk.TWLV + self.kiosk.TWST + self.kiosk.TWSO
Expand Down
26 changes: 19 additions & 7 deletions src/diffwofost/physical_models/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pcse.timer import Timer
from pcse.traitlets import Instance
from diffwofost.physical_models.config import Configuration
from diffwofost.physical_models.override import normalize_components
from diffwofost.physical_models.variablekiosk import VariableKiosk


Expand Down Expand Up @@ -53,6 +54,14 @@ def __init__(
else:
self.mconf = config

# Get the arguments of the CROP class's initialize method
self._components_overrides = None
if self.mconf.CROP_COMPONENTS is not None:
self._components_overrides = normalize_components(
self.mconf.CROP_COMPONENTS,
self.mconf.CROP.COMPONENT_SPECS,
)

def _reset_runtime_state(self):
"""Clear state from a previous simulation run.

Expand Down Expand Up @@ -169,15 +178,18 @@ def _on_CROP_START(
self.parameterprovider.set_active_crop(
crop_name, variety_name, crop_start_type, crop_end_type
)

crop_args = [day, self.kiosk, self.parameterprovider]
crop_kwargs = {"shape": self._shape}

if self.mconf.CROP_NN_MODEL is not None:
# crop_nn_model initialize doesnot accpet parameterprovider
crop_args = [day, self.kiosk, self.mconf.CROP_NN_MODEL]

if self.mconf.CROP_COMPONENTS:
crop_kwargs["component_overrides"] = self.mconf.CROP_COMPONENTS
self.crop = self.mconf.CROP(
day,
self.kiosk,
self.parameterprovider,
**crop_kwargs,
)
crop_kwargs["component_overrides"] = self._components_overrides

self.crop = self.mconf.CROP(*crop_args, **crop_kwargs)

def _finish_cropsimulation(self, day):
"""Finalize and optionally delete the active crop simulation.
Expand Down
81 changes: 81 additions & 0 deletions src/diffwofost/physical_models/override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from dataclasses import dataclass
from typing import Any
import torch
from pcse.base import SimulationObject


@dataclass(frozen=True)
class ComponentOverride:
"""Representation of a component override."""

component_class: type[SimulationObject] | None = None
model: type[torch.nn.Module] = None
kwargs: dict[str, Any] | None = None

def get_kwargs(self) -> dict[str, Any]:
"""Return the keyword arguments for the component constructor."""
return self.kwargs or {}

@classmethod
def from_default(cls, component_class):
"""Create a ComponentOverride from a default component class."""
return cls(component_class=component_class, model=None, kwargs=None)


def normalize_components(
crop_components: dict | None,
crop_component_specs: dict,
) -> dict[str, ComponentOverride]:
"""Convert user-facing component overrides into ComponentOverride instances.

Args:
crop_components: Raw override mapping from the configuration.
crop_component_specs: Mapping of canonical component names to specs.

Returns:
Dictionary keyed by canonical component names with ComponentOverride values.
containing:
- "component_class": The class to use for the component.
- "model": The model to use for the component, if specified in the override.
- "kwargs": Any additional keyword arguments to pass to the component constructor.

Raises:
KeyError: If an unknown component name is provided.
"""
normalized_overrides = {}

for component_name, override in crop_components.items():
if component_name not in crop_component_specs:
msg = (
f"Unknown crop component override: {component_name}. "
f"Valid components are: {list(crop_component_specs.keys())}"
)
raise KeyError(msg)

if isinstance(override, dict):
override_dict = dict(override)
component_class = override_dict.pop("class")
model = override_dict.pop("model", None)
explicit_kwargs = override_dict.pop("kwargs", {})
constructor_kwargs = {**(explicit_kwargs or {}), **override_dict}
else:
component_class = override
model = None
constructor_kwargs = {}

normalized_overrides[component_name] = ComponentOverride(
component_class=component_class,
model=model,
kwargs=constructor_kwargs or None,
)

# Add defaults for any components not in overrides
for component_name, (_, default_class) in crop_component_specs.items():
if component_name not in normalized_overrides:
normalized_overrides[component_name] = ComponentOverride(
component_class=default_class,
model=None,
kwargs=None,
)

return normalized_overrides
Loading
Loading