diff --git a/deepmd/common.py b/deepmd/common.py index 9968cff39c..695fee1a93 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -20,6 +20,7 @@ import yaml from deepmd.env import op_module, tf +from tensorflow.python.framework import tensor_util from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION from deepmd.utils.sess import run_sess from deepmd.utils.errors import GraphWithoutTensorError @@ -487,3 +488,81 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: return np.float64 else: raise RuntimeError(f"{precision} is not a valid precision") + + +def safe_cast_tensor(input: tf.Tensor, + from_precision: tf.DType, + to_precision: tf.DType) -> tf.Tensor: + """Convert a Tensor from a precision to another precision. + + If input is not a Tensor or without the specific precision, the method will not + cast it. + + Parameters + ---------- + input: tf.Tensor + input tensor + precision : tf.DType + Tensor data type that casts to + + Returns + ------- + tf.Tensor + casted Tensor + """ + if tensor_util.is_tensor(input) and input.dtype == from_precision: + return tf.cast(input, to_precision) + return input + + +def cast_precision(func: Callable) -> Callable: + """A decorator that casts and casts back the input + and output tensor of a method. + + The decorator should be used in a classmethod. + + The decorator will do the following thing: + (1) It casts input Tensors from `GLOBAL_TF_FLOAT_PRECISION` + to precision defined by property `precision`. + (2) It casts output Tensors from `precision` to + `GLOBAL_TF_FLOAT_PRECISION`. + (3) It checks inputs and outputs and only casts when + input or output is a Tensor and its dtype matches + `GLOBAL_TF_FLOAT_PRECISION` and `precision`, respectively. + If it does not match (e.g. it is an integer), the decorator + will do nothing on it. + + Parameters + ---------- + precision : tf.DType + Tensor data type that casts to + + Returns + ------- + Callable + a decorator that casts and casts back the input and + output tensor of a method + + Examples + -------- + >>> class A: + ... @property + ... def precision(self): + ... return tf.float32 + ... + ... @cast_precision + ... def f(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: + ... return x ** 2 + y + """ + def wrapper(self, *args, **kwargs): + # only convert tensors + returned_tensor = func( + self, + *[safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for vv in args], + **{kk: safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for kk, vv in kwargs.items()}, + ) + if isinstance(returned_tensor, tuple): + return tuple((safe_cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION) for vv in returned_tensor)) + else: + return safe_cast_tensor(returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION) + return wrapper diff --git a/deepmd/descriptor/se.py b/deepmd/descriptor/se.py index c7470568b6..c42a7fb46d 100644 --- a/deepmd/descriptor/se.py +++ b/deepmd/descriptor/se.py @@ -106,3 +106,8 @@ def init_variables(self, The suffix of the scope """ self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix) + + @property + def precision(self) -> tf.DType: + """Precision of filter network.""" + return self.filter_precision diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 3b8790911f..db2980b882 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -3,7 +3,7 @@ from typing import Tuple, List, Dict, Any from deepmd.env import tf -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_NP_FLOAT_PRECISION @@ -558,7 +558,7 @@ def _pass_filter(self, [ 0, start_index* self.ndescrpt], [-1, natoms[2+type_i]* self.ndescrpt] ) inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3]) output.append(layer) @@ -568,7 +568,7 @@ def _pass_filter(self, inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) + layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3]) output.append(layer) @@ -704,7 +704,6 @@ def _filter_lower( # with (natom x nei_type_i) x 1 xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1]) if type_embedding is not None: - type_embedding = tf.cast(type_embedding, self.filter_precision) xyz_scatter = self._concat_type_embedding( xyz_scatter, nframes, natoms, type_embedding) if self.compress: @@ -736,7 +735,7 @@ def _filter_lower( if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift else: # we can safely return the final xyz_scatter filled with zero directly - return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) + return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), self.filter_precision) # natom x nei_type_i x out_size xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below @@ -747,6 +746,7 @@ def _filter_lower( return tf.matmul(tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True) + @cast_precision def _filter( self, inputs, @@ -759,8 +759,6 @@ def _filter( name='linear', reuse=None, trainable = True): - if self.mixed_prec is not None: - inputs = tf.cast(inputs, get_precision(self.mixed_prec['compute_prec'])) nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0] # natom x (nei x 4) shape = inputs.get_shape().as_list() diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index 429590c07f..79ca850c83 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -2,7 +2,7 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_NP_FLOAT_PRECISION @@ -392,7 +392,7 @@ def _pass_filter(self, [ 0, start_index* self.ndescrpt], [-1, natoms[2+type_i]* self.ndescrpt] ) inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer = self._filter_r(self.filter_precision, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) output.append(layer) start_index += natoms[2+type_i] @@ -400,7 +400,7 @@ def _pass_filter(self, inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 - layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer = self._filter_r(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) output.append(layer) output = tf.concat(output, axis = 1) @@ -450,7 +450,7 @@ def _compute_std (self,sumv2, sumv, sumn) : val = 1e-2 return val - + @cast_precision def _filter_r(self, inputs, type_input, @@ -495,7 +495,7 @@ def _filter_r(self, xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1])) else: natom = tf.shape(inputs)[0] - xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) + xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), self.filter_precision) xyz_scatter_total.append(xyz_scatter) # natom x nei x outputs_size diff --git a/deepmd/descriptor/se_t.py b/deepmd/descriptor/se_t.py index 84ce07eddf..d7dd845cb6 100644 --- a/deepmd/descriptor/se_t.py +++ b/deepmd/descriptor/se_t.py @@ -2,7 +2,7 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_NP_FLOAT_PRECISION @@ -448,7 +448,7 @@ def _pass_filter(self, inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) # qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3]) output.append(layer) @@ -509,7 +509,7 @@ def _compute_std (self,sumv2, sumv, sumn) : val = 1e-2 return val - + @cast_precision def _filter(self, inputs, type_input, diff --git a/deepmd/fit/dipole.py b/deepmd/fit/dipole.py index 56dc777631..44e84dac1f 100644 --- a/deepmd/fit/dipole.py +++ b/deepmd/fit/dipole.py @@ -3,16 +3,17 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.utils.network import one_layer, one_layer_rand_seed_shift from deepmd.utils.graph import get_fitting_net_variables from deepmd.descriptor import DescrptSeA +from deepmd.fit.fitting import Fitting from deepmd.env import global_cvt_2_tf_float from deepmd.env import GLOBAL_TF_FLOAT_PRECISION -class DipoleFittingSeA () : +class DipoleFittingSeA (Fitting) : """ Fit the atomic dipole with descriptor se_a @@ -90,6 +91,7 @@ def get_out_size(self) -> int: """ return 3 + @cast_precision def build (self, input_d : tf.Tensor, rot_mat : tf.Tensor, @@ -121,7 +123,7 @@ def build (self, The atomic dipole. """ start_index = 0 - inputs = tf.cast(tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]) rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 @@ -163,7 +165,7 @@ def build (self, count += 1 tf.summary.histogram('fitting_net_output', outs) - return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) + return tf.reshape(outs, [-1]) # return tf.reshape(outs, [tf.shape(inputs)[0] * natoms[0] * 3 // 3]) def init_variables(self, diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index f82b68f41b..dca4e0d353 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -3,18 +3,17 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import ClassArg, add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.utils.network import one_layer, one_layer_rand_seed_shift -from deepmd.descriptor import DescrptLocFrame -from deepmd.descriptor import DescrptSeA from deepmd.utils.type_embed import embed_atom_type from deepmd.utils.graph import get_fitting_net_variables, load_graph_def, get_tensor_by_name_from_graph +from deepmd.fit.fitting import Fitting from deepmd.env import global_cvt_2_tf_float from deepmd.env import GLOBAL_TF_FLOAT_PRECISION -class EnerFitting (): +class EnerFitting (Fitting): r"""Fitting the energy of the system. The force and the virial can also be trained. The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`: @@ -132,7 +131,7 @@ def __init__ (self, self.atom_ener = [] for at, ae in enumerate(atom_ener): if ae is not None: - self.atom_ener.append(tf.constant(ae, GLOBAL_TF_FLOAT_PRECISION, name = "atom_%d_ener" % at)) + self.atom_ener.append(tf.constant(ae, self.fitting_precision, name = "atom_%d_ener" % at)) else: self.atom_ener.append(None) self.useBN = False @@ -329,7 +328,7 @@ def _build_lower( return final_layer - + @cast_precision def build (self, inputs : tf.Tensor, natoms : tf.Tensor, @@ -399,10 +398,10 @@ def build (self, trainable = False, initializer = tf.constant_initializer(self.aparam_inv_std)) - inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) if len(self.atom_ener): # only for atom_ener - inputs_zero = tf.zeros_like(inputs, dtype=GLOBAL_TF_FLOAT_PRECISION) + inputs_zero = tf.zeros_like(inputs, dtype=self.fitting_precision) if bias_atom_e is not None : @@ -468,7 +467,7 @@ def build (self, axis=1 ) self.dim_descrpt = self.dim_descrpt + type_shape[1] - inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) final_layer = self._build_lower( 0, natoms[0], inputs, fparam, aparam, @@ -485,7 +484,7 @@ def build (self, outs = tf.reshape(outs, [-1]) tf.summary.histogram('fitting_net_output', outs) - return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) + return tf.reshape(outs, [-1]) def init_variables(self, diff --git a/deepmd/fit/fitting.py b/deepmd/fit/fitting.py new file mode 100644 index 0000000000..69c2c96e52 --- /dev/null +++ b/deepmd/fit/fitting.py @@ -0,0 +1,8 @@ +from deepmd.env import tf +from deepmd.utils import Plugin, PluginVariant + +class Fitting: + @property + def precision(self) -> tf.DType: + """Precision of fitting network.""" + return self.fitting_precision diff --git a/deepmd/fit/polar.py b/deepmd/fit/polar.py index 5f6ddd7525..e5632fadbb 100644 --- a/deepmd/fit/polar.py +++ b/deepmd/fit/polar.py @@ -3,12 +3,13 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import add_data_requirement, cast_precision, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter from deepmd.utils.argcheck import list_to_doc from deepmd.utils.network import one_layer, one_layer_rand_seed_shift from deepmd.utils.graph import get_fitting_net_variables from deepmd.descriptor import DescrptLocFrame from deepmd.descriptor import DescrptSeA +from deepmd.fit.fitting import Fitting from deepmd.env import global_cvt_2_tf_float from deepmd.env import GLOBAL_TF_FLOAT_PRECISION @@ -102,7 +103,7 @@ def build (self, return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) -class PolarFittingSeA () : +class PolarFittingSeA (Fitting) : """ Fit the atomic polarizability with descriptor se_a """ @@ -274,6 +275,7 @@ def compute_input_stats(self, for itype in range(len(self.sel_type)): self.constant_matrix[itype] = np.mean(np.diagonal(atom_polar[itype].reshape((3,3)))) + @cast_precision def build (self, input_d : tf.Tensor, rot_mat : tf.Tensor, @@ -305,7 +307,7 @@ def build (self, The atomic polarizability """ start_index = 0 - inputs = tf.cast(tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]) rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 @@ -372,7 +374,7 @@ def build (self, count += 1 tf.summary.histogram('fitting_net_output', outs) - return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) + return tf.reshape(outs, [-1]) def init_variables(self, model_file: str