diff --git a/deepmd/entrypoints/test.py b/deepmd/entrypoints/test.py index 62687866a8..218d2974fe 100644 --- a/deepmd/entrypoints/test.py +++ b/deepmd/entrypoints/test.py @@ -1,12 +1,13 @@ """Test trained DeePMD model.""" import logging from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Dict, Optional, Tuple import numpy as np from deepmd import DeepPotential from deepmd.common import expand_sys_str from deepmd.utils.data import DeepmdData +from deepmd.utils.weight_avg import weighted_average if TYPE_CHECKING: from deepmd.infer import DeepDipole, DeepPolar, DeepPot, DeepWFC @@ -77,7 +78,7 @@ def test( data = DeepmdData(system, set_prefix, shuffle_test=shuffle_test, type_map=tmap) if dp.model_type == "ener": - err, siz = test_ener( + err = test_ener( dp, data, system, @@ -87,18 +88,15 @@ def test( append_detail=(cc != 0), ) elif dp.model_type == "dipole": - err, siz = test_dipole(dp, data, numb_test, detail_file, atomic) + err = test_dipole(dp, data, numb_test, detail_file, atomic) elif dp.model_type == "polar": - err, siz = test_polar(dp, data, numb_test, detail_file, global_polar=False) + err = test_polar(dp, data, numb_test, detail_file, global_polar=False) elif dp.model_type == "global_polar": - err, siz = test_polar(dp, data, numb_test, detail_file, global_polar=True) - elif dp.model_type == "wfc": - err, siz = test_wfc(dp, data, numb_test, detail_file) + err = test_polar(dp, data, numb_test, detail_file, global_polar=True) log.info("# ----------------------------------------------- ") err_coll.append(err) - siz_coll.append(siz) - avg_err = weighted_average(err_coll, siz_coll) + avg_err = weighted_average(err_coll) if len(all_sys) != len(err_coll): log.warning("Not all systems are tested! Check if the systems are valid") @@ -119,8 +117,8 @@ def test( log.info("# ----------------------------------------------- ") -def l2err(diff: np.ndarray) -> np.ndarray: - """Calculate average l2 norm error. +def rmse(diff: np.ndarray) -> np.ndarray: + """Calculate average root mean square error. Parameters ---------- @@ -135,39 +133,6 @@ def l2err(diff: np.ndarray) -> np.ndarray: return np.sqrt(np.average(diff * diff)) -def weighted_average( - err_coll: List[List[np.ndarray]], siz_coll: List[List[int]] -) -> np.ndarray: - """Compute wighted average of prediction errors for model. - - Parameters - ---------- - err_coll : List[List[np.ndarray]] - each item in list represents erros for one model - siz_coll : List[List[int]] - weight for each model errors - - Returns - ------- - np.ndarray - weighted averages - """ - assert len(err_coll) == len(siz_coll) - - nitems = len(err_coll[0]) - sum_err = np.zeros(nitems) - sum_siz = np.zeros(nitems) - for sys_error, sys_size in zip(err_coll, siz_coll): - for ii in range(nitems): - ee = sys_error[ii] - ss = sys_size[ii] - sum_err[ii] += ee * ee * ss - sum_siz[ii] += ss - for ii in range(nitems): - sum_err[ii] = np.sqrt(sum_err[ii] / sum_siz[ii]) - return sum_err - - def save_txt_file( fname: Path, data: np.ndarray, header: str = "", append: bool = False ): @@ -280,25 +245,25 @@ def test_ener( ae = ae.reshape([numb_test, -1]) av = av.reshape([numb_test, -1]) - l2e = l2err(energy - test_data["energy"][:numb_test].reshape([-1, 1])) - l2f = l2err(force - test_data["force"][:numb_test]) - l2v = l2err(virial - test_data["virial"][:numb_test]) - l2ea = l2e / natoms - l2va = l2v / natoms + rmse_e = rmse(energy - test_data["energy"][:numb_test].reshape([-1, 1])) + rmse_f = rmse(force - test_data["force"][:numb_test]) + rmse_v = rmse(virial - test_data["virial"][:numb_test]) + rmse_ea = rmse_e / natoms + rmse_va = rmse_v / natoms if has_atom_ener: - l2ae = l2err( + rmse_ae = rmse( test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1]) ) # print ("# energies: %s" % energy) log.info(f"# number of test data : {numb_test:d} ") - log.info(f"Energy RMSE : {l2e:e} eV") - log.info(f"Energy RMSE/Natoms : {l2ea:e} eV") - log.info(f"Force RMSE : {l2f:e} eV/A") - log.info(f"Virial RMSE : {l2v:e} eV") - log.info(f"Virial RMSE/Natoms : {l2va:e} eV") + log.info(f"Energy RMSE : {rmse_e:e} eV") + log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV") + log.info(f"Force RMSE : {rmse_f:e} eV/A") + log.info(f"Virial RMSE : {rmse_v:e} eV") + log.info(f"Virial RMSE/Natoms : {rmse_va:e} eV") if has_atom_ener: - log.info(f"Atomic ener RMSE : {l2ae:e} eV") + log.info(f"Atomic ener RMSE : {rmse_ae:e} eV") if detail_file is not None: detail_path = Path(detail_file) @@ -344,10 +309,14 @@ def test_ener( "pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz", append=append_detail, ) - return [l2ea, l2f, l2va], [energy.size, force.size, virial.size] + return { + "rmse_ea" : (rmse_ea, energy.size), + "rmse_f" : (rmse_f, force.size), + "rmse_va" : (rmse_va, virial.size), + } -def print_ener_sys_avg(avg: np.ndarray): +def print_ener_sys_avg(avg: Dict[str,float]): """Print errors summary for energy type potential. Parameters @@ -355,9 +324,9 @@ def print_ener_sys_avg(avg: np.ndarray): avg : np.ndarray array with summaries """ - log.info(f"Energy RMSE/Natoms : {avg[0]:e} eV") - log.info(f"Force RMSE : {avg[1]:e} eV/A") - log.info(f"Virial RMSE/Natoms : {avg[2]:e} eV") + log.info(f"Energy RMSE/Natoms : {avg['rmse_ea']:e} eV") + log.info(f"Force RMSE : {avg['rmse_f']:e} eV/A") + log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV") def run_test(dp: "DeepTensor", test_data: dict, numb_test: int): @@ -417,10 +386,10 @@ def test_wfc( ) test_data = data.get_test() wfc, numb_test, _ = run_test(dp, test_data, numb_test) - l2f = l2err(wfc - test_data["wfc"][:numb_test]) + rmse_f = rmse(wfc - test_data["wfc"][:numb_test]) log.info("# number of test data : {numb_test:d} ") - log.info("WFC RMSE : {l2f:e} eV/A") + log.info("WFC RMSE : {rmse_f:e} eV/A") if detail_file is not None: detail_path = Path(detail_file) @@ -436,7 +405,9 @@ def test_wfc( pe, header="ref_wfc(12 dofs) predicted_wfc(12 dofs)", ) - return [l2f], [wfc.size] + return { + 'rmse' : (rmse_f, wfc.size) + } def print_wfc_sys_avg(avg): @@ -447,7 +418,7 @@ def print_wfc_sys_avg(avg): avg : np.ndarray array with summaries """ - log.info(f"WFC RMSE : {avg[0]:e} eV/A") + log.info(f"WFC RMSE : {avg['rmse']:e} eV/A") def test_polar( @@ -504,15 +475,15 @@ def test_polar( for ii in sel_type: sel_natoms += sum(atype == ii) - l2f = l2err(polar - test_data["polarizability"][:numb_test]) - l2fs = l2f / np.sqrt(sel_natoms) - l2fa = l2f / sel_natoms + rmse_f = rmse(polar - test_data["polarizability"][:numb_test]) + rmse_fs = rmse_f / np.sqrt(sel_natoms) + rmse_fa = rmse_f / sel_natoms log.info(f"# number of test data : {numb_test:d} ") - log.info(f"Polarizability RMSE : {l2f:e} eV/A") + log.info(f"Polarizability RMSE : {rmse_f:e} eV/A") if global_polar: - log.info(f"Polarizability RMSE/sqrtN : {l2fs:e} eV/A") - log.info(f"Polarizability RMSE/N : {l2fa:e} eV/A") + log.info(f"Polarizability RMSE/sqrtN : {rmse_fs:e} eV/A") + log.info(f"Polarizability RMSE/N : {rmse_fa:e} eV/A") if detail_file is not None: detail_path = Path(detail_file) @@ -531,7 +502,9 @@ def test_polar( "data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz " "pred_pzx pred_pzy pred_pzz", ) - return [l2f], [polar.size] + return { + "rmse" : (rmse_f, polar.size) + } def print_polar_sys_avg(avg): @@ -542,7 +515,7 @@ def print_polar_sys_avg(avg): avg : np.ndarray array with summaries """ - log.info(f"Polarizability RMSE : {avg[0]:e} eV/A") + log.info(f"Polarizability RMSE : {avg['rmse']:e} eV/A") def test_dipole( @@ -584,13 +557,13 @@ def test_dipole( atoms = dipole.shape[1] dipole = np.sum(dipole,axis=1) - l2f = l2err(dipole - test_data["dipole"][:numb_test]) + rmse_f = rmse(dipole - test_data["dipole"][:numb_test]) if has_atom_dipole == False: - l2f = l2f / atoms + rmse_f = rmse_f / atoms log.info(f"# number of test data : {numb_test:d}") - log.info(f"Dipole RMSE : {l2f:e} eV/A") + log.info(f"Dipole RMSE : {rmse_f:e} eV/A") if detail_file is not None: detail_path = Path(detail_file) @@ -607,7 +580,9 @@ def test_dipole( pe, header="data_x data_y data_z pred_x pred_y pred_z", ) - return [l2f], [dipole.size] + return { + 'rmse' : (rmse_f, dipole.size) + } def print_dipole_sys_avg(avg): @@ -618,4 +593,4 @@ def print_dipole_sys_avg(avg): avg : np.ndarray array with summaries """ - log.info(f"Dipole RMSE : {avg[0]:e} eV/A") + log.info(f"Dipole RMSE : {avg['rmse']:e} eV/A") diff --git a/deepmd/utils/weight_avg.py b/deepmd/utils/weight_avg.py new file mode 100644 index 0000000000..aec5026ae4 --- /dev/null +++ b/deepmd/utils/weight_avg.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING, List, Dict, Optional, Tuple +import numpy as np + + +def weighted_average( + errors: List[Dict[str, Tuple[float, float]]] +) -> Dict: + """Compute wighted average of prediction errors for model. + + Parameters + ---------- + errors : List[Dict[str, Tuple[float, float]]] + List: the error of systems + Dict: the error of quantities, name given by the key + Tuple: (error, weight) + + Returns + ------- + Dict + weighted averages + """ + sum_err = {} + sum_siz = {} + for err in errors: + for kk, (ee, ss) in err.items(): + if kk in sum_err: + sum_err[kk] += ee * ee * ss + sum_siz[kk] += ss + else : + sum_err[kk] = ee * ee * ss + sum_siz[kk] = ss + for kk in sum_err.keys(): + sum_err[kk] = np.sqrt(sum_err[kk] / sum_siz[kk]) + return sum_err