Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions source/train/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def test (args):
dp = DeepWFC(args.model)
else :
raise RuntimeError('unknow model type '+de.model_type)
for ii in all_sys:
for cc,ii in enumerate(all_sys):
args.system = ii
print ("# ---------------output of dp test--------------- ")
print ("# testing system : " + ii)
if de.model_type == 'ener':
err, siz = test_ener(dp, args)
err, siz = test_ener(dp, args, append_detail = (cc!=0))
elif de.model_type == 'dipole':
err, siz = test_dipole(dp, args)
elif de.model_type == 'polar':
Expand Down Expand Up @@ -89,7 +89,14 @@ def weighted_average(err_coll, siz_coll):
return sum_err


def test_ener (dp, args) :
def save_txt_file(fname, data, header = "", append = False):
fp = fname
if append : fp = open(fp, 'ab')
np.savetxt(fp, data, header = header)
if append : fp.close()


def test_ener (dp, args, append_detail = False) :
if args.rand_seed is not None :
np.random.seed(args.rand_seed % (2**32))

Expand Down Expand Up @@ -122,10 +129,7 @@ def test_ener (dp, args) :
else :
aparam = None
detail_file = args.detail_file
if detail_file is not None:
atomic = True
else:
atomic = False
atomic = False

ret = dp.eval(coord, box, atype, fparam = fparam, aparam = aparam, atomic = atomic)
energy = ret[0]
Expand Down Expand Up @@ -158,18 +162,21 @@ def test_ener (dp, args) :
pe = np.concatenate((np.reshape(test_data["energy"][:numb_test], [-1,1]),
np.reshape(energy, [-1,1])),
axis = 1)
np.savetxt(detail_file+".e.out", pe,
header = 'data_e pred_e')
save_txt_file(detail_file+".e.out", pe,
header = '%s: data_e pred_e' % args.system,
append = append_detail)
pf = np.concatenate((np.reshape(test_data["force"] [:numb_test], [-1,3]),
np.reshape(force, [-1,3])),
axis = 1)
np.savetxt(detail_file+".f.out", pf,
header = 'data_fx data_fy data_fz pred_fx pred_fy pred_fz')
save_txt_file(detail_file+".f.out", pf,
header = '%s: data_fx data_fy data_fz pred_fx pred_fy pred_fz' % args.system,
append = append_detail)
pv = np.concatenate((np.reshape(test_data["virial"][:numb_test], [-1,9]),
np.reshape(virial, [-1,9])),
axis = 1)
np.savetxt(detail_file+".v.out", pv,
header = 'data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz')
save_txt_file(detail_file+".v.out", pv,
header = '%s: data_vxx data_vxy data_vxz data_vyx data_vyy data_vyz data_vzx data_vzy data_vzz pred_vxx pred_vxy pred_vxz pred_vyx pred_vyy pred_vyz pred_vzx pred_vzy pred_vzz' % args.system,
append = append_detail)
return [l2ea, l2f, l2va], [energy.size, force.size, virial.size]


Expand Down