Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/guides/datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,47 @@ config = mt.configs.MatterTunerConfig(
)
```

## JSON Dataset
Allows reading atomic structures and properties from JSON files with a specific schema.

API Reference: {py:class}`mattertune.data.json.JSONDatasetConfig`

Expected JSON format:
```json
[
{
"atomic_numbers": [1, 1, 8],
"positions": [[0, 0, 0], [0, 0, 1], [0, 1, 0]],
"cell": [[10, 0, 0], [0, 10, 0], [0, 0, 10]],
"energy": -13.5,
"forces": [[0.1, 0, 0], [-0.1, 0, 0], [0, 0, 0]],
"stress": [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
}
]
```

Usage example:
```python
config = mt.configs.MatterTunerConfig(
model=...,
data=mt.configs.AutoSplitDataModuleConfig(
dataset=mt.configs.JSONDatasetConfig(
src="path/to/data.json",
tasks={
"energy": "energy",
"forces": "forces",
"stress": "stress"
}
),
train_split=0.8,
batch_size=32
),
trainer=...
)
```

The `tasks` dictionary maps property names to the corresponding JSON keys in your data file.

Each dataset configuration can be used with either `AutoSplitDataModuleConfig` for automatic train/validation splitting or `ManualSplitDataModuleConfig` for manual split specification. The examples above use `AutoSplitDataModuleConfig` for simplicity.

Note that some datasets may require additional dependencies:
Expand Down
125 changes: 125 additions & 0 deletions examples/json_dataset/orb-finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

import logging

import nshutils as nu
from lightning.pytorch.strategies import DDPStrategy

import mattertune.configs as MC
from mattertune import MatterTuner
from mattertune.configs import WandbLoggerConfig

logging.basicConfig(level=logging.WARNING)
nu.pretty()


def main(args_dict: dict):
def hparams():
hparams = MC.MatterTunerConfig.draft()

## Model Hyperparameters
hparams.model = MC.ORBBackboneConfig.draft()
hparams.model.pretrained_model = args_dict["model_name"]
hparams.model.ignore_gpu_batch_transform_error = True
hparams.model.optimizer = MC.AdamWConfig(lr=args_dict["lr"])

# Add property
hparams.model.properties = []
energy = MC.EnergyPropertyConfig(
loss=MC.MAELossConfig(),
loss_coefficient=0.01,
# name=args_dict["task"],
# dtype="float",
)
hparams.model.properties.append(energy)
forces = MC.ForcesPropertyConfig(
loss=MC.MAELossConfig(), conservative=False, loss_coefficient=50.0
)
hparams.model.properties.append(forces)
stress = MC.StressesPropertyConfig(
loss=MC.MAELossConfig(), loss_coefficient=50.0, conservative=False
)
hparams.model.properties.append(stress)

## Data Hyperparameters
hparams.data = MC.AutoSplitDataModuleConfig.draft()
hparams.data.dataset = MC.JSONDatasetConfig.draft()
tasks = {
"energy": args_dict["energy_attr"],
"forces": args_dict["forces_attr"],
"stress": args_dict["stress_attr"],
}
hparams.data.dataset.tasks = tasks
hparams.data.dataset.src = args_dict["data_src"]
hparams.data.train_split = args_dict["train_split"]
hparams.data.validation_split = args_dict["validation_split"]
hparams.data.batch_size = args_dict["batch_size"]
hparams.data.num_workers = 0

## Trainer Hyperparameters
hparams.trainer = MC.TrainerConfig.draft()
hparams.trainer.max_epochs = args_dict["max_epochs"]
hparams.trainer.accelerator = "gpu"
hparams.trainer.devices = args_dict["devices"]
hparams.trainer.gradient_clip_algorithm = "value"
hparams.trainer.gradient_clip_val = 1.0
hparams.trainer.precision = "bf16"

# Configure Early Stopping
hparams.trainer.early_stopping = MC.EarlyStoppingConfig(
monitor=f"val/total_loss", patience=200, mode="min"
)

# Configure Model Checkpoint
hparams.trainer.checkpoint = MC.ModelCheckpointConfig(
monitor=f"val/total_loss",
dirpath=f"./checkpoints-{args_dict['task']}",
filename="orb-best",
save_top_k=1,
mode="min",
every_n_epochs=10,
)

# Configure Logger
hparams.trainer.loggers = [
WandbLoggerConfig(
project="MatterTune-Examples",
name=f"ORB-Matbench-{args_dict['task']}",
offline=False,
)
]

# Additional trainer settings that need special handling
hparams.trainer.additional_trainer_kwargs = {
"inference_mode": False,
"strategy": DDPStrategy(
static_graph=True, find_unused_parameters=True
), # Special DDP config
}

hparams = hparams.finalize(strict=False)
return hparams

mt_config = hparams()
model, trainer = MatterTuner(mt_config).tune()


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--data-src", type=str)
parser.add_argument("--task", type=str)
parser.add_argument("--energy-attr", type=str, default="y")
parser.add_argument("--forces-attr", type=str, default="forces")
parser.add_argument("--stress-attr", type=str, default="stress")
parser.add_argument("--model_name", type=str, default="orb-v2")
parser.add_argument("--train_split", type=float, default=0.9)
parser.add_argument("--validation_split", type=float, default=0.05)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--lr", type=float, default=8.0e-5)
parser.add_argument("--max_epochs", type=int, default=200)
parser.add_argument("--devices", type=int, nargs="+", default=[0])
args = parser.parse_args()
args_dict = vars(args)
main(args_dict)
33 changes: 17 additions & 16 deletions src/mattertune/.nshconfig.generated.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,40 @@
"typed_dicts": null,
"json_schemas": {
"mattertune.finetune.base.FinetuneModuleBaseConfig": "configs/backbones/jmp/model/FinetuneModuleBaseConfig.schema.json",
"mattertune.loggers.WandbLoggerConfig": "configs/loggers/WandbLoggerConfig.schema.json",
"mattertune.loggers.CSVLoggerConfig": "configs/main/CSVLoggerConfig.schema.json",
"mattertune.loggers.WandbLoggerConfig": "configs/loggers/WandbLoggerConfig.schema.json",
"mattertune.loggers.TensorBoardLoggerConfig": "configs/loggers/TensorBoardLoggerConfig.schema.json",
"mattertune.callbacks.early_stopping.EarlyStoppingConfig": "configs/callbacks/early_stopping/EarlyStoppingConfig.schema.json",
"mattertune.main.MatterTunerConfig": "configs/main/MatterTunerConfig.schema.json",
"mattertune.callbacks.model_checkpoint.ModelCheckpointConfig": "configs/callbacks/model_checkpoint/ModelCheckpointConfig.schema.json",
"mattertune.main.TrainerConfig": "configs/main/TrainerConfig.schema.json",
"mattertune.normalization.PerAtomReferencingNormalizerConfig": "configs/normalization/PerAtomReferencingNormalizerConfig.schema.json",
"mattertune.normalization.RMSNormalizerConfig": "configs/normalization/RMSNormalizerConfig.schema.json",
"mattertune.callbacks.model_checkpoint.ModelCheckpointConfig": "configs/callbacks/model_checkpoint/ModelCheckpointConfig.schema.json",
"mattertune.normalization.MeanStdNormalizerConfig": "configs/normalization/MeanStdNormalizerConfig.schema.json",
"mattertune.normalization.RMSNormalizerConfig": "configs/normalization/RMSNormalizerConfig.schema.json",
"mattertune.normalization.NormalizerConfigBase": "configs/normalization/NormalizerConfigBase.schema.json",
"mattertune.normalization.PerAtomReferencingNormalizerConfig": "configs/normalization/PerAtomReferencingNormalizerConfig.schema.json",
"mattertune.finetune.loss.L2MAELossConfig": "configs/finetune/loss/L2MAELossConfig.schema.json",
"mattertune.finetune.loss.MSELossConfig": "configs/finetune/loss/MSELossConfig.schema.json",
"mattertune.finetune.loss.MAELossConfig": "configs/finetune/loss/MAELossConfig.schema.json",
"mattertune.finetune.loss.HuberLossConfig": "configs/finetune/loss/HuberLossConfig.schema.json",
"mattertune.finetune.loss.MSELossConfig": "configs/finetune/loss/MSELossConfig.schema.json",
"mattertune.finetune.loss.L2MAELossConfig": "configs/finetune/loss/L2MAELossConfig.schema.json",
"mattertune.finetune.lr_scheduler.CosineAnnealingLRConfig": "configs/finetune/lr_scheduler/CosineAnnealingLRConfig.schema.json",
"mattertune.finetune.lr_scheduler.ReduceOnPlateauConfig": "configs/finetune/lr_scheduler/ReduceOnPlateauConfig.schema.json",
"mattertune.finetune.lr_scheduler.MultiStepLRConfig": "configs/finetune/lr_scheduler/MultiStepLRConfig.schema.json",
"mattertune.finetune.lr_scheduler.StepLRConfig": "configs/finetune/lr_scheduler/StepLRConfig.schema.json",
"mattertune.finetune.lr_scheduler.ExponentialConfig": "configs/finetune/lr_scheduler/ExponentialConfig.schema.json",
"mattertune.finetune.optimizer.SGDConfig": "configs/finetune/optimizer/SGDConfig.schema.json",
"mattertune.finetune.optimizer.AdamConfig": "configs/finetune/optimizer/AdamConfig.schema.json",
"mattertune.finetune.lr_scheduler.StepLRConfig": "configs/finetune/lr_scheduler/StepLRConfig.schema.json",
"mattertune.finetune.lr_scheduler.CosineAnnealingLRConfig": "configs/finetune/lr_scheduler/CosineAnnealingLRConfig.schema.json",
"mattertune.finetune.lr_scheduler.ReduceOnPlateauConfig": "configs/finetune/lr_scheduler/ReduceOnPlateauConfig.schema.json",
"mattertune.finetune.optimizer.AdamWConfig": "configs/finetune/optimizer/AdamWConfig.schema.json",
"mattertune.finetune.properties.StressesPropertyConfig": "configs/finetune/properties/StressesPropertyConfig.schema.json",
"mattertune.finetune.optimizer.AdamConfig": "configs/finetune/optimizer/AdamConfig.schema.json",
"mattertune.finetune.optimizer.SGDConfig": "configs/finetune/optimizer/SGDConfig.schema.json",
"mattertune.finetune.properties.PropertyConfigBase": "configs/finetune/properties/PropertyConfigBase.schema.json",
"mattertune.finetune.properties.GraphPropertyConfig": "configs/finetune/properties/GraphPropertyConfig.schema.json",
"mattertune.finetune.properties.EnergyPropertyConfig": "configs/finetune/properties/EnergyPropertyConfig.schema.json",
"mattertune.finetune.properties.StressesPropertyConfig": "configs/finetune/properties/StressesPropertyConfig.schema.json",
"mattertune.finetune.properties.ForcesPropertyConfig": "configs/finetune/properties/ForcesPropertyConfig.schema.json",
"mattertune.finetune.properties.GraphPropertyConfig": "configs/finetune/properties/GraphPropertyConfig.schema.json",
"mattertune.finetune.properties.PropertyConfigBase": "configs/finetune/properties/PropertyConfigBase.schema.json",
"mattertune.data.mptraj.MPTrajDatasetConfig": "configs/data/mptraj/MPTrajDatasetConfig.schema.json",
"mattertune.data.base.DatasetConfigBase": "configs/data/DatasetConfigBase.schema.json",
"mattertune.data.mptraj.MPTrajDatasetConfig": "configs/data/mptraj/MPTrajDatasetConfig.schema.json",
"mattertune.data.omat24.OMAT24DatasetConfig": "configs/data/OMAT24DatasetConfig.schema.json",
"mattertune.data.xyz.XYZDatasetConfig": "configs/data/XYZDatasetConfig.schema.json",
"mattertune.data.matbench.MatbenchDatasetConfig": "configs/data/MatbenchDatasetConfig.schema.json",
"mattertune.data.json_data.JSONDatasetConfig": "configs/data/JSONDatasetConfig.schema.json",
"mattertune.data.db.DBDatasetConfig": "configs/data/db/DBDatasetConfig.schema.json",
"mattertune.data.mp.MPDatasetConfig": "configs/data/MPDatasetConfig.schema.json",
"mattertune.data.datamodule.DataModuleBaseConfig": "configs/data/datamodule/DataModuleBaseConfig.schema.json",
Expand All @@ -49,8 +50,8 @@
"mattertune.backbones.eqV2.model.FAIRChemAtomsToGraphSystemConfig": "configs/backbones/eqV2/model/FAIRChemAtomsToGraphSystemConfig.schema.json",
"mattertune.backbones.orb.model.ORBSystemConfig": "configs/backbones/orb/model/ORBSystemConfig.schema.json",
"mattertune.backbones.m3gnet.model.M3GNetGraphComputerConfig": "configs/backbones/m3gnet/M3GNetGraphComputerConfig.schema.json",
"mattertune.backbones.jmp.model.JMPGraphComputerConfig": "configs/backbones/jmp/model/JMPGraphComputerConfig.schema.json",
"mattertune.backbones.jmp.model.CutoffsConfig": "configs/backbones/jmp/model/CutoffsConfig.schema.json",
"mattertune.backbones.jmp.model.JMPGraphComputerConfig": "configs/backbones/jmp/model/JMPGraphComputerConfig.schema.json",
"mattertune.backbones.jmp.model.MaxNeighborsConfig": "configs/backbones/jmp/model/MaxNeighborsConfig.schema.json"
}
}
2 changes: 2 additions & 0 deletions src/mattertune/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig
from mattertune.data import DatasetConfigBase as DatasetConfigBase
from mattertune.data import JSONDatasetConfig as JSONDatasetConfig
from mattertune.data import MatbenchDatasetConfig as MatbenchDatasetConfig
from mattertune.data import MPDatasetConfig as MPDatasetConfig
from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig
Expand Down Expand Up @@ -100,6 +101,7 @@
"HuberLossConfig",
"JMPBackboneConfig",
"JMPGraphComputerConfig",
"JSONDatasetConfig",
"L2MAELossConfig",
"M3GNetBackboneConfig",
"M3GNetGraphComputerConfig",
Expand Down
41 changes: 41 additions & 0 deletions src/mattertune/configs/data/JSONDatasetConfig.schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"properties": {
"type": {
"const": "json",
"default": "json",
"description": "Discriminator for the JSON dataset.",
"enum": [
"json"
],
"title": "Type",
"type": "string"
},
"src": {
"anyOf": [
{
"type": "string"
},
{
"format": "path",
"type": "string"
}
],
"description": "The path to the JSON dataset.",
"title": "Src"
},
"tasks": {
"additionalProperties": {
"type": "string"
},
"description": "Attributes in the JSON file that correspond to the tasks to be predicted.",
"title": "Tasks",
"type": "object"
}
},
"required": [
"src",
"tasks"
],
"title": "JSONDatasetConfig",
"type": "object"
}
4 changes: 4 additions & 0 deletions src/mattertune/configs/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
__codegen__ = True

from mattertune.data import DatasetConfigBase as DatasetConfigBase
from mattertune.data import JSONDatasetConfig as JSONDatasetConfig
from mattertune.data import MatbenchDatasetConfig as MatbenchDatasetConfig
from mattertune.data import MPDatasetConfig as MPDatasetConfig
from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig
Expand All @@ -20,6 +21,7 @@
from . import base as base
from . import datamodule as datamodule
from . import db as db
from . import json_data as json_data
from . import matbench as matbench
from . import mp as mp
from . import mptraj as mptraj
Expand All @@ -31,6 +33,7 @@
"DBDatasetConfig",
"DataModuleBaseConfig",
"DatasetConfigBase",
"JSONDatasetConfig",
"MPDatasetConfig",
"MPTrajDatasetConfig",
"ManualSplitDataModuleConfig",
Expand All @@ -40,6 +43,7 @@
"base",
"datamodule",
"db",
"json_data",
"matbench",
"mp",
"mptraj",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
"discriminator": {
"mapping": {
"db": "#/$defs/DBDatasetConfig",
"json": "#/$defs/JSONDatasetConfig",
"matbench": "#/$defs/MatbenchDatasetConfig",
"mp": "#/$defs/MPDatasetConfig",
"mptraj": "#/$defs/MPTrajDatasetConfig",
Expand All @@ -91,6 +92,9 @@
"propertyName": "type"
},
"oneOf": [
{
"$ref": "#/$defs/JSONDatasetConfig"
},
{
"$ref": "#/$defs/MatbenchDatasetConfig"
},
Expand All @@ -111,6 +115,47 @@
}
]
},
"JSONDatasetConfig": {
"properties": {
"type": {
"const": "json",
"default": "json",
"description": "Discriminator for the JSON dataset.",
"enum": [
"json"
],
"title": "Type",
"type": "string"
},
"src": {
"anyOf": [
{
"type": "string"
},
{
"format": "path",
"type": "string"
}
],
"description": "The path to the JSON dataset.",
"title": "Src"
},
"tasks": {
"additionalProperties": {
"type": "string"
},
"description": "Attributes in the JSON file that correspond to the tasks to be predicted.",
"title": "Tasks",
"type": "object"
}
},
"required": [
"src",
"tasks"
],
"title": "JSONDatasetConfig",
"type": "object"
},
"MPDatasetConfig": {
"description": "Configuration for a dataset stored in the Materials Project database.",
"properties": {
Expand Down
Loading