diff --git a/matdeeplearn/models/torchmd_etEarly.py b/matdeeplearn/models/torchmd_etEarly.py index ca9685a2..7691a386 100644 --- a/matdeeplearn/models/torchmd_etEarly.py +++ b/matdeeplearn/models/torchmd_etEarly.py @@ -183,7 +183,7 @@ def _forward(self, data): #), "Distance module did not return directional information" if self.otf_edge_index == True: #data.edge_index, edge_weight, data.edge_vec, cell_offsets, offset_distance, neighbors = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) - data.edge_index, data.edge_weight, _, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) + data.edge_index, data.edge_weight, data.edge_vec, _, _, _ = self.generate_graph(data, self.cutoff_radius, self.n_neighbors) data.edge_attr = self.distance_expansion(data.edge_weight) #mask = data.edge_index[0] != data.edge_index[1]