diff --git a/deepmd/utils/graph.py b/deepmd/utils/graph.py index 031454e4b9..e6ff47b21f 100644 --- a/deepmd/utils/graph.py +++ b/deepmd/utils/graph.py @@ -153,7 +153,7 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str = The embedding net nodes within the given tf.GraphDef object """ # embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+" - if suffix is not "": + if suffix != "": embedding_net_pattern = EMBEDDING_NET_PATTERN\ .replace('/idt', suffix + '/idt')\ .replace('/bias', suffix + '/bias')\