Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deepmd/descriptor/loc_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
from deepmd.env import op_module
from deepmd.env import default_tf_session_config
from deepmd.utils.sess import run_sess

class DescrptLocFrame () :
def __init__(self,
Expand Down Expand Up @@ -327,7 +328,7 @@ def _compute_dstats_sys_nonsmth (self,
natoms_vec,
mesh) :
dd_all \
= self.sub_sess.run(self.stat_descrpt,
= run_sess(self.sub_sess, self.stat_descrpt,
feed_dict = {
self.place_holders['coord']: data_coord,
self.place_holders['type']: data_atype,
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
from deepmd.utils.tabulate import DeepTabulate
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess

class DescrptSeA ():
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
Expand Down Expand Up @@ -491,7 +492,7 @@ def _compute_dstats_sys_smth (self,
natoms_vec,
mesh) :
dd_all \
= self.sub_sess.run(self.stat_descrpt,
= run_sess(self.sub_sess, self.stat_descrpt,
feed_dict = {
self.place_holders['coord']: data_coord,
self.place_holders['type']: data_atype,
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from deepmd.env import tf
from deepmd.common import add_data_requirement,get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.utils.argcheck import list_to_doc
from deepmd.utils.sess import run_sess
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
from deepmd.env import op_module
Expand Down Expand Up @@ -518,7 +519,7 @@ def _compute_dstats_sys_smth (self,
mesh,
data_efield) :
dd_all \
= self.sub_sess.run(self.stat_descrpt,
= run_sess(self.sub_sess, self.stat_descrpt,
feed_dict = {
self.place_holders['coord']: data_coord,
self.place_holders['type']: data_atype,
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.env import op_module
from deepmd.env import default_tf_session_config
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
from deepmd.utils.sess import run_sess

class DescrptSeR ():
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
Expand Down Expand Up @@ -401,7 +402,7 @@ def _compute_dstats_sys_se_r (self,
natoms_vec,
mesh) :
dd_all \
= self.sub_sess.run(self.stat_descrpt,
= run_sess(self.sub_sess, self.stat_descrpt,
feed_dict = {
self.place_holders['coord']: data_coord,
self.place_holders['type']: data_atype,
Expand Down
3 changes: 2 additions & 1 deletion deepmd/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from deepmd.env import op_module
from deepmd.env import default_tf_session_config
from deepmd.utils.network import embedding_net, embedding_net_rand_seed_shift
from deepmd.utils.sess import run_sess

class DescrptSeT ():
@docstring_parameter(list_to_doc(ACTIVATION_FN_DICT.keys()), list_to_doc(PRECISION_DICT.keys()))
Expand Down Expand Up @@ -394,7 +395,7 @@ def _compute_dstats_sys_smth (self,
natoms_vec,
mesh) :
dd_all \
= self.sub_sess.run(self.stat_descrpt,
= run_sess(self.sub_sess, self.stat_descrpt,
feed_dict = {
self.place_holders['coord']: data_coord,
self.place_holders['type']: data_atype,
Expand Down
5 changes: 3 additions & 2 deletions deepmd/entrypoints/freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from deepmd.env import tf
from deepmd.env import op_module
from deepmd.utils.sess import run_sess
from os.path import abspath

# load grad of force module
Expand Down Expand Up @@ -154,9 +155,9 @@ def freeze(
# We start a session and restore the graph weights
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
model_type = sess.run("model_attr/model_type:0", feed_dict={}).decode("utf-8")
model_type = run_sess(sess, "model_attr/model_type:0", feed_dict={}).decode("utf-8")
if "modifier_attr/type" in nodes:
modifier_type = sess.run("modifier_attr/type:0", feed_dict={}).decode(
modifier_type = run_sess(sess, "modifier_attr/type:0", feed_dict={}).decode(
"utf-8"
)
else:
Expand Down
5 changes: 3 additions & 2 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from deepmd.utils.argcheck import normalize
from deepmd.utils.compat import updata_deepmd_input
from deepmd.utils.data_system import DeepmdDataSystem
from deepmd.utils.sess import run_sess

if TYPE_CHECKING:
from deepmd.run_options import TFServerV1
Expand Down Expand Up @@ -74,7 +75,7 @@ def wait_done_queue(
"""
with tf.Session(server.target) as sess:
for i in range(cluster_spec.num_tasks("worker")):
sess.run(queue.dequeue())
run_sess(sess, queue.dequeue())
log.debug(f"ps:{task_index:d} received done from worker:{i:d}")
log.debug(f"ps:{task_index:f} quitting")

Expand Down Expand Up @@ -127,7 +128,7 @@ def fill_done_queue(
"""
with tf.Session(server.target) as sess:
for i in range(cluster_spec.num_tasks("ps")):
sess.run(done_ops[i])
run_sess(sess, done_ops[i])
log.debug(f"worker:{task_index:d} sending done to ps:{i:d}")


Expand Down
7 changes: 4 additions & 3 deletions deepmd/infer/data_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from deepmd.env import global_cvt_2_tf_float
from deepmd.env import global_cvt_2_ener_float
from deepmd.env import op_module
from deepmd.utils.sess import run_sess


class DipoleChargeModifier(DeepDipole):
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(self,
self.ext_dim = 3
self.t_ndesc = self.graph.get_tensor_by_name(os.path.join(self.modifier_prefix, 'descrpt_attr/ndescrpt:0'))
self.t_sela = self.graph.get_tensor_by_name(os.path.join(self.modifier_prefix, 'descrpt_attr/sel:0'))
[self.ndescrpt, self.sel_a] = self.sess.run([self.t_ndesc, self.t_sela])
[self.ndescrpt, self.sel_a] = run_sess(self.sess, [self.t_ndesc, self.t_sela])
self.sel_r = [ 0 for ii in range(len(self.sel_a)) ]
self.nnei_a = np.cumsum(self.sel_a)[-1]
self.nnei_r = np.cumsum(self.sel_r)[-1]
Expand Down Expand Up @@ -335,9 +336,9 @@ def _eval_fv(self, coords, cells, atom_types, ext_f) :
feed_dict_test[self.t_box ] = cells.reshape([-1])
feed_dict_test[self.t_mesh ] = default_mesh.reshape([-1])
feed_dict_test[self.t_ef ] = ext_f.reshape([-1])
# print(self.sess.run(tf.shape(self.t_tensor), feed_dict = feed_dict_test))
# print(run_sess(self.sess, tf.shape(self.t_tensor), feed_dict = feed_dict_test))
fout, vout, avout \
= self.sess.run([self.force, self.virial, self.av],
= run_sess(self.sess, [self.force, self.virial, self.av],
feed_dict = feed_dict_test)
# print('fout: ', fout.shape, fout)
fout = self.reverse_map(np.reshape(fout, [nframes,-1,3]), imap)
Expand Down
5 changes: 3 additions & 2 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from deepmd.common import make_default_mesh
from deepmd.env import default_tf_session_config, tf, MODEL_VERSION
from deepmd.utils.sess import run_sess

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -43,7 +44,7 @@ def model_type(self) -> str:
if not self._model_type:
t_mt = self._get_tensor("model_attr/model_type:0")
sess = tf.Session(graph=self.graph, config=default_tf_session_config)
[mt] = sess.run([t_mt], feed_dict={})
[mt] = run_sess(sess, [t_mt], feed_dict={})
self._model_type = mt.decode("utf-8")
return self._model_type

Expand All @@ -57,7 +58,7 @@ def model_version(self) -> str:
try:
t_mt = self._get_tensor("model_attr/model_version:0")
sess = tf.Session(graph=self.graph, config=default_tf_session_config)
[mt] = sess.run([t_mt], feed_dict={})
[mt] = run_sess(sess, [t_mt], feed_dict={})
self._model_version = mt.decode("utf-8")
except KeyError:
# For deepmd-kit version 0.x - 1.x, set model version to 0.0
Expand Down
7 changes: 4 additions & 3 deletions deepmd/infer/deep_pot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.env import default_tf_session_config, tf
from deepmd.infer.data_modifier import DipoleChargeModifier
from deepmd.infer.deep_eval import DeepEval
from deepmd.utils.sess import run_sess

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -113,7 +114,7 @@ def __init__(
# setup modifier
try:
t_modifier_type = self._get_tensor("modifier_attr/type:0")
self.modifier_type = self.sess.run(t_modifier_type).decode("UTF-8")
self.modifier_type = run_sess(self.sess, t_modifier_type).decode("UTF-8")
except (ValueError, KeyError):
self.modifier_type = None

Expand All @@ -123,13 +124,13 @@ def __init__(
t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0")
t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0")
t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0")
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = self.sess.run([t_mdl_name, t_mdl_charge_map, t_sys_charge_map, t_ewald_h, t_ewald_beta])
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(self.sess, [t_mdl_name, t_mdl_charge_map, t_sys_charge_map, t_ewald_h, t_ewald_beta])
mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()]
sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()]
self.dm = DipoleChargeModifier(mdl_name, mdl_charge_map, sys_charge_map, ewald_h = ewald_h, ewald_beta = ewald_beta)

def _run_default_sess(self):
[self.ntypes, self.rcut, self.dfparam, self.daparam, self.tmap] = self.sess.run(
[self.ntypes, self.rcut, self.dfparam, self.daparam, self.tmap] = run_sess(self.sess,
[self.t_ntypes, self.t_rcut, self.t_dfparam, self.t_daparam, self.t_tmap]
)

Expand Down
3 changes: 2 additions & 1 deletion deepmd/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from deepmd.common import make_default_mesh
from deepmd.env import default_tf_session_config, tf
from deepmd.infer.deep_eval import DeepEval
from deepmd.utils.sess import run_sess

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -63,7 +64,7 @@ def __init__(

def _run_default_sess(self):
[self.ntypes, self.rcut, self.tmap, self.tselt, self.output_dim] \
= self.sess.run(
= run_sess(self.sess,
[self.t_ntypes, self.t_rcut, self.t_tmap, self.t_sel_type, self.t_ouput_dim]
)

Expand Down
3 changes: 2 additions & 1 deletion deepmd/infer/ewald_recp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from deepmd.env import global_cvt_2_ener_float
from deepmd.env import op_module
from deepmd.env import default_tf_session_config
from deepmd.utils.sess import run_sess

class EwaldRecp () :
"""
Expand Down Expand Up @@ -79,7 +80,7 @@ def eval(self,
box = np.reshape(box, [nframes * 9])

[energy, force, virial] \
= self.sess.run([self.t_energy, self.t_force, self.t_virial],
= run_sess(self.sess, [self.t_energy, self.t_force, self.t_virial],
feed_dict = {
self.t_coord: coord,
self.t_charge: charge,
Expand Down
13 changes: 7 additions & 6 deletions deepmd/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from deepmd.env import global_cvt_2_tf_float
from deepmd.env import global_cvt_2_ener_float
from deepmd.utils.sess import run_sess

class EnerStdLoss () :
"""
Expand Down Expand Up @@ -136,7 +137,7 @@ def eval(self, sess, feed_dict, natoms):
self.l2_more['l2_atom_ener_loss'],
self.l2_more['l2_pref_force_loss']
]
error, error_e, error_f, error_v, error_ae, error_pf = sess.run(run_data, feed_dict=feed_dict)
error, error_e, error_f, error_v, error_ae, error_pf = run_sess(sess, run_data, feed_dict=feed_dict)
results = {"natoms": natoms[0], "rmse": np.sqrt(error)}
if self.has_e:
results["rmse_e"] = np.sqrt(error_e) / natoms[0]
Expand Down Expand Up @@ -184,7 +185,7 @@ def print_on_training(self,
]

# first train data
train_out = sess.run(run_data, feed_dict=feed_dict_batch)
train_out = run_sess(sess, run_data, feed_dict=feed_dict_batch)
error_train, error_e_train, error_f_train, error_v_train, error_ae_train, error_pf_train = train_out

# than test data, if tensorboard log writter is present, commpute summary
Expand All @@ -193,7 +194,7 @@ def print_on_training(self,
summary_merged_op = tf.summary.merge([self.l2_loss_summary, self.l2_loss_ener_summary, self.l2_loss_force_summary, self.l2_loss_virial_summary])
run_data.insert(0, summary_merged_op)

test_out = sess.run(run_data, feed_dict=feed_dict_test)
test_out = run_sess(sess, run_data, feed_dict=feed_dict_test)

if tb_writer:
summary = test_out.pop(0)
Expand Down Expand Up @@ -297,7 +298,7 @@ def eval(self, sess, feed_dict, natoms):
self.l2_more['l2_ener_loss'],
self.l2_more['l2_ener_dipole_loss']
]
error, error_e, error_ed = sess.run(run_data, feed_dict=feed_dict)
error, error_e, error_ed = run_sess(sess, run_data, feed_dict=feed_dict)
results = {
'natoms': natoms[0],
'rmse': np.sqrt(error),
Expand Down Expand Up @@ -330,7 +331,7 @@ def print_on_training(self,
]

# first train data
train_out = sess.run(run_data, feed_dict=feed_dict_batch)
train_out = run_sess(sess, run_data, feed_dict=feed_dict_batch)
error_train, error_e_train, error_ed_train = train_out

# than test data, if tensorboard log writter is present, commpute summary
Expand All @@ -343,7 +344,7 @@ def print_on_training(self,
])
run_data.insert(0, summary_merged_op)

test_out = sess.run(run_data, feed_dict=feed_dict_test)
test_out = run_sess(sess, run_data, feed_dict=feed_dict_test)

if tb_writer:
summary = test_out.pop(0)
Expand Down
7 changes: 4 additions & 3 deletions deepmd/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from deepmd.env import global_cvt_2_tf_float
from deepmd.env import global_cvt_2_ener_float
from deepmd.utils.sess import run_sess

class TensorLoss () :
"""
Expand Down Expand Up @@ -123,7 +124,7 @@ def eval(self, sess, feed_dict, natoms):
atoms = natoms[0]

run_data = [self.l2_l, self.l2_more['local_loss'], self.l2_more['global_loss']]
error, error_lc, error_gl = sess.run(run_data, feed_dict=feed_dict)
error, error_lc, error_gl = run_sess(sess, run_data, feed_dict=feed_dict)

results = {"natoms": atoms, "rmse": np.sqrt(error)}
if self.local_weight > 0.0:
Expand Down Expand Up @@ -166,7 +167,7 @@ def print_on_training(self,
summary_list.append(self.l2_loss_global_summary)

# first train data
error_train = sess.run(run_data, feed_dict=feed_dict_batch)
error_train = run_sess(sess, run_data, feed_dict=feed_dict_batch)

# than test data, if tensorboard log writter is present, commpute summary
# and write tensorboard logs
Expand All @@ -175,7 +176,7 @@ def print_on_training(self,
summary_merged_op = tf.summary.merge(summary_list)
run_data.insert(0, summary_merged_op)

test_out = sess.run(run_data, feed_dict=feed_dict_test)
test_out = run_sess(sess, run_data, feed_dict=feed_dict_test)

if tb_writer:
summary = test_out.pop(0)
Expand Down
Loading