diff --git a/matdeeplearn/trainers/base_trainer.py b/matdeeplearn/trainers/base_trainer.py index 318d31b1..ba837e46 100644 --- a/matdeeplearn/trainers/base_trainer.py +++ b/matdeeplearn/trainers/base_trainer.py @@ -101,17 +101,18 @@ def __init__( logging.info( f"GPU is available: {torch.cuda.is_available()}, Quantity: {os.environ.get('LOCAL_WORLD_SIZE', None)}" ) - logging.info("Dataset(s) used:") - for key in self.dataset: - logging.info(f"Dataset length: {key, len(self.dataset[key])}") - if self.dataset.get("train"): - logging.debug(self.dataset["train"][0]) - logging.debug(self.dataset["train"][0].z[0]) - logging.debug(self.dataset["train"][0].y[0]) - else: - logging.debug(self.dataset[list(self.dataset.keys())[0]][0]) - logging.debug(self.dataset[list(self.dataset.keys())[0]][0].x[0]) - logging.debug(self.dataset[list(self.dataset.keys())[0]][0].y[0]) + if not (self.dataset is None): + logging.info("Dataset(s) used:") + for key in self.dataset: + logging.info(f"Dataset length: {key, len(self.dataset[key])}") + if self.dataset.get("train"): + logging.debug(self.dataset["train"][0]) + logging.debug(self.dataset["train"][0].z[0]) + logging.debug(self.dataset["train"][0].y[0]) + else: + logging.debug(self.dataset[list(self.dataset.keys())[0]][0]) + logging.debug(self.dataset[list(self.dataset.keys())[0]][0].x[0]) + logging.debug(self.dataset[list(self.dataset.keys())[0]][0].y[0]) if str(self.rank) not in ("cpu", "cuda"): logging.debug(self.model[0].module) @@ -144,10 +145,10 @@ def from_config(cls, config): else: rank = torch.device("cuda" if torch.cuda.is_available() else "cpu") local_world_size = 1 - dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) + dataset = cls._load_dataset(config["dataset"], config["task"]["run_mode"]) if hasattr(config["dataset"], "src") else None model = cls._load_model(config["model"], config["dataset"]["preprocess_params"], dataset, local_world_size, rank) optimizer = cls._load_optimizer(config["optim"], model, local_world_size) - sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) + sampler = cls._load_sampler(config["optim"], dataset, local_world_size, rank) if hasattr(config["dataset"], "src") else None data_loader = cls._load_dataloader( config["optim"], config["dataset"], @@ -155,7 +156,7 @@ def from_config(cls, config): sampler, config["task"]["run_mode"], config["model"] - ) + ) if hasattr(config["dataset"], "src") else None scheduler = cls._load_scheduler(config["optim"]["scheduler"], optimizer) loss = cls._load_loss(config["optim"]["loss"]) @@ -270,10 +271,11 @@ def _load_dataset(dataset_config, task): def _load_model(model_config, graph_config, dataset, world_size, rank): """Loads the model if from a config file.""" - if dataset.get("train"): - dataset = dataset["train"] - else: - dataset = dataset[list(dataset.keys())[0]] + if not (dataset is None): + if dataset.get("train"): + dataset = dataset["train"] + else: + dataset = dataset[list(dataset.keys())[0]] if isinstance(dataset, torch.utils.data.Subset): dataset = dataset.dataset @@ -293,22 +295,28 @@ def _load_model(model_config, graph_config, dataset, world_size, rank): if graph_config["node_dim"]: node_dim = graph_config["node_dim"] else: - node_dim = dataset.num_features - edge_dim = graph_config["edge_dim"] - if dataset[0]["y"].ndim == 0: - output_dim = 1 + node_dim = dataset.num_features + edge_dim = graph_config["edge_dim"] + if not (dataset is None): + if dataset[0]["y"].ndim == 0: + output_dim = 1 + else: + output_dim = dataset[0]["y"].shape[1] else: - output_dim = dataset[0]["y"].shape[1] + output_dim = graph_config["output_dim"] # Determine if this is a node or graph level model - if dataset[0]["y"].shape[0] == dataset[0]["z"].shape[0]: - model_config["prediction_level"] = "node" - elif dataset[0]["y"].shape[0] == 1: - model_config["prediction_level"] = "graph" + if not (dataset is None): + if dataset[0]["y"].shape[0] == dataset[0]["z"].shape[0]: + model_config["prediction_level"] = "node" + elif dataset[0]["y"].shape[0] == 1: + model_config["prediction_level"] = "graph" + else: + raise ValueError( + "Target labels do not have the correct dimensions for node or graph-level prediction." + ) else: - raise ValueError( - "Target labels do not have the correct dimensions for node or graph-level prediction." - ) + model_config["prediction_level"] = graph_config["prediction_level"] model_cls = registry.get_model_class(model_config["name"]) model = model_cls(