diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 8881ac3a14..ce8af71379 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -3,7 +3,7 @@ from typing import Tuple, List, Dict, Any from deepmd.env import tf -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, get_np_precision +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter from deepmd.utils.argcheck import list_to_doc from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_NP_FLOAT_PRECISION @@ -13,7 +13,7 @@ from deepmd.utils.tabulate import DPTabulate from deepmd.utils.type_embed import embed_atom_type from deepmd.utils.sess import run_sess -from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables +from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph from .descriptor import Descriptor from .se import DescrptSe @@ -133,7 +133,6 @@ def __init__ (self, self.compress_activation_fn = get_activation_func(activation_function) self.filter_activation_fn = get_activation_func(activation_function) self.filter_precision = get_precision(precision) - self.filter_np_precision = get_np_precision(precision) self.exclude_types = set() for tt in exclude_types: assert(len(tt) == 2) @@ -687,7 +686,7 @@ def _filter_lower( net = 'filter_-1_net_' + str(type_i) else: net = 'filter_' + str(type_input) + '_net_' + str(type_i) - return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) + return op_module.tabulate_fusion(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) else: if (not is_exclude): xyz_scatter = embedding_net(