diff --git a/source/tests/test_activation_fn_gelu.py b/source/tests/test_activation_fn_gelu.py index 6ecbd0154f..b1c30eeefc 100644 --- a/source/tests/test_activation_fn_gelu.py +++ b/source/tests/test_activation_fn_gelu.py @@ -17,7 +17,7 @@ class TestGelu(tf.test.TestCase): def setUp(self): self.places = 6 - self.sess = self.test_session().__enter__() + self.sess = self.cached_session().__enter__() self.inputs = tf.reshape( tf.constant([0.0, 1.0, 2.0, 3.0], dtype=tf.float64), [-1, 1] ) diff --git a/source/tests/test_data_large_batch.py b/source/tests/test_data_large_batch.py index 3ae46e8cb9..5750f956f8 100644 --- a/source/tests/test_data_large_batch.py +++ b/source/tests/test_data_large_batch.py @@ -180,7 +180,7 @@ def test_data_mixed_type(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) # print(sess.run(model.type_embedding)) @@ -376,7 +376,7 @@ def test_stripped_data_mixed_type(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) # print(sess.run(model.type_embedding)) @@ -572,7 +572,7 @@ def test_compressible_data_mixed_type(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) # print(sess.run(model.type_embedding)) diff --git a/source/tests/test_data_modifier.py b/source/tests/test_data_modifier.py index dfc602fd92..368a60d68a 100644 --- a/source/tests/test_data_modifier.py +++ b/source/tests/test_data_modifier.py @@ -80,7 +80,7 @@ def _setUp(self): model.build(data) # freeze the graph - with self.test_session() as sess: + with self.cached_session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() diff --git a/source/tests/test_data_modifier_shuffle.py b/source/tests/test_data_modifier_shuffle.py index 151caa9e16..9ddbb8ee29 100644 --- a/source/tests/test_data_modifier_shuffle.py +++ b/source/tests/test_data_modifier_shuffle.py @@ -81,7 +81,7 @@ def _setUp(self): model.build(data) # freeze the graph - with self.test_session() as sess: + with self.cached_session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) graph = tf.get_default_graph() diff --git a/source/tests/test_descrpt_hybrid.py b/source/tests/test_descrpt_hybrid.py index ed39c04307..317f6ea5a0 100644 --- a/source/tests/test_descrpt_hybrid.py +++ b/source/tests/test_descrpt_hybrid.py @@ -115,7 +115,7 @@ def test_descriptor_hybrid(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict=feed_dict_test) diff --git a/source/tests/test_descrpt_nonsmth.py b/source/tests/test_descrpt_nonsmth.py index 1d503e6c8c..fd3bb0b2f7 100644 --- a/source/tests/test_descrpt_nonsmth.py +++ b/source/tests/test_descrpt_nonsmth.py @@ -160,7 +160,7 @@ class TestNonSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data, sess=self.test_session().__enter__()) + Inter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, suffix="_se") @@ -180,8 +180,8 @@ def test_pbc(self): data = Data() inter0 = Inter() inter1 = Inter() - inter0.setUp(data, pbc=True, sess=self.test_session().__enter__()) - inter1.setUp(data, pbc=False, sess=self.test_session().__enter__()) + inter0.setUp(data, pbc=True, sess=self.cached_session().__enter__()) + inter1.setUp(data, pbc=False, sess=self.cached_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) @@ -233,8 +233,8 @@ def test_pbc_small_box(self): data1 = Data(box_scale=2) inter0 = Inter() inter1 = Inter() - inter0.setUp(data0, pbc=True, sess=self.test_session().__enter__()) - inter1.setUp(data1, pbc=False, sess=self.test_session().__enter__()) + inter0.setUp(data0, pbc=True, sess=self.cached_session().__enter__()) + inter1.setUp(data1, pbc=False, sess=self.cached_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) diff --git a/source/tests/test_descrpt_se_a_mask.py b/source/tests/test_descrpt_se_a_mask.py index 30c514a2cc..85cd1cc2a1 100644 --- a/source/tests/test_descrpt_se_a_mask.py +++ b/source/tests/test_descrpt_se_a_mask.py @@ -277,7 +277,7 @@ def test_descriptor_se_a_mask(self): t_aparam: test_data["aparam"][:numb_test, :], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [op_dout] = sess.run([dout], feed_dict=feed_dict_test) op_dout = op_dout.reshape([-1]) diff --git a/source/tests/test_descrpt_se_a_type.py b/source/tests/test_descrpt_se_a_type.py index b10920b1d4..aeab18f149 100644 --- a/source/tests/test_descrpt_se_a_type.py +++ b/source/tests/test_descrpt_se_a_type.py @@ -120,7 +120,7 @@ def test_descriptor_two_sides(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict=feed_dict_test) model_dout = model_dout.reshape([-1]) @@ -284,7 +284,7 @@ def test_descriptor_one_side(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict=feed_dict_test) model_dout = model_dout.reshape([-1]) diff --git a/source/tests/test_descrpt_se_atten.py b/source/tests/test_descrpt_se_atten.py index e49e6ab3e2..76df651a46 100644 --- a/source/tests/test_descrpt_se_atten.py +++ b/source/tests/test_descrpt_se_atten.py @@ -141,7 +141,7 @@ def test_descriptor_two_sides(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict=feed_dict_test) model_dout = model_dout.reshape([-1]) @@ -318,7 +318,7 @@ def test_descriptor_one_side(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict=feed_dict_test) model_dout = model_dout.reshape([-1]) @@ -488,7 +488,7 @@ def test_stripped_type_embedding_descriptor_two_sides(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict=feed_dict_test) model_dout = model_dout.reshape([-1]) @@ -666,7 +666,7 @@ def test_compressible_descriptor_two_sides(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [model_dout] = sess.run([dout], feed_dict=feed_dict_test) model_dout = model_dout.reshape([-1]) diff --git a/source/tests/test_descrpt_se_r.py b/source/tests/test_descrpt_se_r.py index c20515a5fa..779954a545 100644 --- a/source/tests/test_descrpt_se_r.py +++ b/source/tests/test_descrpt_se_r.py @@ -135,7 +135,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data, sess=self.test_session().__enter__()) + Inter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, suffix="_se_r") @@ -155,8 +155,8 @@ def test_pbc(self): data = Data() inter0 = Inter() inter1 = Inter() - inter0.setUp(data, pbc=True, sess=self.test_session().__enter__()) - inter1.setUp(data, pbc=False, sess=self.test_session().__enter__()) + inter0.setUp(data, pbc=True, sess=self.cached_session().__enter__()) + inter1.setUp(data, pbc=False, sess=self.cached_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) @@ -208,8 +208,8 @@ def test_pbc_small_box(self): data1 = Data(box_scale=2) inter0 = Inter() inter1 = Inter() - inter0.setUp(data0, pbc=True, sess=self.test_session().__enter__()) - inter1.setUp(data1, pbc=False, sess=self.test_session().__enter__()) + inter0.setUp(data0, pbc=True, sess=self.cached_session().__enter__()) + inter1.setUp(data1, pbc=False, sess=self.cached_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) diff --git a/source/tests/test_descrpt_sea_ef.py b/source/tests/test_descrpt_sea_ef.py index e39afec97e..efd86854c7 100644 --- a/source/tests/test_descrpt_sea_ef.py +++ b/source/tests/test_descrpt_sea_ef.py @@ -154,7 +154,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data, sess=self.test_session().__enter__()) + Inter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, suffix="_sea_ef") diff --git a/source/tests/test_descrpt_sea_ef_para.py b/source/tests/test_descrpt_sea_ef_para.py index 1ddcc4e196..1a109013cb 100644 --- a/source/tests/test_descrpt_sea_ef_para.py +++ b/source/tests/test_descrpt_sea_ef_para.py @@ -154,7 +154,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data, sess=self.test_session().__enter__()) + Inter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, suffix="_sea_ef_para") diff --git a/source/tests/test_descrpt_sea_ef_rot.py b/source/tests/test_descrpt_sea_ef_rot.py index 10553b878d..d94565af96 100644 --- a/source/tests/test_descrpt_sea_ef_rot.py +++ b/source/tests/test_descrpt_sea_ef_rot.py @@ -17,7 +17,7 @@ class TestEfRot(tf.test.TestCase): def setUp(self): - self.sess = self.test_session().__enter__() + self.sess = self.cached_session().__enter__() self.natoms = [5, 5, 2, 3] self.ntypes = 2 self.sel_a = [12, 24] diff --git a/source/tests/test_descrpt_sea_ef_vert.py b/source/tests/test_descrpt_sea_ef_vert.py index dcbc418720..77ffb3150c 100644 --- a/source/tests/test_descrpt_sea_ef_vert.py +++ b/source/tests/test_descrpt_sea_ef_vert.py @@ -154,7 +154,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data, sess=self.test_session().__enter__()) + Inter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, suffix="_sea_ef_vert") diff --git a/source/tests/test_descrpt_smooth.py b/source/tests/test_descrpt_smooth.py index aa0730cdea..59076e366e 100644 --- a/source/tests/test_descrpt_smooth.py +++ b/source/tests/test_descrpt_smooth.py @@ -153,7 +153,7 @@ class TestSmooth(Inter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - Inter.setUp(self, data, sess=self.test_session().__enter__()) + Inter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, suffix="_smth") @@ -173,8 +173,8 @@ def test_pbc(self): data = Data() inter0 = Inter() inter1 = Inter() - inter0.setUp(data, pbc=True, sess=self.test_session().__enter__()) - inter1.setUp(data, pbc=False, sess=self.test_session().__enter__()) + inter0.setUp(data, pbc=True, sess=self.cached_session().__enter__()) + inter1.setUp(data, pbc=False, sess=self.cached_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) @@ -226,8 +226,8 @@ def test_pbc_small_box(self): data1 = Data(box_scale=2) inter0 = Inter() inter1 = Inter() - inter0.setUp(data0, pbc=True, sess=self.test_session().__enter__()) - inter1.setUp(data1, pbc=False, sess=self.test_session().__enter__()) + inter0.setUp(data0, pbc=True, sess=self.cached_session().__enter__()) + inter1.setUp(data1, pbc=False, sess=self.cached_session().__enter__()) inter0.net_w_i = np.copy(np.ones(inter0.ndescrpt)) inter1.net_w_i = np.copy(np.ones(inter1.ndescrpt)) diff --git a/source/tests/test_dipole_se_a.py b/source/tests/test_dipole_se_a.py index 4e2fa9b30d..687e68c2be 100644 --- a/source/tests/test_dipole_se_a.py +++ b/source/tests/test_dipole_se_a.py @@ -111,7 +111,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([dipole, gdipole], feed_dict=feed_dict_test) diff --git a/source/tests/test_dipole_se_a_tebd.py b/source/tests/test_dipole_se_a_tebd.py index f848526735..4b2e6d0688 100644 --- a/source/tests/test_dipole_se_a_tebd.py +++ b/source/tests/test_dipole_se_a_tebd.py @@ -129,7 +129,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([dipole, gdipole], feed_dict=feed_dict_test) diff --git a/source/tests/test_embedding_net.py b/source/tests/test_embedding_net.py index f09ef74948..1b8c68c089 100644 --- a/source/tests/test_embedding_net.py +++ b/source/tests/test_embedding_net.py @@ -13,7 +13,7 @@ class Inter(tf.test.TestCase): def setUp(self): - self.sess = self.test_session().__enter__() + self.sess = self.cached_session().__enter__() self.inputs = tf.constant([0.0, 1.0, 2.0], dtype=tf.float64) self.ndata = 3 self.inputs = tf.reshape(self.inputs, [-1, 1]) diff --git a/source/tests/test_ewald.py b/source/tests/test_ewald.py index b6b925f801..ef2ace39a4 100644 --- a/source/tests/test_ewald.py +++ b/source/tests/test_ewald.py @@ -64,7 +64,7 @@ def setUp(self): def test_py_interface(self): hh = 1e-4 places = 4 - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() t_energy, t_force, t_virial = op_module.ewald_recp( self.coord, self.charge, @@ -91,7 +91,7 @@ def test_py_interface(self): def test_force(self): hh = 1e-4 places = 6 - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() t_energy, t_force, t_virial = op_module.ewald_recp( self.coord, self.charge, @@ -144,7 +144,7 @@ def test_force(self): def test_virial(self): hh = 1e-4 places = 6 - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() t_energy, t_force, t_virial = op_module.ewald_recp( self.coord, self.charge, diff --git a/source/tests/test_fitting_dos.py b/source/tests/test_fitting_dos.py index 95de81c32c..60a0ee4158 100644 --- a/source/tests/test_fitting_dos.py +++ b/source/tests/test_fitting_dos.py @@ -180,7 +180,7 @@ def test_fitting(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [pred_atom_dos] = sess.run([atom_dos], feed_dict=feed_dict_test) diff --git a/source/tests/test_fitting_ener_type.py b/source/tests/test_fitting_ener_type.py index 54621b634a..42190ef557 100644 --- a/source/tests/test_fitting_ener_type.py +++ b/source/tests/test_fitting_ener_type.py @@ -188,7 +188,7 @@ def test_fitting(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [pred_atom_ener] = sess.run([atom_ener], feed_dict=feed_dict_test) diff --git a/source/tests/test_layer_name.py b/source/tests/test_layer_name.py index 6de4a09736..c6a2f0b09c 100644 --- a/source/tests/test_layer_name.py +++ b/source/tests/test_layer_name.py @@ -137,7 +137,7 @@ def test_model(self): is_training: False, } - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) [e1, f1, v1, e2, f2, v2] = sess.run( [e_energy1, e_force1, e_virial1, e_energy2, e_force2, e_virial2], diff --git a/source/tests/test_linear_model.py b/source/tests/test_linear_model.py index 13a2bc4850..21f0f6efc8 100644 --- a/source/tests/test_linear_model.py +++ b/source/tests/test_linear_model.py @@ -94,7 +94,7 @@ def test_linear_ener_model(self): t_mesh: test_data["default_mesh"], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) e = np.reshape(e, [1, -1]) diff --git a/source/tests/test_model_dos.py b/source/tests/test_model_dos.py index 3562a5b9f9..c7160d4dda 100644 --- a/source/tests/test_model_dos.py +++ b/source/tests/test_model_dos.py @@ -116,7 +116,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [pred_dos, pred_atom_dos] = sess.run([dos, atom_dos], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_loc_frame.py b/source/tests/test_model_loc_frame.py index ed0fc3815a..c493013316 100644 --- a/source/tests/test_model_loc_frame.py +++ b/source/tests/test_model_loc_frame.py @@ -114,7 +114,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_multi.py b/source/tests/test_model_multi.py index 384f1e0553..9017da22e7 100644 --- a/source/tests/test_model_multi.py +++ b/source/tests/test_model_multi.py @@ -141,7 +141,7 @@ def test_model(self): t_mesh: test_data["default_mesh"], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() # test water energy sess.run(tf.global_variables_initializer()) diff --git a/source/tests/test_model_se_a.py b/source/tests/test_model_se_a.py index 65e42f43a0..d3b4323f0d 100644 --- a/source/tests/test_model_se_a.py +++ b/source/tests/test_model_se_a.py @@ -123,7 +123,7 @@ def test_model_atom_ener(self): t_mesh: test_data["default_mesh"], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) self.assertAlmostEqual(e[0], set_atom_ener[0], places=10) @@ -212,7 +212,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) @@ -347,7 +347,7 @@ def test_model_atom_ener_type_embedding(self): t_mesh: test_data["default_mesh"], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) self.assertAlmostEqual(e[0], set_atom_ener[0], places=10) diff --git a/source/tests/test_model_se_a_aparam.py b/source/tests/test_model_se_a_aparam.py index b236320d24..41111c57ee 100644 --- a/source/tests/test_model_se_a_aparam.py +++ b/source/tests/test_model_se_a_aparam.py @@ -115,7 +115,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_se_a_ebd.py b/source/tests/test_model_se_a_ebd.py index 96de277d2f..bf856b7bc5 100644 --- a/source/tests/test_model_se_a_ebd.py +++ b/source/tests/test_model_se_a_ebd.py @@ -115,7 +115,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_se_a_fparam.py b/source/tests/test_model_se_a_fparam.py index fad41947e2..cdb85157a4 100644 --- a/source/tests/test_model_se_a_fparam.py +++ b/source/tests/test_model_se_a_fparam.py @@ -116,7 +116,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_se_a_srtab.py b/source/tests/test_model_se_a_srtab.py index ff91af619b..98cab9e073 100644 --- a/source/tests/test_model_se_a_srtab.py +++ b/source/tests/test_model_se_a_srtab.py @@ -140,7 +140,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_se_a_type.py b/source/tests/test_model_se_a_type.py index 63d0ae279c..85e4a2916d 100644 --- a/source/tests/test_model_se_a_type.py +++ b/source/tests/test_model_se_a_type.py @@ -121,7 +121,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) # print(sess.run(model.type_embedding)) diff --git a/source/tests/test_model_se_atten.py b/source/tests/test_model_se_atten.py index 6e6e9928a6..445959ceb2 100644 --- a/source/tests/test_model_se_atten.py +++ b/source/tests/test_model_se_atten.py @@ -132,7 +132,7 @@ def test_model(self): t_mesh: test_data["default_mesh"], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) # print(sess.run(model.type_embedding)) @@ -258,7 +258,7 @@ def test_exclude_types(self): is_training: False, } - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) [des] = sess.run([dout], feed_dict=feed_dict_test1) @@ -357,7 +357,7 @@ def test_compressible_model(self): t_mesh: test_data["default_mesh"], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) # print(sess.run(model.type_embedding)) @@ -485,7 +485,7 @@ def test_compressible_exclude_types(self): is_training: False, } - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) [des] = sess.run([dout], feed_dict=feed_dict_test1) @@ -587,7 +587,7 @@ def test_stripped_type_embedding_model(self): t_mesh: test_data["default_mesh"], is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) # print(sess.run(model.type_embedding)) @@ -719,7 +719,7 @@ def test_stripped_type_embedding_exclude_types(self): is_training: False, } - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) [des] = sess.run([dout], feed_dict=feed_dict_test1) diff --git a/source/tests/test_model_se_r.py b/source/tests/test_model_se_r.py index 01151d8c30..94812308c6 100644 --- a/source/tests/test_model_se_r.py +++ b/source/tests/test_model_se_r.py @@ -111,7 +111,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_se_t.py b/source/tests/test_model_se_t.py index 300ad46a0a..1d67e852c7 100644 --- a/source/tests/test_model_se_t.py +++ b/source/tests/test_model_se_t.py @@ -109,7 +109,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_model_spin.py b/source/tests/test_model_spin.py index a264f38616..9bdf1d780a 100644 --- a/source/tests/test_model_spin.py +++ b/source/tests/test_model_spin.py @@ -122,7 +122,7 @@ def test_model_spin(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [out_ener, out_force, out_virial] = sess.run( [energy, force, virial], feed_dict=feed_dict_test diff --git a/source/tests/test_nvnmd_entrypoints.py b/source/tests/test_nvnmd_entrypoints.py index af0cd48146..3e721516f1 100644 --- a/source/tests/test_nvnmd_entrypoints.py +++ b/source/tests/test_nvnmd_entrypoints.py @@ -454,7 +454,7 @@ def test_model_qnn_v0(self): dic_ph["default_mesh"]: mesh_dat, } # - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) # get tensordic keys = "o_descriptor,o_rmat,o_energy".split(",") @@ -762,7 +762,7 @@ def test_model_qnn_v1(self): dic_ph["default_mesh"]: mesh_dat, } # - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) # get tensordic keys = "o_descriptor,o_rmat,o_energy".split(",") @@ -818,7 +818,7 @@ def test_model_qnn_v1(self): ref_dout = 60.73941362 np.testing.assert_almost_equal(pred, ref_dout, 8) # test freeze - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() weight_file1 = str(tests_path / "nvnmd" / "ref" / "weight_v1_cnn.npy") weight_file2 = str(tests_path / "nvnmd" / "out" / "weight_v1_qnn.npy") save_weight(sess, weight_file2) diff --git a/source/tests/test_nvnmd_op.py b/source/tests/test_nvnmd_op.py index 2b59b9ef94..3419b375e4 100644 --- a/source/tests/test_nvnmd_op.py +++ b/source/tests/test_nvnmd_op.py @@ -17,7 +17,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -110,7 +110,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -140,7 +140,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -166,7 +166,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -192,7 +192,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -238,7 +238,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -284,7 +284,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -330,7 +330,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -376,7 +376,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph @@ -402,7 +402,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() def test_op(self): # graph diff --git a/source/tests/test_pairwise_dprc.py b/source/tests/test_pairwise_dprc.py index 2ea5888b60..04aaa237b1 100644 --- a/source/tests/test_pairwise_dprc.py +++ b/source/tests/test_pairwise_dprc.py @@ -349,7 +349,7 @@ def test_model_ener(self): t_aparam: np.reshape(np.tile(test_data["aparam"], 5), [-1]), is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [e, f, v] = sess.run([energy, force, virial], feed_dict=feed_dict_test) diff --git a/source/tests/test_polar_se_a.py b/source/tests/test_polar_se_a.py index 1933816488..2564dc0656 100644 --- a/source/tests/test_polar_se_a.py +++ b/source/tests/test_polar_se_a.py @@ -110,7 +110,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([polar, gpolar], feed_dict=feed_dict_test) diff --git a/source/tests/test_polar_se_a_tebd.py b/source/tests/test_polar_se_a_tebd.py index 284cb46498..570c4261d9 100644 --- a/source/tests/test_polar_se_a_tebd.py +++ b/source/tests/test_polar_se_a_tebd.py @@ -128,7 +128,7 @@ def test_model(self): is_training: False, } - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) [p, gp] = sess.run([polar, gpolar], feed_dict=feed_dict_test) diff --git a/source/tests/test_prod_env_mat.py b/source/tests/test_prod_env_mat.py index cf0b9e9296..663b991831 100644 --- a/source/tests/test_prod_env_mat.py +++ b/source/tests/test_prod_env_mat.py @@ -11,7 +11,7 @@ class TestProdEnvMat(tf.test.TestCase): def setUp(self): - self.sess = self.test_session().__enter__() + self.sess = self.cached_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, diff --git a/source/tests/test_prod_force.py b/source/tests/test_prod_force.py index e0497d0b7e..83a44c0be9 100644 --- a/source/tests/test_prod_force.py +++ b/source/tests/test_prod_force.py @@ -18,7 +18,7 @@ def setUp(self): config.graph_options.rewrite_options.custom_optimizers.add().name = ( "dpparallel" ) - self.sess = self.test_session(config=config).__enter__() + self.sess = self.cached_session(config=config).__enter__() self.nframes = 2 self.dcoord = [ 12.83, diff --git a/source/tests/test_prod_force_grad.py b/source/tests/test_prod_force_grad.py index a7eaeb7511..012def217f 100644 --- a/source/tests/test_prod_force_grad.py +++ b/source/tests/test_prod_force_grad.py @@ -10,7 +10,7 @@ class TestProdForceGrad(tf.test.TestCase): def setUp(self): - self.sess = self.test_session().__enter__() + self.sess = self.cached_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, diff --git a/source/tests/test_prod_virial.py b/source/tests/test_prod_virial.py index 29f71daf68..2abcfcb1bf 100644 --- a/source/tests/test_prod_virial.py +++ b/source/tests/test_prod_virial.py @@ -10,7 +10,7 @@ class TestProdVirial(tf.test.TestCase): def setUp(self): - self.sess = self.test_session().__enter__() + self.sess = self.cached_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, diff --git a/source/tests/test_prod_virial_grad.py b/source/tests/test_prod_virial_grad.py index f7d6cfe92d..548b63a54b 100644 --- a/source/tests/test_prod_virial_grad.py +++ b/source/tests/test_prod_virial_grad.py @@ -10,7 +10,7 @@ class TestProdVirialGrad(tf.test.TestCase): def setUp(self): - self.sess = self.test_session().__enter__() + self.sess = self.cached_session().__enter__() self.nframes = 2 self.dcoord = [ 12.83, diff --git a/source/tests/test_tab_nonsmth.py b/source/tests/test_tab_nonsmth.py index d6df226478..9e3f9ff640 100644 --- a/source/tests/test_tab_nonsmth.py +++ b/source/tests/test_tab_nonsmth.py @@ -178,7 +178,7 @@ class TestTabNonSmooth(IntplInter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - IntplInter.setUp(self, data, sess=self.test_session().__enter__()) + IntplInter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, places=5, suffix="_tab") diff --git a/source/tests/test_tab_smooth.py b/source/tests/test_tab_smooth.py index 220ba4e3f3..49b18e14f3 100644 --- a/source/tests/test_tab_smooth.py +++ b/source/tests/test_tab_smooth.py @@ -175,7 +175,7 @@ class TestTabSmooth(IntplInter, tf.test.TestCase): def setUp(self): self.places = 5 data = Data() - IntplInter.setUp(self, data, sess=self.test_session().__enter__()) + IntplInter.setUp(self, data, sess=self.cached_session().__enter__()) def test_force(self): force_test(self, self, places=5, suffix="_tab_smth") diff --git a/source/tests/test_type_embed.py b/source/tests/test_type_embed.py index 47de16cbdc..3e79bad70b 100644 --- a/source/tests/test_type_embed.py +++ b/source/tests/test_type_embed.py @@ -23,14 +23,14 @@ def test_embed_atom_type(self): ) expected_out = [[1, 2, 3], [1, 2, 3], [1, 2, 3], [7, 7, 7], [7, 7, 7]] atom_embed = embed_atom_type(ntypes, natoms, type_embedding) - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() atom_embed = sess.run(atom_embed) np.testing.assert_almost_equal(atom_embed, expected_out, 10) def test_type_embed_net(self): ten = TypeEmbedNet([2, 4, 8], seed=1, uniform_seed=True) type_embedding = ten.build(2) - sess = self.test_session().__enter__() + sess = self.cached_session().__enter__() sess.run(tf.global_variables_initializer()) type_embedding = sess.run(type_embedding) diff --git a/source/tests/test_type_one_side.py b/source/tests/test_type_one_side.py index e16ecd2b12..8e7c173912 100644 --- a/source/tests/test_type_one_side.py +++ b/source/tests/test_type_one_side.py @@ -125,7 +125,7 @@ def test_descriptor_one_side_exclude_types(self): feed_dict_test2[t_type] = np.reshape(new_type2[:numb_test, :], [-1]) feed_dict_test2[t_natoms] = new_natoms2 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) [model_dout1] = sess.run([dout], feed_dict=feed_dict_test1) [model_dout2] = sess.run([dout], feed_dict=feed_dict_test2) @@ -231,7 +231,7 @@ def test_se_r_one_side_exclude_types(self): feed_dict_test2[t_type] = np.reshape(new_type2[:numb_test, :], [-1]) feed_dict_test2[t_natoms] = new_natoms2 - with self.test_session() as sess: + with self.cached_session() as sess: sess.run(tf.global_variables_initializer()) [model_dout1] = sess.run([dout], feed_dict=feed_dict_test1) [model_dout2] = sess.run([dout], feed_dict=feed_dict_test2)