Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions source/tests/test_data_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ def _setUp(self):
batch_size,
test_size,
rcut,
set_prefix=set_pfx,
run_opt=run_opt)
set_prefix=set_pfx)
data.add_dict(data_requirement)

# clear the default graph
Expand Down
3 changes: 1 addition & 2 deletions source/tests/test_data_modifier_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def _setUp(self):
batch_size,
test_size,
rcut,
set_prefix=set_pfx,
run_opt=run_opt)
set_prefix=set_pfx)
data.add_dict(data_requirement)

# clear the default graph
Expand Down
32 changes: 32 additions & 0 deletions source/tests/test_deepmd_data_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,38 @@ def test_get_batch(self):
), 0.0)



def test_prob_sys_size_1(self) :
batch_size = 1
test_size = 1
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
prob = ds._prob_sys_size_ext("prob_sys_size; 0:2:2; 2:4:8")
self.assertAlmostEqual(np.sum(prob), 1)
self.assertAlmostEqual(np.sum(prob[0:2]), 0.2)
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)
# number of training set is self.nset-1
# shift is the total number of set size shift...
shift = np.sum(np.arange(self.nset-1))
self.assertAlmostEqual(prob[1]/prob[0], float(self.nframes[1]*(self.nset-1)+shift)/float(self.nframes[0]*(self.nset-1)+shift))
self.assertAlmostEqual(prob[3]/prob[2], float(self.nframes[3]*(self.nset-1)+shift)/float(self.nframes[2]*(self.nset-1)+shift))


def test_prob_sys_size_1(self) :
batch_size = 1
test_size = 1
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
prob = ds._prob_sys_size_ext("prob_sys_size; 1:2:0.4; 2:4:1.6")
self.assertAlmostEqual(np.sum(prob), 1)
self.assertAlmostEqual(np.sum(prob[1:2]), 0.2)
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)
# number of training set is self.nset-1
# shift is the total number of set size shift...
shift = np.sum(np.arange(self.nset-1))
self.assertAlmostEqual(prob[0], 0.0)
self.assertAlmostEqual(prob[1], 0.2)
self.assertAlmostEqual(prob[3]/prob[2], float(self.nframes[3]*(self.nset-1)+shift)/float(self.nframes[2]*(self.nset-1)+shift))


def _idx_map(self, target, idx_map, ndof):
natoms = len(idx_map)
target = target.reshape([-1, natoms, ndof])
Expand Down
110 changes: 82 additions & 28 deletions source/train/DataSystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__ (self,
rcut,
set_prefix = 'set',
shuffle_test = True,
run_opt = None,
type_map = None,
modifier = None) :
# init data
Expand Down Expand Up @@ -81,10 +80,6 @@ def __init__ (self,
warnings.warn("system %s required test size is larger than the size of the dataset %s (%d > %d)" % \
(self.system_dirs[ii], chk_ret[0], test_size, chk_ret[1]))

# print summary
if run_opt is not None:
self.print_summary(run_opt)


def _load_test(self, ntests = -1):
self.test_data = collections.defaultdict(list)
Expand Down Expand Up @@ -155,24 +150,57 @@ def reduce(self,
def get_data_dict(self) :
return self.data_systems[0].get_data_dict()


def _get_sys_probs(self,
sys_probs,
auto_prob_style) :
if sys_probs is None :
if auto_prob_style == "prob_uniform" :
prob = None
elif auto_prob_style == "prob_sys_size" :
prob = self.prob_nbatches
elif auto_prob_style[:14] == "prob_sys_size;" :
prob = self._prob_sys_size_ext(auto_prob_style)
else :
raise RuntimeError("unkown style " + auto_prob_style )
else :
prob = self._process_sys_probs(sys_probs)
return prob


def get_batch (self,
sys_idx = None,
sys_weights = None,
style = "prob_sys_size") :
sys_probs = None,
auto_prob_style = "prob_sys_size") :
"""
Get a batch of data from the data system

Parameters
----------
sys_idx: int
The index of system from which the batch is get.
If sys_idx is not None, `sys_probs` and `auto_prob_style` are ignored
If sys_idx is None, automatically determine the system according to `sys_probs` or `auto_prob_style`, see the following.
sys_probs: list of float
The probabilitis of systems to get the batch.
Summation of positive elements of this list should be no greater than 1.
Element of this list can be negative, the probability of the corresponding system is determined automatically by the number of batches in the system.
auto_prob_style: float
Determine the probability of systems automatically. The method is assigned by this key and can be
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()
- "prob_sys_size" : the probability of a system is proportional to the number of batches in the system
- "prob_sys_size;stt_idx:end_idx:weight;stt_idx:end_idx:weight;..." :
the list of systems is devided into blocks. A block is specified by `stt_idx:end_idx:weight`,
where `stt_idx` is the starting index of the system, `end_idx` is then ending (not including) index of the system,
the probabilities of the systems in this block sums up to `weight`, and the relatively probabilities within this block is proportional
to the number of batches in the system.
"""
if not hasattr(self, 'default_mesh') :
self._make_default_mesh()
if sys_idx is not None :
self.pick_idx = sys_idx
else :
if sys_weights is None :
if style == "prob_sys_size" :
prob = self.prob_nbatches
elif style == "prob_uniform" :
prob = None
else :
raise RuntimeError("unkown get_batch style")
else :
prob = self.process_sys_weights(sys_weights)
prob = self._get_sys_probs(sys_probs, auto_prob_style)
self.pick_idx = np.random.choice(np.arange(self.nsystems), p = prob)
b_data = self.data_systems[self.pick_idx].get_batch(self.batch_size[self.pick_idx])
b_data["natoms_vec"] = self.natoms_vec[self.pick_idx]
Expand Down Expand Up @@ -224,21 +252,26 @@ def _format_name_length(self, name, width) :
name = '-- ' + name
return name

def print_summary(self, run_opt) :
def print_summary(self,
run_opt,
sys_probs = None,
auto_prob_style = "prob_sys_size") :
prob = self._get_sys_probs(sys_probs, auto_prob_style)
tmp_msg = ""
# width 65
sys_width = 42
tmp_msg += "---Summary of DataSystem-----------------------------------------\n"
tmp_msg += "---Summary of DataSystem------------------------------------------------\n"
tmp_msg += "find %d system(s):\n" % self.nsystems
tmp_msg += "%s " % self._format_name_length('system', sys_width)
tmp_msg += "%s %s %s\n" % ('natoms', 'bch_sz', 'n_bch')
tmp_msg += "%s %s %s %5s\n" % ('natoms', 'bch_sz', 'n_bch', 'prob')
for ii in range(self.nsystems) :
tmp_msg += ("%s %6d %6d %5d\n" %
tmp_msg += ("%s %6d %6d %5d %5.3f\n" %
(self._format_name_length(self.system_dirs[ii], sys_width),
self.natoms[ii],
self.batch_size[ii],
self.nbatches[ii]) )
tmp_msg += "-----------------------------------------------------------------\n"
self.nbatches[ii],
prob[ii]) )
tmp_msg += "------------------------------------------------------------------------\n"
run_opt.message(tmp_msg)


Expand All @@ -264,18 +297,39 @@ def _check_type_map_consistency(self, type_map_list):
ret = ii
return ret

def _process_sys_weights(self, sys_weights) :
sys_weights = np.array(sys_weights)
type_filter = sys_weights >= 0
assigned_sum_prob = np.sum(type_filter * sys_weights)
def _process_sys_probs(self, sys_probs) :
sys_probs = np.array(sys_probs)
type_filter = sys_probs >= 0
assigned_sum_prob = np.sum(type_filter * sys_probs)
assert assigned_sum_prob <= 1, "the sum of assigned probability should be less than 1"
rest_sum_prob = 1. - assigned_sum_prob
rest_nbatch = (1 - type_filter) * self.nbatches
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
ret_prob = rest_prob + type_filter * sys_weights
ret_prob = rest_prob + type_filter * sys_probs
assert np.sum(ret_prob) == 1, "sum of probs should be 1"
return ret_prob


def _prob_sys_size_ext(self, keywords):
block_str = keywords.split(';')[1:]
block_stt = []
block_end = []
block_weights = []
for ii in block_str:
stt = int(ii.split(':')[0])
end = int(ii.split(':')[1])
weight = float(ii.split(':')[2])
assert(weight >= 0), "the weight of a block should be no less than 0"
block_stt.append(stt)
block_end.append(end)
block_weights.append(weight)
nblocks = len(block_str)
block_probs = np.array(block_weights) / np.sum(block_weights)
sys_probs = np.zeros([self.get_nsystems()])
for ii in range(nblocks):
nbatch_block = self.nbatches[block_stt[ii]:block_end[ii]]
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
sys_probs[block_stt[ii]:block_end[ii]] = tmp_prob * block_probs[ii]
return sys_probs



Expand Down
10 changes: 7 additions & 3 deletions source/train/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def _init_param(self, jdata):
.add('timing_in_training', bool, default = True)\
.add('profiling', bool, default = False)\
.add('profiling_file',str, default = 'timeline.json')\
.add('sys_weights', list )
.add('sys_probs', list )\
.add('auto_prob_style', str, default = "prob_sys_size")
tr_data = tr_args.parse(training_param)
self.numb_test = tr_data['numb_test']
self.disp_file = tr_data['disp_file']
Expand All @@ -189,7 +190,8 @@ def _init_param(self, jdata):
self.timing_in_training = tr_data['timing_in_training']
self.profiling = tr_data['profiling']
self.profiling_file = tr_data['profiling_file']
self.sys_weights = tr_data['sys_weights']
self.sys_probs = tr_data['sys_probs']
self.auto_prob_style = tr_data['auto_prob_style']
self.useBN = False
if fitting_type == 'ener' and self.fitting.get_numb_fparam() > 0 :
self.numb_fparam = self.fitting.get_numb_fparam()
Expand Down Expand Up @@ -391,7 +393,9 @@ def train (self,

train_time = 0
while cur_batch < stop_batch :
batch_data = data.get_batch (sys_weights = self.sys_weights)
batch_data = data.get_batch (sys_probs = self.sys_probs,
auto_prob_style = self.auto_prob_style
)
feed_dict_batch = {}
for kk in batch_data.keys():
if kk == 'find_type' or kk == 'type' :
Expand Down
6 changes: 5 additions & 1 deletion source/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def _do_work(jdata, run_opt):
batch_size = j_must_have(jdata['training'], 'batch_size')
test_size = j_must_have(jdata['training'], 'numb_test')
stop_batch = j_must_have(jdata['training'], 'stop_batch')
sys_probs = jdata['training'].get('sys_probs')
auto_prob_style = jdata['training'].get('auto_prob_style', 'prob_sys_size')
if len(type_map) == 0:
# empty type_map
ipt_type_map = None
Expand All @@ -129,9 +131,11 @@ def _do_work(jdata, run_opt):
test_size,
rcut,
set_prefix=set_pfx,
run_opt=run_opt,
type_map = ipt_type_map,
modifier = modifier)
data.print_summary(run_opt,
sys_probs = sys_probs,
auto_prob_style = auto_prob_style)
data.add_dict(data_requirement)
# build the model with stats from the first system
model.build (data, stop_batch)
Expand Down