Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from deepmd.utils.tabulate import DPTabulate
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph

class DescrptSeA ():
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
Expand Down Expand Up @@ -275,6 +276,11 @@ def enable_compression(self,
table_extrapolate,
table_stride_1,
table_stride_2)

graph, _ = load_graph_def(model_file)
self.davg = get_tensor_by_name_from_graph(graph, 'descrpt_attr/t_avg')
self.dstd = get_tensor_by_name_from_graph(graph, 'descrpt_attr/t_std')



def build (self,
Expand Down
5 changes: 0 additions & 5 deletions deepmd/entrypoints/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,3 @@ def compress(
log.info("\n\n")
log.info("stage 2: freeze the model")
freeze(checkpoint_folder=checkpoint_folder, output=output, node_names=None)

# stage 3: transfer the model
log.info("\n\n")
log.info("stage 3: transfer the model")
transfer(old_model=input, raw_model=output, output=output)
1 change: 1 addition & 0 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def build (self,
else :
assert 'rcut' in self.descrpt_param, "Error: descriptor must have attr rcut!"
self.descrpt.enable_compression(self.model_param['compress']["min_nbor_dist"], self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3])
self.fitting.init_variables(get_fitting_net_variables(self.model_param['compress']['model_file']))

if self.is_compress or self.model_type == 'compressed_model':
tf.constant("compressed_model", name = 'model_type', dtype = tf.string)
Expand Down