Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions deepmd/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"""

import logging
from deepmd.env import tf
from deepmd.env import op_module
from deepmd.env import tf, FITTING_NET_PATTERN
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import get_pattern_nodes_from_graph_def
from os.path import abspath

# load grad of force module
Expand All @@ -21,6 +21,36 @@

log = logging.getLogger(__name__)

def _transfer_fitting_net_trainable_variables(sess, old_graph_def, raw_graph_def):
old_pattern = FITTING_NET_PATTERN
raw_pattern = FITTING_NET_PATTERN\
.replace('idt', 'idt+_\d+')\
.replace('bias', 'bias+_\d+')\
.replace('matrix', 'matrix+_\d+')
old_graph_nodes = get_pattern_nodes_from_graph_def(
old_graph_def,
old_pattern
)
try :
raw_graph_def = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
raw_graph_def, # The graph_def is used to retrieve the nodes
[n + '_1' for n in old_graph_nodes], # The output node names are used to select the usefull nodes
)
except AssertionError:
# if there's no additional nodes
return old_graph_def

raw_graph_nodes = get_pattern_nodes_from_graph_def(
raw_graph_def,
raw_pattern
)
for node in old_graph_def.node:
if node.name not in old_graph_nodes.keys():
continue
tensor = tf.make_ndarray(raw_graph_nodes[node.name + '_1'])
node.attr["value"].tensor.tensor_content = tensor.tostring()
return old_graph_def

def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> List[str]:
"""Get node names based on model type.
Expand Down Expand Up @@ -205,6 +235,13 @@ def freeze(
output_node_list, # The output node names are used to select the usefull nodes
)

# If we need to transfer the fitting net variables
output_graph_def = _transfer_fitting_net_trainable_variables(
sess,
output_graph_def,
input_graph_def
)

# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
Expand Down
21 changes: 2 additions & 19 deletions deepmd/entrypoints/transfer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module used for transfering parameters between models."""

from typing import Dict, Optional, Sequence, Tuple
from deepmd.env import tf
from deepmd.env import tf, TRANSFER_PATTERN
import re
import numpy as np
import logging
Expand Down Expand Up @@ -225,24 +225,7 @@ def load_transform_node(graph: tf.Graph) -> Dict[str, tf.Tensor]:
Dict[str, tf.Tensor]
mapping on graph node names and corresponding tensors
"""
transform_node_pattern = re.compile(
r"filter_type_\d+/matrix_\d+_\d+|"
r"filter_type_\d+/bias_\d+_\d+|"
r"filter_type_\d+/idt_\d+_\d+|"
r"layer_\d+_type_\d+/matrix|"
r"layer_\d+_type_\d+/bias|"
r"layer_\d+_type_\d+/idt|"
r"final_layer_type_\d+/matrix|"
r"descrpt_attr/t_avg|"
r"descrpt_attr/t_std|"
r"final_layer_type_\d+/bias|"
r"fitting_attr/t_fparam_avg|"
r"fitting_attr/t_fparam_istd|"
r"fitting_attr/t_aparam_avg|"
r"fitting_attr/t_aparam_istd|"
r"model_attr/t_tab_info|"
r"model_attr/t_tab_data|"
)
transform_node_pattern = re.compile(TRANSFER_PATTERN)

transform_node = {}
for node in graph.node:
Expand Down
36 changes: 36 additions & 0 deletions deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import os
import re
import platform
from configparser import ConfigParser
from imp import reload
Expand Down Expand Up @@ -35,10 +36,45 @@
"reset_default_tf_session_config",
"op_module",
"op_grads_module",
"TRANSFER_PATTERN",
"FITTING_NET_PATTERN",
"EMBEDDING_NET_PATTERN",
]

SHARED_LIB_MODULE = "op"

EMBEDDING_NET_PATTERN = str(
r"filter_type_\d+/matrix_\d+_\d+|"
r"filter_type_\d+/bias_\d+_\d+|"
r"filter_type_\d+/idt_\d+_\d+|"
r"filter_type_all/matrix_\d+_\d+|"
r"filter_type_all/matrix_\d+_\d+_\d+|"
r"filter_type_all/bias_\d+_\d+|"
r"filter_type_all/bias_\d+_\d+_\d+|"
r"filter_type_all/idt_\d+_\d+|"
)

FITTING_NET_PATTERN = str(
r"layer_\d+_type_\d+/matrix|"
r"layer_\d+_type_\d+/bias|"
r"layer_\d+_type_\d+/idt|"
r"final_layer_type_\d+/matrix|"
r"final_layer_type_\d+/bias|"
)

TRANSFER_PATTERN = \
EMBEDDING_NET_PATTERN + \
FITTING_NET_PATTERN + \
str(
r"descrpt_attr/t_avg|"
r"descrpt_attr/t_std|"
r"fitting_attr/t_fparam_avg|"
r"fitting_attr/t_fparam_istd|"
r"fitting_attr/t_aparam_avg|"
r"fitting_attr/t_aparam_istd|"
r"model_attr/t_tab_info|"
r"model_attr/t_tab_data|"
)

def set_env_if_empty(key: str, value: str, verbose: bool = True):
"""Set environment variable only if it is empty.
Expand Down
47 changes: 36 additions & 11 deletions deepmd/utils/graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import numpy as np
from typing import Tuple, Dict
from deepmd.env import tf
from deepmd.env import tf, EMBEDDING_NET_PATTERN, FITTING_NET_PATTERN
from deepmd.utils.sess import run_sess
from deepmd.utils.errors import GraphWithoutTensorError

Expand Down Expand Up @@ -112,6 +112,30 @@ def get_tensor_by_type(node,
return tensor


def get_pattern_nodes_from_graph_def(graph_def: tf.GraphDef, pattern: str) -> Dict:
"""
Get the pattern nodes with the given tf.GraphDef object

Parameters
----------
graph_def
The input tf.GraphDef object
pattern
The node pattern within the graph_def

Returns
----------
Dict
The fitting net nodes within the given tf.GraphDef object
"""
nodes = {}
pattern = re.compile(pattern)
for node in graph_def.node:
if re.fullmatch(pattern, node.name) != None:
nodes[node.name] = node.attr["value"].tensor
return nodes


def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str = "") -> Dict:
"""
Get the embedding net nodes with the given tf.GraphDef object
Expand All @@ -128,11 +152,16 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str =
Dict
The embedding net nodes within the given tf.GraphDef object
"""
embedding_net_nodes = {}
embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+"
for node in graph_def.node:
if re.fullmatch(embedding_net_pattern, node.name) != None:
embedding_net_nodes[node.name] = node.attr["value"].tensor
# embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+"
if suffix is not "":
embedding_net_pattern = EMBEDDING_NET_PATTERN\
.replace('/idt', suffix + '/idt')\
.replace('/bias', suffix + '/bias')\
.replace('/matrix', suffix + '/matrix')
else:
embedding_net_pattern = EMBEDDING_NET_PATTERN

embedding_net_nodes = get_pattern_nodes_from_graph_def(graph_def, embedding_net_pattern)
for key in embedding_net_nodes.keys():
assert key.find('bias') > 0 or key.find(
'matrix') > 0, "currently, only support weight matrix and bias matrix at the tabulation op!"
Expand Down Expand Up @@ -222,11 +251,7 @@ def get_fitting_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict:
Dict
The fitting net nodes within the given tf.GraphDef object
"""
fitting_net_nodes = {}
fitting_net_pattern = "layer_\d+_type_\d+/matrix+|layer_\d+_type_\d+/bias+|layer_\d+_type_\d+/idt+|final_layer_type_\d+/matrix+|final_layer_type_\d+/bias"
for node in graph_def.node:
if re.fullmatch(fitting_net_pattern, node.name) != None:
fitting_net_nodes[node.name] = node.attr["value"].tensor
fitting_net_nodes = get_pattern_nodes_from_graph_def(graph_def, FITTING_NET_PATTERN)
for key in fitting_net_nodes.keys():
assert key.find('bias') > 0 or key.find('matrix') > 0 or key.find(
'idt') > 0, "currently, only support weight matrix, bias and idt at the model compression process!"
Expand Down