From 84278fc4fe9ef722f9d6b13d60186b0e0b58ca6e Mon Sep 17 00:00:00 2001 From: ZhengdQin <46387172+ZhengdQin@users.noreply.github.com> Date: Tue, 26 Oct 2021 22:29:06 +0800 Subject: [PATCH] Update transfer.py fix the np.frombuffer in dp transfer --- deepmd/entrypoints/transfer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/deepmd/entrypoints/transfer.py b/deepmd/entrypoints/transfer.py index 47509551a8..85664ac83c 100644 --- a/deepmd/entrypoints/transfer.py +++ b/deepmd/entrypoints/transfer.py @@ -130,8 +130,9 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph: if raw_graph_dtype == np.float16: if old_graph_dtype == np.float64 or old_graph_dtype == np.float32: if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor = np.frombuffer(old_node.tensor_content, dtype = raw_graph_dtype) - cp_attr.from_array(tensor, tf.float16, shape = tensor_shape) + tensor = np.frombuffer(old_node.tensor_content, dtype = old_graph_dtype) + tensor = tensor.astype(raw_graph_dtype) + cp_attr.from_str(tensor) else: tensor = load_tensor(old_node, old_graph_dtype, raw_graph_dtype) cp_attr.from_array(tensor, tf.float16, [1]) @@ -143,7 +144,8 @@ def transform_graph(raw_graph: tf.Graph, old_graph: tf.Graph) -> tf.Graph: elif raw_graph_dtype == np.float64 or raw_graph_dtype == np.float32: if old_graph_dtype == np.float64 or old_graph_dtype == np.float32: if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): - tensor = np.frombuffer(old_node.tensor_content, dtype = raw_graph_dtype) + tensor = np.frombuffer(old_node.tensor_content, dtype = old_graph_dtype) + tensor = tensor.astype(raw_graph_dtype) cp_attr.from_str(tensor) else: tensor = load_tensor(old_node, old_graph_dtype, raw_graph_dtype)