diff --git a/configs/config.yml b/configs/config.yml index e6ec15fc..2d536d7b 100644 --- a/configs/config.yml +++ b/configs/config.yml @@ -45,11 +45,13 @@ model: otf_edge_attr: False # Compute node attributes on the fly in the model forward otf_node_attr: False + # 1 indicates normal behavior, larger numbers indicate the number of models to be used + model_ensemble: 5 # compute gradients w.r.t to positions and cell, requires otf_edge_attr=True gradient: False optim: - max_epochs: 40 + max_epochs: 20 max_checkpoint_epochs: 0 lr: 0.001 # Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper @@ -130,4 +132,4 @@ dataset: # Ratios for train/val/test split out of a total of less than 1 (0.8 corresponds to 80% of the data) train_ratio: 0.8 val_ratio: 0.05 - test_ratio: 0.15 \ No newline at end of file + test_ratio: 1 diff --git a/data/data.pt b/data/data.pt new file mode 100644 index 00000000..1675b93d Binary files /dev/null and b/data/data.pt differ diff --git a/main.py b/main.py new file mode 100644 index 00000000..72e3516d --- /dev/null +++ b/main.py @@ -0,0 +1,84 @@ +import logging +import pprint +import os +import sys +import shutil +from datetime import datetime +from torch import distributed as dist +from matdeeplearn.common.config.build_config import build_config +from matdeeplearn.common.config.flags import flags +from matdeeplearn.common.trainer_context import new_trainer_context +from matdeeplearn.preprocessor.processor import process_data + +# import submitit + +# from matdeeplearn.common.utils import setup_logging + + +class Runner: # submitit.helpers.Checkpointable): + def __init__(self): + self.config = None + + def __call__(self, config): + + with new_trainer_context(args=args, config=config) as ctx: + self.config = ctx.config + self.task = ctx.task + self.trainer = ctx.trainer + + self.task.setup(self.trainer) + + # Print settings for job + logging.debug("Settings: ") + logging.debug(pprint.pformat(self.config)) + + self.task.run() + + shutil.move('log_'+config["task"]["log_id"]+'.txt', os.path.join(self.trainer.save_dir, "results", self.trainer.timestamp_id, "log.txt")) + + def checkpoint(self, *args, **kwargs): + # new_runner = Runner() + self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True) + self.config["checkpoint"] = self.task.chkpt_path + self.config["timestamp_id"] = self.trainer.timestamp_id + if self.trainer.logger is not None: + self.trainer.logger.mark_preempting() + # return submitit.helpers.DelayedSubmission(new_runner, self.config) + + +if __name__ == "__main__": + + + # setup_logging() + local_rank = os.environ.get('LOCAL_RANK', None) + if local_rank == None or int(local_rank) == 0: + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + + timestamp = datetime.now().timestamp() + timestamp_id = datetime.fromtimestamp(timestamp).strftime( + "%Y-%m-%d-%H-%M-%S-%f" + )[:-3] + fh = logging.FileHandler('log_'+timestamp_id+'.txt', 'w+') + fh.setLevel(logging.DEBUG) + root_logger.addHandler(fh) + + sh = logging.StreamHandler(sys.stdout) + sh.setLevel(logging.DEBUG) + root_logger.addHandler(sh) + + parser = flags.get_parser() + args, override_args = parser.parse_known_args() + config = build_config(args, override_args) + config["task"]["log_id"] = timestamp_id + + if not config["dataset"]["processed"]: + process_data(config["dataset"]) + + if args.submit: # Run on cluster + # TODO: add setup to submit to cluster + pass + + else: # Run locally + Runner()(config) + diff --git a/matdeeplearn/tasks/task.py b/matdeeplearn/tasks/task.py index 4b32794f..6a48172e 100644 --- a/matdeeplearn/tasks/task.py +++ b/matdeeplearn/tasks/task.py @@ -29,7 +29,7 @@ def _process_error(self, e: RuntimeError): ) def setup(self, trainer): - self.trainer = trainer + self.trainer = trainer use_checkpoint = self.config["task"].get("continue_job", False) if use_checkpoint: logging.info("Attempting to load checkpoint...") @@ -62,14 +62,24 @@ def setup(self, trainer): logging.info("Recent checkpoint loaded successfully.") def run(self): + # if isinstance(self.trainer.data_loader, list): assert ( - self.trainer.data_loader.get("predict_loader") is not None - ), "Predict dataset is required for making predictions" + self.trainer.data_loader[0].get("predict_loader") is not None + ), "Predict dataset is required for making predictions" + # else: + # assert ( + # self.trainer.data_loader.get("predict_loader") is not None + # ), "Predict dataset is required for making predictions" results_dir = f"predictions/{self.config['dataset']['name']}" try: + # if isinstance(self.trainer.data_loader, list): self.trainer.predict( - loader=self.trainer.data_loader["predict_loader"], split="predict", results_dir=results_dir, labels=self.config["task"]["labels"], + loader=self.trainer.data_loader, split="predict", results_dir=results_dir, labels=self.config["task"]["labels"], ) + # else: + # self.trainer.predict( + # loader=self.trainer.data_loader["predict_loader"], split="predict", results_dir=results_dir, labels=self.config["task"]["labels"], + # ) except RuntimeError as e: logging.warning("Errors in predict task") raise e @@ -93,4 +103,3 @@ def run(self): except RuntimeError as e: self._process_error(e) raise e - diff --git a/matdeeplearn/trainers/base_trainer.py b/matdeeplearn/trainers/base_trainer.py index 4a5cb6a5..318d31b1 100644 --- a/matdeeplearn/trainers/base_trainer.py +++ b/matdeeplearn/trainers/base_trainer.py @@ -68,10 +68,11 @@ def __init__( self.epoch = 0 self.step = 0 - self.metrics = {} self.epoch_time = None - self.best_metric = 1e10 - self.best_model_state = None + + self.metrics = [{} for _ in range(len(self.model))] + self.best_metric = [1e10 for _ in range(len(self.model))] + self.best_model_state = [None for _ in range(len(self.model))] self.save_dir = save_dir if save_dir else os.getcwd() self.checkpoint_path = checkpoint_path @@ -113,9 +114,9 @@ def __init__( logging.debug(self.dataset[list(self.dataset.keys())[0]][0].y[0]) if str(self.rank) not in ("cpu", "cuda"): - logging.debug(self.model.module) + logging.debug(self.model[0].module) else: - logging.debug(self.model) + logging.debug(self.model[0]) @classmethod def from_config(cls, config): @@ -153,6 +154,7 @@ def from_config(cls, config): dataset, sampler, config["task"]["run_mode"], + config["model"] ) scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer) @@ -276,49 +278,64 @@ def _load_model(model_config, graph_config, dataset, world_size, rank): if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset + model_list = [] # Obtain node, edge, and output dimensions for model initialization - - if graph_config["node_dim"]: - node_dim = graph_config["node_dim"] - else: - node_dim = dataset.num_features - edge_dim = graph_config["edge_dim"] - if dataset[0]["y"].ndim == 0: - output_dim = 1 - else: - output_dim = dataset[0]["y"].shape[1] + for mod in range(model_config["model_ensemble"]): + rand_seed = random.randint(1,10000) + random.seed(rand_seed) + np.random.seed(rand_seed) + torch.manual_seed(rand_seed) + torch.cuda.manual_seed_all(rand_seed) + #torch.autograd.set_detect_anomaly(True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + if graph_config["node_dim"]: + node_dim = graph_config["node_dim"] + else: + node_dim = dataset.num_features + edge_dim = graph_config["edge_dim"] + if dataset[0]["y"].ndim == 0: + output_dim = 1 + else: + output_dim = dataset[0]["y"].shape[1] - # Determine if this is a node or graph level model - if dataset[0]["y"].shape[0] == dataset[0]["z"].shape[0]: - model_config["prediction_level"] = "node" - elif dataset[0]["y"].shape[0] == 1: - model_config["prediction_level"] = "graph" - else: - raise ValueError( - "Target labels do not have the correct dimensions for node or graph-level prediction." - ) + # Determine if this is a node or graph level model + if dataset[0]["y"].shape[0] == dataset[0]["z"].shape[0]: + model_config["prediction_level"] = "node" + elif dataset[0]["y"].shape[0] == 1: + model_config["prediction_level"] = "graph" + else: + raise ValueError( + "Target labels do not have the correct dimensions for node or graph-level prediction." + ) - model_cls = registry.get_model_class(model_config["name"]) - model = model_cls( - node_dim=node_dim, - edge_dim=edge_dim, - output_dim=output_dim, - cutoff_radius=graph_config["cutoff_radius"], - n_neighbors=graph_config["n_neighbors"], - graph_method=graph_config["edge_calc_method"], - num_offsets=graph_config["num_offsets"], - **model_config - ) - model = model.to(rank) - # model = torch_geometric.compile(model) - # if model_config["load_model"] == True: - # checkpoint = torch.load(model_config["model_path"]) - # model.load_state_dict(checkpoint["state_dict"]) - if world_size > 1: - model = DistributedDataParallel( - model, device_ids=[rank], find_unused_parameters=False - ) - return model + model_cls = registry.get_model_class(model_config["name"]) + model = model_cls( + node_dim=node_dim, + edge_dim=edge_dim, + output_dim=output_dim, + cutoff_radius=graph_config["cutoff_radius"], + n_neighbors=graph_config["n_neighbors"], + graph_method=graph_config["edge_calc_method"], + num_offsets=graph_config["num_offsets"], + **model_config + ) + model = model.to(rank) + + # model = torch_geometric.compile(model) + # if model_config["load_model"] == True: + # checkpoint = torch.load(model_config["model_path"]) + # model.load_state_dict(checkpoint["state_dict"]) + + if world_size > 1: + model = DistributedDataParallel( + model, device_ids=[rank], find_unused_parameters=False + ) + + model_list.append(model) + + return model_list @staticmethod def _load_optimizer(optim_config, model, world_size): @@ -328,15 +345,19 @@ def _load_optimizer(optim_config, model, world_size): # Some discussions here: # https://github.com/Lightning-AI/lightning/discussions/3706 # https://discuss.pytorch.org/t/should-we-split-batch-size-according-to-ngpu-per-node-when-distributeddataparallel/72769/15 - if world_size > 1: - optim_config["lr"] = optim_config["lr"] * world_size + optim_list = [] + for i in range(len(model)): + if world_size > 1: + optim_config["lr"] = optim_config["lr"] * world_size + + optimizer = getattr(optim, optim_config["optimizer"]["optimizer_type"])( + model[i].parameters(), + lr=optim_config["lr"], + **optim_config["optimizer"].get("optimizer_args", {}), + ) + optim_list.append(optimizer) - optimizer = getattr(optim, optim_config["optimizer"]["optimizer_type"])( - model.parameters(), - lr=optim_config["lr"], - **optim_config["optimizer"].get("optimizer_args", {}), - ) - return optimizer + return optim_list @staticmethod def _load_sampler(optim_config, dataset, world_size, rank): @@ -357,24 +378,27 @@ def _load_sampler(optim_config, dataset, world_size, rank): return sampler @staticmethod - def _load_dataloader(optim_config, dataset_config, dataset, sampler, run_mode): - data_loader = {} + def _load_dataloader(optim_config, dataset_config, dataset, sampler, run_mode, model_config): + data_loader = [{} for _ in range(model_config["model_ensemble"])] + batch_size = optim_config.get("batch_size") - if dataset.get("train"): - data_loader["train_loader"] = get_dataloader( - dataset["train"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=sampler - ) - if dataset.get("val"): - data_loader["val_loader"] = get_dataloader( - dataset["val"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None - ) - if dataset.get("test"): - data_loader["test_loader"] = get_dataloader( - dataset["test"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None + + for i in range(model_config["model_ensemble"]): + if dataset.get("train"): + data_loader[i]["train_loader"] = get_dataloader( + dataset["train"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=sampler + ) + if dataset.get("val"): + data_loader[i]["val_loader"] = get_dataloader( + dataset["val"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None + ) + if dataset.get("test"): + data_loader[i]["test_loader"] = get_dataloader( + dataset["test"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None ) - if run_mode == "predict" and dataset.get("predict"): - data_loader["predict_loader"] = get_dataloader( - dataset["predict"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None + if run_mode == "predict" and dataset.get("predict"): + data_loader[i]["predict_loader"] = get_dataloader( + dataset["predict"], batch_size=batch_size, num_workers=dataset_config.get("num_workers", 0), sampler=None ) return data_loader @@ -383,7 +407,10 @@ def _load_dataloader(optim_config, dataset_config, dataset, sampler, run_mode): def _load_scheduler(scheduler_config, optimizer): scheduler_type = scheduler_config["scheduler_type"] scheduler_args = scheduler_config["scheduler_args"] - scheduler = LRScheduler(optimizer, scheduler_type, scheduler_args) + scheduler = [] + for i in range(len(optimizer)): + scheduler.append(LRScheduler(optimizer[i], scheduler_type, scheduler_args)) + return scheduler @staticmethod @@ -412,72 +439,121 @@ def validate(self): def predict(self): """Implemented by derived classes.""" - def update_best_model(self, metric, write_model=False, write_csv=False): + def update_best_model(self, metric, index=None, write_model=False, write_csv=False): """Updates the best val metric and model, saves the best model, and saves the best model predictions""" - self.best_metric = metric[type(self.loss_fn).__name__]["metric"] + self.best_metric[index] = metric[type(self.loss_fn).__name__]["metric"] + if str(self.rank) not in ("cpu", "cuda"): - self.best_model_state = copy.deepcopy(self.model.module.state_dict()) + self.best_model_state[index] = copy.deepcopy(self.model[index].module.state_dict()) else: - self.best_model_state = copy.deepcopy(self.model.state_dict()) + self.best_model_state[index] = copy.deepcopy(self.model[index].state_dict()) + if write_model == True: - self.save_model("best_checkpoint.pt", metric, True) + self.save_model("best_checkpoint.pt", index, metric, True) if write_csv == True: logging.debug( f"Saving prediction results for epoch {self.epoch} to: /results/{self.timestamp_id}/train_results/" - ) + ) + if "train" in self.write_output: - self.predict(self.data_loader["train_loader"], "train") - if "val" in self.write_output and self.data_loader.get("val_loader"): - self.predict(self.data_loader["val_loader"], "val") - if "test" in self.write_output and self.data_loader.get("test_loader"): - self.predict(self.data_loader["test_loader"], "test") + self.predict(self.data_loader[index]["train_loader"], "train") + if "val" in self.write_output and self.data_loader[index].get("val_loader"): + self.predict(self.data_loader[index]["val_loader"], "val") + if "test" in self.write_output and self.data_loader[index].get("test_loader"): + self.predict(self.data_loader[index]["test_loader"], "test") - def save_model(self, checkpoint_file, metric=None, training_state=True): + def save_model(self, checkpoint_file, index=None, metric=None, training_state=True): """Saves the model state dict""" - if str(self.rank) not in ("cpu", "cuda"): - if training_state: - state = { - "epoch": self.epoch, - "step": self.step, - "state_dict": self.model.module.state_dict(), - "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.scheduler.state_dict(), - "scaler":self.scaler.state_dict(), - "best_metric": self.best_metric, - "identifier": self.timestamp_id, - "seed": torch.random.initial_seed(), - } - else: - state = {"state_dict": self.model.module.state_dict(), "metric": metric} - else: - if training_state: - state = { - "epoch": self.epoch, - "step": self.step, - "state_dict": self.model.state_dict(), - "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.scheduler.state_dict(), - "scaler":self.scaler.state_dict(), - "best_metric": self.best_metric, - "identifier": self.timestamp_id, - "seed": torch.random.initial_seed(), - } + if index != None: + if str(self.rank) not in ("cpu", "cuda"): + if training_state: + state = { + "epoch": self.epoch, + "step": self.step, + "state_dict": self.model[index].module.state_dict(), + "optimizer": self.optimizer[index].state_dict(), + "scheduler": self.scheduler[index].scheduler.state_dict(), + "scaler":self.scaler.state_dict(), + "best_metric": self.best_metric[index], + "identifier": self.timestamp_id, + "seed": torch.random.initial_seed(), + } + else: + state = {"state_dict": self.model[index].module.state_dict(), "metric": metric} else: - state = {"state_dict": self.model.state_dict(), "metric": metric} + if training_state: + state = { + "epoch": self.epoch, + "step": self.step, + "state_dict": self.model[index].state_dict(), + "optimizer": self.optimizer[index].state_dict(), + "scheduler": self.scheduler[index].scheduler.state_dict(), + "scaler":self.scaler.state_dict(), + "best_metric": self.best_metric[index], + "identifier": self.timestamp_id, + "seed": torch.random.initial_seed(), + } + else: + state = {"state_dict": self.model[index].state_dict(), "metric": metric} + + num = str(index) + curr_checkpt_dir = os.path.join( + self.save_dir, "results", self.timestamp_id, f"checkpoint_{num}" + ) + os.makedirs(curr_checkpt_dir, exist_ok=True) + filename = os.path.join(curr_checkpt_dir, checkpoint_file) - curr_checkpt_dir = os.path.join( - self.save_dir, "results", self.timestamp_id, "checkpoint" - ) - os.makedirs(curr_checkpt_dir, exist_ok=True) - filename = os.path.join(curr_checkpt_dir, checkpoint_file) + torch.save(state, filename) + del state + else: + state = [] + for i in range(len(self.model)): + if str(self.rank) not in ("cpu", "cuda"): + if training_state: + state.append({ + "epoch": self.epoch, + "step": self.step, + "state_dict": self.model.module[i].state_dict(), + "optimizer": self.optimizer[i].state_dict(), + "scheduler": self.scheduler[i].scheduler.state_dict(), + "scaler":self.scaler.state_dict(), + "best_metric": self.best_metric[i], + "identifier": self.timestamp_id, + "seed": torch.random.initial_seed(), + }) + else: + state.append({"state_dict": self.model[i].module.state_dict(), "metric": metric[i]}) + else: + if training_state: + state.append({ + "epoch": self.epoch, + "step": self.step, + "state_dict": self.model[i].state_dict(), + "optimizer": self.optimizer[i].state_dict(), + "scheduler": self.scheduler[i].scheduler.state_dict(), + "scaler": self.scaler.state_dict(), + "best_metric": self.best_metric[i], + "identifier": self.timestamp_id, + "seed": torch.random.initial_seed(), + }) + else: + state.append({"state_dict": self.model[i].state_dict(), "metric": metric}) + + for x in range(len(self.model)): + num = str(x) + curr_checkpt_dir = os.path.join( + self.save_dir, "results", self.timestamp_id, f"checkpoint_{num}" + ) + os.makedirs(curr_checkpt_dir, exist_ok=True) + filename = os.path.join(curr_checkpt_dir, checkpoint_file) - torch.save(state, filename) - del state + torch.save(state[x], filename) + del state return filename - def save_results(self, output, results_dir, filename, node_level_predictions=False, labels=True): + def save_results(self, output, results_dir, filename, node_level_predictions=False, labels=True, std=False): results_path = os.path.join( self.save_dir, "results", self.timestamp_id, results_dir ) @@ -490,12 +566,20 @@ def save_results(self, output, results_dir, filename, node_level_predictions=Fal id_headers += ["node_id"] if labels==True: - num_cols = (shape[1] - len(id_headers)) // 2 - headers = id_headers + ["target"] * num_cols + ["prediction"] * num_cols + if std == True: + num_cols = (shape[1] - len(id_headers)) // 3 + headers = id_headers + ["target"] * num_cols + ["prediction"] * num_cols + ["std"] * num_cols + else: + num_cols = (shape[1] - len(id_headers)) // 2 + headers = id_headers + ["target"] * num_cols + ["prediction"] * num_cols else: - num_cols = (shape[1] - len(id_headers)) - headers = id_headers + ["prediction"] * num_cols - + if std == True: + num_cols = (shape[1] - len(id_headers)) // 2 + headers = id_headers + ["prediction"] * num_cols + ["std"] * num_cols + else: + num_cols = (shape[1] - len(id_headers)) + headers = id_headers + ["prediction"] * num_cols + with open(filename, "w") as f: csvwriter = csv.writer(f) for i in range(0, len(output)+1): @@ -516,34 +600,35 @@ def load_checkpoint(self, load_training_state=True): # checkpoint_file = os.path.join(checkpoint_path, "checkpoint", "checkpoint.pt") # Load params from checkpoint - checkpoint = torch.load(self.checkpoint_path) + self.checkpoint_path = self.checkpoint_path.split(",") + checkpoint = [torch.load(i) for i in self.checkpoint_path] - if str(self.rank) not in ("cpu", "cuda"): - self.model.module.load_state_dict(checkpoint["state_dict"]) - self.best_model_state = copy.deepcopy(self.model.module.state_dict()) - else: - self.model.load_state_dict(checkpoint["state_dict"]) - self.best_model_state = copy.deepcopy(self.model.state_dict()) - - if load_training_state == True: - if checkpoint.get("optimizer"): - self.optimizer.load_state_dict(checkpoint["optimizer"]) - if checkpoint.get("scheduler"): - self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) - self.scheduler.update_lr() - if checkpoint.get("epoch"): - self.epoch = checkpoint["epoch"] - if checkpoint.get("step"): - self.step = checkpoint["step"] - if checkpoint.get("best_metric"): - self.best_metric = checkpoint["best_metric"] - if checkpoint.get("seed"): - seed = checkpoint["seed"] - self.set_seed(seed) - if checkpoint.get("scaler"): - self.scaler.load_state_dict(checkpoint["scaler"]) - #todo: load dataset to recreate the same split as the prior run - #self._load_dataset(dataset_config, task) + for n, check in enumerate(checkpoint): + if str(self.rank) not in ("cpu", "cuda"): + self.model[n].module.load_state_dict(checkpoint[n]["state_dict"]) + self.best_model_state[n] = copy.deepcopy(self.model[n].module.state_dict()) + else: + self.model[n].load_state_dict(checkpoint[n]["state_dict"]) + self.best_model_state[n] = copy.deepcopy(self.model[n].state_dict()) + if load_training_state == True: + if checkpoint[n].get("optimizer"): + self.optimizer[n].load_state_dict(checkpoint[n]["optimizer"]) + if checkpoint[n].get("scheduler"): + self.scheduler[n].scheduler.load_state_dict(checkpoint[n]["scheduler"]) + self.scheduler[n].update_lr() + if checkpoint[n].get("epoch"): + self.epoch = checkpoint[n]["epoch"] + if checkpoint[n].get("step"): + self.step = checkpoint[n]["step"] + if checkpoint[n].get("best_metric"): + self.best_metric[n] = checkpoint[n]["best_metric"] + if checkpoint[n].get("seed"): + seed = checkpoint[n]["seed"] + self.set_seed(seed) + if checkpoint[n].get("scaler"): + self.scaler.load_state_dict(checkpoint[n]["scaler"]) + #todo: load dataset to recreate the same split as the prior run + #self._load_dataset(dataset_config, task) # Loads portion of model dict into a new model for fine tuning def load_pre_trained_weights(self, load_training_state=False): @@ -552,43 +637,42 @@ def load_pre_trained_weights(self, load_training_state=False): if not self.checkpoint_path: raise ValueError("No checkpoint directory specified in config.") checkpoints_folder = os.path.join(self.fine_tune_from, 'checkpoint') - - load_model = torch.load(self.checkpoint_path, map_location=self.device) - load_state = load_model["state_dict"] - - model_state = self.model.state_dict() - - print(model_state.keys()) - for name, param in load_state.items(): - #if name not in model_state or name.split('.')[0] in "post_lin_list": - if name not in model_state: - logging.debug('NOT loaded: %s', name) - continue - else: - logging.debug('loaded: %s', name) - if isinstance(param, torch.nn.parameter.Parameter): - # backwards compatibility for serialized parameters - param = param.data - model_state[name].copy_(param) - logging.info("Loaded pre-trained model with success.") - - if load_training_state == True: - if checkpoint.get("optimizer"): - self.optimizer.load_state_dict(checkpoint["optimizer"]) - if checkpoint.get("scheduler"): - self.scheduler.scheduler.load_state_dict(checkpoint["scheduler"]) - self.scheduler.update_lr() - #if checkpoint.get("epoch"): - # self.epoch = checkpoint["epoch"] - #if checkpoint.get("step"): - # self.step = checkpoint["step"] - #if checkpoint.get("best_metric"): - # self.best_metric = checkpoint["best_metric"] - if checkpoint.get("seed"): - seed = checkpoint["seed"] - self.set_seed(seed) - if checkpoint.get("scaler"): - self.scaler.load_state_dict(checkpoint["scaler"]) + + self.checkpoint_path = self.checkpoint_path.split(",") + load_model = [torch.load(i, map_location=self.device) for i in self.checkpoint_path] + load_state = [i["state_dict"] for i in load_model] + model_state = [i.state_dict() for i in self.model] + + for x in range(len(self.model)): + for name, param in load_state[x].items(): + #if name not in model_state or name.split('.')[0] in "post_lin_list": + if name not in model_state[x]: + logging.debug('NOT loaded: %s', name) + continue + else: + logging.debug('loaded: %s', name) + if isinstance(param, torch.nn.parameter.Parameter): + # backwards compatibility for serialized parameters + param = param.data + model_state[x][name].copy_(param) + logging.info("Loaded pre-trained model with success.") + if load_training_state == True: + if checkpoint.get("optimizer"): + self.optimizer[x].load_state_dict(checkpoint["optimizer"]) + if checkpoint.get("scheduler"): + self.scheduler[x].scheduler.load_state_dict(checkpoint["scheduler"]) + self.scheduler[x].update_lr() + #if checkpoint.get("epoch"): + # self.epoch = checkpoint["epoch"] + #if checkpoint.get("step"): + # self.step = checkpoint["step"] + #if checkpoint.get("best_metric"): + # self.best_metric = checkpoint["best_metric"] + if checkpoint.get("seed"): + seed = checkpoint["seed"] + self.set_seed(seed) + if checkpoint.get("scaler"): + self.scaler.load_state_dict(checkpoint["scaler"]) @staticmethod def set_seed(seed): diff --git a/matdeeplearn/trainers/property_trainer.py b/matdeeplearn/trainers/property_trainer.py index 3fd71839..e2e14d44 100644 --- a/matdeeplearn/trainers/property_trainer.py +++ b/matdeeplearn/trainers/property_trainer.py @@ -68,7 +68,7 @@ def train(self): if str(self.rank) not in ("cpu", "cuda"): dist.barrier() - + end_epoch = ( self.max_checkpoint_epochs + start_epoch if self.max_checkpoint_epochs @@ -79,11 +79,11 @@ def train(self): logging.info("Starting regular training") if str(self.rank) not in ("cpu", "cuda"): logging.info( - f"running for {end_epoch - start_epoch} epochs on {type(self.model.module).__name__} model" + f"Running for {end_epoch - start_epoch} epochs on {type(self.model[0].module).__name__} model" ) else: logging.info( - f"running for {end_epoch - start_epoch} epochs on {type(self.model).__name__} model" + f"Running for {end_epoch - start_epoch} epochs on {type(self.model[0]).__name__} model" ) for epoch in range(start_epoch, end_epoch): @@ -91,32 +91,40 @@ def train(self): if self.train_sampler: self.train_sampler.set_epoch(epoch) # skip_steps = self.step % len(self.train_loader) - train_loader_iter = iter(self.data_loader["train_loader"]) + train_loader_iter = [] + for i in range(len(self.model)): + train_loader_iter.append(iter(self.data_loader[i]["train_loader"])) # metrics for every epoch - _metrics = {} + _metrics = [{} for _ in range(len(self.model))] #for i in range(skip_steps, len(self.train_loader)): - pbar = tqdm(range(0, len(self.data_loader["train_loader"])), disable=not self.batch_tqdm) + pbar = tqdm(range(0, len(self.data_loader[0]["train_loader"])), disable=not self.batch_tqdm) for i in pbar: #self.epoch = epoch + (i + 1) / len(self.train_loader) #self.step = epoch * len(self.train_loader) + i + 1 #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) - self.model.train() + batch = [] + for n, mod in enumerate(self.model): + mod.train() + batch.append(next(train_loader_iter[n]).to(self.rank)) # Get a batch of train data - batch = next(train_loader_iter).to(self.rank) - #print(epoch, i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024), torch.sum(batch.n_atoms)) + # batch = next(train_loader_iter).to(self.rank) + # print(epoch, i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024), torch.sum(batch.n_atoms)) # Compute forward, loss, backward with autocast(enabled=self.use_amp): - out = self._forward(batch) - loss = self._compute_loss(out, batch) + out_list = self._forward(batch) + loss = self._compute_loss(out_list, batch) #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) - grad_norm = self._backward(loss) - pbar.set_description("Batch Loss {:.4f}, grad norm {:.4f}".format(loss.item(), grad_norm.item())) + grad_norm = [] + for i in range(len(self.model)): + grad_norm.append(self._backward(loss[i], i)) + pbar.set_description("Batch Loss {:.4f}, grad norm {:.4f}".format(torch.mean(torch.stack(loss)).item(), torch.mean(torch.stack(grad_norm)).item())) # Compute metrics # TODO: revert _metrics to be empty per batch, so metrics are logged per batch, not per epoch # keep option to log metrics per epoch - _metrics = self._compute_metrics(out, batch, _metrics) - self.metrics = self.evaluator.update("loss", loss.item(), out["output"].shape[0], _metrics) + for n in range(len(self.model)): + _metrics[n] = self._compute_metrics(out_list[n], batch[n], _metrics[n]) + self.metrics[n] = self.evaluator.update("loss", loss[n].item(), out_list[n]["output"].shape[0], _metrics[n]) self.epoch = epoch + 1 @@ -132,7 +140,7 @@ def train(self): self.save_model(checkpoint_file="checkpoint.pt", training_state=True) # Evaluate on validation set if it exists - if self.data_loader.get("val_loader"): + if self.data_loader[0].get("val_loader"): metric = self.validate("val") else: metric = self.metrics @@ -141,72 +149,83 @@ def train(self): self.epoch_time = time.time() - epoch_start_time # Log metrics if epoch % self.train_verbosity == 0: - if self.data_loader.get("val_loader"): + if self.data_loader[0].get("val_loader"): self._log_metrics(metric) else: self._log_metrics() # Update best val metric and model, and save best model and predicted outputs - if metric[type(self.loss_fn).__name__]["metric"] < self.best_metric: - if self.output_frequency == 0: - if self.model_save_frequency == 1: - self.update_best_model(metric, write_model=True, write_csv=False) - else: - self.update_best_model(metric, write_model=False, write_csv=False) - elif self.output_frequency == 1: - if self.model_save_frequency == 1: - self.update_best_model(metric, write_model=True, write_csv=True) - else: - self.update_best_model(metric, write_model=False, write_csv=True) - # step scheduler, using validation error + for i in range(len(self.model)): + if metric[i][type(self.loss_fn).__name__]["metric"] < self.best_metric[i]: + if self.output_frequency == 0: + if self.model_save_frequency == 1: + self.update_best_model(metric[i], i, write_model=True, write_csv=False) + else: + self.update_best_model(metric[i], i, write_model=False, write_csv=False) + elif self.output_frequency == 1: + if self.model_save_frequency == 1: + self.update_best_model(metric[i], i, write_model=True, write_csv=True) + else: + self.update_best_model(metric[i], i, write_model=False, write_csv=True) + self._scheduler_step() + torch.cuda.empty_cache() if self.best_model_state: - if str(self.rank) in "0": - self.model.module.load_state_dict(self.best_model_state) - elif str(self.rank) in ("cpu", "cuda"): - self.model.load_state_dict(self.best_model_state) + for i in range(len(self.model)): + if str(self.rank) in "0": + self.model[i].module.load_state_dict(self.best_model_state[i]) + elif str(self.rank) in ("cpu", "cuda"): + self.model[i].load_state_dict(self.best_model_state[i]) #if self.data_loader.get("test_loader"): # metric = self.validate("test") # test_loss = metric[type(self.loss_fn).__name__]["metric"] #else: # test_loss = "N/A" if self.model_save_frequency != -1: - self.save_model("best_checkpoint.pt", metric, True) + self.save_model("best_checkpoint.pt", index=None, metric=metric, training_state=True) logging.info("Final Losses: ") if "train" in self.write_output: - self.predict(self.data_loader["train_loader"], "train") - if "val" in self.write_output and self.data_loader.get("val_loader"): - self.predict(self.data_loader["val_loader"], "val") - if "test" in self.write_output and self.data_loader.get("test_loader"): - self.predict(self.data_loader["test_loader"], "test") + self.predict(self.data_loader[0]["train_loader"], "train") + if "val" in self.write_output and self.data_loader[0].get("val_loader"): + self.predict(self.data_loader[0]["val_loader"], "val") + if "test" in self.write_output and self.data_loader[0].get("test_loader"): + self.predict(self.data_loader[0]["test_loader"], "test") return self.best_model_state @torch.no_grad() def validate(self, split="val"): - self.model.eval() - evaluator, metrics = Evaluator(), {} - - if split == "val": - loader_iter = iter(self.data_loader["val_loader"]) - elif split == "test": - loader_iter = iter(self.data_loader["test_loader"]) - elif split == "train": - loader_iter = iter(self.data_loader["train_loader"]) + for i in range(len(self.model)): + self.model[i].eval() + + evaluator, metrics = Evaluator(), [{} for _ in range(len(self.model))] - for i in range(0, len(loader_iter)): + loader_iter = [] + for i in range(len(self.model)): + if split == "val": + loader_iter.append(iter(self.data_loader[i]["val_loader"])) + elif split == "test": + loader_iter.append(iter(self.data_loader[i]["test_loader"])) + elif split == "train": + loader_iter.append(iter(self.data_loader[i]["train_loader"])) + + for i in range(0, len(loader_iter[0])): #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) - batch = next(loader_iter).to(self.rank) - out = self._forward(batch.to(self.rank)) - loss = self._compute_loss(out, batch) + batch = [] + for i in range(len(self.model)): + batch.append(next(loader_iter[i]).to(self.rank)) + + out_list = self._forward(batch) + loss = self._compute_loss(out_list, batch) # Compute metrics #print(i, torch.cuda.memory_allocated() / (1024 * 1024), torch.cuda.memory_cached() / (1024 * 1024)) - metrics = self._compute_metrics(out, batch, metrics) - metrics = evaluator.update("loss", loss.item(), out["output"].shape[0], metrics) - del loss, batch, out + for n in range(len(self.model)): + metrics[n] = self._compute_metrics(out_list[n], batch[n], metrics[n]) + metrics[n] = evaluator.update("loss", loss[n].item(), out_list[n]["output"].shape[0], metrics[n]) + del loss, batch, out_list torch.cuda.empty_cache() @@ -214,10 +233,12 @@ def validate(self, split="val"): @torch.no_grad() def predict(self, loader, split, results_dir="train_results", write_output=True, labels=True): - self.model.eval() + for mod in self.model: + mod.eval() - assert isinstance(loader, torch.utils.data.dataloader.DataLoader) + # assert isinstance(loader, torch.utils.data.dataloader.DataLoader) + # TODO: make this compatible with model ensemble if str(self.rank) not in ("cpu", "cuda"): loader = get_dataloader( loader.dataset, batch_size=loader.batch_size, sampler=None @@ -230,40 +251,54 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, target_pos_grad = None ids_cell_grad = [] target_cell_grad = None - node_level = False - loader_iter = iter(loader) + node_level = False + + loader_iter = iter(loader) for i in range(0, len(loader_iter)): - batch = next(loader_iter).to(self.rank) - out = self._forward(batch.to(self.rank)) - batch_p = out["output"].data.cpu().numpy() - batch_ids = batch.structure_id + batch = next(loader_iter).to(self.rank) + out_list = self._forward([batch]) + out = {} + out_stack={} + for key in out_list[0].keys(): + temp = [o[key] for o in out_list] + if temp[0] is not None: + out_stack[key] = torch.stack(temp) + out[key] = torch.mean(out_stack[key], dim=0) + out[key+"_std"] = torch.std(out_stack[key], dim=0) + else: + out[key] = None + out[key+"_std"] = None + + + batch_p = [o["output"].data.cpu().numpy() for o in out_list] + batch_p_mean = out["output"].cpu().numpy() + batch_ids = batch.structure_id + batch_stds = out["output_std"].cpu().numpy() + if labels == True: loss = self._compute_loss(out, batch) metrics = self._compute_metrics(out, batch, metrics) metrics = evaluator.update( "loss", loss.item(), out["output"].shape[0], metrics - ) - if str(self.rank) not in ("cpu", "cuda"): - batch_t = batch[self.model.module.target_attr].cpu().numpy() + ) + if str(self.rank) not in ("cpu", "cuda"): + batch_t = batch[self.model[0].module.target_attr].cpu().numpy() else: - batch_t = batch[self.model.target_attr].cpu().numpy() - #batch_ids = np.array( - # [item for sublist in batch.structure_id for item in sublist] - #) + batch_t = batch[self.model[0].target_attr].cpu().numpy() # Node level prediction - if batch_p.shape[0] > loader.batch_size: - + if batch_p[0].shape[0] > loader.batch_size: node_level = True node_ids = batch.z.cpu().numpy() structure_ids = np.repeat( batch.structure_id, batch.n_atoms.cpu().numpy(), axis=0 ) batch_ids = np.column_stack((structure_ids, node_ids)) - + if out.get("pos_grad") != None: batch_p_pos_grad = out["pos_grad"].data.cpu().numpy() + batch_p_pos_grad_std = out["pos_grad_std"].data.cpu().numpy() node_ids_pos_grad = batch.z.cpu().numpy() structure_ids_pos_grad = np.repeat( batch.structure_id, batch.n_atoms.cpu().numpy(), axis=0 @@ -271,21 +306,30 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, batch_ids_pos_grad = np.column_stack((structure_ids_pos_grad, node_ids_pos_grad)) ids_pos_grad = batch_ids_pos_grad if i == 0 else np.row_stack((ids_pos_grad, batch_ids_pos_grad)) predict_pos_grad = batch_p_pos_grad if i == 0 else np.concatenate((predict_pos_grad, batch_p_pos_grad), axis=0) + predict_pos_grad_std = batch_p_pos_grad_std if i == 0 else np.concatenate((predict_pos_grad_std, batch_p_pos_grad_std), axis=0) if "forces" in batch: batch_t_pos_grad = batch["forces"].cpu().numpy() target_pos_grad = batch_t_pos_grad if i == 0 else np.concatenate((target_pos_grad, batch_t_pos_grad), axis=0) if out.get("cell_grad") != None: batch_p_cell_grad = out["cell_grad"].data.view(out["cell_grad"].data.size(0), -1).cpu().numpy() + batch_p_cell_grad_std = out["cell_grad_std"].data.view(out["cell_grad"].data.size(0), -1).cpu().numpy() batch_ids_cell_grad = batch.structure_id ids_cell_grad = batch_ids_cell_grad if i == 0 else np.row_stack((ids_cell_grad, batch_ids_cell_grad)) predict_cell_grad = batch_p_cell_grad if i == 0 else np.concatenate((predict_cell_grad, batch_p_cell_grad), axis=0) + predict_cell_grad_std = batch_p_cell_grad_std if i == 0 else np.concatenate((predict_cell_grad_std, batch_p_cell_grad_std), axis=0) if "stress" in batch: batch_t_cell_grad = batch["stress"].view(out["cell_grad"].data.size(0), -1).cpu().numpy() target_cell_grad = batch_t_cell_grad if i == 0 else np.concatenate((target_cell_grad, batch_t_cell_grad), axis=0) - - ids = batch_ids if i == 0 else np.row_stack((ids, batch_ids)) - predict = batch_p if i == 0 else np.concatenate((predict, batch_p), axis=0) + + ids = batch_ids if i == 0 else np.row_stack((ids, batch_ids)) + predict_mean = batch_p_mean if i == 0 else np.concatenate((predict_mean, batch_p_mean), axis=0) + stds = batch_stds if i == 0 else np.row_stack((stds, batch_stds)) + if i == 0: + predict = [0 for _ in range(len(self.model))] + for x in range(len(self.model)): + predict[x] = batch_p[x] if i == 0 else np.concatenate((predict[x], batch_p[x]), axis=0) + if labels == True: target = batch_t if i == 0 else np.concatenate((target, batch_t), axis=0) @@ -293,52 +337,79 @@ def predict(self, loader, split, results_dir="train_results", write_output=True, del loss, batch, out else: del batch, out - + if write_output == True: if labels == True: - self.save_results( - np.column_stack((ids, target, predict)), results_dir, f"{split}_predictions.csv", node_level - ) + if len(self.model) > 1: + self.save_results( + np.column_stack((ids, target, predict_mean, stds)), results_dir, f"{split}_predictions.csv", node_level, std=True, + ) + for x in range(len(self.model)): + mod = str(x) + self.save_results( + np.column_stack((ids, target, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, std=False, + ) + else: + self.save_results( + np.column_stack((ids, target, predict_mean)), results_dir, f"{split}_predictions.csv", node_level, std=False, + ) else: - self.save_results( - np.column_stack((ids, predict)), results_dir, f"{split}_predictions.csv", node_level - ) - + if len(self.model) > 1: + self.save_results( + np.column_stack((ids, predict_mean, stds)), results_dir, f"{split}_predictions.csv", node_level, std=True, + ) + for x in range(len(self.model)): + mod = str(x) + self.save_results( + np.column_stack((ids, predict[x])), results_dir, f"{split}_predictions_{mod}.csv", node_level, std=False, + ) + else: + self.save_results( + np.column_stack((ids, predict_mean)), results_dir, f"{split}_predictions.csv", node_level, std=False, + ) + #if out.get("pos_grad") != None: if len(ids_pos_grad) > 0: - if isinstance(target_pos_grad, np.ndarray): + if isinstance(target_pos_grad, np.ndarray): self.save_results( - np.column_stack((ids_pos_grad, target_pos_grad, predict_pos_grad)), results_dir, f"{split}_predictions_pos_grad.csv", True, True + np.column_stack((ids_pos_grad, target_pos_grad, predict_pos_grad, predict_pos_grad_std)), results_dir, f"{split}_predictions_pos_grad.csv", True, True ) - else: + else: self.save_results( - np.column_stack((ids_pos_grad, predict_pos_grad)), results_dir, f"{split}_predictions_pos_grad.csv", True, False - ) + np.column_stack((ids_pos_grad, predict_pos_grad, predict_pos_grad_std)), results_dir, f"{split}_predictions_pos_grad.csv", True, False + ) #if out.get("cell_grad") != None: if len(ids_cell_grad) > 0: - if isinstance(target_cell_grad, np.ndarray): + if isinstance(target_cell_grad, np.ndarray): self.save_results( - np.column_stack((ids_cell_grad, target_cell_grad, predict_cell_grad)), results_dir, f"{split}_predictions_cell_grad.csv", False, True + np.column_stack((ids_cell_grad, target_cell_grad, predict_cell_grad, predict_cell_grad_std)), results_dir, f"{split}_predictions_cell_grad.csv", False, True ) - else: + else: self.save_results( - np.column_stack((ids_cell_grad, predict_cell_grad)), results_dir, f"{split}_predictions_cell_grad.csv", False, False - ) - + np.column_stack((ids_cell_grad, predict_cell_grad, predict_cell_grad_std)), results_dir, f"{split}_predictions_cell_grad.csv", False, False + ) + if labels == True: predict_loss = metrics[type(self.loss_fn).__name__]["metric"] - logging.info("Saved {:s} error: {:.5f}".format(split, predict_loss)) - predictions = {"ids":ids, "predict":predict, "target":target} + logging.info("Saved {:s} error: {:.5f}".format(split, predict_loss)) + if len(self.model) > 1: + predictions = {"ids":ids, "predict":predict_mean, "target":target, "std": stds} + else: + predictions = {"ids":ids, "predict":predict_mean, "target":target} else: - predictions = {"ids":ids, "predict":predict} - + if len(self.model) > 1: + predictions = {"ids":ids, "predict":predict_mean, "std": stds} + else: + predictions = {"ids":ids, "predict":predict_mean} + torch.cuda.empty_cache() return predictions def predict_by_calculator(self, loader): - self.model.eval() - + for x, mod in self.model: + mod.eval() + assert isinstance(loader, torch.utils.data.dataloader.DataLoader) assert len(loader) == 1, f"Predicting by calculator only allows one structure at a time, but got {len(loader)} structures." @@ -351,78 +422,111 @@ def predict_by_calculator(self, loader): loader_iter = iter(loader) for i in range(0, len(loader_iter)): batch = next(loader_iter).to(self.rank) - out = self._forward(batch.to(self.rank)) - + out_list = self._forward(batch.to(self.rank)) + out = {} + out_stack={} + for key in out_list[0].keys(): + temp = [o[key] for o in out_list] + if temp[0] is not None: + out_stack[key] = torch.stack(temp) + out[key] = torch.mean(out_stack[key], dim=0) + else: + out[key] = None + energy = None if out.get('output') is None else out.get('output').data.cpu().numpy() stress = None if out.get('cell_grad') is None else out.get('cell_grad').view(-1, 3).data.cpu().numpy() forces = None if out.get('pos_grad') is None else out.get('pos_grad').data.cpu().numpy() results = {'energy': energy, 'stress': stress, 'forces': forces} + return results def _forward(self, batch_data): - output = self.model(batch_data) + if len(batch_data) > 1: + output = [] + for i in range(len(self.model)): + output.append(self.model[i](batch_data[i])) + else: + output = [] + for i in range(len(self.model)): + output.append(self.model[i](batch_data[0])) return output def _compute_loss(self, out, batch_data): - loss = self.loss_fn(out, batch_data) + if isinstance(out, list): + loss = [] + for i in range(len(out)): + loss.append(self.loss_fn(out[i], batch_data[i])) + else: + loss = self.loss_fn(out, batch_data) return loss - def _backward(self, loss): - self.optimizer.zero_grad(set_to_none=True) + def _backward(self, loss, index=None): + self.optimizer[index].zero_grad(set_to_none=True) self.scaler.scale(loss).backward() if self.clip_grad_norm: grad_norm = torch.nn.utils.clip_grad_norm_( - self.model.parameters(), + self.model[index].parameters(), max_norm=self.clip_grad_norm, - ) - self.scaler.step(self.optimizer) - self.scaler.update() + ) + self.scaler.step(self.optimizer[index]) + self.scaler.update() + return grad_norm + def _compute_metrics(self, out, batch_data, metrics): # TODO: finish this method - property_target = batch_data.to(self.rank) + try: + property_target = batch_data.to(self.rank) + except: + property_target = batch_data metrics = self.evaluator.eval( out, property_target, self.loss_fn, prev_metrics=metrics - ) + ) return metrics def _log_metrics(self, val_metrics=None): - train_loss = self.metrics[type(self.loss_fn).__name__]["metric"] + train_loss = [torch.tensor(i[type(self.loss_fn).__name__]["metric"]) for i in self.metrics] + train_loss = torch.mean(torch.stack(train_loss)).item() + lr = self.scheduler[0].lr if not val_metrics: val_loss = "N/A" logging.info( "Epoch: {:04d}, Learning Rate: {:.6f}, Training Error: {:.5f}, Val Error: {}, Time per epoch (s): {:.5f}".format( int(self.epoch - 1), - self.scheduler.lr, + lr, train_loss, val_loss, self.epoch_time, ) ) else: - val_loss = val_metrics[type(self.loss_fn).__name__]["metric"] + val_loss = [torch.tensor(i[type(self.loss_fn).__name__]["metric"]) for i in val_metrics] + val_loss = torch.mean(torch.stack(val_loss)).item() + lr = self.scheduler[0].lr logging.info( "Epoch: {:04d}, Learning Rate: {:.6f}, Training Error: {:.5f}, Val Error: {:.5f}, Time per epoch (s): {:.5f}".format( int(self.epoch - 1), - self.scheduler.lr, + lr, train_loss, val_loss, self.epoch_time, ) ) + def _load_task(self): """Initializes task-specific info. Implemented by derived classes.""" pass def _scheduler_step(self): - if self.scheduler.scheduler_type == "ReduceLROnPlateau": - self.scheduler.step( - metrics=self.metrics[type(self.loss_fn).__name__]["metric"] - ) - else: - self.scheduler.step() + for i in range(len(self.model)): + if self.scheduler[i].scheduler_type == "ReduceLROnPlateau": + self.scheduler[i].step( + metrics=self.metrics[i][type(self.loss_fn).__name__]["metric"] + ) + else: + self.scheduler[i].step()