From e20888d58686920872ea6f6efb7507c580f02748 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 9 Jun 2021 20:36:31 -0400 Subject: [PATCH 01/10] add v1.3 compatibility --- .gitignore | 2 + deepmd/infer/deep_eval.py | 4 ++ source/api_cc/include/common.h | 4 ++ source/api_cc/src/DeepPot.cc | 5 +++ source/api_cc/src/common.cc | 5 ++- source/op/prod_env_mat_multi_device.cc | 53 +++++++++++++++++++++++++- 6 files changed, 71 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index f5c87dde5d..41eb111e4e 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,5 @@ venv* _build _templates API_CC +dp/** +build_lammps/** \ No newline at end of file diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 92a45e3cd4..41fa9ebd2f 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -77,6 +77,10 @@ def _graph_compatable( model_version_minor = int(self.model_version.split('.')[1]) MODEL_VERSION_MAJOR = int(MODEL_VERSION.split('.')[0]) MODEL_VERSION_MINOR = int(MODEL_VERSION.split('.')[1]) + if model_version_major == 0: + # We plan to support model generated from v1.3 + # We have no way to distinguish versions earlier than v1.3 + return True if (model_version_major != MODEL_VERSION_MAJOR) or \ (model_version_minor > MODEL_VERSION_MINOR) : return False diff --git a/source/api_cc/include/common.h b/source/api_cc/include/common.h index d59878693e..75fd61a6f7 100644 --- a/source/api_cc/include/common.h +++ b/source/api_cc/include/common.h @@ -87,6 +87,10 @@ void get_env_nthreads(int & num_intra_nthreads, int & num_inter_nthreads); +struct +tf_exception: public std::exception { +}; + /** * @brief Check TensorFlow status. Exit if not OK. * @param[in] status TensorFlow status. diff --git a/source/api_cc/src/DeepPot.cc b/source/api_cc/src/DeepPot.cc index 9400b47691..c862bb84fd 100644 --- a/source/api_cc/src/DeepPot.cc +++ b/source/api_cc/src/DeepPot.cc @@ -254,7 +254,12 @@ init (const std::string & model, const int & gpu_rank, const std::string & file_ if (dfparam < 0) dfparam = 0; if (daparam < 0) daparam = 0; model_type = get_scalar("model_attr/model_type"); + try{ model_version = get_scalar("model_attr/model_version"); + } catch (deepmd::tf_exception& e){ + // no model version defined in old models + model_version = "0.0"; + } if(! model_compatable(model_version)){ throw std::runtime_error( "incompatable model: version " + model_version diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 74c317529e..412ee55656 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -36,6 +36,9 @@ model_compatable( int model_version_minor = atoi(words_mv[1].c_str()); int MODEL_VERSION_MAJOR = atoi(words_gmv[0].c_str()); int MODEL_VERSION_MINOR = atoi(words_gmv[1].c_str()); + // we plan to support model generated from v1.3, + // but have no way to distinguish versions earlier than v1.3 + if(model_version_major == 0) return true; if(model_version_major != MODEL_VERSION_MAJOR || model_version_minor > MODEL_VERSION_MINOR){ return false; @@ -201,7 +204,7 @@ deepmd:: check_status(const tensorflow::Status& status) { if (!status.ok()) { std::cout << status.ToString() << std::endl; - exit(1); + throw deepmd::tf_exception(); } } diff --git a/source/op/prod_env_mat_multi_device.cc b/source/op/prod_env_mat_multi_device.cc index e4e12cac2b..6320f1f501 100644 --- a/source/op/prod_env_mat_multi_device.cc +++ b/source/op/prod_env_mat_multi_device.cc @@ -25,6 +25,26 @@ REGISTER_OP("ProdEnvMatA") .Output("nlist: int32"); // only sel_a and rcut_r uesd. +// an alias of ProdEnvMatA -- Compatible with v1.3 +REGISTER_OP("DescrptSeA") + .Attr("T: {float, double}") + .Input("coord: T") + .Input("type: int32") + .Input("natoms: int32") + .Input("box : T") + .Input("mesh : int32") + .Input("davg: T") + .Input("dstd: T") + .Attr("rcut_a: float") + .Attr("rcut_r: float") + .Attr("rcut_r_smth: float") + .Attr("sel_a: list(int)") + .Attr("sel_r: list(int)") + .Output("descrpt: T") + .Output("descrpt_deriv: T") + .Output("rij: T") + .Output("nlist: int32"); + REGISTER_OP("ProdEnvMatR") .Attr("T: {float, double}") .Input("coord: T") @@ -42,6 +62,23 @@ REGISTER_OP("ProdEnvMatR") .Output("rij: T") .Output("nlist: int32"); +// an alias of ProdEnvMatR -- Compatible with v1.3 +REGISTER_OP("DescrptSeR") + .Attr("T: {float, double}") + .Input("coord: T") + .Input("type: int32") + .Input("natoms: int32") + .Input("box: T") + .Input("mesh: int32") + .Input("davg: T") + .Input("dstd: T") + .Attr("rcut: float") + .Attr("rcut_smth: float") + .Attr("sel: list(int)") + .Output("descrpt: T") + .Output("descrpt_deriv: T") + .Output("rij: T") + .Output("nlist: int32"); template static int @@ -1364,17 +1401,25 @@ _prepare_coord_nlist_gpu_rocm( // Register the CPU kernels. +// Compatible with v1.3 #define REGISTER_CPU(T) \ REGISTER_KERNEL_BUILDER( \ Name("ProdEnvMatA").Device(DEVICE_CPU).TypeConstraint("T"), \ ProdEnvMatAOp); \ REGISTER_KERNEL_BUILDER( \ Name("ProdEnvMatR").Device(DEVICE_CPU).TypeConstraint("T"), \ - ProdEnvMatROp); + ProdEnvMatROp); \ +REGISTER_KERNEL_BUILDER( \ + Name("DescrptSeA").Device(DEVICE_CPU).TypeConstraint("T"), \ + ProdEnvMatAOp); \ +REGISTER_KERNEL_BUILDER( \ + Name("DescrptSeR").Device(DEVICE_CPU).TypeConstraint("T"), \ + ProdEnvMatROp); REGISTER_CPU(float); REGISTER_CPU(double); // Register the GPU kernels. +// Compatible with v1.3 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM #define REGISTER_GPU(T) \ REGISTER_KERNEL_BUILDER( \ @@ -1382,6 +1427,12 @@ REGISTER_KERNEL_BUILDER( ProdEnvMatAOp); \ REGISTER_KERNEL_BUILDER( \ Name("ProdEnvMatR").Device(DEVICE_GPU).TypeConstraint("T").HostMemory("natoms").HostMemory("box"), \ + ProdEnvMatROp); \ +REGISTER_KERNEL_BUILDER( \ + Name("DescrptSeA").Device(DEVICE_GPU).TypeConstraint("T").HostMemory("natoms").HostMemory("box"), \ + ProdEnvMatAOp); \ +REGISTER_KERNEL_BUILDER( \ + Name("DescrptSeR").Device(DEVICE_GPU).TypeConstraint("T").HostMemory("natoms").HostMemory("box"), \ ProdEnvMatROp); REGISTER_GPU(float); REGISTER_GPU(double); From 40dd8073b268bf2a557cb4de4e50ae6fcb25f85b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 10 Jun 2021 07:48:21 -0400 Subject: [PATCH 02/10] remove TestModelMajorCompatability as compatibility was added By the way: Compatability should be compatibility --- source/tests/test_deeppot_a.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 3726299fb5..7ce97261b2 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -35,12 +35,12 @@ def tearDown(self): os.remove(self.version_pbtxt) os.remove(self.version_pb) - def test(self): - with self.assertRaises(RuntimeError) as context: - DeepPot(str(self.version_pb)) - self.assertTrue('incompatible' in str(context.exception)) - self.assertTrue(MODEL_VERSION in str(context.exception)) - self.assertTrue('0.0' in str(context.exception)) + #def test(self): + # with self.assertRaises(RuntimeError) as context: + # DeepPot(str(self.version_pb)) + # self.assertTrue('incompatible' in str(context.exception)) + # self.assertTrue(MODEL_VERSION in str(context.exception)) + # self.assertTrue('0.0' in str(context.exception)) class TestModelMinorCompatability(unittest.TestCase) : From 11fdd5c67f7b467bdaa4fb04d6280841caa90ffb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 10 Jun 2021 10:15:50 -0400 Subject: [PATCH 03/10] Also remove TestModelMinorCompatability --- source/tests/test_deeppot_a.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 7ce97261b2..16a37a5f88 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -66,11 +66,11 @@ def tearDown(self): os.remove(self.version_pb) def test(self): - with self.assertRaises(RuntimeError) as context: - DeepPot(self.version_pb) - self.assertTrue('incompatible' in str(context.exception)) - self.assertTrue(MODEL_VERSION in str(context.exception)) - self.assertTrue('0.1000000' in str(context.exception)) + #with self.assertRaises(RuntimeError) as context: + # DeepPot(self.version_pb) + #self.assertTrue('incompatible' in str(context.exception)) + #self.assertTrue(MODEL_VERSION in str(context.exception)) + #self.assertTrue('0.1000000' in str(context.exception)) class TestDeepPotAPBC(unittest.TestCase) : From a03b5ee62107cfdd4f6fb621db4565c8531c4cd6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 10 Jun 2021 10:25:24 -0400 Subject: [PATCH 04/10] Update test_deeppot_a.py --- source/tests/test_deeppot_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 16a37a5f88..82afb0e105 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -65,7 +65,7 @@ def tearDown(self): os.remove(self.version_pbtxt) os.remove(self.version_pb) - def test(self): + #def test(self): #with self.assertRaises(RuntimeError) as context: # DeepPot(self.version_pb) #self.assertTrue('incompatible' in str(context.exception)) From 1accc15d427f1db78d322f3eddc2e93ef1f7791e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Jun 2021 17:11:51 -0400 Subject: [PATCH 05/10] Revert "Update test_deeppot_a.py" This reverts commit a03b5ee62107cfdd4f6fb621db4565c8531c4cd6. --- source/tests/test_deeppot_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 82afb0e105..16a37a5f88 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -65,7 +65,7 @@ def tearDown(self): os.remove(self.version_pbtxt) os.remove(self.version_pb) - #def test(self): + def test(self): #with self.assertRaises(RuntimeError) as context: # DeepPot(self.version_pb) #self.assertTrue('incompatible' in str(context.exception)) From df3a94e4193854b6d89eed41341193e8aaf0ed7d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Jun 2021 17:11:57 -0400 Subject: [PATCH 06/10] Revert "Also remove TestModelMinorCompatability" This reverts commit 11fdd5c67f7b467bdaa4fb04d6280841caa90ffb. --- source/tests/test_deeppot_a.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 16a37a5f88..7ce97261b2 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -66,11 +66,11 @@ def tearDown(self): os.remove(self.version_pb) def test(self): - #with self.assertRaises(RuntimeError) as context: - # DeepPot(self.version_pb) - #self.assertTrue('incompatible' in str(context.exception)) - #self.assertTrue(MODEL_VERSION in str(context.exception)) - #self.assertTrue('0.1000000' in str(context.exception)) + with self.assertRaises(RuntimeError) as context: + DeepPot(self.version_pb) + self.assertTrue('incompatible' in str(context.exception)) + self.assertTrue(MODEL_VERSION in str(context.exception)) + self.assertTrue('0.1000000' in str(context.exception)) class TestDeepPotAPBC(unittest.TestCase) : From 94f1d66ef401bccc37f1dcad6e960d0c855870a9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Jun 2021 17:11:59 -0400 Subject: [PATCH 07/10] Revert "remove TestModelMajorCompatability as compatibility was added" This reverts commit 40dd8073b268bf2a557cb4de4e50ae6fcb25f85b. --- source/tests/test_deeppot_a.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/source/tests/test_deeppot_a.py b/source/tests/test_deeppot_a.py index 7ce97261b2..3726299fb5 100644 --- a/source/tests/test_deeppot_a.py +++ b/source/tests/test_deeppot_a.py @@ -35,12 +35,12 @@ def tearDown(self): os.remove(self.version_pbtxt) os.remove(self.version_pb) - #def test(self): - # with self.assertRaises(RuntimeError) as context: - # DeepPot(str(self.version_pb)) - # self.assertTrue('incompatible' in str(context.exception)) - # self.assertTrue(MODEL_VERSION in str(context.exception)) - # self.assertTrue('0.0' in str(context.exception)) + def test(self): + with self.assertRaises(RuntimeError) as context: + DeepPot(str(self.version_pb)) + self.assertTrue('incompatible' in str(context.exception)) + self.assertTrue(MODEL_VERSION in str(context.exception)) + self.assertTrue('0.0' in str(context.exception)) class TestModelMinorCompatability(unittest.TestCase) : From c8f9bf885847a6c21672ac6b68e02b646d0e6a56 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Jun 2021 17:15:30 -0400 Subject: [PATCH 08/10] revert allowing 0.0 model --- deepmd/infer/deep_eval.py | 4 ---- source/api_cc/src/common.cc | 3 --- 2 files changed, 7 deletions(-) diff --git a/deepmd/infer/deep_eval.py b/deepmd/infer/deep_eval.py index 41fa9ebd2f..92a45e3cd4 100644 --- a/deepmd/infer/deep_eval.py +++ b/deepmd/infer/deep_eval.py @@ -77,10 +77,6 @@ def _graph_compatable( model_version_minor = int(self.model_version.split('.')[1]) MODEL_VERSION_MAJOR = int(MODEL_VERSION.split('.')[0]) MODEL_VERSION_MINOR = int(MODEL_VERSION.split('.')[1]) - if model_version_major == 0: - # We plan to support model generated from v1.3 - # We have no way to distinguish versions earlier than v1.3 - return True if (model_version_major != MODEL_VERSION_MAJOR) or \ (model_version_minor > MODEL_VERSION_MINOR) : return False diff --git a/source/api_cc/src/common.cc b/source/api_cc/src/common.cc index 412ee55656..579216cb2c 100644 --- a/source/api_cc/src/common.cc +++ b/source/api_cc/src/common.cc @@ -36,9 +36,6 @@ model_compatable( int model_version_minor = atoi(words_mv[1].c_str()); int MODEL_VERSION_MAJOR = atoi(words_gmv[0].c_str()); int MODEL_VERSION_MINOR = atoi(words_gmv[1].c_str()); - // we plan to support model generated from v1.3, - // but have no way to distinguish versions earlier than v1.3 - if(model_version_major == 0) return true; if(model_version_major != MODEL_VERSION_MAJOR || model_version_minor > MODEL_VERSION_MINOR){ return false; From 139c264a2c930b2f875ead761b720bf2795a782c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Jun 2021 17:42:30 -0400 Subject: [PATCH 09/10] convert from model 1.3 to 2.0 --- deepmd/entrypoints/__init__.py | 4 ++- deepmd/entrypoints/convert.py | 13 ++++++++ deepmd/entrypoints/main.py | 33 ++++++++++++++++++- deepmd/utils/convert.py | 59 ++++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 deepmd/entrypoints/convert.py create mode 100644 deepmd/utils/convert.py diff --git a/deepmd/entrypoints/__init__.py b/deepmd/entrypoints/__init__.py index 3beceace3a..4a02b995f3 100644 --- a/deepmd/entrypoints/__init__.py +++ b/deepmd/entrypoints/__init__.py @@ -8,6 +8,7 @@ from .train import train from .transfer import transfer from ..infer.model_devi import make_model_devi +from .convert import convert __all__ = [ "config", @@ -18,5 +19,6 @@ "transfer", "compress", "doc_train_input", - "make_model_devi" + "make_model_devi", + "convert", ] diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py new file mode 100644 index 0000000000..4bf514fe51 --- /dev/null +++ b/deepmd/entrypoints/convert.py @@ -0,0 +1,13 @@ +from deepmd.utils.convert import convert_13_to_20 + +def convert( + *, + FROM: str, + input_model: str, + output_model: str, + **kwargs, +): + if FROM == '1.3': + convert_13_to_20(input_model, output_model) + else: + raise RuntimeError('unsupported model version ' + FROM) diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index b245053053..04dc245271 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -3,7 +3,7 @@ import argparse import logging from pathlib import Path -from typing import List, Optional +from typing import Dict, List, Optional from deepmd.entrypoints import ( compress, @@ -14,6 +14,7 @@ train, transfer, make_model_devi, + convert, ) from deepmd.loggers import set_log_handles @@ -359,6 +360,34 @@ def parse_args(args: Optional[List[str]] = None): help="The trajectory frequency of the system" ) + # * convert models + # supported: 1.3->2.0 + parser_transform = subparsers.add_parser( + 'convert-from', + parents=[parser_log], + help='convert lower model version to supported version', + ) + parser_transform.add_argument( + 'FROM', + type = str, + choices = ['1.3'], + help="The original model compatibility", + ) + parser_transform.add_argument( + '-i', + "--input-model", + default = "frozen_model.pb", + type=str, + help = "the input model", + ) + parser_transform.add_argument( + "-o", + "--output-model", + default = "convert_out.pb", + type=str, + help='the output model', + ) + parsed_args = parser.parse_args(args=args) if parsed_args.command is None: parser.print_help() @@ -402,6 +431,8 @@ def main(): doc_train_input() elif args.command == "model-devi": make_model_devi(**dict_args) + elif args.command == "convert-from": + convert(**dict_args) elif args.command is None: pass else: diff --git a/deepmd/utils/convert.py b/deepmd/utils/convert.py new file mode 100644 index 0000000000..0d9c39df88 --- /dev/null +++ b/deepmd/utils/convert.py @@ -0,0 +1,59 @@ +import os +from deepmd.env import tf +from google.protobuf import text_format +from tensorflow.python.platform import gfile + +def convert_13_to_20(input_model: str, output_model: str): + convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') + convert_dp13_to_dp20('frozen_model.pbtxt') + convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) + if os.path.isfile('frozen_model.pbtxt'): + os.remove('frozen_model.pbtxt') + print("the converted output model (2.0 support) is saved in %s" % output_model) + +def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str): + with gfile.FastGFile(pbfile, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True) + +def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str): + with tf.gfile.FastGFile(pbtxtfile, 'r') as f: + graph_def = tf.GraphDef() + file_content = f.read() + # Merges the human-readable string in `file_content` into `graph_def`. + text_format.Merge(file_content, graph_def) + tf.train.write_graph(graph_def, './', pbfile, as_text=False) + +def convert_dp13_to_dp20(fname: str): + with open(fname) as fp: + file_content = fp.read() + file_content += """ +node { + name: "model_attr/model_version" + op: "Const" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + } + string_val: "1.0" + } + } + } +} +""" + file_content = file_content\ + .replace('DescrptSeA', 'ProdEnvMatA')\ + .replace('DescrptSeR', 'ProdEnvMatR') + with open(fname, 'w') as fp: + fp.write(file_content) From fc8b7f64b6df47d8cdec052122696a915ca1ba4e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Jun 2021 17:44:17 -0400 Subject: [PATCH 10/10] fix .gitignore --- .gitignore | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.gitignore b/.gitignore index b7556f269e..d56c01dbdd 100644 --- a/.gitignore +++ b/.gitignore @@ -27,10 +27,5 @@ venv* _build _templates API_CC -<<<<<<< HEAD -dp/** -build_lammps/** -======= dp/ build_lammps/ ->>>>>>> devel