diff --git a/.gitignore b/.gitignore index d17a81ab..7e4c54c8 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +core.python.* # C extensions *.so @@ -163,18 +164,23 @@ dmypy.json # mac .DS_Store -# data -data/** - # config ./config/* # results -results/** +**/*/results/ server/ -main.py - +# tests +testing/* test*.py test*.ipynb + +checkpoints/ + +# misc +.flake8 +.pylintrc +**/wandb/ +*.out diff --git a/README.md b/README.md index cab93d00..8f9445c1 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,9 @@ ``` conda-merge env.common.yaml env.gpu.yaml > env.yaml conda env create -f env.yaml + conda activate matdeeplearn ``` - + 2. CPU-only machines: 1. M1 Macs (see https://github.com/pyg-team/pytorch_geometric/issues/4549): @@ -27,6 +28,7 @@ ``` conda-merge env.common.yaml env.cpu.yaml > env.yaml conda env create -f env.yaml + conda activate matdeeplearn ``` 3. Install package: diff --git a/configs/config.yml b/configs/config.yml index 6573ccc4..2d327d8e 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -1,25 +1,20 @@ - trainer: property task: # run_mode: train identifier: "my_train_job" - reprocess: False - - parallel: True + # seed=0 means random initalization seed: 0 - #seed=0 means random initalization - - + # Defaults to run directory if not specified + # save_dir: "." + # checkpoint_dir: "." write_output: True parallel: True #Training print out frequency (print per n number of epochs) verbosity: 5 - - model: name: CGCNN load_model: False @@ -27,7 +22,7 @@ model: model_path: "my_model.pth" edge_steps: 50 self_loop: True - #model attributes + # model attributes dim1: 100 dim2: 150 pre_fc_count: 1 @@ -42,8 +37,9 @@ model: optim: max_epochs: 250 + max_checkpoint_epochs: 0 lr: 0.002 - #Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper + # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper loss: loss_type: "TorchLossWrapper" loss_args: {"loss_fn": "l1_loss"} @@ -57,33 +53,31 @@ optim: scheduler_args: {"mode":"min", "factor":0.8, "patience":10, "min_lr":0.00001, "threshold":0.0002} dataset: - processed: False # if False, need to preprocessor data and generate .pt file - # Whether to use "inmemory" or "large" format for pytorch-geometric dataset. Reccomend inmemory unless the dataset is too large - # dataset_type: "inmemory" - #Path to data files + processed: False + # Path to data files src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_npj/raw/" - #Path to target file within data_path + # Path to target file within data_path target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_npj/targets.csv" - #Path to save processed data.pt file + # Path to save processed data.pt file pt_path: "/global/homes/s/shuyijia/datasets/MP_data_npj/" - #Format of data files (limit to those supported by ASE) + transforms: + - name: GetY + args: + index: 0 + otf: False # Optional parameter, default is False + # Format of data files (limit to those supported by ASE) data_format: "json" - #Method of obtaining atom idctionary: available:(onehot) + # Method of obtaining atom idctionary: available:(onehot) node_representation: "onehot" additional_attributes: [] - #Print out processing info + # Print out processing info verbose: True - - #Loading dataset params - #Index of target column in targets.csv - target_index: 0 - - #graph specific settings + # Index of target column in targets.csv + # graph specific settings cutoff_radius : 8.0 n_neighbors : 12 edge_steps : 50 - - #Ratios for train/val/test split out of a total of 1 + # Ratios for train/val/test split out of a total of 1 train_ratio: 0.8 val_ratio: 0.05 test_ratio: 0.15 diff --git a/configs/examples/config_alignn.yml b/configs/examples/config_alignn.yml new file mode 100644 index 00000000..a86e44ca --- /dev/null +++ b/configs/examples/config_alignn.yml @@ -0,0 +1,92 @@ +trainer: property + +task: + identifier: "alignn_train_100" + reprocess: False + parallel: True + seed: 0 + save_dir: "." + checkpoint_dir: "." + write_output: True + parallel: True + # Training print out frequency (print per n number of epochs) + verbosity: 1 + + +model: + name: ALIGNN + load_model: False + save_model: True + model_path: "alignn_model.pth" + alignn_layers: 4 + gcn_layers: 4 + atom_input_features: 114 + edge_input_features: 50 + triplet_input_features: 40 + embedding_features: 64 + hidden_features: 256 + output_features: 1 + min_edge_distance: 0.0 + max_edge_distance: 8.0 + link: "identity" + +optim: + max_epochs: 100 + lr: 0.001 + # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper + loss: + loss_type: "TorchLossWrapper" + loss_args: {"loss_fn": "mse_loss"} + + batch_size: 64 + + optimizer: + optimizer_type: "AdamW" + optimizer_args: {"weight_decay": 0.00001} + scheduler: + scheduler_type: "OneCycleLR" + # Look further into steps per epoch, for now hardcoded calculation from paper + scheduler_args: {"max_lr": 0.001, "epochs": 300, "steps_per_epoch": 1} + +dataset: + processed: False + # Path to data files + # src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_69K/raw/" + src: "/storage/home/hhive1/sbaskaran31/scratch/MP_data_69K/raw/" + # Path to target file within data_path + # target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_69K/targets.csv" + target_path: "/storage/home/hhive1/sbaskaran31/scratch/MP_data_69K/targets.csv" + # Path to save processed data.pt file (a directory path not filepath) + # pt_path: "/global/cfs/projectdirs/m3641/Sidharth/datasets/MP_data_69K/" + pt_path: "/storage/home/hhive1/sbaskaran31/scratch/MP_data_69K/" + transforms: + - name: GetY + args: + index: 0 + otf: False + - name: NumNodeTransform + args: + otf: False + - name: LineGraphMod + args: + otf: False + - name: ToFloat + args: + otf: False + # Format of data files (limit to those supported by ASE) + data_format: "json" + # Method of obtaining atom idctionary: available:(onehot) + node_representation: "onehot" + additional_attributes: [] + # Print out processing info + verbose: True + # Loading dataset params + # Index of target column in targets.csv + # graph specific settings + cutoff_radius : 8.0 + n_neighbors : 12 + edge_steps : 50 + # Ratios for train/val/test split out of a total of 1 + train_ratio: 0.8 + val_ratio: 0.05 + test_ratio: 0.15 diff --git a/configs/examples/config_graphite.yml b/configs/examples/config_graphite.yml new file mode 100644 index 00000000..37115ba9 --- /dev/null +++ b/configs/examples/config_graphite.yml @@ -0,0 +1,98 @@ + +trainer: property + +task: + # run_mode: train + identifier: "alignn_train_100" + + reprocess: False + + + parallel: True + seed: 0 + #seed=0 means random initalization + + + write_output: True + parallel: True + #Training print out frequency (print per n number of epochs) + verbosity: 1 + + + +model: + name: ALIGNN_GRAPHITE + load_model: False + save_model: True + model_path: "alignn_graphite_model.pth" + num_interactions: 4 + num_species: 3 + cutoff: 3.0 + dim: 64 + # min_angle: float = 0.0, + # max_angle: float = torch.acos(torch.zeros(1)).item() * 2, + link: "identity" + +optim: + max_epochs: 103 + lr: 0.001 + #Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper + loss: + loss_type: "TorchLossWrapper" + loss_args: {"loss_fn": "mse_loss"} + + batch_size: 64 + + optimizer: + optimizer_type: "AdamW" + optimizer_args: {"weight_decay": 0.00001} + scheduler: + scheduler_type: "OneCycleLR" + # Look further into steps per epoch, for now hardcoded calculation from paper + scheduler_args: {"max_lr": 0.001, "epochs": 300, "steps_per_epoch": 1} + +dataset: + processed: True # if False, need to preprocessor data and generate .pt file + # Whether to use "inmemory" or "large" format for pytorch-geometric dataset. Reccomend inmemory unless the dataset is too large + # dataset_type: "inmemory" + #Path to data files + src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_69K/raw/" + #Path to target file within data_path + target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_69K/targets.csv" + #Path to save processed data.pt file (a directory path not filepath) + pt_path: "/global/cfs/projectdirs/m3641/Sidharth/datasets/MP_data_69K/" + transforms: + - name: GetY + args: + index: 0 + otf: False + - name: NumNodeTransform + args: + otf: False + - name: LineGraphMod + args: + otf: False + - name: ToFloat + args: + otf: False + #Format of data files (limit to those supported by ASE) + data_format: "json" + #Method of obtaining atom idctionary: available:(onehot) + node_representation: "onehot" + additional_attributes: [] + #Print out processing info + verbose: True + + #Loading dataset params + #Index of target column in targets.csv + target_index: 0 + + #graph specific settings + cutoff_radius : 8.0 + n_neighbors : 12 + edge_steps : 50 + + #Ratios for train/val/test split out of a total of 1 + train_ratio: 0.8 + val_ratio: 0.05 + test_ratio: 0.15 diff --git a/env.common.yaml b/env.common.yaml index f1c9ee5e..88683366 100644 --- a/env.common.yaml +++ b/env.common.yaml @@ -11,6 +11,6 @@ dependencies: - pre-commit - numpy - scipy - - ase=3.21.* + - ase==3.21.* - black - pandas diff --git a/matdeeplearn/common/config/build_config.py b/matdeeplearn/common/config/build_config.py index 91054b8a..afa8c6d7 100644 --- a/matdeeplearn/common/config/build_config.py +++ b/matdeeplearn/common/config/build_config.py @@ -87,9 +87,9 @@ def create_dict_from_args(args: list, sep: str = "."): def build_config(args, args_override): # Open provided config file - assert os.path.exists(args.config_path), ( - "Config file not found in " + args.config_path - ) + assert os.path.exists( + args.config_path + ), f"Config file not found in {str(args.config_path)}" with open(args.config_path, "r") as ymlfile: config = yaml.load(ymlfile, Loader=yaml.FullLoader) diff --git a/matdeeplearn/common/data.py b/matdeeplearn/common/data.py index 0d8cb53c..848eaa31 100644 --- a/matdeeplearn/common/data.py +++ b/matdeeplearn/common/data.py @@ -1,11 +1,13 @@ import warnings +from typing import List import torch from torch.utils.data import random_split from torch_geometric.loader import DataLoader +from torch_geometric.transforms import Compose +from matdeeplearn.common.registry import registry from matdeeplearn.preprocessor.datasets import LargeStructureDataset, StructureDataset -from matdeeplearn.preprocessor.transforms import GetY # train test split @@ -58,7 +60,9 @@ def dataset_split( def get_dataset( - data_path, target_index: int = 0, transform_type="GetY", large_dataset=False + data_path, + transform_list: List[dict] = [], + large_dataset=False, ): """ get dataset according to data_path @@ -71,21 +75,24 @@ def get_dataset( data_path: str path to the folder containing data.pt file - target_index: int - the index to select the target values - this is needed because in our target.csv, there might be - multiple columns of target values available for that - particular dataset, thus we need to index one column for - the current run/experiment - - transform_type: transformation function/class to be applied + transform_list: transformation function/classes to be applied """ + # Ensure GetY exists to prevent downstream model errors + assert "GetY" in [ + tf["name"] for tf in transform_list + ], "The target transform GetY is required in config." + + transforms = [] # set transform method - if transform_type == "GetY": - T = GetY - else: - raise ValueError("No such transform found for {transform}") + for transform in transform_list: + if transform.get("otf", False): + transforms.append( + registry.get_transform_class( + transform["name"], + **({} if transform["args"] is None else transform["args"]) + ) + ) # check if large dataset is needed if large_dataset: @@ -93,17 +100,15 @@ def get_dataset( else: Dataset = StructureDataset - transform = T(index=target_index) + composition = Compose(transforms) if len(transforms) > 1 else transforms[0] - return Dataset(data_path, processed_data_path="", transform=transform) + dataset = Dataset(data_path, processed_data_path="", transform=composition) + + return dataset def get_dataloader( - dataset, - batch_size: int, - num_workers: int = 0, - sampler=None, - shuffle=True + dataset, batch_size: int, num_workers: int = 0, sampler=None, shuffle=True ): """ Returns a single dataloader for a given dataset diff --git a/matdeeplearn/common/registry.py b/matdeeplearn/common/registry.py index c77f8b41..33380eed 100644 --- a/matdeeplearn/common/registry.py +++ b/matdeeplearn/common/registry.py @@ -15,6 +15,7 @@ - Register a model: ``@registry.register_model`` """ import importlib +from typing import Callable def _get_absolute_mapping(name: str): @@ -52,6 +53,7 @@ class Registry: "trainer_name_mapping": {}, "loss_name_mapping": {}, "state": {}, + "transforms": {}, } @classmethod @@ -211,6 +213,16 @@ def register(cls, name, obj): current[path[-1]] = obj + @classmethod + def register_transform(cls, transform_name: str): + """Registers a transform function for bookkeeping.""" + + def wrap_func(transform: Callable): + cls.mapping["transforms"][transform_name] = transform + return transform + + return wrap_func + @classmethod def __import_error(cls, name: str, mapping_name: str): kind = mapping_name[: -len("_name_mapping")] @@ -275,6 +287,10 @@ def get_trainer_class(cls, name): def get_loss_class(cls, name): return cls.get_class(name, "loss_name_mapping") + @classmethod + def get_transform_class(cls, name, **kwargs): + return cls.get_class(name, "transforms")(**kwargs) + @classmethod def get(cls, name, default=None, no_warning=False): r"""Get an item from registry with key 'name' diff --git a/matdeeplearn/models/alignn.py b/matdeeplearn/models/alignn.py new file mode 100644 index 00000000..0dcc3766 --- /dev/null +++ b/matdeeplearn/models/alignn.py @@ -0,0 +1,408 @@ +from typing import Literal, Optional + +import numpy as np +import torch +from torch.nn import BatchNorm1d, Linear, Sequential, Sigmoid, SiLU +from torch.nn import functional as F +from torch_geometric.data import Data +from torch_geometric.nn import MessagePassing, global_mean_pool +from torch_scatter import scatter + +from matdeeplearn.common.registry import registry +from matdeeplearn.models.base_model import BaseModel + + +@registry.register_model("ALIGNN") +class ALIGNN(BaseModel): + def __init__( + self, + alignn_layers: int = 4, + gcn_layers: int = 4, + atom_input_features: int = 114, + edge_input_features: int = 50, + triplet_input_features: int = 40, + embedding_features: int = 64, + hidden_features: int = 256, + output_features: int = 1, + min_edge_distance: float = 0.0, + max_edge_distance: float = 8.0, + min_angle: float = 0.0, + max_angle: float = torch.acos(torch.zeros(1)).item() * 2, + link: Literal["identity", "log", "logit"] = "identity", + ) -> None: + super().__init__() + + # utilizing data object + # atom_input_features = data.num_features + # edge_input_features = data.num_edge_features + + self.atom_embedding = EmbeddingLayer(atom_input_features, hidden_features) + + self.edge_embedding = torch.nn.Sequential( + RBFExpansion( + vmin=min_edge_distance, vmax=max_edge_distance, bins=edge_input_features + ), + EmbeddingLayer(edge_input_features, embedding_features), + EmbeddingLayer(embedding_features, hidden_features), + ) + + self.angle_embedding = torch.nn.Sequential( + RBFExpansion(vmin=min_angle, vmax=max_angle, bins=triplet_input_features), + EmbeddingLayer(triplet_input_features, embedding_features), + EmbeddingLayer(embedding_features, hidden_features), + ) + + # layer to perform M ALIGNNConv updates on the graph + self.alignn_layers = torch.nn.ModuleList( + [ALIGNNConv(hidden_features, hidden_features) for _ in range(alignn_layers)] + ) + + # layer to perform N EdgeGatedConv updates on the graph + self.gcn_layers = torch.nn.ModuleList( + [ + EdgeGatedGraphConv(hidden_features, hidden_features) + for _ in range(gcn_layers) + ] + ) + + # prediction task + self.fc = Linear(hidden_features, output_features) + + # linking which is performed post-readout + self.link = None + self.link_name = link + if link == "identity": + self.link = lambda x: x + elif link == "log": + self.link = torch.exp + avg_gap = 0.7 + self.fc.bias.data = torch.tensor(np.log(avg_gap), dtype=torch.float) + elif link == "logit": + self.link = torch.sigmoid + + @property + def target_attr(self): + return "y" + + def forward(self, g: Data): + # initial node features + node_feats = self.atom_embedding(g.x) + # initial bond features + edge_attr = self.edge_embedding(g.edge_attr) + # initial angle/triplet features + triplet_feats = self.angle_embedding(g.edge_attr_lg) + + # ALIGNN updates + for alignn_layer in self.alignn_layers: + node_feats, edge_attr, triplet_feats = alignn_layer( + g, # required for correct edge and triplet indexing + node_feats, + edge_attr, + triplet_feats, + ) + + # GCN updates + for gcn_layer in self.gcn_layers: + node_feats, edge_attr = gcn_layer( + node_feats, + edge_attr, + g.edge_index, + ) + + # readout + h = global_mean_pool(node_feats, g.batch) + out = self.fc(h) + + if self.link: + out = self.link(out) + + return torch.squeeze(out, -1) + + +class ALIGNNConv(torch.nn.Module): + """ + Implementation of the ALIGNN layer composed of EdgeGatedGraphConv steps + """ + + def __init__(self, input_features, output_features) -> None: + super().__init__() + + # Sequential EdgeGatedCONV layers + # Overall mapping is input_features -> output_features + self.edge_update = EdgeGatedGraphConv(output_features, output_features) + self.node_update = EdgeGatedGraphConv(input_features, output_features) + + def forward( + self, + g: Data, + node_feats: torch.Tensor, + edge_attr: torch.Tensor, + triplet_feats: torch.Tensor, + ) -> torch.Tensor: + # Perform sequential edge and node updates + + message, triplet_feats = self.edge_update( + edge_attr, triplet_feats, g.edge_index_lg + ) + + node_feats, edge_attr = self.node_update(node_feats, message, g.edge_index) + + # Return updated node, edge, and triplet embeddings + return node_feats, edge_attr, triplet_feats + + +class EdgeGatedGraphConv(MessagePassing): + """ + Message-passing based implementation of EGGConv + """ + + def __init__( + self, input_features: int, output_features: int, residual: bool = True, eps=1e-6 + ) -> None: + super().__init__() + + self.W_src = Linear(input_features, output_features) + self.W_dst = Linear(input_features, output_features) + # Operates on h_i + self.W_ai = Linear(input_features, output_features) + # Operates on h_j + self.W_bj = Linear(input_features, output_features) + # Operates on e_ij + self.W_cij = Linear(output_features, output_features) + + self.bn_nodes = BatchNorm1d(output_features) + self.bn_edges = BatchNorm1d(output_features) + + self.act = SiLU() + self.sigmoid = Sigmoid() + self.residual = residual + self.eps = eps + + def forward( + self, + node_feats: torch.Tensor, + edge_attr: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + i, j = edge_index + # Node update routine + sigma = self.sigmoid(edge_attr) + sigma_sum = scatter(src=sigma, index=i, dim=0) + # Accessing at index i allows for shape matching and correct aggregate division + e_ij_hat = sigma / (sigma_sum[i] + self.eps) + + dest_aggr = self.propagate(edge_index, x=node_feats, e_ij_hat=e_ij_hat) + + new_node_feats = node_feats + self.act( + self.bn_nodes(self.W_src(node_feats) + dest_aggr) + ) + + # Edge update routine + new_edge_attr = edge_attr + self.act( + self.bn_edges( + self.W_ai(node_feats[i]) + + self.W_bj(node_feats[j]) + + self.W_cij(edge_attr) + ) + ) + + return new_node_feats, new_edge_attr + + def message(self, x_j, e_ij_hat): + return e_ij_hat * self.W_dst(x_j) + + +class EdgeGatedGraphConvNoMP(torch.nn.Module): + """ + Implementation of the EdgeGatedGraphConv layer + """ + + def __init__( + self, input_features: int, output_features: int, residual: bool = True + ) -> None: + super().__init__() + + # define the edge and node models for creating new embeddings + self.edge_model = EdgeModel(input_features, output_features, residual) + self.node_model = NodeModel(input_features, output_features, residual) + + def forward( + self, + node_feats: torch.Tensor, + edge_attr: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + # compute new edge features + + row, col = edge_index + new_edge_attr = self.edge_model(node_feats[row], node_feats[col], edge_attr) + + # compute new node features (which are based on previously updated edge features) + new_node_feats = self.node_model(node_feats, edge_index, new_edge_attr) + + return new_node_feats, new_edge_attr + + +class EdgeModel(torch.nn.Module): + """ + Abstraction to perform an update on the edge attributes + e_ij_new = f(e_ij, h_i, h_j) + """ + + def __init__(self, input_features, output_features, residual=True): + super().__init__() + + # source node attrributes + self.src_gate = Linear(input_features, output_features) + # dest node attributes + self.dest_gate = Linear(input_features, output_features) + # weights for edge attributes + self.edge_gate = Linear(input_features, output_features) + + self.batch_norm = BatchNorm1d(output_features) + self.residual = residual + + def forward(self, src, dest, edge_attr, u=None, batch=None): + # src and dest are the nodes connecting the edge + + new_feats = F.silu( + self.batch_norm( + self.src_gate(src) + self.dest_gate(dest) + self.edge_gate(edge_attr) + ) + ) + + if self.residual: + new_feats = edge_attr + new_feats + return new_feats + + +class NodeModel(torch.nn.Module): + """ + Abstraction to perform an update on the node attributes + h_i_new = f(h_i, \\sum e_ij_hat * (Wdst * h_j)) + """ + + def __init__( + self, input_features, output_features, residual=True, eps=1e-6 + ) -> None: + super().__init__() + # sam + self.src_update = Linear(input_features, output_features) + # Define message passing routines + self.node_aggr = NodeAggregation(input_features, output_features) + self.edge_aggr = EdgeAggregation() + + self.batch_norm = BatchNorm1d(output_features) + self.act = SiLU() + self.sigmoid = Sigmoid() + + self.residual = residual + self.eps = eps + + def forward(self, x, edge_index, edge_attr, u=None, batch=None): + # compute sigmoid-aggregate of new edge features + node_aggregate = self.node_aggr(x, edge_index, edge_attr) + + edge_aggregate = self.edge_aggr(edge_index, edge_attr) + + dest_aggr = node_aggregate / (edge_aggregate + self.eps) + + # compute new node features + new_feats = self.act(self.batch_norm(self.src_update(x) + dest_aggr)) + + if self.residual: + new_feats = x + new_feats + + return new_feats + + +class NodeAggregation(MessagePassing): + """ + Used to compute an aggregation of transformed node attributes and edge attributes + \\sumj e_ij_hat * (Wdst * h_j) + """ + + def __init__(self, in_channels, out_channels): + super().__init__(aggr="add") + # Define the MLP that operates on each neighboring node + self.dst_update = Linear(in_channels, out_channels) + + def forward(self, x, edge_index, edge_attr): + out = self.propagate(edge_index, x=x, edge_attr=edge_attr) + return out + + def message(self, x_j, edge_attr): + update = self.dst_update(x_j) + # element-wise multiplication as dest update matches shapes + return torch.sigmoid(edge_attr) * update + + +class EdgeAggregation(MessagePassing): + """ + Used to compute the aggregation of edge attributes (sigmoid transform) with respect to neighboring nodes + Message passing still occurs with respect to bond graph + \\sumk \\sigma(e_ik) + """ + + def __init__(self): + super().__init__(aggr="add") + + def forward(self, edge_index, edge_attr): + out = self.propagate(edge_index, edge_attr=edge_attr) + # print(out.element_size() * out.nelement(), out.nelement()) + return out + + def message(self, edge_attr): + return torch.sigmoid(edge_attr) + + +class EmbeddingLayer(torch.nn.Module): + """ + Custom layer which performs nonlinear transform on embeddings + """ + + def __init__(self, input_features, output_features) -> None: + super().__init__() + + self.mlp = Sequential( + Linear(input_features, output_features), + BatchNorm1d(output_features), + torch.nn.SiLU(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class RBFExpansion(torch.nn.Module): + """ + RBF Expansion on distances or angles to compute Gaussian distribution of embeddings + """ + + def __init__( + self, + vmin: float = 0, # default 0A + vmax: float = 8, # default 8A + bins: int = 40, # embedding dimension + lengthscale: Optional[float] = None, + ) -> None: + super().__init__() + + self.vmin = vmin + self.vmax = vmax + self.bins = bins + self.centers = torch.linspace(vmin, vmax, bins) + + if lengthscale is None: + lengthscale = torch.diff(self.centers).mean() + self.gamma = 1.0 / lengthscale + self.lengthscale = lengthscale + else: + self.lengthscale = lengthscale + self.gamma = 1.0 / (lengthscale**2) + + def forward(self, distance: torch.Tensor): + out = torch.exp( + -self.gamma * (distance - self.centers.to(distance.device)) ** 2 + ) + return out diff --git a/matdeeplearn/models/alignn_graphite.py b/matdeeplearn/models/alignn_graphite.py new file mode 100644 index 00000000..2bd73abf --- /dev/null +++ b/matdeeplearn/models/alignn_graphite.py @@ -0,0 +1,155 @@ +from functools import partial + +import torch +from torch.nn import Embedding, LayerNorm, Linear, ModuleList, Sequential, Sigmoid, SiLU +from torch_geometric.data import Data +from torch_geometric.nn import MessagePassing +from torch_scatter import scatter + +from matdeeplearn.common.registry import registry +from matdeeplearn.models.base_model import BaseModel + + +@registry.register_model("ALIGNN_GRAPHITE") +class ALIGNN_GRAPHITE(BaseModel): + """ALIGNN model that uses auxiliary line graph to explicitly represent and encode bond angles. + Reference: https://www.nature.com/articles/s41524-021-00650-1. + """ + + def __init__(self, dim=64, num_interactions=4, num_species=3, cutoff=3.0): + super().__init__() + + self.dim = dim + self.num_interactions = num_interactions + self.cutoff = cutoff + + self.embed_atm = Embedding(num_species, dim) + self.embed_bnd = partial(bessel, start=0, end=cutoff, num_basis=dim) + + self.atm_bnd_interactions = ModuleList() + self.bnd_ang_interactions = ModuleList() + for _ in range(num_interactions): + self.atm_bnd_interactions.append(EGConv(dim, dim)) + self.bnd_ang_interactions.append(EGConv(dim, dim)) + + self.head = Sequential( + Linear(dim, dim), + SiLU(), + ) + + self.out = Sequential( + Linear(dim, 1), + ) + + self.reset_parameters() + + @property + def target_attr(self): + return "y" + + def reset_parameters(self): + self.embed_atm.reset_parameters() + for i in range(self.num_interactions): + self.atm_bnd_interactions[i].reset_parameters() + self.bnd_ang_interactions[i].reset_parameters() + + def embed_ang(self, x_ang): + cos_ang = torch.cos(x_ang) + return gaussian(cos_ang, start=-1, end=1, num_basis=self.dim) + + def forward(self, data: Data): + edge_index_G = data.edge_index + edge_index_A = data.edge_index_lg + h_atm = self.embed_atm(data.x.type(torch.long)) + h_bnd = self.embed_bnd(data.edge_attr) + h_ang = self.embed_ang(data.edge_attr_lg) + + for i in range(self.num_interactions): + h_bnd, h_ang = self.bnd_ang_interactions[i](h_bnd, edge_index_A, h_ang) + h_atm, h_bnd = self.atm_bnd_interactions[i](h_atm, edge_index_G, h_bnd) + + h = scatter(h_atm, data.batch, dim=0, reduce="add") + h = self.head(h) + return self.out(h) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"dim={self.dim}, " + f"num_interactions={self.num_interactions}, " + f"cutoff={self.cutoff})" + ) + + +class EGConv(MessagePassing): + """Edge-gated convolution. + This version is closer to the original formulation (without the concatenation). + * https://arxiv.org/abs/2003.00982 + """ + + def __init__(self, node_dim, edge_dim, epsilon=1e-5): + super().__init__(aggr="add") + self.W_src = Linear(node_dim, node_dim) + self.W_dst = Linear(node_dim, node_dim) + self.W_A = Linear(node_dim, edge_dim) + self.W_B = Linear(node_dim, edge_dim) + self.W_C = Linear(edge_dim, edge_dim) + self.act = SiLU() + self.sigma = Sigmoid() + self.norm_x = LayerNorm([node_dim]) + self.norm_e = LayerNorm([edge_dim]) + self.eps = epsilon + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.xavier_uniform_(self.W_src.weight) + self.W_src.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.W_dst.weight) + self.W_dst.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.W_A.weight) + self.W_A.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.W_B.weight) + self.W_B.bias.data.fill_(0) + torch.nn.init.xavier_uniform_(self.W_C.weight) + self.W_C.bias.data.fill_(0) + + def forward(self, x, edge_index, edge_attr): + i, j = edge_index + + # Calculate gated edges + sigma_e = self.sigma(edge_attr) + e_sum = scatter(src=sigma_e, index=i, dim=0) + e_gated = sigma_e / (e_sum[i] + self.eps) + + # Update the nodes (this utilizes the gated edges) + out = self.propagate(edge_index, x=x, e_gated=e_gated) + out = self.W_src(x) + out + out = x + self.act(self.norm_x(out)) + + # Update the edges + edge_attr = edge_attr + self.act( + self.norm_e(self.W_A(x[i]) + self.W_B(x[j]) + self.W_C(edge_attr)) + ) + + return out, edge_attr + + def message(self, x_j, e_gated): + return e_gated * self.W_dst(x_j) + + +def bessel(x, start=0.0, end=1.0, num_basis=8, eps=1e-5): + """Expand scalar features into (radial) Bessel basis function values.""" + x = x[..., None] - start + eps + c = end - start + n = torch.arange(1, num_basis + 1, dtype=x.dtype, device=x.device) + return ((2 / c) ** 0.5) * torch.sin(n * torch.pi * x / c) / x + + +def gaussian(x, start=0.0, end=1.0, num_basis=8): + """Expand scalar features into Gaussian basis function values.""" + mu = torch.linspace(start, end, num_basis, dtype=x.dtype, device=x.device) + step = mu[1] - mu[0] + diff = (x[..., None] - mu) / step + # division by 1.12 so that sum of square is roughly 1 + return diff.pow(2).neg().exp().div(1.12) diff --git a/matdeeplearn/models/cgcnn.py b/matdeeplearn/models/cgcnn.py index b9fa8de9..dcc67ebe 100644 --- a/matdeeplearn/models/cgcnn.py +++ b/matdeeplearn/models/cgcnn.py @@ -63,7 +63,7 @@ def __init__( if data[0][self.target_attr].ndim == 0: self.output_dim = 1 else: - self.output_dim = len(data[0][self.target_attr][0]) + self.output_dim = len(data[0][self.target_attr]) # setup layers self.pre_lin_list = self._setup_pre_gnn_layers() diff --git a/matdeeplearn/models/dos_predict.py b/matdeeplearn/models/dos_predict.py index 306a079a..55e51993 100644 --- a/matdeeplearn/models/dos_predict.py +++ b/matdeeplearn/models/dos_predict.py @@ -49,7 +49,7 @@ def __init__( if data[0][self.target_attr].ndim == 0: self.output_dim = 1 else: - self.output_dim = len(data[0][self.target_attr][0]) + self.output_dim = len(data[0][self.target_attr]) # setup layers self.pre_lin_list = self._setup_pre_gnn_layers() diff --git a/matdeeplearn/modules/loss.py b/matdeeplearn/modules/loss.py index 27a564c5..24c828e7 100644 --- a/matdeeplearn/modules/loss.py +++ b/matdeeplearn/modules/loss.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import torch import torch.nn.functional as F diff --git a/matdeeplearn/preprocessor/datasets.py b/matdeeplearn/preprocessor/datasets.py index 899534c4..475e85ce 100644 --- a/matdeeplearn/preprocessor/datasets.py +++ b/matdeeplearn/preprocessor/datasets.py @@ -1,34 +1,36 @@ -import torch, os +import os +import torch from torch_geometric.data import InMemoryDataset + class StructureDataset(InMemoryDataset): def __init__( self, - root, - processed_data_path, - transform=None, - pre_transform=None, + root, + processed_data_path, + transform=None, + pre_transform=None, pre_filter=None, - device=None + device=None, ): self.root = root self.processed_data_path = processed_data_path - super(StructureDataset, self).__init__(root, transform, pre_transform, pre_filter) + super(StructureDataset, self).__init__( + root, transform, pre_transform, pre_filter + ) if not torch.cuda.is_available() or device == "cpu": self.data, self.slices = torch.load( - self.processed_paths[0], - map_location=torch.device('cpu') + self.processed_paths[0], map_location=torch.device("cpu") ) else: self.data, self.slices = torch.load(self.processed_paths[0]) - - + @property def raw_file_names(self): """ - The name of the files in the self.raw_dir folder + The name of the files in the self.raw_dir folder that must be present in order to skip downloading. """ return [] @@ -46,10 +48,11 @@ def processed_dir(self): @property def processed_file_names(self): """ - The name of the files in the self.processed_dir + The name of the files in the self.processed_dir folder that must be present in order to skip processing. """ return ["data.pt"] + class LargeStructureDataset(InMemoryDataset): pass diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index 35dc4ac0..5e55f604 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -1,14 +1,28 @@ -import numpy as np -import ase -from ase import io -import torch +import contextlib import itertools +import logging from pathlib import Path +import numpy as np import torch import torch.nn.functional as F -from torch_geometric.utils import dense_to_sparse, degree, add_self_loops +from torch.profiler import ProfilerActivity, profile from torch_geometric.data.data import Data +from torch_geometric.utils import add_self_loops, degree +from torch_sparse import SparseTensor + + +@contextlib.contextmanager +def prof_ctx(): + """Primitive debug tool which allows profiling of PyTorch code""" + with profile( + activities=[ProfilerActivity.CUDA], record_shapes=True, profile_memory=True + ) as prof: + + yield + + logging.debug(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) + def threshold_sort(all_distances, r, n_neighbors): # A = all_distances.clone().detach() @@ -23,14 +37,15 @@ def threshold_sort(all_distances, r, n_neighbors): _, indices = torch.topk(A, N) A = torch.scatter( A, - 1, indices, torch.zeros(len(A), len(A), - device=all_distances.device, - dtype=torch.float) + 1, + indices, + torch.zeros(len(A), len(A), device=all_distances.device, dtype=torch.float), ) A[A > r] = 0 return A + def one_hot_degree(data, max_degree, in_degree=False, cat=True): idx, x = data.edge_index[1 if in_degree else 0], data.x deg = degree(idx, data.num_nodes, dtype=torch.long) @@ -49,7 +64,10 @@ class GaussianSmearing(torch.nn.Module): """ slightly edited version from pytorch geometric to create edge from gaussian basis """ - def __init__(self, start=0.0, stop=5.0, resolution=50, width=0.05, device="cpu", **kwargs): + + def __init__( + self, start=0.0, stop=5.0, resolution=50, width=0.05, device="cpu", **kwargs + ): super(GaussianSmearing, self).__init__() offset = torch.linspace(start, stop, resolution, device=device) # self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2 @@ -60,6 +78,7 @@ def forward(self, dist): dist = dist.unsqueeze(-1) - self.offset.view(1, -1) return torch.exp(self.coeff * torch.pow(dist, 2)) + def normalize_edge(dataset, descriptor_label): mean, std, feature_min, feature_max = get_ranges(dataset, descriptor_label) @@ -68,9 +87,13 @@ def normalize_edge(dataset, descriptor_label): data.edge_descriptor[descriptor_label] - feature_min ) / (feature_max - feature_min) + def normalize_edge_cutoff(dataset, descriptor_label, r): for data in dataset: - data.edge_descriptor[descriptor_label] = data.edge_descriptor[descriptor_label] / r + data.edge_descriptor[descriptor_label] = ( + data.edge_descriptor[descriptor_label] / r + ) + def get_ranges(dataset, descriptor_label): mean = 0.0 @@ -91,21 +114,23 @@ def get_ranges(dataset, descriptor_label): std = std / len(dataset) return mean, std, feature_min, feature_max + def clean_up(data_list, attr_list): if not attr_list: return - + # check which attributes in the list are removable removable_attrs = [t for t in attr_list if t in data_list[0].to_dict()] for data in data_list: for attr in removable_attrs: delattr(data, attr) + def get_distances( positions: torch.Tensor, offsets: torch.Tensor, device: str = "cpu", - mic: bool = True + mic: bool = True, ): """ Get pairwise atomic distances @@ -113,17 +138,17 @@ def get_distances( Parameters positions: torch.Tensor positions of atoms in a unit cell - + offsets: torch.Tensor offsets for the unit cell - + device: str torch device type - + mic: bool minimum image convention """ - + # convert numpy array to torch tensors n_atoms = len(positions) n_cells = len(offsets) @@ -139,7 +164,7 @@ def get_distances( # set diagonal of the (0,0,0) unit cell to infinity # this allows us to get the minimum self-loop distance # of an atom to itself in all other images - origin_unit_cell_idx = 13 + # origin_unit_cell_idx = 13 # atomic_distances[:,:,origin_unit_cell_idx].fill_diagonal_(float("inf")) # get minimum @@ -147,7 +172,9 @@ def get_distances( expanded_min_indices = min_indices.clone().detach() atom_rij = pos1 - pos2 - expanded_min_indices = expanded_min_indices[..., None, None].expand(-1, -1, 1, atom_rij.size(3)) + expanded_min_indices = expanded_min_indices[..., None, None].expand( + -1, -1, 1, atom_rij.size(3) + ) atom_rij = torch.gather(atom_rij, dim=2, index=expanded_min_indices).squeeze() return min_atomic_distances, min_indices @@ -156,7 +183,7 @@ def get_distances( def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"): """ Get the periodic boundary condition (PBC) offsets for a unit cell - + Parameters cell: torch.Tensor unit cell vectors of ase.cell.Cell @@ -167,22 +194,25 @@ def get_pbc_cells(cell: torch.Tensor, offset_number: int, device: str = "cpu"): if == 1: 27-cell offsets (3x3x3) """ - _range = np.arange(-offset_number, offset_number+1) + _range = np.arange(-offset_number, offset_number + 1) offsets = [list(x) for x in itertools.product(_range, _range, _range)] offsets = torch.tensor(offsets, device=device, dtype=torch.float) return offsets @ cell, offsets -def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop, offset_number=1): + +def get_cutoff_distance_matrix( + pos, cell, r, n_neighbors, device, image_selfloop, offset_number=1 +): """ get the distance matrix TODO: need to tune this for elongated structures Parameters ---------- - pos: np.ndarray + pos: np.ndarray positions of atoms in a unit cell get from crystal.get_positions() - + cell: np.ndarray unit cell of a ase Atoms object @@ -214,12 +244,15 @@ def get_cutoff_distance_matrix(pos, cell, r, n_neighbors, device, image_selfloop # thus initialize a zero matrix of (M+N, 3) for cell offsets n_edges = torch.count_nonzero(cutoff_distance_matrix).item() cell_offsets = torch.zeros(n_edges + len(pos), 3, dtype=torch.float) - # get cells for edges except for self loops + # get cells for edges except for self loops cell_offsets[:n_edges, :] = all_cell_offsets[cutoff_distance_matrix != 0] return cutoff_distance_matrix, cell_offsets -def add_selfloop(num_nodes, edge_indices, edge_weights, cutoff_distance_matrix, self_loop=True): + +def add_selfloop( + num_nodes, edge_indices, edge_weights, cutoff_distance_matrix, self_loop=True +): """ add self loop (i, i) to graph structure @@ -239,16 +272,15 @@ def add_selfloop(num_nodes, edge_indices, edge_weights, cutoff_distance_matrix, distance_matrix_masked = (cutoff_distance_matrix.fill_diagonal_(1) != 0).int() return edge_indices, edge_weights, distance_matrix_masked + def load_node_representation(node_representation="onehot"): node_rep_path = Path(__file__).parent - default_reps = { - "onehot": str(node_rep_path / "./node_representations/onehot.csv") - } + default_reps = {"onehot": str(node_rep_path / "./node_representations/onehot.csv")} rep_file_path = node_representation if node_representation in default_reps: rep_file_path = default_reps[node_representation] - + file_type = rep_file_path.split(".")[-1] loaded_rep = None @@ -263,21 +295,23 @@ def load_node_representation(node_representation="onehot"): return loaded_rep + def generate_node_features(input_data, n_neighbors, device): node_reps = load_node_representation() node_reps = torch.from_numpy(node_reps).to(device) n_elements, n_features = node_reps.shape - + if isinstance(input_data, Data): - input_data.x = node_reps[input_data.z-1].view(-1,n_features) - return one_hot_degree(input_data, n_neighbors+1) + input_data.x = node_reps[input_data.z - 1].view(-1, n_features) + return one_hot_degree(input_data, n_neighbors + 1) for i, data in enumerate(input_data): # minus 1 as the reps are 0-indexed but atomic number starts from 1 - data.x = node_reps[data.z-1].view(-1,n_features) + data.x = node_reps[data.z - 1].view(-1, n_features) for i, data in enumerate(input_data): - input_data[i] = one_hot_degree(data, n_neighbors+1) + input_data[i] = one_hot_degree(data, n_neighbors + 1) + def generate_edge_features(input_data, edge_steps, r, device): distance_gaussian = GaussianSmearing(0, 1, edge_steps, 0.2, device=device) @@ -287,4 +321,74 @@ def generate_edge_features(input_data, edge_steps, r, device): normalize_edge_cutoff(input_data, "distance", r) for i, data in enumerate(input_data): - input_data[i].edge_attr = distance_gaussian(input_data[i].edge_descriptor["distance"]) + input_data[i].edge_attr = distance_gaussian( + input_data[i].edge_descriptor["distance"] + ) + + +def triplets(edge_index, cell_offsets, num_nodes): + """ + Taken from the DimeNet implementation on OCP + """ + + row, col = edge_index # j->i + + value = torch.arange(row.size(0), device=row.device) + adj_t = SparseTensor( + row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes) + ) + adj_t_row = adj_t[row] + num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) + + # Node indices (k->j->i) for triplets. + idx_i = col.repeat_interleave(num_triplets) + idx_j = row.repeat_interleave(num_triplets) + idx_k = adj_t_row.storage.col() + + # Edge indices (k->j, j->i) for triplets. + idx_kj = adj_t_row.storage.value() + idx_ji = adj_t_row.storage.row() + + # Remove self-loop triplets d->b->d + # Check atom as well as cell offset + cell_offset_kji = cell_offsets[idx_kj] + cell_offsets[idx_ji] + mask = (idx_i != idx_k) | torch.any(cell_offset_kji != 0, dim=-1).to( + device=idx_i.device + ) + + idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] + idx_kj, idx_ji = idx_kj[mask], idx_ji[mask] + + return idx_i, idx_j, idx_k, idx_kj, idx_ji + + +def compute_bond_angles( + pos: torch.Tensor, offsets: torch.Tensor, edge_index: torch.Tensor, num_nodes: int +) -> torch.Tensor: + """ + Compute angle between bonds to compute node embeddings for L(g) + Taken from the DimeNet implementation on OCP + """ + + # Calculate triplets + idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets( + edge_index, offsets.to(device=edge_index.device), num_nodes + ) + + # Calculate angles. + pos_i = pos[idx_i] + pos_j = pos[idx_j] + + offsets = offsets.to(pos.device) + + pos_ji, pos_kj = ( + pos[idx_j] - pos_i + offsets[idx_ji], + pos[idx_k] - pos_j + offsets[idx_kj], + ) + + a = (pos_ji * pos_kj).sum(dim=-1) + b = torch.cross(pos_ji, pos_kj).norm(dim=-1) + + angle = torch.atan2(b, a) + + return angle, idx_kj, idx_ji diff --git a/matdeeplearn/preprocessor/processor.py b/matdeeplearn/preprocessor/processor.py index 32bb5e54..ea90d5dc 100644 --- a/matdeeplearn/preprocessor/processor.py +++ b/matdeeplearn/preprocessor/processor.py @@ -2,15 +2,16 @@ import logging import os -import ase import numpy as np import pandas as pd import torch from ase import io from torch_geometric.data import Data, InMemoryDataset +from torch_geometric.transforms import Compose from torch_geometric.utils import dense_to_sparse from tqdm import tqdm +from matdeeplearn.common.registry import registry from matdeeplearn.preprocessor.helpers import ( clean_up, generate_edge_features, @@ -41,13 +42,14 @@ def process_data(dataset_config): r=cutoff_radius, n_neighbors=n_neighbors, edge_steps=edge_steps, + transforms=dataset_config.get("transforms", []), data_format=data_format, image_selfloop=image_selfloop, self_loop=self_loop, node_representation=node_representation, additional_attributes=additional_attributes, verbose=verbose, - device=device + device=device, ) processor.process() @@ -61,6 +63,7 @@ def __init__( r: float, n_neighbors: int, edge_steps: int, + transforms: list = [], data_format: str = "json", image_selfloop: bool = True, self_loop: bool = True, @@ -94,6 +97,9 @@ def __init__( step size for creating Gaussian basis for edges used in torch.linspace + transforms: list + default []. List of transforms to apply to the data. + data_format: str format of the raw data file @@ -131,7 +137,7 @@ def __init__( self.additional_attributes = additional_attributes self.verbose = verbose self.device = device - + self.transforms = transforms self.disable_tqdm = logging.root.level > logging.INFO def src_check(self): @@ -156,7 +162,7 @@ def ase_wrap(self): logging.info("Converting data to standardized form for downstream processing.") for i, structure_id in enumerate(file_names): p = os.path.join(self.root_path, str(structure_id) + "." + self.data_format) - ase_structures.append(ase.io.read(p)) + ase_structures.append(io.read(p)) for i, s in enumerate(tqdm(ase_structures, disable=self.disable_tqdm)): d = {} @@ -212,7 +218,11 @@ def json_wrap(self): dict_structures = [] y = [] - y_dim = len(original_structures[0]["y"]) if isinstance(original_structures[0]["y"], list) else 1 + y_dim = ( + len(original_structures[0]["y"]) + if isinstance(original_structures[0]["y"], list) + else 1 + ) logging.info("Converting data to standardized form for downstream processing.") for i, s in enumerate(tqdm(original_structures, disable=self.disable_tqdm)): @@ -268,6 +278,7 @@ def get_data_list(self, dict_structures, y): data_list = [Data() for _ in range(n_structures)] logging.info("Getting torch_geometric.data.Data() objects.") + for i, sdict in enumerate(tqdm(dict_structures, disable=self.disable_tqdm)): target_val = y[i] data = data_list[i] @@ -314,6 +325,30 @@ def get_data_list(self, dict_structures, y): logging.info("Generating edge features...") generate_edge_features(data_list, self.edge_steps, self.r, device=self.device) + # compile non-otf transforms + logging.debug("Applying transforms.") + + # Ensure GetY exists to prevent downstream model errors + assert "GetY" in [ + tf["name"] for tf in self.transforms + ], "The target transform GetY is required in config." + + transforms_list = [] + for transform in self.transforms: + if not transform.get("otf", False): + transforms_list.append( + registry.get_transform_class( + transform["name"], + **({} if transform["args"] is None else transform["args"]) + ) + ) + + composition = Compose(transforms_list) + + # apply transforms + for data in data_list: + composition(data) + clean_up(data_list, ["edge_descriptor"]) return data_list diff --git a/matdeeplearn/preprocessor/transforms.py b/matdeeplearn/preprocessor/transforms.py index 111f4bc7..60e7e4dc 100644 --- a/matdeeplearn/preprocessor/transforms.py +++ b/matdeeplearn/preprocessor/transforms.py @@ -1,4 +1,8 @@ -import os +import torch +from torch_sparse import coalesce + +from matdeeplearn.common.registry import registry +from matdeeplearn.preprocessor.helpers import compute_bond_angles """ here resides the transform classes needed for data processing @@ -10,7 +14,10 @@ """ +@registry.register_transform("GetY") class GetY(object): + """Get the target from the data object.""" + def __init__(self, index=0): self.index = index @@ -19,3 +26,61 @@ def __call__(self, data): if self.index != -1: data.y = data.y[0][self.index] return data + + +@registry.register_transform("NumNodeTransform") +class NumNodeTransform(object): + """ + Adds the number of nodes to the data object + """ + + def __call__(self, data): + data.num_nodes = data.x.shape[0] + return data + + +@registry.register_transform("LineGraphMod") +class LineGraphMod(object): + """ + Adds line graph attributes to the data object + """ + + def __call__(self, data): + # CODE FROM PYG LINEGRAPH TRANSFORM (DIRECTED) + N = data.num_nodes + edge_index, edge_attr = data.edge_index, data.edge_attr + _, edge_attr = coalesce(edge_index, edge_attr, N, N) + + # compute bond angles + angles, idx_kj, idx_ji = compute_bond_angles( + data.pos, data.cell_offsets, data.edge_index, data.num_nodes + ) + triplet_pairs = torch.stack([idx_kj, idx_ji], dim=0) + + data.edge_index_lg = triplet_pairs + data.x_lg = data.edge_attr + data.num_nodes_lg = edge_index.size(1) + + # assign bond angles to edge attributes + data.edge_attr_lg = angles.reshape(-1, 1) + + return data + + +@registry.register_transform("ToFloat") +class ToFloat(object): + """ + Convert non-int attributes to float + """ + + def __call__(self, data): + data.x = data.x.float() + data.x_lg = data.x_lg.float() + + data.distances = data.distances.float() + data.pos = data.pos.float() + + data.edge_attr = data.edge_attr.float() + data.edge_attr_lg = data.edge_attr_lg.float() + + return data diff --git a/matdeeplearn/tasks/task.py b/matdeeplearn/tasks/task.py index 09bbda2c..58af20a5 100644 --- a/matdeeplearn/tasks/task.py +++ b/matdeeplearn/tasks/task.py @@ -1,5 +1,4 @@ import logging -import os from matdeeplearn.common.registry import registry @@ -14,9 +13,11 @@ def __init__(self, config): def setup(self, trainer): self.trainer = trainer - checkpoint = self.config.get("checkpoint", None) - if checkpoint is not None: - self.trainer.load_checkpoint(self.config["checkpoint"]) + use_checkpoint = self.config["model"].get("load_model", False) + if use_checkpoint: + logging.info("Attempting to load most recent checkpoint...") + self.trainer.load_checkpoint() + logging.info("Recent checkpoint loaded successfully.") # save checkpoint path to runner state for slurm resubmissions # self.chkpt_path = os.path.join( diff --git a/matdeeplearn/trainers/base_trainer.py b/matdeeplearn/trainers/base_trainer.py index de61656b..6c35f0db 100644 --- a/matdeeplearn/trainers/base_trainer.py +++ b/matdeeplearn/trainers/base_trainer.py @@ -1,5 +1,6 @@ import copy import csv +import glob import logging import os from abc import ABC, abstractmethod @@ -38,8 +39,11 @@ def __init__( test_loader: DataLoader, loss: nn.Module, max_epochs: int, + max_checkpoint_epochs: int = None, identifier: str = None, verbosity: int = None, + save_dir: str = None, + checkpoint_dir: str = None, ): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = model.to(self.device) @@ -54,6 +58,7 @@ def __init__( self.scheduler = scheduler self.loss_fn = loss self.max_epochs = max_epochs + self.max_checkpoint_epochs = max_checkpoint_epochs self.train_verbosity = verbosity self.epoch = 0 @@ -63,9 +68,10 @@ def __init__( self.best_val_metric = 1e10 self.best_model_state = None - self.evaluator = Evaluator() + self.save_dir = save_dir if save_dir else os.getcwd() + self.checkpoint_dir = checkpoint_dir - self.run_dir = os.getcwd() + self.evaluator = Evaluator() timestamp = torch.tensor(datetime.now().timestamp()).to(self.device) self.timestamp_id = datetime.fromtimestamp(timestamp.int()).strftime( @@ -95,9 +101,6 @@ def from_config(cls, config): scheduler dataset """ - # TODO: figure out what configs are passed in and how they're structured - # (one overall config, or individual components) - dataset = cls._load_dataset(config["dataset"]) model = cls._load_model(config["model"], dataset) optimizer = cls._load_optimizer(config["optim"], model) @@ -107,10 +110,13 @@ def from_config(cls, config): ) scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer) loss = cls._load_loss(config["optim"]["loss"]) - max_epochs = config["optim"]["max_epochs"] + max_checkpoint_epochs = config["optim"].get("max_checkpoint_epochs", None) identifier = config["task"].get("identifier", None) verbosity = config["task"].get("verbosity", None) + # pass in custom results home dir and load in prev checkpoint dir + save_dir = config["task"].get("save_dir", None) + checkpoint_dir = config["task"].get("checkpoint_dir", None) return cls( model=model, @@ -123,17 +129,22 @@ def from_config(cls, config): test_loader=test_loader, loss=loss, max_epochs=max_epochs, + max_checkpoint_epochs=max_checkpoint_epochs, identifier=identifier, verbosity=verbosity, + save_dir=save_dir, + checkpoint_dir=checkpoint_dir, ) @staticmethod def _load_dataset(dataset_config): """Loads the dataset if from a config file.""" dataset_path = dataset_config["pt_path"] - target_index = dataset_config.get("target_index", 0) - dataset = get_dataset(dataset_path, target_index) + dataset = get_dataset( + dataset_path, + transform_list=dataset_config.get("transforms", []), + ) return dataset @@ -249,17 +260,17 @@ def save_model(self, checkpoint_file, val_metrics=None, training_state=True): else: state = {"state_dict": self.model.state_dict(), "val_metrics": val_metrics} - checkpoint_dir = os.path.join( - self.run_dir, "results", self.timestamp_id, "checkpoint" + curr_checkpt_dir = os.path.join( + self.save_dir, "results", self.timestamp_id, "checkpoint" ) - os.makedirs(checkpoint_dir, exist_ok=True) - filename = os.path.join(checkpoint_dir, checkpoint_file) + os.makedirs(curr_checkpt_dir, exist_ok=True) + filename = os.path.join(curr_checkpt_dir, checkpoint_file) torch.save(state, filename) return filename def save_results(self, output, filename, node_level_predictions=False): - results_path = os.path.join(self.run_dir, "results", self.timestamp_id) + results_path = os.path.join(self.save_dir, "results", self.timestamp_id) os.makedirs(results_path, exist_ok=True) filename = os.path.join(results_path, filename) shape = output.shape @@ -279,7 +290,22 @@ def save_results(self, output, filename, node_level_predictions=False): csvwriter.writerow(output[i - 1, :]) return filename + # TODO: streamline this from PR #12 def load_checkpoint(self): """Loads the model from a checkpoint.pt file""" - # TODO: implement this method - pass + + if not self.checkpoint_dir: + raise ValueError("No checkpoint directory specified in config.") + + checkpoint_dir = glob.glob(os.path.join(self.checkpoint_dir, "results", "*")) + checkpoint_file = os.path.join(checkpoint_dir, "checkpoint", "checkpoint.pt") + + # Load params from checkpoint + checkpoint = torch.load(checkpoint_file) + + self.model.load_state_dict(checkpoint["state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) + self.epoch = checkpoint["epoch"] + self.step = checkpoint["step"] + self.best_val_metric = checkpoint["best_val_metric"] diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index be510ee6..9583932d 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -23,8 +23,11 @@ def __init__( test_loader, loss, max_epochs, + max_checkpoint_epochs, identifier, verbosity, + save_dir, + checkpoint_dir, ): super().__init__( model, @@ -37,22 +40,32 @@ def __init__( test_loader, loss, max_epochs, + max_checkpoint_epochs, identifier, verbosity, + save_dir, + checkpoint_dir, ) def train(self): + # Start training over epochs loop + # Calculate start_epoch from step instead of loading the epoch number + # to prevent inconsistencies due to different batch size in checkpoint. + start_epoch = self.step // len(self.train_loader) + + end_epoch = ( + self.max_checkpoint_epochs + start_epoch + if self.max_checkpoint_epochs + else self.max_epochs + ) + if self.train_verbosity: logging.info("Starting regular training") logging.info( - f"running for {self.max_epochs} epochs on {type(self.model).__name__} model" + f"running for {end_epoch - start_epoch} epochs on {type(self.model).__name__} model" ) - # Start training over epochs loop - # Calculate start_epoch from step instead of loading the epoch number - # to prevent inconsistencies due to different batch size in checkpoint. - start_epoch = self.step // len(self.train_loader) - for epoch in range(start_epoch, self.max_epochs): + for epoch in range(start_epoch, end_epoch): epoch_start_time = time.time() if self.train_sampler: self.train_sampler.set_epoch(epoch) diff --git a/scripts/main.py b/scripts/main.py index 25a506a9..8efa0324 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,4 +1,3 @@ -import copy import logging import pprint diff --git a/tutorials/MatDeepLearn_Tutorial.ipynb b/tutorials/MatDeepLearn_Tutorial.ipynb new file mode 100644 index 00000000..fd477a97 --- /dev/null +++ b/tutorials/MatDeepLearn_Tutorial.ipynb @@ -0,0 +1,238 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "gpuClass": "standard" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "## **Installation**" + ], + "metadata": { + "id": "8wHTK_7IjIEg" + } + }, + { + "cell_type": "markdown", + "source": [ + "Go to [Google Colab](https://colab.research.google.com/), and under the Github tab, search for [https://github.com/Fung-Lab/MatDeepLearn_dev/](https://github.com/Fung-Lab/MatDeepLearn_dev/). Open up the tutorial in a new tab and follow the steps below." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "### Runtime" + ], + "metadata": { + "id": "F8tZmlwFULDJ" + } + }, + { + "cell_type": "markdown", + "source": [ + "For the purpose of the demo, we recommend you use GPU on Google Colab, although it is not required.\n", + "\n", + "Google Colab provides access to 1 GPU (Runtime -> Change runtime type -> select GPU)." + ], + "metadata": { + "id": "uhMUkc1GUQ47" + } + }, + { + "cell_type": "markdown", + "source": [ + "### MatDeepLearn Set-Up" + ], + "metadata": { + "id": "byczXMtRQ7nj" + } + }, + { + "cell_type": "markdown", + "source": [ + "Install the required dependencies (ignore torchvision, torchtext, and torchaudio errors - these packages are not used)" + ], + "metadata": { + "id": "x61oTV8oRbAA" + } + }, + { + "cell_type": "code", + "source": [ + "%%bash\n", + "pip install torch==1.12.1+cu113 -f https://download.pytorch.org/whl/torch_stable.html \n", + "pip install ase==3.21.*\n", + "pip install torch-scatter torch-sparse torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html" + ], + "metadata": { + "id": "W260O-Twlmjw", + "pycharm": { + "is_executing": true + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "If you would like to verify that the packages were correctly installed, run this command:" + ], + "metadata": { + "id": "GC2zas3y3T6D" + } + }, + { + "cell_type": "code", + "source": [ + "!pip list -v" + ], + "metadata": { + "id": "TUhglzPDK7Cl", + "pycharm": { + "is_executing": true + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Verify PyTorch and CUDA are working" + ], + "metadata": { + "id": "diG5bHFXJ7ks" + } + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "torch.cuda.is_available()" + ], + "metadata": { + "id": "2QU9TxvPJ_Qe", + "pycharm": { + "is_executing": true + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Clone MatDeepLearn package" + ], + "metadata": { + "id": "lfdcDD-xmT9g" + } + }, + { + "cell_type": "code", + "source": [ + "!git clone https://github.com/Fung-Lab/MatDeepLearn_dev\n", + "%cd MatDeepLearn_dev" + ], + "metadata": { + "id": "-BTGUX8rJyWJ" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "Install MatDeepLearn package in environment" + ], + "metadata": { + "id": "slpMHtdomRyp" + } + }, + { + "cell_type": "code", + "source": [ + "!pip install -e ." + ], + "metadata": { + "id": "M_w-B43wmP8L", + "pycharm": { + "is_executing": true + } + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## **Training**" + ], + "metadata": { + "id": "G1Nji_7TAQp_" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Run model training with test data\n" + ], + "metadata": { + "id": "c0IoRVFsmbmT" + } + }, + { + "cell_type": "markdown", + "source": [ + "Modify the `dataset` configuration in configs/config.yml so that the dataset paths point to the sample data under the data directory.\n", + "\n", + "```\n", + " src: \"data/test_data/raw/\"\n", + " target_path: \"data/test_data/targets.csv\"\n", + " pt_path: \"data/test_data/processed/\"\n", + "\n", + "```\n", + "\n", + "If you would like to skip the data processing steps when training, then under `dataset`, modify:\n", + "```\n", + "processed: False \n", + "```\n", + "Because there already is a `data.pt` file under `data/test_data/processed`, this file will automatically be used during training.\n" + ], + "metadata": { + "id": "6LbI-0Fvy1ex" + } + }, + { + "cell_type": "code", + "source": [ + "!python scripts/main.py --run_mode=train --config_path=configs/config.yml" + ], + "metadata": { + "id": "QZ1gMkhLmiFV", + "pycharm": { + "is_executing": true + } + }, + "execution_count": null, + "outputs": [] + } + ] +}