Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions deepmd/entrypoints/convert.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from deepmd.utils.convert import convert_20_to_21, convert_13_to_21, convert_12_to_21
from deepmd.utils.convert import convert_10_to_21, convert_20_to_21, convert_13_to_21, convert_12_to_21

def convert(
*,
Expand All @@ -7,7 +7,9 @@ def convert(
output_model: str,
**kwargs,
):
if FROM in ['1.1', '1.2']:
if FROM == '1.0':
convert_10_to_21(input_model, output_model)
elif FROM in ['1.1', '1.2']:
# no difference between 1.1 and 1.2
convert_12_to_21(input_model, output_model)
elif FROM == '1.3':
Expand Down
2 changes: 1 addition & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def parse_args(args: Optional[List[str]] = None):
parser_transform.add_argument(
'FROM',
type = str,
choices = ['1.1', '1.2', '1.3', '2.0'],
choices = ['1.0', '1.1', '1.2', '1.3', '2.0'],
help="The original model compatibility",
)
parser_transform.add_argument(
Expand Down
139 changes: 138 additions & 1 deletion deepmd/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,36 @@
from google.protobuf import text_format
from tensorflow.python.platform import gfile


def convert_13_to_21(input_model: str, output_model: str):
"""Convert DP 1.3 graph to 2.1 graph.

Parameters
----------
input_model : str
filename of the input graph
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp13_to_dp20('frozen_model.pbtxt')
convert_dp20_to_dp21('frozen_model.pbtxt')
convert_pbtxt_to_pb('frozen_model.pbtxt', output_model)
if os.path.isfile('frozen_model.pbtxt'):
os.remove('frozen_model.pbtxt')
print("the converted output model (2.1 support) is saved in %s" % output_model)


def convert_13_to_21(input_model: str, output_model: str):
"""Convert DP 1.3 graph to 2.1 graph.

Parameters
----------
input_model : str
filename of the input graph
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp13_to_dp20('frozen_model.pbtxt')
convert_dp20_to_dp21('frozen_model.pbtxt')
Expand All @@ -12,8 +41,39 @@ def convert_13_to_21(input_model: str, output_model: str):
os.remove('frozen_model.pbtxt')
print("the converted output model (2.1 support) is saved in %s" % output_model)


def convert_12_to_21(input_model: str, output_model: str):
"""Convert DP 1.2 graph to 2.1 graph.

Parameters
----------
input_model : str
filename of the input graph
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp12_to_dp13('frozen_model.pbtxt')
convert_dp13_to_dp20('frozen_model.pbtxt')
convert_dp20_to_dp21('frozen_model.pbtxt')
convert_pbtxt_to_pb('frozen_model.pbtxt', output_model)
if os.path.isfile('frozen_model.pbtxt'):
os.remove('frozen_model.pbtxt')
print("the converted output model (2.1 support) is saved in %s" % output_model)


def convert_10_to_21(input_model: str, output_model: str):
"""Convert DP 1.0 graph to 2.1 graph.

Parameters
----------
input_model : str
filename of the input graph
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp10_to_dp11('frozen_model.pbtxt')
convert_dp12_to_dp13('frozen_model.pbtxt')
convert_dp13_to_dp20('frozen_model.pbtxt')
convert_dp20_to_dp21('frozen_model.pbtxt')
Expand All @@ -22,7 +82,17 @@ def convert_12_to_21(input_model: str, output_model: str):
os.remove('frozen_model.pbtxt')
print("the converted output model (2.1 support) is saved in %s" % output_model)


def convert_20_to_21(input_model: str, output_model: str):
"""Convert DP 2.0 graph to 2.1 graph.

Parameters
----------
input_model : str
filename of the input graph
output_model : str
filename of the output graph
"""
convert_pb_to_pbtxt(input_model, 'frozen_model.pbtxt')
convert_dp20_to_dp21('frozen_model.pbtxt')
convert_pbtxt_to_pb('frozen_model.pbtxt', output_model)
Expand All @@ -31,21 +101,80 @@ def convert_20_to_21(input_model: str, output_model: str):
print("the converted output model (2.1 support) is saved in %s" % output_model)

def convert_pb_to_pbtxt(pbfile: str, pbtxtfile: str):
"""Convert DP graph to graph text.

Parameters
----------
pbfile : str
filename of the input graph
pbtxtfile : str
filename of the output graph text
"""
with gfile.FastGFile(pbfile, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
tf.train.write_graph(graph_def, './', pbtxtfile, as_text=True)

def convert_pbtxt_to_pb(pbtxtfile: str, pbfile: str):
"""Convert DP graph text to graph.

Parameters
----------
pbtxtfile : str
filename of the input graph text
pbfile : str
filename of the output graph
"""
with tf.gfile.FastGFile(pbtxtfile, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
# Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
tf.train.write_graph(graph_def, './', pbfile, as_text=False)

def convert_dp12_to_dp13(file):

def convert_dp10_to_dp11(file: str):
"""Convert DP 1.0 graph text to 1.1 graph text.

Parameters
----------
file : str
filename of the graph text
"""
with open(file, 'a') as f:
f.write("""
node {
name: "fitting_attr/daparam"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
} }
}
""")


def convert_dp12_to_dp13(file: str):
"""Convert DP 1.2 graph text to 1.3 graph text.

Parameters
----------
file : str
filename of the graph text
"""
file_data = ""
with open(file, "r", encoding="utf-8") as f:
ii = 0
Expand All @@ -67,7 +196,15 @@ def convert_dp12_to_dp13(file):
with open(file, "w", encoding="utf-8") as f:
f.write(file_data)


def convert_dp13_to_dp20(fname: str):
"""Convert DP 1.3 graph text to 2.0 graph text.

Parameters
----------
file : str
filename of the graph text
"""
with open(fname) as fp:
file_content = fp.read()
file_content += """
Expand Down
2 changes: 1 addition & 1 deletion deepmd/utils/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_plugin(self, key) -> object:

Parameters
----------
key str
key : str
key of the plugin

Returns
Expand Down
2 changes: 1 addition & 1 deletion doc/troubleshooting/model-compatability.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ One can execute `dp convert-from` to convert an old model to a new one.

| Model version | v0.12 | v1.0 | v1.1 | v1.2 | v1.3 | v2.0 | v2.1 |
|:-:|:-----------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| Compatibility | 😢 | 😢 | 😊 | 😊 | 😊 | 😄 | 😄 |
| Compatibility | 😢 | 😊 | 😊 | 😊 | 😊 | 😄 | 😄 |

**Legend**:
- 😄: The model is compatible with the DeePMD-kit package.
Expand Down