Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions configs/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ model:
otf_edge_attr: False
# Compute node attributes on the fly in the model forward
otf_node_attr: False
# 1 indicates normal behavior, larger numbers indicate the number of models to be used
model_ensemble: 5
# compute gradients w.r.t to positions and cell, requires otf_edge_attr=True
gradient: False

optim:
max_epochs: 40
max_epochs: 20
max_checkpoint_epochs: 0
lr: 0.001
# Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper
Expand Down Expand Up @@ -130,4 +132,4 @@ dataset:
# Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data)
train_ratio: 0.8
val_ratio: 0.05
test_ratio: 0.15
test_ratio: 1
Binary file added data/data.pt
Binary file not shown.
84 changes: 84 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import logging
import pprint
import os
import sys
import shutil
from datetime import datetime
from torch import distributed as dist
from matdeeplearn.common.config.build_config import build_config
from matdeeplearn.common.config.flags import flags
from matdeeplearn.common.trainer_context import new_trainer_context
from matdeeplearn.preprocessor.processor import process_data

# import submitit

# from matdeeplearn.common.utils import setup_logging


class Runner: # submitit.helpers.Checkpointable):
def __init__(self):
self.config = None

def __call__(self, config):

with new_trainer_context(args=args, config=config) as ctx:
self.config = ctx.config
self.task = ctx.task
self.trainer = ctx.trainer

self.task.setup(self.trainer)

# Print settings for job
logging.debug("Settings: ")
logging.debug(pprint.pformat(self.config))

self.task.run()

shutil.move('log_'+config["task"]["log_id"]+'.txt', os.path.join(self.trainer.save_dir, "results", self.trainer.timestamp_id, "log.txt"))

def checkpoint(self, *args, **kwargs):
# new_runner = Runner()
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
self.config["checkpoint"] = self.task.chkpt_path
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
# return submitit.helpers.DelayedSubmission(new_runner, self.config)


if __name__ == "__main__":


# setup_logging()
local_rank = os.environ.get('LOCAL_RANK', None)
if local_rank == None or int(local_rank) == 0:
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)

timestamp = datetime.now().timestamp()
timestamp_id = datetime.fromtimestamp(timestamp).strftime(
"%Y-%m-%d-%H-%M-%S-%f"
)[:-3]
fh = logging.FileHandler('log_'+timestamp_id+'.txt', 'w+')
fh.setLevel(logging.DEBUG)
root_logger.addHandler(fh)

sh = logging.StreamHandler(sys.stdout)
sh.setLevel(logging.DEBUG)
root_logger.addHandler(sh)

parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
config["task"]["log_id"] = timestamp_id

if not config["dataset"]["processed"]:
process_data(config["dataset"])

if args.submit: # Run on cluster
# TODO: add setup to submit to cluster
pass

else: # Run locally
Runner()(config)

19 changes: 14 additions & 5 deletions matdeeplearn/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _process_error(self, e: RuntimeError):
)

def setup(self, trainer):
self.trainer = trainer
self.trainer = trainer
use_checkpoint = self.config["task"].get("continue_job", False)
if use_checkpoint:
logging.info("Attempting to load checkpoint...")
Expand Down Expand Up @@ -62,14 +62,24 @@ def setup(self, trainer):
logging.info("Recent checkpoint loaded successfully.")

def run(self):
# if isinstance(self.trainer.data_loader, list):
assert (
self.trainer.data_loader.get("predict_loader") is not None
), "Predict dataset is required for making predictions"
self.trainer.data_loader[0].get("predict_loader") is not None
), "Predict dataset is required for making predictions"
# else:
# assert (
# self.trainer.data_loader.get("predict_loader") is not None
# ), "Predict dataset is required for making predictions"
results_dir = f"predictions/{self.config['dataset']['name']}"
try:
# if isinstance(self.trainer.data_loader, list):
self.trainer.predict(
loader=self.trainer.data_loader["predict_loader"], split="predict", results_dir=results_dir, labels=self.config["task"]["labels"],
loader=self.trainer.data_loader, split="predict", results_dir=results_dir, labels=self.config["task"]["labels"],
)
# else:
# self.trainer.predict(
# loader=self.trainer.data_loader["predict_loader"], split="predict", results_dir=results_dir, labels=self.config["task"]["labels"],
# )
except RuntimeError as e:
logging.warning("Errors in predict task")
raise e
Expand All @@ -93,4 +103,3 @@ def run(self):
except RuntimeError as e:
self._process_error(e)
raise e

Loading