diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index a690bbef4a..aa511bb41f 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -13,6 +13,7 @@ from deepmd.utils.tabulate import DPTabulate from deepmd.utils.type_embed import embed_atom_type from deepmd.utils.sess import run_sess +from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph class DescrptSeA (): @docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys())) @@ -275,6 +276,11 @@ def enable_compression(self, table_extrapolate, table_stride_1, table_stride_2) + + graph, _ = load_graph_def(model_file) + self.davg = get_tensor_by_name_from_graph(graph, 'descrpt_attr/t_avg') + self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr/t_std') + def build (self, diff --git a/deepmd/entrypoints/compress.py b/deepmd/entrypoints/compress.py index 5114d6410e..c689481b96 100644 --- a/deepmd/entrypoints/compress.py +++ b/deepmd/entrypoints/compress.py @@ -141,8 +141,3 @@ def compress( log.info("\n\n") log.info("stage 2: freeze the model") freeze(checkpoint_folder=checkpoint_folder, output=output, node_names=None) - - # stage 3: transfer the model - log.info("\n\n") - log.info("stage 3: transfer the model") - transfer(old_model=input, raw_model=output, output=output) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 49c6a5a537..8d3cfe2b42 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -317,6 +317,7 @@ def build (self, else : assert 'rcut' in self.descrpt_param, "Error: descriptor must have attr rcut!" self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3]) + self.fitting.init_variables(get_fitting_net_variables(self.model_param['compress']['model_file'])) if self.is_compress or self.model_type == 'compressed_model': tf.constant("compressed_model", name = 'model_type', dtype = tf.string)