diff --git a/source/train/Fitting.py b/source/train/Fitting.py index 49ca40b2c7..1effa8d7c2 100644 --- a/source/train/Fitting.py +++ b/source/train/Fitting.py @@ -194,12 +194,14 @@ def build (self, if self.numb_fparam > 0 : ext_fparam = tf.tile(fparam, [1, natoms[2+type_i]]) ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam]) + ext_fparam = tf.cast(ext_fparam,self.fitting_precision) layer = tf.concat([layer, ext_fparam], axis = 1) if self.numb_aparam > 0 : ext_aparam = tf.slice(aparam, [ 0, start_index * self.numb_aparam], [-1, natoms[2+type_i] * self.numb_aparam]) ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam]) + ext_aparam = tf.cast(ext_aparam,self.fitting_precision) layer = tf.concat([layer, ext_aparam], axis = 1) start_index += natoms[2+type_i] diff --git a/source/train/transform.py b/source/train/transform.py index 850f492d6e..1d587ae531 100644 --- a/source/train/transform.py +++ b/source/train/transform.py @@ -49,22 +49,18 @@ def transform_graph(raw_graph,old_graph): for node in raw_graph_def.node: if node.name in raw_graph_node.keys(): - """ - if precision_dict[old_graph_node[node.name].dtype][1] == "float16" or precision_dict[raw_graph_node[node.name].dtype][1] == "float16": - raise RuntimeError("float16 conversions not currently supported") - """ check_dim(raw_graph_node, old_graph_node, node.name) + tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim] old_graph_dtype = precision_dict[old_graph_node[node.name].dtype] raw_graph_dtype = precision_dict[raw_graph_node[node.name].dtype] print("%s is passed from old graph(%s) to raw graph(%s)" % (node.name, old_graph_dtype[1],raw_graph_dtype[1])) if raw_graph_dtype[1] == "float16": if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32": - if re.fullmatch("final_layer_type_\d+/bias", node.name) == None: + if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content, dtype=old_graph_dtype[0]) tensor_value = tensor_value.astype(np.float16) - tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim] node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_shape))) else: @@ -77,13 +73,12 @@ def transform_graph(raw_graph,old_graph): node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,tf.float16, [1]))) elif old_graph_dtype[1] == "float16": - tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim] tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape) node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value, tf.float16, tensor_value.shape))) elif raw_graph_dtype[1] == "float64" or raw_graph_dtype[1] == "float32": if old_graph_dtype[1] == "float64" or old_graph_dtype[1] == "float32": - if re.fullmatch("final_layer_type_\d+/bias", node.name) == None: + if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor_value = np.frombuffer(old_graph_node[node.name].tensor_content,dtype = old_graph_dtype[0]) tensor_value = tensor_value.astype(dtype=raw_graph_dtype[0]) node.attr["value"].tensor.tensor_content = tensor_value.tostring() @@ -98,13 +93,11 @@ def transform_graph(raw_graph,old_graph): node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], [1]))) elif old_graph_dtype[1] == "float16": - if re.fullmatch("final_layer_type_\d+/bias", node.name) == None: - tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim] + if (len(tensor_shape) != 1) or (tensor_shape[0] != 1): tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape) tensor_value = tensor_value.astype(raw_graph_dtype[0]) node.attr["value"].tensor.tensor_content = tensor_value.tostring() else: - tensor_shape = [dim.size for dim in raw_graph_node[node.name].tensor_shape.dim] tensor_value = convertMatrix(np.array(old_graph_node[node.name].half_val), tensor_shape) tensor_value = tensor_value.astype(raw_graph_dtype[0]) node.attr["value"].CopyFrom(tf.AttrValue(tensor=tf.make_tensor_proto(tensor_value,raw_graph_dtype[0], tensor_value.shape))) @@ -127,8 +120,16 @@ def load_transform_node(graph): layer_\d+_type_\d+/matrix|\ layer_\d+_type_\d+/bias|\ layer_\d+_type_\d+/idt|\ +final_layer_type_\d+/matrix|\ +descrpt_attr/t_avg|\ +descrpt_attr/t_std|\ final_layer_type_\d+/bias|\ -final_layer_type_\d+/matrix\ +fitting_attr/t_fparam_avg|\ +fitting_attr/t_fparam_istd|\ +fitting_attr/t_aparam_avg|\ +fitting_attr/t_aparam_istd|\ +model_attr/t_tab_info|\ +model_attr/t_tab_data|\ " for node in graph.node: if re.fullmatch(transform_node_pattern,node.name) != None: