Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 68 additions & 18 deletions source/train/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
from deepmd.env import tf
import re
import numpy as np

def convertNumber(number):
binary = bin(number).replace("0b", "").zfill(16)
sign = int(binary[0]) * (-2) + 1
exp = int(binary[1:6], 2)
frac = (int(binary[6:], 2) + 2 ** 10) * (2 ** -10)
return sign * (2 ** (exp - 15)) * frac


def convertMatrix(matrix, shape):
matrix = matrix.flatten()
tmp = np.array([convertNumber(matrix[i]) for i in range(len(matrix))])
return tmp.reshape(shape)


def transform(args):
raw_graph = load_graph(args.raw_model)
old_graph = load_graph(args.old_model)
Expand Down Expand Up @@ -34,31 +49,66 @@ def transform_graph(raw_graph,old_graph):

for node in raw_graph_def.node:
if node.name in raw_graph_node.keys():
"""
if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16":
raise RuntimeError("float16 conversions not currently supported")
"""

check_dim(raw_graph_node, old_graph_node, node.name)
old_graph_dtype = precision_dict[old_graph_node[node.name].dtype]
raw_graph_dtype = precision_dict[raw_graph_node[node.name].dtype]
print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name, old_graph_dtype[1],raw_graph_dtype[1]))

if raw_graph_dtype[1] == "float16":
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content, dtype=old_graph_dtype[0])
tensor_value = tensor_value.astype(np.float16)
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_shape)))

if re.fullmatch("final_layer_type_\d+/bias",node.name) == None:
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = precision_dict[old_graph_node[node.name].dtype][0])
tensor_value = tensor_value.astype(dtype=precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
else:
if old_graph_dtype[1] == "float64":
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(np.float16)
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,tf.float16, [1])))

else:
if precision_dict[old_graph_node[node.name].dtype][1] == "float64":
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,precision_dict[raw_graph_node[node.name].dtype][0], [1])))

elif precision_dict[old_graph_node[node.name].dtype][1] == "float32":
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))

elif precision_dict[old_graph_node[node.name].dtype][1] == "float16":
tensor_value = (np.array(old_graph_node[node.name].half_val)).astype(precision_dict[raw_graph_node[node.name].dtype][0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, precision_dict[raw_graph_node[node.name].dtype][0], [1])))
elif old_graph_dtype[1] == "float32":
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(np.float16)
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,tf.float16, [1])))

elif old_graph_dtype[1] == "float16":
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_value.shape)))

print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name,precision_dict[old_graph_node[node.name].dtype][1],precision_dict[raw_graph_node[node.name].dtype][1]))

elif raw_graph_dtype[1] == "float64" or raw_graph_dtype[1] == "float32":
if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32":
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = old_graph_dtype[0])
tensor_value = tensor_value.astype(dtype=raw_graph_dtype[0])
node.attr["value"].tensor.tensor_content = tensor_value.tostring()

else:
if old_graph_dtype[1] == "float64":
tensor_value = (np.array(old_graph_node[node.name].double_val)).astype(raw_graph_dtype[0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1])))

elif old_graph_dtype[1] == "float32":
tensor_value = (np.array(old_graph_node[node.name].float_val)).astype(raw_graph_dtype[0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1])))

elif old_graph_dtype[1] == "float16":
if re.fullmatch("final_layer_type_\d+/bias", node.name) == None:
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
tensor_value = tensor_value.astype(raw_graph_dtype[0])
node.attr["value"].tensor.tensor_content = tensor_value.tostring()
else:
tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim]
tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape)
tensor_value = tensor_value.astype(raw_graph_dtype[0])
node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], tensor_value.shape)))

return raw_graph_def

def check_dim(raw_graph_node, old_graph_node, node_name):
Expand Down