From 0eade8bd206d4836e611323d7a551682f29d5f4e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 6 Sep 2021 01:16:52 -0400 Subject: [PATCH] merge duplicated NeighborStat.get_stat Note: this is a simple fix to resolve #1088, but I think we should design a clear architecture to call neighbor stat. This should reduce the half of the time, but it may be still too long. We can consider some better algorithm to calculate neighbour stat (like KDtree?) for further optimization. --- deepmd/entrypoints/train.py | 12 +++++++++++- deepmd/train/trainer.py | 13 +++---------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index bbb1e55bd2..a91d0aa3cf 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Any from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have -from deepmd.env import tf, reset_default_tf_session_config +from deepmd.env import tf, reset_default_tf_session_config, GLOBAL_TF_FLOAT_PRECISION from deepmd.infer.data_modifier import DipoleChargeModifier from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions from deepmd.train.trainer import DPTrainer @@ -262,6 +262,16 @@ def get_nbor_stat(jdata, rcut): neistat = NeighborStat(ntypes, rcut) min_nbor_dist, max_nbor_size = neistat.get_stat(train_data) + + # moved from traier.py as duplicated + # TODO: this is a simple fix but we should have a clear + # architecture to call neighbor stat + tf.constant(min_nbor_dist, + name = 'train_attr/min_nbor_dist', + dtype = GLOBAL_TF_FLOAT_PRECISION) + tf.constant(max_nbor_size, + name = 'train_attr/max_nbor_size', + dtype = GLOBAL_TF_FLOAT_PRECISION) return min_nbor_dist, max_nbor_size def get_sel(jdata, rcut): diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 0fa08e77a2..5b07c0439e 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -314,16 +314,9 @@ def build (self, if self.run_opt.init_mode == 'init_from_frz_model': self._init_from_frz_model() - self.neighbor_stat \ - = NeighborStat(self.ntypes, self.descrpt.get_rcut()) - self.min_nbor_dist, self.max_nbor_size \ - = self.neighbor_stat.get_stat(data) - tf.constant(self.min_nbor_dist, - name = 'train_attr/min_nbor_dist', - dtype = GLOBAL_TF_FLOAT_PRECISION) - tf.constant(self.max_nbor_size, - name = 'train_attr/max_nbor_size', - dtype = GLOBAL_TF_FLOAT_PRECISION) + # neighbor_stat is moved to train.py as duplicated + # TODO: this is a simple fix but we should have a clear + # architecture to call neighbor stat else : self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3]) self.fitting.init_variables(get_fitting_net_variables(self.model_param['compress']['model_file']))