Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 53 additions & 78 deletions deepmd/entrypoints/test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Test trained DeePMD model."""
import logging
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Dict, Optional, Tuple

import numpy as np
from deepmd import DeepPotential
from deepmd.common import expand_sys_str
from deepmd.utils.data import DeepmdData
from deepmd.utils.weight_avg import weighted_average

if TYPE_CHECKING:
from deepmd.infer import DeepDipole, DeepPolar, DeepPot, DeepWFC
Expand Down Expand Up @@ -77,7 +78,7 @@ def test(
data = DeepmdData(system, set_prefix, shuffle_test=shuffle_test, type_map=tmap)

if dp.model_type == "ener":
err, siz = test_ener(
err = test_ener(
dp,
data,
system,
Expand All @@ -87,18 +88,15 @@ def test(
append_detail=(cc != 0),
)
elif dp.model_type == "dipole":
err, siz = test_dipole(dp, data, numb_test, detail_file, atomic)
err = test_dipole(dp, data, numb_test, detail_file, atomic)
elif dp.model_type == "polar":
err, siz = test_polar(dp, data, numb_test, detail_file, global_polar=False)
err = test_polar(dp, data, numb_test, detail_file, global_polar=False)
elif dp.model_type == "global_polar":
err, siz = test_polar(dp, data, numb_test, detail_file, global_polar=True)
elif dp.model_type == "wfc":
err, siz = test_wfc(dp, data, numb_test, detail_file)
err = test_polar(dp, data, numb_test, detail_file, global_polar=True)
log.info("# ----------------------------------------------- ")
err_coll.append(err)
siz_coll.append(siz)

avg_err = weighted_average(err_coll, siz_coll)
avg_err = weighted_average(err_coll)

if len(all_sys) != len(err_coll):
log.warning("Not all systems are tested! Check if the systems are valid")
Expand All @@ -119,8 +117,8 @@ def test(
log.info("# ----------------------------------------------- ")


def l2err(diff: np.ndarray) -> np.ndarray:
"""Calculate average l2 norm error.
def rmse(diff: np.ndarray) -> np.ndarray:
"""Calculate average root mean square error.

Parameters
----------
Expand All @@ -135,39 +133,6 @@ def l2err(diff: np.ndarray) -> np.ndarray:
return np.sqrt(np.average(diff * diff))


def weighted_average(
err_coll: List[List[np.ndarray]], siz_coll: List[List[int]]
) -> np.ndarray:
"""Compute wighted average of prediction errors for model.

Parameters
----------
err_coll : List[List[np.ndarray]]
each item in list represents erros for one model
siz_coll : List[List[int]]
weight for each model errors

Returns
-------
np.ndarray
weighted averages
"""
assert len(err_coll) == len(siz_coll)

nitems = len(err_coll[0])
sum_err = np.zeros(nitems)
sum_siz = np.zeros(nitems)
for sys_error, sys_size in zip(err_coll, siz_coll):
for ii in range(nitems):
ee = sys_error[ii]
ss = sys_size[ii]
sum_err[ii] += ee * ee * ss
sum_siz[ii] += ss
for ii in range(nitems):
sum_err[ii] = np.sqrt(sum_err[ii] / sum_siz[ii])
return sum_err


def save_txt_file(
fname: Path, data: np.ndarray, header: str = "", append: bool = False
):
Expand Down Expand Up @@ -280,25 +245,25 @@ def test_ener(
ae = ae.reshape([numb_test, -1])
av = av.reshape([numb_test, -1])

l2e = l2err(energy - test_data["energy"][:numb_test].reshape([-1, 1]))
l2f = l2err(force - test_data["force"][:numb_test])
l2v = l2err(virial - test_data["virial"][:numb_test])
l2ea = l2e / natoms
l2va = l2v / natoms
rmse_e = rmse(energy - test_data["energy"][:numb_test].reshape([-1, 1]))
rmse_f = rmse(force - test_data["force"][:numb_test])
rmse_v = rmse(virial - test_data["virial"][:numb_test])
rmse_ea = rmse_e / natoms
rmse_va = rmse_v / natoms
if has_atom_ener:
l2ae = l2err(
rmse_ae = rmse(
test_data["atom_ener"][:numb_test].reshape([-1]) - ae.reshape([-1])
)

# print ("# energies: %s" % energy)
log.info(f"# number of test data : {numb_test:d} ")
log.info(f"Energy RMSE : {l2e:e} eV")
log.info(f"Energy RMSE/Natoms : {l2ea:e} eV")
log.info(f"Force RMSE : {l2f:e} eV/A")
log.info(f"Virial RMSE : {l2v:e} eV")
log.info(f"Virial RMSE/Natoms : {l2va:e} eV")
log.info(f"Energy RMSE : {rmse_e:e} eV")
log.info(f"Energy RMSE/Natoms : {rmse_ea:e} eV")
log.info(f"Force RMSE : {rmse_f:e} eV/A")
log.info(f"Virial RMSE : {rmse_v:e} eV")
log.info(f"Virial RMSE/Natoms : {rmse_va:e} eV")
if has_atom_ener:
log.info(f"Atomic ener RMSE : {l2ae:e} eV")
log.info(f"Atomic ener RMSE : {rmse_ae:e} eV")

if detail_file is not None:
detail_path = Path(detail_file)
Expand Down Expand Up @@ -344,20 +309,24 @@ def test_ener(
"pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz",
append=append_detail,
)
return [l2ea, l2f, l2va], [energy.size, force.size, virial.size]
return {
"rmse_ea" : (rmse_ea, energy.size),
"rmse_f" : (rmse_f, force.size),
"rmse_va" : (rmse_va, virial.size),
}


def print_ener_sys_avg(avg: np.ndarray):
def print_ener_sys_avg(avg: Dict[str,float]):
"""Print errors summary for energy type potential.

Parameters
----------
avg : np.ndarray
array with summaries
"""
log.info(f"Energy RMSE/Natoms : {avg[0]:e} eV")
log.info(f"Force RMSE : {avg[1]:e} eV/A")
log.info(f"Virial RMSE/Natoms : {avg[2]:e} eV")
log.info(f"Energy RMSE/Natoms : {avg['rmse_ea']:e} eV")
log.info(f"Force RMSE : {avg['rmse_f']:e} eV/A")
log.info(f"Virial RMSE/Natoms : {avg['rmse_va']:e} eV")


def run_test(dp: "DeepTensor", test_data: dict, numb_test: int):
Expand Down Expand Up @@ -417,10 +386,10 @@ def test_wfc(
)
test_data = data.get_test()
wfc, numb_test, _ = run_test(dp, test_data, numb_test)
l2f = l2err(wfc - test_data["wfc"][:numb_test])
rmse_f = rmse(wfc - test_data["wfc"][:numb_test])

log.info("# number of test data : {numb_test:d} ")
log.info("WFC RMSE : {l2f:e} eV/A")
log.info("WFC RMSE : {rmse_f:e} eV/A")

if detail_file is not None:
detail_path = Path(detail_file)
Expand All @@ -436,7 +405,9 @@ def test_wfc(
pe,
header="ref_wfc(12 dofs) predicted_wfc(12 dofs)",
)
return [l2f], [wfc.size]
return {
'rmse' : (rmse_f, wfc.size)
}


def print_wfc_sys_avg(avg):
Expand All @@ -447,7 +418,7 @@ def print_wfc_sys_avg(avg):
avg : np.ndarray
array with summaries
"""
log.info(f"WFC RMSE : {avg[0]:e} eV/A")
log.info(f"WFC RMSE : {avg['rmse']:e} eV/A")


def test_polar(
Expand Down Expand Up @@ -504,15 +475,15 @@ def test_polar(
for ii in sel_type:
sel_natoms += sum(atype == ii)

l2f = l2err(polar - test_data["polarizability"][:numb_test])
l2fs = l2f / np.sqrt(sel_natoms)
l2fa = l2f / sel_natoms
rmse_f = rmse(polar - test_data["polarizability"][:numb_test])
rmse_fs = rmse_f / np.sqrt(sel_natoms)
rmse_fa = rmse_f / sel_natoms

log.info(f"# number of test data : {numb_test:d} ")
log.info(f"Polarizability RMSE : {l2f:e} eV/A")
log.info(f"Polarizability RMSE : {rmse_f:e} eV/A")
if global_polar:
log.info(f"Polarizability RMSE/sqrtN : {l2fs:e} eV/A")
log.info(f"Polarizability RMSE/N : {l2fa:e} eV/A")
log.info(f"Polarizability RMSE/sqrtN : {rmse_fs:e} eV/A")
log.info(f"Polarizability RMSE/N : {rmse_fa:e} eV/A")

if detail_file is not None:
detail_path = Path(detail_file)
Expand All @@ -531,7 +502,9 @@ def test_polar(
"data_pzy data_pzz pred_pxx pred_pxy pred_pxz pred_pyx pred_pyy pred_pyz "
"pred_pzx pred_pzy pred_pzz",
)
return [l2f], [polar.size]
return {
"rmse" : (rmse_f, polar.size)
}


def print_polar_sys_avg(avg):
Expand All @@ -542,7 +515,7 @@ def print_polar_sys_avg(avg):
avg : np.ndarray
array with summaries
"""
log.info(f"Polarizability RMSE : {avg[0]:e} eV/A")
log.info(f"Polarizability RMSE : {avg['rmse']:e} eV/A")


def test_dipole(
Expand Down Expand Up @@ -584,13 +557,13 @@ def test_dipole(
atoms = dipole.shape[1]
dipole = np.sum(dipole,axis=1)

l2f = l2err(dipole - test_data["dipole"][:numb_test])
rmse_f = rmse(dipole - test_data["dipole"][:numb_test])

if has_atom_dipole == False:
l2f = l2f / atoms
rmse_f = rmse_f / atoms

log.info(f"# number of test data : {numb_test:d}")
log.info(f"Dipole RMSE : {l2f:e} eV/A")
log.info(f"Dipole RMSE : {rmse_f:e} eV/A")

if detail_file is not None:
detail_path = Path(detail_file)
Expand All @@ -607,7 +580,9 @@ def test_dipole(
pe,
header="data_x data_y data_z pred_x pred_y pred_z",
)
return [l2f], [dipole.size]
return {
'rmse' : (rmse_f, dipole.size)
}


def print_dipole_sys_avg(avg):
Expand All @@ -618,4 +593,4 @@ def print_dipole_sys_avg(avg):
avg : np.ndarray
array with summaries
"""
log.info(f"Dipole RMSE : {avg[0]:e} eV/A")
log.info(f"Dipole RMSE : {avg['rmse']:e} eV/A")
34 changes: 34 additions & 0 deletions deepmd/utils/weight_avg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import TYPE_CHECKING, List, Dict, Optional, Tuple
import numpy as np


def weighted_average(
errors: List[Dict[str, Tuple[float, float]]]
) -> Dict:
"""Compute wighted average of prediction errors for model.

Parameters
----------
errors : List[Dict[str, Tuple[float, float]]]
List: the error of systems
Dict: the error of quantities, name given by the key
Tuple: (error, weight)

Returns
-------
Dict
weighted averages
"""
sum_err = {}
sum_siz = {}
for err in errors:
for kk, (ee, ss) in err.items():
if kk in sum_err:
sum_err[kk] += ee * ee * ss
sum_siz[kk] += ss
else :
sum_err[kk] = ee * ee * ss
sum_siz[kk] = ss
for kk in sum_err.keys():
sum_err[kk] = np.sqrt(sum_err[kk] / sum_siz[kk])
return sum_err