Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a6ecd4b
Merge pull request #11 from deepmodeling/devel
iProzd Jun 1, 2021
aab1931
add gpu op unittest in source/tests
iProzd Jun 3, 2021
d5dca99
fix bug in #713
iProzd Jun 14, 2021
91b68f5
Merge pull request #12 from deepmodeling/devel
iProzd Jun 17, 2021
9f9fdae
Fix bug of empty input in gelu.cu
iProzd Jun 26, 2021
a958bcb
Merge pull request #13 from deepmodeling/devel
iProzd Jun 26, 2021
5144a53
Merge pull request #14 from deepmodeling/devel
iProzd Jul 1, 2021
46ab4e8
Merge pull request #15 from deepmodeling/devel
iProzd Jul 16, 2021
7372db6
Merge pull request #16 from deepmodeling/devel
iProzd Jul 26, 2021
7072344
Merge pull request #17 from deepmodeling/devel
iProzd Aug 9, 2021
978b37c
Merge pull request #18 from deepmodeling/devel
iProzd Aug 29, 2021
1cdf4c1
Merge pull request #19 from deepmodeling/devel
iProzd Sep 8, 2021
42ce5b3
Merge pull request #21 from deepmodeling/devel
iProzd Jan 30, 2022
b3ad9a5
Merge pull request #22 from deepmodeling/devel
iProzd May 28, 2022
d807db9
Merge pull request #23 from deepmodeling/devel
iProzd Jul 27, 2022
cb035d4
Upload attention based model
iProzd Aug 19, 2022
c31b283
Fix bugs in DPA-1 when trying to compress other model
iProzd Aug 19, 2022
c5d9d68
Add layernorm support for tf early version
iProzd Aug 20, 2022
b66c2ff
Upload the model image
iProzd Aug 20, 2022
fbfa5fe
Create train-se-atten.md
iProzd Aug 20, 2022
7854dd4
Update the download links of DPA-1 example
iProzd Aug 20, 2022
055401e
Update train-se-atten.md
iProzd Aug 20, 2022
3961e0d
Update train-se-atten.md
iProzd Aug 20, 2022
f04567f
Update train-se-atten.md
iProzd Aug 20, 2022
c0247c1
Update train-se-atten.md
iProzd Aug 20, 2022
087b464
Deal with the required changes
iProzd Aug 23, 2022
9203f02
Fix typo in data_system.py
iProzd Aug 23, 2022
a54524f
Add docs in toc
iProzd Aug 23, 2022
56e0386
Add docs to ntype and nmask
iProzd Aug 24, 2022
6a77184
Fix duplicated period in each doc_activation_function
iProzd Aug 24, 2022
4bbccfe
Optimized mixed_type format
iProzd Aug 24, 2022
798d864
Update common.py
iProzd Aug 24, 2022
3912df5
Git reset common.py of format editing.
iProzd Aug 25, 2022
a6f8258
Match default args of tebd in se_atten
iProzd Aug 25, 2022
becf0cb
Change default activation_function back to 'tanh' in tebd.
iProzd Aug 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ A full [document](doc/train/train-input-auto.rst) on options in the training inp
- [Descriptor `"se_e2_a"`](doc/model/train-se-e2-a.md)
- [Descriptor `"se_e2_r"`](doc/model/train-se-e2-r.md)
- [Descriptor `"se_e3"`](doc/model/train-se-e3.md)
- [Descriptor `"se_atten"`](doc/model/train-se-atten.md)
- [Descriptor `"hybrid"`](doc/model/train-hybrid.md)
- [Descriptor `sel`](doc/model/sel.md)
- [Fit energy](doc/model/train-energy.md)
Expand Down
6 changes: 4 additions & 2 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ def j_loader(filename: Union[str, Path]) -> Dict[str, Any]:


def get_activation_func(
activation_fn: "_ACTIVATION",
) -> Callable[[tf.Tensor], tf.Tensor]:
activation_fn: Union["_ACTIVATION", None],
) -> Union[Callable[[tf.Tensor], tf.Tensor], None]:
"""Get activation function callable based on string name.

Parameters
Expand All @@ -424,6 +424,8 @@ def get_activation_func(
RuntimeError
if unknown activation function is specified
"""
if activation_fn is None or activation_fn in ['none', 'None']:
return None
if activation_fn not in ACTIVATION_FN_DICT:
raise RuntimeError(f"{activation_fn} is not a valid activation function")
return ACTIVATION_FN_DICT[activation_fn]
Expand Down
1 change: 1 addition & 0 deletions deepmd/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .se_a_ef import DescrptSeAEf
from .se_a_ef import DescrptSeAEfLower
from .loc_frame import DescrptLocFrame
from .se_atten import DescrptSeAtten
777 changes: 777 additions & 0 deletions deepmd/descriptor/se_atten.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,10 @@ def train(

jdata = normalize(jdata)

if not is_compress and not skip_neighbor_stat:
if jdata['model']['descriptor']['type'] in ['se_atten'] and isinstance(jdata['model']['descriptor']['sel'], list):
jdata['model']['descriptor']['sel'] = sum(jdata['model']['descriptor']['sel'])

if not is_compress and not skip_neighbor_stat and jdata['model']['descriptor']['type'] not in ['se_atten']:
jdata = update_sel(jdata)

with open(output, "w") as fp:
Expand Down
53 changes: 39 additions & 14 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def get_numb_aparam(self) -> int:
return self.numb_fparam

def compute_output_stats(self,
all_stat: dict
all_stat: dict,
mixed_type: bool = False
) -> None:
"""
Compute the ouput statistics
Expand All @@ -179,10 +180,14 @@ def compute_output_stats(self,
must have the following components:
all_stat['energy'] of shape n_sys x n_batch x n_frame
can be prepared by model.make_stat_input
mixed_type
Whether to perform the mixed_type mode.
If True, the input data has the mixed_type format (see doc/model/train_se_atten.md),
in which frames in a system may have different natoms_vec(s), with the same nloc.
"""
self.bias_atom_e = self._compute_output_stats(all_stat, rcond = self.rcond)
self.bias_atom_e = self._compute_output_stats(all_stat, rcond=self.rcond, mixed_type=mixed_type)

def _compute_output_stats(self, all_stat, rcond = 1e-3):
def _compute_output_stats(self, all_stat, rcond=1e-3, mixed_type=False):
data = all_stat['energy']
# data[sys_idx][batch_idx][frame_idx]
sys_ener = np.array([])
Expand All @@ -193,11 +198,22 @@ def _compute_output_stats(self, all_stat, rcond = 1e-3):
sys_data.append(data[ss][ii][jj])
sys_data = np.concatenate(sys_data)
sys_ener = np.append(sys_ener, np.average(sys_data))
data = all_stat['natoms_vec']
sys_tynatom = np.array([])
nsys = len(data)
for ss in range(len(data)):
sys_tynatom = np.append(sys_tynatom, data[ss][0].astype(np.float64))
if mixed_type:
data = all_stat['real_natoms_vec']
nsys = len(data)
for ss in range(len(data)):
tmp_tynatom = []
for ii in range(len(data[ss])):
for jj in range(len(data[ss][ii])):
tmp_tynatom.append(data[ss][ii][jj].astype(np.float64))
tmp_tynatom = np.average(np.array(tmp_tynatom), axis=0)
sys_tynatom = np.append(sys_tynatom, tmp_tynatom)
else:
data = all_stat['natoms_vec']
nsys = len(data)
for ss in range(len(data)):
sys_tynatom = np.append(sys_tynatom, data[ss][0].astype(np.float64))
sys_tynatom = np.reshape(sys_tynatom, [nsys,-1])
sys_tynatom = sys_tynatom[:,2:]
if len(self.atom_ener) > 0:
Expand Down Expand Up @@ -402,6 +418,11 @@ def build (self,
t_daparam = tf.constant(self.numb_aparam,
name = 'daparam',
dtype = tf.int32)
self.t_bias_atom_e = tf.get_variable('t_bias_atom_e',
self.bias_atom_e.shape,
dtype=GLOBAL_TF_FLOAT_PRECISION,
trainable=False,
initializer=tf.constant_initializer(self.bias_atom_e))
if self.numb_fparam > 0:
t_fparam_avg = tf.get_variable('t_fparam_avg',
self.numb_fparam,
Expand Down Expand Up @@ -452,12 +473,16 @@ def build (self,
aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]])

type_embedding = input_dict.get('type_embedding', None)
atype = input_dict.get('atype', None)
if type_embedding is not None:
atype_embed = embed_atom_type(self.ntypes, natoms, type_embedding)
atype_embed = tf.tile(atype_embed,[tf.shape(inputs)[0],1])
atype_nall = tf.reshape(atype, [-1, natoms[1]])
self.atype_nloc = tf.reshape(tf.slice(atype_nall, [0, 0], [-1, natoms[0]]), [-1]) ## lammps will make error
atype_embed = tf.nn.embedding_lookup(type_embedding, self.atype_nloc)
else:
atype_embed = None

self.atype_embed = atype_embed

if atype_embed is None:
start_index = 0
outs_list = []
Expand Down Expand Up @@ -503,11 +528,11 @@ def build (self,
bias_atom_e=0.0, suffix=suffix, reuse=reuse
)
outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]])
# add atom energy bias; TF will broadcast to all batches
# tf.repeat is avaiable in TF>=2.1 or TF 1.15
_TF_VERSION = Version(TF_VERSION)
if (Version('1.15') <= _TF_VERSION < Version('2') or _TF_VERSION >= Version('2.1')) and self.bias_atom_e is not None:
outs += tf.repeat(tf.Variable(self.bias_atom_e, dtype=self.fitting_precision, trainable=False, name="bias_atom_ei"), natoms[2:])
# add bias
self.atom_ener_before = outs
self.add_type = tf.reshape(tf.nn.embedding_lookup(self.t_bias_atom_e, self.atype_nloc), [tf.shape(inputs)[0], natoms[0]])
outs = outs + self.add_type
self.atom_ener_after = outs

if self.tot_ener_zero:
force_tot_ener = 0.0
Expand Down
41 changes: 28 additions & 13 deletions deepmd/model/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,36 @@ def get_type_map (self) :
def data_stat(self, data):
all_stat = make_stat_input(data, self.data_stat_nbatch, merge_sys = False)
m_all_stat = merge_sys_stat(all_stat)
self._compute_input_stat(m_all_stat, protection = self.data_stat_protect)
self._compute_output_stat(all_stat)
self._compute_input_stat(m_all_stat, protection=self.data_stat_protect, mixed_type=data.mixed_type)
self._compute_output_stat(all_stat, mixed_type=data.mixed_type)
# self.bias_atom_e = data.compute_energy_shift(self.rcond)

def _compute_input_stat (self, all_stat, protection = 1e-2) :
self.descrpt.compute_input_stats(all_stat['coord'],
all_stat['box'],
all_stat['type'],
all_stat['natoms_vec'],
all_stat['default_mesh'],
all_stat)
self.fitting.compute_input_stats(all_stat, protection = protection)
def _compute_input_stat (self, all_stat, protection=1e-2, mixed_type=False):
if mixed_type:
self.descrpt.compute_input_stats(all_stat['coord'],
all_stat['box'],
all_stat['type'],
all_stat['natoms_vec'],
all_stat['default_mesh'],
all_stat,
mixed_type,
all_stat['real_natoms_vec'])
else:
self.descrpt.compute_input_stats(all_stat['coord'],
all_stat['box'],
all_stat['type'],
all_stat['natoms_vec'],
all_stat['default_mesh'],
all_stat)
self.fitting.compute_input_stats(all_stat, protection=protection)

def _compute_output_stat (self, all_stat, mixed_type=False):
if mixed_type:
self.fitting.compute_output_stats(all_stat, mixed_type=mixed_type)
else:
self.fitting.compute_output_stats(all_stat)

def _compute_output_stat (self, all_stat) :
self.fitting.compute_output_stats(all_stat)


def build (self,
coord_,
atype_,
Expand Down Expand Up @@ -158,6 +171,7 @@ def build (self,
suffix = suffix,
)
input_dict['type_embedding'] = type_embedding
input_dict['atype'] = atype_

if frz_model == None:
dout \
Expand Down Expand Up @@ -195,6 +209,7 @@ def build (self,
input_dict,
reuse = reuse,
suffix = suffix)
self.atom_ener = atom_ener

if self.srtab is not None :
sw_lambda, sw_deriv \
Expand Down
30 changes: 27 additions & 3 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from deepmd.utils.sess import run_sess
from deepmd.utils.type_embed import TypeEmbedNet
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
from deepmd.utils.argcheck import type_embedding_args

from tensorflow.python.client import timeline
from deepmd.env import op_module, TF_VERSION
Expand Down Expand Up @@ -69,17 +70,20 @@ def _init_param(self, jdata):
# nvnmd
self.nvnmd_param = jdata.get('nvnmd', {})
nvnmd_cfg.init_from_jdata(self.nvnmd_param)
nvnmd_cfg.init_from_deepmd_input(model_param)
if nvnmd_cfg.enable:
nvnmd_cfg.init_from_deepmd_input(model_param)
nvnmd_cfg.disp_message()
nvnmd_cfg.save()

# descriptor
try:
descrpt_type = descrpt_param['type']
self.descrpt_type = descrpt_type
except KeyError:
raise KeyError('the type of descriptor should be set by `type`')

if descrpt_param['type'] in ['se_atten']:
descrpt_param['ntypes'] = len(model_param['type_map'])
self.descrpt = Descriptor(**descrpt_param)

# fitting net
Expand Down Expand Up @@ -112,14 +116,30 @@ def _init_param(self, jdata):
raise RuntimeError('unknow fitting type ' + fitting_type)

# type embedding
padding = False
if descrpt_type == 'se_atten':
padding = True
if typeebd_param is not None:
self.typeebd = TypeEmbedNet(
neuron=typeebd_param['neuron'],
resnet_dt=typeebd_param['resnet_dt'],
activation_function=typeebd_param['activation_function'],
precision=typeebd_param['precision'],
trainable=typeebd_param['trainable'],
seed=typeebd_param['seed']
seed=typeebd_param['seed'],
padding=padding
)
elif descrpt_type == 'se_atten':
default_args = type_embedding_args()
default_args_dict = {i.name: i.default for i in default_args}
self.typeebd = TypeEmbedNet(
neuron=default_args_dict['neuron'],
resnet_dt=default_args_dict['resnet_dt'],
activation_function=None,
precision=default_args_dict['precision'],
trainable=default_args_dict['trainable'],
seed=default_args_dict['seed'],
padding=padding
)
else:
self.typeebd = None
Expand Down Expand Up @@ -272,6 +292,10 @@ def build (self,
self.ntypes = self.model.get_ntypes()
self.stop_batch = stop_batch

if not self.is_compress and data.mixed_type:
assert self.descrpt_type in ['se_atten'], 'Data in mixed_type format must use attention descriptor!'
assert self.fitting_type in ['ener'], 'Data in mixed_type format must use ener fitting!'

if self.numb_fparam > 0 :
log.info("training with %d frame parameter(s)" % self.numb_fparam)
else:
Expand Down Expand Up @@ -585,7 +609,7 @@ def save_checkpoint(self, cur_batch: int):
def get_feed_dict(self, batch, is_training):
feed_dict = {}
for kk in batch.keys():
if kk == 'find_type' or kk == 'type':
if kk == 'find_type' or kk == 'type' or kk == 'real_natoms_vec':
continue
if 'find_' in kk:
feed_dict[self.place_holders[kk]] = batch[kk]
Expand Down
Loading