diff --git a/docs/guides/datasets.md b/docs/guides/datasets.md index 18a1e5b..adbe7da 100644 --- a/docs/guides/datasets.md +++ b/docs/guides/datasets.md @@ -127,6 +127,47 @@ config = mt.configs.MatterTunerConfig( ) ``` +## JSON Dataset +Allows reading atomic structures and properties from JSON files with a specific schema. + +API Reference: {py:class}`mattertune.data.json.JSONDatasetConfig` + +Expected JSON format: +```json +[ + { + "atomic_numbers": [1, 1, 8], + "positions": [[0, 0, 0], [0, 0, 1], [0, 1, 0]], + "cell": [[10, 0, 0], [0, 10, 0], [0, 0, 10]], + "energy": -13.5, + "forces": [[0.1, 0, 0], [-0.1, 0, 0], [0, 0, 0]], + "stress": [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + } +] +``` + +Usage example: +```python +config = mt.configs.MatterTunerConfig( + model=..., + data=mt.configs.AutoSplitDataModuleConfig( + dataset=mt.configs.JSONDatasetConfig( + src="path/to/data.json", + tasks={ + "energy": "energy", + "forces": "forces", + "stress": "stress" + } + ), + train_split=0.8, + batch_size=32 + ), + trainer=... +) +``` + +The `tasks` dictionary maps property names to the corresponding JSON keys in your data file. + Each dataset configuration can be used with either `AutoSplitDataModuleConfig` for automatic train/validation splitting or `ManualSplitDataModuleConfig` for manual split specification. The examples above use `AutoSplitDataModuleConfig` for simplicity. Note that some datasets may require additional dependencies: diff --git a/examples/json_dataset/orb-finetune.py b/examples/json_dataset/orb-finetune.py new file mode 100644 index 0000000..7eb83dc --- /dev/null +++ b/examples/json_dataset/orb-finetune.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import logging + +import nshutils as nu +from lightning.pytorch.strategies import DDPStrategy + +import mattertune.configs as MC +from mattertune import MatterTuner +from mattertune.configs import WandbLoggerConfig + +logging.basicConfig(level=logging.WARNING) +nu.pretty() + + +def main(args_dict: dict): + def hparams(): + hparams = MC.MatterTunerConfig.draft() + + ## Model Hyperparameters + hparams.model = MC.ORBBackboneConfig.draft() + hparams.model.pretrained_model = args_dict["model_name"] + hparams.model.ignore_gpu_batch_transform_error = True + hparams.model.optimizer = MC.AdamWConfig(lr=args_dict["lr"]) + + # Add property + hparams.model.properties = [] + energy = MC.EnergyPropertyConfig( + loss=MC.MAELossConfig(), + loss_coefficient=0.01, + # name=args_dict["task"], + # dtype="float", + ) + hparams.model.properties.append(energy) + forces = MC.ForcesPropertyConfig( + loss=MC.MAELossConfig(), conservative=False, loss_coefficient=50.0 + ) + hparams.model.properties.append(forces) + stress = MC.StressesPropertyConfig( + loss=MC.MAELossConfig(), loss_coefficient=50.0, conservative=False + ) + hparams.model.properties.append(stress) + + ## Data Hyperparameters + hparams.data = MC.AutoSplitDataModuleConfig.draft() + hparams.data.dataset = MC.JSONDatasetConfig.draft() + tasks = { + "energy": args_dict["energy_attr"], + "forces": args_dict["forces_attr"], + "stress": args_dict["stress_attr"], + } + hparams.data.dataset.tasks = tasks + hparams.data.dataset.src = args_dict["data_src"] + hparams.data.train_split = args_dict["train_split"] + hparams.data.validation_split = args_dict["validation_split"] + hparams.data.batch_size = args_dict["batch_size"] + hparams.data.num_workers = 0 + + ## Trainer Hyperparameters + hparams.trainer = MC.TrainerConfig.draft() + hparams.trainer.max_epochs = args_dict["max_epochs"] + hparams.trainer.accelerator = "gpu" + hparams.trainer.devices = args_dict["devices"] + hparams.trainer.gradient_clip_algorithm = "value" + hparams.trainer.gradient_clip_val = 1.0 + hparams.trainer.precision = "bf16" + + # Configure Early Stopping + hparams.trainer.early_stopping = MC.EarlyStoppingConfig( + monitor=f"val/total_loss", patience=200, mode="min" + ) + + # Configure Model Checkpoint + hparams.trainer.checkpoint = MC.ModelCheckpointConfig( + monitor=f"val/total_loss", + dirpath=f"./checkpoints-{args_dict['task']}", + filename="orb-best", + save_top_k=1, + mode="min", + every_n_epochs=10, + ) + + # Configure Logger + hparams.trainer.loggers = [ + WandbLoggerConfig( + project="MatterTune-Examples", + name=f"ORB-Matbench-{args_dict['task']}", + offline=False, + ) + ] + + # Additional trainer settings that need special handling + hparams.trainer.additional_trainer_kwargs = { + "inference_mode": False, + "strategy": DDPStrategy( + static_graph=True, find_unused_parameters=True + ), # Special DDP config + } + + hparams = hparams.finalize(strict=False) + return hparams + + mt_config = hparams() + model, trainer = MatterTuner(mt_config).tune() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--data-src", type=str) + parser.add_argument("--task", type=str) + parser.add_argument("--energy-attr", type=str, default="y") + parser.add_argument("--forces-attr", type=str, default="forces") + parser.add_argument("--stress-attr", type=str, default="stress") + parser.add_argument("--model_name", type=str, default="orb-v2") + parser.add_argument("--train_split", type=float, default=0.9) + parser.add_argument("--validation_split", type=float, default=0.05) + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lr", type=float, default=8.0e-5) + parser.add_argument("--max_epochs", type=int, default=200) + parser.add_argument("--devices", type=int, nargs="+", default=[0]) + args = parser.parse_args() + args_dict = vars(args) + main(args_dict) diff --git a/src/mattertune/.nshconfig.generated.json b/src/mattertune/.nshconfig.generated.json index d25aec9..6593414 100644 --- a/src/mattertune/.nshconfig.generated.json +++ b/src/mattertune/.nshconfig.generated.json @@ -4,39 +4,40 @@ "typed_dicts": null, "json_schemas": { "mattertune.finetune.base.FinetuneModuleBaseConfig": "configs/backbones/jmp/model/FinetuneModuleBaseConfig.schema.json", - "mattertune.loggers.WandbLoggerConfig": "configs/loggers/WandbLoggerConfig.schema.json", "mattertune.loggers.CSVLoggerConfig": "configs/main/CSVLoggerConfig.schema.json", + "mattertune.loggers.WandbLoggerConfig": "configs/loggers/WandbLoggerConfig.schema.json", "mattertune.loggers.TensorBoardLoggerConfig": "configs/loggers/TensorBoardLoggerConfig.schema.json", "mattertune.callbacks.early_stopping.EarlyStoppingConfig": "configs/callbacks/early_stopping/EarlyStoppingConfig.schema.json", "mattertune.main.MatterTunerConfig": "configs/main/MatterTunerConfig.schema.json", - "mattertune.callbacks.model_checkpoint.ModelCheckpointConfig": "configs/callbacks/model_checkpoint/ModelCheckpointConfig.schema.json", "mattertune.main.TrainerConfig": "configs/main/TrainerConfig.schema.json", - "mattertune.normalization.PerAtomReferencingNormalizerConfig": "configs/normalization/PerAtomReferencingNormalizerConfig.schema.json", - "mattertune.normalization.RMSNormalizerConfig": "configs/normalization/RMSNormalizerConfig.schema.json", + "mattertune.callbacks.model_checkpoint.ModelCheckpointConfig": "configs/callbacks/model_checkpoint/ModelCheckpointConfig.schema.json", "mattertune.normalization.MeanStdNormalizerConfig": "configs/normalization/MeanStdNormalizerConfig.schema.json", + "mattertune.normalization.RMSNormalizerConfig": "configs/normalization/RMSNormalizerConfig.schema.json", "mattertune.normalization.NormalizerConfigBase": "configs/normalization/NormalizerConfigBase.schema.json", + "mattertune.normalization.PerAtomReferencingNormalizerConfig": "configs/normalization/PerAtomReferencingNormalizerConfig.schema.json", + "mattertune.finetune.loss.L2MAELossConfig": "configs/finetune/loss/L2MAELossConfig.schema.json", + "mattertune.finetune.loss.MSELossConfig": "configs/finetune/loss/MSELossConfig.schema.json", "mattertune.finetune.loss.MAELossConfig": "configs/finetune/loss/MAELossConfig.schema.json", "mattertune.finetune.loss.HuberLossConfig": "configs/finetune/loss/HuberLossConfig.schema.json", - "mattertune.finetune.loss.MSELossConfig": "configs/finetune/loss/MSELossConfig.schema.json", - "mattertune.finetune.loss.L2MAELossConfig": "configs/finetune/loss/L2MAELossConfig.schema.json", - "mattertune.finetune.lr_scheduler.CosineAnnealingLRConfig": "configs/finetune/lr_scheduler/CosineAnnealingLRConfig.schema.json", - "mattertune.finetune.lr_scheduler.ReduceOnPlateauConfig": "configs/finetune/lr_scheduler/ReduceOnPlateauConfig.schema.json", "mattertune.finetune.lr_scheduler.MultiStepLRConfig": "configs/finetune/lr_scheduler/MultiStepLRConfig.schema.json", - "mattertune.finetune.lr_scheduler.StepLRConfig": "configs/finetune/lr_scheduler/StepLRConfig.schema.json", "mattertune.finetune.lr_scheduler.ExponentialConfig": "configs/finetune/lr_scheduler/ExponentialConfig.schema.json", - "mattertune.finetune.optimizer.SGDConfig": "configs/finetune/optimizer/SGDConfig.schema.json", - "mattertune.finetune.optimizer.AdamConfig": "configs/finetune/optimizer/AdamConfig.schema.json", + "mattertune.finetune.lr_scheduler.StepLRConfig": "configs/finetune/lr_scheduler/StepLRConfig.schema.json", + "mattertune.finetune.lr_scheduler.CosineAnnealingLRConfig": "configs/finetune/lr_scheduler/CosineAnnealingLRConfig.schema.json", + "mattertune.finetune.lr_scheduler.ReduceOnPlateauConfig": "configs/finetune/lr_scheduler/ReduceOnPlateauConfig.schema.json", "mattertune.finetune.optimizer.AdamWConfig": "configs/finetune/optimizer/AdamWConfig.schema.json", - "mattertune.finetune.properties.StressesPropertyConfig": "configs/finetune/properties/StressesPropertyConfig.schema.json", + "mattertune.finetune.optimizer.AdamConfig": "configs/finetune/optimizer/AdamConfig.schema.json", + "mattertune.finetune.optimizer.SGDConfig": "configs/finetune/optimizer/SGDConfig.schema.json", + "mattertune.finetune.properties.PropertyConfigBase": "configs/finetune/properties/PropertyConfigBase.schema.json", + "mattertune.finetune.properties.GraphPropertyConfig": "configs/finetune/properties/GraphPropertyConfig.schema.json", "mattertune.finetune.properties.EnergyPropertyConfig": "configs/finetune/properties/EnergyPropertyConfig.schema.json", + "mattertune.finetune.properties.StressesPropertyConfig": "configs/finetune/properties/StressesPropertyConfig.schema.json", "mattertune.finetune.properties.ForcesPropertyConfig": "configs/finetune/properties/ForcesPropertyConfig.schema.json", - "mattertune.finetune.properties.GraphPropertyConfig": "configs/finetune/properties/GraphPropertyConfig.schema.json", - "mattertune.finetune.properties.PropertyConfigBase": "configs/finetune/properties/PropertyConfigBase.schema.json", - "mattertune.data.mptraj.MPTrajDatasetConfig": "configs/data/mptraj/MPTrajDatasetConfig.schema.json", "mattertune.data.base.DatasetConfigBase": "configs/data/DatasetConfigBase.schema.json", + "mattertune.data.mptraj.MPTrajDatasetConfig": "configs/data/mptraj/MPTrajDatasetConfig.schema.json", "mattertune.data.omat24.OMAT24DatasetConfig": "configs/data/OMAT24DatasetConfig.schema.json", "mattertune.data.xyz.XYZDatasetConfig": "configs/data/XYZDatasetConfig.schema.json", "mattertune.data.matbench.MatbenchDatasetConfig": "configs/data/MatbenchDatasetConfig.schema.json", + "mattertune.data.json_data.JSONDatasetConfig": "configs/data/JSONDatasetConfig.schema.json", "mattertune.data.db.DBDatasetConfig": "configs/data/db/DBDatasetConfig.schema.json", "mattertune.data.mp.MPDatasetConfig": "configs/data/MPDatasetConfig.schema.json", "mattertune.data.datamodule.DataModuleBaseConfig": "configs/data/datamodule/DataModuleBaseConfig.schema.json", @@ -49,8 +50,8 @@ "mattertune.backbones.eqV2.model.FAIRChemAtomsToGraphSystemConfig": "configs/backbones/eqV2/model/FAIRChemAtomsToGraphSystemConfig.schema.json", "mattertune.backbones.orb.model.ORBSystemConfig": "configs/backbones/orb/model/ORBSystemConfig.schema.json", "mattertune.backbones.m3gnet.model.M3GNetGraphComputerConfig": "configs/backbones/m3gnet/M3GNetGraphComputerConfig.schema.json", - "mattertune.backbones.jmp.model.JMPGraphComputerConfig": "configs/backbones/jmp/model/JMPGraphComputerConfig.schema.json", "mattertune.backbones.jmp.model.CutoffsConfig": "configs/backbones/jmp/model/CutoffsConfig.schema.json", + "mattertune.backbones.jmp.model.JMPGraphComputerConfig": "configs/backbones/jmp/model/JMPGraphComputerConfig.schema.json", "mattertune.backbones.jmp.model.MaxNeighborsConfig": "configs/backbones/jmp/model/MaxNeighborsConfig.schema.json" } } \ No newline at end of file diff --git a/src/mattertune/configs/__init__.py b/src/mattertune/configs/__init__.py index 460aed8..600fea6 100644 --- a/src/mattertune/configs/__init__.py +++ b/src/mattertune/configs/__init__.py @@ -19,6 +19,7 @@ ) from mattertune.backbones.orb.model import ORBSystemConfig as ORBSystemConfig from mattertune.data import DatasetConfigBase as DatasetConfigBase +from mattertune.data import JSONDatasetConfig as JSONDatasetConfig from mattertune.data import MatbenchDatasetConfig as MatbenchDatasetConfig from mattertune.data import MPDatasetConfig as MPDatasetConfig from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig @@ -100,6 +101,7 @@ "HuberLossConfig", "JMPBackboneConfig", "JMPGraphComputerConfig", + "JSONDatasetConfig", "L2MAELossConfig", "M3GNetBackboneConfig", "M3GNetGraphComputerConfig", diff --git a/src/mattertune/configs/data/JSONDatasetConfig.schema.json b/src/mattertune/configs/data/JSONDatasetConfig.schema.json new file mode 100644 index 0000000..aca2c7f --- /dev/null +++ b/src/mattertune/configs/data/JSONDatasetConfig.schema.json @@ -0,0 +1,41 @@ +{ + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" +} \ No newline at end of file diff --git a/src/mattertune/configs/data/__init__.py b/src/mattertune/configs/data/__init__.py index 8b924ff..6303a97 100644 --- a/src/mattertune/configs/data/__init__.py +++ b/src/mattertune/configs/data/__init__.py @@ -3,6 +3,7 @@ __codegen__ = True from mattertune.data import DatasetConfigBase as DatasetConfigBase +from mattertune.data import JSONDatasetConfig as JSONDatasetConfig from mattertune.data import MatbenchDatasetConfig as MatbenchDatasetConfig from mattertune.data import MPDatasetConfig as MPDatasetConfig from mattertune.data import OMAT24DatasetConfig as OMAT24DatasetConfig @@ -20,6 +21,7 @@ from . import base as base from . import datamodule as datamodule from . import db as db +from . import json_data as json_data from . import matbench as matbench from . import mp as mp from . import mptraj as mptraj @@ -31,6 +33,7 @@ "DBDatasetConfig", "DataModuleBaseConfig", "DatasetConfigBase", + "JSONDatasetConfig", "MPDatasetConfig", "MPTrajDatasetConfig", "ManualSplitDataModuleConfig", @@ -40,6 +43,7 @@ "base", "datamodule", "db", + "json_data", "matbench", "mp", "mptraj", diff --git a/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json b/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json index 0238e95..0b24b90 100644 --- a/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json +++ b/src/mattertune/configs/data/datamodule/AutoSplitDataModuleConfig.schema.json @@ -82,6 +82,7 @@ "discriminator": { "mapping": { "db": "#/$defs/DBDatasetConfig", + "json": "#/$defs/JSONDatasetConfig", "matbench": "#/$defs/MatbenchDatasetConfig", "mp": "#/$defs/MPDatasetConfig", "mptraj": "#/$defs/MPTrajDatasetConfig", @@ -91,6 +92,9 @@ "propertyName": "type" }, "oneOf": [ + { + "$ref": "#/$defs/JSONDatasetConfig" + }, { "$ref": "#/$defs/MatbenchDatasetConfig" }, @@ -111,6 +115,47 @@ } ] }, + "JSONDatasetConfig": { + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" + }, "MPDatasetConfig": { "description": "Configuration for a dataset stored in the Materials Project database.", "properties": { diff --git a/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json b/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json index 93cd1ca..35978aa 100644 --- a/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json +++ b/src/mattertune/configs/data/datamodule/ManualSplitDataModuleConfig.schema.json @@ -82,6 +82,7 @@ "discriminator": { "mapping": { "db": "#/$defs/DBDatasetConfig", + "json": "#/$defs/JSONDatasetConfig", "matbench": "#/$defs/MatbenchDatasetConfig", "mp": "#/$defs/MPDatasetConfig", "mptraj": "#/$defs/MPTrajDatasetConfig", @@ -91,6 +92,9 @@ "propertyName": "type" }, "oneOf": [ + { + "$ref": "#/$defs/JSONDatasetConfig" + }, { "$ref": "#/$defs/MatbenchDatasetConfig" }, @@ -111,6 +115,47 @@ } ] }, + "JSONDatasetConfig": { + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" + }, "MPDatasetConfig": { "description": "Configuration for a dataset stored in the Materials Project database.", "properties": { diff --git a/src/mattertune/configs/data/json_data/DatasetConfigBase.schema.json b/src/mattertune/configs/data/json_data/DatasetConfigBase.schema.json new file mode 100644 index 0000000..3648f98 --- /dev/null +++ b/src/mattertune/configs/data/json_data/DatasetConfigBase.schema.json @@ -0,0 +1,5 @@ +{ + "properties": {}, + "title": "DatasetConfigBase", + "type": "object" +} \ No newline at end of file diff --git a/src/mattertune/configs/data/json_data/JSONDatasetConfig.schema.json b/src/mattertune/configs/data/json_data/JSONDatasetConfig.schema.json new file mode 100644 index 0000000..aca2c7f --- /dev/null +++ b/src/mattertune/configs/data/json_data/JSONDatasetConfig.schema.json @@ -0,0 +1,41 @@ +{ + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" +} \ No newline at end of file diff --git a/src/mattertune/configs/data/json_data/__init__.py b/src/mattertune/configs/data/json_data/__init__.py new file mode 100644 index 0000000..4009fe6 --- /dev/null +++ b/src/mattertune/configs/data/json_data/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +__codegen__ = True + +from mattertune.data.json_data import DatasetConfigBase as DatasetConfigBase +from mattertune.data.json_data import JSONDatasetConfig as JSONDatasetConfig + +__all__ = [ + "DatasetConfigBase", + "JSONDatasetConfig", +] diff --git a/src/mattertune/configs/main/MatterTunerConfig.schema.json b/src/mattertune/configs/main/MatterTunerConfig.schema.json index f7a72f4..1176227 100644 --- a/src/mattertune/configs/main/MatterTunerConfig.schema.json +++ b/src/mattertune/configs/main/MatterTunerConfig.schema.json @@ -489,6 +489,7 @@ "discriminator": { "mapping": { "db": "#/$defs/DBDatasetConfig", + "json": "#/$defs/JSONDatasetConfig", "matbench": "#/$defs/MatbenchDatasetConfig", "mp": "#/$defs/MPDatasetConfig", "mptraj": "#/$defs/MPTrajDatasetConfig", @@ -498,6 +499,9 @@ "propertyName": "type" }, "oneOf": [ + { + "$ref": "#/$defs/JSONDatasetConfig" + }, { "$ref": "#/$defs/MatbenchDatasetConfig" }, @@ -1027,6 +1031,47 @@ "title": "JMPGraphComputerConfig", "type": "object" }, + "JSONDatasetConfig": { + "properties": { + "type": { + "const": "json", + "default": "json", + "description": "Discriminator for the JSON dataset.", + "enum": [ + "json" + ], + "title": "Type", + "type": "string" + }, + "src": { + "anyOf": [ + { + "type": "string" + }, + { + "format": "path", + "type": "string" + } + ], + "description": "The path to the JSON dataset.", + "title": "Src" + }, + "tasks": { + "additionalProperties": { + "type": "string" + }, + "description": "Attributes in the JSON file that correspond to the tasks to be predicted.", + "title": "Tasks", + "type": "object" + } + }, + "required": [ + "src", + "tasks" + ], + "title": "JSONDatasetConfig", + "type": "object" + }, "L2MAELossConfig": { "properties": { "name": { diff --git a/src/mattertune/data/__init__.py b/src/mattertune/data/__init__.py index 8b49541..0d85166 100644 --- a/src/mattertune/data/__init__.py +++ b/src/mattertune/data/__init__.py @@ -2,6 +2,8 @@ from .base import DatasetConfig as DatasetConfig from .base import DatasetConfigBase as DatasetConfigBase +from .json_data import JSONDataset as JSONDataset +from .json_data import JSONDatasetConfig as JSONDatasetConfig from .matbench import MatbenchDataset as MatbenchDataset from .matbench import MatbenchDatasetConfig as MatbenchDatasetConfig from .mp import MPDataset as MPDataset diff --git a/src/mattertune/data/json_data.py b/src/mattertune/data/json_data.py new file mode 100644 index 0000000..2362450 --- /dev/null +++ b/src/mattertune/data/json_data.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Literal + +import numpy as np +import torch +from ase import Atoms +from ase.calculators.singlepoint import SinglePointCalculator +from torch.utils.data import Dataset +from typing_extensions import override + +from ..registry import data_registry +from .base import DatasetConfigBase + +log = logging.getLogger(__name__) + + +@data_registry.register +class JSONDatasetConfig(DatasetConfigBase): + type: Literal["json"] = "json" + """Discriminator for the JSON dataset.""" + + src: str | Path + """The path to the JSON dataset.""" + + tasks: dict[str, str] + """Attributes in the JSON file that correspond to the tasks to be predicted.""" + + @override + def create_dataset(self): + return JSONDataset(self) + + +class JSONDataset(Dataset[Atoms]): + def __init__(self, config: JSONDatasetConfig): + super().__init__() + self.config = config + + with open(str(self.config.src), "r") as f: + raw_data = json.load(f) + + self.atoms_list = [] + for entry in raw_data: + atoms = Atoms( + numbers=np.array(entry["atomic_numbers"]), + positions=np.array(entry["positions"]), + cell=np.array(entry["cell"]), + pbc=True, + ) + + energy, forces, stress = None, None, None + if "energy" in self.config.tasks: + energy = torch.tensor(entry[self.config.tasks["energy"]]) + if "forces" in self.config.tasks: + forces = torch.tensor(entry[self.config.tasks["forces"]]) + if "stress" in self.config.tasks: + stress = torch.tensor(entry[self.config.tasks["stress"]]) + # ASE requires stress to be of shape (3, 3) or (6,) + # Some datasets store stress with shape (1, 3, 3) + if stress.ndim == 3: + stress = stress.squeeze(0) + + single_point_calc = SinglePointCalculator( + atoms, energy=energy, forces=forces, stress=stress + ) + + atoms.calc = single_point_calc + self.atoms_list.append(atoms) + + log.info(f"Loaded {len(self.atoms_list)} structures from {self.config.src}") + + @override + def __getitem__(self, idx: int) -> Atoms: + return self.atoms_list[idx] + + def __len__(self) -> int: + return len(self.atoms_list)