Fix compress training bug within the dp train --init-frz-model interface#1233
Conversation
Codecov Report
@@ Coverage Diff @@
## devel #1233 +/- ##
==========================================
- Coverage 76.02% 75.99% -0.04%
==========================================
Files 91 91
Lines 7367 7389 +22
==========================================
+ Hits 5601 5615 +14
- Misses 1766 1774 +8
Continue to review full report at Codecov.
|
deepmd/entrypoints/freeze.py
Outdated
| raw_graph_def, # The graph_def is used to retrieve the nodes | ||
| [n + '_1' for n in old_graph_nodes], # The output node names are used to select the usefull nodes | ||
| ) | ||
| except Exception: |
There was a problem hiding this comment.
All fitting net variables are added the _1 suffix, we can check it by the tf.trainable_variables() function. I think this is the default node naming method of TensorFlow: When a specific variable name is not available in the graph(due to the usage of tf.import_graph_def), TF will automatically add a number suffix to that name. And each fitting_net node name are unique within the original graph(with a suffix matrix, bias or idt), so we are fine to do so.
There was a problem hiding this comment.
I mean could you catch a specific exception (such as RuntimeError, etc) instead of general Exception?
There was a problem hiding this comment.
Sure. It's the AssertionError.
deepmd/entrypoints/freeze.py
Outdated
|
|
||
| log = logging.getLogger(__name__) | ||
|
|
||
| def _transfer_graph_def(sess, old_graph_def, raw_graph_def): |
There was a problem hiding this comment.
_transfer_graph_def is not a good name for this function. It should specified which variables are transferred
…to compress-training fix pip CI problem
The compress training code uses the
tf.import_graph_deffunction to load thetf.Tensorandtf.Operationobjects from the old graph def to the current default graph.However, this could lead to a variable name conflict during the model freeze process. And that's the reason for the issue #1194 .According to the tensorflow doc :
In this PR, the following changes are adopted to address the #1194 :
EMBEDDING_NET_PATTERN,FITTING_NET_PATTERNas well as theTRANSFER_PATTERN, to thedeepmd.envmodule.Note that this PR does not use the prefix parameter of the
tf.import_graph_deffunction to solve the #1194 , although it is easier to do so, it will change the node name permanently. Instead this PR will not affect the graph structures as well as the node names within the graph, which is very important for the model maintenance.