Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 45 additions & 56 deletions deepmd/utils/data_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,28 +353,15 @@ def set_sys_probs(self, sys_probs=None, auto_prob_style: str = "prob_sys_size"):
elif auto_prob_style == "prob_sys_size":
probs = self.prob_nbatches
elif auto_prob_style[:14] == "prob_sys_size;":
probs = self._prob_sys_size_ext(auto_prob_style)
probs = prob_sys_size_ext(
auto_prob_style, self.get_nsystems(), self.nbatches
)
else:
raise RuntimeError("Unknown auto prob style: " + auto_prob_style)
else:
probs = self._process_sys_probs(sys_probs)
probs = process_sys_probs(sys_probs, self.nbatches)
self.sys_probs = probs

def _get_sys_probs(self, sys_probs, auto_prob_style): # depreciated
if sys_probs is None:
if auto_prob_style == "prob_uniform":
prob_v = 1.0 / float(self.nsystems)
prob = [prob_v for ii in range(self.nsystems)]
elif auto_prob_style == "prob_sys_size":
prob = self.prob_nbatches
elif auto_prob_style[:14] == "prob_sys_size;":
prob = self._prob_sys_size_ext(auto_prob_style)
else:
raise RuntimeError("unknown style " + auto_prob_style)
else:
prob = self._process_sys_probs(sys_probs)
return prob

def get_batch(self, sys_idx: Optional[int] = None) -> dict:
# batch generation style altered by Ziyao Li:
# one should specify the "sys_prob" and "auto_prob_style" params
Expand Down Expand Up @@ -623,42 +610,44 @@ def _check_type_map_consistency(self, type_map_list):
ret = ii
return ret

def _process_sys_probs(self, sys_probs):
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
assigned_sum_prob = np.sum(type_filter * sys_probs)
# 1e-8 is to handle floating point error; See #1917
assert (
assigned_sum_prob <= 1.0 + 1e-8
), "the sum of assigned probability should be less than 1"
rest_sum_prob = 1.0 - assigned_sum_prob
if not np.isclose(rest_sum_prob, 0):
rest_nbatch = (1 - type_filter) * self.nbatches
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
ret_prob = rest_prob + type_filter * sys_probs
else:
ret_prob = sys_probs
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
return ret_prob

def _prob_sys_size_ext(self, keywords):
block_str = keywords.split(";")[1:]
block_stt = []
block_end = []
block_weights = []
for ii in block_str:
stt = int(ii.split(":")[0])
end = int(ii.split(":")[1])
weight = float(ii.split(":")[2])
assert weight >= 0, "the weight of a block should be no less than 0"
block_stt.append(stt)
block_end.append(end)
block_weights.append(weight)
nblocks = len(block_str)
block_probs = np.array(block_weights) / np.sum(block_weights)
sys_probs = np.zeros([self.get_nsystems()])
for ii in range(nblocks):
nbatch_block = self.nbatches[block_stt[ii] : block_end[ii]]
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii]
return sys_probs

def process_sys_probs(sys_probs, nbatch):
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
assigned_sum_prob = np.sum(type_filter * sys_probs)
# 1e-8 is to handle floating point error; See #1917
assert (
assigned_sum_prob <= 1.0 + 1e-8
), "the sum of assigned probability should be less than 1"
rest_sum_prob = 1.0 - assigned_sum_prob
if not np.isclose(rest_sum_prob, 0):
rest_nbatch = (1 - type_filter) * nbatch
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
ret_prob = rest_prob + type_filter * sys_probs
else:
ret_prob = sys_probs
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
return ret_prob


def prob_sys_size_ext(keywords, nsystems, nbatch):
block_str = keywords.split(";")[1:]
block_stt = []
block_end = []
block_weights = []
for ii in block_str:
stt = int(ii.split(":")[0])
end = int(ii.split(":")[1])
weight = float(ii.split(":")[2])
assert weight >= 0, "the weight of a block should be no less than 0"
block_stt.append(stt)
block_end.append(end)
block_weights.append(weight)
nblocks = len(block_str)
block_probs = np.array(block_weights) / np.sum(block_weights)
sys_probs = np.zeros([nsystems])
for ii in range(nblocks):
nbatch_block = nbatch[block_stt[ii] : block_end[ii]]
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii]
return sys_probs
9 changes: 7 additions & 2 deletions source/tests/test_deepmd_data_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
prob_sys_size_ext,
)

if GLOBAL_NP_FLOAT_PRECISION == np.float32:
Expand Down Expand Up @@ -310,7 +311,9 @@ def test_prob_sys_size_1(self):
batch_size = 1
test_size = 1
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
prob = ds._prob_sys_size_ext("prob_sys_size; 0:2:2; 2:4:8")
prob = prob_sys_size_ext(
"prob_sys_size; 0:2:2; 2:4:8", ds.get_nsystems(), ds.get_nbatches()
)
self.assertAlmostEqual(np.sum(prob), 1)
self.assertAlmostEqual(np.sum(prob[0:2]), 0.2)
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)
Expand All @@ -332,7 +335,9 @@ def test_prob_sys_size_2(self):
batch_size = 1
test_size = 1
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
prob = ds._prob_sys_size_ext("prob_sys_size; 1:2:0.4; 2:4:1.6")
prob = prob_sys_size_ext(
"prob_sys_size; 1:2:0.4; 2:4:1.6", ds.get_nsystems(), ds.get_nbatches()
)
self.assertAlmostEqual(np.sum(prob), 1)
self.assertAlmostEqual(np.sum(prob[1:2]), 0.2)
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)
Expand Down