-
Notifications
You must be signed in to change notification settings - Fork 599
fix network precision under specific situation #1391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
34f6456
5aab378
3f1199e
8bed5ce
c926a05
9dadabb
e84a588
1d4029e
9e3ff6e
506d003
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,7 @@ | |
| from typing import Tuple, List, Dict, Any | ||
|
|
||
| from deepmd.env import tf | ||
| from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter | ||
| from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision | ||
| from deepmd.utils.argcheck import list_to_doc | ||
| from deepmd.env import GLOBAL_TF_FLOAT_PRECISION | ||
| from deepmd.env import GLOBAL_NP_FLOAT_PRECISION | ||
|
|
@@ -558,7 +558,7 @@ def _pass_filter(self, | |
| [ 0, start_index* self.ndescrpt], | ||
| [-1, natoms[2+type_i]* self.ndescrpt] ) | ||
| inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) | ||
| layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) | ||
| layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) | ||
| layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) | ||
| qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3]) | ||
| output.append(layer) | ||
|
|
@@ -568,7 +568,7 @@ def _pass_filter(self, | |
| inputs_i = inputs | ||
| inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) | ||
| type_i = -1 | ||
| layer, qmat = self._filter(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) | ||
| layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) | ||
| layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) | ||
| qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0] * self.get_dim_rot_mat_1() * 3]) | ||
| output.append(layer) | ||
|
|
@@ -704,7 +704,6 @@ def _filter_lower( | |
| # with (natom x nei_type_i) x 1 | ||
| xyz_scatter = tf.reshape(tf.slice(inputs_reshape, [0,0],[-1,1]),[-1,1]) | ||
| if type_embedding is not None: | ||
| type_embedding = tf.cast(type_embedding, self.filter_precision) | ||
| xyz_scatter = self._concat_type_embedding( | ||
| xyz_scatter, nframes, natoms, type_embedding) | ||
| if self.compress: | ||
|
|
@@ -736,7 +735,7 @@ def _filter_lower( | |
| if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift | ||
| else: | ||
| # we can safely return the final xyz_scatter filled with zero directly | ||
| return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) | ||
| return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), self.filter_precision) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We cast the result back to global precision before return
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see any
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there something wrong?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
747 should be cast back... Why do we need to cast at L722?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with you
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not writing a decorator for doing that? It casts the inputs of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea! |
||
| # natom x nei_type_i x out_size | ||
| xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) | ||
| # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below | ||
|
|
@@ -747,6 +746,7 @@ def _filter_lower( | |
| return tf.matmul(tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), xyz_scatter, transpose_a = True) | ||
|
|
||
|
|
||
| @cast_precision | ||
| def _filter( | ||
| self, | ||
| inputs, | ||
|
|
@@ -759,8 +759,6 @@ def _filter( | |
| name='linear', | ||
| reuse=None, | ||
| trainable = True): | ||
| if self.mixed_prec is not None: | ||
| inputs = tf.cast(inputs, get_precision(self.mixed_prec['compute_prec'])) | ||
| nframes = tf.shape(tf.reshape(inputs, [-1, natoms[0], self.ndescrpt]))[0] | ||
| # natom x (nei x 4) | ||
| shape = inputs.get_shape().as_list() | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,7 @@ | |
| from typing import Tuple, List | ||
|
|
||
| from deepmd.env import tf | ||
| from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter | ||
| from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, cast_precision | ||
| from deepmd.utils.argcheck import list_to_doc | ||
| from deepmd.env import GLOBAL_TF_FLOAT_PRECISION | ||
| from deepmd.env import GLOBAL_NP_FLOAT_PRECISION | ||
|
|
@@ -392,15 +392,15 @@ def _pass_filter(self, | |
| [ 0, start_index* self.ndescrpt], | ||
| [-1, natoms[2+type_i]* self.ndescrpt] ) | ||
| inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) | ||
| layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) | ||
| layer = self._filter_r(self.filter_precision, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first variable should be "inputs_i" instead of “self.filter_precision”. It couldn't pass UT but it was merged into devel.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, this line is not covered by UT.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it's covered by source/tests/test_model_compression_se_r.py, a UT just added in PR #1361. |
||
| layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()]) | ||
| output.append(layer) | ||
| start_index += natoms[2+type_i] | ||
| else : | ||
| inputs_i = inputs | ||
| inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) | ||
| type_i = -1 | ||
| layer = self._filter_r(tf.cast(inputs_i, self.filter_precision), type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) | ||
| layer = self._filter_r(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn) | ||
| layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0] * self.get_dim_out()]) | ||
| output.append(layer) | ||
| output = tf.concat(output, axis = 1) | ||
|
|
@@ -450,7 +450,7 @@ def _compute_std (self,sumv2, sumv, sumn) : | |
| val = 1e-2 | ||
| return val | ||
|
|
||
|
|
||
| @cast_precision | ||
| def _filter_r(self, | ||
| inputs, | ||
| type_input, | ||
|
|
@@ -495,7 +495,7 @@ def _filter_r(self, | |
| xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1])) | ||
| else: | ||
| natom = tf.shape(inputs)[0] | ||
| xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) | ||
| xyz_scatter = tf.cast(tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), self.filter_precision) | ||
| xyz_scatter_total.append(xyz_scatter) | ||
|
|
||
| # natom x nei x outputs_size | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| from deepmd.env import tf | ||
| from deepmd.utils import Plugin, PluginVariant | ||
|
|
||
| class Fitting: | ||
| @property | ||
| def precision(self) -> tf.DType: | ||
| """Precision of fitting network.""" | ||
| return self.fitting_precision |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please write more on the logic behind
cast_precision?GLOBAL_TF_FLOAT_PRECISIONto precision defined by propertyprecision.precisiontoGLOBAL_TF_FLOAT_PRECISION.GLOBAL_TF_FLOAT_PRECISIONorprecision, respectively.