diff --git a/source/train/DescrptSeA.py b/source/train/DescrptSeA.py index 890fe24672..e46265231f 100644 --- a/source/train/DescrptSeA.py +++ b/source/train/DescrptSeA.py @@ -17,6 +17,7 @@ def __init__ (self, jdata): .add('resnet_dt',bool, default = False) \ .add('trainable',bool, default = True) \ .add('seed', int) \ + .add('type_one_side', bool, default = False) \ .add('exclude_types', list, default = []) \ .add('set_davg_zero', bool, default = False) \ .add('activation_function', str, default = 'tanh') \ @@ -39,6 +40,9 @@ def __init__ (self, jdata): self.exclude_types.add((tt[0], tt[1])) self.exclude_types.add((tt[1], tt[0])) self.set_davg_zero = class_data['set_davg_zero'] + self.type_one_side = class_data['type_one_side'] + if self.type_one_side and len(exclude_types) != 0: + raise RuntimeError('"type_one_side" is not compatible with "exclude_types"') # descrpt config self.sel_r = [ 0 for ii in range(len(self.sel_a)) ] @@ -244,17 +248,27 @@ def _pass_filter(self, inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]]) output = [] output_qmat = [] - for type_i in range(self.ntypes): - inputs_i = tf.slice (inputs, - [ 0, start_index* self.ndescrpt], - [-1, natoms[2+type_i]* self.ndescrpt] ) + if not self.type_one_side: + for type_i in range(self.ntypes): + inputs_i = tf.slice (inputs, + [ 0, start_index* self.ndescrpt], + [-1, natoms[2+type_i]* self.ndescrpt] ) + inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) + layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn) + layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) + qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3]) + output.append(layer) + output_qmat.append(qmat) + start_index += natoms[2+type_i] + else : + inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn) - layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) - qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3]) + type_i = -1 + layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn) + layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) + qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3]) output.append(layer) output_qmat.append(qmat) - start_index += natoms[2+type_i] output = tf.concat(output, axis = 1) output_qmat = tf.concat(output_qmat, axis = 1) return output, output_qmat