From b2d4f05676a83438f9f813be1ecc26f153d15bac Mon Sep 17 00:00:00 2001 From: qzheng75 Date: Fri, 29 Nov 2024 12:59:08 -0800 Subject: [PATCH 1/5] Load dataset as an ase.Atoms list from a JSON file --- examples/json_dataset/orb-finetune.py | 124 ++++++++++++++++++ src/mattertune/configs/__init__.py | 2 + .../data/JSONDatasetConfig.schema.json | 42 ++++++ src/mattertune/configs/data/__init__.py | 2 + .../data/json/JSONDatasetConfig.schema.json | 32 +++++ src/mattertune/configs/data/json/__init__.py | 6 + src/mattertune/data/__init__.py | 1 + src/mattertune/data/json_data.py | 80 +++++++++++ 8 files changed, 289 insertions(+) create mode 100644 examples/json_dataset/orb-finetune.py create mode 100644 src/mattertune/configs/data/JSONDatasetConfig.schema.json create mode 100644 src/mattertune/configs/data/json/JSONDatasetConfig.schema.json create mode 100644 src/mattertune/configs/data/json/__init__.py create mode 100644 src/mattertune/data/json_data.py diff --git a/examples/json_dataset/orb-finetune.py b/examples/json_dataset/orb-finetune.py new file mode 100644 index 0000000..bd9be5a --- /dev/null +++ b/examples/json_dataset/orb-finetune.py @@ -0,0 +1,124 @@ +from __future__ import annotations + +import logging + +import mattertune.configs as MC +import nshutils as nu +from lightning.pytorch.strategies import DDPStrategy +from mattertune import MatterTuner +from mattertune.configs import WandbLoggerConfig + +logging.basicConfig(level=logging.WARNING) +nu.pretty() + + +def main(args_dict: dict): + def hparams(): + hparams = MC.MatterTunerConfig.draft() + + ## Model Hyperparameters + hparams.model = MC.ORBBackboneConfig.draft() + hparams.model.pretrained_model = args_dict["model_name"] + hparams.model.ignore_gpu_batch_transform_error = True + hparams.model.optimizer = MC.AdamWConfig(lr=args_dict["lr"]) + + # Add property + hparams.model.properties = [] + energy = MC.EnergyPropertyConfig( + loss=MC.MAELossConfig(), + loss_coefficient=0.01, + # name=args_dict["task"], + # dtype="float", + ) + hparams.model.properties.append(energy) + forces = MC.ForcesPropertyConfig( + loss=MC.MAELossConfig(), conservative=False, loss_coefficient=50.0 + ) + hparams.model.properties.append(forces) + stress = MC.StressesPropertyConfig( + loss=MC.MAELossConfig(), loss_coefficient=50.0, conservative=False + ) + hparams.model.properties.append(stress) + + ## Data Hyperparameters + hparams.data = MC.AutoSplitDataModuleConfig.draft() + hparams.data.dataset = MC.JSONDatasetConfig.draft() + tasks = { + "energy": args_dict["energy_attr"], + "forces": args_dict["forces_attr"], + "stress": args_dict["stress_attr"], + } + hparams.data.dataset.tasks = tasks + hparams.data.dataset.src = args_dict["data_src"] + hparams.data.train_split = args_dict["train_split"] + hparams.data.validation_split = args_dict["validation_split"] + hparams.data.batch_size = args_dict["batch_size"] + hparams.data.num_workers = 0 + + ## Trainer Hyperparameters + hparams.trainer = MC.TrainerConfig.draft() + hparams.trainer.max_epochs = args_dict["max_epochs"] + hparams.trainer.accelerator = "gpu" + hparams.trainer.devices = args_dict["devices"] + hparams.trainer.gradient_clip_algorithm = "value" + hparams.trainer.gradient_clip_val = 1.0 + hparams.trainer.precision = "bf16" + + # Configure Early Stopping + hparams.trainer.early_stopping = MC.EarlyStoppingConfig( + monitor=f"val/total_loss", patience=200, mode="min" + ) + + # Configure Model Checkpoint + hparams.trainer.checkpoint = MC.ModelCheckpointConfig( + monitor=f"val/total_loss", + dirpath=f"./checkpoints-{args_dict['task']}", + filename="orb-best", + save_top_k=1, + mode="min", + every_n_epochs=10, + ) + + # Configure Logger + # hparams.trainer.loggers = [ + # WandbLoggerConfig( + # project="MatterTune-Examples", + # name=f"ORB-Matbench-{args_dict['task']}", + # offline=False, + # ) + # ] + + # Additional trainer settings that need special handling + hparams.trainer.additional_trainer_kwargs = { + "inference_mode": False, + "strategy": DDPStrategy( + static_graph=True, find_unused_parameters=True + ), # Special DDP config + } + + hparams = hparams.finalize(strict=False) + return hparams + + mt_config = hparams() + model, trainer = MatterTuner(mt_config).tune() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data-src", type=str) + parser.add_argument("--task", type=str) + parser.add_argument("--energy-attr", type=str, default="y") + parser.add_argument("--forces-attr", type=str, default="forces") + parser.add_argument("--stress-attr", type=str, default="stress") + parser.add_argument("--model_name", type=str, default="orb-v2") + parser.add_argument("--train_split", type=float, default=0.9) + parser.add_argument("--validation_split", type=float, default=0.05) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=8.0e-5) + parser.add_argument("--max_epochs", type=int, default=200) + parser.add_argument("--devices", type=int, nargs="+", default=[0]) + args = parser.parse_args() + args_dict = vars(args) + main(args_dict) diff --git a/src/mattertune/configs/__init__.py b/src/mattertune/configs/__init__.py index a5a6277..e6c7a65 100644 --- a/src/mattertune/configs/__init__.py +++ b/src/mattertune/configs/__init__.py @@ -19,6 +19,7 @@ ) from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig from mattertune.data import DatasetConfigBase as DatasetConfigBase +from mattertune.data import JSONDatasetConfig as JSONDatasetConfig from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig from mattertune.data import XYZDatasetConfig as XYZDatasetConfig from mattertune.data.datamodule import ( @@ -129,6 +130,7 @@ "TrainerConfig", "WandbLoggerConfig", "XYZDatasetConfig", + "JSONDatasetConfig", "backbones", "callbacks", "data", diff --git a/src/mattertune/configs/data/JSONDatasetConfig.schema.json b/src/mattertune/configs/data/JSONDatasetConfig.schema.json new file mode 100644 index 0000000..dbe93db --- /dev/null +++ b/src/mattertune/configs/data/JSONDatasetConfig.schema.json @@ -0,0 +1,42 @@ +{ + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "type": "object", + "additionalProperties": { + "type": "string" + }, + "description": "The mapping of task names to descriptions.", + "title": "Tasks", + "minProperties": 1 + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" +} \ No newline at end of file diff --git a/src/mattertune/configs/data/__init__.py b/src/mattertune/configs/data/__init__.py index 1496a26..aab6bf8 100644 --- a/src/mattertune/configs/data/__init__.py +++ b/src/mattertune/configs/data/__init__.py @@ -20,6 +20,7 @@ from . import base as base from . import datamodule as datamodule from . import db as db +from . import json as json from . import matbench as matbench from . import mp as mp from . import mptraj as mptraj @@ -45,4 +46,5 @@ "mptraj", "omat24", "xyz", + "json", ] diff --git a/src/mattertune/configs/data/json/JSONDatasetConfig.schema.json b/src/mattertune/configs/data/json/JSONDatasetConfig.schema.json new file mode 100644 index 0000000..bd9c384 --- /dev/null +++ b/src/mattertune/configs/data/json/JSONDatasetConfig.schema.json @@ -0,0 +1,32 @@ +{ + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + } + }, + "required": [ + "src" + ], + "title": "JSONDatasetConfig", + "type": "object" +} \ No newline at end of file diff --git a/src/mattertune/configs/data/json/__init__.py b/src/mattertune/configs/data/json/__init__.py new file mode 100644 index 0000000..614a71d --- /dev/null +++ b/src/mattertune/configs/data/json/__init__.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +__codegen__ = True + +from mattertune.data.json_data import DatasetConfigBase as DatasetConfigBase +from mattertune.data.json_data import JSONDatasetConfig as JSONDatasetConfig diff --git a/src/mattertune/data/__init__.py b/src/mattertune/data/__init__.py index 8046180..8dfe325 100644 --- a/src/mattertune/data/__init__.py +++ b/src/mattertune/data/__init__.py @@ -2,6 +2,7 @@ from .base import DatasetConfig as DatasetConfig from .base import DatasetConfigBase as DatasetConfigBase +from .json_data import JSONDatasetConfig as JSONDatasetConfig from .omat24 import OMAT24Dataset as OMAT24Dataset from .omat24 import OMAT24DatasetConfig as OMAT24DatasetConfig from .xyz import XYZDatasetConfig as XYZDatasetConfig diff --git a/src/mattertune/data/json_data.py b/src/mattertune/data/json_data.py new file mode 100644 index 0000000..2362450 --- /dev/null +++ b/src/mattertune/data/json_data.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Literal + +import numpy as np +import torch +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator +from torch.utils.data import Dataset +from typing_extensions import override + +from ..registry import data_registry +from .base import DatasetConfigBase + +log = logging.getLogger(__name__) + + +@data_registry.register +class JSONDatasetConfig(DatasetConfigBase): + type: Literal["json"] = "json" + """Discriminator for the JSON dataset.""" + + src: str | Path + """The path to the JSON dataset.""" + + tasks: dict[str, str] + """Attributes in the JSON file that correspond to the tasks to be predicted.""" + + @override + def create_dataset(self): + return JSONDataset(self) + + +class JSONDataset(Dataset[Atoms]): + def __init__(self, config: JSONDatasetConfig): + super().__init__() + self.config = config + + with open(str(self.config.src), "r") as f: + raw_data = json.load(f) + + self.atoms_list = [] + for entry in raw_data: + atoms = Atoms( + numbers=np.array(entry["atomic_numbers"]), + positions=np.array(entry["positions"]), + cell=np.array(entry["cell"]), + pbc=True, + ) + + energy, forces, stress = None, None, None + if "energy" in self.config.tasks: + energy = torch.tensor(entry[self.config.tasks["energy"]]) + if "forces" in self.config.tasks: + forces = torch.tensor(entry[self.config.tasks["forces"]]) + if "stress" in self.config.tasks: + stress = torch.tensor(entry[self.config.tasks["stress"]]) + # ASE requires stress to be of shape (3, 3) or (6,) + # Some datasets store stress with shape (1, 3, 3) + if stress.ndim == 3: + stress = stress.squeeze(0) + + single_point_calc = SinglePointCalculator( + atoms, energy=energy, forces=forces, stress=stress + ) + + atoms.calc = single_point_calc + self.atoms_list.append(atoms) + + log.info(f"Loaded {len(self.atoms_list)} structures from {self.config.src}") + + @override + def __getitem__(self, idx: int) -> Atoms: + return self.atoms_list[idx] + + def __len__(self) -> int: + return len(self.atoms_list) From e13fc7cccf970d36facab91d60ee63194cf10308 Mon Sep 17 00:00:00 2001 From: qzheng75 Date: Sat, 30 Nov 2024 19:55:45 -0800 Subject: [PATCH 2/5] Pass ruff tests --- examples/json_dataset/orb-finetune.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/json_dataset/orb-finetune.py b/examples/json_dataset/orb-finetune.py index bd9be5a..7eb83dc 100644 --- a/examples/json_dataset/orb-finetune.py +++ b/examples/json_dataset/orb-finetune.py @@ -2,9 +2,10 @@ import logging -import mattertune.configs as MC import nshutils as nu from lightning.pytorch.strategies import DDPStrategy + +import mattertune.configs as MC from mattertune import MatterTuner from mattertune.configs import WandbLoggerConfig @@ -80,13 +81,13 @@ def hparams(): ) # Configure Logger - # hparams.trainer.loggers = [ - # WandbLoggerConfig( - # project="MatterTune-Examples", - # name=f"ORB-Matbench-{args_dict['task']}", - # offline=False, - # ) - # ] + hparams.trainer.loggers = [ + WandbLoggerConfig( + project="MatterTune-Examples", + name=f"ORB-Matbench-{args_dict['task']}", + offline=False, + ) + ] # Additional trainer settings that need special handling hparams.trainer.additional_trainer_kwargs = { From d731c2ec7dce5d116f69e89a121a22070fc5b759 Mon Sep 17 00:00:00 2001 From: Nima Shoghi Date: Sat, 30 Nov 2024 23:56:49 -0500 Subject: [PATCH 3/5] Export JSONDataset from dat module --- src/mattertune/data/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mattertune/data/__init__.py b/src/mattertune/data/__init__.py index 55aa95c..0d85166 100644 --- a/src/mattertune/data/__init__.py +++ b/src/mattertune/data/__init__.py @@ -2,6 +2,7 @@ from .base import DatasetConfig as DatasetConfig from .base import DatasetConfigBase as DatasetConfigBase +from .json_data import JSONDataset as JSONDataset from .json_data import JSONDatasetConfig as JSONDatasetConfig from .matbench import MatbenchDataset as MatbenchDataset from .matbench import MatbenchDatasetConfig as MatbenchDatasetConfig From e67a72ec29ea5787962844a36f681f4d8914ba44 Mon Sep 17 00:00:00 2001 From: Nima Shoghi Date: Sat, 30 Nov 2024 23:56:52 -0500 Subject: [PATCH 4/5] Re-run config gen --- src/mattertune/.nshconfig.generated.json | 33 +++++++------- src/mattertune/configs/__init__.py | 2 +- .../data/JSONDatasetConfig.schema.json | 5 +-- src/mattertune/configs/data/__init__.py | 6 ++- .../AutoSplitDataModuleConfig.schema.json | 45 +++++++++++++++++++ .../ManualSplitDataModuleConfig.schema.json | 45 +++++++++++++++++++ .../json_data/DatasetConfigBase.schema.json | 5 +++ .../JSONDatasetConfig.schema.json | 11 ++++- .../data/{json => json_data}/__init__.py | 5 +++ .../main/MatterTunerConfig.schema.json | 45 +++++++++++++++++++ 10 files changed, 179 insertions(+), 23 deletions(-) create mode 100644 src/mattertune/configs/data/json_data/DatasetConfigBase.schema.json rename src/mattertune/configs/data/{json => json_data}/JSONDatasetConfig.schema.json (69%) rename src/mattertune/configs/data/{json => json_data}/__init__.py (76%) diff --git a/src/mattertune/.nshconfig.generated.json b/src/mattertune/.nshconfig.generated.json index d25aec9..6593414 100644 --- a/src/mattertune/.nshconfig.generated.json +++ b/src/mattertune/.nshconfig.generated.json @@ -4,39 +4,40 @@ "typed_dicts": null, "json_schemas": { "mattertune.finetune.base.FinetuneModuleBaseConfig": "configs/backbones/jmp/model/FinetuneModuleBaseConfig.schema.json", - "mattertune.loggers.WandbLoggerConfig": "configs/loggers/WandbLoggerConfig.schema.json", "mattertune.loggers.CSVLoggerConfig": "configs/main/CSVLoggerConfig.schema.json", + "mattertune.loggers.WandbLoggerConfig": "configs/loggers/WandbLoggerConfig.schema.json", "mattertune.loggers.TensorBoardLoggerConfig": "configs/loggers/TensorBoardLoggerConfig.schema.json", "mattertune.callbacks.early_stopping.EarlyStoppingConfig": "configs/callbacks/early_stopping/EarlyStoppingConfig.schema.json", "mattertune.main.MatterTunerConfig": "configs/main/MatterTunerConfig.schema.json", - "mattertune.callbacks.model_checkpoint.ModelCheckpointConfig": "configs/callbacks/model_checkpoint/ModelCheckpointConfig.schema.json", "mattertune.main.TrainerConfig": "configs/main/TrainerConfig.schema.json", - "mattertune.normalization.PerAtomReferencingNormalizerConfig": "configs/normalization/PerAtomReferencingNormalizerConfig.schema.json", - "mattertune.normalization.RMSNormalizerConfig": "configs/normalization/RMSNormalizerConfig.schema.json", + "mattertune.callbacks.model_checkpoint.ModelCheckpointConfig": "configs/callbacks/model_checkpoint/ModelCheckpointConfig.schema.json", "mattertune.normalization.MeanStdNormalizerConfig": "configs/normalization/MeanStdNormalizerConfig.schema.json", + "mattertune.normalization.RMSNormalizerConfig": "configs/normalization/RMSNormalizerConfig.schema.json", "mattertune.normalization.NormalizerConfigBase": "configs/normalization/NormalizerConfigBase.schema.json", + "mattertune.normalization.PerAtomReferencingNormalizerConfig": "configs/normalization/PerAtomReferencingNormalizerConfig.schema.json", + "mattertune.finetune.loss.L2MAELossConfig": "configs/finetune/loss/L2MAELossConfig.schema.json", + "mattertune.finetune.loss.MSELossConfig": "configs/finetune/loss/MSELossConfig.schema.json", "mattertune.finetune.loss.MAELossConfig": "configs/finetune/loss/MAELossConfig.schema.json", "mattertune.finetune.loss.HuberLossConfig": "configs/finetune/loss/HuberLossConfig.schema.json", - "mattertune.finetune.loss.MSELossConfig": "configs/finetune/loss/MSELossConfig.schema.json", - "mattertune.finetune.loss.L2MAELossConfig": "configs/finetune/loss/L2MAELossConfig.schema.json", - "mattertune.finetune.lr_scheduler.CosineAnnealingLRConfig": "configs/finetune/lr_scheduler/CosineAnnealingLRConfig.schema.json", - "mattertune.finetune.lr_scheduler.ReduceOnPlateauConfig": "configs/finetune/lr_scheduler/ReduceOnPlateauConfig.schema.json", "mattertune.finetune.lr_scheduler.MultiStepLRConfig": "configs/finetune/lr_scheduler/MultiStepLRConfig.schema.json", - "mattertune.finetune.lr_scheduler.StepLRConfig": "configs/finetune/lr_scheduler/StepLRConfig.schema.json", "mattertune.finetune.lr_scheduler.ExponentialConfig": "configs/finetune/lr_scheduler/ExponentialConfig.schema.json", - "mattertune.finetune.optimizer.SGDConfig": "configs/finetune/optimizer/SGDConfig.schema.json", - "mattertune.finetune.optimizer.AdamConfig": "configs/finetune/optimizer/AdamConfig.schema.json", + "mattertune.finetune.lr_scheduler.StepLRConfig": "configs/finetune/lr_scheduler/StepLRConfig.schema.json", + "mattertune.finetune.lr_scheduler.CosineAnnealingLRConfig": "configs/finetune/lr_scheduler/CosineAnnealingLRConfig.schema.json", + "mattertune.finetune.lr_scheduler.ReduceOnPlateauConfig": "configs/finetune/lr_scheduler/ReduceOnPlateauConfig.schema.json", "mattertune.finetune.optimizer.AdamWConfig": "configs/finetune/optimizer/AdamWConfig.schema.json", - "mattertune.finetune.properties.StressesPropertyConfig": "configs/finetune/properties/StressesPropertyConfig.schema.json", + "mattertune.finetune.optimizer.AdamConfig": "configs/finetune/optimizer/AdamConfig.schema.json", + "mattertune.finetune.optimizer.SGDConfig": "configs/finetune/optimizer/SGDConfig.schema.json", + "mattertune.finetune.properties.PropertyConfigBase": "configs/finetune/properties/PropertyConfigBase.schema.json", + "mattertune.finetune.properties.GraphPropertyConfig": "configs/finetune/properties/GraphPropertyConfig.schema.json", "mattertune.finetune.properties.EnergyPropertyConfig": "configs/finetune/properties/EnergyPropertyConfig.schema.json", + "mattertune.finetune.properties.StressesPropertyConfig": "configs/finetune/properties/StressesPropertyConfig.schema.json", "mattertune.finetune.properties.ForcesPropertyConfig": "configs/finetune/properties/ForcesPropertyConfig.schema.json", - "mattertune.finetune.properties.GraphPropertyConfig": "configs/finetune/properties/GraphPropertyConfig.schema.json", - "mattertune.finetune.properties.PropertyConfigBase": "configs/finetune/properties/PropertyConfigBase.schema.json", - "mattertune.data.mptraj.MPTrajDatasetConfig": "configs/data/mptraj/MPTrajDatasetConfig.schema.json", "mattertune.data.base.DatasetConfigBase": "configs/data/DatasetConfigBase.schema.json", + "mattertune.data.mptraj.MPTrajDatasetConfig": "configs/data/mptraj/MPTrajDatasetConfig.schema.json", "mattertune.data.omat24.OMAT24DatasetConfig": "configs/data/OMAT24DatasetConfig.schema.json", "mattertune.data.xyz.XYZDatasetConfig": "configs/data/XYZDatasetConfig.schema.json", "mattertune.data.matbench.MatbenchDatasetConfig": "configs/data/MatbenchDatasetConfig.schema.json", + "mattertune.data.json_data.JSONDatasetConfig": "configs/data/JSONDatasetConfig.schema.json", "mattertune.data.db.DBDatasetConfig": "configs/data/db/DBDatasetConfig.schema.json", "mattertune.data.mp.MPDatasetConfig": "configs/data/MPDatasetConfig.schema.json", "mattertune.data.datamodule.DataModuleBaseConfig": "configs/data/datamodule/DataModuleBaseConfig.schema.json", @@ -49,8 +50,8 @@ "mattertune.backbones.eqV2.model.FAIRChemAtomsToGraphSystemConfig": "configs/backbones/eqV2/model/FAIRChemAtomsToGraphSystemConfig.schema.json", "mattertune.backbones.orb.model.ORBSystemConfig": "configs/backbones/orb/model/ORBSystemConfig.schema.json", "mattertune.backbones.m3gnet.model.M3GNetGraphComputerConfig": "configs/backbones/m3gnet/M3GNetGraphComputerConfig.schema.json", - "mattertune.backbones.jmp.model.JMPGraphComputerConfig": "configs/backbones/jmp/model/JMPGraphComputerConfig.schema.json", "mattertune.backbones.jmp.model.CutoffsConfig": "configs/backbones/jmp/model/CutoffsConfig.schema.json", + "mattertune.backbones.jmp.model.JMPGraphComputerConfig": "configs/backbones/jmp/model/JMPGraphComputerConfig.schema.json", "mattertune.backbones.jmp.model.MaxNeighborsConfig": "configs/backbones/jmp/model/MaxNeighborsConfig.schema.json" } } \ No newline at end of file diff --git a/src/mattertune/configs/__init__.py b/src/mattertune/configs/__init__.py index 9132fba..600fea6 100644 --- a/src/mattertune/configs/__init__.py +++ b/src/mattertune/configs/__init__.py @@ -101,6 +101,7 @@ "HuberLossConfig", "JMPBackboneConfig", "JMPGraphComputerConfig", + "JSONDatasetConfig", "L2MAELossConfig", "M3GNetBackboneConfig", "M3GNetGraphComputerConfig", @@ -130,7 +131,6 @@ "TrainerConfig", "WandbLoggerConfig", "XYZDatasetConfig", - "JSONDatasetConfig", "backbones", "callbacks", "data", diff --git a/src/mattertune/configs/data/JSONDatasetConfig.schema.json b/src/mattertune/configs/data/JSONDatasetConfig.schema.json index dbe93db..aca2c7f 100644 --- a/src/mattertune/configs/data/JSONDatasetConfig.schema.json +++ b/src/mattertune/configs/data/JSONDatasetConfig.schema.json @@ -24,13 +24,12 @@ "title": "Src" }, "tasks": { - "type": "object", "additionalProperties": { "type": "string" }, - "description": "The mapping of task names to descriptions.", + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", "title": "Tasks", - "minProperties": 1 + "type": "object" } }, "required": [ diff --git a/src/mattertune/configs/data/__init__.py b/src/mattertune/configs/data/__init__.py index 9d5ac33..6303a97 100644 --- a/src/mattertune/configs/data/__init__.py +++ b/src/mattertune/configs/data/__init__.py @@ -3,6 +3,7 @@ __codegen__ = True from mattertune.data import DatasetConfigBase as DatasetConfigBase +from mattertune.data import JSONDatasetConfig as JSONDatasetConfig from mattertune.data import MatbenchDatasetConfig as MatbenchDatasetConfig from mattertune.data import MPDatasetConfig as MPDatasetConfig from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig @@ -20,7 +21,7 @@ from . import base as base from . import datamodule as datamodule from . import db as db -from . import json as json +from . import json_data as json_data from . import matbench as matbench from . import mp as mp from . import mptraj as mptraj @@ -32,6 +33,7 @@ "DBDatasetConfig", "DataModuleBaseConfig", "DatasetConfigBase", + "JSONDatasetConfig", "MPDatasetConfig", "MPTrajDatasetConfig", "ManualSplitDataModuleConfig", @@ -41,10 +43,10 @@ "base", "datamodule", "db", + "json_data", "matbench", "mp", "mptraj", "omat24", "xyz", - "json", ] diff --git a/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json b/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json index 0238e95..0b24b90 100644 --- a/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json +++ b/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json @@ -82,6 +82,7 @@ "discriminator": { "mapping": { "db": "#/$defs/DBDatasetConfig", + "json": "#/$defs/JSONDatasetConfig", "matbench": "#/$defs/MatbenchDatasetConfig", "mp": "#/$defs/MPDatasetConfig", "mptraj": "#/$defs/MPTrajDatasetConfig", @@ -91,6 +92,9 @@ "propertyName": "type" }, "oneOf": [ + { + "$ref": "#/$defs/JSONDatasetConfig" + }, { "$ref": "#/$defs/MatbenchDatasetConfig" }, @@ -111,6 +115,47 @@ } ] }, + "JSONDatasetConfig": { + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" + }, "MPDatasetConfig": { "description": "Configuration for a dataset stored in the Materials Project database.", "properties": { diff --git a/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json b/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json index 93cd1ca..35978aa 100644 --- a/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json +++ b/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json @@ -82,6 +82,7 @@ "discriminator": { "mapping": { "db": "#/$defs/DBDatasetConfig", + "json": "#/$defs/JSONDatasetConfig", "matbench": "#/$defs/MatbenchDatasetConfig", "mp": "#/$defs/MPDatasetConfig", "mptraj": "#/$defs/MPTrajDatasetConfig", @@ -91,6 +92,9 @@ "propertyName": "type" }, "oneOf": [ + { + "$ref": "#/$defs/JSONDatasetConfig" + }, { "$ref": "#/$defs/MatbenchDatasetConfig" }, @@ -111,6 +115,47 @@ } ] }, + "JSONDatasetConfig": { + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" + }, "MPDatasetConfig": { "description": "Configuration for a dataset stored in the Materials Project database.", "properties": { diff --git a/src/mattertune/configs/data/json_data/DatasetConfigBase.schema.json b/src/mattertune/configs/data/json_data/DatasetConfigBase.schema.json new file mode 100644 index 0000000..3648f98 --- /dev/null +++ b/src/mattertune/configs/data/json_data/DatasetConfigBase.schema.json @@ -0,0 +1,5 @@ +{ + "properties": {}, + "title": "DatasetConfigBase", + "type": "object" +} \ No newline at end of file diff --git a/src/mattertune/configs/data/json/JSONDatasetConfig.schema.json b/src/mattertune/configs/data/json_data/JSONDatasetConfig.schema.json similarity index 69% rename from src/mattertune/configs/data/json/JSONDatasetConfig.schema.json rename to src/mattertune/configs/data/json_data/JSONDatasetConfig.schema.json index bd9c384..aca2c7f 100644 --- a/src/mattertune/configs/data/json/JSONDatasetConfig.schema.json +++ b/src/mattertune/configs/data/json_data/JSONDatasetConfig.schema.json @@ -22,10 +22,19 @@ ], "description": "The path to the JSON dataset.", "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" } }, "required": [ - "src" + "src", + "tasks" ], "title": "JSONDatasetConfig", "type": "object" diff --git a/src/mattertune/configs/data/json/__init__.py b/src/mattertune/configs/data/json_data/__init__.py similarity index 76% rename from src/mattertune/configs/data/json/__init__.py rename to src/mattertune/configs/data/json_data/__init__.py index 614a71d..4009fe6 100644 --- a/src/mattertune/configs/data/json/__init__.py +++ b/src/mattertune/configs/data/json_data/__init__.py @@ -4,3 +4,8 @@ from mattertune.data.json_data import DatasetConfigBase as DatasetConfigBase from mattertune.data.json_data import JSONDatasetConfig as JSONDatasetConfig + +__all__ = [ + "DatasetConfigBase", + "JSONDatasetConfig", +] diff --git a/src/mattertune/configs/main/MatterTunerConfig.schema.json b/src/mattertune/configs/main/MatterTunerConfig.schema.json index f7a72f4..1176227 100644 --- a/src/mattertune/configs/main/MatterTunerConfig.schema.json +++ b/src/mattertune/configs/main/MatterTunerConfig.schema.json @@ -489,6 +489,7 @@ "discriminator": { "mapping": { "db": "#/$defs/DBDatasetConfig", + "json": "#/$defs/JSONDatasetConfig", "matbench": "#/$defs/MatbenchDatasetConfig", "mp": "#/$defs/MPDatasetConfig", "mptraj": "#/$defs/MPTrajDatasetConfig", @@ -498,6 +499,9 @@ "propertyName": "type" }, "oneOf": [ + { + "$ref": "#/$defs/JSONDatasetConfig" + }, { "$ref": "#/$defs/MatbenchDatasetConfig" }, @@ -1027,6 +1031,47 @@ "title": "JMPGraphComputerConfig", "type": "object" }, + "JSONDatasetConfig": { + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" + }, "L2MAELossConfig": { "properties": { "name": { From 562a0c6fb7fb1b65021fa8d8f100cefed2ce8980 Mon Sep 17 00:00:00 2001 From: Nima Shoghi Date: Sat, 30 Nov 2024 23:58:38 -0500 Subject: [PATCH 5/5] Docs: Add docs for JSON dataset --- docs/guides/datasets.md | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/docs/guides/datasets.md b/docs/guides/datasets.md index 18a1e5b..adbe7da 100644 --- a/docs/guides/datasets.md +++ b/docs/guides/datasets.md @@ -127,6 +127,47 @@ config = mt.configs.MatterTunerConfig( ) ``` +## JSON Dataset +Allows reading atomic structures and properties from JSON files with a specific schema. + +API Reference: {py:class}`mattertune.data.json.JSONDatasetConfig` + +Expected JSON format: +```json +[ + { + "atomic_numbers": [1, 1, 8], + "positions": [[0, 0, 0], [0, 0, 1], [0, 1, 0]], + "cell": [[10, 0, 0], [0, 10, 0], [0, 0, 10]], + "energy": -13.5, + "forces": [[0.1, 0, 0], [-0.1, 0, 0], [0, 0, 0]], + "stress": [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + } +] +``` + +Usage example: +```python +config = mt.configs.MatterTunerConfig( + model=..., + data=mt.configs.AutoSplitDataModuleConfig( + dataset=mt.configs.JSONDatasetConfig( + src="path/to/data.json", + tasks={ + "energy": "energy", + "forces": "forces", + "stress": "stress" + } + ), + train_split=0.8, + batch_size=32 + ), + trainer=... +) +``` + +The `tasks` dictionary maps property names to the corresponding JSON keys in your data file. + Each dataset configuration can be used with either `AutoSplitDataModuleConfig` for automatic train/validation splitting or `ManualSplitDataModuleConfig` for manual split specification. The examples above use `AutoSplitDataModuleConfig` for simplicity. Note that some datasets may require additional dependencies: