diff --git a/configs/config.yml b/configs/config.yml index 26ed87b4..e6ec15fc 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -14,10 +14,15 @@ task: load_training_state: False # Path to the checkpoint.pt file checkpoint_path: - # E.g. ["train", "val", "test"] + # Whether to write predictions to csv file. E.g. ["train", "val", "test"] write_output: [train, val, test] + # Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time + output_frequency: 0 + # Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint + model_save_frequency: 0 # Specify if labels are provided for the predict task # labels: True + # Use amp mixed precision use_amp: True model: @@ -34,9 +39,13 @@ model: batch_track_stats: True act: relu dropout_rate: 0.0 - # Compute edge features on the fly - otf_edge: False - # compute gradients w.r.t to positions and cell, requires otf_edge=True + # Compute edge indices on the fly in the model forward + otf_edge_index: False + # Compute edge attributes on the fly in the model forward + otf_edge_attr: False + # Compute node attributes on the fly in the model forward + otf_node_attr: False + # compute gradients w.r.t to positions and cell, requires otf_edge_attr=True gradient: False optim: @@ -47,8 +56,8 @@ optim: loss: loss_type: TorchLossWrapper loss_args: {loss_fn: l1_loss} - clip_grad_norm: 10 - + # gradient clipping value + clip_grad_norm: 10 batch_size: 100 optimizer: optimizer_type: AdamW @@ -63,6 +72,7 @@ optim: dataset: name: test_data + # Whether the data has already been processed and a data.pt file is present from a previous run processed: False # Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path} src: data/test_data/data_graph_scalar.json @@ -71,7 +81,7 @@ dataset: target_path: # Path to save processed data.pt file pt_path: data/ - # Either "node" or "graph" + # Either "node" or "graph" level prediction_level: graph transforms: @@ -81,12 +91,11 @@ dataset: # For example, an index: 0 (default) will use the first entry in the target vector # if all values are to be predicted simultaneously, then specify index: -1 index: -1 - otf: True # Optional parameter, default is False + otf_transform: True # Optional parameter, default is True # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) data_format: json - # E.g. additional_attributes: [forces, stress] + # specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress] additional_attributes: - #additional_attributes: # Print out processing info verbose: True # Index of target column in targets.csv @@ -99,16 +108,26 @@ dataset: # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly preprocess_edge_features: True # determine if node attributes are computed during processing, if false, then they need to be computed on the fly - preprocess_nodes: True + preprocess_node_features: True + # distance cutoff to determine if two atoms are connected by an edge cutoff_radius : 8.0 + # maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors) n_neighbors : 250 + # number of pbc offsets to consider when determining neighbors (usually not changed) num_offsets: 2 - edge_steps : 50 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 50 + # whether or not to add self-loops self_loop: True # Method of obtaining atom dictionary: available: (onehot) node_representation: onehot - all_neighbors: True - # Ratios for train/val/test split out of a total of less than 1 + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cpu + # Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data) train_ratio: 0.8 val_ratio: 0.05 - test_ratio: 0.15 + test_ratio: 0.15 \ No newline at end of file diff --git a/configs/config_calculator.yml b/configs/config_calculator.yml new file mode 100644 index 00000000..64ba8b7e --- /dev/null +++ b/configs/config_calculator.yml @@ -0,0 +1,130 @@ +trainer: matdeeplearn.trainers.PropertyTrainer + +task: + run_mode: train + identifier: my_train_job + parallel: False + # If seed is not set, then it will be random every time + seed: 12345678 + # Defaults to run directory if not specified + save_dir: + # continue from a previous job + continue_job: False + # spefcify if the training state is loaded: epochs, learning rate, etc + load_training_state: False + # Path to the checkpoint.pt file. The model used in the calculator will load parameters from this file. + checkpoint_path: results/2023-09-20-16-22-38-738-my_train_job/checkpoint/best_checkpoint.pt + # E.g. ["train", "val", "test"] + write_output: [train, val, test] + # Specify if labels are provided for the predict task + # labels: True + use_amp: True + +model: + name: CGCNN + # model attributes + dim1: 100 + dim2: 150 + pre_fc_count: 1 + gc_count: 4 + post_fc_count: 3 + pool: global_add_pool + pool_order: early + batch_norm: False + batch_track_stats: True + act: silu + dropout_rate: 0.0 + # Compute edge indices on the fly in the model forward + otf_edge_index: True + # Compute edge attributes on the fly in the model forward + otf_edge_attr: True + # Compute node attributes on the fly in the model forward + otf_node_attr: True + # compute gradients w.r.t to positions and cell, requires otf_edge_attr=True + gradient: True + +optim: + max_epochs: 40 + max_checkpoint_epochs: 0 + lr: 0.002 + # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper + loss: + loss_type: TorchLossWrapper + loss_args: {loss_fn: l1_loss} + # gradient clipping value + clip_grad_norm: 10 + batch_size: 100 + optimizer: + optimizer_type: AdamW + optimizer_args: {} + scheduler: + scheduler_type: ReduceLROnPlateau + scheduler_args: {mode: min, factor: 0.8, patience: 10, min_lr: 0.00001, threshold: 0.0002} + #Training print out frequency (print per n number of epochs) + verbosity: 5 + # tdqm progress bar per batch in the epoch + batch_tqdm: False + +dataset: + name: test_data + # Whether the data has already been processed and a data.pt file is present from a previous run + processed: False + # Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path} + src: data/force_data/data.json + # Path to target file within data_path - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path} or left blank when the dataset is a single json file + # Example: target_path: "data/raw_graph_scalar/targets.csv" + target_path: + # Path to save processed data.pt file + pt_path: data/force_data/ + # Either "node" or "graph" level + prediction_level: graph + + transforms: + - name: GetY + args: + # index specifies the index of a target vector to predict, which is useful when there are multiple property labels for a single dataset + # For example, an index: 0 (default) will use the first entry in the target vector + # if all values are to be predicted simultaneously, then specify index: -1 + index: -1 + otf_transform: True # Optional parameter, default is True + # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) + data_format: json + # specify if additional attributes to be loaded into the dataset from the .json file; e.g. additional_attributes: [forces, stress] + additional_attributes: + # Print out processing info + verbose: True + # Index of target column in targets.csv + # graph specific settings + preprocess_params: + # one of mdl (minimum image convention), ocp (all neighbors included) + edge_calc_method: ocp + # determine if edges are computed, if false, then they need to be computed on the fly + preprocess_edges: False + # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_edge_features: False + # determine if node attributes are computed during processing, if false, then they need to be computed on the fly + preprocess_node_features: False + # distance cutoff to determine if two atoms are connected by an edge + cutoff_radius : 8.0 + # maximum number of neighbors to consider (usually an arbitrarily high number to consider all neighbors) + n_neighbors : 250 + # number of pbc offsets to consider when determining neighbors (usually not changed) + num_offsets: 2 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 50 + # whether or not to add self-loops + self_loop: True + # Method of obtaining atom dictionary: available: (onehot) + node_representation: onehot + all_neighbors: True + + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cpu + # Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data) + train_ratio: 0.9 + val_ratio: 0.05 + test_ratio: 0.05 diff --git a/configs/config_forces.yml b/configs/config_forces.yml index 761a963e..1d8bf283 100644 --- a/configs/config_forces.yml +++ b/configs/config_forces.yml @@ -16,6 +16,10 @@ task: checkpoint_path: # E.g. [train, val, test] write_output: [val, test] + # Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time + output_frequency: 1 + # Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint + model_save_frequency: 1 # Specify if labels are provided for the predict task # labels: True use_amp: False @@ -34,15 +38,19 @@ model: batch_track_stats: True act: silu dropout_rate: 0.0 - # Compute edge features on the fly - otf_edge: True - # compute gradients w.r.t to positions and cell, requires otf_edge=True + # Compute edge indices on the fly in the model forward + otf_edge_index: True + # Compute edge attributes on the fly in the model forward + otf_edge_attr: True + # Compute node attributes on the fly in the model forward + otf_node_attr: False + # compute gradients w.r.t to positions and cell, requires otf_edge_attr=True gradient: True optim: - max_epochs: 40 + max_epochs: 400 max_checkpoint_epochs: 0 - lr: 0.002 + lr: 0.001 # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper loss: #loss_type: "TorchLossWrapper" @@ -69,12 +77,12 @@ dataset: name: test_data processed: False # Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path} - src: /global/cfs/projectdirs/m3641/Shared/Materials_datasets/MP_data_forces/raw/data.json + src: data/force_data/data.json # Path to target file within data_path - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path} or left blank when the dataset is a single json file # Example: target_path: "data/test_data/raw_graph_scalar/targets.csv" target_path: # Path to save processed data.pt file - pt_path: data/ + pt_path: data/force_data/ # Either "node" or "graph" prediction_level: graph @@ -103,18 +111,24 @@ dataset: # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly preprocess_edge_features: False # determine if node attributes are computed during processing, if false, then they need to be computed on the fly - preprocess_nodes: True + preprocess_node_features: True cutoff_radius : 8.0 n_neighbors : 250 num_offsets: 2 - edge_steps : 50 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 50 self_loop: True # Method of obtaining atom dictionary: available: (onehot) node_representation: onehot all_neighbors: True + + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cpu # Ratios for train/val/test split out of a total of less than 1 - train_ratio: 0.8 + train_ratio: 0.9 val_ratio: 0.05 - test_ratio: 0.015 - - + test_ratio: 0.05 \ No newline at end of file diff --git a/matdeeplearn/common/ase_utils.py b/matdeeplearn/common/ase_utils.py new file mode 100644 index 00000000..c25047c3 --- /dev/null +++ b/matdeeplearn/common/ase_utils.py @@ -0,0 +1,122 @@ +import torch +import numpy as np +import yaml +from ase import Atoms +from ase.geometry import Cell +from ase.calculators.calculator import Calculator +from matdeeplearn.preprocessor.helpers import generate_node_features +from torch_geometric.data.data import Data +from torch_geometric.loader import DataLoader +import logging +from typing import List +from matdeeplearn.common.registry import registry + + +logging.basicConfig(level=logging.INFO) + + +class MDLCalculator(Calculator): + implemented_properties = ["energy", "forces", "stress"] + + def __init__(self, config): + """ + Initialize the MDLCalculator instance. + + Args: + config (str or dict): Configuration settings for the MDLCalculator. + + Raises: + AssertionError: If the trainer name is not in the correct format or if the trainer class is not found. + """ + Calculator.__init__(self) + if isinstance(config, str): + with open(config, "r") as yaml_file: + config = yaml.safe_load(yaml_file) + + gradient = config["model"].get("gradient", False) + otf_edge_index = config["model"].get("otf_edge_index", False) + otf_edge_attr = config["model"].get("otf_edge_attr", False) + self.otf_node_attr = config["model"].get("otf_node_attr", False) + assert otf_edge_index and otf_edge_attr and gradient, "To use this calculator to calculate forces and stress, you should set otf_edge_index, oft_edge_attr and gradient to True." + + trainer_name = config.get("trainer", "matdeeplearn.trainers.PropertyTrainer") + assert trainer_name.count(".") >= 1, "Trainer name should be in format {module}.{trainer_name}, like matdeeplearn.trainers.PropertyTrainer" + + trainer_cls = registry.get_trainer_class(trainer_name) + load_state = config['task'].get('checkpoint_path', None) + assert trainer_cls is not None, "Trainer not found" + self.trainer = trainer_cls.from_config(config) + + try: + self.trainer.load_checkpoint() + except ValueError: + logging.warning("No checkpoint.pt file is found, and an untrained model is used for prediction.") + + self.n_neighbors = config['dataset']['preprocess_params'].get('n_neighbors', 250) + self.device = 'cpu' + + def calculate(self, atoms: Atoms, properties=implemented_properties, system_changes=None): + """ + Calculate energy, forces, and stress for a given ase.Atoms object. + + Args: + atoms (ase.Atoms): The atomic structure for which calculations are to be performed. + properties (list): List of properties to calculate. Defaults to ['energy', 'forces', 'stress']. + system_changes: Not supported in the current implementation. + + Returns: + None: The results are stored in the instance variable 'self.results'. + + Note: + This method performs energy, forces, and stress calculations using a neural network-based calculator. + The results are stored in the instance variable 'self.results' as 'energy', 'forces', and 'stress'. + """ + Calculator.calculate(self, atoms, properties, system_changes) + + cell = torch.tensor(atoms.cell.array, dtype=torch.float32) + pos = torch.tensor(atoms.positions, dtype=torch.float32) + atomic_numbers = torch.LongTensor(atoms.get_atomic_numbers()) + + data = Data(n_atoms=len(atomic_numbers), pos=pos, cell=cell.unsqueeze(dim=0), + z=atomic_numbers, structure_id=atoms.info.get('structure_id', None)) + + # Generate node features + if not self.otf_node_attr: + generate_node_features(data, self.n_neighbors, device=self.device) + data.x = data.x.to(torch.float32) + + data_list = [data] + loader = DataLoader(data_list, batch_size=1) + + out = self.trainer.predict_by_calculator(loader) + self.results['energy'] = out['energy'] + self.results['forces'] = out['forces'] + self.results['stress'] = out['stress'] + + @staticmethod + def data_to_atoms_list(data: Data) -> List[Atoms]: + """ + This helper method takes a 'torch_geometric.data.Data' object containing information about atomic structures + and converts it into a list of 'ase.Atoms' objects. Each 'Atoms' object represents an atomic structure + with its associated properties such as positions and cell. + + Args: + data (Data): A data object containing information about atomic structures. + + Returns: + List[Atoms]: A list of 'ase.Atoms' objects, each representing an atomic structure + with positions and associated properties. + """ + cells = data.cell.numpy() + + split_indices = np.cumsum(data.n_atoms)[:-1] + positions_per_structure = np.split(data.pos.numpy(), split_indices) + symbols_per_structure = np.split(data.z.numpy(), split_indices) + + atoms_list = [Atoms( + symbols=symbols_per_structure[i], + positions=positions_per_structure[i], + cell=Cell(cells[i])) for i in range(len(data.structure_id))] + for i in range(len(data.structure_id)): + atoms_list[i].structure_id = data.structure_id[i][0] + return atoms_list diff --git a/matdeeplearn/common/data.py b/matdeeplearn/common/data.py index 025718b4..313d6198 100644 --- a/matdeeplearn/common/data.py +++ b/matdeeplearn/common/data.py @@ -70,7 +70,7 @@ def get_otf_transforms(transform_list: List[dict]): transforms = [] # set transform method for transform in transform_list: - if transform.get("otf", False): + if transform.get("otf_transform", False): transforms.append( registry.get_transform_class( transform["name"], @@ -85,6 +85,7 @@ def get_dataset( processed_file_name, transform_list: List[dict] = [], large_dataset=False, + dataset_device=None, ): """ get dataset according to data_path @@ -111,7 +112,7 @@ def get_dataset( composition = Compose(otf_transforms) if len(otf_transforms) >= 1 else None - dataset = Dataset(data_path, processed_data_path="", processed_file_name=processed_file_name, transform=composition) + dataset = Dataset(data_path, processed_data_path="", processed_file_name=processed_file_name, transform=composition, device=dataset_device) return dataset @@ -136,13 +137,27 @@ def get_dataloader( """ # load data - loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=(sampler is None), - num_workers=num_workers, - pin_memory=True, - sampler=sampler, - ) - + try: + device = str(dataset.dataset[0].pos.device) + except: + device = str(dataset[0].pos.device) + + if device == "cuda:0" or device == "cuda": + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=(sampler is None), + num_workers=0, + pin_memory=False, + sampler=sampler, + ) + else: + loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=(sampler is None), + num_workers=num_workers, + pin_memory=True, + sampler=sampler, + ) return loader diff --git a/matdeeplearn/common/trainer_context.py b/matdeeplearn/common/trainer_context.py index 1909d39d..1dc2e54f 100644 --- a/matdeeplearn/common/trainer_context.py +++ b/matdeeplearn/common/trainer_context.py @@ -115,8 +115,10 @@ def setup_imports(): import_keys = ["trainers", "models", "tasks"] for key in import_keys: - for f in (project_root / "matdeeplearn" / key).rglob("*.py"): - _import_local_file(f, project_root=project_root) + dir_list = (project_root / "matdeeplearn" / key).rglob("*.py") + for f in dir_list: + if "old" not in str(f) and "in_progress" not in str(f): + _import_local_file(f, project_root=project_root) finally: registry.register("imports_setup", True) diff --git a/matdeeplearn/models/__init__.py b/matdeeplearn/models/__init__.py index f8a7fd69..613fbe71 100644 --- a/matdeeplearn/models/__init__.py +++ b/matdeeplearn/models/__init__.py @@ -1,5 +1,7 @@ -__all__ = ["BaseModel", "CGCNN", "DOSPredict"] +__all__ = ["BaseModel", "CGCNN", "MPNN", "SchNet", "TorchMD_ET"] from .base_model import BaseModel from .cgcnn import CGCNN -from .dos_predict import DOSPredict +from .mpnn import MPNN +from .schnet import SchNet +from .torchmd_etEarly import TorchMD_ET diff --git a/matdeeplearn/models/base_model.py b/matdeeplearn/models/base_model.py index d8157717..4ddd3ab9 100644 --- a/matdeeplearn/models/base_model.py +++ b/matdeeplearn/models/base_model.py @@ -23,23 +23,27 @@ class BaseModel(nn.Module, metaclass=ABCMeta): def __init__(self, prediction_level="graph", - otf_edge=False, + otf_edge_index=False, + otf_edge_attr=False, + otf_node_attr=False, graph_method="ocp", gradient=False, cutoff_radius=8, n_neighbors=None, - edge_steps=50, + edge_dim=50, num_offsets=1, **kwargs ) -> None: super(BaseModel, self).__init__() self.prediction_level = prediction_level - self.otf_edge = otf_edge + self.otf_edge_index = otf_edge_index + self.otf_edge_attr = otf_edge_attr + self.otf_node_attr = otf_node_attr self.gradient = gradient self.cutoff_radius = cutoff_radius self.n_neighbors = n_neighbors - self.edge_steps = edge_steps + self.edge_dim = edge_dim self.graph_method = graph_method self.num_offsets = num_offsets @@ -99,10 +103,6 @@ def generate_graph(self, data, cutoff_radius, n_neighbors): n_neighbors: int max number of neighbors - - otf: bool - otf == on-the-fly - if True, this function will be called """ #For calculation of stress, see https://github.com/mir-group/nequip/blob/main/nequip/nn/_grad_output.py @@ -117,7 +117,8 @@ def generate_graph(self, data, cutoff_radius, n_neighbors): if torch.sum(data.cell) == 0: self.graph_method = "mdl" - + + #Can differ from non-otf if amp=True for a very small percentage of edges ~0.01% if self.graph_method == "ocp": edge_index, cell_offsets, neighbors = radius_graph_pbc( cutoff_radius, @@ -126,6 +127,7 @@ def generate_graph(self, data, cutoff_radius, n_neighbors): data.cell, data.n_atoms, [True, True, True], + self.num_offsets, ) edge_gen_out = get_pbc_distances( @@ -143,12 +145,14 @@ def generate_graph(self, data, cutoff_radius, n_neighbors): edge_vec = edge_gen_out["distance_vec"] if(edge_vec.dim() > 2): edge_vec = edge_vec[edge_indices[0], edge_indices[1]] - + + elif self.graph_method == "mdl": edge_index_list = [] edge_weights_list = [] edge_vec_list = [] cell_offsets_list = [] + count = 0 for i in range(0, len(data)): cutoff_distance_matrix, cell_offsets, edge_vec = get_cutoff_distance_matrix( @@ -158,12 +162,15 @@ def generate_graph(self, data, cutoff_radius, n_neighbors): n_neighbors, self.num_offsets, ) - + edge_index, edge_weights = dense_to_sparse(cutoff_distance_matrix) # get into correct shape for model stage edge_vec = edge_vec[edge_index[0], edge_index[1]] - + + edge_index = edge_index + count + count = count + data[i].pos.shape[0] + edge_index_list.append(edge_index) edge_weights_list.append(edge_weights) edge_vec_list.append(edge_vec) @@ -195,7 +202,7 @@ def generate_graph(self, data, cutoff_radius, n_neighbors): generate_node_features(data, n_neighbors) # TODO # check if edge features that is normalized over the entire dataset can be skipped - generate_edge_features(data, self.edge_steps) + generate_edge_features(data, self.edge_dim) ''' return ( edge_index, diff --git a/matdeeplearn/models/cgcnn.py b/matdeeplearn/models/cgcnn.py index 8de2aa85..d33e35de 100644 --- a/matdeeplearn/models/cgcnn.py +++ b/matdeeplearn/models/cgcnn.py @@ -15,7 +15,7 @@ from matdeeplearn.common.registry import registry from matdeeplearn.models.base_model import BaseModel, conditional_grad -from matdeeplearn.preprocessor.helpers import GaussianSmearing +from matdeeplearn.preprocessor.helpers import GaussianSmearing, node_rep_one_hot @registry.register_model("CGCNN") class CGCNN(BaseModel): @@ -55,7 +55,7 @@ def __init__( self.output_dim = output_dim self.dropout_rate = dropout_rate - self.distance_expansion = GaussianSmearing(0.0, self.cutoff_radius, self.edge_steps) + self.distance_expansion = GaussianSmearing(0.0, self.cutoff_radius, self.edge_dim, 0.2) # Determine gc dimension and post_fc dimension assert gc_count > 0, "Need at least 1 GC layer" @@ -142,10 +142,20 @@ def _setup_post_gnn_layers(self): @conditional_grad(torch.enable_grad()) def _forward(self, data): - if self.otf_edge == True: + if self.otf_edge_index == True: #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) - data.edge_attr = self.distance_expansion(data.edge_weight) + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + else: + logging.warning("Edge attributes should be re-computed for otf edge indices.") + + if self.otf_edge_index == False: + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() # Pre-GNN dense layers for i in range(0, len(self.pre_lin_list)): diff --git a/matdeeplearn/models/alignn.py b/matdeeplearn/models/in_progress/alignn.py similarity index 100% rename from matdeeplearn/models/alignn.py rename to matdeeplearn/models/in_progress/alignn.py diff --git a/matdeeplearn/models/alignn_graphite.py b/matdeeplearn/models/in_progress/alignn_graphite.py similarity index 100% rename from matdeeplearn/models/alignn_graphite.py rename to matdeeplearn/models/in_progress/alignn_graphite.py diff --git a/matdeeplearn/models/dimenet_plus_plus.py b/matdeeplearn/models/in_progress/dimenet_plus_plus.py similarity index 100% rename from matdeeplearn/models/dimenet_plus_plus.py rename to matdeeplearn/models/in_progress/dimenet_plus_plus.py diff --git a/matdeeplearn/models/dimenet_plus_plusEarly.py b/matdeeplearn/models/in_progress/dimenet_plus_plusEarly.py similarity index 100% rename from matdeeplearn/models/dimenet_plus_plusEarly.py rename to matdeeplearn/models/in_progress/dimenet_plus_plusEarly.py diff --git a/matdeeplearn/models/in_progress/dimenet_plus_plusEarly_unfinished.py b/matdeeplearn/models/in_progress/dimenet_plus_plusEarly_unfinished.py new file mode 100644 index 00000000..0b4ad39b --- /dev/null +++ b/matdeeplearn/models/in_progress/dimenet_plus_plusEarly_unfinished.py @@ -0,0 +1,495 @@ + +import torch +from torch import nn +import torch.nn.functional as F +import torch_geometric.nn +from torch_geometric.nn import radius_graph +from torch_geometric.nn.inits import glorot_orthogonal + +from torch_geometric.nn.models.dimenet import ( + BesselBasisLayer, + EmbeddingBlock, + Envelope, + ResidualLayer, + SphericalBasisLayer, +) +from torch_geometric.nn.resolver import activation_resolver +from torch_scatter import scatter +from torch_sparse import SparseTensor +from matdeeplearn.preprocessor.helpers import triplets, triplets_pbc +from matdeeplearn.common.registry import registry +from matdeeplearn.models.utils import ( + conditional_grad, +) +from matdeeplearn.models.base_model import BaseModel + +try: + import sympy as sym +except ImportError: + sym = None + + +class InteractionPPBlock(torch.nn.Module): + def __init__( + self, + hidden_channels, + int_emb_size, + basis_emb_size, + num_spherical, + num_radial, + num_before_skip, + num_after_skip, + act="silu", + ): + act = activation_resolver(act) + super(InteractionPPBlock, self).__init__() + self.act = act + + # Transformations of Bessel and spherical basis representations. + self.lin_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False) + self.lin_rbf2 = nn.Linear(basis_emb_size, hidden_channels, bias=False) + self.lin_sbf1 = nn.Linear( + num_spherical * num_radial, basis_emb_size, bias=False + ) + self.lin_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False) + + # Dense transformations of input messages. + self.lin_kj = nn.Linear(hidden_channels, hidden_channels) + self.lin_ji = nn.Linear(hidden_channels, hidden_channels) + + # Embedding projections for interaction triplets. + self.lin_down = nn.Linear(hidden_channels, int_emb_size, bias=False) + self.lin_up = nn.Linear(int_emb_size, hidden_channels, bias=False) + + # Residual layers before and after skip connection. + self.layers_before_skip = torch.nn.ModuleList( + [ + ResidualLayer(hidden_channels, act) + for _ in range(num_before_skip) + ] + ) + self.lin = nn.Linear(hidden_channels, hidden_channels) + self.layers_after_skip = torch.nn.ModuleList( + [ + ResidualLayer(hidden_channels, act) + for _ in range(num_after_skip) + ] + ) + + self.reset_parameters() + + def reset_parameters(self): + glorot_orthogonal(self.lin_rbf1.weight, scale=2.0) + glorot_orthogonal(self.lin_rbf2.weight, scale=2.0) + glorot_orthogonal(self.lin_sbf1.weight, scale=2.0) + glorot_orthogonal(self.lin_sbf2.weight, scale=2.0) + + glorot_orthogonal(self.lin_kj.weight, scale=2.0) + self.lin_kj.bias.data.fill_(0) + glorot_orthogonal(self.lin_ji.weight, scale=2.0) + self.lin_ji.bias.data.fill_(0) + + glorot_orthogonal(self.lin_down.weight, scale=2.0) + glorot_orthogonal(self.lin_up.weight, scale=2.0) + + for res_layer in self.layers_before_skip: + res_layer.reset_parameters() + glorot_orthogonal(self.lin.weight, scale=2.0) + self.lin.bias.data.fill_(0) + for res_layer in self.layers_after_skip: + res_layer.reset_parameters() + + def forward(self, x, rbf, sbf, idx_kj, idx_ji): + # Initial transformations. + x_ji = self.act(self.lin_ji(x)) + x_kj = self.act(self.lin_kj(x)) + + # Transformation via Bessel basis. + rbf = self.lin_rbf1(rbf) + rbf = self.lin_rbf2(rbf) + x_kj = x_kj * rbf + + # Down-project embeddings and generate interaction triplet embeddings. + x_kj = self.act(self.lin_down(x_kj)) + + # Transform via 2D spherical basis. + sbf = self.lin_sbf1(sbf) + sbf = self.lin_sbf2(sbf) + x_kj = x_kj[idx_kj] * sbf + + # Aggregate interactions and up-project embeddings. + x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0)) + x_kj = self.act(self.lin_up(x_kj)) + + h = x_ji + x_kj + for layer in self.layers_before_skip: + h = layer(h) + h = self.act(self.lin(h)) + x + for layer in self.layers_after_skip: + h = layer(h) + + return h + + +class OutputPPBlock(torch.nn.Module): + def __init__( + self, + num_radial, + hidden_channels, + out_emb_channels, + out_channels, + num_layers, + act="silu", + ): + act = activation_resolver(act) + super(OutputPPBlock, self).__init__() + self.act = act + + self.lin_rbf = nn.Linear(num_radial, hidden_channels, bias=False) + self.lin_up = nn.Linear(hidden_channels, out_emb_channels, bias=True) + self.lins = torch.nn.ModuleList() + for _ in range(num_layers): + self.lins.append(nn.Linear(out_emb_channels, out_emb_channels)) + self.lin = nn.Linear(out_emb_channels, out_channels, bias=False) + + self.reset_parameters() + + def reset_parameters(self): + glorot_orthogonal(self.lin_rbf.weight, scale=2.0) + glorot_orthogonal(self.lin_up.weight, scale=2.0) + for lin in self.lins: + glorot_orthogonal(lin.weight, scale=2.0) + lin.bias.data.fill_(0) + self.lin.weight.data.fill_(0) + + def forward(self, x, rbf, i, num_nodes=None): + x = self.lin_rbf(rbf) * x + x = scatter(x, i, dim=0, dim_size=num_nodes) + x = self.lin_up(x) + for lin in self.lins: + x = self.act(lin(x)) + return self.lin(x) + +class DimeNetPlusPlus(BaseModel): + r"""DimeNet++ implementation based on https://github.com/klicperajo/dimenet. + Args: + hidden_channels (int): Hidden embedding size. + out_channels (int): Size of each output sample. + num_blocks (int): Number of building blocks. + int_emb_size (int): Embedding size used for interaction triplets + basis_emb_size (int): Embedding size used in the basis transformation + out_emb_channels(int): Embedding size used for atoms in the output block + num_spherical (int): Number of spherical harmonics. + num_radial (int): Number of radial basis functions. + cutoff: (float, optional): Cutoff distance for interatomic + interactions. (default: :obj:`5.0`) + envelope_exponent (int, optional): Shape of the smooth cutoff. + (default: :obj:`5`) + num_before_skip: (int, optional): Number of residual layers in the + interaction blocks before the skip connection. (default: :obj:`1`) + num_after_skip: (int, optional): Number of residual layers in the + interaction blocks after the skip connection. (default: :obj:`2`) + num_output_layers: (int, optional): Number of linear layers for the + output blocks. (default: :obj:`3`) + act: (function, optional): The activation funtion. + (default: :obj:`silu`) + """ + + url = "https://github.com/klicperajo/dimenet/raw/master/pretrained" + + def __init__( + self, + hidden_channels, + out_channels, + num_blocks, + int_emb_size, + basis_emb_size, + out_emb_channels, + num_spherical, + num_radial, + cutoff=5.0, + envelope_exponent=5, + num_before_skip=1, + num_after_skip=2, + num_output_layers=3, + act="silu", + ): + + act = activation_resolver(act) + + super(DimeNetPlusPlus, self).__init__() + + self.cutoff = cutoff + + if sym is None: + raise ImportError("Package `sympy` could not be found.") + + self.num_blocks = num_blocks + + self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent) + self.sbf = SphericalBasisLayer( + num_spherical, num_radial, cutoff, envelope_exponent + ) + + self.emb = EmbeddingBlock(num_radial, hidden_channels, act) + + self.output_blocks = torch.nn.ModuleList( + [ + OutputPPBlock( + num_radial, + hidden_channels, + out_emb_channels, + out_channels, + num_output_layers, + act, + ) + for _ in range(num_blocks + 1) + ] + ) + + self.interaction_blocks = torch.nn.ModuleList( + [ + InteractionPPBlock( + hidden_channels, + int_emb_size, + basis_emb_size, + num_spherical, + num_radial, + num_before_skip, + num_after_skip, + act, + ) + for _ in range(num_blocks) + ] + ) + + self.reset_parameters() + + def reset_parameters(self): + self.rbf.reset_parameters() + self.emb.reset_parameters() + for out in self.output_blocks: + out.reset_parameters() + for interaction in self.interaction_blocks: + interaction.reset_parameters() + + def triplets(self, edge_index, cell_offsets, num_nodes): + """ + Taken from the DimeNet implementation on OCP + """ + + row, col = edge_index # j->i + + value = torch.arange(row.size(0), device=row.device) + adj_t = SparseTensor( + row=col, col=row, value=value, sparse_sizes=(num_nodes, num_nodes) + ) + adj_t_row = adj_t[row] + num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long) + + # Node indices (k->j->i) for triplets. + idx_i = col.repeat_interleave(num_triplets) + idx_j = row.repeat_interleave(num_triplets) + idx_k = adj_t_row.storage.col() + + # Edge indices (k->j, j->i) for triplets. + idx_kj = adj_t_row.storage.value() + idx_ji = adj_t_row.storage.row() + + # Remove self-loop triplets d->b->d + # Check atom as well as cell offset + cell_offset_kji = cell_offsets[idx_kj] + cell_offsets[idx_ji] + mask = (idx_i != idx_k) | torch.any(cell_offset_kji != 0, dim=-1).to( + device=idx_i.device + ) + + idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask] + idx_kj, idx_ji = idx_kj[mask], idx_ji[mask] + + return idx_i, idx_j, idx_k, idx_kj, idx_ji + + def forward(self, z, pos, batch=None): + """ """ + raise NotImplementedError + +@registry.register_model("dimenetplusplusEarly") +class DimeNetPlusPlusWrap(DimeNetPlusPlus): + def __init__( + self, + #num_atoms, + #bond_feat_dim, # not used + node_dim, + edge_dim, + output_dim, + num_targets=1, + use_pbc=True, + regress_forces=False, + hidden_channels=128, + num_blocks=4, + int_emb_size=64, + basis_emb_size=8, + out_emb_channels=256, + num_spherical=7, + num_radial=6, + otf_graph=False, + cutoff=10.0, + envelope_exponent=5, + num_before_skip=1, + num_after_skip=2, + num_output_layers=3, + num_post_layers=1, + post_hidden_channels=64, + pool="global_mean_pool", + activation="relu", + pool_order="early", + **kwargs, + ): + self.num_targets = num_targets + self.regress_forces = regress_forces + self.use_pbc = use_pbc + self.cutoff = cutoff + self.otf_graph = otf_graph + self.max_neighbors = 50 + + super(DimeNetPlusPlusWrap, self).__init__( + hidden_channels=hidden_channels, + out_channels=output_dim, + num_blocks=num_blocks, + int_emb_size=int_emb_size, + basis_emb_size=basis_emb_size, + out_emb_channels=out_emb_channels, + num_spherical=num_spherical, + num_radial=num_radial, + cutoff=cutoff, + envelope_exponent=envelope_exponent, + num_before_skip=num_before_skip, + num_after_skip=num_after_skip, + num_output_layers=num_output_layers, + ) + self.num_post_layers = num_post_layers + self.post_hidden_channels = post_hidden_channels + self.post_lin_list = nn.ModuleList() + self.pool = pool + self.activation = activation + self.pool_order = pool_order + if self.pool_order == "early": + for i in range(self.num_post_layers): + if i == 0: + self.post_lin_list.append(nn.Linear(hidden_channels, post_hidden_channels)) + else: + self.post_lin_list.append(nn.Linear(post_hidden_channels, post_hidden_channels)) + self.post_lin_list.append(nn.Linear(post_hidden_channels, output_dim)) + + @conditional_grad(torch.enable_grad()) + def _forward(self, data): + pos = data.pos + batch = data.batch + + if self.otf_edge_index == True: + #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, _, data.cell_offsets, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + #( + # edge_index, + # dist, + # _, + # cell_offsets, + # offsets, + # neighbors, + #) = self.generate_graph(data) + + #data.edge_index = edge_index + #data.cell_offsets = cell_offsets + #data.neighbors = neighbors + j, i = data.edge_index + dist = data.edge_weight + + if torch.sum(data.cell) == 0: + _, _, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(data.edge_index, num_nodes=data.z.size(0)) + else: + idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets_pbc( + data.edge_index, + data.cell_offsets, + num_nodes=data.z.size(0),) + offsets = data.cell_offsets + + # Calculate angles. + pos_i = pos[idx_i].detach() + pos_j = pos[idx_j].detach() + if self.use_pbc: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i + offsets[idx_ji], + pos[idx_k].detach() - pos_j + offsets[idx_kj], + ) + else: + pos_ji, pos_kj = ( + pos[idx_j].detach() - pos_i, + pos[idx_k].detach() - pos_j, + ) + + a = (pos_ji * pos_kj).sum(dim=-1) + b = torch.cross(pos_ji, pos_kj).norm(dim=-1) + angle = torch.atan2(b, a) + + rbf = self.rbf(dist) + sbf = self.sbf(dist, angle, idx_kj) + print(rbf.shape) + + # Embedding block. + x = self.emb(data.z.long(), rbf, i, j) + P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0)) + + # Interaction blocks. + for interaction_block, output_block in zip( + self.interaction_blocks, self.output_blocks[1:] + ): + x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) + P += output_block(x, rbf, i, num_nodes=pos.size(0)) + if self.prediction_level == "graph" and self.pool_order == "early": + energy = getattr(torch_geometric.nn, self.pool)(P, data.batch) + x = energy + elif self.prediction_level == "graph": + energy = P.sum(dim=0) if batch is None else scatter(P, batch, dim=0) + x = energy + else: + x = P + + if self.pool_order == "early": + for i in range(0, len(self.post_lin_list) - 1): + x = self.post_lin_list[i](x) + x = getattr(F, self.activation)(x) + x = self.post_lin_list[-1](x) + + return x + + def forward(self, data): + + output = {} + out = self._forward(data) + output["output"] = out + + if self.gradient == True and out.requires_grad == True: + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) + + output["pos_grad"] = forces + output["cell_grad"] = stress + else: + output["pos_grad"] = None + output["cell_grad"] = None + + return output + + @property + def target_attr(self): + return "y" + + @property + def num_params(self): + return sum(p.numel() for p in self.parameters()) \ No newline at end of file diff --git a/matdeeplearn/models/dos_predict.py b/matdeeplearn/models/in_progress/dos_predict.py similarity index 100% rename from matdeeplearn/models/dos_predict.py rename to matdeeplearn/models/in_progress/dos_predict.py diff --git a/matdeeplearn/models/escn/Jd.pt b/matdeeplearn/models/in_progress/escn/Jd.pt similarity index 100% rename from matdeeplearn/models/escn/Jd.pt rename to matdeeplearn/models/in_progress/escn/Jd.pt diff --git a/matdeeplearn/models/escn/escn.py b/matdeeplearn/models/in_progress/escn/escn.py similarity index 100% rename from matdeeplearn/models/escn/escn.py rename to matdeeplearn/models/in_progress/escn/escn.py diff --git a/matdeeplearn/models/escn/so3.py b/matdeeplearn/models/in_progress/escn/so3.py similarity index 100% rename from matdeeplearn/models/escn/so3.py rename to matdeeplearn/models/in_progress/escn/so3.py diff --git a/matdeeplearn/models/gemnet/gemnet.py b/matdeeplearn/models/in_progress/gemnet/gemnet.py similarity index 100% rename from matdeeplearn/models/gemnet/gemnet.py rename to matdeeplearn/models/in_progress/gemnet/gemnet.py diff --git a/matdeeplearn/models/gemnet/initializers.py b/matdeeplearn/models/in_progress/gemnet/initializers.py similarity index 100% rename from matdeeplearn/models/gemnet/initializers.py rename to matdeeplearn/models/in_progress/gemnet/initializers.py diff --git a/matdeeplearn/models/gemnet/layers/atom_update_block.py b/matdeeplearn/models/in_progress/gemnet/layers/atom_update_block.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/atom_update_block.py rename to matdeeplearn/models/in_progress/gemnet/layers/atom_update_block.py diff --git a/matdeeplearn/models/gemnet/layers/base_layers.py b/matdeeplearn/models/in_progress/gemnet/layers/base_layers.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/base_layers.py rename to matdeeplearn/models/in_progress/gemnet/layers/base_layers.py diff --git a/matdeeplearn/models/gemnet/layers/basis_utils.py b/matdeeplearn/models/in_progress/gemnet/layers/basis_utils.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/basis_utils.py rename to matdeeplearn/models/in_progress/gemnet/layers/basis_utils.py diff --git a/matdeeplearn/models/gemnet/layers/compat.py b/matdeeplearn/models/in_progress/gemnet/layers/compat.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/compat.py rename to matdeeplearn/models/in_progress/gemnet/layers/compat.py diff --git a/matdeeplearn/models/gemnet/layers/efficient.py b/matdeeplearn/models/in_progress/gemnet/layers/efficient.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/efficient.py rename to matdeeplearn/models/in_progress/gemnet/layers/efficient.py diff --git a/matdeeplearn/models/gemnet/layers/embedding_block.py b/matdeeplearn/models/in_progress/gemnet/layers/embedding_block.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/embedding_block.py rename to matdeeplearn/models/in_progress/gemnet/layers/embedding_block.py diff --git a/matdeeplearn/models/gemnet/layers/interaction_block.py b/matdeeplearn/models/in_progress/gemnet/layers/interaction_block.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/interaction_block.py rename to matdeeplearn/models/in_progress/gemnet/layers/interaction_block.py diff --git a/matdeeplearn/models/gemnet/layers/radial_basis.py b/matdeeplearn/models/in_progress/gemnet/layers/radial_basis.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/radial_basis.py rename to matdeeplearn/models/in_progress/gemnet/layers/radial_basis.py diff --git a/matdeeplearn/models/gemnet/layers/scale_factor.py b/matdeeplearn/models/in_progress/gemnet/layers/scale_factor.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/scale_factor.py rename to matdeeplearn/models/in_progress/gemnet/layers/scale_factor.py diff --git a/matdeeplearn/models/gemnet/layers/spherical_basis.py b/matdeeplearn/models/in_progress/gemnet/layers/spherical_basis.py similarity index 100% rename from matdeeplearn/models/gemnet/layers/spherical_basis.py rename to matdeeplearn/models/in_progress/gemnet/layers/spherical_basis.py diff --git a/matdeeplearn/models/gemnet/utils.py b/matdeeplearn/models/in_progress/gemnet/utils.py similarity index 100% rename from matdeeplearn/models/gemnet/utils.py rename to matdeeplearn/models/in_progress/gemnet/utils.py diff --git a/matdeeplearn/models/gemnet_oc/README.md b/matdeeplearn/models/in_progress/gemnet_oc/README.md similarity index 100% rename from matdeeplearn/models/gemnet_oc/README.md rename to matdeeplearn/models/in_progress/gemnet_oc/README.md diff --git a/matdeeplearn/models/gemnet_oc/gemnet_oc.py b/matdeeplearn/models/in_progress/gemnet_oc/gemnet_oc.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/gemnet_oc.py rename to matdeeplearn/models/in_progress/gemnet_oc/gemnet_oc.py diff --git a/matdeeplearn/models/gemnet_oc/gemnet_ocAll.py b/matdeeplearn/models/in_progress/gemnet_oc/gemnet_ocAll.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/gemnet_ocAll.py rename to matdeeplearn/models/in_progress/gemnet_oc/gemnet_ocAll.py diff --git a/matdeeplearn/models/gemnet_oc/gemnet_ocEarly.py b/matdeeplearn/models/in_progress/gemnet_oc/gemnet_ocEarly.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/gemnet_ocEarly.py rename to matdeeplearn/models/in_progress/gemnet_oc/gemnet_ocEarly.py diff --git a/matdeeplearn/models/gemnet_oc/gemnet_ocEarlyAll.py b/matdeeplearn/models/in_progress/gemnet_oc/gemnet_ocEarlyAll.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/gemnet_ocEarlyAll.py rename to matdeeplearn/models/in_progress/gemnet_oc/gemnet_ocEarlyAll.py diff --git a/matdeeplearn/models/gemnet_oc/initializers.py b/matdeeplearn/models/in_progress/gemnet_oc/initializers.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/initializers.py rename to matdeeplearn/models/in_progress/gemnet_oc/initializers.py diff --git a/matdeeplearn/models/gemnet_oc/interaction_indices.py b/matdeeplearn/models/in_progress/gemnet_oc/interaction_indices.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/interaction_indices.py rename to matdeeplearn/models/in_progress/gemnet_oc/interaction_indices.py diff --git a/matdeeplearn/models/gemnet_oc/layers/atom_update_block.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/atom_update_block.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/atom_update_block.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/atom_update_block.py diff --git a/matdeeplearn/models/gemnet_oc/layers/base_layers.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/base_layers.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/base_layers.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/base_layers.py diff --git a/matdeeplearn/models/gemnet_oc/layers/basis_utils.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/basis_utils.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/basis_utils.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/basis_utils.py diff --git a/matdeeplearn/models/gemnet_oc/layers/efficient.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/efficient.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/efficient.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/efficient.py diff --git a/matdeeplearn/models/gemnet_oc/layers/embedding_block.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/embedding_block.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/embedding_block.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/embedding_block.py diff --git a/matdeeplearn/models/gemnet_oc/layers/force_scaler.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/force_scaler.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/force_scaler.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/force_scaler.py diff --git a/matdeeplearn/models/gemnet_oc/layers/interaction_block.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/interaction_block.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/interaction_block.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/interaction_block.py diff --git a/matdeeplearn/models/gemnet_oc/layers/radial_basis.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/radial_basis.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/radial_basis.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/radial_basis.py diff --git a/matdeeplearn/models/gemnet_oc/layers/spherical_basis.py b/matdeeplearn/models/in_progress/gemnet_oc/layers/spherical_basis.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/layers/spherical_basis.py rename to matdeeplearn/models/in_progress/gemnet_oc/layers/spherical_basis.py diff --git a/matdeeplearn/models/gemnet_oc/utils.py b/matdeeplearn/models/in_progress/gemnet_oc/utils.py similarity index 100% rename from matdeeplearn/models/gemnet_oc/utils.py rename to matdeeplearn/models/in_progress/gemnet_oc/utils.py diff --git a/matdeeplearn/models/matformer/__init__.py b/matdeeplearn/models/in_progress/matformer/__init__.py similarity index 100% rename from matdeeplearn/models/matformer/__init__.py rename to matdeeplearn/models/in_progress/matformer/__init__.py diff --git a/matdeeplearn/models/matformer/bn_utils.py b/matdeeplearn/models/in_progress/matformer/bn_utils.py similarity index 100% rename from matdeeplearn/models/matformer/bn_utils.py rename to matdeeplearn/models/in_progress/matformer/bn_utils.py diff --git a/matdeeplearn/models/matformer/pyg_att.py b/matdeeplearn/models/in_progress/matformer/pyg_att.py similarity index 100% rename from matdeeplearn/models/matformer/pyg_att.py rename to matdeeplearn/models/in_progress/matformer/pyg_att.py diff --git a/matdeeplearn/models/matformer/transformer.py b/matdeeplearn/models/in_progress/matformer/transformer.py similarity index 100% rename from matdeeplearn/models/matformer/transformer.py rename to matdeeplearn/models/in_progress/matformer/transformer.py diff --git a/matdeeplearn/models/matformer/utils.py b/matdeeplearn/models/in_progress/matformer/utils.py similarity index 100% rename from matdeeplearn/models/matformer/utils.py rename to matdeeplearn/models/in_progress/matformer/utils.py diff --git a/matdeeplearn/models/megnet.py b/matdeeplearn/models/in_progress/megnet.py similarity index 82% rename from matdeeplearn/models/megnet.py rename to matdeeplearn/models/in_progress/megnet.py index ab4629d7..5db94d0d 100644 --- a/matdeeplearn/models/megnet.py +++ b/matdeeplearn/models/in_progress/megnet.py @@ -13,13 +13,17 @@ from torch_scatter import scatter_mean, scatter_add, scatter_max, scatter from matdeeplearn.common.registry import registry +from matdeeplearn.models.base_model import BaseModel, conditional_grad +from matdeeplearn.preprocessor.helpers import GaussianSmearing, node_rep_one_hot @registry.register_model("MEGNet") # Megnet -class MEGNet(torch.nn.Module): +class MEGNet(BaseModel): def __init__( self, - data, + node_dim, + edge_dim, + output_dim, dim1=64, dim2=64, dim3=64, @@ -29,18 +33,15 @@ def __init__( post_fc_count=1, pool="global_mean_pool", pool_order="early", - batch_norm="True", - batch_track_stats="True", + batch_norm=True, + batch_track_stats=True, act="relu", dropout_rate=0.0, **kwargs ): super(MEGNet, self).__init__() - if batch_track_stats == "False": - self.batch_track_stats = False - else: - self.batch_track_stats = True + self.batch_track_stats = batch_track_stats self.batch_norm = batch_norm self.pool = pool if pool == "global_mean_pool": @@ -51,35 +52,33 @@ def __init__( self.pool_reduce="sum" self.act = act self.pool_order = pool_order + self.node_dim = node_dim + self.edge_dim = edge_dim + self.output_dim = output_dim self.dropout_rate = dropout_rate + self.pre_fc_count = pre_fc_count + self.dim1 = dim1 + self.dim2 = dim2 + self.dim3 = dim3 + self.gc_count = gc_count + self.post_fc_count = post_fc_count + + self.distance_expansion = GaussianSmearing(0.0, self.cutoff_radius, self.edge_dim, 0.2) ##Determine gc dimension dimension assert gc_count > 0, "Need at least 1 GC layer" if pre_fc_count == 0: - gc_dim = data.num_features + self.gc_dim = self.node_dim else: - gc_dim = dim1 + self.gc_dim = dim1 ##Determine post_fc dimension - post_fc_dim = dim3 - ##Determine output dimension length - if data[0][self.target_attr].ndim == 0: - self.output_dim = 1 - else: - self.output_dim = len(data[0][self.target_attr]) + self.post_fc_dim = dim3 + # setup layers + self.pre_lin_list = self._setup_pre_gnn_layers() + self.conv_list, self.bn_list = self._setup_gnn_layers() + self.post_lin_list, self.lin_out = self._setup_post_gnn_layers() - ##Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer) - if pre_fc_count > 0: - self.pre_lin_list = torch.nn.ModuleList() - for i in range(pre_fc_count): - if i == 0: - lin = torch.nn.Linear(data.num_features, dim1) - self.pre_lin_list.append(lin) - else: - lin = torch.nn.Linear(dim1, dim1) - self.pre_lin_list.append(lin) - elif pre_fc_count == 0: - self.pre_lin_list = torch.nn.ModuleList() ##Set up GNN layers self.e_embed_list = torch.nn.ModuleList() @@ -162,6 +161,47 @@ def __init__( @property def target_attr(self): return "y" + + ## Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer) + def _setup_pre_gnn_layers(self): + pre_lin_list = torch.nn.ModuleList() + if self.pre_fc_count > 0: + pre_lin_list = torch.nn.ModuleList() + for i in range(self.pre_fc_count): + if i == 0: + lin = torch.nn.Linear(self.node_dim, self.dim1) + + else: + lin = torch.nn.Linear(self.dim1, self.dim1) + pre_lin_list.append(lin) + + return pre_lin_list + + def _setup_post_gnn_layers(self): + """Sets up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero).""" + post_lin_list = torch.nn.ModuleList() + if self.post_fc_count > 0: + for i in range(self.post_fc_count): + if i == 0: + # Set2set pooling has doubled dimension + if self.pool_order == "early" and self.pool == "set2set": + lin = torch.nn.Linear(self.post_fc_dim * 5, self.dim2) + else: + lin = torch.nn.Linear(self.post_fc_dim * 3, self.dim2) + else: + lin = torch.nn.Linear(self.dim2, self.dim2) + post_lin_list.append(lin) + lin_out = torch.nn.Linear(self.dim2, self.output_dim) + # Set up set2set pooling (if used) + + # else post_fc_count is 0 + else: + if self.pool_order == "early" and self.pool == "set2set": + lin_out = torch.nn.Linear(self.post_fc_dim * 2, self.output_dim) + else: + lin_out = torch.nn.Linear(self.post_fc_dim, self.output_dim) + + return post_lin_list, lin_out def forward(self, data): diff --git a/matdeeplearn/models/ocpbase.py b/matdeeplearn/models/in_progress/ocpbase.py similarity index 100% rename from matdeeplearn/models/ocpbase.py rename to matdeeplearn/models/in_progress/ocpbase.py diff --git a/matdeeplearn/models/painn/painn.py b/matdeeplearn/models/in_progress/painn/painn.py similarity index 100% rename from matdeeplearn/models/painn/painn.py rename to matdeeplearn/models/in_progress/painn/painn.py diff --git a/matdeeplearn/models/painn/painnAllNeigbors.py b/matdeeplearn/models/in_progress/painn/painnAllNeigbors.py similarity index 100% rename from matdeeplearn/models/painn/painnAllNeigbors.py rename to matdeeplearn/models/in_progress/painn/painnAllNeigbors.py diff --git a/matdeeplearn/models/painn/painnEarly.py b/matdeeplearn/models/in_progress/painn/painnEarly.py similarity index 100% rename from matdeeplearn/models/painn/painnEarly.py rename to matdeeplearn/models/in_progress/painn/painnEarly.py diff --git a/matdeeplearn/models/painn/painnEarlyAll.py b/matdeeplearn/models/in_progress/painn/painnEarlyAll.py similarity index 100% rename from matdeeplearn/models/painn/painnEarlyAll.py rename to matdeeplearn/models/in_progress/painn/painnEarlyAll.py diff --git a/matdeeplearn/models/painn/utils.py b/matdeeplearn/models/in_progress/painn/utils.py similarity index 100% rename from matdeeplearn/models/painn/utils.py rename to matdeeplearn/models/in_progress/painn/utils.py diff --git a/matdeeplearn/models/scn/Jd.pt b/matdeeplearn/models/in_progress/scn/Jd.pt similarity index 100% rename from matdeeplearn/models/scn/Jd.pt rename to matdeeplearn/models/in_progress/scn/Jd.pt diff --git a/matdeeplearn/models/scn/README.md b/matdeeplearn/models/in_progress/scn/README.md similarity index 100% rename from matdeeplearn/models/scn/README.md rename to matdeeplearn/models/in_progress/scn/README.md diff --git a/matdeeplearn/models/scn/sampling.py b/matdeeplearn/models/in_progress/scn/sampling.py similarity index 100% rename from matdeeplearn/models/scn/sampling.py rename to matdeeplearn/models/in_progress/scn/sampling.py diff --git a/matdeeplearn/models/scn/scn.py b/matdeeplearn/models/in_progress/scn/scn.py similarity index 100% rename from matdeeplearn/models/scn/scn.py rename to matdeeplearn/models/in_progress/scn/scn.py diff --git a/matdeeplearn/models/scn/smearing.py b/matdeeplearn/models/in_progress/scn/smearing.py similarity index 100% rename from matdeeplearn/models/scn/smearing.py rename to matdeeplearn/models/in_progress/scn/smearing.py diff --git a/matdeeplearn/models/scn/spherical_harmonics.py b/matdeeplearn/models/in_progress/scn/spherical_harmonics.py similarity index 100% rename from matdeeplearn/models/scn/spherical_harmonics.py rename to matdeeplearn/models/in_progress/scn/spherical_harmonics.py diff --git a/matdeeplearn/models/spinconv.py b/matdeeplearn/models/in_progress/spinconv.py similarity index 100% rename from matdeeplearn/models/spinconv.py rename to matdeeplearn/models/in_progress/spinconv.py diff --git a/matdeeplearn/models/spinconvEarly.py b/matdeeplearn/models/in_progress/spinconvEarly.py similarity index 100% rename from matdeeplearn/models/spinconvEarly.py rename to matdeeplearn/models/in_progress/spinconvEarly.py diff --git a/matdeeplearn/models/torchmd_et.py b/matdeeplearn/models/in_progress/torchmd_et.py similarity index 100% rename from matdeeplearn/models/torchmd_et.py rename to matdeeplearn/models/in_progress/torchmd_et.py diff --git a/matdeeplearn/models/torchmd_gn.py b/matdeeplearn/models/in_progress/torchmd_gn.py similarity index 100% rename from matdeeplearn/models/torchmd_gn.py rename to matdeeplearn/models/in_progress/torchmd_gn.py diff --git a/matdeeplearn/models/torchmd_t.py b/matdeeplearn/models/in_progress/torchmd_t.py similarity index 100% rename from matdeeplearn/models/torchmd_t.py rename to matdeeplearn/models/in_progress/torchmd_t.py diff --git a/matdeeplearn/models/mpnn.py b/matdeeplearn/models/mpnn.py index f6b7c01c..82bf5a94 100644 --- a/matdeeplearn/models/mpnn.py +++ b/matdeeplearn/models/mpnn.py @@ -13,7 +13,8 @@ from torch_scatter import scatter, scatter_add, scatter_max, scatter_mean from matdeeplearn.common.registry import registry -from matdeeplearn.models.base_model import BaseModel +from matdeeplearn.models.base_model import BaseModel, conditional_grad +from matdeeplearn.preprocessor.helpers import GaussianSmearing, node_rep_one_hot # CGCNN @@ -21,9 +22,9 @@ class MPNN(BaseModel): def __init__( self, - edge_steps, - self_loop, - data, + node_dim, + edge_dim, + output_dim, dim1=64, dim2=64, dim3=64, @@ -38,13 +39,16 @@ def __init__( dropout_rate=0.0, **kwargs ): - super(MPNN, self).__init__(edge_steps, self_loop) + super(MPNN, self).__init__(**kwargs) self.batch_track_stats = batch_track_stats self.batch_norm = batch_norm self.pool = pool self.act = act self.pool_order = pool_order + self.node_dim = node_dim + self.edge_dim = edge_dim + self.output_dim = output_dim self.dropout_rate = dropout_rate self.pre_fc_count = pre_fc_count self.dim1 = dim1 @@ -52,27 +56,22 @@ def __init__( self.dim3 = dim3 self.gc_count = gc_count self.post_fc_count = post_fc_count - self.num_features = data.num_features - self.num_edge_features = data.num_edge_features + + self.distance_expansion = GaussianSmearing(0.0, self.cutoff_radius, self.edge_dim, 0.2) ## Determine gc dimension dimension and post_fc dimension assert gc_count > 0, "Need at least 1 GC layer" if pre_fc_count == 0: - self.gc_dim, self.post_fc_dim = self.data.num_features, data.num_features + self.gc_dim, self.post_fc_dim = self.node_dim, self.node_dim else: self.gc_dim, self.post_fc_dim = dim1, dim1 - ## Determine output dimension length - if data[0][self.target_attr].ndim == 0: - self.output_dim = 1 - else: - self.output_dim = len(data[0][self.target_attr]) - # setup layers self.pre_lin_list = self._setup_pre_gnn_layers() self.conv_list, self.bn_list, self.gru_list = self._setup_gnn_layers() self.post_lin_list, self.lin_out = self._setup_post_gnn_layers() - + + # set up output layer if self.pool_order == "early" and self.pool == "set2set": self.set2set = Set2Set(self.post_fc_dim, processing_steps=3) elif self.pool_order == "late" and self.pool == "set2set": @@ -84,14 +83,14 @@ def __init__( def target_attr(self): return "y" - ## Set up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer) def _setup_pre_gnn_layers(self): + """Sets up pre-GNN dense layers (NOTE: in v0.1 this is always set to 1 layer).""" pre_lin_list = torch.nn.ModuleList() if self.pre_fc_count > 0: pre_lin_list = torch.nn.ModuleList() for i in range(self.pre_fc_count): if i == 0: - lin = torch.nn.Linear(self.num_features, self.dim1) + lin = torch.nn.Linear(self.node_dim, self.dim1) else: lin = torch.nn.Linear(self.dim1, self.dim1) @@ -100,13 +99,13 @@ def _setup_pre_gnn_layers(self): return pre_lin_list def _setup_gnn_layers(self): - ## Set up GNN layers + """Sets up GNN layers.""" conv_list = torch.nn.ModuleList() gru_list = torch.nn.ModuleList() bn_list = torch.nn.ModuleList() for i in range(self.gc_count): nn = Sequential( - Linear(self.num_edge_features, self.dim3), + Linear(self.edge_dim, self.dim3), ReLU(), Linear(self.dim3, self.gc_dim * self.gc_dim), ) @@ -124,9 +123,9 @@ def _setup_gnn_layers(self): return conv_list, bn_list, gru_list - ## Set up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero) def _setup_post_gnn_layers(self): + """Sets up post-GNN dense layers (NOTE: in v0.1 there was a minimum of 2 dense layers, and fc_count(now post_fc_count) added to this number. In the current version, the minimum is zero).""" post_lin_list = torch.nn.ModuleList() if self.post_fc_count > 0: for i in range(self.post_fc_count): @@ -149,9 +148,25 @@ def _setup_post_gnn_layers(self): else: lin_out = torch.nn.Linear(self.post_fc_dim, self.output_dim) return post_lin_list, lin_out + + @conditional_grad(torch.enable_grad()) + def _forward(self, data): - def forward(self, data): - + if self.otf_edge_index == True: + #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + else: + logging.warning("Edge attributes should be re-computed for otf edge indices.") + + if self.otf_edge_index == False: + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() + ## Pre-GNN dense layers for i in range(0, len(self.pre_lin_list)): if i == 0: @@ -185,28 +200,57 @@ def forward(self, data): out = out.squeeze(0) ## Post-GNN dense layers - if self.pool_order == "early": - if self.pool == "set2set": - out = self.set2set(out, data.batch) - else: - out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + if self.prediction_level == "graph": + if self.pool_order == "early": + if self.pool == "set2set": + out = self.set2set(out, data.batch) + else: + out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + for i in range(0, len(self.post_lin_list)): + out = self.post_lin_list[i](out) + out = getattr(F, self.act)(out) + out = self.lin_out(out) + + elif self.pool_order == "late": + for i in range(0, len(self.post_lin_list)): + out = self.post_lin_list[i](out) + out = getattr(F, self.act)(out) + out = self.lin_out(out) + if self.pool == "set2set": + out = self.set2set(out, data.batch) + out = self.lin_out_2(out) + else: + out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + + elif self.prediction_level == "node": for i in range(0, len(self.post_lin_list)): out = self.post_lin_list[i](out) out = getattr(F, self.act)(out) - out = self.lin_out(out) + out = self.lin_out(out) + + return out - elif self.pool_order == "late": - for i in range(0, len(self.post_lin_list)): - out = self.post_lin_list[i](out) - out = getattr(F, self.act)(out) - out = self.lin_out(out) - if self.pool == "set2set": - out = self.set2set(out, data.batch) - out = self.lin_out_2(out) - else: - out = getattr(torch_geometric.nn, self.pool)(out, data.batch) + def forward(self, data): + + output = {} + out = self._forward(data) + output["output"] = out + + if self.gradient == True and out.requires_grad == True: + volume = torch.einsum("zi,zi->z", data.cell[:, 0, :], torch.cross(data.cell[:, 1, :], data.cell[:, 2, :], dim=1)).unsqueeze(-1) + grad = torch.autograd.grad( + out, + [data.pos, data.displacement], + grad_outputs=torch.ones_like(out), + create_graph=self.training) + forces = -1 * grad[0] + stress = grad[1] + stress = stress / volume.view(-1, 1, 1) - if out.shape[1] == 1: - return out.view(-1) + output["pos_grad"] = forces + output["cell_grad"] = stress else: - return out + output["pos_grad"] = None + output["cell_grad"] = None + + return output \ No newline at end of file diff --git a/matdeeplearn/models/schnet.py b/matdeeplearn/models/schnet.py index 08046ff7..acf86a15 100644 --- a/matdeeplearn/models/schnet.py +++ b/matdeeplearn/models/schnet.py @@ -15,7 +15,7 @@ from matdeeplearn.common.registry import registry from matdeeplearn.models.base_model import BaseModel, conditional_grad -from matdeeplearn.preprocessor.helpers import GaussianSmearing +from matdeeplearn.preprocessor.helpers import GaussianSmearing, node_rep_one_hot @registry.register_model("SchNet") class SchNet(BaseModel): @@ -57,7 +57,7 @@ def __init__( self.output_dim = output_dim self.dropout_rate = dropout_rate - self.distance_expansion = GaussianSmearing(0.0, self.cutoff_radius, self.edge_steps) + self.distance_expansion = GaussianSmearing(0.0, self.cutoff_radius, self.edge_dim, 0.2) # Determine gc dimension and post_fc dimension assert gc_count > 0, "Need at least 1 GC layer" @@ -141,10 +141,21 @@ def _setup_post_gnn_layers(self): @conditional_grad(torch.enable_grad()) def _forward(self, data): - if self.otf_edge == True: + + if self.otf_edge_index == True: #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) - data.edge_attr = self.distance_expansion(data.edge_weight) + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + else: + logging.warning("Edge attributes should be re-computed for otf edge indices.") + + if self.otf_edge_index == False: + if self.otf_edge_attr == True: + data.edge_attr = self.distance_expansion(data.edge_weight) + + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() # Pre-GNN dense layers for i in range(0, len(self.pre_lin_list)): @@ -235,4 +246,4 @@ def forward(self, data): output["pos_grad"] = None output["cell_grad"] = None - return output \ No newline at end of file + return output diff --git a/matdeeplearn/models/torchmd_etEarly.py b/matdeeplearn/models/torchmd_etEarly.py index 4292dfef..ca9685a2 100644 --- a/matdeeplearn/models/torchmd_etEarly.py +++ b/matdeeplearn/models/torchmd_etEarly.py @@ -13,8 +13,9 @@ act_class_mapping, ) from matdeeplearn.models.base_model import BaseModel, conditional_grad -from matdeeplearn.models.output_modules import EquivariantScalar +from matdeeplearn.models.torchmd_output_modules import Scalar, EquivariantScalar from matdeeplearn.common.registry import registry +from matdeeplearn.preprocessor.helpers import node_rep_one_hot @registry.register_model("torchmd_etEarly") @@ -180,14 +181,18 @@ def _forward(self, data): #assert ( # edge_vec is not None #), "Distance module did not return directional information" - if self.otf_edge == True: - data.edge_index, data.edge_weight, data.edge_vec, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + if self.otf_edge_index == True: + #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) data.edge_attr = self.distance_expansion(data.edge_weight) - + #mask = data.edge_index[0] != data.edge_index[1] #data.edge_vec[mask] = data.edge_vec[mask] / torch.norm(data.edge_vec[mask], dim=1).unsqueeze(1) data.edge_vec = data.edge_vec / torch.norm(data.edge_vec, dim=1).unsqueeze(1) + if self.otf_node_attr == True: + data.x = node_rep_one_hot(data.z).float() + if self.neighbor_embedding is not None: x = self.neighbor_embedding(data.z, x, data.edge_index, data.edge_weight, data.edge_attr) @@ -407,4 +412,4 @@ def aggregate( def update( self, inputs: Tuple[torch.Tensor, torch.Tensor] ) -> Tuple[torch.Tensor, torch.Tensor]: - return inputs \ No newline at end of file + return inputs diff --git a/matdeeplearn/models/output_modules.py b/matdeeplearn/models/torchmd_output_modules.py similarity index 100% rename from matdeeplearn/models/output_modules.py rename to matdeeplearn/models/torchmd_output_modules.py diff --git a/matdeeplearn/preprocessor/datasets.py b/matdeeplearn/preprocessor/datasets.py index 1d432536..99006b04 100644 --- a/matdeeplearn/preprocessor/datasets.py +++ b/matdeeplearn/preprocessor/datasets.py @@ -21,13 +21,12 @@ def __init__( super(StructureDataset, self).__init__( root, transform, pre_transform, pre_filter ) - if not torch.cuda.is_available() or device == "cpu": self.data, self.slices = torch.load( self.processed_paths[0], map_location=torch.device("cpu") ) else: - self.data, self.slices = torch.load(self.processed_paths[0]) + self.data, self.slices = torch.load(self.processed_paths[0], map_location=device) @property def raw_file_names(self): diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index a79acc63..b3cf157c 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -19,7 +19,6 @@ def calculate_edges_master( method: Literal["ase", "ocp", "mdl"], - all_neighbors: bool, r: float, n_neighbors: int, offset_number: int, @@ -33,9 +32,6 @@ def calculate_edges_master( ) -> dict[str, torch.Tensor]: """Generates edges using one of three methods (ASE, OCP, or MDL implementations) due to limitations of each method. Args: - all_neighbors (bool): Whether or not to use all neighbors (ASE method) - or only the n_neighbors closest neighbors. - OCP based on all_neighbors and MDL based on original without considering all. r (float): cutoff radius n_neighbors (int): number of neighbors to consider structure_id (str): structure id @@ -43,12 +39,6 @@ def calculate_edges_master( pos (torch.Tensor): positions of atom in unit cell """ - if method == "ase" or method == "ocp": - assert (method == "ase" and all_neighbors) or ( - method == "ocp" and all_neighbors - ), "OCP and ASE methods only support all_neighbors=True" - #if method == "ase": - # raise Warning("ASE does not take into account n_neighbors") out = dict() neighbors = torch.empty(0) @@ -72,10 +62,10 @@ def calculate_edges_master( # get into correct shape for model stage edge_vec = edge_vec[edge_index[0], edge_index[1]] - elif method == "ase": - edge_index, cell_offsets, edge_weights, edge_vec = calculate_edges_ase( - all_neighbors, r, n_neighbors, structure_id, cell.squeeze(0), pos - ) + #elif method == "ase": + # edge_index, cell_offsets, edge_weights, edge_vec = calculate_edges_ase( + # all_neighbors, r, n_neighbors, structure_id, cell.squeeze(0), pos + # ) elif method == "ocp": # OCP requires a different format for the cell @@ -439,7 +429,10 @@ def add_selfloop( return edge_indices, edge_weights, distance_matrix_masked -def load_node_representation(node_representation="onehot"): +def node_rep_one_hot(Z): + return F.one_hot(Z - 1, num_classes = 100) + +def node_rep_from_file(node_representation="onehot"): node_rep_path = Path(__file__).parent default_reps = {"onehot": str(node_rep_path / "./node_representations/onehot.csv")} @@ -461,21 +454,16 @@ def load_node_representation(node_representation="onehot"): return loaded_rep - -def generate_node_features(input_data, n_neighbors, device, use_degree=False): - node_reps = load_node_representation() - node_reps = torch.from_numpy(node_reps).to(device) - n_elements, n_features = node_reps.shape - +def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = node_rep_one_hot): if isinstance(input_data, Data): - input_data.x = node_reps[input_data.z - 1].view(-1, n_features) + input_data.x = node_rep_func(input_data.z) if use_degree: return one_hot_degree(input_data, n_neighbors) return input_data for i, data in enumerate(input_data): # minus 1 as the reps are 0-indexed but atomic number starts from 1 - data.x = node_reps[data.z - 1].view(-1, n_features).float() + data.x = node_rep_func(data.z).float() #for i, data in enumerate(input_data): #input_data[i] = one_hot_degree(data, n_neighbors) @@ -492,7 +480,7 @@ def generate_edge_features(input_data, edge_steps, r, device): input_data[i].edge_attr = distance_gaussian( input_data[i].edge_descriptor["distance"] ) -def tripletsOld( +def triplets( edge_index, num_nodes, ): @@ -517,7 +505,7 @@ def tripletsOld( return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji -def triplets(edge_index, cell_offsets, num_nodes): +def triplets_pbc(edge_index, cell_offsets, num_nodes): """ Taken from the DimeNet implementation on OCP """ diff --git a/matdeeplearn/preprocessor/transformsNew.py b/matdeeplearn/preprocessor/in_progress/transformsNew.py similarity index 100% rename from matdeeplearn/preprocessor/transformsNew.py rename to matdeeplearn/preprocessor/in_progress/transformsNew.py diff --git a/matdeeplearn/preprocessor/deprecated.py b/matdeeplearn/preprocessor/old/deprecated.py similarity index 100% rename from matdeeplearn/preprocessor/deprecated.py rename to matdeeplearn/preprocessor/old/deprecated.py diff --git a/matdeeplearn/preprocessor/helpersOld.py b/matdeeplearn/preprocessor/old/helpersOld.py similarity index 100% rename from matdeeplearn/preprocessor/helpersOld.py rename to matdeeplearn/preprocessor/old/helpersOld.py diff --git a/matdeeplearn/preprocessor/processorOld.py b/matdeeplearn/preprocessor/old/processorOld.py similarity index 100% rename from matdeeplearn/preprocessor/processorOld.py rename to matdeeplearn/preprocessor/old/processorOld.py diff --git a/matdeeplearn/preprocessor/processor.py b/matdeeplearn/preprocessor/processor.py index 06f709f4..4e0c0712 100644 --- a/matdeeplearn/preprocessor/processor.py +++ b/matdeeplearn/preprocessor/processor.py @@ -29,18 +29,17 @@ def from_config(dataset_config): prediction_level = dataset_config.get("prediction_level", "graph") preprocess_edges = dataset_config["preprocess_params"]["preprocess_edges"] preprocess_edge_features = dataset_config["preprocess_params"]["preprocess_edge_features"] - preprocess_nodes = dataset_config["preprocess_params"]["preprocess_nodes"] + preprocess_node_features = dataset_config["preprocess_params"]["preprocess_node_features"] cutoff_radius = dataset_config["preprocess_params"]["cutoff_radius"] n_neighbors = dataset_config["preprocess_params"]["n_neighbors"] num_offsets = dataset_config["preprocess_params"]["num_offsets"] - edge_steps = dataset_config["preprocess_params"]["edge_steps"] + edge_dim = dataset_config["preprocess_params"]["edge_dim"] data_format = dataset_config.get("data_format", "json") image_selfloop = dataset_config.get("image_selfloop", True) self_loop = dataset_config.get("self_loop", True) node_representation = dataset_config["preprocess_params"].get("node_representation", "onehot") additional_attributes = dataset_config.get("additional_attributes", []) verbose: bool = dataset_config.get("verbose", True) - all_neighbors = dataset_config["preprocess_params"]["all_neighbors"] edge_calc_method = dataset_config["preprocess_params"].get("edge_calc_method", "mdl") device: str = dataset_config.get("device", "cpu") @@ -51,11 +50,11 @@ def from_config(dataset_config): prediction_level=prediction_level, preprocess_edges=preprocess_edges, preprocess_edge_features=preprocess_edge_features, - preprocess_nodes=preprocess_nodes, + preprocess_node_features=preprocess_node_features, r=cutoff_radius, n_neighbors=n_neighbors, num_offsets=num_offsets, - edge_steps=edge_steps, + edge_dim=edge_dim, transforms=dataset_config.get("transforms", []), data_format=data_format, image_selfloop=image_selfloop, @@ -63,7 +62,6 @@ def from_config(dataset_config): node_representation=node_representation, additional_attributes=additional_attributes, verbose=verbose, - all_neighbors=all_neighbors, edge_calc_method=edge_calc_method, device=device, ) @@ -87,11 +85,11 @@ def __init__( prediction_level: str, preprocess_edges, preprocess_edge_features, - preprocess_nodes, + preprocess_node_features, r: float, n_neighbors: int, num_offsets: int, - edge_steps: int, + edge_dim: int, transforms: list = [], data_format: str = "json", image_selfloop: bool = True, @@ -99,7 +97,6 @@ def __init__( node_representation: str = "onehot", additional_attributes: list = [], verbose: bool = True, - all_neighbors: bool = False, edge_calc_method: str = "mdl", device: str = "cpu", ) -> None: @@ -124,7 +121,7 @@ def __init__( max number of neighbors to be considered => closest n neighbors will be kept - edge_steps: int + edge_dim: int step size for creating Gaussian basis for edges used in torch.linspace @@ -162,17 +159,16 @@ def __init__( self.prediction_level = prediction_level self.preprocess_edges = preprocess_edges self.preprocess_edge_features = preprocess_edge_features - self.preprocess_nodes = preprocess_nodes + self.preprocess_node_features = preprocess_node_features self.n_neighbors = n_neighbors self.num_offsets = num_offsets - self.edge_steps = edge_steps + self.edge_dim = edge_dim self.data_format = data_format self.image_selfloop = image_selfloop self.self_loop = self_loop self.node_representation = node_representation self.additional_attributes = additional_attributes self.verbose = verbose - self.all_neighbors = all_neighbors self.edge_calc_method = edge_calc_method self.device = device self.transforms = transforms @@ -451,7 +447,6 @@ def get_data_list(self, dict_structures): if self.preprocess_edges == True: edge_gen_out = calculate_edges_master( self.edge_calc_method, - self.all_neighbors, self.r, self.n_neighbors, self.num_offsets, @@ -476,7 +471,7 @@ def get_data_list(self, dict_structures): data.edge_descriptor = {} # data.edge_descriptor["mask"] = cd_matrix_masked data.edge_descriptor["distance"] = edge_weights - data.distances = edge_weights + # data.distances = edge_weights # add additional attributes @@ -484,13 +479,13 @@ def get_data_list(self, dict_structures): for attr in self.additional_attributes: data.__setattr__(attr, sdict[attr]) - if self.preprocess_nodes == True: + if self.preprocess_node_features == True: logging.info("Generating node features...") generate_node_features(data_list, self.n_neighbors, device=self.device) if self.preprocess_edge_features == True: logging.info("Generating edge features...") - generate_edge_features(data_list, self.edge_steps, self.r, device=self.device) + generate_edge_features(data_list, self.edge_dim, self.r, device=self.device) # compile non-otf transforms logging.debug("Applying transforms.") @@ -502,14 +497,13 @@ def get_data_list(self, dict_structures): transforms_list = [] for transform in self.transforms: - if not transform.get("otf", False): + if not transform.get("otf_transform", False): transforms_list.append( registry.get_transform_class( transform["name"], **({} if transform["args"] is None else transform["args"]) ) ) - composition = Compose(transforms_list) # apply transforms diff --git a/matdeeplearn/trainers/base_trainer.py b/matdeeplearn/trainers/base_trainer.py index a1561404..4a5cb6a5 100644 --- a/matdeeplearn/trainers/base_trainer.py +++ b/matdeeplearn/trainers/base_trainer.py @@ -34,15 +34,17 @@ def __init__( optimizer: Optimizer, sampler: DistributedSampler, scheduler: LRScheduler, - data_loader: DataLoader, + data_loader: DataLoader, loss: nn.Module, max_epochs: int, - clip_grad_norm: int = None, + clip_grad_norm: float = None, max_checkpoint_epochs: int = None, identifier: str = None, verbosity: int = None, - batch_tqdm: bool = False, + batch_tqdm: bool = False, write_output: list = ["train", "val", "test"], + output_frequency: int = 1, + model_save_frequency: int = 1, save_dir: str = None, checkpoint_path: str = None, use_amp: bool = False, @@ -61,6 +63,8 @@ def __init__( self.train_verbosity = verbosity self.batch_tqdm = batch_tqdm self.write_output = write_output + self.output_frequency = output_frequency + self.model_save_frequency = model_save_frequency self.epoch = 0 self.step = 0 @@ -85,10 +89,10 @@ def __init__( else: self.rank = self.train_sampler.rank - timestamp = torch.tensor(datetime.now().timestamp()).to(self.device) - self.timestamp_id = datetime.fromtimestamp(timestamp.int()).strftime( - "%Y-%m-%d-%H-%M-%S" - ) + timestamp = datetime.now().timestamp() + self.timestamp_id = datetime.fromtimestamp(timestamp).strftime( + "%Y-%m-%d-%H-%M-%S-%f" + )[:-3] if identifier: self.timestamp_id = f"{self.timestamp_id}-{identifier}" @@ -101,7 +105,7 @@ def __init__( logging.info(f"Dataset length: {key, len(self.dataset[key])}") if self.dataset.get("train"): logging.debug(self.dataset["train"][0]) - logging.debug(self.dataset["train"][0].x[0]) + logging.debug(self.dataset["train"][0].z[0]) logging.debug(self.dataset["train"][0].y[0]) else: logging.debug(self.dataset[list(self.dataset.keys())[0]][0]) @@ -158,6 +162,8 @@ def from_config(cls, config): verbosity = config["optim"].get("verbosity", None) batch_tqdm = config["optim"].get("batch_tqdm", False) write_output = config["task"].get("write_output", []) + output_frequency = config["task"].get("output_frequency", 0) + model_save_frequency = config["task"].get("model_save_frequency", 0) max_checkpoint_epochs = config["optim"].get("max_checkpoint_epochs", None) identifier = config["task"].get("identifier", None) @@ -167,7 +173,7 @@ def from_config(cls, config): if local_world_size > 1: dist.barrier() - + return cls( model=model, dataset=dataset, @@ -183,6 +189,8 @@ def from_config(cls, config): verbosity=verbosity, batch_tqdm=batch_tqdm, write_output=write_output, + output_frequency=output_frequency, + model_save_frequency=model_save_frequency, save_dir=save_dir, checkpoint_path=checkpoint_path, use_amp=config["task"].get("use_amp", False), @@ -191,7 +199,11 @@ def from_config(cls, config): @staticmethod def _load_dataset(dataset_config, task): """Loads the dataset if from a config file.""" - + if dataset_config.get("dataset_device", "cpu"): + logging.info("Loading dataset to "+dataset_config.get("dataset_device", "cpu")) + else: + logging.info("Loading dataset to default device") + dataset_path = dataset_config["pt_path"] dataset = {} if isinstance(dataset_config["src"], dict): @@ -200,24 +212,28 @@ def _load_dataset(dataset_config, task): dataset_path, processed_file_name="data_train.pt", transform_list=dataset_config.get("transforms", []), + dataset_device=dataset_config.get("dataset_device", "cpu"), ) if dataset_config["src"].get("val"): dataset["val"] = get_dataset( dataset_path, processed_file_name="data_val.pt", transform_list=dataset_config.get("transforms", []), + dataset_device=dataset_config.get("dataset_device", "cpu"), ) if dataset_config["src"].get("test"): dataset["test"] = get_dataset( dataset_path, processed_file_name="data_test.pt", transform_list=dataset_config.get("transforms", []), - ) + dataset_device=dataset_config.get("dataset_device", "cpu"), + ) if dataset_config["src"].get("predict"): dataset["predict"] = get_dataset( dataset_path, processed_file_name="data_predict.pt", transform_list=dataset_config.get("transforms", []), + dataset_device=dataset_config.get("dataset_device", "cpu"), ) else: @@ -226,6 +242,7 @@ def _load_dataset(dataset_config, task): dataset_path, processed_file_name="data.pt", transform_list=dataset_config.get("transforms", []), + dataset_device=dataset_config.get("dataset_device", "cpu"), ) train_ratio = dataset_config["train_ratio"] val_ratio = dataset_config["val_ratio"] @@ -242,6 +259,7 @@ def _load_dataset(dataset_config, task): dataset_path, processed_file_name="data.pt", transform_list=dataset_config.get("transforms", []), + dataset_device=dataset_config.get("dataset_device", "cpu"), ) return dataset @@ -258,16 +276,20 @@ def _load_model(model_config, graph_config, dataset, world_size, rank): if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset - # Obtain node, edge, and output dimensions for model initialization - node_dim = dataset.num_features - edge_dim = graph_config["edge_steps"] + # Obtain node, edge, and output dimensions for model initialization + + if graph_config["node_dim"]: + node_dim = graph_config["node_dim"] + else: + node_dim = dataset.num_features + edge_dim = graph_config["edge_dim"] if dataset[0]["y"].ndim == 0: output_dim = 1 else: output_dim = dataset[0]["y"].shape[1] # Determine if this is a node or graph level model - if dataset[0]["y"].shape[0] == dataset[0]["x"].shape[0]: + if dataset[0]["y"].shape[0] == dataset[0]["z"].shape[0]: model_config["prediction_level"] = "node" elif dataset[0]["y"].shape[0] == 1: model_config["prediction_level"] = "graph" @@ -283,7 +305,6 @@ def _load_model(model_config, graph_config, dataset, world_size, rank): output_dim=output_dim, cutoff_radius=graph_config["cutoff_radius"], n_neighbors=graph_config["n_neighbors"], - edge_steps=graph_config["edge_steps"], graph_method=graph_config["edge_calc_method"], num_offsets=graph_config["num_offsets"], **model_config @@ -341,19 +362,19 @@ def _load_dataloader(optim_config, dataset_config, dataset, sampler, run_mode): batch_size = optim_config.get("batch_size") if dataset.get("train"): data_loader["train_loader"] = get_dataloader( - dataset["train"], batch_size=batch_size, sampler=sampler + dataset["train"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=sampler ) if dataset.get("val"): data_loader["val_loader"] = get_dataloader( - dataset["val"], batch_size=batch_size, sampler=None + dataset["val"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None ) if dataset.get("test"): data_loader["test_loader"] = get_dataloader( - dataset["test"], batch_size=batch_size, sampler=None + dataset["test"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None ) if run_mode == "predict" and dataset.get("predict"): data_loader["predict_loader"] = get_dataloader( - dataset["predict"], batch_size=batch_size, sampler=None + dataset["predict"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None ) return data_loader @@ -391,24 +412,26 @@ def validate(self): def predict(self): """Implemented by derived classes.""" - def update_best_model(self, metric): + def update_best_model(self, metric, write_model=False, write_csv=False): """Updates the best val metric and model, saves the best model, and saves the best model predictions""" self.best_metric = metric[type(self.loss_fn).__name__]["metric"] if str(self.rank) not in ("cpu", "cuda"): self.best_model_state = copy.deepcopy(self.model.module.state_dict()) else: self.best_model_state = copy.deepcopy(self.model.state_dict()) - self.save_model("best_checkpoint.pt", metric, True) - - logging.debug( - f"Saving prediction results for epoch {self.epoch} to: /results/{self.timestamp_id}/train_results/" - ) - if "train" in self.write_output: - self.predict(self.data_loader["train_loader"], "train") - if "val" in self.write_output and self.data_loader.get("val_loader"): - self.predict(self.data_loader["val_loader"], "val") - if "test" in self.write_output and self.data_loader.get("test_loader"): - self.predict(self.data_loader["test_loader"], "test") + if write_model == True: + self.save_model("best_checkpoint.pt", metric, True) + + if write_csv == True: + logging.debug( + f"Saving prediction results for epoch {self.epoch} to: /results/{self.timestamp_id}/train_results/" + ) + if "train" in self.write_output: + self.predict(self.data_loader["train_loader"], "train") + if "val" in self.write_output and self.data_loader.get("val_loader"): + self.predict(self.data_loader["val_loader"], "val") + if "test" in self.write_output and self.data_loader.get("test_loader"): + self.predict(self.data_loader["test_loader"], "test") def save_model(self, checkpoint_file, metric=None, training_state=True): """Saves the model state dict""" diff --git a/matdeeplearn/trainers/base_trainer.py.old b/matdeeplearn/trainers/old/base_trainer.py.old similarity index 100% rename from matdeeplearn/trainers/base_trainer.py.old rename to matdeeplearn/trainers/old/base_trainer.py.old diff --git a/matdeeplearn/trainers/property_trainer.py.old b/matdeeplearn/trainers/old/property_trainer.py.old similarity index 100% rename from matdeeplearn/trainers/property_trainer.py.old rename to matdeeplearn/trainers/old/property_trainer.py.old diff --git a/matdeeplearn/trainers/property_trainer.py.old2 b/matdeeplearn/trainers/old/property_trainer.py.old2 similarity index 100% rename from matdeeplearn/trainers/property_trainer.py.old2 rename to matdeeplearn/trainers/old/property_trainer.py.old2 diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 92439015..3fd71839 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -31,6 +31,8 @@ def __init__( verbosity, batch_tqdm, write_output, + output_frequency, + model_save_frequency, save_dir, checkpoint_path, use_amp, @@ -50,6 +52,8 @@ def __init__( verbosity, batch_tqdm, write_output, + output_frequency, + model_save_frequency, save_dir, checkpoint_path, use_amp, @@ -124,7 +128,8 @@ def train(self): # Save current model torch.cuda.empty_cache() if str(self.rank) in ("0", "cpu", "cuda"): - self.save_model(checkpoint_file="checkpoint.pt", training_state=True) + if self.model_save_frequency == 1: + self.save_model(checkpoint_file="checkpoint.pt", training_state=True) # Evaluate on validation set if it exists if self.data_loader.get("val_loader"): @@ -143,8 +148,16 @@ def train(self): # Update best val metric and model, and save best model and predicted outputs if metric[type(self.loss_fn).__name__]["metric"] < self.best_metric: - self.update_best_model(metric) - + if self.output_frequency == 0: + if self.model_save_frequency == 1: + self.update_best_model(metric, write_model=True, write_csv=False) + else: + self.update_best_model(metric, write_model=False, write_csv=False) + elif self.output_frequency == 1: + if self.model_save_frequency == 1: + self.update_best_model(metric, write_model=True, write_csv=True) + else: + self.update_best_model(metric, write_model=False, write_csv=True) # step scheduler, using validation error self._scheduler_step() @@ -155,14 +168,21 @@ def train(self): self.model.module.load_state_dict(self.best_model_state) elif str(self.rank) in ("cpu", "cuda"): self.model.load_state_dict(self.best_model_state) - - if self.data_loader.get("test_loader"): - metric = self.validate("test") - test_loss = metric[type(self.loss_fn).__name__]["metric"] - else: - test_loss = "N/A" - logging.info("Test loss: " + str(test_loss)) - + #if self.data_loader.get("test_loader"): + # metric = self.validate("test") + # test_loss = metric[type(self.loss_fn).__name__]["metric"] + #else: + # test_loss = "N/A" + if self.model_save_frequency != -1: + self.save_model("best_checkpoint.pt", metric, True) + logging.info("Final Losses: ") + if "train" in self.write_output: + self.predict(self.data_loader["train_loader"], "train") + if "val" in self.write_output and self.data_loader.get("val_loader"): + self.predict(self.data_loader["val_loader"], "val") + if "test" in self.write_output and self.data_loader.get("test_loader"): + self.predict(self.data_loader["test_loader"], "test") + return self.best_model_state @torch.no_grad() @@ -314,7 +334,31 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, torch.cuda.empty_cache() - return predictions + return predictions + + def predict_by_calculator(self, loader): + self.model.eval() + + assert isinstance(loader, torch.utils.data.dataloader.DataLoader) + assert len(loader) == 1, f"Predicting by calculator only allows one structure at a time, but got {len(loader)} structures." + + if str(self.rank) not in ("cpu", "cuda"): + loader = get_dataloader( + loader.dataset, batch_size=loader.batch_size, sampler=None + ) + + results = [] + loader_iter = iter(loader) + for i in range(0, len(loader_iter)): + batch = next(loader_iter).to(self.rank) + out = self._forward(batch.to(self.rank)) + + energy = None if out.get('output') is None else out.get('output').data.cpu().numpy() + stress = None if out.get('cell_grad') is None else out.get('cell_grad').view(-1, 3).data.cpu().numpy() + forces = None if out.get('pos_grad') is None else out.get('pos_grad').data.cpu().numpy() + + results = {'energy': energy, 'stress': stress, 'forces': forces} + return results def _forward(self, batch_data): output = self.model(batch_data) diff --git a/scripts/main.py b/scripts/main.py index 768df8cb..72e3516d 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,6 +1,9 @@ import logging import pprint import os +import sys +import shutil +from datetime import datetime from torch import distributed as dist from matdeeplearn.common.config.build_config import build_config from matdeeplearn.common.config.flags import flags @@ -17,10 +20,11 @@ def __init__(self): self.config = None def __call__(self, config): + with new_trainer_context(args=args, config=config) as ctx: self.config = ctx.config self.task = ctx.task - self.trainer = ctx.trainer + self.trainer = ctx.trainer self.task.setup(self.trainer) @@ -29,6 +33,8 @@ def __call__(self, config): logging.debug(pprint.pformat(self.config)) self.task.run() + + shutil.move('log_'+config["task"]["log_id"]+'.txt', os.path.join(self.trainer.save_dir, "results", self.trainer.timestamp_id, "log.txt")) def checkpoint(self, *args, **kwargs): # new_runner = Runner() @@ -49,9 +55,22 @@ def checkpoint(self, *args, **kwargs): root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) + timestamp = datetime.now().timestamp() + timestamp_id = datetime.fromtimestamp(timestamp).strftime( + "%Y-%m-%d-%H-%M-%S-%f" + )[:-3] + fh = logging.FileHandler('log_'+timestamp_id+'.txt', 'w+') + fh.setLevel(logging.DEBUG) + root_logger.addHandler(fh) + + sh = logging.StreamHandler(sys.stdout) + sh.setLevel(logging.DEBUG) + root_logger.addHandler(sh) + parser = flags.get_parser() args, override_args = parser.parse_known_args() config = build_config(args, override_args) + config["task"]["log_id"] = timestamp_id if not config["dataset"]["processed"]: process_data(config["dataset"]) @@ -62,3 +81,4 @@ def checkpoint(self, *args, **kwargs): else: # Run locally Runner()(config) + diff --git a/test/configs/cpu/test_predict.yml b/test/configs/cpu/test_predict.yml index 2c1b1467..f6efe6f7 100644 --- a/test/configs/cpu/test_predict.yml +++ b/test/configs/cpu/test_predict.yml @@ -16,8 +16,13 @@ task: checkpoint_path: test/saved_models/cgcnn.pt # E.g. ["train", "val", "test"] write_output: [train, val, test] + # Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time + output_frequency: 0 + # Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint + model_save_frequency: 0 # Specify if labels are provided for the predict task # labels: True + # Use amp mixed precision use_amp: True model: @@ -33,9 +38,12 @@ model: batch_norm: True batch_track_stats: True act: relu - dropout_rate: 0.0 - # Compute edge features on the fly - otf_edge: False + # Compute edge indices on the fly in the model forward + otf_edge_index: False + # Compute edge attributes on the fly in the model forward + otf_edge_attr: False + # Compute node attributes on the fly in the model forward + otf_node_attr: False # compute gradients w.r.t to positions and cell, requires otf_edge=True gradient: False @@ -47,8 +55,8 @@ optim: loss: loss_type: TorchLossWrapper loss_args: {loss_fn: l1_loss} - clip_grad_norm: 10 - + # gradient clipping value + clip_grad_norm: 10 batch_size: 100 optimizer: optimizer_type: AdamW @@ -81,7 +89,7 @@ dataset: # For example, an index: 0 (default) will use the first entry in the target vector # if all values are to be predicted simultaneously, then specify index: -1 index: -1 - otf: True # Optional parameter, default is False + otf_transform: True # Optional parameter, default is False # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) data_format: json # E.g. additional_attributes: [forces, stress] @@ -99,15 +107,21 @@ dataset: # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly preprocess_edge_features: True # determine if node attributes are computed during processing, if false, then they need to be computed on the fly - preprocess_nodes: True + preprocess_node_features: True cutoff_radius : 8.0 n_neighbors : 250 num_offsets: 2 - edge_steps : 50 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 50 self_loop: True # Method of obtaining atom dictionary: available: (onehot) node_representation: onehot - all_neighbors: True + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cuda # Ratios for train/val/test split out of a total of less than 1 train_ratio: 0.8 val_ratio: 0.05 diff --git a/test/configs/cpu/test_training.yml b/test/configs/cpu/test_training.yml index 60dac406..a4d38b74 100644 --- a/test/configs/cpu/test_training.yml +++ b/test/configs/cpu/test_training.yml @@ -16,8 +16,13 @@ task: checkpoint_path: # E.g. ["train", "val", "test"] write_output: [train, val, test] + # Frequency of writing to file; 0 denotes writing only at the end, 1 denotes writing every time + output_frequency: 0 + # Frequency of saving model .pt file; 0 denotes saving only at the end, 1 denotes saving every time, -1 denotes never saving; this controls both checkpoint and best_checkpoint + model_save_frequency: 0 # Specify if labels are provided for the predict task # labels: True + # Use amp mixed precision use_amp: True model: @@ -33,9 +38,12 @@ model: batch_norm: True batch_track_stats: True act: relu - dropout_rate: 0.0 - # Compute edge features on the fly - otf_edge: False + # Compute edge indices on the fly in the model forward + otf_edge_index: False + # Compute edge attributes on the fly in the model forward + otf_edge_attr: False + # Compute node attributes on the fly in the model forward + otf_node_attr: False # compute gradients w.r.t to positions and cell, requires otf_edge=True gradient: False @@ -47,8 +55,8 @@ optim: loss: loss_type: TorchLossWrapper loss_args: {loss_fn: l1_loss} - clip_grad_norm: 10 - + # gradient clipping value + clip_grad_norm: 10 batch_size: 100 optimizer: optimizer_type: AdamW @@ -81,7 +89,7 @@ dataset: # For example, an index: 0 (default) will use the first entry in the target vector # if all values are to be predicted simultaneously, then specify index: -1 index: -1 - otf: True # Optional parameter, default is False + otf_transform: True # Optional parameter, default is False # Format of data files (limit to those supported by ASE: https://wiki.fysik.dtu.dk/ase/ase/io/io.html) data_format: json # E.g. additional_attributes: [forces, stress] @@ -99,15 +107,21 @@ dataset: # determine if edge attributes are computed during processing, if false, then they need to be computed on the fly preprocess_edge_features: True # determine if node attributes are computed during processing, if false, then they need to be computed on the fly - preprocess_nodes: True + preprocess_node_features: True cutoff_radius : 8.0 n_neighbors : 250 num_offsets: 2 - edge_steps : 50 + # dimension of node attributes + node_dim : 100 + # dimension of edge attributes + edge_dim : 50 self_loop: True # Method of obtaining atom dictionary: available: (onehot) node_representation: onehot - all_neighbors: True + # Number of workers for dataloader, see https://pytorch.org/docs/stable/data.html + num_workers: 0 + # Where the dataset is loaded; either "cpu" or "cuda" + dataset_device: cuda # Ratios for train/val/test split out of a total of less than 1 train_ratio: 0.8 val_ratio: 0.05