Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions deepmd/descriptor/se.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Tuple, List

from deepmd.env import tf
from deepmd.utils.graph import get_embedding_net_variables
from .descriptor import Descriptor


class DescrptSe (Descriptor):
"""A base class for smooth version of descriptors.

Notes
-----
All of these descriptors have an environmental matrix and an
embedding network (:meth:`deepmd.utils.network.embedding_net`), so
they can share some similiar methods without defining them twice.

Attributes
----------
embedding_net_variables : dict
initial embedding network variables
descrpt_reshape : tf.Tensor
the reshaped descriptor
descrpt_deriv : tf.Tensor
the descriptor derivative
rij : tf.Tensor
distances between two atoms
nlist : tf.Tensor
the neighbor list

"""
def _identity_tensors(self, suffix : str = "") -> None:
"""Identify tensors which are expected to be stored and restored.

Notes
-----
These tensors will be indentitied:
self.descrpt_reshape : o_rmat
self.descrpt_deriv : o_rmat_deriv
self.rij : o_rij
self.nlist : o_nlist
Thus, this method should be called during building the descriptor and
after these tensors are initialized.

Parameters
----------
suffix : str
The suffix of the scope
"""
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix)
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix)
self.rij = tf.identity(self.rij, name = 'o_rij' + suffix)
self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix)

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.

Parameters
----------
suffix : str
The suffix of the scope

Returns
-------
Tuple[str]
Names of tensors
"""
return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0')

def pass_tensors_from_frz_model(self,
descrpt_reshape : tf.Tensor,
descrpt_deriv : tf.Tensor,
rij : tf.Tensor,
nlist : tf.Tensor
):
"""
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def

Parameters
----------
descrpt_reshape
The passed descrpt_reshape tensor
descrpt_deriv
The passed descrpt_deriv tensor
rij
The passed rij tensor
nlist
The passed nlist tensor
"""
self.rij = rij
self.nlist = nlist
self.descrpt_deriv = descrpt_deriv
self.descrpt_reshape = descrpt_reshape

def init_variables(self,
model_file : str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict

Parameters
----------
model_file : str
The input frozen model file
suffix : str, optional
The suffix of the scope
"""
self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)
65 changes: 3 additions & 62 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
from .descriptor import Descriptor
from .se import DescrptSe

@Descriptor.register("se_e2_a")
@Descriptor.register("se_a")
class DescrptSeA (Descriptor):
class DescrptSeA (DescrptSe):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.

Expand Down Expand Up @@ -435,10 +436,7 @@ def build (self,
tf.summary.histogram('nlist', self.nlist)

self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat' + suffix)
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv' + suffix)
self.rij = tf.identity(self.rij, name = 'o_rij' + suffix)
self.nlist = tf.identity(self.nlist, name = 'o_nlist' + suffix)
self._identity_tensors(suffix=suffix)

self.dout, self.qmat = self._pass_filter(self.descrpt_reshape,
atype,
Expand All @@ -458,63 +456,6 @@ def get_rot_mat(self) -> tf.Tensor:
"""
return self.qmat

def get_tensor_names(self, suffix : str = "") -> Tuple[str]:
"""Get names of tensors.

Parameters
----------
suffix : str
The suffix of the scope

Returns
-------
Tuple[str]
Names of tensors
"""
return (f'o_rmat{suffix}:0', f'o_rmat_deriv{suffix}:0', f'o_rij{suffix}:0', f'o_nlist{suffix}:0')

def pass_tensors_from_frz_model(self,
descrpt_reshape : tf.Tensor,
descrpt_deriv : tf.Tensor,
rij : tf.Tensor,
nlist : tf.Tensor
):
"""
Pass the descrpt_reshape tensor as well as descrpt_deriv tensor from the frz graph_def

Parameters
----------
descrpt_reshape
The passed descrpt_reshape tensor
descrpt_deriv
The passed descrpt_deriv tensor
rij
The passed rij tensor
nlist
The passed nlist tensor
"""
self.rij = rij
self.nlist = nlist
self.descrpt_deriv = descrpt_deriv
self.descrpt_reshape = descrpt_reshape

def init_variables(self,
model_file : str,
suffix : str = "",
) -> None:
"""
Init the embedding net variables with the given dict

Parameters
----------
model_file : str
The input frozen model file
suffix : str, optional
The suffix of the scope
"""
self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)


def prod_force_virial(self,
atom_ener : tf.Tensor,
natoms : tf.Tensor
Expand Down
14 changes: 8 additions & 6 deletions deepmd/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor
from .se import DescrptSe


@Descriptor.register("se_e2_r")
@Descriptor.register("se_r")
class DescrptSeR (Descriptor):
class DescrptSeR (DescrptSe):
"""DeepPot-SE constructed from radial information of atomic configurations.

The embedding takes the distance between atoms as input.
Expand Down Expand Up @@ -114,6 +116,7 @@ def __init__ (self,
self.useBN = False
self.davg = None
self.dstd = None
self.embedding_net_variables = None

self.place_holders = {}
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
Expand Down Expand Up @@ -312,10 +315,7 @@ def build (self,
sel = self.sel_r)

self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat')
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv')
self.rij = tf.identity(self.rij, name = 'o_rij')
self.nlist = tf.identity(self.nlist, name = 'o_nlist')
self._identity_tensors(suffix=suffix)

# only used when tensorboard was set as true
tf.summary.histogram('descrpt', self.descrpt)
Expand Down Expand Up @@ -485,7 +485,9 @@ def _filter_r(self,
bavg = bavg,
seed = self.seed,
trainable = trainable,
uniform_seed = self.uniform_seed)
uniform_seed = self.uniform_seed,
initial_variables = self.embedding_net_variables,
)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
# natom x nei_type_i x out_size
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1]))
Expand Down
13 changes: 7 additions & 6 deletions deepmd/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor
from .se import DescrptSe

@Descriptor.register("se_e3")
@Descriptor.register("se_at")
@Descriptor.register("se_a_3be")
class DescrptSeT (Descriptor):
class DescrptSeT (DescrptSe):
"""DeepPot-SE constructed from all information (both angular and radial) of atomic
configurations.

Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__ (self,
self.useBN = False
self.dstd = None
self.davg = None
self.embedding_net_variables = None

self.place_holders = {}
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
Expand Down Expand Up @@ -311,10 +313,7 @@ def build (self,
sel_r = self.sel_r)

self.descrpt_reshape = tf.reshape(self.descrpt, [-1, self.ndescrpt])
self.descrpt_reshape = tf.identity(self.descrpt_reshape, name = 'o_rmat')
self.descrpt_deriv = tf.identity(self.descrpt_deriv, name = 'o_rmat_deriv')
self.rij = tf.identity(self.rij, name = 'o_rij')
self.nlist = tf.identity(self.nlist, name = 'o_nlist')
self._identity_tensors(suffix=suffix)

self.dout, self.qmat = self._pass_filter(self.descrpt_reshape,
atype,
Expand Down Expand Up @@ -509,7 +508,9 @@ def _filter(self,
bavg = bavg,
seed = self.seed,
trainable = trainable,
uniform_seed = self.uniform_seed)
uniform_seed = self.uniform_seed,
initial_variables = self.embedding_net_variables,
)
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
# with natom x nei_type_i x nei_type_j x out_size
ebd_env_ij = tf.reshape(ebd_env_ij, [-1, nei_type_i, nei_type_j, outputs_size[-1]])
Expand Down
2 changes: 1 addition & 1 deletion deepmd/utils/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_embedding_net_nodes_from_graph_def(graph_def: tf.GraphDef, suffix: str =
The embedding net nodes within the given tf.GraphDef object
"""
embedding_net_nodes = {}
embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d"
embedding_net_pattern = f"filter_type_\d+{suffix}/matrix_\d+_\d+|filter_type_\d+{suffix}/bias_\d+_\d+|filter_type_\d+{suffix}/idt_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+|filter_type_all{suffix}/matrix_\d+_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+|filter_type_all{suffix}/bias_\d+_\d+_\d+|filter_type_all{suffix}/idt_\d+_\d+"
for node in graph_def.node:
if re.fullmatch(embedding_net_pattern, node.name) != None:
embedding_net_nodes[node.name] = node.attr["value"].tensor
Expand Down