Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__ (self,
set_davg_zero: bool = False,
activation_function: str = 'tanh',
precision: str = 'default',
uniform_seed: bool = False
uniform_seed: bool = False,
name: str = None,
) -> None:
"""
Constructor
Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__ (self,
The precision of the embedding net parameters. Supported options are {1}
uniform_seed
Only for the purpose of backward compatibility, retrieves the old behavior of using the random seed
name
Name used to identify the descriptor
"""
self.sel_a = sel
self.rcut_r = rcut
Expand All @@ -89,7 +92,7 @@ def __init__ (self,
self.type_one_side = type_one_side
if self.type_one_side and len(exclude_types) != 0:
raise RuntimeError('"type_one_side" is not compatible with "exclude_types"')

self.name = name
# descrpt config
self.sel_r = [ 0 for ii in range(len(self.sel_a)) ]
self.ntypes = len(self.sel_a)
Expand Down
1 change: 1 addition & 0 deletions deepmd/entrypoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# import `train` as `train_dp` to avoid the conflict of the
# module name `train` and the function name `train`
from .train import train as train_dp
from .train_mt import train_mt as train_dp_mt
from .transfer import transfer
from ..infer.model_devi import make_model_devi
from .convert import convert
Expand Down
1 change: 1 addition & 0 deletions deepmd/entrypoints/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def compress(
mpi_log=mpi_log,
log_level=log_level,
log_path=log_path,
multi_task = False,
)
except GraphTooLargeError as e:
raise RuntimeError(
Expand Down
13 changes: 12 additions & 1 deletion deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
freeze,
test,
train_dp,
train_dp_mt,
transfer,
make_model_devi,
convert,
Expand Down Expand Up @@ -162,6 +163,13 @@ def parse_args(args: Optional[List[str]] = None):
default="out.json",
help="The output file of the parameters used in training.",
)
parser_train.add_argument(
"-mt",
"--multi_task",
action = 'store_true',
help="Whether using multi-task.",
)


# * freeze script ******************************************************************
parser_frz = subparsers.add_parser(
Expand Down Expand Up @@ -422,7 +430,10 @@ def main():
dict_args = vars(args)

if args.command == "train":
train_dp(**dict_args)
if dict_args['multi_task']:
train_dp_mt(**dict_args)
else:
train_dp(**dict_args)
elif args.command == "freeze":
freeze(**dict_args)
elif args.command == "config":
Expand Down
19 changes: 14 additions & 5 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def train(
restart=restart,
log_path=log_path,
log_level=log_level,
mpi_log=mpi_log
mpi_log=mpi_log,
)

for message in WELCOME + CITATION + BUILD:
Expand Down Expand Up @@ -210,14 +210,19 @@ def get_modifier(modi_data=None):
return modifier


def get_rcut(jdata):
descrpt_data = jdata['model']['descriptor']
def parse_rcut(descrpt_data):
rcut_list = []
if descrpt_data['type'] == 'hybrid':
for ii in descrpt_data['list']:
rcut_list.append(ii['rcut'])
else:
rcut_list.append(descrpt_data['rcut'])
return rcut_list

def get_rcut(jdata):
descrpt_data = jdata['model']['descriptor']
rcut_list = []
rcut_list.extend(parse_rcut(descrpt_data))
Comment on lines -213 to +225
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think you need to revise the get_rcut.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I don't split get_rcut into two function, I can't reuse it in the train_mt

return max(rcut_list)


Expand Down Expand Up @@ -295,12 +300,16 @@ def update_one_sel(jdata, descriptor):
return descriptor


def update_sel(jdata):
descrpt_data = jdata['model']['descriptor']
def parse_auto_descrpt(jdata,descrpt_data):
if descrpt_data['type'] == 'hybrid':
for ii in range(len(descrpt_data['list'])):
descrpt_data['list'][ii] = update_one_sel(jdata, descrpt_data['list'][ii])
else:
descrpt_data = update_one_sel(jdata, descrpt_data)
return descrpt_data

def update_sel(jdata):
descrpt_data = jdata['model']['descriptor']
descrpt_data = parse_auto_descrpt(jdata, descrpt_data)
jdata['model']['descriptor'] = descrpt_data
return jdata
260 changes: 260 additions & 0 deletions deepmd/entrypoints/train_mt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
"""DeePMD training entrypoint script.

Can handle local or distributed training.
"""

import json
import logging
import time
import os
from typing import Dict, List, Optional, Any

import numpy as np
from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have
from deepmd.env import reset_default_tf_session_config
from deepmd.infer.data_modifier import DipoleChargeModifier
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
from deepmd.train.trainer import DPTrainer
from deepmd.train.trainer_mt import DPMultitaskTrainer
from deepmd.utils.argcheck import normalize
from deepmd.utils.argcheck_mt import normalize_mt
from deepmd.utils.compat import updata_deepmd_input
from deepmd.utils.data_system import DeepmdDataSystem
from deepmd.utils.data_docker import DeepmdDataDocker
from deepmd.utils.sess import run_sess
from deepmd.utils.neighbor_stat import NeighborStat
from deepmd.entrypoints.train import get_modifier, parse_rcut, get_type_map
from deepmd.entrypoints.train import parse_auto_sel, parse_auto_sel_ratio, wrap_up_4

__all__ = ["train"]

log = logging.getLogger(__name__)


def train_mt(
*,
INPUT: str,
init_model: Optional[str],
restart: Optional[str],
output: str,
mpi_log: str,
log_level: int,
log_path: Optional[str],
**kwargs,
):
"""Run DeePMD model training.

Parameters
----------
INPUT : str
json/yaml control file
init_model : Optional[str]
path to checkpoint folder or None
restart : Optional[str]
path to checkpoint folder or None
output : str
path for dump file with arguments
mpi_log : str
mpi logging mode
log_level : int
logging level defined by int 0-3
log_path : Optional[str]
logging file path or None if logs are to be output only to stdout
Raises
------
RuntimeError
if distributed training job nem is wrong
"""
# load json database
jdata = j_loader(INPUT)

jdata = updata_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
jdata = normalize_mt(jdata)
jdata = update_sel(jdata)

with open(output, "w") as fp:
json.dump(jdata, fp, indent=4)

# run options
run_opt = RunOptions(
init_model=init_model,
restart=restart,
log_path=log_path,
log_level=log_level,
mpi_log=mpi_log,

)

for message in WELCOME + CITATION + BUILD:
log.info(message)

run_opt.print_resource_summary()
_do_work(jdata, run_opt)




def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
"""Run serial model training.

Parameters
----------
jdata : Dict[str, Any]
arguments read form json/yaml control file
run_opt : RunOptions
object with run configuration

Raises
------
RuntimeError
If unsupported modifier type is selected for model
"""
# make necessary checks
assert "training" in jdata

# avoid conflict of visible gpus among multipe tf sessions in one process
if run_opt.is_distrib and len(run_opt.gpus or []) > 1:
reset_default_tf_session_config(cpu_only=True)

# init the model
rcut_list = []
model = DPMultitaskTrainer(jdata, run_opt=run_opt)
for model_name in model.model_dict.keys():
sub_model = model.model_dict[model_name]
rcut_list.append(sub_model.get_rcut())
type_map = sub_model.get_type_map()
rcut = max(rcut_list)

if len(type_map) == 0:
ipt_type_map = None
else:
ipt_type_map = type_map

#  init random seed
seed = jdata["training"].get("seed", None)
if seed is not None:
seed = seed % (2 ** 32)
np.random.seed(seed)

# setup data modifier
modifier = get_modifier(jdata["model"].get("modifier", None))
if modifier is not None:
raise RuntimeError('modifier is not supported in multi-task training mode yet')

# init data
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier)
train_data.print_summary("training")
if jdata["training"].get("validation_data", None) is not None:
valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier)
valid_data.print_summary("validation")
else:
valid_data = None

# get training info
stop_batch = j_must_have(jdata["training"], "numb_steps")
model.build(train_data, stop_batch)

# train the model with the provided systems in a cyclic way
start_time = time.time()
model.train(train_data, valid_data)
end_time = time.time()
log.info("finished training")
log.info(f"wall time: {(end_time - start_time):.3f} s")

def get_data(jdata: Dict[str, Any], rcut, type_map, modifier):
systems = j_must_have(jdata, "systems")
batch_size = j_must_have(jdata, "batch_size")
sys_probs = jdata.get("sys_probs", None)
auto_prob = jdata.get("auto_prob", "prob_sys_size")
auto_prob_method = jdata.get("auto_prob_method", "prob_uniform")


docker = DeepmdDataDocker(
data_systems=systems,
batch_size = batch_size,
rcut = rcut,
type_map = type_map, # in the data docker is the total type
sys_probs = sys_probs,
auto_prob_style = auto_prob,
auto_prob_style_method = auto_prob_method,
modifier = modifier,
)
return docker

def get_rcut(jdata):
descrpt_data = jdata['model']['descriptor']
rcut_list = []
for sub_descrpt in descrpt_data:
rcut_list.extend(parse_rcut(sub_descrpt))
return max(rcut_list)

def get_sel(jdata, rcut, data_sys_name = None):
max_rcut = get_rcut(jdata)
type_map = get_type_map(jdata)

if type_map and len(type_map) == 0:
type_map = None
train_data = get_data(jdata["training"]["training_data"], max_rcut, type_map, None)
train_data = train_data.get_data_system(data_sys_name)

train_data.get_batch()
data_ntypes = train_data.get_ntypes()
if type_map is not None:
map_ntypes = len(type_map)
else:
map_ntypes = data_ntypes
ntypes = max([map_ntypes, data_ntypes])

neistat = NeighborStat(ntypes, rcut)

min_nbor_dist, max_nbor_size = neistat.get_stat(train_data)

return max_nbor_size



def update_one_sel(jdata, descriptor):
rcut = descriptor['rcut']
data_sys_name = ''
if 'name' in descriptor.keys():
sys_name = descriptor['name']
for sub_task in jdata['training']['tasks']:
# find the data system we want, which using the specific descriptor
if sub_task['descriptor'] == sys_name:
data_sys_name = sub_task['name']
break
tmp_sel = get_sel(jdata, rcut, data_sys_name)

if parse_auto_sel(descriptor['sel']) :
ratio = parse_auto_sel_ratio(descriptor['sel'])
descriptor['sel'] = [int(wrap_up_4(ii * ratio)) for ii in tmp_sel]
else:
# sel is set by user
for ii, (tt, dd) in enumerate(zip(tmp_sel, descriptor['sel'])):
if dd and tt > dd:
# we may skip warning for sel=0, where the user is likely
# to exclude such type in the descriptor
log.warning(
"sel of type %d is not enough! The expected value is "
"not less than %d, but you set it to %d. The accuracy"
" of your model may get worse." %(ii, tt, dd)
)
return descriptor


def parse_auto_descrpt(jdata,descrpt_data):
if descrpt_data['type'] == 'hybrid':
for ii in range(len(descrpt_data['list'])):
descrpt_data['list'][ii] = update_one_sel(jdata, descrpt_data['list'][ii])
else:
descrpt_data = update_one_sel(jdata, descrpt_data)
return descrpt_data

def update_sel(jdata):
descrpt_data = jdata['model']['descriptor']
update_descrpt = []
for sub_descrpt in descrpt_data:
sub_descrpt = parse_auto_descrpt(jdata, sub_descrpt)
update_descrpt.append(sub_descrpt)
jdata['model']['descriptor'] = update_descrpt
return jdata
Loading