From eac7e04fc0c6d051cb398cdeb6f1a9d12e17a922 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 14 May 2021 14:37:05 +0800 Subject: [PATCH 1/3] add support for converting models to 2.0 compatibility --- source/train/CMakeLists.txt | 2 +- .../train/{convert_to_13.py => convert_to.py} | 51 +++++++++++++++++-- source/train/main.py | 20 +++++--- 3 files changed, 63 insertions(+), 10 deletions(-) rename source/train/{convert_to_13.py => convert_to.py} (60%) diff --git a/source/train/CMakeLists.txt b/source/train/CMakeLists.txt index 52b1442745..7d7e2b8831 100644 --- a/source/train/CMakeLists.txt +++ b/source/train/CMakeLists.txt @@ -2,7 +2,7 @@ configure_file("RunOptions.py.in" "${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py" @ONLY) -file(GLOB LIB_PY main.py common.py env.py compat.py calculator.py Network.py Deep*.py Data.py DataSystem.py Model*.py Descrpt*.py Fitting.py Loss.py LearningRate.py Trainer.py TabInter.py EwaldRecp.py DataModifier.py ${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py transform.py convert_to_13.py) +file(GLOB LIB_PY main.py common.py env.py compat.py calculator.py Network.py Deep*.py Data.py DataSystem.py Model*.py Descrpt*.py Fitting.py Loss.py LearningRate.py Trainer.py TabInter.py EwaldRecp.py DataModifier.py ${CMAKE_CURRENT_BINARY_DIR}/RunOptions.py transform.py convert_to.py) file(GLOB CLS_PY Local.py Slurm.py) diff --git a/source/train/convert_to_13.py b/source/train/convert_to.py similarity index 60% rename from source/train/convert_to_13.py rename to source/train/convert_to.py index 97aad18400..c2b05e6630 100644 --- a/source/train/convert_to_13.py +++ b/source/train/convert_to.py @@ -1,15 +1,27 @@ +import os from deepmd.env import tf from google.protobuf import text_format from tensorflow.python.platform import gfile from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import graph_util -def convert_to_13(args): +def convert_12_to_13(args): convert_pb_to_pbtxt(args.input_model, 'frozen_model.pbtxt') - convert_to_dp13('frozen_model.pbtxt') + convert_dp12_to_dp13('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', args.output_model) + if os.path.isfile('frozen_model.pbtxt'): + os.remove('frozen_model.pbtxt') print("the converted output model(1.3 support) is saved in %s" % args.output_model) +def convert_12_to_20(args): + convert_pb_to_pbtxt(args.input_model, 'frozen_model.pbtxt') + convert_dp12_to_dp13('frozen_model.pbtxt') + convert_dp13_to_dp20('frozen_model.pbtxt') + convert_pbtxt_to_pb('frozen_model.pbtxt', args.output_model) + # if os.path.isfile('frozen_model.pbtxt'): + # os.remove('frozen_model.pbtxt') + print("the converted output model(2.0 support) is saved in %s" % args.output_model) + def convert_pb_to_pbtxt(pbfile, pbtxtfile): with gfile.FastGFile(pbfile, 'rb') as f: graph_def = tf.GraphDef() @@ -25,7 +37,7 @@ def convert_pbtxt_to_pb(pbtxtfile, pbfile): text_format.Merge(file_content, graph_def) tf.train.write_graph(graph_def, './', pbfile, as_text=False) -def convert_to_dp13(file): +def convert_dp12_to_dp13(file): file_data = "" with open(file, "r", encoding="utf-8") as f: ii = 0 @@ -46,3 +58,36 @@ def convert_to_dp13(file): file_data += ' }\n' with open(file, "w", encoding="utf-8") as f: f.write(file_data) + + +def convert_dp13_to_dp20(fname): + with open(fname) as fp: + file_content = fp.read() + file_content += """ +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} +""" + file_content = file_content\ + .replace('DescrptSeA', 'ProdEnvMatA')\ + .replace('DescrptSeR', 'ProdEnvMatR') + with open(fname, 'w') as fp: + fp.write(file_content) diff --git a/source/train/main.py b/source/train/main.py index ac69540114..a81ebd5dfb 100644 --- a/source/train/main.py +++ b/source/train/main.py @@ -5,7 +5,7 @@ from .config import config from .test import test from .transform import transform -from .convert_to_13 import convert_to_13 +from .convert_to import convert_12_to_13, convert_12_to_20 def main () : parser = argparse.ArgumentParser( @@ -58,11 +58,14 @@ def main () : parser_tst.add_argument("-d", "--detail-file", type=str, help="The file containing details of energy force and virial accuracy") - parser_transform = subparsers.add_parser('convert-to-1.3', help='convert dp-1.2 model to dp-1.3 model') + parser_transform = subparsers.add_parser('convert-to', help='convert dp-1.2 model to higher model compatibility') + parser_transform.add_argument('TO', type = str, + choices = ['1.3', '2.0'], + help="The target model compatibility") parser_transform.add_argument('-i', "--input-model", default = "frozen_model.pb", type=str, help = "the input dp-1.2 model") - parser_transform.add_argument("-o","--output-model", default = "frozen_model_1.3.pb", type=str, - help='the converted dp-1.3 model') + parser_transform.add_argument("-o","--output-model", default = "convert_out.pb", type=str, + help='the output model') args = parser.parse_args() if args.command is None : @@ -78,7 +81,12 @@ def main () : test(args) elif args.command == 'transform' : transform(args) - elif args.command == 'convert-to-1.3' : - convert_to_13(args) + elif args.command == 'convert-to' : + if args.TO == '1.3': + convert_12_to_13(args) + elif args.TO == '2.0': + convert_12_to_20(args) + else: + raise RuntimeError('unsupported model compatibility ' + args.TO) else : raise RuntimeError('unknown command ' + args.command) From 592d214b3581437ae0340c08f522ff784eca52e8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 14 May 2021 15:30:28 +0800 Subject: [PATCH 2/3] rm pbtxt file when the conversion finishes --- source/train/convert_to.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/train/convert_to.py b/source/train/convert_to.py index c2b05e6630..fc07c14448 100644 --- a/source/train/convert_to.py +++ b/source/train/convert_to.py @@ -18,8 +18,8 @@ def convert_12_to_20(args): convert_dp12_to_dp13('frozen_model.pbtxt') convert_dp13_to_dp20('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', args.output_model) - # if os.path.isfile('frozen_model.pbtxt'): - # os.remove('frozen_model.pbtxt') + if os.path.isfile('frozen_model.pbtxt'): + os.remove('frozen_model.pbtxt') print("the converted output model(2.0 support) is saved in %s" % args.output_model) def convert_pb_to_pbtxt(pbfile, pbtxtfile): From 9bc2ae5824e7af614be8ca4e1ae2467b719f4d4f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 27 Aug 2021 11:39:12 +0800 Subject: [PATCH 3/3] add doc for model conversion --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 5520daa57f..a9b5b53f47 100644 --- a/README.md +++ b/README.md @@ -616,6 +616,11 @@ When the version of DeePMD-kit used to training model is different from the that DeePMD-kit guarantees that the codes with the same major and minor revisions are compatible. That is to say v0.12.5 is compatible to v0.12.0, but is not compatible to v0.11.0 nor v1.0.0. +One can convert the model trained with DeePMD-kit v1.2 to v2 compatible by using the command +```shell +dp convert-to 2.0 -i frozen_model.pb -o frozen_model_2.0.pb +``` + ## Installation: inadequate versions of gcc/g++ Sometimes you may use a gcc/g++ of version <4.9. If you have a gcc/g++ of version > 4.9, say, 7.2.0, you may choose to use it by doing ```bash