Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions deepmd/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import yaml

from deepmd.env import op_module, tf
from tensorflow.python.framework import tensor_util
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION, GLOBAL_NP_FLOAT_PRECISION
from deepmd.utils.sess import run_sess
from deepmd.utils.errors import GraphWithoutTensorError
Expand Down Expand Up @@ -487,3 +488,81 @@ def get_np_precision(precision: "_PRECISION") -> np.dtype:
return np.float64
else:
raise RuntimeError(f"{precision} is not a valid precision")


def safe_cast_tensor(input: tf.Tensor,
from_precision: tf.DType,
to_precision: tf.DType) -> tf.Tensor:
"""Convert a Tensor from a precision to another precision.

If input is not a Tensor or without the specific precision, the method will not
cast it.

Parameters
----------
input: tf.Tensor
input tensor
precision : tf.DType
Tensor data type that casts to

Returns
-------
tf.Tensor
casted Tensor
"""
if tensor_util.is_tensor(input) and input.dtype == from_precision:
return tf.cast(input, to_precision)
return input


def cast_precision(func: Callable) -> Callable:
"""A decorator that casts and casts back the input
and output tensor of a method.
Comment on lines +519 to +520
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please write more on the logic behind cast_precision?

  • It casts i tensor from GLOBAL_TF_FLOAT_PRECISION to precision defined by property precision.
  • It casts o tensor from precision to GLOBAL_TF_FLOAT_PRECISION.
  • It checks the i/o list and only cast when an i or o is tensor and the tensor matches GLOBAL_TF_FLOAT_PRECISION or precision, respectively.


The decorator should be used in a classmethod.

The decorator will do the following thing:
(1) It casts input Tensors from `GLOBAL_TF_FLOAT_PRECISION`
to precision defined by property `precision`.
(2) It casts output Tensors from `precision` to
`GLOBAL_TF_FLOAT_PRECISION`.
(3) It checks inputs and outputs and only casts when
input or output is a Tensor and its dtype matches
`GLOBAL_TF_FLOAT_PRECISION` and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.

Parameters
----------
precision : tf.DType
Tensor data type that casts to

Returns
-------
Callable
a decorator that casts and casts back the input and
output tensor of a method

Examples
--------
>>> class A:
... @property
... def precision(self):
... return tf.float32
...
... @cast_precision
... def f(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
... return x ** 2 + y
"""
def wrapper(self, *args, **kwargs):
# only convert tensors
returned_tensor = func(
self,
*[safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for vv in args],
**{kk: safe_cast_tensor(vv, GLOBAL_TF_FLOAT_PRECISION, self.precision) for kk, vv in kwargs.items()},
)
if isinstance(returned_tensor, tuple):
return tuple((safe_cast_tensor(vv, self.precision, GLOBAL_TF_FLOAT_PRECISION) for vv in returned_tensor))
else:
return safe_cast_tensor(returned_tensor, self.precision, GLOBAL_TF_FLOAT_PRECISION)
return wrapper
5 changes: 5 additions & 0 deletions deepmd/descriptor/se.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,8 @@ def init_variables(self,
The suffix of the scope
"""
self.embedding_net_variables = get_embedding_net_variables(model_file, suffix = suffix)

@property
def precision(self) -> tf.DType:
"""Precision of filter network."""
return self.filter_precision
12 changes: 5 additions & 7 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple, List, Dict, Any

from deepmd.env import tf
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision
from deepmd.utils.argcheck import list_to_doc
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
Expand Down Expand Up @@ -558,7 +558,7 @@ def _pass_filter(self,
[ 0, start_index* self.ndescrpt],
[-1, natoms[2+type_i]* self.ndescrpt] )
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
output.append(layer)
Expand All @@ -568,7 +568,7 @@ def _pass_filter(self,
inputs_i = inputs
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding)
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
output.append(layer)
Expand Down Expand Up @@ -704,7 +704,6 @@ def _filter_lower(
# with (natom x nei_type_i) x 1
xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1])
if type_embedding is not None:
type_embedding = tf.cast(type_embedding, self.filter_precision)
xyz_scatter = self._concat_type_embedding(
xyz_scatter, nframes, natoms, type_embedding)
if self.compress:
Expand Down Expand Up @@ -736,7 +735,7 @@ def _filter_lower(
if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift
else:
# we can safely return the final xyz_scatter filled with zero directly
return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION)
return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), self.filter_precision)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cast the result back to global precision before return

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any cast in line 722 or line 747

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there something wrong?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any cast in line 722 or line 747

747 should be cast back...

Why do we need to cast at L722?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of _filter_lower, I think we should cast back at line 839 before _filter returns.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not writing a decorator for doing that? It casts the inputs of _filter to filter_precision and casts back to global precision when _filter returns

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

# natom x nei_type_i x out_size
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1]))
# When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below
Expand All @@ -747,6 +746,7 @@ def _filter_lower(
return tf.matmul(tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True)


@cast_precision
def _filter(
self,
inputs,
Expand All @@ -759,8 +759,6 @@ def _filter(
name='linear',
reuse=None,
trainable = True):
if self.mixed_prec is not None:
inputs = tf.cast(inputs, get_precision(self.mixed_prec['compute_prec']))
nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0]
# natom x (nei x 4)
shape = inputs.get_shape().as_list()
Expand Down
10 changes: 5 additions & 5 deletions deepmd/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Tuple, List

from deepmd.env import tf
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision
from deepmd.utils.argcheck import list_to_doc
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
Expand Down Expand Up @@ -392,15 +392,15 @@ def _pass_filter(self,
[ 0, start_index* self.ndescrpt],
[-1, natoms[2+type_i]* self.ndescrpt] )
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = self._filter_r(self.filter_precision, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first variable should be "inputs_i" instead of “self.filter_precision”. It couldn't pass UT but it was merged into devel.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, this line is not covered by UT.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's covered by source/tests/test_model_compression_se_r.py, a UT just added in PR #1361.

layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
output.append(layer)
start_index += natoms[2+type_i]
else :
inputs_i = inputs
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = self._filter_r(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
output.append(layer)
output = tf.concat(output, axis = 1)
Expand Down Expand Up @@ -450,7 +450,7 @@ def _compute_std (self,sumv2, sumv, sumn) :
val = 1e-2
return val


@cast_precision
def _filter_r(self,
inputs,
type_input,
Expand Down Expand Up @@ -495,7 +495,7 @@ def _filter_r(self,
xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1]))
else:
natom = tf.shape(inputs)[0]
xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION)
xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), self.filter_precision)
xyz_scatter_total.append(xyz_scatter)

# natom x nei x outputs_size
Expand Down
6 changes: 3 additions & 3 deletions deepmd/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Tuple, List

from deepmd.env import tf
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision
from deepmd.utils.argcheck import list_to_doc
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_NP_FLOAT_PRECISION
Expand Down Expand Up @@ -448,7 +448,7 @@ def _pass_filter(self,
inputs_i = inputs
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
type_i = -1
layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn)
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()])
# qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3])
output.append(layer)
Expand Down Expand Up @@ -509,7 +509,7 @@ def _compute_std (self,sumv2, sumv, sumn) :
val = 1e-2
return val


@cast_precision
def _filter(self,
inputs,
type_input,
Expand Down
10 changes: 6 additions & 4 deletions deepmd/fit/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
from typing import Tuple, List

from deepmd.env import tf
from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision
from deepmd.utils.argcheck import list_to_doc
from deepmd.utils.network import one_layer, one_layer_rand_seed_shift
from deepmd.utils.graph import get_fitting_net_variables
from deepmd.descriptor import DescrptSeA
from deepmd.fit.fitting import Fitting

from deepmd.env import global_cvt_2_tf_float
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION

class DipoleFittingSeA () :
class DipoleFittingSeA (Fitting) :
"""
Fit the atomic dipole with descriptor se_a

Expand Down Expand Up @@ -90,6 +91,7 @@ def get_out_size(self) -> int:
"""
return 3

@cast_precision
def build (self,
input_d : tf.Tensor,
rot_mat : tf.Tensor,
Expand Down Expand Up @@ -121,7 +123,7 @@ def build (self,
The atomic dipole.
"""
start_index = 0
inputs = tf.cast(tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision)
inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])

count = 0
Expand Down Expand Up @@ -163,7 +165,7 @@ def build (self,
count += 1

tf.summary.histogram('fitting_net_output', outs)
return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION)
return tf.reshape(outs, [-1])
# return tf.reshape(outs, [tf.shape(inputs)[0] * natoms[0] * 3 // 3])

def init_variables(self,
Expand Down
19 changes: 9 additions & 10 deletions deepmd/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
from typing import Tuple, List

from deepmd.env import tf
from deepmd.common import ClassArg, add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter
from deepmd.common import add_data_requirement, get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision
from deepmd.utils.argcheck import list_to_doc
from deepmd.utils.network import one_layer, one_layer_rand_seed_shift
from deepmd.descriptor import DescrptLocFrame
from deepmd.descriptor import DescrptSeA
from deepmd.utils.type_embed import embed_atom_type
from deepmd.utils.graph import get_fitting_net_variables, load_graph_def, get_tensor_by_name_from_graph
from deepmd.fit.fitting import Fitting

from deepmd.env import global_cvt_2_tf_float
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION

class EnerFitting ():
class EnerFitting (Fitting):
r"""Fitting the energy of the system. The force and the virial can also be trained.

The potential energy :math:`E` is a fitting network function of the descriptor :math:`\mathcal{D}`:
Expand Down Expand Up @@ -132,7 +131,7 @@ def __init__ (self,
self.atom_ener = []
for at, ae in enumerate(atom_ener):
if ae is not None:
self.atom_ener.append(tf.constant(ae, GLOBAL_TF_FLOAT_PRECISION, name = "atom_%d_ener" % at))
self.atom_ener.append(tf.constant(ae, self.fitting_precision, name = "atom_%d_ener" % at))
else:
self.atom_ener.append(None)
self.useBN = False
Expand Down Expand Up @@ -329,7 +328,7 @@ def _build_lower(
return final_layer



@cast_precision
def build (self,
inputs : tf.Tensor,
natoms : tf.Tensor,
Expand Down Expand Up @@ -399,10 +398,10 @@ def build (self,
trainable = False,
initializer = tf.constant_initializer(self.aparam_inv_std))

inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision)
inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]])
if len(self.atom_ener):
# only for atom_ener
inputs_zero = tf.zeros_like(inputs, dtype=GLOBAL_TF_FLOAT_PRECISION)
inputs_zero = tf.zeros_like(inputs, dtype=self.fitting_precision)


if bias_atom_e is not None :
Expand Down Expand Up @@ -468,7 +467,7 @@ def build (self,
axis=1
)
self.dim_descrpt = self.dim_descrpt + type_shape[1]
inputs = tf.cast(tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision)
inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]])
final_layer = self._build_lower(
0, natoms[0],
inputs, fparam, aparam,
Expand All @@ -485,7 +484,7 @@ def build (self,
outs = tf.reshape(outs, [-1])

tf.summary.histogram('fitting_net_output', outs)
return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION)
return tf.reshape(outs, [-1])


def init_variables(self,
Expand Down
8 changes: 8 additions & 0 deletions deepmd/fit/fitting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from deepmd.env import tf
from deepmd.utils import Plugin, PluginVariant

class Fitting:
@property
def precision(self) -> tf.DType:
"""Precision of fitting network."""
return self.fitting_precision
Loading