diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index c44a6c1d6e..a024783ec4 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -2,7 +2,7 @@ from typing import Optional, Any, Dict, List, Tuple import numpy as np -from deepmd.env import tf +from deepmd.env import tf, GLOBAL_TF_FLOAT_PRECISION from deepmd.utils import Plugin, PluginVariant @@ -409,3 +409,92 @@ def pass_tensors_from_frz_model(self, :meth:`get_tensor_names`. """ raise NotImplementedError("Descriptor %s doesn't support this method!" % type(self).__name__) + + def build_type_exclude_mask(self, + exclude_types: List[Tuple[int, int]], + ntypes: int, + sel: List[int], + ndescrpt: int, + atype: tf.Tensor, + shape0: tf.Tensor) -> tf.Tensor: + r"""Build the type exclude mask for the descriptor. + + Notes + ----- + To exclude the interaction between two types, the derivative of energy with + respect to distances (or angles) between two atoms should be zero[1]_, i.e. + + .. math:: + \forall i \in \text{type 1}, j \in \text{type 2}, + \frac{\partial{E}}{\partial{r_{ij}}} = 0 + + When embedding networks between every two types are built, we can just remove + that network. But when `type_one_side` is enabled, a network may be built for + multiple pairs of types. In this case, we need to build a mask to exclude the + interaction between two types. + + The mask assumes the descriptors are sorted by neighbro type with the fixed + number of given `sel` and each neighbor has the same number of descriptors + (for example 4). + + Parameters + ---------- + exclude_types : List[Tuple[int, int]] + The list of excluded types, e.g. [(0, 1), (1, 0)] means the interaction + between type 0 and type 1 is excluded. + ntypes : int + The number of types. + sel : List[int] + The list of the number of selected neighbors for each type. + ndescrpt : int + The number of descriptors for each atom. + atype : tf.Tensor + The type of atoms, with the size of shape0. + shape0 : tf.Tensor + The shape of the first dimension of the inputs, which is equal to + nsamples * natoms. + + Returns + ------- + tf.Tensor + The type exclude mask, with the shape of (shape0, ndescrpt), and the + precision of GLOBAL_TF_FLOAT_PRECISION. The mask has the value of 1 if the + interaction between two types is not excluded, and 0 otherwise. + + References + ---------- + .. [1] Jinzhe Zeng, Timothy J. Giese, ̧Sölen Ekesan, Darrin M. York, + Development of Range-Corrected Deep Learning Potentials for Fast, + Accurate Quantum Mechanical/molecular Mechanical Simulations of + Chemical Reactions in Solution, J. Chem. Theory Comput., 2021, + 17 (11), 6993-7009. + """ + # generate a mask + type_mask = np.array([ + [1 if (tt_i, tt_j) not in exclude_types else 0 + for tt_i in range(ntypes)] + for tt_j in range(ntypes) + ], dtype = bool) + type_mask = tf.convert_to_tensor(type_mask, dtype = GLOBAL_TF_FLOAT_PRECISION) + type_mask = tf.reshape(type_mask, [-1]) + + # (nsamples * natoms, 1) + atype_expand = tf.reshape(atype, [-1, 1]) + # (nsamples * natoms, ndescrpt) + idx_i = tf.tile(atype_expand * ntypes, (1, ndescrpt)) + ndescrpt_per_neighbor = ndescrpt // np.sum(sel) + # assume the number of neighbors for each type is the same + assert ndescrpt_per_neighbor * np.sum(sel) == ndescrpt + atype_descrpt = np.repeat(np.arange(ntypes), np.array(sel) * ndescrpt_per_neighbor) + atype_descrpt = tf.convert_to_tensor(atype_descrpt, dtype = tf.int32) + # (1, ndescrpt) + atype_descrpt = tf.reshape(atype_descrpt, (1, ndescrpt)) + # (nsamples * natoms, ndescrpt) + idx_j = tf.tile(atype_descrpt, (shape0, 1)) + # the index to mask (row index * ntypes + col index) + idx = idx_i + idx_j + idx = tf.reshape(idx, [-1]) + mask = tf.nn.embedding_lookup(type_mask, idx) + # same as inputs_i, (nsamples * natoms, ndescrpt) + mask = tf.reshape(mask, [-1, ndescrpt]) + return mask diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index dbc9b41603..1f92030926 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -591,18 +591,13 @@ def _pass_filter(self, inputs = tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]) output = [] output_qmat = [] - if not (self.type_one_side and len(self.exclude_types) == 0) and type_embedding is None: + if not self.type_one_side and type_embedding is None: for type_i in range(self.ntypes): inputs_i = tf.slice (inputs, [ 0, start_index, 0], [-1, natoms[2+type_i], -1] ) inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - if self.type_one_side: - # reuse NN parameters for all types to support type_one_side along with exclude_types - reuse = tf.AUTO_REUSE - filter_name = 'filter_type_all'+suffix - else: - filter_name = 'filter_type_'+str(type_i)+suffix + filter_name = 'filter_type_'+str(type_i)+suffix layer, qmat = self._filter(inputs_i, type_i, name=filter_name, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i], self.get_dim_out()]) qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i], self.get_dim_rot_mat_1() * 3]) @@ -615,6 +610,17 @@ def _pass_filter(self, type_i = -1 if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: inputs_i = descrpt2r4(inputs_i, natoms) + if len(self.exclude_types): + mask = self.build_type_exclude_mask( + self.exclude_types, + self.ntypes, + self.sel_a, + self.ndescrpt, + atype, + tf.shape(inputs_i)[0], + ) + inputs_i *= mask + layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()]) qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0], self.get_dim_rot_mat_1() * 3]) diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index 163fe92f00..5773c47202 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -393,7 +393,7 @@ def build (self, tf.summary.histogram('rij', self.rij) tf.summary.histogram('nlist', self.nlist) - self.dout = self._pass_filter(self.descrpt_reshape, natoms, suffix = suffix, reuse = reuse, trainable = self.trainable) + self.dout = self._pass_filter(self.descrpt_reshape, atype, natoms, suffix = suffix, reuse = reuse, trainable = self.trainable) tf.summary.histogram('embedding_net_output', self.dout) return self.dout @@ -448,6 +448,7 @@ def prod_force_virial(self, def _pass_filter(self, inputs, + atype, natoms, reuse = None, suffix = '', @@ -455,18 +456,13 @@ def _pass_filter(self, start_index = 0 inputs = tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]) output = [] - if not (self.type_one_side and len(self.exclude_types) == 0): + if not self.type_one_side: for type_i in range(self.ntypes): inputs_i = tf.slice (inputs, [ 0, start_index, 0], [-1, natoms[2+type_i], -1] ) inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - if self.type_one_side: - # reuse NN parameters for all types to support type_one_side along with exclude_types - reuse = tf.AUTO_REUSE - filter_name = 'filter_type_all'+suffix - else: - filter_name = 'filter_type_'+str(type_i)+suffix + filter_name = 'filter_type_'+str(type_i)+suffix layer = self._filter_r(inputs_i, type_i, name=filter_name, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i], self.get_dim_out()]) output.append(layer) @@ -475,6 +471,16 @@ def _pass_filter(self, inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 + if len(self.exclude_types): + mask = self.build_type_exclude_mask( + self.exclude_types, + self.ntypes, + self.sel_r, + self.ndescrpt, + atype, + tf.shape(inputs_i)[0], + ) + inputs_i *= mask layer = self._filter_r(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()]) output.append(layer) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 97f21b7858..ebac8dbb07 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -518,14 +518,13 @@ def build (self, outs = tf.concat(outs_list, axis = 1) # with type embedding else: - if len(self.atom_ener) > 0: - raise RuntimeError("setting atom_ener is not supported by type embedding") atype_embed = tf.cast(atype_embed, self.fitting_precision) type_shape = atype_embed.get_shape().as_list() inputs = tf.concat( [tf.reshape(inputs,[-1,self.dim_descrpt]),atype_embed], axis=1 ) + original_dim_descrpt = self.dim_descrpt self.dim_descrpt = self.dim_descrpt + type_shape[1] inputs = tf.reshape(inputs, [-1, natoms[0], self.dim_descrpt]) final_layer = self._build_lower( @@ -533,6 +532,20 @@ def build (self, inputs, fparam, aparam, bias_atom_e=0.0, suffix=suffix, reuse=reuse ) + if len(self.atom_ener): + # remove contribution in vacuum + inputs_zero = tf.concat( + [tf.reshape(inputs_zero, [-1, original_dim_descrpt]), atype_embed], + axis=1 + ) + inputs_zero = tf.reshape(inputs_zero, [-1, natoms[0], self.dim_descrpt]) + zero_layer = self._build_lower( + 0, natoms[0], + inputs_zero, fparam, aparam, + bias_atom_e=0.0, suffix=suffix, reuse=True, + ) + # atomic energy will be stored in `self.t_bias_atom_e` which is not trainable + final_layer -= zero_layer outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]]) # add bias self.atom_ener_before = outs diff --git a/source/tests/test_model_se_a.py b/source/tests/test_model_se_a.py index 8180a0c4bd..d7227bd54e 100644 --- a/source/tests/test_model_se_a.py +++ b/source/tests/test_model_se_a.py @@ -9,6 +9,7 @@ from deepmd.fit import EnerFitting from deepmd.model import EnerModel from deepmd.common import j_must_have +from deepmd.utils.type_embed import TypeEmbedNet GLOBAL_ENER_FLOAT_PRECISION = tf.float64 GLOBAL_TF_FLOAT_PRECISION = tf.float64 @@ -208,3 +209,102 @@ def test_model(self): np.testing.assert_almost_equal(e, refe, places) np.testing.assert_almost_equal(f, reff, places) np.testing.assert_almost_equal(v, refv, places) + + def test_model_atom_ener_type_embedding(self): + """Test atom ener with type embedding""" + jfile = 'water_se_a.json' + jdata = j_loader(jfile) + set_atom_ener = [0.02, 0.01] + jdata['model']['fitting_net']['atom_ener'] = set_atom_ener + jdata['model']['type_embeding'] = {"neuron": [2]} + + sys = dpdata.LabeledSystem() + sys.data['atom_names'] = ['foo', 'bar'] + sys.data['coords'] = np.array([0, 0, 0, 0, 0, 0]) + sys.data['atom_types'] = [0] + sys.data['cells'] = np.array([np.eye(3) * 30, np.eye(3) * 30]) + nframes = 2 + natoms = 1 + sys.data['coords'] = sys.data['coords'].reshape([nframes,natoms,3]) + sys.data['cells'] = sys.data['cells'].reshape([nframes,3,3]) + sys.data['energies'] = np.zeros([nframes,1]) + sys.data['forces'] = np.zeros([nframes,natoms,3]) + sys.to_deepmd_npy('system', prec=np.float64) + + systems = j_must_have(jdata, 'systems') + set_pfx = j_must_have(jdata, 'set_prefix') + batch_size = j_must_have(jdata, 'batch_size') + test_size = j_must_have(jdata, 'numb_test') + batch_size = 1 + test_size = 1 + stop_batch = j_must_have(jdata, 'stop_batch') + rcut = j_must_have (jdata['model']['descriptor'], 'rcut') + + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt = None) + test_data = data.get_test () + numb_test = 1 + + typeebd = TypeEmbedNet(**jdata['model']['type_embeding']) + jdata['model']['descriptor'].pop('type', None) + descrpt = DescrptSeA(**jdata['model']['descriptor'], uniform_seed=True) + jdata['model']['fitting_net']['descrpt'] = descrpt + fitting = EnerFitting(**jdata['model']['fitting_net'], uniform_seed=True) + model = EnerModel(descrpt, fitting, typeebd=typeebd) + + test_data['natoms_vec'] = [1, 1, 1, 0] + + input_data = {'coord' : [test_data['coord']], + 'box': [test_data['box']], + 'type': [test_data['type']], + 'natoms_vec' : [test_data['natoms_vec']], + 'default_mesh' : [test_data['default_mesh']] + } + model._compute_input_stat(input_data) + model.fitting.bias_atom_e = np.array(set_atom_ener) + + t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c') + t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name='t_energy') + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name='i_coord') + t_type = tf.placeholder(tf.int32, [None], name='i_type') + t_natoms = tf.placeholder(tf.int32, [model.ntypes+2], name='i_natoms') + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name='i_box') + t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh') + is_training = tf.placeholder(tf.bool) + t_fparam = None + + model_pred \ + = model.build (t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + t_fparam, + suffix = "se_a_atom_ener_type_embbed_0", + reuse = False) + energy = model_pred['energy'] + force = model_pred['force'] + virial = model_pred['virial'] + + feed_dict_test = {t_prop_c: test_data['prop_c'], + t_energy: test_data['energy'] [:numb_test], + t_coord: np.reshape(test_data['coord'] [:numb_test, :], [-1]), + t_box: test_data['box'] [:numb_test, :], + t_type: np.reshape([0], [-1]), + t_natoms: [1, 1, 1, 0], + t_mesh: test_data['default_mesh'], + is_training: False + } + sess = self.test_session().__enter__() + sess.run(tf.global_variables_initializer()) + [e, f, v] = sess.run([energy, force, virial], + feed_dict = feed_dict_test) + self.assertAlmostEqual(e[0], set_atom_ener[0], places = 10) + + feed_dict_test[t_type] = np.reshape([1], [-1]) + feed_dict_test[t_natoms] = [1, 1, 0, 1] + [e, f, v] = sess.run([energy, force, virial], + feed_dict = feed_dict_test) + self.assertAlmostEqual(e[0], set_atom_ener[1], places = 10) + + + diff --git a/source/tests/test_type_one_side.py b/source/tests/test_type_one_side.py index b3dc874a63..c457b5dbeb 100644 --- a/source/tests/test_type_one_side.py +++ b/source/tests/test_type_one_side.py @@ -113,7 +113,115 @@ def test_descriptor_one_side_exclude_types(self): feed_dict_test1[t_natoms] = new_natoms1 feed_dict_test2[t_type] = np.reshape(new_type2[:numb_test, :], [-1]) feed_dict_test2[t_natoms] = new_natoms2 - print(feed_dict_test1,feed_dict_test2) + + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + [model_dout1] = sess.run([dout], + feed_dict = feed_dict_test1) + [model_dout2] = sess.run([dout], + feed_dict = feed_dict_test2) + [model_dout1_failed] = sess.run([dout_failed], + feed_dict = feed_dict_test1) + [model_dout2_failed] = sess.run([dout_failed], + feed_dict = feed_dict_test2) + model_dout1 = model_dout1.reshape([6, -1]) + model_dout2 = model_dout2.reshape([6, -1]) + model_dout1_failed = model_dout1_failed.reshape([6, -1]) + model_dout2_failed = model_dout2_failed.reshape([6, -1]) + + np.testing.assert_almost_equal(model_dout1[0], model_dout2[0], 10) + with self.assertRaises(AssertionError): + np.testing.assert_almost_equal(model_dout1_failed[0], model_dout2_failed[0], 10) + + + def test_se_r_one_side_exclude_types(self): + """se_r + """ + jfile = 'water_se_r.json' + jdata = j_loader(jfile) + + systems = j_must_have(jdata, 'systems') + set_pfx = j_must_have(jdata, 'set_prefix') + batch_size = j_must_have(jdata, 'batch_size') + test_size = j_must_have(jdata, 'numb_test') + batch_size = 1 + test_size = 1 + rcut = j_must_have (jdata['model']['descriptor'], 'rcut') + sel = j_must_have (jdata['model']['descriptor'], 'sel') + ntypes=len(sel) + + data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt = None) + + test_data = data.get_test () + numb_test = 1 + + # set parameters + jdata['model']['descriptor']['neuron'] = [5, 5, 5] + jdata['model']['descriptor']['type_one_side'] = True + jdata['model']['descriptor']['exclude_types'] = [[0, 0]] + + t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c') + t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name='i_coord') + t_type = tf.placeholder(tf.int32, [None], name='i_type') + t_natoms = tf.placeholder(tf.int32, [ntypes+2], name='i_natoms') + t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name='i_box') + t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh') + is_training = tf.placeholder(tf.bool) + + + # successful + descrpt = Descriptor(**jdata['model']['descriptor']) + dout \ + = descrpt.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + {}, + reuse = False, + suffix = "_se_r_1side_exclude_types" + ) + # failed + descrpt_failed = Descriptor(**{**jdata['model']['descriptor'], "type_one_side": False}) + dout_failed \ + = descrpt_failed.build( + t_coord, + t_type, + t_natoms, + t_box, + t_mesh, + {}, + reuse = False, + suffix = "_se_r_1side_exclude_types_failed" + ) + + feed_dict_test1 = {t_prop_c: test_data['prop_c'], + t_coord: np.reshape(test_data['coord'] [:numb_test, :], [-1]), + t_box: test_data['box'] [:numb_test, :], + t_type: np.reshape(test_data['type'] [:numb_test, :], [-1]), + t_natoms: test_data['natoms_vec'], + t_mesh: test_data['default_mesh'], + is_training: False} + feed_dict_test2 = feed_dict_test1.copy() + # original type: 0 0 1 1 1 1 + # current: 0 1 1 1 1 1 + # current: 1 1 1 1 1 1 + new_natoms1 = test_data['natoms_vec'].copy() + new_natoms1[2] = 1 + new_natoms1[3] = 5 + new_type1 = test_data['type'].copy() + new_type1[:numb_test, 0] = 0 + new_type1[:numb_test, 1:6] = 1 + new_natoms2 = test_data['natoms_vec'].copy() + new_natoms2[2] = 0 + new_natoms2[3] = 6 + new_type2 = test_data['type'].copy() + new_type2[:numb_test] = 1 + feed_dict_test1[t_type] = np.reshape(new_type1[:numb_test, :], [-1]) + feed_dict_test1[t_natoms] = new_natoms1 + feed_dict_test2[t_type] = np.reshape(new_type2[:numb_test, :], [-1]) + feed_dict_test2[t_natoms] = new_natoms2 with self.test_session() as sess: sess.run(tf.global_variables_initializer())