diff --git a/deepmd/entrypoints/convert.py b/deepmd/entrypoints/convert.py index 782bb89241..aa602dbed4 100644 --- a/deepmd/entrypoints/convert.py +++ b/deepmd/entrypoints/convert.py @@ -1,4 +1,4 @@ -from deepmd.utils.convert import convert_20_to_21, convert_13_to_21, convert_12_to_21 +from deepmd.utils.convert import convert_10_to_21, convert_20_to_21, convert_13_to_21, convert_12_to_21 def convert( *, @@ -7,7 +7,9 @@ def convert( output_model: str, **kwargs, ): - if FROM in ['1.1', '1.2']: + if FROM == '1.0': + convert_10_to_21(input_model, output_model) + elif FROM in ['1.1', '1.2']: # no difference between 1.1 and 1.2 convert_12_to_21(input_model, output_model) elif FROM == '1.3': diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index 043e6523df..46bdad05de 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -392,7 +392,7 @@ def parse_args(args: Optional[List[str]] = None): parser_transform.add_argument( 'FROM', type = str, - choices = ['1.1', '1.2', '1.3', '2.0'], + choices = ['1.0', '1.1', '1.2', '1.3', '2.0'], help="The original model compatibility", ) parser_transform.add_argument( diff --git a/deepmd/utils/convert.py b/deepmd/utils/convert.py index b17178c761..2c9a653002 100644 --- a/deepmd/utils/convert.py +++ b/deepmd/utils/convert.py @@ -3,7 +3,36 @@ from google.protobuf import text_format from tensorflow.python.platform import gfile + +def convert_13_to_21(input_model: str, output_model: str): + """Convert DP 1.3 graph to 2.1 graph. + + Parameters + ---------- + input_model : str + filename of the input graph + output_model : str + filename of the output graph + """ + convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') + convert_dp13_to_dp20('frozen_model.pbtxt') + convert_dp20_to_dp21('frozen_model.pbtxt') + convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) + if os.path.isfile('frozen_model.pbtxt'): + os.remove('frozen_model.pbtxt') + print("the converted output model (2.1 support) is saved in %s" % output_model) + + def convert_13_to_21(input_model: str, output_model: str): + """Convert DP 1.3 graph to 2.1 graph. + + Parameters + ---------- + input_model : str + filename of the input graph + output_model : str + filename of the output graph + """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') convert_dp13_to_dp20('frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') @@ -12,8 +41,39 @@ def convert_13_to_21(input_model: str, output_model: str): os.remove('frozen_model.pbtxt') print("the converted output model (2.1 support) is saved in %s" % output_model) + def convert_12_to_21(input_model: str, output_model: str): + """Convert DP 1.2 graph to 2.1 graph. + + Parameters + ---------- + input_model : str + filename of the input graph + output_model : str + filename of the output graph + """ + convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') + convert_dp12_to_dp13('frozen_model.pbtxt') + convert_dp13_to_dp20('frozen_model.pbtxt') + convert_dp20_to_dp21('frozen_model.pbtxt') + convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) + if os.path.isfile('frozen_model.pbtxt'): + os.remove('frozen_model.pbtxt') + print("the converted output model (2.1 support) is saved in %s" % output_model) + + +def convert_10_to_21(input_model: str, output_model: str): + """Convert DP 1.0 graph to 2.1 graph. + + Parameters + ---------- + input_model : str + filename of the input graph + output_model : str + filename of the output graph + """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') + convert_dp10_to_dp11('frozen_model.pbtxt') convert_dp12_to_dp13('frozen_model.pbtxt') convert_dp13_to_dp20('frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') @@ -22,7 +82,17 @@ def convert_12_to_21(input_model: str, output_model: str): os.remove('frozen_model.pbtxt') print("the converted output model (2.1 support) is saved in %s" % output_model) + def convert_20_to_21(input_model: str, output_model: str): + """Convert DP 2.0 graph to 2.1 graph. + + Parameters + ---------- + input_model : str + filename of the input graph + output_model : str + filename of the output graph + """ convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt') convert_dp20_to_dp21('frozen_model.pbtxt') convert_pbtxt_to_pb('frozen_model.pbtxt', output_model) @@ -31,6 +101,15 @@ def convert_20_to_21(input_model: str, output_model: str): print("the converted output model (2.1 support) is saved in %s" % output_model) def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str): + """Convert DP graph to graph text. + + Parameters + ---------- + pbfile : str + filename of the input graph + pbtxtfile : str + filename of the output graph text + """ with gfile.FastGFile(pbfile, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) @@ -38,6 +117,15 @@ def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str): tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True) def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str): + """Convert DP graph text to graph. + + Parameters + ---------- + pbtxtfile : str + filename of the input graph text + pbfile : str + filename of the output graph + """ with tf.gfile.FastGFile(pbtxtfile, 'r') as f: graph_def = tf.GraphDef() file_content = f.read() @@ -45,7 +133,48 @@ def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str): text_format.Merge(file_content, graph_def) tf.train.write_graph(graph_def, './', pbfile, as_text=False) -def convert_dp12_to_dp13(file): + +def convert_dp10_to_dp11(file: str): + """Convert DP 1.0 graph text to 1.1 graph text. + + Parameters + ---------- + file : str + filename of the graph text + """ + with open(file, 'a') as f: + f.write(""" +node { + name: "fitting_attr/daparam" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } } +} +""") + + +def convert_dp12_to_dp13(file: str): + """Convert DP 1.2 graph text to 1.3 graph text. + + Parameters + ---------- + file : str + filename of the graph text + """ file_data = "" with open(file, "r", encoding="utf-8") as f: ii = 0 @@ -67,7 +196,15 @@ def convert_dp12_to_dp13(file): with open(file, "w", encoding="utf-8") as f: f.write(file_data) + def convert_dp13_to_dp20(fname: str): + """Convert DP 1.3 graph text to 2.0 graph text. + + Parameters + ---------- + file : str + filename of the graph text + """ with open(fname) as fp: file_content = fp.read() file_content += """ diff --git a/deepmd/utils/plugin.py b/deepmd/utils/plugin.py index f195b7808c..6a40e69fab 100644 --- a/deepmd/utils/plugin.py +++ b/deepmd/utils/plugin.py @@ -52,7 +52,7 @@ def get_plugin(self, key) -> object: Parameters ---------- - key : str + key : str key of the plugin Returns diff --git a/doc/troubleshooting/model-compatability.md b/doc/troubleshooting/model-compatability.md index 5c0aa11889..820a79210f 100644 --- a/doc/troubleshooting/model-compatability.md +++ b/doc/troubleshooting/model-compatability.md @@ -8,7 +8,7 @@ One can execute `dp convert-from` to convert an old model to a new one. | Model version | v0.12 | v1.0 | v1.1 | v1.2 | v1.3 | v2.0 | v2.1 | |:-:|:-----------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:| -| Compatibility | 😢 | 😢 | 😊 | 😊 | 😊 | 😄 | 😄 | +| Compatibility | 😢 | 😊 | 😊 | 😊 | 😊 | 😄 | 😄 | **Legend**: - 😄: The model is compatible with the DeePMD-kit package.