Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/train/Fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def build (self,
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
else :
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[ii])
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn, precision = self.fitting_precision, trainable = self.trainable[ii])
final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, precision = self.fitting_precision, trainable = self.trainable[-1])

if type_i < len(self.atom_ener) and self.atom_ener[type_i] is not None:
Expand Down
10 changes: 10 additions & 0 deletions source/train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .freeze import freeze
from .config import config
from .test import test
from .transform import transform

def main () :
parser = argparse.ArgumentParser(
Expand All @@ -15,6 +16,13 @@ def main () :
# help="the output json file")

default_num_inter_threads = 0
parser_transform = subparsers.add_parser('transform', help='pass parameters to another model')
parser_transform.add_argument('-r', "--raw-model", default = "raw_frozen_model.pb", type=str,
help = "the model receiving parameters")
parser_transform.add_argument("-o","--old-model", default = "old_frozen_model.pb", type=str,
help='the model providing parameters')
parser_transform.add_argument("-n", "--output", default = "frozen_model.pb", type=str,
help = "the model after passing parameters")
parser_train = subparsers.add_parser('train', help='train a model')
parser_train.add_argument('INPUT',
help='the input parameter file in json format')
Expand Down Expand Up @@ -62,5 +70,7 @@ def main () :
config(args)
elif args.command == 'test' :
test(args)
elif args.command == 'transform' :
transform(args)
else :
raise RuntimeError('unknown command ' + args.command)
63 changes: 63 additions & 0 deletions source/train/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from deepmd.env import tf
import re
def transform(args):
new_graph = load_graph(args.raw_model)
old_graph = load_graph(args.old_model)
print("%d ops in the raw graph\n%d ops in the old graph" %(len(new_graph.node),len(old_graph.node)))
transform_node = load_data(new_graph,old_graph)
for node in new_graph.node:
if node.name in transform_node:
print("%s is passed from old graph to raw graph" % node.name)
node.attr["value"].tensor.CopyFrom(transform_node[node.name].attr["value"].tensor)
with tf.gfile.GFile(args.output, mode='wb') as f:
f.write(new_graph.SerializeToString())
print("the output model is saved in %s" % args.output)

def load_graph(graphName):
graph_def = tf.GraphDef()
with open(graphName,"rb") as f:
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def,name = "")
return graph_def

def load_data(new_graph,old_graph):
new_graph_node = load_transform_node(new_graph)
old_graph_node = load_transform_node(old_graph)
if len(new_graph_node) != len(old_graph_node):
raise RuntimeError("New graph and original graph has different network structure\n")
for nodeName in old_graph_node.keys():
check_dim(new_graph_node, old_graph_node, nodeName)
check_precision(new_graph_node, old_graph_node, nodeName)
return old_graph_node


def check_precision(new_graph_node, old_graph_node, nodeName):
new_graph_precision = new_graph_node[nodeName].attr["value"].tensor.dtype
old_graph_precision = old_graph_node[nodeName].attr["value"].tensor.dtype
if new_graph_precision != old_graph_precision:
raise RuntimeError("New graph and original graph has different"+nodeName+" precision\n")

def check_dim(new_graph_node, old_graph_node, nodeName):
new_graph_dim = new_graph_node[nodeName].attr["value"].tensor.tensor_shape
old_graph_dim = old_graph_node[nodeName].attr["value"].tensor.tensor_shape
if new_graph_dim != old_graph_dim:
raise RuntimeError("New graph and original graph has different"+nodeName+" dim\n")


def load_transform_node(graph):
transform_node = {}
transform_node_pattern = "\
filter_type_\d+/matrix_\d+_\d+|\
filter_type_\d+/bias_\d+_\d+|\
filter_type_\d+/idt_\d+_\d+\
layer_\d+_type_\d+/matrix|\
layer_\d+_type_\d+/bias|\
layer_\d+_type_\d+/idt|\
final_layer_type_\d+/bias|\
final_layer_type_\d+/matrix\
"
for node in graph.node:
if re.fullmatch(transform_node_pattern,node.name) != None:
transform_node[node.name] = node
return transform_node