diff --git a/source/train/Trainer.py b/source/train/Trainer.py index b7112370ae..0023e4ed12 100644 --- a/source/train/Trainer.py +++ b/source/train/Trainer.py @@ -368,7 +368,8 @@ def _init_sess_distrib(self): # save_checkpoint_steps = self.save_freq) def train (self, - data) : + data, + valid_data) : stop_batch = self.stop_batch if self.run_opt.is_distrib : self._init_sess_distrib() @@ -417,7 +418,7 @@ def train (self, feed_dict_batch[self.place_holders['is_training']] = True if self.display_in_training and is_first_step : - self.test_on_the_fly(fp, data, feed_dict_batch) + self.test_on_the_fly(fp, valid_data, feed_dict_batch) is_first_step = False if self.timing_in_training : tic = time.time() self.sess.run([self.train_op], feed_dict = feed_dict_batch, options=prf_options, run_metadata=prf_run_metadata) @@ -428,7 +429,7 @@ def train (self, if self.display_in_training and (cur_batch % self.disp_freq == 0) : tic = time.time() - self.test_on_the_fly(fp, data, feed_dict_batch) + self.test_on_the_fly(fp, valid_data, feed_dict_batch) toc = time.time() test_time = toc - tic if self.timing_in_training : @@ -461,12 +462,12 @@ def print_head (self) : def test_on_the_fly (self, fp, - data, + valid_data, feed_dict_batch) : # ! altered by Marián Rynik # Do not need to pass numb_test here as data object already knows it. # Both DeepmdDataSystem and ClassArg parse the same json file - test_data = data.get_test(n_test=data.get_sys_ntest()) + test_data = valid_data.get_test(n_test=valid_data.get_sys_ntest()) feed_dict_test = {} for kk in test_data.keys(): if kk == 'find_type' or kk == 'type' : diff --git a/source/train/train.py b/source/train/train.py index e7978361b2..55357fa670 100755 --- a/source/train/train.py +++ b/source/train/train.py @@ -136,11 +136,30 @@ def _do_work(jdata, run_opt): sys_probs = sys_probs, auto_prob_style = auto_prob_style) data.add_dict(data_requirement) + + ### START modified by ziyao + + # init valid data + valid_data = data # use train data if no validation is specified + if 'valid' in jdata.keys(): + valid_systems = j_must_have(jdata['valid'], 'systems') + valid_data = DeepmdDataSystem(valid_systems, + 1, + test_size, + rcut, + set_prefix=set_pfx, + type_map=ipt_type_map, + modifier=modifier) + valid_data.print_summary(run_opt, + sys_probs=sys_probs, + auto_prob_style=auto_prob_style) + valid_data.add_dict(data_requirement) + # build the model with stats from the first system model.build (data, stop_batch) # train the model with the provided systems in a cyclic way start_time = time.time() - model.train (data) + model.train (data, valid_data) end_time = time.time() run_opt.message("finished training\nwall time: %.3f s" % (end_time-start_time))