diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 3b5d1d0922..60de701886 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -586,6 +586,7 @@ def _filter_lower( [ 0, start_index* 4], [-1, incrs_index* 4] ) shape_i = inputs_i.get_shape().as_list() + natom = tf.shape(inputs_i)[0] # with (natom x nei_type_i) x 4 inputs_reshape = tf.reshape(inputs_i, [-1, 4]) # with (natom x nei_type_i) x 1 @@ -603,7 +604,7 @@ def _filter_lower( net = 'filter_-1_net_' + str(type_i) else: net = 'filter_' + str(type_input) + '_net_' + str(type_i) - return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) + return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) else: if (not is_exclude): xyz_scatter = embedding_net( @@ -620,11 +621,16 @@ def _filter_lower( uniform_seed = self.uniform_seed) if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift else: - w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=GLOBAL_TF_FLOAT_PRECISION) - xyz_scatter = tf.matmul(xyz_scatter, w) + # we can safely return the final xyz_scatter filled with zero directly + return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) # natom x nei_type_i x out_size xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) - return tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True) + # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below + # [588 24] -> [588 6 4] correct + # but if sel is zero + # [588 0] -> [147 0 4] incorrect; the correct one is [588 0 4] + # So we need to explicitly assign the shape to tf.shape(inputs_i)[0] instead of -1 + return tf.matmul(tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True) def _filter( @@ -644,6 +650,18 @@ def _filter( shape = inputs.get_shape().as_list() outputs_size = [1] + self.filter_neuron outputs_size_2 = self.n_axis_neuron + all_excluded = all([(type_input, type_i) in self.exclude_types for type_i in range(self.ntypes)]) + if all_excluded: + # all types are excluded so result and qmat should be zeros + # we can safaly return a zero matrix... + # See also https://stackoverflow.com/a/34725458/9567349 + # result: natom x outputs_size x outputs_size_2 + # qmat: natom x outputs_size x 3 + natom = tf.shape(inputs)[0] + result = tf.cast(tf.fill((natom, outputs_size_2, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) + qmat = tf.cast(tf.fill((natom, outputs_size[-1], 3), 0.), GLOBAL_TF_FLOAT_PRECISION) + return result, qmat + with tf.variable_scope(name, reuse=reuse): start_index = 0 type_i = 0 @@ -665,7 +683,8 @@ def _filter( suffix = "_"+str(type_i)) if type_i == 0: xyz_scatter_1 = ret - else: + elif (type_input, type_i) not in self.exclude_types: + # add zero is meaningless; skip xyz_scatter_1+= ret start_index += self.sel_a[type_i] else : diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index 40bbc21593..f362302f11 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -478,11 +478,11 @@ def _filter_r(self, trainable = trainable, uniform_seed = self.uniform_seed) if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift + # natom x nei_type_i x out_size + xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1])) else: - w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=GLOBAL_TF_FLOAT_PRECISION) - xyz_scatter = tf.matmul(xyz_scatter, w) - # natom x nei_type_i x out_size - xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1])) + natom = tf.shape(inputs)[0] + xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) xyz_scatter_total.append(xyz_scatter) # natom x nei x outputs_size