Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 56 additions & 33 deletions source/train/transform.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from deepmd.env import tf
import re
import numpy as np
def transform(args):
new_graph = load_graph(args.raw_model)
raw_graph = load_graph(args.raw_model)
old_graph = load_graph(args.old_model)
print("%d ops in the raw graph\n%d ops in the old graph" %(len(new_graph.node),len(old_graph.node)))
transform_node = load_data(new_graph,old_graph)
for node in new_graph.node:
if node.name in transform_node:
print("%s is passed from old graph to raw graph" % node.name)
node.attr["value"].tensor.CopyFrom(transform_node[node.name].attr["value"].tensor)
print("%d ops in the raw graph\n%d ops in the old graph" %(len(raw_graph.as_graph_def().node),len(old_graph.as_graph_def().node)))
new_graph_def = transform_graph(raw_graph,old_graph)
with tf.gfile.GFile(args.output, mode='wb') as f:
f.write(new_graph.SerializeToString())
f.write(new_graph_def.SerializeToString())
print("the output model is saved in %s" % args.output)

def load_graph(graphName):
Expand All @@ -19,30 +16,56 @@ def load_graph(graphName):
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def,name = "")
return graph_def

def load_data(new_graph,old_graph):
new_graph_node = load_transform_node(new_graph)
old_graph_node = load_transform_node(old_graph)
if len(new_graph_node) != len(old_graph_node):
raise RuntimeError("New graph and original graph has different network structure\n")
for nodeName in old_graph_node.keys():
check_dim(new_graph_node, old_graph_node, nodeName)
check_precision(new_graph_node, old_graph_node, nodeName)
return old_graph_node


def check_precision(new_graph_node, old_graph_node, nodeName):
new_graph_precision = new_graph_node[nodeName].attr["value"].tensor.dtype
old_graph_precision = old_graph_node[nodeName].attr["value"].tensor.dtype
if new_graph_precision != old_graph_precision:
raise RuntimeError("New graph and original graph has different"+nodeName+" precision\n")

def check_dim(new_graph_node, old_graph_node, nodeName):
new_graph_dim = new_graph_node[nodeName].attr["value"].tensor.tensor_shape
old_graph_dim = old_graph_node[nodeName].attr["value"].tensor.tensor_shape
if new_graph_dim != old_graph_dim:
raise RuntimeError("New graph and original graph has different"+nodeName+" dim\n")
return graph

def transform_graph(raw_graph,old_graph):
precision_dict = {\
1:(np.float32, "float32"),\
2:(np.float64, "float64"),\
19:(np.float16, "float16")\
}
old_graph_def = old_graph.as_graph_def()
raw_graph_def = raw_graph.as_graph_def()
raw_graph_node = load_transform_node(raw_graph_def)
old_graph_node = load_transform_node(old_graph_def)

if len(raw_graph_node) != len(old_graph_node):
raise RuntimeError("raw graph and old graph has different network structure")

for node in raw_graph_def.node:
if node.name in raw_graph_node.keys():
if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
raise RuntimeError("float16 conversions not currently supported")

check_dim(raw_graph_node, old_graph_node, node.name)

if re.fullmatch("final_layer_type_\d+/bias",node.name) == None:
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = precision_dict[old_graph_node[node.name].dtype][0])
tensor_value = tensor_value.astype(dtype=precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].tensor.tensor_content = tensor_value.tostring()

else:
if precision_dict[old_graph_node[node.name].dtype][1] == "float64":
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,precision_dict[raw_graph_node[node.name].dtype][0], [1])))

elif precision_dict[old_graph_node[node.name].dtype][1] == "float32":
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))

elif precision_dict[old_graph_node[node.name].dtype][1] == "float16":
tensor_value = (np.array(old_graph_node[node.name].half_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))

print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name,precision_dict[old_graph_node[node.name].dtype][1],precision_dict[raw_graph_node[node.name].dtype][1]))

return raw_graph_def

def check_dim(raw_graph_node, old_graph_node, node_name):
raw_graph_dim = raw_graph_node[node_name].tensor_shape
old_graph_dim = old_graph_node[node_name].tensor_shape
if raw_graph_dim != old_graph_dim:
raise RuntimeError("old graph and raw graph has different"+node_name+" dim")


def load_transform_node(graph):
Expand All @@ -59,5 +82,5 @@ def load_transform_node(graph):
"
for node in graph.node:
if re.fullmatch(transform_node_pattern,node.name) != None:
transform_node[node.name] = node
transform_node[node.name] = node.attr["value"].tensor
return transform_node