diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index eb88aca92e..52d8e97cf8 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -55,16 +55,16 @@ def _generate_descrpt_from_param_dict(descrpt_param): descrpt_param.pop(kk, None) if descrpt_type == 'loc_frame': descrpt = DescrptLocFrame(**descrpt_param) - elif descrpt_type == 'se_a' : + elif descrpt_type == 'se_e2_a' or descrpt_type == 'se_a' : descrpt = DescrptSeA(**descrpt_param) - elif descrpt_type == 'se_a_3be' or descrpt_type == 'se_at' or descrpt_type == 'se_t' : + elif descrpt_type == 'se_e2_r' or descrpt_type == 'se_r' : + descrpt = DescrptSeR(**descrpt_param) + elif descrpt_type == 'se_e3' or descrpt_type == 'se_at' or descrpt_type == 'se_a_3be' : descrpt = DescrptSeT(**descrpt_param) elif descrpt_type == 'se_a_tpe' or descrpt_type == 'se_a_ebd' : descrpt = DescrptSeAEbd(**descrpt_param) elif descrpt_type == 'se_a_ef' : descrpt = DescrptSeAEf(**descrpt_param) - elif descrpt_type == 'se_r' : - descrpt = DescrptSeR(**descrpt_param) elif descrpt_type == 'se_ar' : descrpt = DescrptSeAR(descrpt_param) else : diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index dca1508a03..121079bead 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -177,9 +177,9 @@ def descrpt_variant_type_args(): return Variant("type", [ Argument("loc_frame", dict, descrpt_local_frame_args()), - Argument("se_a", dict, descrpt_se_a_args()), - Argument("se_r", dict, descrpt_se_r_args()), - Argument("se_t", dict, descrpt_se_t_args(), alias = ['se_at', 'se_a_3be']), + Argument("se_e2_a", dict, descrpt_se_a_args(), alias = ['se_a']), + Argument("se_e2_r", dict, descrpt_se_r_args(), alias = ['se_r']), + Argument("se_e3", dict, descrpt_se_t_args(), alias = ['se_at', 'se_a_3be', 'se_t']), Argument("se_a_tpe", dict, descrpt_se_a_tpe_args(), alias = ['se_a_ebd']), Argument("hybrid", dict, descrpt_hybrid_args()), ], doc = doc_descrpt_type) @@ -553,7 +553,21 @@ def gen_doc(*, make_anchor=True, make_link=True, **kwargs): return "\n\n".join(ptr) +def normalize_hybrid_list(hy_list): + new_list = [] + base = Argument("base", dict, [], [descrpt_variant_type_args()], doc = "") + for ii in range(len(hy_list)): + data = base.normalize_value(hy_list[ii], trim_pattern="_*") + base.check_value(data, strict=True) + new_list.append(data) + return new_list + + def normalize(data): + if "hybrid" == data["model"]["descriptor"]["type"]: + data["model"]["descriptor"]["list"] \ + = normalize_hybrid_list(data["model"]["descriptor"]["list"]) + ma = model_args() lra = learning_rate_args() la = loss_args()