diff --git a/source/train/Data.py b/source/train/Data.py index f3ecaeeb62..149c76bbe8 100644 --- a/source/train/Data.py +++ b/source/train/Data.py @@ -97,6 +97,8 @@ def check_batch_size (self, batch_size) : tmpe = np.load(os.path.join(ii, "coord.npy")).astype(global_ener_float_precision) else: tmpe = np.load(os.path.join(ii, "coord.npy")).astype(global_np_float_precision) + if tmpe.ndim == 1: + tmpe = tmpe.reshape([1,-1]) if tmpe.shape[0] < batch_size : return ii, tmpe.shape[0] return None @@ -106,6 +108,8 @@ def check_test_size (self, test_size) : tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(global_ener_float_precision) else: tmpe = np.load(os.path.join(self.test_dir, "coord.npy")).astype(global_np_float_precision) + if tmpe.ndim == 1: + tmpe = tmpe.reshape([1,-1]) if tmpe.shape[0] < test_size : return self.test_dir, tmpe.shape[0] else : @@ -271,6 +275,8 @@ def _load_set(self, set_name) : coord = np.load(path).astype(global_ener_float_precision) else: coord = np.load(path).astype(global_np_float_precision) + if coord.ndim == 1: + coord = coord.reshape([1,-1]) nframes = coord.shape[0] assert(coord.shape[1] == self.data_dict['coord']['ndof'] * self.natoms) # load keys