Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ share/python-wheels/

# jupyter notebook
.ipynb_checkpoints
docs/notebooks/test*

# Unit test / coverage reports
htmlcov/
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Issues = "https://github.com/WUR-AI/diffwofost/issues"

[tool.pytest.ini_options]
testpaths = ["tests"]
filterwarnings = [
Comment thread
SCiarella marked this conversation as resolved.
"ignore::DeprecationWarning:pcse.base.simulationobject",
]


[tool.coverage.run]
Expand Down
135 changes: 135 additions & 0 deletions src/diffwofost/physical_models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,146 @@
from pathlib import Path
from typing import Self
import pcse
import torch
from pcse.agromanager import AgroManager
from pcse.base import AncillaryObject
from pcse.base import SimulationObject


class ComputeConfig:
"""Central configuration for device and dtype settings.

This class provides a centralized way to control PyTorch device and dtype
settings across all simulation objects in diffWOFOST. Instead of setting
device and dtype individually for each class, use this central configuration
to apply settings globally.

**Default Behavior:**

- **Device**: Automatically defaults to 'cuda' if available, otherwise 'cpu'
- **Dtype**: Defaults to torch.float64

**Basic Usage:**

>>> from diffwofost.physical_models.config import ComputeConfig
>>> import torch
>>>
>>> # Set device to CPU
>>> ComputeConfig.set_device('cpu')
>>>
>>> # Or use a torch.device object
>>> ComputeConfig.set_device(torch.device('cuda'))
>>>
>>> # Set dtype to float32
>>> ComputeConfig.set_dtype(torch.float32)
>>>
>>> # Get current settings
>>> device = ComputeConfig.get_device() # Returns: torch.device('cpu')
>>> dtype = ComputeConfig.get_dtype() # Returns: torch.float32

**Using with Simulation Objects:**

All simulation objects (e.g., WOFOST_Leaf_Dynamics, WOFOST_Phenology)
automatically use the settings from ComputeConfig. No changes needed to
instantiation code:

>>> from diffwofost.physical_models.config import ComputeConfig
>>> from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics
>>>
>>> # Set global compute settings
>>> ComputeConfig.set_device('cuda')
>>> ComputeConfig.set_dtype(torch.float32)
>>>
>>> # Instantiate objects - they automatically use global settings
>>> leaf_dynamics = WOFOST_Leaf_Dynamics()

**Switching Between Devices:**

Useful for switching between GPU training and CPU evaluation:

>>> # Train on GPU
>>> ComputeConfig.set_device('cuda')
>>> ComputeConfig.set_dtype(torch.float32)
>>> # ... run training ...
>>>
>>> # Evaluate on CPU
>>> ComputeConfig.set_device('cpu')
>>> ComputeConfig.set_dtype(torch.float64)
>>> # ... run evaluation ...

**Resetting to Defaults:**

>>> ComputeConfig.reset_to_defaults()

"""

_device: torch.device = None
_dtype: torch.dtype = None

@classmethod
def _initialize_defaults(cls):
"""Initialize default device and dtype if not already set."""
if cls._device is None:
cls._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if cls._dtype is None:
cls._dtype = torch.float64

@classmethod
def get_device(cls) -> torch.device:
"""Get the current device setting.

Returns:
torch.device: The current device (cuda or cpu)
"""
cls._initialize_defaults()
return cls._device

@classmethod
def set_device(cls, device: str | torch.device) -> None:
"""Set the device to use for tensor operations.

Args:
device (str | torch.device): Device to use ('cuda', 'cpu', or torch.device object)

Example:
>>> ComputeConfig.set_device('cuda')
>>> ComputeConfig.set_device(torch.device('cpu'))
"""
if isinstance(device, str):
cls._device = torch.device(device)
else:
cls._device = device

@classmethod
def get_dtype(cls) -> torch.dtype:
"""Get the current dtype setting.

Returns:
torch.dtype: The current dtype (e.g., torch.float32, torch.float64)
"""
cls._initialize_defaults()
return cls._dtype

@classmethod
def set_dtype(cls, dtype: torch.dtype) -> None:
"""Set the dtype to use for tensor creation.

Args:
dtype (torch.dtype): PyTorch dtype (torch.float32, torch.float64, etc.)

Example:
>>> ComputeConfig.set_dtype(torch.float32)
"""
cls._dtype = dtype

@classmethod
def reset_to_defaults(cls) -> None:
"""Reset device and dtype to their default values."""
cls._device = None
cls._dtype = None
cls._initialize_defaults()


@dataclass(frozen=True)
class Configuration:
"""Class to store model configuration from a PCSE configuration files."""
Expand Down
Loading