diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index bbb1e55bd2..a91d0aa3cf 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Any from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have -from deepmd.env import tf, reset_default_tf_session_config +from deepmd.env import tf, reset_default_tf_session_config, GLOBAL_TF_FLOAT_PRECISION from deepmd.infer.data_modifier import DipoleChargeModifier from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions from deepmd.train.trainer import DPTrainer @@ -262,6 +262,16 @@ def get_nbor_stat(jdata, rcut): neistat = NeighborStat(ntypes, rcut) min_nbor_dist, max_nbor_size = neistat.get_stat(train_data) + + # moved from traier.py as duplicated + # TODO: this is a simple fix but we should have a clear + # architecture to call neighbor stat + tf.constant(min_nbor_dist, + name = 'train_attr/min_nbor_dist', + dtype = GLOBAL_TF_FLOAT_PRECISION) + tf.constant(max_nbor_size, + name = 'train_attr/max_nbor_size', + dtype = GLOBAL_TF_FLOAT_PRECISION) return min_nbor_dist, max_nbor_size def get_sel(jdata, rcut): diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 0fa08e77a2..5b07c0439e 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -314,16 +314,9 @@ def build (self, if self.run_opt.init_mode == 'init_from_frz_model': self._init_from_frz_model() - self.neighbor_stat \ - = NeighborStat(self.ntypes, self.descrpt.get_rcut()) - self.min_nbor_dist, self.max_nbor_size \ - = self.neighbor_stat.get_stat(data) - tf.constant(self.min_nbor_dist, - name = 'train_attr/min_nbor_dist', - dtype = GLOBAL_TF_FLOAT_PRECISION) - tf.constant(self.max_nbor_size, - name = 'train_attr/max_nbor_size', - dtype = GLOBAL_TF_FLOAT_PRECISION) + # neighbor_stat is moved to train.py as duplicated + # TODO: this is a simple fix but we should have a clear + # architecture to call neighbor stat else : self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3]) self.fitting.init_variables(get_fitting_net_variables(self.model_param['compress']['model_file']))