diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index a79acc63..0dfd0181 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -439,43 +439,19 @@ def add_selfloop( return edge_indices, edge_weights, distance_matrix_masked -def load_node_representation(node_representation="onehot"): - node_rep_path = Path(__file__).parent - default_reps = {"onehot": str(node_rep_path / "./node_representations/onehot.csv")} +def one_hot_node_rep(Z, device): + return F.one_hot(Z - 1, num_classes = 100) - rep_file_path = node_representation - if node_representation in default_reps: - rep_file_path = default_reps[node_representation] - - file_type = rep_file_path.split(".")[-1] - loaded_rep = None - - if file_type == "csv": - loaded_rep = np.genfromtxt(rep_file_path, delimiter=",") - # TODO: need to check if typecasting to integer is needed - loaded_rep = loaded_rep.astype(int) - - elif file_type == "json": - # TODO - pass - - return loaded_rep - - -def generate_node_features(input_data, n_neighbors, device, use_degree=False): - node_reps = load_node_representation() - node_reps = torch.from_numpy(node_reps).to(device) - n_elements, n_features = node_reps.shape - +def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = one_hot_node_rep): if isinstance(input_data, Data): - input_data.x = node_reps[input_data.z - 1].view(-1, n_features) + input_data.x = node_rep_func(input_data.z, device = device) if use_degree: return one_hot_degree(input_data, n_neighbors) return input_data for i, data in enumerate(input_data): # minus 1 as the reps are 0-indexed but atomic number starts from 1 - data.x = node_reps[data.z - 1].view(-1, n_features).float() + data.x = node_rep_func(data.z, device = device).float() #for i, data in enumerate(input_data): #input_data[i] = one_hot_degree(data, n_neighbors)