Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions source/train/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def _init_sess_distrib(self):
# save_checkpoint_steps = self.save_freq)

def train (self,
data) :
data,
valid_data) :
stop_batch = self.stop_batch
if self.run_opt.is_distrib :
self._init_sess_distrib()
Expand Down Expand Up @@ -417,7 +418,7 @@ def train (self,
feed_dict_batch[self.place_holders['is_training']] = True

if self.display_in_training and is_first_step :
self.test_on_the_fly(fp, data, feed_dict_batch)
self.test_on_the_fly(fp, valid_data, feed_dict_batch)
is_first_step = False
if self.timing_in_training : tic = time.time()
self.sess.run([self.train_op], feed_dict = feed_dict_batch, options=prf_options, run_metadata=prf_run_metadata)
Expand All @@ -428,7 +429,7 @@ def train (self,

if self.display_in_training and (cur_batch % self.disp_freq == 0) :
tic = time.time()
self.test_on_the_fly(fp, data, feed_dict_batch)
self.test_on_the_fly(fp, valid_data, feed_dict_batch)
toc = time.time()
test_time = toc - tic
if self.timing_in_training :
Expand Down Expand Up @@ -461,12 +462,12 @@ def print_head (self) :

def test_on_the_fly (self,
fp,
data,
valid_data,
feed_dict_batch) :
# ! altered by Marián Rynik
# Do not need to pass numb_test here as data object already knows it.
# Both DeepmdDataSystem and ClassArg parse the same json file
test_data = data.get_test(n_test=data.get_sys_ntest())
test_data = valid_data.get_test(n_test=valid_data.get_sys_ntest())
feed_dict_test = {}
for kk in test_data.keys():
if kk == 'find_type' or kk == 'type' :
Expand Down
21 changes: 20 additions & 1 deletion source/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,30 @@ def _do_work(jdata, run_opt):
sys_probs = sys_probs,
auto_prob_style = auto_prob_style)
data.add_dict(data_requirement)

### START modified by ziyao

# init valid data
valid_data = data # use train data if no validation is specified
if 'valid' in jdata.keys():
valid_systems = j_must_have(jdata['valid'], 'systems')
valid_data = DeepmdDataSystem(valid_systems,
1,
test_size,
rcut,
set_prefix=set_pfx,
type_map=ipt_type_map,
modifier=modifier)
valid_data.print_summary(run_opt,
sys_probs=sys_probs,
auto_prob_style=auto_prob_style)
valid_data.add_dict(data_requirement)

# build the model with stats from the first system
model.build (data, stop_batch)
# train the model with the provided systems in a cyclic way
start_time = time.time()
model.train (data)
model.train (data, valid_data)
end_time = time.time()
run_opt.message("finished training\nwall time: %.3f s" % (end_time-start_time))