@@ -20,7 +20,7 @@ class TestModel(unittest.TestCase):
2020 def setUp (self ) :
2121 gen_data ()
2222
23- def test_model (self ):
23+ def test_descriptor_two_sides (self ):
2424 jfile = 'water_se_a_type.json'
2525 jdata = j_loader (jfile )
2626
@@ -82,7 +82,7 @@ def test_model(self):
8282
8383 type_embedding = typeebd .build (
8484 ntypes ,
85- suffix = "_se_a_type_des_ebd "
85+ suffix = "_se_a_type_des_ebd_2sdies "
8686 )
8787
8888 dout \
@@ -94,7 +94,7 @@ def test_model(self):
9494 t_mesh ,
9595 {'type_embedding' : type_embedding },
9696 reuse = False ,
97- suffix = "_se_a_type_des "
97+ suffix = "_se_a_type_des_2sides "
9898 )
9999
100100 feed_dict_test = {t_prop_c : test_data ['prop_c' ],
@@ -126,6 +126,115 @@ def test_model(self):
126126 for ii in range (model_dout .size ) :
127127 self .assertAlmostEqual (model_dout [ii ], ref_dout [ii ], places = places )
128128
129+
130+ def test_descriptor_one_side (self ):
131+ jfile = 'water_se_a_type.json'
132+ jdata = j_loader (jfile )
133+
134+ systems = j_must_have (jdata , 'systems' )
135+ set_pfx = j_must_have (jdata , 'set_prefix' )
136+ batch_size = j_must_have (jdata , 'batch_size' )
137+ test_size = j_must_have (jdata , 'numb_test' )
138+ batch_size = 1
139+ test_size = 1
140+ stop_batch = j_must_have (jdata , 'stop_batch' )
141+ rcut = j_must_have (jdata ['model' ]['descriptor' ], 'rcut' )
142+ sel = j_must_have (jdata ['model' ]['descriptor' ], 'sel' )
143+ ntypes = len (sel )
144+
145+ data = DataSystem (systems , set_pfx , batch_size , test_size , rcut , run_opt = None )
146+
147+ test_data = data .get_test ()
148+ numb_test = 1
149+
150+ # set parameters
151+ jdata ['model' ]['descriptor' ]['neuron' ] = [5 , 5 , 5 ]
152+ jdata ['model' ]['descriptor' ]['axis_neuron' ] = 2
153+ jdata ['model' ]['descriptor' ]['type_one_side' ] = True
154+ typeebd_param = {'neuron' : [5 , 5 , 5 ],
155+ 'resnet_dt' : False ,
156+ 'seed' : 1 ,
157+ }
158+
159+ # init models
160+ typeebd = TypeEmbedNet (
161+ neuron = typeebd_param ['neuron' ],
162+ resnet_dt = typeebd_param ['resnet_dt' ],
163+ seed = typeebd_param ['seed' ],
164+ )
165+
166+ jdata ['model' ]['descriptor' ].pop ('type' , None )
167+ descrpt = DescrptSeA (** jdata ['model' ]['descriptor' ])
168+
169+ # model._compute_dstats([test_data['coord']], [test_data['box']], [test_data['type']], [test_data['natoms_vec']], [test_data['default_mesh']])
170+ input_data = {'coord' : [test_data ['coord' ]],
171+ 'box' : [test_data ['box' ]],
172+ 'type' : [test_data ['type' ]],
173+ 'natoms_vec' : [test_data ['natoms_vec' ]],
174+ 'default_mesh' : [test_data ['default_mesh' ]]
175+ }
176+ descrpt .bias_atom_e = data .compute_energy_shift ()
177+
178+ t_prop_c = tf .placeholder (tf .float32 , [5 ], name = 't_prop_c' )
179+ t_energy = tf .placeholder (GLOBAL_ENER_FLOAT_PRECISION , [None ], name = 't_energy' )
180+ t_force = tf .placeholder (GLOBAL_TF_FLOAT_PRECISION , [None ], name = 't_force' )
181+ t_virial = tf .placeholder (GLOBAL_TF_FLOAT_PRECISION , [None ], name = 't_virial' )
182+ t_atom_ener = tf .placeholder (GLOBAL_TF_FLOAT_PRECISION , [None ], name = 't_atom_ener' )
183+ t_coord = tf .placeholder (GLOBAL_TF_FLOAT_PRECISION , [None ], name = 'i_coord' )
184+ t_type = tf .placeholder (tf .int32 , [None ], name = 'i_type' )
185+ t_natoms = tf .placeholder (tf .int32 , [ntypes + 2 ], name = 'i_natoms' )
186+ t_box = tf .placeholder (GLOBAL_TF_FLOAT_PRECISION , [None , 9 ], name = 'i_box' )
187+ t_mesh = tf .placeholder (tf .int32 , [None ], name = 'i_mesh' )
188+ is_training = tf .placeholder (tf .bool )
189+ t_fparam = None
190+
191+ type_embedding = typeebd .build (
192+ ntypes ,
193+ suffix = "_se_a_type_des_ebd_1side"
194+ )
195+
196+ dout \
197+ = descrpt .build (
198+ t_coord ,
199+ t_type ,
200+ t_natoms ,
201+ t_box ,
202+ t_mesh ,
203+ {'type_embedding' : type_embedding },
204+ reuse = False ,
205+ suffix = "_se_a_type_des_1side"
206+ )
207+
208+ feed_dict_test = {t_prop_c : test_data ['prop_c' ],
209+ t_energy : test_data ['energy' ] [:numb_test ],
210+ t_force : np .reshape (test_data ['force' ] [:numb_test , :], [- 1 ]),
211+ t_virial : np .reshape (test_data ['virial' ] [:numb_test , :], [- 1 ]),
212+ t_atom_ener : np .reshape (test_data ['atom_ener' ][:numb_test , :], [- 1 ]),
213+ t_coord : np .reshape (test_data ['coord' ] [:numb_test , :], [- 1 ]),
214+ t_box : test_data ['box' ] [:numb_test , :],
215+ t_type : np .reshape (test_data ['type' ] [:numb_test , :], [- 1 ]),
216+ t_natoms : test_data ['natoms_vec' ],
217+ t_mesh : test_data ['default_mesh' ],
218+ is_training : False }
219+
220+ sess = tf .Session ()
221+ sess .run (tf .global_variables_initializer ())
222+ [model_dout ] = sess .run ([dout ],
223+ feed_dict = feed_dict_test )
224+ model_dout = model_dout .reshape ([- 1 ])
225+
226+ ref_dout = [0.0009704469114440277 ,0.0007136310372560243 ,0.0007136310372560243 ,0.000524968274824758 ,- 0.0019790100690810016 ,- 0.0014556100390424947 ,- 0.001318691223889266 ,- 0.0009698525512440269 ,0.001937780602605409 ,
227+ 0.0014251755182315322 ,0.0008158935519461114 ,0.0005943870925895051 ,0.0005943870925895051 ,0.0004340263490412088 ,- 0.0016539827195947239 ,- 0.0012066241021841376 ,- 0.0011042186455562336 ,- 0.0008051343572505189 ,
228+ 0.0016229491738044255 ,0.0011833923257801077 ,0.0006020440527161554 ,0.00047526899287409847 ,0.00047526899287409847 ,0.00037538142786805136 ,- 0.0012811397377036637 ,- 0.0010116898098710776 ,- 0.0008465095301785942 ,
229+ - 0.0006683577463042215 ,0.0012459039620461505 ,0.0009836962283627838 ,0.00077088529431722 ,0.0006105807630364827 ,0.0006105807630364827 ,0.00048361458700877996 ,- 0.0016444700616024337 ,- 0.001302510079662288 ,
230+ - 0.0010856603485807576 ,- 0.0008598975276238373 ,0.00159730642327918 ,0.001265146946434076 ,0.0008495806081447204 ,0.000671787466824433 ,0.000671787466824433 ,0.0005312928157964384 ,- 0.0018105890543181475 ,
231+ - 0.001431844407277983 ,- 0.0011956722392735362 ,- 0.000945544277375045 ,0.0017590147511761475 ,0.0013910348287283414 ,0.0007393644735054756 ,0.0005850536182149991 ,0.0005850536182149991 ,0.0004631887654949332 ,
232+ - 0.0015760302086346792 ,- 0.0012475134925387294 ,- 0.001041074331192672 ,- 0.0008239586048523492 ,0.0015319673563669856 ,0.0012124704278707746 ]
233+
234+ places = 10
235+ for ii in range (model_dout .size ) :
236+ self .assertAlmostEqual (model_dout [ii ], ref_dout [ii ], places = places )
237+
129238
130239
131240
0 commit comments