Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 93 additions & 43 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def type_embedding_args():
doc_trainable = "If the parameters in the embedding net are trainable"

return [
Argument("neuron", list, optional=True, default=[8], doc=doc_neuron),
Argument("neuron", List[int], optional=True, default=[8], doc=doc_neuron),
Argument(
"activation_function",
str,
Expand All @@ -77,9 +77,9 @@ def spin_args():
doc_virtual_len = "The distance between virtual atom representing spin and its corresponding real atom for each atom type with spin"

return [
Argument("use_spin", list, doc=doc_use_spin),
Argument("spin_norm", list, doc=doc_spin_norm),
Argument("virtual_len", list, doc=doc_virtual_len),
Argument("use_spin", List[bool], doc=doc_use_spin),
Argument("spin_norm", List[float], doc=doc_spin_norm),
Argument("virtual_len", List[float], doc=doc_virtual_len),
]


Expand Down Expand Up @@ -159,10 +159,10 @@ def descrpt_local_frame_args():
- axis_rule[i*6+5]: index of the axis atom defining the second axis. Note that the neighbors with the same class and type are sorted according to their relative distance."

return [
Argument("sel_a", list, optional=False, doc=doc_sel_a),
Argument("sel_r", list, optional=False, doc=doc_sel_r),
Argument("sel_a", List[int], optional=False, doc=doc_sel_a),
Argument("sel_r", List[int], optional=False, doc=doc_sel_r),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("axis_rule", list, optional=False, doc=doc_axis_rule),
Argument("axis_rule", List[int], optional=False, doc=doc_axis_rule),
]


Expand All @@ -185,10 +185,12 @@ def descrpt_se_a_args():
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"axis_neuron",
int,
Expand All @@ -212,7 +214,11 @@ def descrpt_se_a_args():
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument(
"set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
Expand All @@ -236,10 +242,12 @@ def descrpt_se_t_args():
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"activation_function",
str,
Expand Down Expand Up @@ -289,10 +297,12 @@ def descrpt_se_r_args():
doc_set_davg_zero = "Set the normalization average to zero. This option should be set when `atom_ener` in the energy fitting is used"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"activation_function",
str,
Expand All @@ -308,7 +318,11 @@ def descrpt_se_r_args():
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument(
"set_davg_zero", bool, optional=True, default=False, doc=doc_set_davg_zero
Expand Down Expand Up @@ -356,10 +370,14 @@ def descrpt_se_atten_common_args():
doc_attn_mask = "Whether to do mask on the diagonal in the attention matrix"

return [
Argument("sel", [int, list, str], optional=True, default="auto", doc=doc_sel),
Argument(
"sel", [int, List[int], str], optional=True, default="auto", doc=doc_sel
),
Argument("rcut", float, optional=True, default=6.0, doc=doc_rcut),
Argument("rcut_smth", float, optional=True, default=0.5, doc=doc_rcut_smth),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"axis_neuron",
int,
Expand All @@ -383,7 +401,11 @@ def descrpt_se_atten_common_args():
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument("attn", int, optional=True, default=128, doc=doc_attn),
Argument("attn_layer", int, optional=True, default=2, doc=doc_attn_layer),
Expand Down Expand Up @@ -454,8 +476,10 @@ def descrpt_se_a_mask_args():
doc_seed = "Random seed for parameter initialization"

return [
Argument("sel", [list, str], optional=True, default="auto", doc=doc_sel),
Argument("neuron", list, optional=True, default=[10, 20, 40], doc=doc_neuron),
Argument("sel", [List[int], str], optional=True, default="auto", doc=doc_sel),
Argument(
"neuron", List[int], optional=True, default=[10, 20, 40], doc=doc_neuron
),
Argument(
"axis_neuron",
int,
Expand All @@ -476,7 +500,11 @@ def descrpt_se_a_mask_args():
"type_one_side", bool, optional=True, default=False, doc=doc_type_one_side
),
Argument(
"exclude_types", list, optional=True, default=[], doc=doc_exclude_types
"exclude_types",
List[List[int]],
optional=True,
default=[],
doc=doc_exclude_types,
),
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument("trainable", bool, optional=True, default=True, doc=doc_trainable),
Expand Down Expand Up @@ -525,7 +553,7 @@ def fitting_ener():
doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection'
doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be\n\n\
- bool: True if all parameters of the fitting net are trainable, False otherwise.\n\n\
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of tihs list should be equal to len(`neuron`)+1."
- list of bool: Specifies if each layer is trainable. Since the fitting net is composed by hidden layers followed by a output layer, the length of this list should be equal to len(`neuron`)+1."
doc_rcond = "The condition number used to determine the inital energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details."
doc_seed = "Random seed for parameter initialization of the fitting net"
doc_atom_ener = "Specify the atomic energy in vacuum for each type"
Expand All @@ -547,7 +575,7 @@ def fitting_ener():
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
Argument(
"neuron",
list,
List[int],
optional=True,
default=[120, 120, 120],
alias=["n_neuron"],
Expand All @@ -563,14 +591,24 @@ def fitting_ener():
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
Argument(
"trainable", [list, bool], optional=True, default=True, doc=doc_trainable
"trainable",
[List[bool], bool],
optional=True,
default=True,
doc=doc_trainable,
),
Argument(
"rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond
),
Argument("seed", [int, None], optional=True, doc=doc_seed),
Argument("atom_ener", list, optional=True, default=[], doc=doc_atom_ener),
Argument("layer_name", list, optional=True, doc=doc_layer_name),
Argument(
"atom_ener",
List[Optional[float]],
optional=True,
default=[],
doc=doc_atom_ener,
),
Argument("layer_name", List[str], optional=True, doc=doc_layer_name),
Argument(
"use_aparam_as_mask",
bool,
Expand Down Expand Up @@ -602,7 +640,7 @@ def fitting_dos():
Argument("numb_fparam", int, optional=True, default=0, doc=doc_numb_fparam),
Argument("numb_aparam", int, optional=True, default=0, doc=doc_numb_aparam),
Argument(
"neuron", list, optional=True, default=[120, 120, 120], doc=doc_neuron
"neuron", List[int], optional=True, default=[120, 120, 120], doc=doc_neuron
),
Argument(
"activation_function",
Expand All @@ -614,7 +652,11 @@ def fitting_dos():
Argument("precision", str, optional=True, default="float64", doc=doc_precision),
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
Argument(
"trainable", [list, bool], optional=True, default=True, doc=doc_trainable
"trainable",
[List[bool], bool],
optional=True,
default=True,
doc=doc_trainable,
),
Argument(
"rcond", [float, type(None)], optional=True, default=None, doc=doc_rcond
Expand Down Expand Up @@ -642,7 +684,7 @@ def fitting_polar():
return [
Argument(
"neuron",
list,
List[int],
optional=True,
default=[120, 120, 120],
alias=["n_neuron"],
Expand All @@ -658,12 +700,14 @@ def fitting_polar():
Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt),
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument("fit_diag", bool, optional=True, default=True, doc=doc_fit_diag),
Argument("scale", [list, float], optional=True, default=1.0, doc=doc_scale),
Argument(
"scale", [List[float], float], optional=True, default=1.0, doc=doc_scale
),
# Argument("diag_shift", [list,float], optional = True, default = 0.0, doc = doc_diag_shift),
Argument("shift_diag", bool, optional=True, default=True, doc=doc_shift_diag),
Argument(
"sel_type",
[list, int, None],
[List[int], int, None],
optional=True,
alias=["pol_type"],
doc=doc_sel_type,
Expand All @@ -687,7 +731,7 @@ def fitting_dipole():
return [
Argument(
"neuron",
list,
List[int],
optional=True,
default=[120, 120, 120],
alias=["n_neuron"],
Expand All @@ -704,7 +748,7 @@ def fitting_dipole():
Argument("precision", str, optional=True, default="default", doc=doc_precision),
Argument(
"sel_type",
[list, int, None],
[List[int], int, None],
optional=True,
alias=["dipole_type"],
doc=doc_sel_type,
Expand Down Expand Up @@ -740,8 +784,10 @@ def modifier_dipole_charge():

return [
Argument("model_name", str, optional=False, doc=doc_model_name),
Argument("model_charge_map", list, optional=False, doc=doc_model_charge_map),
Argument("sys_charge_map", list, optional=False, doc=doc_sys_charge_map),
Argument(
"model_charge_map", List[float], optional=False, doc=doc_model_charge_map
),
Argument("sys_charge_map", List[float], optional=False, doc=doc_sys_charge_map),
Argument("ewald_beta", float, optional=True, default=0.4, doc=doc_ewald_beta),
Argument("ewald_h", float, optional=True, default=1.0, doc=doc_ewald_h),
]
Expand Down Expand Up @@ -770,7 +816,7 @@ def model_compression():

return [
Argument("model_file", str, optional=False, doc=doc_model_file),
Argument("table_config", list, optional=False, doc=doc_table_config),
Argument("table_config", List[float], optional=False, doc=doc_table_config),
Argument("min_nbor_dist", float, optional=False, doc=doc_min_nbor_dist),
]

Expand Down Expand Up @@ -814,7 +860,7 @@ def model_args(exclude_hybrid=False):
"model",
dict,
[
Argument("type_map", list, optional=True, doc=doc_type_map),
Argument("type_map", List[str], optional=True, doc=doc_type_map),
Argument(
"data_stat_nbatch",
int,
Expand Down Expand Up @@ -1456,11 +1502,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
)

args = [
Argument("systems", [list, str], optional=False, default=".", doc=doc_systems),
Argument(
"systems", [List[str], str], optional=False, default=".", doc=doc_systems
),
Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix),
Argument(
"batch_size",
[list, int, str],
[List[int], int, str],
optional=True,
default="auto",
doc=doc_batch_size,
Expand All @@ -1477,7 +1525,7 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
),
Argument(
"sys_probs",
list,
List[float],
optional=True,
default=None,
doc=doc_sys_probs,
Expand Down Expand Up @@ -1521,11 +1569,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
doc_numb_btch = "An integer that specifies the number of batches to be sampled for each validation period."

args = [
Argument("systems", [list, str], optional=False, default=".", doc=doc_systems),
Argument(
"systems", [List[str], str], optional=False, default=".", doc=doc_systems
),
Argument("set_prefix", str, optional=True, default="set", doc=doc_set_prefix),
Argument(
"batch_size",
[list, int, str],
[List[int], int, str],
optional=True,
default="auto",
doc=doc_batch_size,
Expand All @@ -1542,7 +1592,7 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat
),
Argument(
"sys_probs",
list,
List[float],
optional=True,
default=None,
doc=doc_sys_probs,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ dependencies = [
'numpy',
'scipy',
'pyyaml',
'dargs >= 0.3.5',
'dargs >= 0.4.1',
'python-hostlist >= 1.21',
'typing_extensions; python_version < "3.8"',
'importlib_metadata>=1.4; python_version < "3.8"',
Expand Down