diff --git a/source/train/Fitting.py b/source/train/Fitting.py index 960ff64e12..49ca40b2c7 100644 --- a/source/train/Fitting.py +++ b/source/train/Fitting.py @@ -212,7 +212,7 @@ def build (self, if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] : layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii]) else : - layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[ii]) + layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii]) final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1]) if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None: diff --git a/source/train/main.py b/source/train/main.py index 49aea59084..1e35d3bf17 100644 --- a/source/train/main.py +++ b/source/train/main.py @@ -4,6 +4,7 @@ from .freeze import freeze from .config import config from .test import test +from .transform import transform def main () : parser = argparse.ArgumentParser( @@ -15,6 +16,13 @@ def main () : # help="the output json file") default_num_inter_threads = 0 + parser_transform = subparsers.add_parser('transform', help='pass parameters to another model') + parser_transform.add_argument('-r', "--raw-model", default = "raw_frozen_model.pb", type=str, + help = "the model receiving parameters") + parser_transform.add_argument("-o","--old-model", default = "old_frozen_model.pb", type=str, + help='the model providing parameters') + parser_transform.add_argument("-n", "--output", default = "frozen_model.pb", type=str, + help = "the model after passing parameters") parser_train = subparsers.add_parser('train', help='train a model') parser_train.add_argument('INPUT', help='the input parameter file in json format') @@ -62,5 +70,7 @@ def main () : config(args) elif args.command == 'test' : test(args) + elif args.command == 'transform' : + transform(args) else : raise RuntimeError('unknown command ' + args.command) diff --git a/source/train/transform.py b/source/train/transform.py new file mode 100644 index 0000000000..66ff871aad --- /dev/null +++ b/source/train/transform.py @@ -0,0 +1,63 @@ +from deepmd.env import tf +import re +def transform(args): + new_graph = load_graph(args.raw_model) + old_graph = load_graph(args.old_model) + print("%d ops in the raw graph\n%d ops in the old graph" %(len(new_graph.node),len(old_graph.node))) + transform_node = load_data(new_graph,old_graph) + for node in new_graph.node: + if node.name in transform_node: + print("%s is passed from old graph to raw graph" % node.name) + node.attr["value"].tensor.CopyFrom(transform_node[node.name].attr["value"].tensor) + with tf.gfile.GFile(args.output, mode='wb') as f: + f.write(new_graph.SerializeToString()) + print("the output model is saved in %s" % args.output) + +def load_graph(graphName): + graph_def = tf.GraphDef() + with open(graphName,"rb") as f: + graph_def.ParseFromString(f.read()) + with tf.Graph().as_default() as graph: + tf.import_graph_def(graph_def,name = "") + return graph_def + +def load_data(new_graph,old_graph): + new_graph_node = load_transform_node(new_graph) + old_graph_node = load_transform_node(old_graph) + if len(new_graph_node) != len(old_graph_node): + raise RuntimeError("New graph and original graph has different network structure\n") + for nodeName in old_graph_node.keys(): + check_dim(new_graph_node, old_graph_node, nodeName) + check_precision(new_graph_node, old_graph_node, nodeName) + return old_graph_node + + +def check_precision(new_graph_node, old_graph_node, nodeName): + new_graph_precision = new_graph_node[nodeName].attr["value"].tensor.dtype + old_graph_precision = old_graph_node[nodeName].attr["value"].tensor.dtype + if new_graph_precision != old_graph_precision: + raise RuntimeError("New graph and original graph has different"+nodeName+" precision\n") + +def check_dim(new_graph_node, old_graph_node, nodeName): + new_graph_dim = new_graph_node[nodeName].attr["value"].tensor.tensor_shape + old_graph_dim = old_graph_node[nodeName].attr["value"].tensor.tensor_shape + if new_graph_dim != old_graph_dim: + raise RuntimeError("New graph and original graph has different"+nodeName+" dim\n") + + +def load_transform_node(graph): + transform_node = {} + transform_node_pattern = "\ +filter_type_\d+/matrix_\d+_\d+|\ +filter_type_\d+/bias_\d+_\d+|\ +filter_type_\d+/idt_\d+_\d+\ +layer_\d+_type_\d+/matrix|\ +layer_\d+_type_\d+/bias|\ +layer_\d+_type_\d+/idt|\ +final_layer_type_\d+/bias|\ +final_layer_type_\d+/matrix\ +" + for node in graph.node: + if re.fullmatch(transform_node_pattern,node.name) != None: + transform_node[node.name] = node + return transform_node