Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions deepmd/entrypoints/compress.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Compress a model, which including tabulating the embedding-net."""

import os
import json
import logging
from typing import Optional
Expand All @@ -11,7 +12,7 @@
from deepmd.utils.errors import GraphTooLargeError, GraphWithoutTensorError

from .freeze import freeze
from .train import train
from .train import train, get_rcut, get_min_nbor_dist
from .transfer import transfer

__all__ = ["compress"]
Expand All @@ -27,6 +28,7 @@ def compress(
step: float,
frequency: str,
checkpoint_folder: str,
training_script: str,
mpi_log: str,
log_path: Optional[str],
log_level: int,
Expand Down Expand Up @@ -54,6 +56,8 @@ def compress(
frequency of tabulation overflow check
checkpoint_folder : str
trining checkpoint folder for freezing
training_script : str
training script of the input frozen model
mpi_log : str
mpi logging mode for training
log_path : Optional[str]
Expand All @@ -64,16 +68,27 @@ def compress(
try:
t_jdata = get_tensor_by_name(input, 'train_attr/training_script')
t_min_nbor_dist = get_tensor_by_name(input, 'train_attr/min_nbor_dist')
jdata = json.loads(t_jdata)
except GraphWithoutTensorError as e:
raise RuntimeError(
"The input frozen model: %s has no training script or min_nbor_dist information,"
"which is not supported by the model compression program."
"Please consider using the dp convert-from interface to upgrade the model" % input
) from e
if training_script == None:
raise RuntimeError(
"The input frozen model: %s has no training script or min_nbor_dist information, "
"which is not supported by the model compression interface. "
"Please consider using the --training-script command within the model compression interface to provide the training script of the input frozen model. "
"Note that the input training script must contain the correct path to the training data." % input
) from e
elif os.path.exists(training_script) == False:
raise RuntimeError(
"The input training script %s does not exist! Please check the path of the training script. " % (input + "(" + os.path.abspath(input) + ")")
) from e
else:
log.info("stage 0: compute the min_nbor_dist")
jdata = j_loader(training_script)
t_min_nbor_dist = get_min_nbor_dist(jdata, get_rcut(jdata))

tf.constant(t_min_nbor_dist,
name = 'train_attr/min_nbor_dist',
dtype = GLOBAL_TF_FLOAT_PRECISION)
jdata = json.loads(t_jdata)
jdata["model"]["compress"] = {}
jdata["model"]["compress"]["type"] = 'se_e2_a'
jdata["model"]["compress"]["compress"] = True
Expand Down
7 changes: 7 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ def parse_args(args: Optional[List[str]] = None):
default=".",
help="path to checkpoint folder",
)
parser_compress.add_argument(
"-t",
"--training-script",
type=str,
default=None,
help="The training script of the input frozen model",
)

# * print docs script **************************************************************
parsers_doc = subparsers.add_parser(
Expand Down
8 changes: 7 additions & 1 deletion deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def get_type_map(jdata):
return jdata['model'].get('type_map', None)


def get_sel(jdata, rcut):
def get_nbor_stat(jdata, rcut):
max_rcut = get_rcut(jdata)
type_map = get_type_map(jdata)

Expand All @@ -258,9 +258,15 @@ def get_sel(jdata, rcut):
neistat = NeighborStat(ntypes, rcut)

min_nbor_dist, max_nbor_size = neistat.get_stat(train_data)
return min_nbor_dist, max_nbor_size

def get_sel(jdata, rcut):
_, max_nbor_size = get_nbor_stat(jdata, rcut)
return max_nbor_size

def get_min_nbor_dist(jdata, rcut):
min_nbor_dist, _ = get_nbor_stat(jdata, rcut)
return min_nbor_dist

def parse_auto_sel(sel):
if type(sel) is not str:
Expand Down
3 changes: 3 additions & 0 deletions doc/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,9 @@ optional arguments:
(default: -1)
-c CHECKPOINT_FOLDER, --checkpoint-folder CHECKPOINT_FOLDER
path to checkpoint folder (default: .)
-t TRAINING_SCRIPT, --training-script TRAINING_SCRIPT
The training script of the input frozen model
(default: None)
```
**Parameter explanation**

Expand Down