Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 5 additions & 29 deletions matdeeplearn/preprocessor/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,43 +439,19 @@ def add_selfloop(
return edge_indices, edge_weights, distance_matrix_masked


def load_node_representation(node_representation="onehot"):
node_rep_path = Path(__file__).parent
default_reps = {"onehot": str(node_rep_path / "./node_representations/onehot.csv")}
def one_hot_node_rep(Z, device):
return F.one_hot(Z - 1, num_classes = 100)

rep_file_path = node_representation
if node_representation in default_reps:
rep_file_path = default_reps[node_representation]

file_type = rep_file_path.split(".")[-1]
loaded_rep = None

if file_type == "csv":
loaded_rep = np.genfromtxt(rep_file_path, delimiter=",")
# TODO: need to check if typecasting to integer is needed
loaded_rep = loaded_rep.astype(int)

elif file_type == "json":
# TODO
pass

return loaded_rep


def generate_node_features(input_data, n_neighbors, device, use_degree=False):
node_reps = load_node_representation()
node_reps = torch.from_numpy(node_reps).to(device)
n_elements, n_features = node_reps.shape

def generate_node_features(input_data, n_neighbors, device, use_degree=False, node_rep_func = one_hot_node_rep):
if isinstance(input_data, Data):
input_data.x = node_reps[input_data.z - 1].view(-1, n_features)
input_data.x = node_rep_func(input_data.z, device = device)
if use_degree:
return one_hot_degree(input_data, n_neighbors)
return input_data

for i, data in enumerate(input_data):
# minus 1 as the reps are 0-indexed but atomic number starts from 1
data.x = node_reps[data.z - 1].view(-1, n_features).float()
data.x = node_rep_func(data.z, device = device).float()

#for i, data in enumerate(input_data):
#input_data[i] = one_hot_degree(data, n_neighbors)
Expand Down