Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 147 additions & 61 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from deepmd.env import default_tf_session_config
from deepmd.utils.network import embedding_net
from deepmd.utils.tabulate import DeepTabulate

from deepmd.utils.type_embed import embed_atom_type

class DescrptSeA ():
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
Expand Down Expand Up @@ -101,6 +101,11 @@ def __init__ (self,
self.davg = None
self.compress = False
self.place_holders = {}
nei_type = np.array([])
for ii in range(self.ntypes):
nei_type = np.append(nei_type, ii * np.ones(self.sel_a[ii])) # like a mask
self.nei_type = tf.constant(nei_type, dtype = tf.int32)

avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(GLOBAL_NP_FLOAT_PRECISION)
sub_graph = tf.Graph()
Expand Down Expand Up @@ -214,7 +219,7 @@ def compute_input_stats (self,
sumr2 = np.sum(sumr2, axis = 0)
suma2 = np.sum(suma2, axis = 0)
for type_i in range(self.ntypes) :
davgunit = [sumr[type_i]/sumn[type_i], 0, 0, 0]
davgunit = [sumr[type_i]/(sumn[type_i]+1e-15), 0, 0, 0]
dstdunit = [self._compute_std(sumr2[type_i], sumr[type_i], sumn[type_i]),
self._compute_std(suma2[type_i], suma[type_i], sumn[type_i]),
self._compute_std(suma2[type_i], suma[type_i], sumn[type_i]),
Expand Down Expand Up @@ -440,11 +445,15 @@ def _pass_filter(self,
reuse = None,
suffix = '',
trainable = True) :
if input_dict is not None:
type_embedding = input_dict.get('type_embedding', None)
else:
type_embedding = None
start_index = 0
inputs = tf.reshape(inputs, [-1, self.ndescrpt * natoms[0]])
output = []
output_qmat = []
if not self.type_one_side:
if not self.type_one_side and type_embedding is None:
for type_i in range(self.ntypes):
inputs_i = tf.slice (inputs,
[ 0, start_index* self.ndescrpt],
Expand All @@ -460,7 +469,7 @@ def _pass_filter(self,
inputs_i = inputs
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
output.append(layer)
Expand Down Expand Up @@ -516,75 +525,152 @@ def _compute_dstats_sys_smth (self,


def _compute_std (self,sumv2, sumv, sumn) :
if sumn == 0:
return 1e-2
val = np.sqrt(sumv2/sumn - np.multiply(sumv/sumn, sumv/sumn))
if np.abs(val) < 1e-2:
val = 1e-2
return val


def _filter(self,
inputs,
type_input,
natoms,
activation_fn=tf.nn.tanh,
stddev=1.0,
bavg=0.0,
name='linear',
reuse=None,
seed=None,
trainable = True):
def _concat_type_embedding(
self,
xyz_scatter,
nframes,
natoms,
type_embedding,
):
te_out_dim = type_embedding.get_shape().as_list()[-1]
nei_embed = tf.nn.embedding_lookup(type_embedding,tf.cast(self.nei_type,dtype=tf.int32)) #nnei*nchnl
nei_embed = tf.tile(nei_embed,(nframes*natoms[0],1))
nei_embed = tf.reshape(nei_embed,[-1,te_out_dim])
embedding_input = tf.concat([xyz_scatter,nei_embed],1)
if not self.type_one_side:
atm_embed = embed_atom_type(self.ntypes, natoms, type_embedding)
atm_embed = tf.tile(atm_embed,(1,self.nnei))
atm_embed = tf.reshape(atm_embed,[-1,te_out_dim])
embedding_input = tf.concat([embedding_input,atm_embed],1)
return embedding_input


def _filter_lower(
self,
start_index,
incrs_index,
inputs,
nframes,
natoms,
type_embedding=None,
is_exclude = False,
activation_fn = None,
bavg = 0.0,
stddev = 1.0,
seed = None,
trainable = True,
suffix = '',
):
"""
input env matrix, returns R.G
"""
outputs_size = [1] + self.filter_neuron
# cut-out inputs
# with natom x (nei_type_i x 4)
inputs_i = tf.slice (inputs,
[ 0, start_index* 4],
[-1, incrs_index* 4] )
shape_i = inputs_i.get_shape().as_list()
# with (natom x nei_type_i) x 4
inputs_reshape = tf.reshape(inputs_i, [-1, 4])
# with (natom x nei_type_i) x 1
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1])
if type_embedding is not None:
type_embedding = tf.cast(type_embedding, self.filter_precision)
xyz_scatter = self._concat_type_embedding(
xyz_scatter, nframes, natoms, type_embedding)
if self.compress:
raise RuntimeError('compression of type embedded descriptor is not supported at the moment')
# with (natom x nei_type_i) x out_size
if self.compress and (not is_exclude):
info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]]
if self.type_one_side:
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
if (not is_exclude):
xyz_scatter = embedding_net(
xyz_scatter,
self.filter_neuron,
self.filter_precision,
activation_fn = activation_fn,
resnet_dt = self.filter_resnet_dt,
name_suffix = suffix,
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable)
else:
w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=GLOBAL_TF_FLOAT_PRECISION)
xyz_scatter = tf.matmul(xyz_scatter, w)
# natom x nei_type_i x out_size
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1]))
return tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True)


def _filter(
self,
inputs,
type_input,
natoms,
type_embedding = None,
activation_fn=tf.nn.tanh,
stddev=1.0,
bavg=0.0,
name='linear',
reuse=None,
seed=None,
trainable = True):
nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0]
# natom x (nei x 4)
shape = inputs.get_shape().as_list()
outputs_size = [1] + self.filter_neuron
outputs_size_2 = self.n_axis_neuron
with tf.variable_scope(name, reuse=reuse):
start_index = 0
xyz_scatter_total = []
for type_i in range(self.ntypes):
# cut-out inputs
# with natom x (nei_type_i x 4)
inputs_i = tf.slice (inputs,
[ 0, start_index* 4],
[-1, self.sel_a[type_i]* 4] )
start_index += self.sel_a[type_i]
shape_i = inputs_i.get_shape().as_list()
# with (natom x nei_type_i) x 4
inputs_reshape = tf.reshape(inputs_i, [-1, 4])
# with (natom x nei_type_i) x 1
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1])
# with (natom x nei_type_i) x out_size
if self.compress and (type_input, type_i) not in self.exclude_types:
info = [self.lower, self.upper, self.upper * self.table_config[0], self.table_config[1], self.table_config[2], self.table_config[3]]
if self.type_one_side:
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
if type_i == 0:
xyz_scatter_1 = op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
xyz_scatter_1 += op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
if (type_input, type_i) not in self.exclude_types:
xyz_scatter = embedding_net(xyz_scatter,
self.filter_neuron,
self.filter_precision,
activation_fn = activation_fn,
resnet_dt = self.filter_resnet_dt,
name_suffix = "_"+str(type_i),
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable)
else:
w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=GLOBAL_TF_FLOAT_PRECISION)
xyz_scatter = tf.matmul(xyz_scatter, w)
# natom x nei_type_i x out_size
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1]))
# xyz_scatter_total.append(xyz_scatter)
if type_i == 0 :
xyz_scatter_1 = tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True)
else :
xyz_scatter_1 += tf.matmul(tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True)
if type_embedding is None:
for type_i in range(self.ntypes):
ret = self._filter_lower(
start_index, self.sel_a[type_i],
inputs,
nframes,
natoms,
type_embedding = type_embedding,
is_exclude = (type_input, type_i) in self.exclude_types,
activation_fn = activation_fn,
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable,
suffix = "_"+str(type_i))
if type_i == 0:
xyz_scatter_1 = ret
else:
xyz_scatter_1+= ret
start_index += self.sel_a[type_i]
else :
xyz_scatter_1 = self._filter_lower(
start_index, np.cumsum(self.sel_a)[-1],
inputs,
nframes,
natoms,
type_embedding = type_embedding,
is_exclude = False,
activation_fn = activation_fn,
stddev = stddev,
bavg = bavg,
seed = seed,
trainable = trainable)
# natom x nei x outputs_size
# xyz_scatter = tf.concat(xyz_scatter_total, axis=1)
# natom x nei x 4
Expand Down
Loading