Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, List, Dict, Any

from deepmd.env import tf
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, get_np_precision
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.utils.argcheck import list_to_doc
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
Expand All @@ -13,7 +13,7 @@
from deepmd.utils.tabulate import DPTabulate
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.sess import run_sess
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph
from .descriptor import Descriptor
from .se import DescrptSe

Expand Down Expand Up @@ -133,7 +133,6 @@ def __init__ (self,
self.compress_activation_fn = get_activation_func(activation_function)
self.filter_activation_fn = get_activation_func(activation_function)
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
self.exclude_types = set()
for tt in exclude_types:
assert(len(tt) == 2)
Expand Down Expand Up @@ -687,7 +686,7 @@ def _filter_lower(
net = 'filter_-1_net_' + str(type_i)
else:
net = 'filter_' + str(type_input) + '_net_' + str(type_i)
return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
return op_module.tabulate_fusion(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1])
else:
if (not is_exclude):
xyz_scatter = embedding_net(
Expand Down