diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 2c5f668589..35317c3c42 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -5,6 +5,10 @@ import platform import shutil import time +from typing import ( + Dict, + List, +) import google.protobuf.message import numpy as np @@ -57,6 +61,9 @@ from deepmd.utils.argcheck import ( type_embedding_args, ) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) from deepmd.utils.errors import ( GraphTooLargeError, GraphWithoutTensorError, @@ -903,19 +910,55 @@ def train(self, train_data=None, valid_data=None): train_time = 0 total_train_time = 0.0 + wall_time_tic = time.time() + + next_batch_train_op = None + next_fitting_key = None + next_train_batch_list = None + next_datasetloader = None + + # dataset loader op + if not self.multi_task_mode: + datasetloader = DatasetLoader(train_data) + data_op = datasetloader.build() + else: + datasetloader = {} + data_op = {} + for fitting_key in self.fitting_type_dict: + datasetloader[fitting_key] = DatasetLoader(train_data[fitting_key]) + data_op[fitting_key] = datasetloader[fitting_key].build() while cur_batch < stop_batch: # first round validation: + if is_first_step: + if not self.multi_task_mode: + train_batch = train_data.get_batch() + batch_train_op = self.train_op + else: + fitting_idx = dp_random.choice( + np.arange(self.nfitting), p=np.array(self.fitting_prob) + ) + fitting_key = self.fitting_key_list[fitting_idx] + train_batch = train_data[fitting_key].get_batch() + batch_train_op = self.train_op[fitting_key] + else: + train_batch = next_datasetloader.get_data_dict(next_train_batch_list) + batch_train_op = next_batch_train_op + fitting_key = next_fitting_key + # for next round if not self.multi_task_mode: - train_batch = train_data.get_batch() - batch_train_op = self.train_op + next_datasetloader = datasetloader + next_batch_train_op = self.train_op + next_train_batch_op = data_op else: fitting_idx = dp_random.choice( np.arange(self.nfitting), p=np.array(self.fitting_prob) ) - fitting_key = self.fitting_key_list[fitting_idx] - train_batch = train_data[fitting_key].get_batch() - batch_train_op = self.train_op[fitting_key] + next_fitting_key = self.fitting_key_list[fitting_idx] + next_datasetloader = datasetloader[next_fitting_key] + next_batch_train_op = self.train_op[fitting_key] + next_train_batch_op = data_op[fitting_key] + if self.display_in_training and is_first_step: if self.run_opt.is_chief: if not self.multi_task_mode: @@ -964,18 +1007,18 @@ def train(self, train_data=None, valid_data=None): # use tensorboard to visualize the training of deepmd-kit # it will takes some extra execution time to generate the tensorboard data if self.tensorboard and (cur_batch % self.tensorboard_freq == 0): - summary, _ = run_sess( + summary, _, next_train_batch_list = run_sess( self.sess, - [summary_merged_op, batch_train_op], + [summary_merged_op, batch_train_op, next_train_batch_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata, ) tb_train_writer.add_summary(summary, cur_batch) else: - run_sess( + _, next_train_batch_list = run_sess( self.sess, - [batch_train_op], + [batch_train_op, next_train_batch_op], feed_dict=train_feed_dict, options=prf_options, run_metadata=prf_run_metadata, @@ -1025,14 +1068,16 @@ def train(self, train_data=None, valid_data=None): if self.timing_in_training: toc = time.time() test_time = toc - tic + wall_time = toc - wall_time_tic log.info( - "batch %7d training time %.2f s, testing time %.2f s" - % (cur_batch, train_time, test_time) + "batch %7d training time %.2f s, testing time %.2f s, total wall time %.2f s" + % (cur_batch, train_time, test_time, wall_time) ) # the first training time is not accurate if cur_batch > self.disp_freq or stop_batch < 2 * self.disp_freq: total_train_time += train_time train_time = 0 + wall_time_tic = toc if ( self.save_freq > 0 and cur_batch % self.save_freq == 0 @@ -1405,3 +1450,64 @@ def _change_energy_bias( bias_shift=bias_shift, ntest=self.model_param.get("data_bias_nsample", 10), ) + + +class DatasetLoader: + """Generate an OP that loads the training data from the given DeepmdDataSystem. + + It can be used to load the training data in the training process, so there is + no waiting time between training steps. + + Parameters + ---------- + train_data : DeepmdDataSystem + The training data. + + Examples + -------- + >>> loader = DatasetLoader(train_data) + >>> data_op = loader.build() + >>> with tf.Session() as sess: + >>> data_list = sess.run(data_op) + >>> data_dict = loader.get_data_dict(data_list) + """ + + def __init__(self, train_data: DeepmdDataSystem): + self.train_data = train_data + # get the keys of the data + batch_data = self.train_data.get_batch() + self.data_keys = batch_data.keys() + self.data_types = [tf.as_dtype(x.dtype) for x in batch_data.values()] + + def build(self) -> List[tf.Tensor]: + """Build the OP that loads the training data. + + Returns + ------- + List[tf.Tensor] + Tensor of the loaded data. + """ + train_data = self.train_data + + def get_train_batch() -> List[np.ndarray]: + batch_data = train_data.get_batch() + # convert dict to list of arryas + batch_data = tuple([batch_data[kk] for kk in self.data_keys]) + return batch_data + + return tf.py_func(get_train_batch, [], self.data_types, name="train_data") + + def get_data_dict(self, batch_list: List[np.ndarray]) -> Dict[str, np.ndarray]: + """Generate a dict of the loaded data. + + Parameters + ---------- + batch_list : List[np.ndarray] + The loaded data. + + Returns + ------- + Dict[str, np.ndarray] + The dict of the loaded data. + """ + return {kk: vv for kk, vv in zip(self.data_keys, batch_list)}