Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 97 additions & 25 deletions deepmd/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,18 @@ def _pass_filter(
inputs_i = inputs
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
if len(self.exclude_types):
mask = self.build_type_exclude_mask(
self.exclude_types,
self.ntypes,
self.sel_a,
self.ndescrpt,
atype,
tf.shape(inputs_i)[0],
self.nei_type_vec, # extra input for atten
)
inputs_i *= mask

layer, qmat = self._filter(
inputs_i,
type_i,
Expand Down Expand Up @@ -854,10 +866,7 @@ def _filter_lower(
)
# xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1] // 4, outputs_size[-1]))
else:
# we can safely return the final xyz_scatter filled with zero directly
return tf.cast(
tf.fill((natom, 4, outputs_size[-1]), 0.0), self.filter_precision
)
raise RuntimeError("this should not be touched")
# When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below
# [588 24] -> [588 6 4] correct
# but if sel is zero
Expand Down Expand Up @@ -890,27 +899,6 @@ def _filter(
shape = inputs.get_shape().as_list()
outputs_size = [1] + self.filter_neuron
outputs_size_2 = self.n_axis_neuron
all_excluded = all(
[
(type_input, type_i) in self.exclude_types
for type_i in range(self.ntypes)
]
)
if all_excluded:
# all types are excluded so result and qmat should be zeros
# we can safaly return a zero matrix...
# See also https://stackoverflow.com/a/34725458/9567349
# result: natom x outputs_size x outputs_size_2
# qmat: natom x outputs_size x 3
natom = tf.shape(inputs)[0]
result = tf.cast(
tf.fill((natom, outputs_size_2, outputs_size[-1]), 0.0),
GLOBAL_TF_FLOAT_PRECISION,
)
qmat = tf.cast(
tf.fill((natom, outputs_size[-1], 3), 0.0), GLOBAL_TF_FLOAT_PRECISION
)
return result, qmat

start_index = 0
type_i = 0
Expand Down Expand Up @@ -1007,3 +995,87 @@ def init_variables(
i, suffix, i
)
]

def build_type_exclude_mask(
self,
exclude_types: List[Tuple[int, int]],
ntypes: int,
sel: List[int],
ndescrpt: int,
atype: tf.Tensor,
shape0: tf.Tensor,
nei_type_vec: tf.Tensor,
) -> tf.Tensor:
r"""Build the type exclude mask for the attention descriptor.

Notes
-----
This method has the similiar way to build the type exclude mask as
:meth:`deepmd.descriptor.descriptor.Descriptor.build_type_exclude_mask`.
The mathmatical expression has been explained in that method.
The difference is that the attention descriptor has provided the type of
the neighbors (idx_j) that is not in order, so we use it from an extra
input.

Parameters
----------
exclude_types : List[Tuple[int, int]]
The list of excluded types, e.g. [(0, 1), (1, 0)] means the interaction
between type 0 and type 1 is excluded.
ntypes : int
The number of types.
sel : List[int]
The list of the number of selected neighbors for each type.
ndescrpt : int
The number of descriptors for each atom.
atype : tf.Tensor
The type of atoms, with the size of shape0.
shape0 : tf.Tensor
The shape of the first dimension of the inputs, which is equal to
nsamples * natoms.
nei_type_vec : tf.Tensor
The type of neighbors, with the size of (shape0, nnei).

Returns
-------
tf.Tensor
The type exclude mask, with the shape of (shape0, ndescrpt), and the
precision of GLOBAL_TF_FLOAT_PRECISION. The mask has the value of 1 if the
interaction between two types is not excluded, and 0 otherwise.

See Also
--------
deepmd.descriptor.descriptor.Descriptor.build_type_exclude_mask
"""
# generate a mask
# op returns ntypes when the neighbor doesn't exist, so we need to add 1
type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in exclude_types else 0
for tt_i in range(ntypes + 1)
]
for tt_j in range(ntypes)
],
dtype=bool,
)
type_mask = tf.convert_to_tensor(type_mask, dtype=GLOBAL_TF_FLOAT_PRECISION)
type_mask = tf.reshape(type_mask, [-1])

# (nsamples * natoms, 1)
atype_expand = tf.reshape(atype, [-1, 1])
# (nsamples * natoms, ndescrpt)
idx_i = tf.tile(atype_expand * (ntypes + 1), (1, ndescrpt))
# idx_j has been provided by atten op
# (nsamples * natoms, nnei, 1)
idx_j = tf.reshape(nei_type_vec, [shape0, sel[0], 1])
# (nsamples * natoms, nnei, ndescrpt // nnei)
idx_j = tf.tile(idx_j, (1, 1, ndescrpt // sel[0]))
# (nsamples * natoms, ndescrpt)
idx_j = tf.reshape(idx_j, [shape0, ndescrpt])
idx = idx_i + idx_j
idx = tf.reshape(idx, [-1])
mask = tf.nn.embedding_lookup(type_mask, idx)
# same as inputs_i, (nsamples * natoms, ndescrpt)
mask = tf.reshape(mask, [-1, ndescrpt])
return mask
75 changes: 75 additions & 0 deletions source/tests/test_model_se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,78 @@ def test_model(self):
np.testing.assert_almost_equal(e, refe, places)
np.testing.assert_almost_equal(f, reff, places)
np.testing.assert_almost_equal(v, refv, places)

def test_exclude_types(self):
"""In this test, we make type 0 has no interaction with type 0 and type 1,
so the descriptor should be zero for type 0 atoms.
"""
jfile = "water_se_atten.json"
jdata = j_loader(jfile)

systems = j_must_have(jdata, "systems")
set_pfx = j_must_have(jdata, "set_prefix")
batch_size = j_must_have(jdata, "batch_size")
test_size = j_must_have(jdata, "numb_test")
batch_size = 1
test_size = 1
rcut = j_must_have(jdata["model"]["descriptor"], "rcut")
ntypes = 2

data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt=None)

test_data = data.get_test()
numb_test = 1

# set parameters
jdata["model"]["descriptor"]["exclude_types"] = [[0, 0], [0, 1]]

t_prop_c = tf.placeholder(tf.float32, [5], name="t_prop_c")
t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name="i_coord")
t_type = tf.placeholder(tf.int32, [None], name="i_type")
t_natoms = tf.placeholder(tf.int32, [ntypes + 2], name="i_natoms")
t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name="i_box")
t_mesh = tf.placeholder(tf.int32, [None], name="i_mesh")
is_training = tf.placeholder(tf.bool)

# successful
descrpt = DescrptSeAtten(ntypes=ntypes, **jdata["model"]["descriptor"])
typeebd_param = jdata["model"]["type_embedding"]
typeebd = TypeEmbedNet(
neuron=typeebd_param["neuron"],
activation_function=None,
resnet_dt=typeebd_param["resnet_dt"],
seed=typeebd_param["seed"],
uniform_seed=True,
padding=True,
)
type_embedding = typeebd.build(
ntypes,
)
dout = descrpt.build(
t_coord,
t_type,
t_natoms,
t_box,
t_mesh,
{"type_embedding": type_embedding},
reuse=False,
suffix="_se_atten_exclude_types",
)

feed_dict_test1 = {
t_prop_c: test_data["prop_c"],
t_coord: np.reshape(test_data["coord"][:numb_test, :], [-1]),
t_box: test_data["box"][:numb_test, :],
t_type: np.reshape(test_data["type"][:numb_test, :], [-1]),
t_natoms: test_data["natoms_vec"],
t_mesh: test_data["default_mesh"],
is_training: False,
}

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
[des] = sess.run([dout], feed_dict=feed_dict_test1)

np.testing.assert_almost_equal(des[:, 0:2], 0.0, 10)
with self.assertRaises(AssertionError):
np.testing.assert_almost_equal(des[:, 2:6], 0.0, 10)