From fec85d48e79125bed797a691d571cf11a932e587 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 17 Apr 2021 20:39:49 +0800 Subject: [PATCH 1/3] printing the training error with exactly the batch used to train the model --- deepmd/entrypoints/train.py | 2 +- deepmd/train/trainer.py | 36 ++++++++++++++++++++++-------------- deepmd/utils/data_system.py | 5 ++--- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 31ee881d83..49fa3dd6c4 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -250,7 +250,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions): # init data train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier) train_data.print_summary("training") - if jdata["training"]["validation_data"] is not None: + if jdata["training"].get("validation_data", None) is not None: valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier) valid_data.print_summary("validation") else: diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 91a7b3668f..a2510a4f49 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -258,7 +258,7 @@ def _init_param(self, jdata): self.valid_numb_batch = tr_data["validation_data"].get("numb_btch", 1) else: self.valid_numb_batch = 1 - + def build (self, data, @@ -477,12 +477,14 @@ def train (self, train_data, valid_data=None) : while cur_batch < stop_batch : # first round validation: + train_batch = train_data.get_batch() if self.display_in_training and is_first_step: - self.valid_on_the_fly(fp, train_data, valid_data, print_header=True) + valid_batch = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None + self.valid_on_the_fly(fp, [train_batch], valid_batch, print_header=True) is_first_step = False if self.timing_in_training: tic = time.time() - train_feed_dict = self.get_feed_dict(train_data.get_batch(), is_training=True) + train_feed_dict = self.get_feed_dict(train_batch, is_training=True) # use tensorboard to visualize the training of deepmd-kit # it will takes some extra execution time to generate the tensorboard data if self.tensorboard : @@ -501,7 +503,8 @@ def train (self, train_data, valid_data=None) : if self.display_in_training and (cur_batch % self.disp_freq == 0): if self.timing_in_training: tic = time.time() - self.valid_on_the_fly(fp, train_data, valid_data) + valid_batch = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None + self.valid_on_the_fly(fp, [train_batch], valid_batch) if self.timing_in_training: toc = time.time() test_time = toc - tic @@ -550,11 +553,11 @@ def get_global_step(self): def valid_on_the_fly(self, fp, - train_data, - valid_data, + train_batch, + valid_batch, print_header=False): - train_results = self.get_evaluation_results(train_data, self.valid_numb_batch) - valid_results = self.get_evaluation_results(valid_data, self.valid_numb_batch) + train_results = self.get_evaluation_results(train_batch) + valid_results = self.get_evaluation_results(valid_batch) cur_batch = self.cur_batch current_lr = self.sess.run(self.learning_rate) @@ -595,17 +598,22 @@ def print_on_training(fp, train_results, valid_results, cur_batch, cur_lr): fp.write(print_str) fp.flush() - def get_evaluation_results(self, data, numb_batch): - if data is None: return None + def get_evaluation_results(self, batch_list): + if batch_list is None: return None + numb_batch = len(batch_list) sum_results = {} # sum of losses on all atoms + sum_natoms = 0 for i in range(numb_batch): - batch = data.get_batch() + batch = batch_list[i] natoms = batch["natoms_vec"] feed_dict = self.get_feed_dict(batch, is_training=False) results = self.loss.eval(self.sess, feed_dict, natoms) - + for k, v in results.items(): - sum_results[k] = sum_results.get(k, 0.) + v * results["natoms"] - avg_results = {k: v / sum_results["natoms"] for k, v in sum_results.items() if not k == "natoms"} + if k == "natoms": + sum_natoms += v + else: + sum_results[k] = sum_results.get(k, 0.) + v * results["natoms"] + avg_results = {k: v / sum_natoms for k, v in sum_results.items() if not k == "natoms"} return avg_results diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 4cdb9fb50f..8b84319eb2 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -426,15 +426,14 @@ def print_summary(self, name) : log.info(f"---Summary of DataSystem: {name:13s}-----------------------------------------------") log.info("found %d system(s):" % self.nsystems) log.info(("%s " % self._format_name_length('system', sys_width)) + - ("%6s %6s %6s %6s %5s %3s" % ('natoms', 'bch_sz', 'n_bch', "n_test", 'prob', 'pbc'))) + ("%6s %6s %6s %5s %3s" % ('natoms', 'bch_sz', 'n_bch', 'prob', 'pbc'))) for ii in range(self.nsystems) : - log.info("%s %6d %6d %6d %6d %5.3f %3s" % + log.info("%s %6d %6d %6d %5.3f %3s" % (self._format_name_length(self.system_dirs[ii], sys_width), self.natoms[ii], # TODO batch size * nbatches = number of structures self.batch_size[ii], self.nbatches[ii], - self.test_size[ii], self.sys_probs[ii], "T" if self.data_systems[ii].pbc else "F" ) ) From a1d21b9735843af571b4fafd9e44a8ef0b1a4d41 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 17 Apr 2021 23:01:40 +0800 Subject: [PATCH 2/3] fix 2 spacing issues --- deepmd/train/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index a2510a4f49..6b04d45d23 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -258,7 +258,7 @@ def _init_param(self, jdata): self.valid_numb_batch = tr_data["validation_data"].get("numb_btch", 1) else: self.valid_numb_batch = 1 - + def build (self, data, @@ -609,7 +609,7 @@ def get_evaluation_results(self, batch_list): natoms = batch["natoms_vec"] feed_dict = self.get_feed_dict(batch, is_training=False) results = self.loss.eval(self.sess, feed_dict, natoms) - + for k, v in results.items(): if k == "natoms": sum_natoms += v From 8fc9bed592b790a201c92658ad3d376e2c669a2e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 17 Apr 2021 23:19:56 +0800 Subject: [PATCH 3/3] more understandable variable names --- deepmd/train/trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 6b04d45d23..5327094a16 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -479,8 +479,8 @@ def train (self, train_data, valid_data=None) : # first round validation: train_batch = train_data.get_batch() if self.display_in_training and is_first_step: - valid_batch = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None - self.valid_on_the_fly(fp, [train_batch], valid_batch, print_header=True) + valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None + self.valid_on_the_fly(fp, [train_batch], valid_batches, print_header=True) is_first_step = False if self.timing_in_training: tic = time.time() @@ -503,8 +503,8 @@ def train (self, train_data, valid_data=None) : if self.display_in_training and (cur_batch % self.disp_freq == 0): if self.timing_in_training: tic = time.time() - valid_batch = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None - self.valid_on_the_fly(fp, [train_batch], valid_batch) + valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None + self.valid_on_the_fly(fp, [train_batch], valid_batches) if self.timing_in_training: toc = time.time() test_time = toc - tic @@ -553,11 +553,11 @@ def get_global_step(self): def valid_on_the_fly(self, fp, - train_batch, - valid_batch, + train_batches, + valid_batches, print_header=False): - train_results = self.get_evaluation_results(train_batch) - valid_results = self.get_evaluation_results(valid_batch) + train_results = self.get_evaluation_results(train_batches) + valid_results = self.get_evaluation_results(valid_batches) cur_batch = self.cur_batch current_lr = self.sess.run(self.learning_rate)