Skip to content

Commit e8a9101

Browse files
authored
add model compression support for models without training script (#948)
* add model compression support for models without training script * fix line changes * update doc for model compression * make the error message more reasonable * update doc for model compression
1 parent c4856ef commit e8a9101

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

deepmd/entrypoints/compress.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Compress a model, which including tabulating the embedding-net."""
22

3+
import os
34
import json
45
import logging
56
from typing import Optional
@@ -11,7 +12,7 @@
1112
from deepmd.utils.errors import GraphTooLargeError, GraphWithoutTensorError
1213

1314
from .freeze import freeze
14-
from .train import train
15+
from .train import train, get_rcut, get_min_nbor_dist
1516
from .transfer import transfer
1617

1718
__all__ = ["compress"]
@@ -27,6 +28,7 @@ def compress(
2728
step: float,
2829
frequency: str,
2930
checkpoint_folder: str,
31+
training_script: str,
3032
mpi_log: str,
3133
log_path: Optional[str],
3234
log_level: int,
@@ -54,6 +56,8 @@ def compress(
5456
frequency of tabulation overflow check
5557
checkpoint_folder : str
5658
trining checkpoint folder for freezing
59+
training_script : str
60+
training script of the input frozen model
5761
mpi_log : str
5862
mpi logging mode for training
5963
log_path : Optional[str]
@@ -64,16 +68,27 @@ def compress(
6468
try:
6569
t_jdata = get_tensor_by_name(input, 'train_attr/training_script')
6670
t_min_nbor_dist = get_tensor_by_name(input, 'train_attr/min_nbor_dist')
71+
jdata = json.loads(t_jdata)
6772
except GraphWithoutTensorError as e:
68-
raise RuntimeError(
69-
"The input frozen model: %s has no training script or min_nbor_dist information,"
70-
"which is not supported by the model compression program."
71-
"Please consider using the dp convert-from interface to upgrade the model" % input
72-
) from e
73+
if training_script == None:
74+
raise RuntimeError(
75+
"The input frozen model: %s has no training script or min_nbor_dist information, "
76+
"which is not supported by the model compression interface. "
77+
"Please consider using the --training-script command within the model compression interface to provide the training script of the input frozen model. "
78+
"Note that the input training script must contain the correct path to the training data." % input
79+
) from e
80+
elif os.path.exists(training_script) == False:
81+
raise RuntimeError(
82+
"The input training script %s does not exist! Please check the path of the training script. " % (input + "(" + os.path.abspath(input) + ")")
83+
) from e
84+
else:
85+
log.info("stage 0: compute the min_nbor_dist")
86+
jdata = j_loader(training_script)
87+
t_min_nbor_dist = get_min_nbor_dist(jdata, get_rcut(jdata))
88+
7389
tf.constant(t_min_nbor_dist,
7490
name = 'train_attr/min_nbor_dist',
7591
dtype = GLOBAL_TF_FLOAT_PRECISION)
76-
jdata = json.loads(t_jdata)
7792
jdata["model"]["compress"] = {}
7893
jdata["model"]["compress"]["type"] = 'se_e2_a'
7994
jdata["model"]["compress"]["compress"] = True

deepmd/entrypoints/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,13 @@ def parse_args(args: Optional[List[str]] = None):
306306
default=".",
307307
help="path to checkpoint folder",
308308
)
309+
parser_compress.add_argument(
310+
"-t",
311+
"--training-script",
312+
type=str,
313+
default=None,
314+
help="The training script of the input frozen model",
315+
)
309316

310317
# * print docs script **************************************************************
311318
parsers_doc = subparsers.add_parser(

deepmd/entrypoints/train.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def get_type_map(jdata):
240240
return jdata['model'].get('type_map', None)
241241

242242

243-
def get_sel(jdata, rcut):
243+
def get_nbor_stat(jdata, rcut):
244244
max_rcut = get_rcut(jdata)
245245
type_map = get_type_map(jdata)
246246

@@ -258,9 +258,15 @@ def get_sel(jdata, rcut):
258258
neistat = NeighborStat(ntypes, rcut)
259259

260260
min_nbor_dist, max_nbor_size = neistat.get_stat(train_data)
261+
return min_nbor_dist, max_nbor_size
261262

263+
def get_sel(jdata, rcut):
264+
_, max_nbor_size = get_nbor_stat(jdata, rcut)
262265
return max_nbor_size
263266

267+
def get_min_nbor_dist(jdata, rcut):
268+
min_nbor_dist, _ = get_nbor_stat(jdata, rcut)
269+
return min_nbor_dist
264270

265271
def parse_auto_sel(sel):
266272
if type(sel) is not str:

doc/getting-started.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ optional arguments:
336336
(default: -1)
337337
-c CHECKPOINT_FOLDER, --checkpoint-folder CHECKPOINT_FOLDER
338338
path to checkpoint folder (default: .)
339+
-t TRAINING_SCRIPT, --training-script TRAINING_SCRIPT
340+
The training script of the input frozen model
341+
(default: None)
339342
```
340343
**Parameter explanation**
341344

0 commit comments

Comments
 (0)