Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepmd/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .se_a_ebd import (
DescrptSeAEbd,
)
from .se_a_ebd_v2 import (
DescrptSeAEbdV2,
)
from .se_a_ef import (
DescrptSeAEf,
DescrptSeAEfLower,
Expand All @@ -39,6 +42,7 @@
"DescrptHybrid",
"DescrptLocFrame",
"DescrptSeA",
"DescrptSeAEbdV2",
"DescrptSeAEbd",
"DescrptSeAEf",
"DescrptSeAEfLower",
Expand Down
210 changes: 186 additions & 24 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from deepmd.common import (
cast_precision,
get_activation_func,
get_np_precision,
get_precision,
)
from deepmd.env import (
Expand All @@ -30,10 +31,17 @@
from deepmd.nvnmd.utils.config import (
nvnmd_cfg,
)
from deepmd.utils.compress import (
get_extra_side_embedding_net_variable,
get_two_side_type_embedding,
get_type_embedding,
make_data,
)
from deepmd.utils.errors import (
GraphWithoutTensorError,
)
from deepmd.utils.graph import (
get_pattern_nodes_from_graph_def,
get_tensor_by_name_from_graph,
)
from deepmd.utils.network import (
Expand Down Expand Up @@ -165,6 +173,7 @@ def __init__(
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
stripped_type_embedding: bool = False,
**kwargs,
) -> None:
"""Constructor."""
Expand All @@ -185,6 +194,7 @@ def __init__(
self.compress_activation_fn = get_activation_func(activation_function)
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
self.exclude_types = set()
for tt in exclude_types:
assert len(tt) == 2
Expand All @@ -193,6 +203,9 @@ def __init__(
self.set_davg_zero = set_davg_zero
self.type_one_side = type_one_side
self.spin = spin
self.stripped_type_embedding = stripped_type_embedding
self.extra_embeeding_net_variables = None
self.layer_size = len(neuron)

# extend sel_a for spin system
if self.spin is not None:
Expand Down Expand Up @@ -463,6 +476,39 @@ def enable_compression(
"The size of the next layer of the neural network must be twice the size of the previous layer."
% ",".join([str(item) for item in self.filter_neuron])
)
if self.stripped_type_embedding:
ret_two_side = get_pattern_nodes_from_graph_def(
graph_def, f"filter_type_all{suffix}/.+_two_side_ebd"
)
ret_one_side = get_pattern_nodes_from_graph_def(
graph_def, f"filter_type_all{suffix}/.+_one_side_ebd"
)
if len(ret_two_side) == 0 and len(ret_one_side) == 0:
raise RuntimeError(
"can not find variables of embedding net from graph_def, maybe it is not a compressible model."
)
elif len(ret_one_side) != 0 and len(ret_two_side) != 0:
raise RuntimeError(
"both one side and two side embedding net varaibles are detected, it is a wrong model."
)
elif len(ret_two_side) != 0:
self.final_type_embedding = get_two_side_type_embedding(self, graph)
self.matrix = get_extra_side_embedding_net_variable(
self, graph_def, "two_side", "matrix", suffix
)
self.bias = get_extra_side_embedding_net_variable(
self, graph_def, "two_side", "bias", suffix
)
self.extra_embedding = make_data(self, self.final_type_embedding)
else:
self.final_type_embedding = get_type_embedding(self, graph)
self.matrix = get_extra_side_embedding_net_variable(
self, graph_def, "one_side", "matrix", suffix
)
self.bias = get_extra_side_embedding_net_variable(
self, graph_def, "one_side", "bias", suffix
)
self.extra_embedding = make_data(self, self.final_type_embedding)

self.compress = True
self.table = DPTabulate(
Expand Down Expand Up @@ -588,6 +634,7 @@ def build(
coord = tf.reshape(coord_, [-1, natoms[1] * 3])
box = tf.reshape(box_, [-1, 9])
atype = tf.reshape(atype_, [-1, natoms[1]])
self.atype = atype

op_descriptor = (
build_op_descriptor() if nvnmd_cfg.enable else op_module.prod_env_mat_a
Expand All @@ -606,6 +653,10 @@ def build(
sel_a=self.sel_a,
sel_r=self.sel_r,
)
nlist_t = tf.reshape(self.nlist + 1, [-1])
atype_t = tf.concat([[self.ntypes], tf.reshape(self.atype, [-1])], axis=0)
self.nei_type_vec = tf.nn.embedding_lookup(atype_t, nlist_t)

# only used when tensorboard was set as true
tf.summary.histogram("descrpt", self.descrpt)
tf.summary.histogram("rij", self.rij)
Expand Down Expand Up @@ -692,6 +743,8 @@ def _pass_filter(
type_embedding = input_dict.get("type_embedding", None)
else:
type_embedding = None
if self.stripped_type_embedding and type_embedding is None:
raise RuntimeError("type_embedding is required for se_a_tebd_v2 model.")
start_index = 0
inputs = tf.reshape(inputs, [-1, natoms[0], self.ndescrpt])
output = []
Expand Down Expand Up @@ -901,13 +954,89 @@ def _filter_lower(
# with (natom x nei_type_i) x 1
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0, 0], [-1, 1]), [-1, 1])
if type_embedding is not None:
xyz_scatter = self._concat_type_embedding(
xyz_scatter, nframes, natoms, type_embedding
)
if self.compress:
raise RuntimeError(
"compression of type embedded descriptor is not supported at the moment"
if self.stripped_type_embedding:
if self.type_one_side:
extra_embedding_index = self.nei_type_vec
else:
padding_ntypes = type_embedding.shape[0]
atype_expand = tf.reshape(self.atype, [-1, 1])
idx_i = tf.tile(atype_expand * padding_ntypes, [1, self.nnei])
idx_j = tf.reshape(self.nei_type_vec, [-1, self.nnei])
idx = idx_i + idx_j
index_of_two_side = tf.reshape(idx, [-1])
extra_embedding_index = index_of_two_side

if not self.compress:
if self.type_one_side:
one_side_type_embedding_suffix = "_one_side_ebd"
net_output = embedding_net(
type_embedding,
self.filter_neuron,
self.filter_precision,
activation_fn=activation_fn,
resnet_dt=self.filter_resnet_dt,
name_suffix=one_side_type_embedding_suffix,
stddev=stddev,
bavg=bavg,
seed=self.seed,
trainable=trainable,
uniform_seed=self.uniform_seed,
initial_variables=self.extra_embeeding_net_variables,
mixed_prec=self.mixed_prec,
)
net_output = tf.nn.embedding_lookup(
net_output, self.nei_type_vec
)
else:
type_embedding_nei = tf.tile(
tf.reshape(type_embedding, [1, padding_ntypes, -1]),
[padding_ntypes, 1, 1],
) # (ntypes) * ntypes * Y
type_embedding_center = tf.tile(
tf.reshape(type_embedding, [padding_ntypes, 1, -1]),
[1, padding_ntypes, 1],
) # ntypes * (ntypes) * Y
two_side_type_embedding = tf.concat(
[type_embedding_nei, type_embedding_center], -1
) # ntypes * ntypes * (Y+Y)
two_side_type_embedding = tf.reshape(
two_side_type_embedding,
[-1, two_side_type_embedding.shape[-1]],
)

atype_expand = tf.reshape(self.atype, [-1, 1])
idx_i = tf.tile(atype_expand * padding_ntypes, [1, self.nnei])
idx_j = tf.reshape(self.nei_type_vec, [-1, self.nnei])
idx = idx_i + idx_j
index_of_two_side = tf.reshape(idx, [-1])
self.extra_embedding_index = index_of_two_side

two_side_type_embedding_suffix = "_two_side_ebd"
net_output = embedding_net(
two_side_type_embedding,
self.filter_neuron,
self.filter_precision,
activation_fn=activation_fn,
resnet_dt=self.filter_resnet_dt,
name_suffix=two_side_type_embedding_suffix,
stddev=stddev,
bavg=bavg,
seed=self.seed,
trainable=trainable,
uniform_seed=self.uniform_seed,
initial_variables=self.extra_embeeding_net_variables,
mixed_prec=self.mixed_prec,
)
net_output = tf.nn.embedding_lookup(net_output, idx)
net_output = tf.reshape(net_output, [-1, self.filter_neuron[-1]])
else:
xyz_scatter = self._concat_type_embedding(
xyz_scatter, nframes, natoms, type_embedding
)
if self.compress:
raise RuntimeError(
"compression of type embedded descriptor is not supported when stripped_type_embedding == False"
)
# natom x 4 x outputs_size
if nvnmd_cfg.enable:
return filter_lower_R42GR(
Expand All @@ -929,25 +1058,48 @@ def _filter_lower(
self.embedding_net_variables,
)
if self.compress and (not is_exclude):
if self.type_one_side:
net = "filter_-1_net_" + str(type_i)
if self.stripped_type_embedding:
net_output = tf.nn.embedding_lookup(
self.extra_embedding, extra_embedding_index
)
net = "filter_net"
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
return op_module.tabulate_fusion_se_atten(
tf.cast(self.table.data[net], self.filter_precision),
info,
xyz_scatter,
tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
net_output,
last_layer_size=outputs_size[-1],
is_sorted=False,
)
else:
net = "filter_" + str(type_input) + "_net_" + str(type_i)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
return op_module.tabulate_fusion_se_a(
tf.cast(self.table.data[net], self.filter_precision),
info,
xyz_scatter,
tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
last_layer_size=outputs_size[-1],
)
if self.type_one_side:
net = "filter_-1_net_" + str(type_i)
else:
net = "filter_" + str(type_input) + "_net_" + str(type_i)
info = [
self.lower[net],
self.upper[net],
self.upper[net] * self.table_config[0],
self.table_config[1],
self.table_config[2],
self.table_config[3],
]
return op_module.tabulate_fusion_se_a(
tf.cast(self.table.data[net], self.filter_precision),
info,
xyz_scatter,
tf.reshape(inputs_i, [natom, shape_i[1] // 4, 4]),
last_layer_size=outputs_size[-1],
)
else:
if not is_exclude:
# with (natom x nei_type_i) x out_size
Expand All @@ -966,6 +1118,9 @@ def _filter_lower(
initial_variables=self.embedding_net_variables,
mixed_prec=self.mixed_prec,
)

if self.stripped_type_embedding:
xyz_scatter = xyz_scatter * net_output + xyz_scatter
if (not self.uniform_seed) and (self.seed is not None):
self.seed += self.seed_shift
else:
Expand Down Expand Up @@ -1179,3 +1334,10 @@ def init_variables(
self.dstd = new_dstd
if self.original_sel is None:
self.original_sel = sel

@property
def explicit_ntypes(self) -> bool:
"""Explicit ntypes with type embedding."""
if self.stripped_type_embedding:
return True
return False
70 changes: 70 additions & 0 deletions deepmd/descriptor/se_a_ebd_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
from typing import (
List,
Optional,
)

from deepmd.utils.spin import (
Spin,
)

from .descriptor import (
Descriptor,
)
from .se_a import (
DescrptSeA,
)

log = logging.getLogger(__name__)


@Descriptor.register("se_a_tpe_v2")
@Descriptor.register("se_a_ebd_v2")
class DescrptSeAEbdV2(DescrptSeA):
r"""A compressible se_a_ebd model.

This model is a warpper for DescriptorSeA, which set stripped_type_embedding=True.
"""

def __init__(
self,
rcut: float,
rcut_smth: float,
sel: List[str],
neuron: List[int] = [24, 48, 96],
axis_neuron: int = 8,
resnet_dt: bool = False,
trainable: bool = True,
seed: Optional[int] = None,
type_one_side: bool = True,
exclude_types: List[List[int]] = [],
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = "default",
uniform_seed: bool = False,
multi_task: bool = False,
spin: Optional[Spin] = None,
**kwargs,
) -> None:
DescrptSeA.__init__(
self,
rcut,
rcut_smth,
sel,
neuron=neuron,
axis_neuron=axis_neuron,
resnet_dt=resnet_dt,
trainable=trainable,
seed=seed,
type_one_side=type_one_side,
exclude_types=exclude_types,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
uniform_seed=uniform_seed,
multi_task=multi_task,
spin=spin,
stripped_type_embedding=True,
**kwargs,
)
Loading