diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 69a6cbe112..09dcac2d8d 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -195,8 +195,7 @@ def __init__( assert isinstance(self.test_size, (list, np.ndarray)) assert len(self.test_size) == self.nsystems - # prob of batch, init pick idx - self.prob_nbatches = [float(i) for i in self.nbatches] / np.sum(self.nbatches) + # init pick idx self.pick_idx = 0 # derive system probabilities @@ -350,11 +349,13 @@ def set_sys_probs(self, sys_probs=None, auto_prob_style: str = "prob_sys_size"): if auto_prob_style == "prob_uniform": prob_v = 1.0 / float(self.nsystems) probs = [prob_v for ii in range(self.nsystems)] - elif auto_prob_style == "prob_sys_size": - probs = self.prob_nbatches - elif auto_prob_style[:14] == "prob_sys_size;": + elif auto_prob_style[:13] == "prob_sys_size": + if auto_prob_style == "prob_sys_size": + prob_style = f"prob_sys_size;0:{self.get_nsystems()}:1.0" + else: + prob_style = auto_prob_style probs = prob_sys_size_ext( - auto_prob_style, self.get_nsystems(), self.nbatches + prob_style, self.get_nsystems(), self.nbatches ) else: raise RuntimeError("Unknown auto prob style: " + auto_prob_style)