diff --git a/deepmd/entrypoints/compress.py b/deepmd/entrypoints/compress.py index b0f0f42729..7755156601 100644 --- a/deepmd/entrypoints/compress.py +++ b/deepmd/entrypoints/compress.py @@ -1,5 +1,6 @@ """Compress a model, which including tabulating the embedding-net.""" +import os import json import logging from typing import Optional @@ -11,7 +12,7 @@ from deepmd.utils.errors import GraphTooLargeError, GraphWithoutTensorError from .freeze import freeze -from .train import train +from .train import train, get_rcut, get_min_nbor_dist from .transfer import transfer __all__ = ["compress"] @@ -27,6 +28,7 @@ def compress( step: float, frequency: str, checkpoint_folder: str, + training_script: str, mpi_log: str, log_path: Optional[str], log_level: int, @@ -54,6 +56,8 @@ def compress( frequency of tabulation overflow check checkpoint_folder : str trining checkpoint folder for freezing + training_script : str + training script of the input frozen model mpi_log : str mpi logging mode for training log_path : Optional[str] @@ -64,16 +68,27 @@ def compress( try: t_jdata = get_tensor_by_name(input, 'train_attr/training_script') t_min_nbor_dist = get_tensor_by_name(input, 'train_attr/min_nbor_dist') + jdata = json.loads(t_jdata) except GraphWithoutTensorError as e: - raise RuntimeError( - "The input frozen model: %s has no training script or min_nbor_dist information," - "which is not supported by the model compression program." - "Please consider using the dp convert-from interface to upgrade the model" % input - ) from e + if training_script == None: + raise RuntimeError( + "The input frozen model: %s has no training script or min_nbor_dist information, " + "which is not supported by the model compression interface. " + "Please consider using the --training-script command within the model compression interface to provide the training script of the input frozen model. " + "Note that the input training script must contain the correct path to the training data." % input + ) from e + elif os.path.exists(training_script) == False: + raise RuntimeError( + "The input training script %s does not exist! Please check the path of the training script. " % (input + "(" + os.path.abspath(input) + ")") + ) from e + else: + log.info("stage 0: compute the min_nbor_dist") + jdata = j_loader(training_script) + t_min_nbor_dist = get_min_nbor_dist(jdata, get_rcut(jdata)) + tf.constant(t_min_nbor_dist, name = 'train_attr/min_nbor_dist', dtype = GLOBAL_TF_FLOAT_PRECISION) - jdata = json.loads(t_jdata) jdata["model"]["compress"] = {} jdata["model"]["compress"]["type"] = 'se_e2_a' jdata["model"]["compress"]["compress"] = True diff --git a/deepmd/entrypoints/main.py b/deepmd/entrypoints/main.py index d28022ac1f..52e4e1c61e 100644 --- a/deepmd/entrypoints/main.py +++ b/deepmd/entrypoints/main.py @@ -306,6 +306,13 @@ def parse_args(args: Optional[List[str]] = None): default=".", help="path to checkpoint folder", ) + parser_compress.add_argument( + "-t", + "--training-script", + type=str, + default=None, + help="The training script of the input frozen model", + ) # * print docs script ************************************************************** parsers_doc = subparsers.add_parser( diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index ba20f5f7e8..3fe0466eae 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -240,7 +240,7 @@ def get_type_map(jdata): return jdata['model'].get('type_map', None) -def get_sel(jdata, rcut): +def get_nbor_stat(jdata, rcut): max_rcut = get_rcut(jdata) type_map = get_type_map(jdata) @@ -258,9 +258,15 @@ def get_sel(jdata, rcut): neistat = NeighborStat(ntypes, rcut) min_nbor_dist, max_nbor_size = neistat.get_stat(train_data) + return min_nbor_dist, max_nbor_size +def get_sel(jdata, rcut): + _, max_nbor_size = get_nbor_stat(jdata, rcut) return max_nbor_size +def get_min_nbor_dist(jdata, rcut): + min_nbor_dist, _ = get_nbor_stat(jdata, rcut) + return min_nbor_dist def parse_auto_sel(sel): if type(sel) is not str: diff --git a/doc/getting-started.md b/doc/getting-started.md index 0965b72661..607e15ed07 100644 --- a/doc/getting-started.md +++ b/doc/getting-started.md @@ -336,6 +336,9 @@ optional arguments: (default: -1) -c CHECKPOINT_FOLDER, --checkpoint-folder CHECKPOINT_FOLDER path to checkpoint folder (default: .) + -t TRAINING_SCRIPT, --training-script TRAINING_SCRIPT + The training script of the input frozen model + (default: None) ``` **Parameter explanation**