Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions deepmd/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ def model_version(self) -> str:
if not self._model_version:
try:
t_mt = self._get_tensor("model_attr/model_version:0")
sess = tf.Session(graph=self.graph, config=default_tf_session_config)
[mt] = run_sess(sess, [t_mt], feed_dict={})
self._model_version = mt.decode("utf-8")
except KeyError:
# For deepmd-kit version 0.x - 1.x, set model version to 0.0
self._model_version = "0.0"
else:
sess = tf.Session(graph=self.graph, config=default_tf_session_config)
[mt] = run_sess(sess, [t_mt], feed_dict={})
self._model_version = mt.decode("utf-8")
return self._model_version

def _graph_compatable(
Expand Down
5 changes: 3 additions & 2 deletions deepmd/infer/deep_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def __init__(
# then put those into self.attrs
for attr_name, tensor_name in optional_tensors.items():
self._get_tensor(tensor_name, attr_name)
self.tensors.update(optional_tensors)
self._support_gfv = True
except KeyError:
self._support_gfv = False
else:
self.tensors.update(optional_tensors)
self._support_gfv = True

# start a tf session associated to the graph
self.sess = tf.Session(graph=self.graph, config=default_tf_session_config)
Expand Down
6 changes: 3 additions & 3 deletions deepmd/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class TensorLoss () :
Loss function for tensorial properties.
"""
def __init__ (self, jdata, **kwarg) :
try:
model = kwarg['model']
model = kwarg.get('model', None)
if model is not None:
self.type_sel = model.get_sel_type()
except :
else:
self.type_sel = None
self.tensor_name = kwarg['tensor_name']
self.tensor_size = kwarg['tensor_size']
Expand Down
18 changes: 4 additions & 14 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ def _init_param(self, jdata):
self.descrpt = DescrptHybrid(descrpt_list)

# fitting net
try:
fitting_type = fitting_param['type']
except:
fitting_type = 'ener'
fitting_type = fitting_param.get('type', 'ener')
fitting_param.pop('type', None)
fitting_param['descrpt'] = self.descrpt
if fitting_type == 'ener':
Expand Down Expand Up @@ -198,10 +195,7 @@ def _init_param(self, jdata):

# learning rate
lr_param = j_must_have(jdata, 'learning_rate')
try:
lr_type = lr_param['type']
except:
lr_type = 'exp'
lr_type = lr_param.get('type', 'exp')
if lr_type == 'exp':
self.lr = LearningRateExp(lr_param['start_lr'],
lr_param['stop_lr'],
Expand All @@ -211,12 +205,8 @@ def _init_param(self, jdata):

# loss
# infer loss type by fitting_type
try :
loss_param = jdata['loss']
loss_type = loss_param.get('type', 'ener')
except:
loss_param = None
loss_type = 'ener'
loss_param = jdata.get('loss', None)
loss_type = loss_param.get('type', 'ener')

if fitting_type == 'ener':
loss_param.pop('type', None)
Expand Down