Skip to content

Commit 73482ac

Browse files
author
Han Wang
committed
add UT for type_one_side
1 parent 4bf3545 commit 73482ac

File tree

2 files changed

+113
-4
lines changed

2 files changed

+113
-4
lines changed

source/tests/test_descrpt_se_a_type.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TestModel(unittest.TestCase):
2020
def setUp(self) :
2121
gen_data()
2222

23-
def test_model(self):
23+
def test_descriptor_two_sides(self):
2424
jfile = 'water_se_a_type.json'
2525
jdata = j_loader(jfile)
2626

@@ -82,7 +82,7 @@ def test_model(self):
8282

8383
type_embedding = typeebd.build(
8484
ntypes,
85-
suffix = "_se_a_type_des_ebd"
85+
suffix = "_se_a_type_des_ebd_2sdies"
8686
)
8787

8888
dout \
@@ -94,7 +94,7 @@ def test_model(self):
9494
t_mesh,
9595
{'type_embedding' : type_embedding},
9696
reuse = False,
97-
suffix = "_se_a_type_des"
97+
suffix = "_se_a_type_des_2sides"
9898
)
9999

100100
feed_dict_test = {t_prop_c: test_data['prop_c'],
@@ -126,6 +126,115 @@ def test_model(self):
126126
for ii in range(model_dout.size) :
127127
self.assertAlmostEqual(model_dout[ii], ref_dout[ii], places = places)
128128

129+
130+
def test_descriptor_one_side(self):
131+
jfile = 'water_se_a_type.json'
132+
jdata = j_loader(jfile)
133+
134+
systems = j_must_have(jdata, 'systems')
135+
set_pfx = j_must_have(jdata, 'set_prefix')
136+
batch_size = j_must_have(jdata, 'batch_size')
137+
test_size = j_must_have(jdata, 'numb_test')
138+
batch_size = 1
139+
test_size = 1
140+
stop_batch = j_must_have(jdata, 'stop_batch')
141+
rcut = j_must_have (jdata['model']['descriptor'], 'rcut')
142+
sel = j_must_have (jdata['model']['descriptor'], 'sel')
143+
ntypes=len(sel)
144+
145+
data = DataSystem(systems, set_pfx, batch_size, test_size, rcut, run_opt = None)
146+
147+
test_data = data.get_test ()
148+
numb_test = 1
149+
150+
# set parameters
151+
jdata['model']['descriptor']['neuron'] = [5, 5, 5]
152+
jdata['model']['descriptor']['axis_neuron'] = 2
153+
jdata['model']['descriptor']['type_one_side'] = True
154+
typeebd_param = {'neuron' : [5, 5, 5],
155+
'resnet_dt': False,
156+
'seed': 1,
157+
}
158+
159+
# init models
160+
typeebd = TypeEmbedNet(
161+
neuron = typeebd_param['neuron'],
162+
resnet_dt = typeebd_param['resnet_dt'],
163+
seed = typeebd_param['seed'],
164+
)
165+
166+
jdata['model']['descriptor'].pop('type', None)
167+
descrpt = DescrptSeA(**jdata['model']['descriptor'])
168+
169+
# model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
170+
input_data = {'coord' : [test_data['coord']],
171+
'box': [test_data['box']],
172+
'type': [test_data['type']],
173+
'natoms_vec' : [test_data['natoms_vec']],
174+
'default_mesh' : [test_data['default_mesh']]
175+
}
176+
descrpt.bias_atom_e = data.compute_energy_shift()
177+
178+
t_prop_c = tf.placeholder(tf.float32, [5], name='t_prop_c')
179+
t_energy = tf.placeholder(GLOBAL_ENER_FLOAT_PRECISION, [None], name='t_energy')
180+
t_force = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name='t_force')
181+
t_virial = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name='t_virial')
182+
t_atom_ener = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name='t_atom_ener')
183+
t_coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None], name='i_coord')
184+
t_type = tf.placeholder(tf.int32, [None], name='i_type')
185+
t_natoms = tf.placeholder(tf.int32, [ntypes+2], name='i_natoms')
186+
t_box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name='i_box')
187+
t_mesh = tf.placeholder(tf.int32, [None], name='i_mesh')
188+
is_training = tf.placeholder(tf.bool)
189+
t_fparam = None
190+
191+
type_embedding = typeebd.build(
192+
ntypes,
193+
suffix = "_se_a_type_des_ebd_1side"
194+
)
195+
196+
dout \
197+
= descrpt.build(
198+
t_coord,
199+
t_type,
200+
t_natoms,
201+
t_box,
202+
t_mesh,
203+
{'type_embedding' : type_embedding},
204+
reuse = False,
205+
suffix = "_se_a_type_des_1side"
206+
)
207+
208+
feed_dict_test = {t_prop_c: test_data['prop_c'],
209+
t_energy: test_data['energy'] [:numb_test],
210+
t_force: np.reshape(test_data['force'] [:numb_test, :], [-1]),
211+
t_virial: np.reshape(test_data['virial'] [:numb_test, :], [-1]),
212+
t_atom_ener: np.reshape(test_data['atom_ener'][:numb_test, :], [-1]),
213+
t_coord: np.reshape(test_data['coord'] [:numb_test, :], [-1]),
214+
t_box: test_data['box'] [:numb_test, :],
215+
t_type: np.reshape(test_data['type'] [:numb_test, :], [-1]),
216+
t_natoms: test_data['natoms_vec'],
217+
t_mesh: test_data['default_mesh'],
218+
is_training: False}
219+
220+
sess = tf.Session()
221+
sess.run(tf.global_variables_initializer())
222+
[model_dout] = sess.run([dout],
223+
feed_dict = feed_dict_test)
224+
model_dout = model_dout.reshape([-1])
225+
226+
ref_dout = [0.0009704469114440277,0.0007136310372560243,0.0007136310372560243,0.000524968274824758,-0.0019790100690810016,-0.0014556100390424947,-0.001318691223889266,-0.0009698525512440269,0.001937780602605409,
227+
0.0014251755182315322,0.0008158935519461114,0.0005943870925895051,0.0005943870925895051,0.0004340263490412088,-0.0016539827195947239,-0.0012066241021841376,-0.0011042186455562336,-0.0008051343572505189,
228+
0.0016229491738044255,0.0011833923257801077,0.0006020440527161554,0.00047526899287409847,0.00047526899287409847,0.00037538142786805136,-0.0012811397377036637,-0.0010116898098710776,-0.0008465095301785942,
229+
-0.0006683577463042215,0.0012459039620461505,0.0009836962283627838,0.00077088529431722,0.0006105807630364827,0.0006105807630364827,0.00048361458700877996,-0.0016444700616024337,-0.001302510079662288,
230+
-0.0010856603485807576,-0.0008598975276238373,0.00159730642327918,0.001265146946434076,0.0008495806081447204,0.000671787466824433,0.000671787466824433,0.0005312928157964384,-0.0018105890543181475,
231+
-0.001431844407277983,-0.0011956722392735362,-0.000945544277375045,0.0017590147511761475,0.0013910348287283414,0.0007393644735054756,0.0005850536182149991,0.0005850536182149991,0.0004631887654949332,
232+
-0.0015760302086346792,-0.0012475134925387294,-0.001041074331192672,-0.0008239586048523492,0.0015319673563669856,0.0012124704278707746]
233+
234+
places = 10
235+
for ii in range(model_dout.size) :
236+
self.assertAlmostEqual(model_dout[ii], ref_dout[ii], places = places)
237+
129238

130239

131240

source/tests/test_fitting_ener_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestModel(unittest.TestCase):
1919
def setUp(self) :
2020
gen_data()
2121

22-
def test_model(self):
22+
def test_fitting(self):
2323
jfile = 'water_se_a_type.json'
2424
jdata = j_loader(jfile)
2525

0 commit comments

Comments
 (0)