diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 1eb2941ab0..31ee881d83 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -249,10 +249,10 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions): # init data train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier) - train_data.print_summary() - if "validation_data" in jdata["training"]: + train_data.print_summary("training") + if jdata["training"]["validation_data"] is not None: valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier) - valid_data.print_summary() + valid_data.print_summary("validation") else: valid_data = None @@ -309,4 +309,4 @@ def get_modifier(modi_data=None): raise RuntimeError("unknown modifier type " + str(modi_data["type"])) else: modifier = None - return modifier \ No newline at end of file + return modifier diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 6912346525..91a7b3668f 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -235,35 +235,17 @@ def _init_param(self, jdata): raise RuntimeError('get unknown fitting type when building loss function') # training - training_param = j_must_have(jdata, 'training') - - # ! first .add() altered by Marián Rynik - tr_args = ClassArg()\ - .add('disp_file', str, default = 'lcurve.out')\ - .add('disp_freq', int, default = 100)\ - .add('save_freq', int, default = 1000)\ - .add('save_ckpt', str, default = 'model.ckpt')\ - .add('display_in_training', bool, default = True)\ - .add('timing_in_training', bool, default = True)\ - .add('profiling', bool, default = False)\ - .add('profiling_file',str, default = 'timeline.json')\ - .add('tensorboard', bool, default = False)\ - .add('tensorboard_log_dir',str, default = 'log') - # .add('sys_probs', list )\ - # .add('auto_prob', str, default = "prob_sys_size") - tr_data = tr_args.parse(training_param) - # not needed - # self.numb_test = tr_data['numb_test'] - self.disp_file = tr_data['disp_file'] - self.disp_freq = tr_data['disp_freq'] - self.save_freq = tr_data['save_freq'] - self.save_ckpt = tr_data['save_ckpt'] - self.display_in_training = tr_data['display_in_training'] - self.timing_in_training = tr_data['timing_in_training'] - self.profiling = tr_data['profiling'] - self.profiling_file = tr_data['profiling_file'] - self.tensorboard = tr_data['tensorboard'] - self.tensorboard_log_dir = tr_data['tensorboard_log_dir'] + tr_data = jdata['training'] + self.disp_file = tr_data.get('disp_file', 'lcurve.out') + self.disp_freq = tr_data.get('disp_freq', 1000) + self.save_freq = tr_data.get('save_freq', 1000) + self.save_ckpt = tr_data.get('save_ckpt', 'model.ckpt') + self.display_in_training = tr_data.get('disp_training', True) + self.timing_in_training = tr_data.get('time_training', True) + self.profiling = tr_data.get('profiling', False) + self.profiling_file = tr_data.get('profiling_file', 'timeline.json') + self.tensorboard = tr_data.get('tensorboard', False) + self.tensorboard_log_dir = tr_data.get('tensorboard_log_dir', 'log') # self.sys_probs = tr_data['sys_probs'] # self.auto_prob_style = tr_data['auto_prob'] self.useBN = False @@ -272,11 +254,12 @@ def _init_param(self, jdata): else : self.numb_fparam = 0 - if "validation_data" in tr_data.keys(): # if validation set specified + if tr_data.get("validation_data", None) is not None: self.valid_numb_batch = tr_data["validation_data"].get("numb_btch", 1) else: self.valid_numb_batch = 1 + def build (self, data, stop_batch = 0) : @@ -622,7 +605,7 @@ def get_evaluation_results(self, data, numb_batch): feed_dict = self.get_feed_dict(batch, is_training=False) results = self.loss.eval(self.sess, feed_dict, natoms) - for k, v in results: + for k, v in results.items(): sum_results[k] = sum_results.get(k, 0.) + v * results["natoms"] avg_results = {k: v / sum_results["natoms"] for k, v in sum_results.items() if not k == "natoms"} - return avg_results \ No newline at end of file + return avg_results diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 728568c1da..4cdb9fb50f 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -420,10 +420,10 @@ def _format_name_length(self, name, width) : name = '-- ' + name return name - def print_summary(self) : + def print_summary(self, name) : # width 65 sys_width = 42 - log.info("---Summary of DataSystem--------------------------------------------------------------") + log.info(f"---Summary of DataSystem: {name:13s}-----------------------------------------------") log.info("found %d system(s):" % self.nsystems) log.info(("%s " % self._format_name_length('system', sys_width)) + ("%6s %6s %6s %6s %5s %3s" % ('natoms', 'bch_sz', 'n_bch', "n_test", 'prob', 'pbc')))