diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index f241f6ee15..893493a201 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -280,26 +280,65 @@ def get_feed_dict(self, feed_dict : dict[str, tf.Tensor] The output feed_dict of current descriptor """ - # TODO: currently only SeA has this method, but I think the method can be - # moved here as it doesn't contain anything related to a specific descriptor - raise NotImplementedError + feed_dict = { + 't_coord:0' :coord_, + 't_type:0' :atype_, + 't_natoms:0' :natoms, + 't_box:0' :box, + 't_mesh:0' :mesh + } + return feed_dict def init_variables(self, - embedding_net_variables: dict - ) -> None: + model_file: str, + suffix : str = "", + ) -> None: """ Init the embedding net variables with the given dict Parameters ---------- - embedding_net_variables - The input dict which stores the embedding net variables + model_file : str + The input model file + suffix : str, optional + The suffix of the scope Notes ----- This method is called by others when the descriptor supported initialization from the given variables. """ - # TODO: currently only SeA has this method, but I think the method can be - # moved here as it doesn't contain anything related to a specific descriptor raise NotImplementedError( "Descriptor %s doesn't support initialization from the given variables!" % type(self).__name__) + + def get_tensor_names(self, suffix : str = "") -> Tuple[str]: + """Get names of tensors. + + Parameters + ---------- + suffix : str + The suffix of the scope + + Returns + ------- + Tuple[str] + Names of tensors + """ + raise NotImplementedError("Descriptor %s doesn't support this property!" % type(self).__name__) + + def pass_tensors_from_frz_model(self, + *tensors : tf.Tensor, + ) -> None: + """ + Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def + + Parameters + ---------- + *tensors : tf.Tensor + passed tensors + + Notes + ----- + The number of parameters in the method must be equal to the numbers of returns in + :meth:`get_tensor_names`. + """ + raise NotImplementedError("Descriptor %s doesn't support this method!" % type(self).__name__) diff --git a/deepmd/descriptor/hybrid.py b/deepmd/descriptor/hybrid.py index bff59518a1..37b9578b4e 100644 --- a/deepmd/descriptor/hybrid.py +++ b/deepmd/descriptor/hybrid.py @@ -253,3 +253,55 @@ def enable_compression(self, """ for idx, ii in enumerate(self.descrpt_list): ii.enable_compression(min_nbor_dist, model_file, table_extrapolate, table_stride_1, table_stride_2, check_frequency, suffix=f"{suffix}_{idx}") + + def init_variables(self, + model_file : str, + suffix : str = "", + ) -> None: + """ + Init the embedding net variables with the given dict + + Parameters + ---------- + model_file : str + The input frozen model file + suffix : str, optional + The suffix of the scope + """ + for idx, ii in enumerate(self.descrpt_list): + ii.init_variables(model_file, suffix=f"{suffix}_{idx}") + + def get_tensor_names(self, suffix : str = "") -> Tuple[str]: + """Get names of tensors. + + Parameters + ---------- + suffix : str + The suffix of the scope + + Returns + ------- + Tuple[str] + Names of tensors + """ + tensor_names = [] + for idx, ii in enumerate(self.descrpt_list): + tensor_names.extend(ii.get_tensor_names(suffix=f"{suffix}_{idx}")) + return tuple(tensor_names) + + def pass_tensors_from_frz_model(self, + *tensors : tf.Tensor, + ) -> None: + """ + Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def + + Parameters + ---------- + *tensors : tf.Tensor + passed tensors + """ + jj = 0 + for ii in self.descrpt_list: + n_tensors = len(ii.get_tensor_names()) + ii.pass_tensors_from_frz_model(*tensors[jj:jj+n_tensors]) + jj += n_tensors diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 47e28522a9..39485463a9 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -13,7 +13,7 @@ from deepmd.utils.tabulate import DPTabulate from deepmd.utils.type_embed import embed_atom_type from deepmd.utils.sess import run_sess -from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph +from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables from .descriptor import Descriptor class DescrptSeA (Descriptor): @@ -433,10 +433,10 @@ def build (self, tf.summary.histogram('nlist', self.nlist) self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt]) - self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat') - self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv') - self.rij = tf.identity(self.rij, name = 'o_rij') - self.nlist = tf.identity(self.nlist, name = 'o_nlist') + self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix) + self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix) + self.rij = tf.identity(self.rij, name = 'o_rij' + suffix) + self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix) self.dout, self.qmat = self._pass_filter(self.descrpt_reshape, atype, @@ -456,6 +456,21 @@ def get_rot_mat(self) -> tf.Tensor: """ return self.qmat + def get_tensor_names(self, suffix : str = "") -> Tuple[str]: + """Get names of tensors. + + Parameters + ---------- + suffix : str + The suffix of the scope + + Returns + ------- + Tuple[str] + Names of tensors + """ + return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0') + def pass_tensors_from_frz_model(self, descrpt_reshape : tf.Tensor, descrpt_deriv : tf.Tensor, @@ -481,60 +496,21 @@ def pass_tensors_from_frz_model(self, self.descrpt_deriv = descrpt_deriv self.descrpt_reshape = descrpt_reshape - def get_feed_dict(self, - coord_, - atype_, - natoms, - box, - mesh): - """ - generate the deed_dict for current descriptor - - Parameters - ---------- - coord_ - The coordinate of atoms - atype_ - The type of atoms - natoms - The number of atoms. This tensor has the length of Ntypes + 2 - natoms[0]: number of local atoms - natoms[1]: total number of atoms held by this processor - natoms[i]: 2 <= i < Ntypes+2, number of type i atoms - box - The box. Can be generated by deepmd.model.make_stat_input - mesh - For historical reasons, only the length of the Tensor matters. - if size of mesh == 6, pbc is assumed. - if size of mesh == 0, no-pbc is assumed. - - Returns - ------- - feed_dict - The output feed_dict of current descriptor - """ - feed_dict = { - 't_coord:0' :coord_, - 't_type:0' :atype_, - 't_natoms:0' :natoms, - 't_box:0' :box, - 't_mesh:0' :mesh - } - return feed_dict - - def init_variables(self, - embedding_net_variables: dict + model_file : str, + suffix : str = "", ) -> None: """ Init the embedding net variables with the given dict Parameters ---------- - embedding_net_variables - The input dict which stores the embedding net variables + model_file : str + The input frozen model file + suffix : str, optional + The suffix of the scope """ - self.embedding_net_variables = embedding_net_variables + self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix) def prod_force_virial(self, diff --git a/deepmd/model/ener.py b/deepmd/model/ener.py index dec6ca66f4..441709caff 100644 --- a/deepmd/model/ener.py +++ b/deepmd/model/ener.py @@ -173,10 +173,11 @@ def build (self, name = 'descrpt_attr/ntypes', dtype = tf.int32) feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh) - return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0'] - descrpt_reshape, descrpt_deriv, rij, nlist, dout \ + return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0'] + imported_tensors \ = self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements) - self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist) + dout = imported_tensors[-1] + self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1]) if self.srtab is not None : diff --git a/deepmd/model/tensor.py b/deepmd/model/tensor.py index 5c996ec38d..2a63eda4d6 100644 --- a/deepmd/model/tensor.py +++ b/deepmd/model/tensor.py @@ -3,7 +3,7 @@ from deepmd.env import tf from deepmd.common import ClassArg -from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION +from deepmd.env import global_cvt_2_ener_float, MODEL_VERSION, GLOBAL_TF_FLOAT_PRECISION from deepmd.env import op_module from deepmd.utils.graph import load_graph_def from .model_stat import make_stat_input, merge_sys_stat @@ -138,10 +138,11 @@ def build (self, name = 'descrpt_attr/ntypes', dtype = tf.int32) feed_dict = self.descrpt.get_feed_dict(coord_, atype_, natoms, box, mesh) - return_elements = ['o_rmat:0', 'o_rmat_deriv:0', 'o_rij:0', 'o_nlist:0', 'o_descriptor:0'] - descrpt_reshape, descrpt_deriv, rij, nlist, dout \ + return_elements = [*self.descrpt.get_tensor_names(), 'o_descriptor:0'] + imported_tensors \ = self._import_graph_def_from_frz_model(frz_model, feed_dict, return_elements) - self.descrpt.pass_tensors_from_frz_model(descrpt_reshape, descrpt_deriv, rij, nlist) + dout = imported_tensors[-1] + self.descrpt.pass_tensors_from_frz_model(*imported_tensors[:-1]) rot_mat = self.descrpt.get_rot_mat() rot_mat = tf.identity(rot_mat, name = 'o_rot_mat'+suffix) diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 4701ed0848..dc888ad3e0 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -691,7 +691,7 @@ def _init_from_frz_model(self): # self.frz_model will control the self.model to import the descriptor from the given frozen model instead of building from scratch... # initialize fitting net with the given compressed frozen model if self.model_type == 'original_model': - self.descrpt.init_variables(get_embedding_net_variables(self.run_opt.init_frz_model)) + self.descrpt.init_variables(self.run_opt.init_frz_model) self.fitting.init_variables(get_fitting_net_variables(self.run_opt.init_frz_model)) tf.constant("original_model", name = 'model_type', dtype = tf.string) elif self.model_type == 'compressed_model': diff --git a/deepmd/utils/graph.py b/deepmd/utils/graph.py index ed10108dbb..53fd05cc67 100644 --- a/deepmd/utils/graph.py +++ b/deepmd/utils/graph.py @@ -108,7 +108,7 @@ def get_tensor_by_type(node, elif data_type == np.float32: tensor = np.array(node.float_val) else: - raise RunTimeError('model compression does not support the half precision') + raise RuntimeError('model compression does not support the half precision') return tensor @@ -139,7 +139,7 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str = return embedding_net_nodes -def get_embedding_net_nodes(model_file: str) -> Dict: +def get_embedding_net_nodes(model_file: str, suffix: str = "") -> Dict: """ Get the embedding net nodes with the given frozen model(model_file) @@ -147,6 +147,8 @@ def get_embedding_net_nodes(model_file: str) -> Dict: ---------- model_file The input frozen model path + suffix : str, optional + The suffix of the scope Returns ---------- @@ -154,10 +156,10 @@ def get_embedding_net_nodes(model_file: str) -> Dict: The embedding net nodes with the given frozen model """ _, graph_def = load_graph_def(model_file) - return get_embedding_net_nodes_from_graph_def(graph_def) + return get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix) -def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict: +def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef, suffix: str = "") -> Dict: """ Get the embedding net variables with the given tf.GraphDef object @@ -165,6 +167,8 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict: ---------- graph_def The input tf.GraphDef object + suffix : str, optional + The suffix of the scope Returns ---------- @@ -172,7 +176,7 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict: The embedding net variables within the given tf.GraphDef object """ embedding_net_variables = {} - embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def) + embedding_net_nodes = get_embedding_net_nodes_from_graph_def(graph_def, suffix=suffix) for item in embedding_net_nodes: node = embedding_net_nodes[item] dtype = tf.as_dtype(node.dtype).as_numpy_dtype @@ -184,7 +188,7 @@ def get_embedding_net_variables_from_graph_def(graph_def : tf.GraphDef) -> Dict: embedding_net_variables[item] = np.reshape(tensor_value, tensor_shape) return embedding_net_variables -def get_embedding_net_variables(model_file : str) -> Dict: +def get_embedding_net_variables(model_file : str, suffix: str = "") -> Dict: """ Get the embedding net variables with the given frozen model(model_file) @@ -192,6 +196,8 @@ def get_embedding_net_variables(model_file : str) -> Dict: ---------- model_file The input frozen model path + suffix : str, optional + The suffix of the scope Returns ---------- @@ -199,7 +205,7 @@ def get_embedding_net_variables(model_file : str) -> Dict: The embedding net variables within the given frozen model """ _, graph_def = load_graph_def(model_file) - return get_embedding_net_variables_from_graph_def(graph_def) + return get_embedding_net_variables_from_graph_def(graph_def, suffix=suffix) def get_fitting_net_nodes_from_graph_def(graph_def: tf.GraphDef) -> Dict: