diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 57bea00fac..423745cddf 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -265,8 +265,6 @@ def get_batch(self, batch_size: int) -> dict: self._load_batch_set(self.train_dirs[self.set_count % self.get_numb_set()]) self.set_count += 1 set_size = self.batch_set["coord"].shape[0] - if self.modifier is not None: - self.modifier.modify_data(self.batch_set, self) iterator_1 = self.iterator + batch_size if iterator_1 >= set_size: iterator_1 = set_size @@ -410,6 +408,8 @@ def _get_subdata(self, data, idx=None): def _load_batch_set(self, set_name: DPPath): if not hasattr(self, "batch_set") or self.get_numb_set() > 1: self.batch_set = self._load_set(set_name) + if self.modifier is not None: + self.modifier.modify_data(self.batch_set, self) self.batch_set, _ = self._shuffle_data(self.batch_set) self.reset_get_batch()