From 34f6456e6f4aa88090c9c022ada6890b617e3c1b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 00:55:05 -0500 Subject: [PATCH 01/10] fix filter network precision --- deepmd/descriptor/se_a.py | 6 +++--- deepmd/descriptor/se_r.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 3b8790911f..c7893473cf 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -736,7 +736,7 @@ def _filter_lower( if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift else: # we can safely return the final xyz_scatter filled with zero directly - return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) + return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), self.filter_precision) # natom x nei_type_i x out_size xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below @@ -774,8 +774,8 @@ def _filter( # result: natom x outputs_size x outputs_size_2 # qmat: natom x outputs_size x 3 natom = tf.shape(inputs)[0] - result = tf.cast(tf.fill((natom, outputs_size_2, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) - qmat = tf.cast(tf.fill((natom, outputs_size[-1], 3), 0.), GLOBAL_TF_FLOAT_PRECISION) + result = tf.cast(tf.fill((natom, outputs_size_2, outputs_size[-1]), 0.), self.filter_precision) + qmat = tf.cast(tf.fill((natom, outputs_size[-1], 3), 0.), self.filter_precision) return result, qmat with tf.variable_scope(name, reuse=reuse): diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index 429590c07f..9f91d664ca 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -495,7 +495,7 @@ def _filter_r(self, xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1])) else: natom = tf.shape(inputs)[0] - xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) + xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), self.filter_precision) xyz_scatter_total.append(xyz_scatter) # natom x nei x outputs_size From 5aab37843a13e0a6b0bfd810dfe8543c227eea75 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 01:10:20 -0500 Subject: [PATCH 02/10] fix fitting energy precision --- deepmd/fit/ener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index f82b68f41b..319b0a5776 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -402,7 +402,7 @@ def build (self, inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) if len(self.atom_ener): # only for atom_ener - inputs_zero = tf.zeros_like(inputs, dtype=GLOBAL_TF_FLOAT_PRECISION) + inputs_zero = tf.zeros_like(inputs, dtype=self.fitting_precision) if bias_atom_e is not None : From 3f1199ed24aa4ac4b48849fd6c4d4ac6bb0323c2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 01:15:07 -0500 Subject: [PATCH 03/10] fix one more precision --- deepmd/fit/ener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 319b0a5776..6820036c91 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -132,7 +132,7 @@ def __init__ (self, self.atom_ener = [] for at, ae in enumerate(atom_ener): if ae is not None: - self.atom_ener.append(tf.constant(ae, GLOBAL_TF_FLOAT_PRECISION, name = "atom_%d_ener" % at)) + self.atom_ener.append(tf.constant(ae, self.fitting_precision, name = "atom_%d_ener" % at)) else: self.atom_ener.append(None) self.useBN = False From 8bed5cee88fdda368d279bb8c890f9e139cce95b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 19:46:48 -0500 Subject: [PATCH 04/10] cast back precision --- deepmd/descriptor/se_a.py | 6 ++++-- deepmd/descriptor/se_r.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index c7893473cf..7e121f3391 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -774,8 +774,8 @@ def _filter( # result: natom x outputs_size x outputs_size_2 # qmat: natom x outputs_size x 3 natom = tf.shape(inputs)[0] - result = tf.cast(tf.fill((natom, outputs_size_2, outputs_size[-1]), 0.), self.filter_precision) - qmat = tf.cast(tf.fill((natom, outputs_size[-1], 3), 0.), self.filter_precision) + result = tf.cast(tf.fill((natom, outputs_size_2, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) + qmat = tf.cast(tf.fill((natom, outputs_size[-1], 3), 0.), GLOBAL_TF_FLOAT_PRECISION) return result, qmat with tf.variable_scope(name, reuse=reuse): @@ -835,5 +835,7 @@ def _filter( result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a = True) # natom x (outputs_size x outputs_size_2) result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) + result = tf.cast(result, GLOBAL_TF_FLOAT_PRECISION) + qmat = tf.cast(qmat, GLOBAL_TF_FLOAT_PRECISION) return result, qmat diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index 9f91d664ca..eb56e9529e 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -504,5 +504,6 @@ def _filter_r(self, # res_rescale = 1./5. result = tf.reduce_mean(xyz_scatter, axis = 1) * res_rescale + result = tf.cast(result, GLOBAL_TF_FLOAT_PRECISION) return result From c926a053be45458f006508350974f0e3d3aff874 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 21:59:14 -0500 Subject: [PATCH 05/10] add a decorator to cast precision --- deepmd/common.py | 72 +++++++++++++++++++++++++++++++++++++++ deepmd/descriptor/se.py | 5 +++ deepmd/descriptor/se_a.py | 12 +++---- deepmd/descriptor/se_r.py | 9 +++-- deepmd/descriptor/se_t.py | 6 ++-- deepmd/fit/dipole.py | 10 +++--- deepmd/fit/ener.py | 15 ++++---- deepmd/fit/fitting.py | 8 +++++ deepmd/fit/polar.py | 10 +++--- 9 files changed, 115 insertions(+), 32 deletions(-) create mode 100644 deepmd/fit/fitting.py diff --git a/deepmd/common.py b/deepmd/common.py index 9968cff39c..b44455676f 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -487,3 +487,75 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: return np.float64 else: raise RuntimeError(f"{precision} is not a valid precision") + + +def cast_tensor(input: tf.Tensor, + from_precision: tf.DTypes, + to_precision: tf.DTypes) -> tf.Tensor: + """Convert a Tensor from a precision to another precision. + + Parameters + ---------- + input: tf.Tensor + input tensor + precision : tf.DType + Tensor data type that casts to + + Returns + ------- + tf.Tensor + casted Tensor + + Notes + ----- + If input is not a Tensor or without the specific precision, the method will not + cast it. + """ + if tf.is_tensor(input) and input.dtype == from_precision: + return tf.cast(input, to_precision) + return input + + +def cast_precision(func: Callable) -> Callable: + """A decorator that casts and casts back the input + and output tensor of a method. + + Parameters + ---------- + precision : tf.DType + Tensor data type that casts to + + Returns + ------- + Callable + a decorator that casts and casts back the input and + output tensor of a method + + Notes + ----- + The decorator should only be used in a classmethod where + the class has the property `precision`. The decorator will + only cast Tensors with global precision. + + Examples + -------- + >>> class A: + ... @property + ... def precision(self): + ... return tf.float32 + ... + ... @cast_precision + ... def f(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: + ... return x ** 2 + y + """ + def wrapper(self, *args, **kwargs): + # only convert tensors + returned_tensor = func( + *[cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for vv in args], + **{kk: cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for kk, vv in kwargs.items()}, + ) + if isinstance(returned_tensor, tuple): + return tuple((cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION) for vv in returned_tensor)) + else: + return cast_tensor(returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION) + return wrapper diff --git a/deepmd/descriptor/se.py b/deepmd/descriptor/se.py index c7470568b6..c42a7fb46d 100644 --- a/deepmd/descriptor/se.py +++ b/deepmd/descriptor/se.py @@ -106,3 +106,8 @@ def init_variables(self, The suffix of the scope """ self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix) + + @property + def precision(self) -> tf.DType: + """Precision of filter network.""" + return self.filter_precision diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 7e121f3391..db2980b882 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -3,7 +3,7 @@ from typing import Tuple, List, Dict, Any from deepmd.env import tf -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_NP_FLOAT_PRECISION @@ -558,7 +558,7 @@ def _pass_filter(self, [ 0, start_index* self.ndescrpt], [-1, natoms[2+type_i]* self.ndescrpt] ) inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3]) output.append(layer) @@ -568,7 +568,7 @@ def _pass_filter(self, inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) + layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3]) output.append(layer) @@ -704,7 +704,6 @@ def _filter_lower( # with (natom x nei_type_i) x 1 xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1]) if type_embedding is not None: - type_embedding = tf.cast(type_embedding, self.filter_precision) xyz_scatter = self._concat_type_embedding( xyz_scatter, nframes, natoms, type_embedding) if self.compress: @@ -747,6 +746,7 @@ def _filter_lower( return tf.matmul(tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True) + @cast_precision def _filter( self, inputs, @@ -759,8 +759,6 @@ def _filter( name='linear', reuse=None, trainable = True): - if self.mixed_prec is not None: - inputs = tf.cast(inputs, get_precision(self.mixed_prec['compute_prec'])) nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0] # natom x (nei x 4) shape = inputs.get_shape().as_list() @@ -835,7 +833,5 @@ def _filter( result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a = True) # natom x (outputs_size x outputs_size_2) result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) - result = tf.cast(result, GLOBAL_TF_FLOAT_PRECISION) - qmat = tf.cast(qmat, GLOBAL_TF_FLOAT_PRECISION) return result, qmat diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index eb56e9529e..79ca850c83 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -2,7 +2,7 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_NP_FLOAT_PRECISION @@ -392,7 +392,7 @@ def _pass_filter(self, [ 0, start_index* self.ndescrpt], [-1, natoms[2+type_i]* self.ndescrpt] ) inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) - layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer = self._filter_r(self.filter_precision, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) output.append(layer) start_index += natoms[2+type_i] @@ -400,7 +400,7 @@ def _pass_filter(self, inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 - layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer = self._filter_r(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) output.append(layer) output = tf.concat(output, axis = 1) @@ -450,7 +450,7 @@ def _compute_std (self,sumv2, sumv, sumn) : val = 1e-2 return val - + @cast_precision def _filter_r(self, inputs, type_input, @@ -504,6 +504,5 @@ def _filter_r(self, # res_rescale = 1./5. result = tf.reduce_mean(xyz_scatter, axis = 1) * res_rescale - result = tf.cast(result, GLOBAL_TF_FLOAT_PRECISION) return result diff --git a/deepmd/descriptor/se_t.py b/deepmd/descriptor/se_t.py index 84ce07eddf..d7dd845cb6 100644 --- a/deepmd/descriptor/se_t.py +++ b/deepmd/descriptor/se_t.py @@ -2,7 +2,7 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_NP_FLOAT_PRECISION @@ -448,7 +448,7 @@ def _pass_filter(self, inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 - layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) + layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) # qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3]) output.append(layer) @@ -509,7 +509,7 @@ def _compute_std (self,sumv2, sumv, sumn) : val = 1e-2 return val - + @cast_precision def _filter(self, inputs, type_input, diff --git a/deepmd/fit/dipole.py b/deepmd/fit/dipole.py index 56dc777631..44e84dac1f 100644 --- a/deepmd/fit/dipole.py +++ b/deepmd/fit/dipole.py @@ -3,16 +3,17 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.utils.network import one_layer, one_layer_rand_seed_shift from deepmd.utils.graph import get_fitting_net_variables from deepmd.descriptor import DescrptSeA +from deepmd.fit.fitting import Fitting from deepmd.env import global_cvt_2_tf_float from deepmd.env import GLOBAL_TF_FLOAT_PRECISION -class DipoleFittingSeA () : +class DipoleFittingSeA (Fitting) : """ Fit the atomic dipole with descriptor se_a @@ -90,6 +91,7 @@ def get_out_size(self) -> int: """ return 3 + @cast_precision def build (self, input_d : tf.Tensor, rot_mat : tf.Tensor, @@ -121,7 +123,7 @@ def build (self, The atomic dipole. """ start_index = 0 - inputs = tf.cast(tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]) rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 @@ -163,7 +165,7 @@ def build (self, count += 1 tf.summary.histogram('fitting_net_output', outs) - return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) + return tf.reshape(outs, [-1]) # return tf.reshape(outs, [tf.shape(inputs)[0] * natoms[0] * 3 // 3]) def init_variables(self, diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 6820036c91..dca4e0d353 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -3,18 +3,17 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import ClassArg, add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision from deepmd.utils.argcheck import list_to_doc from deepmd.utils.network import one_layer, one_layer_rand_seed_shift -from deepmd.descriptor import DescrptLocFrame -from deepmd.descriptor import DescrptSeA from deepmd.utils.type_embed import embed_atom_type from deepmd.utils.graph import get_fitting_net_variables, load_graph_def, get_tensor_by_name_from_graph +from deepmd.fit.fitting import Fitting from deepmd.env import global_cvt_2_tf_float from deepmd.env import GLOBAL_TF_FLOAT_PRECISION -class EnerFitting (): +class EnerFitting (Fitting): r"""Fitting the energy of the system. The force and the virial can also be trained. The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`: @@ -329,7 +328,7 @@ def _build_lower( return final_layer - + @cast_precision def build (self, inputs : tf.Tensor, natoms : tf.Tensor, @@ -399,7 +398,7 @@ def build (self, trainable = False, initializer = tf.constant_initializer(self.aparam_inv_std)) - inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) if len(self.atom_ener): # only for atom_ener inputs_zero = tf.zeros_like(inputs, dtype=self.fitting_precision) @@ -468,7 +467,7 @@ def build (self, axis=1 ) self.dim_descrpt = self.dim_descrpt + type_shape[1] - inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) final_layer = self._build_lower( 0, natoms[0], inputs, fparam, aparam, @@ -485,7 +484,7 @@ def build (self, outs = tf.reshape(outs, [-1]) tf.summary.histogram('fitting_net_output', outs) - return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) + return tf.reshape(outs, [-1]) def init_variables(self, diff --git a/deepmd/fit/fitting.py b/deepmd/fit/fitting.py new file mode 100644 index 0000000000..2b10705273 --- /dev/null +++ b/deepmd/fit/fitting.py @@ -0,0 +1,8 @@ +from deepmd.env import tf +from deepmd.utils import Plugin, PluginVariant + +class Fitting: + @property + def precision(self) -> tf.DType: + """Precision of fitting network.""" + return self.fitting_precision \ No newline at end of file diff --git a/deepmd/fit/polar.py b/deepmd/fit/polar.py index 5f6ddd7525..e5632fadbb 100644 --- a/deepmd/fit/polar.py +++ b/deepmd/fit/polar.py @@ -3,12 +3,13 @@ from typing import Tuple, List from deepmd.env import tf -from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter +from deepmd.common import add_data_requirement, cast_precision, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter from deepmd.utils.argcheck import list_to_doc from deepmd.utils.network import one_layer, one_layer_rand_seed_shift from deepmd.utils.graph import get_fitting_net_variables from deepmd.descriptor import DescrptLocFrame from deepmd.descriptor import DescrptSeA +from deepmd.fit.fitting import Fitting from deepmd.env import global_cvt_2_tf_float from deepmd.env import GLOBAL_TF_FLOAT_PRECISION @@ -102,7 +103,7 @@ def build (self, return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) -class PolarFittingSeA () : +class PolarFittingSeA (Fitting) : """ Fit the atomic polarizability with descriptor se_a """ @@ -274,6 +275,7 @@ def compute_input_stats(self, for itype in range(len(self.sel_type)): self.constant_matrix[itype] = np.mean(np.diagonal(atom_polar[itype].reshape((3,3)))) + @cast_precision def build (self, input_d : tf.Tensor, rot_mat : tf.Tensor, @@ -305,7 +307,7 @@ def build (self, The atomic polarizability """ start_index = 0 - inputs = tf.cast(tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) + inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]) rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 @@ -372,7 +374,7 @@ def build (self, count += 1 tf.summary.histogram('fitting_net_output', outs) - return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) + return tf.reshape(outs, [-1]) def init_variables(self, model_file: str From 9dadabb3a8d1613ff6f66cd1b6ca9d88617e3eac Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 22:04:39 -0500 Subject: [PATCH 06/10] fix typo --- deepmd/common.py | 4 ++-- deepmd/fit/fitting.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index b44455676f..ae988d2c88 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -490,8 +490,8 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: def cast_tensor(input: tf.Tensor, - from_precision: tf.DTypes, - to_precision: tf.DTypes) -> tf.Tensor: + from_precision: tf.DType, + to_precision: tf.DType) -> tf.Tensor: """Convert a Tensor from a precision to another precision. Parameters diff --git a/deepmd/fit/fitting.py b/deepmd/fit/fitting.py index 2b10705273..69c2c96e52 100644 --- a/deepmd/fit/fitting.py +++ b/deepmd/fit/fitting.py @@ -5,4 +5,4 @@ class Fitting: @property def precision(self) -> tf.DType: """Precision of fitting network.""" - return self.fitting_precision \ No newline at end of file + return self.fitting_precision From e84a588fdf0a1344d44b8e88da36f16b005bb317 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 22:27:56 -0500 Subject: [PATCH 07/10] fix compatibility issue with TF 1.8 --- deepmd/common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/common.py b/deepmd/common.py index ae988d2c88..704e8e2404 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -20,6 +20,7 @@ import yaml from deepmd.env import op_module, tf +from tensorflow.python.framework import tensor_util from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION from deepmd.utils.sess import run_sess from deepmd.utils.errors import GraphWithoutTensorError @@ -511,7 +512,7 @@ def cast_tensor(input: tf.Tensor, If input is not a Tensor or without the specific precision, the method will not cast it. """ - if tf.is_tensor(input) and input.dtype == from_precision: + if tensor_util.is_tensor(input) and input.dtype == from_precision: return tf.cast(input, to_precision) return input From 1d4029e775d06c8349046d6d5e79c795dcaec23f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 22:42:04 -0500 Subject: [PATCH 08/10] fix missing self --- deepmd/common.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/common.py b/deepmd/common.py index 704e8e2404..e9f2dbbc08 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -552,6 +552,7 @@ def cast_precision(func: Callable) -> Callable: def wrapper(self, *args, **kwargs): # only convert tensors returned_tensor = func( + self, *[cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for vv in args], **{kk: cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for kk, vv in kwargs.items()}, ) From 9e3ff6ecef077768e2e35567c3fdc11564484adb Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 23:04:58 -0500 Subject: [PATCH 09/10] improve docstrings --- deepmd/common.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index e9f2dbbc08..b5eed11312 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -490,10 +490,13 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: raise RuntimeError(f"{precision} is not a valid precision") -def cast_tensor(input: tf.Tensor, +def safe_cast_tensor(input: tf.Tensor, from_precision: tf.DType, to_precision: tf.DType) -> tf.Tensor: """Convert a Tensor from a precision to another precision. + + If input is not a Tensor or without the specific precision, the method will not + cast it. Parameters ---------- @@ -506,11 +509,6 @@ def cast_tensor(input: tf.Tensor, ------- tf.Tensor casted Tensor - - Notes - ----- - If input is not a Tensor or without the specific precision, the method will not - cast it. """ if tensor_util.is_tensor(input) and input.dtype == from_precision: return tf.cast(input, to_precision) @@ -520,6 +518,19 @@ def cast_tensor(input: tf.Tensor, def cast_precision(func: Callable) -> Callable: """A decorator that casts and casts back the input and output tensor of a method. + + The decorator should be used in a classmethod. + + The decorator will do the following thing: + (1) It casts input Tensors from `GLOBAL_TF_FLOAT_PRECISION` + to precision defined by property `precision`. + (2) It casts output Tensors from `precision` to + `GLOBAL_TF_FLOAT_PRECISION`. + (3) It checks inputs and outputs and only casts when + input or output is a Tensor and its dtype matches + `GLOBAL_TF_FLOAT_PRECISION` and `precision`, respectively. + If it does not match (e.g. it is an integer), the decorator + will do nothing on it. Parameters ---------- @@ -532,12 +543,6 @@ def cast_precision(func: Callable) -> Callable: a decorator that casts and casts back the input and output tensor of a method - Notes - ----- - The decorator should only be used in a classmethod where - the class has the property `precision`. The decorator will - only cast Tensors with global precision. - Examples -------- >>> class A: @@ -553,11 +558,11 @@ def wrapper(self, *args, **kwargs): # only convert tensors returned_tensor = func( self, - *[cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for vv in args], - **{kk: cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for kk, vv in kwargs.items()}, + *[safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for vv in args], + **{kk: safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for kk, vv in kwargs.items()}, ) if isinstance(returned_tensor, tuple): - return tuple((cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION) for vv in returned_tensor)) + return tuple((safe_cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION) for vv in returned_tensor)) else: - return cast_tensor(returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION) + return safe_cast_tensor(returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION) return wrapper From 506d003c59bd1845ca01afb0100c7a1d9099a5c9 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 30 Dec 2021 23:12:30 -0500 Subject: [PATCH 10/10] fix lint warnings --- deepmd/common.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index b5eed11312..695fee1a93 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -491,20 +491,20 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype: def safe_cast_tensor(input: tf.Tensor, - from_precision: tf.DType, - to_precision: tf.DType) -> tf.Tensor: + from_precision: tf.DType, + to_precision: tf.DType) -> tf.Tensor: """Convert a Tensor from a precision to another precision. If input is not a Tensor or without the specific precision, the method will not cast it. - + Parameters ---------- input: tf.Tensor input tensor precision : tf.DType Tensor data type that casts to - + Returns ------- tf.Tensor @@ -531,25 +531,25 @@ def cast_precision(func: Callable) -> Callable: `GLOBAL_TF_FLOAT_PRECISION` and `precision`, respectively. If it does not match (e.g. it is an integer), the decorator will do nothing on it. - + Parameters ---------- precision : tf.DType Tensor data type that casts to - + Returns ------- Callable a decorator that casts and casts back the input and output tensor of a method - + Examples -------- >>> class A: ... @property ... def precision(self): ... return tf.float32 - ... + ... ... @cast_precision ... def f(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: ... return x ** 2 + y