From d9c8f8675ca052d6ebe4e50be21d36a21cb7ed55 Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Tue, 12 Sep 2023 14:33:26 -0400 Subject: [PATCH 1/5] Generate node representations on the fly --- matdeeplearn/preprocessor/helpers.py | 36 +++++++--------------------- 1 file changed, 8 insertions(+), 28 deletions(-) diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index a79acc63..01f0ba16 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -439,43 +439,23 @@ def add_selfloop( return edge_indices, edge_weights, distance_matrix_masked -def load_node_representation(node_representation="onehot"): - node_rep_path = Path(__file__).parent - default_reps = {"onehot": str(node_rep_path / "./node_representations/onehot.csv")} +def one_hot_node_rep(Z, device): + to_return = torch.zeros(Z.size(dim = 0), 100, dtype = torch.long, device = device) + to_return[torch.arange(Z.size(dim = 0)), Z - 1] = 1 + return to_return.view(-1, 100) - rep_file_path = node_representation - if node_representation in default_reps: - rep_file_path = default_reps[node_representation] - - file_type = rep_file_path.split(".")[-1] - loaded_rep = None - - if file_type == "csv": - loaded_rep = np.genfromtxt(rep_file_path, delimiter=",") - # TODO: need to check if typecasting to integer is needed - loaded_rep = loaded_rep.astype(int) - - elif file_type == "json": - # TODO - pass - - return loaded_rep - - -def generate_node_features(input_data, n_neighbors, device, use_degree=False): - node_reps = load_node_representation() - node_reps = torch.from_numpy(node_reps).to(device) - n_elements, n_features = node_reps.shape +def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = one_hot_node_rep): + node_reps = node_rep_func(input_data.z, device = device) if isinstance(input_data, Data): - input_data.x = node_reps[input_data.z - 1].view(-1, n_features) + input_data.x = node_reps if use_degree: return one_hot_degree(input_data, n_neighbors) return input_data for i, data in enumerate(input_data): # minus 1 as the reps are 0-indexed but atomic number starts from 1 - data.x = node_reps[data.z - 1].view(-1, n_features).float() + data.x = node_reps.float() #for i, data in enumerate(input_data): #input_data[i] = one_hot_degree(data, n_neighbors) From 20885e13864460efb3e4a6e8f10d2b9a45dc2c13 Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Tue, 12 Sep 2023 14:41:49 -0400 Subject: [PATCH 2/5] Don't try to call at the beginning of the function --- matdeeplearn/preprocessor/helpers.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index 01f0ba16..f03ce2e9 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -445,17 +445,15 @@ def one_hot_node_rep(Z, device): return to_return.view(-1, 100) def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = one_hot_node_rep): - node_reps = node_rep_func(input_data.z, device = device) - if isinstance(input_data, Data): - input_data.x = node_reps + input_data.x = node_rep_func(input_data.z, device = device) if use_degree: return one_hot_degree(input_data, n_neighbors) return input_data for i, data in enumerate(input_data): # minus 1 as the reps are 0-indexed but atomic number starts from 1 - data.x = node_reps.float() + data.x = node_rep_func(input_data.z, device = device).float() #for i, data in enumerate(input_data): #input_data[i] = one_hot_degree(data, n_neighbors) From 5e2261389d630edf3b112af7d3bb695a66099bfc Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Tue, 12 Sep 2023 14:46:40 -0400 Subject: [PATCH 3/5] Fix typo --- matdeeplearn/preprocessor/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index f03ce2e9..d67d8691 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -453,7 +453,7 @@ def generate_node_features(input_data, n_neighbors, device, use_degree=False, no for i, data in enumerate(input_data): # minus 1 as the reps are 0-indexed but atomic number starts from 1 - data.x = node_rep_func(input_data.z, device = device).float() + data.x = node_rep_func(data.z, device = device).float() #for i, data in enumerate(input_data): #input_data[i] = one_hot_degree(data, n_neighbors) From 0d70b77b2a38022c482151c71ea4e8971553ec7a Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Tue, 12 Sep 2023 15:52:11 -0400 Subject: [PATCH 4/5] Use built-in torch one hot generation --- matdeeplearn/preprocessor/helpers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index d67d8691..308fbf92 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -440,9 +440,7 @@ def add_selfloop( def one_hot_node_rep(Z, device): - to_return = torch.zeros(Z.size(dim = 0), 100, dtype = torch.long, device = device) - to_return[torch.arange(Z.size(dim = 0)), Z - 1] = 1 - return to_return.view(-1, 100) + return F.one_hot(Z - 1, num_classes = 100).to(device) def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = one_hot_node_rep): if isinstance(input_data, Data): From 5263e61453b4bd7e2490b636744c5533c72d836c Mon Sep 17 00:00:00 2001 From: Ian Slagle Date: Wed, 13 Sep 2023 20:07:03 -0400 Subject: [PATCH 5/5] Shouldn't need to send to the same device --- matdeeplearn/preprocessor/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/matdeeplearn/preprocessor/helpers.py b/matdeeplearn/preprocessor/helpers.py index 308fbf92..0dfd0181 100644 --- a/matdeeplearn/preprocessor/helpers.py +++ b/matdeeplearn/preprocessor/helpers.py @@ -440,7 +440,7 @@ def add_selfloop( def one_hot_node_rep(Z, device): - return F.one_hot(Z - 1, num_classes = 100).to(device) + return F.one_hot(Z - 1, num_classes = 100) def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = one_hot_node_rep): if isinstance(input_data, Data):