Skip to content
37 changes: 37 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from deepmd.env import op_module, tf
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION
from deepmd.utils.sess import run_sess
from deepmd.utils.errors import GraphWithoutTensorError

if TYPE_CHECKING:
_DICT_VAL = TypeVar("_DICT_VAL")
Expand Down Expand Up @@ -483,3 +485,38 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype:
return np.float64
else:
raise RuntimeError(f"{precision} is not a valid precision")


def get_tensor_by_name(model_file: str,
tensor_name: str) -> tf.Tensor:
"""Load tensor value from the frozen model(model_file)

Parameters
----------
model_file : str
The input frozen model.
tensor : tensor_name
Indicates which tensor which will be loaded from the frozen model.

Returns
-------
tf.Tensor
The tensor which was loaded from the frozen model.

Raises
------
GraphWithoutTensorError
Whether the tensor_name is within the frozen model.
"""
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name="")
try:
tensor = graph.get_tensor_by_name(tensor_name + ":0")
except KeyError as e:
raise GraphWithoutTensorError() from e
with tf.Session(graph=graph) as sess:
tensor = run_sess(sess, tensor)
return tensor
30 changes: 19 additions & 11 deletions deepmd/entrypoints/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import logging
from typing import Optional

from deepmd.common import j_loader
from deepmd.env import tf
from deepmd.common import j_loader, get_tensor_by_name, GLOBAL_TF_FLOAT_PRECISION
from deepmd.utils.argcheck import normalize
from deepmd.utils.compat import updata_deepmd_input
from deepmd.utils.errors import GraphTooLargeError
from deepmd.utils.errors import GraphTooLargeError, GraphWithoutTensorError

from .freeze import freeze
from .train import train
Expand All @@ -20,7 +21,6 @@

def compress(
*,
INPUT: str,
input: str,
output: str,
extrapolate: int,
Expand All @@ -42,8 +42,6 @@ def compress(

Parameters
----------
INPUT : str
input json/yaml control file
input : str
frozen model file to compress
output : str
Expand All @@ -63,21 +61,30 @@ def compress(
log_level : int
logging level
"""
jdata = j_loader(INPUT)
if "model" not in jdata.keys():
jdata = updata_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
try:
t_jdata = get_tensor_by_name(input, 'train_attr/training_script')
t_min_nbor_dist = get_tensor_by_name(input, 'train_attr/min_nbor_dist')
except GraphWithoutTensorError as e:
raise RuntimeError(
"The input frozen model: %s has no training script or min_nbor_dist information,"
"which is not supported by the model compression program."
"Please consider using the dp convert-from interface to upgrade the model" % input
) from e
tf.constant(t_min_nbor_dist,
name = 'train_attr/min_nbor_dist',
dtype = GLOBAL_TF_FLOAT_PRECISION)
jdata = json.loads(t_jdata)
jdata["model"]["compress"] = {}
jdata["model"]["compress"]["type"] = 'se_e2_a'
jdata["model"]["compress"]["compress"] = True
jdata["model"]["compress"]["model_file"] = input
jdata["model"]["compress"]["min_nbor_dist"] = t_min_nbor_dist
jdata["model"]["compress"]["table_config"] = [
extrapolate,
step,
10 * step,
int(frequency),
]
# be careful here, if one want to refine the model
jdata["training"]["numb_steps"] = jdata["training"]["save_freq"]
jdata = normalize(jdata)

# check the descriptor info of the input file
Expand All @@ -90,7 +97,7 @@ def compress(

# stage 1: training or refining the model with tabulation
log.info("\n\n")
log.info("stage 1: train or refine the model with tabulation")
log.info("stage 1: compress the model")
control_file = "compress.json"
with open(control_file, "w") as fp:
json.dump(jdata, fp, indent=4)
Expand All @@ -103,6 +110,7 @@ def compress(
mpi_log=mpi_log,
log_level=log_level,
log_path=log_path,
is_compress=True,
)
except GraphTooLargeError as e:
raise RuntimeError(
Expand Down
2 changes: 2 additions & 0 deletions deepmd/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def _make_node_names(model_type: str, modifier_type: Optional[str] = None) -> Li
"model_attr/tmap",
"model_attr/model_type",
"model_attr/model_version",
"train_attr/min_nbor_dist",
"train_attr/training_script",
]

if model_type == "ener":
Expand Down
5 changes: 0 additions & 5 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,6 @@ def parse_args(args: Optional[List[str]] = None):
help="compress a model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser_compress.add_argument(
"INPUT",
help="The input parameter file in json or yaml format, which should be "
"consistent with the original model parameter file",
)
parser_compress.add_argument(
"-i",
"--input",
Expand Down
53 changes: 34 additions & 19 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import numpy as np
from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have
from deepmd.env import reset_default_tf_session_config
from deepmd.env import tf, reset_default_tf_session_config
from deepmd.infer.data_modifier import DipoleChargeModifier
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
from deepmd.train.trainer import DPTrainer
Expand All @@ -35,6 +35,7 @@ def train(
mpi_log: str,
log_level: int,
log_path: Optional[str],
is_compress: bool = False,
**kwargs,
):
"""Run DeePMD model training.
Expand All @@ -55,6 +56,8 @@ def train(
logging level defined by int 0-3
log_path : Optional[str]
logging file path or None if logs are to be output only to stdout
is_compress: Bool
indicates whether in the model compress mode

Raises
------
Expand All @@ -68,11 +71,15 @@ def train(

jdata = normalize(jdata)

jdata = update_sel(jdata)
if is_compress == False:
jdata = update_sel(jdata)

with open(output, "w") as fp:
json.dump(jdata, fp, indent=4)

# save the training script into the graph
tf.constant(json.dumps(jdata), name='train_attr/training_script', dtype=tf.string)

# run options
run_opt = RunOptions(
init_model=init_model,
Expand All @@ -86,10 +93,10 @@ def train(
log.info(message)

run_opt.print_resource_summary()
_do_work(jdata, run_opt)
_do_work(jdata, run_opt, is_compress)


def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = False):
"""Run serial model training.

Parameters
Expand All @@ -98,6 +105,8 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
arguments read form json/yaml control file
run_opt : RunOptions
object with run configuration
is_compress : Bool
indicates whether in model compress mode

Raises
------
Expand All @@ -112,7 +121,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
reset_default_tf_session_config(cpu_only=True)

# init the model
model = DPTrainer(jdata, run_opt=run_opt)
model = DPTrainer(jdata, run_opt=run_opt, is_compress = is_compress)
rcut = model.model.get_rcut()
type_map = model.model.get_type_map()
if len(type_map) == 0:
Expand All @@ -129,25 +138,31 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
# setup data modifier
modifier = get_modifier(jdata["model"].get("modifier", None))

# init data
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier)
train_data.print_summary("training")
if jdata["training"].get("validation_data", None) is not None:
valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier)
valid_data.print_summary("validation")
else:
valid_data = None
# decouple the training data from the model compress process
train_data = None
valid_data = None
if is_compress == False:
# init data
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier)
train_data.print_summary("training")
if jdata["training"].get("validation_data", None) is not None:
valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier)
valid_data.print_summary("validation")

# get training info
stop_batch = j_must_have(jdata["training"], "numb_steps")
model.build(train_data, stop_batch)

# train the model with the provided systems in a cyclic way
start_time = time.time()
model.train(train_data, valid_data)
end_time = time.time()
log.info("finished training")
log.info(f"wall time: {(end_time - start_time):.3f} s")
if is_compress == False:
# train the model with the provided systems in a cyclic way
start_time = time.time()
model.train(train_data, valid_data)
end_time = time.time()
log.info("finished training")
log.info(f"wall time: {(end_time - start_time):.3f} s")
else:
model.save_compressed()
log.info("finished compressing")


def get_data(jdata: Dict[str, Any], rcut, type_map, modifier):
Expand Down
Loading