From c13bc1dfdc00d08a3abb28661a32a731de016b05 Mon Sep 17 00:00:00 2001 From: Shaochen Shi Date: Thu, 19 Aug 2021 15:26:08 +0800 Subject: [PATCH] Add the argument `tensorboard_freq` to control sampling ratio during training. --- deepmd/train/trainer.py | 5 +++-- deepmd/utils/argcheck.py | 2 ++ doc/train-input-auto.rst | 7 +++++++ doc/train/tensorboard.md | 3 ++- 4 files changed, 14 insertions(+), 3 deletions(-) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index d000becece..62545281ba 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -257,6 +257,7 @@ def _init_param(self, jdata): self.profiling_file = tr_data.get('profiling_file', 'timeline.json') self.tensorboard = self.run_opt.is_chief and tr_data.get('tensorboard', False) self.tensorboard_log_dir = tr_data.get('tensorboard_log_dir', 'log') + self.tensorboard_freq = tr_data.get('tensorboard_freq', 1) # self.sys_probs = tr_data['sys_probs'] # self.auto_prob_style = tr_data['auto_prob'] self.useBN = False @@ -475,11 +476,11 @@ def train (self, train_data = None, valid_data=None) : train_feed_dict = self.get_feed_dict(train_batch, is_training=True) # use tensorboard to visualize the training of deepmd-kit # it will takes some extra execution time to generate the tensorboard data - if self.tensorboard : + if self.tensorboard and (cur_batch % self.tensorboard_freq == 0): summary, _ = run_sess(self.sess, [summary_merged_op, self.train_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata) tb_train_writer.add_summary(summary, cur_batch) - else : + else: run_sess(self.sess, [self.train_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata) if self.timing_in_training: toc = time.time() diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8b59438003..6e69481305 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -571,6 +571,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. doc_profiling_file = 'Output file for profiling.' doc_tensorboard = 'Enable tensorboard' doc_tensorboard_log_dir = 'The log directory of tensorboard outputs' + doc_tensorboard_freq = 'The frequency of writing tensorboard events.' arg_training_data = training_data_args() arg_validation_data = validation_data_args() @@ -591,6 +592,7 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. Argument("profiling_file", str, optional=True, default='timeline.json', doc=doc_profiling_file), Argument("tensorboard", bool, optional=True, default=False, doc=doc_tensorboard), Argument("tensorboard_log_dir", str, optional=True, default='log', doc=doc_tensorboard_log_dir), + Argument("tensorboard_freq", int, optional=True, default=1, doc=doc_tensorboard_freq), ] doc_training = 'The training options.' diff --git a/doc/train-input-auto.rst b/doc/train-input-auto.rst index cf9063d7ca..0add17d992 100644 --- a/doc/train-input-auto.rst +++ b/doc/train-input-auto.rst @@ -1485,3 +1485,10 @@ training: The log directory of tensorboard outputs + .. _`training/tensorboard_freq`: + + tensorboard_freq: + | type: ``int``, optional, default: ``1`` + | argument path: ``training/tensorboard_freq`` + + The frequency of writing tensorboard events. diff --git a/doc/train/tensorboard.md b/doc/train/tensorboard.md index 64f7cb344f..aa92bfaaab 100644 --- a/doc/train/tensorboard.md +++ b/doc/train/tensorboard.md @@ -43,6 +43,7 @@ subsection will enable the tensorboard data analysis. eg. **water_se_a.json**. "time_training":true, "tensorboard": true, "tensorboard_log_dir":"log", + "tensorboard_freq": 1000, "profiling": false, "profiling_file":"timeline.json", "_comment": "that's all" @@ -53,7 +54,7 @@ Once you have event files, run TensorBoard and provide the log directory. This should print that TensorBoard has started. Next, connect to http://tensorboard_server_ip:6006. TensorBoard requires a logdir to read logs from. For info on configuring TensorBoard, run tensorboard --help. -One can easily change the log name with "tensorboard_log_dir". +One can easily change the log name with "tensorboard_log_dir" and the sampling frequency with "tensorboard_freq". ```bash tensorboard --logdir path/to/logs