From e72e27b88a60d56dad9f2dba167bead861936069 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 9 Aug 2021 21:41:02 -0400 Subject: [PATCH 1/3] do some small optimization to ops 1. avoid concat or add in loops. Instead, append tensors to a list, and concat or accumulate_n after loops 2. remove a duplicated reshape --- deepmd/descriptor/se_a.py | 14 ++++++++------ deepmd/fit/dipole.py | 7 +++---- deepmd/fit/ener.py | 10 +++++----- deepmd/fit/polar.py | 14 ++++++-------- deepmd/fit/wfc.py | 7 +++---- 5 files changed, 25 insertions(+), 27 deletions(-) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 60de701886..f5cb803054 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -624,7 +624,7 @@ def _filter_lower( # we can safely return the final xyz_scatter filled with zero directly return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) # natom x nei_type_i x out_size - xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) + xyz_scatter = tf.reshape(xyz_scatter, (natom, shape_i[1]//4, outputs_size[-1])) # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below # [588 24] -> [588 6 4] correct # but if sel is zero @@ -667,6 +667,7 @@ def _filter( type_i = 0 # natom x 4 x outputs_size if type_embedding is None: + rets = [] for type_i in range(self.ntypes): ret = self._filter_lower( type_i, type_input, @@ -681,12 +682,12 @@ def _filter( bavg = bavg, trainable = trainable, suffix = "_"+str(type_i)) - if type_i == 0: - xyz_scatter_1 = ret - elif (type_input, type_i) not in self.exclude_types: + if (type_input, type_i) not in self.exclude_types: # add zero is meaningless; skip - xyz_scatter_1+= ret + rets.append(ret) start_index += self.sel_a[type_i] + # maybe faster to use accumulate_n than multiple add + xyz_scatter_1 = tf.accumulate_n(rets) else : xyz_scatter_1 = self._filter_lower( type_i, type_input, @@ -718,6 +719,7 @@ def _filter( # natom x outputs_size x outputs_size_2 result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a = True) # natom x (outputs_size x outputs_size_2) - result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) + # we have reshaped in _pass_filter method + #result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) return result, qmat diff --git a/deepmd/fit/dipole.py b/deepmd/fit/dipole.py index 73562951dc..b9193e8176 100644 --- a/deepmd/fit/dipole.py +++ b/deepmd/fit/dipole.py @@ -123,6 +123,7 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 + outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -154,11 +155,9 @@ def build (self, final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3]) # concat the results - if count == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) count += 1 + outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index 03145076cb..897a10c476 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -389,6 +389,7 @@ def build (self, if atype_embed is None: start_index = 0 + outs_list = [] for type_i in range(self.ntypes): if bias_atom_e is None : type_bias_ae = 0.0 @@ -408,12 +409,11 @@ def build (self, ) final_layer += self.atom_ener[type_i] - zero_layer final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i]]) - # concat the results - if type_i == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) start_index += natoms[2+type_i] + # concat the results + # concat once may be faster than multiple concat + outs = tf.concat(outs_list, axis = 1) # with type embedding else: if len(self.atom_ener) > 0: diff --git a/deepmd/fit/polar.py b/deepmd/fit/polar.py index 33c5be378a..24385cbe89 100644 --- a/deepmd/fit/polar.py +++ b/deepmd/fit/polar.py @@ -55,6 +55,7 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, 9 * natoms[0]]) count = 0 + outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -88,11 +89,9 @@ def build (self, final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3, 3]) # concat the results - if count == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) count += 1 + outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) @@ -302,6 +301,7 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]]) count = 0 + outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -358,11 +358,9 @@ def build (self, final_layer = final_layer + self.constant_matrix[sel_type_idx] * tf.eye(3, batch_shape=[tf.shape(inputs)[0], natoms[2+type_i]], dtype = GLOBAL_TF_FLOAT_PRECISION) # concat the results - if count == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) count += 1 + outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) diff --git a/deepmd/fit/wfc.py b/deepmd/fit/wfc.py index fce82bd04b..fb3dec654e 100644 --- a/deepmd/fit/wfc.py +++ b/deepmd/fit/wfc.py @@ -62,6 +62,7 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, 9 * natoms[0]]) count = 0 + outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -93,11 +94,9 @@ def build (self, final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], self.wfc_numb, 3]) # concat the results - if count == 0: - outs = final_layer - else: - outs = tf.concat([outs, final_layer], axis = 1) + outs_list.append(final_layer) count += 1 + outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION) From 0385c47a524b548d6df1aa86550d643912dab158 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 14 Jan 2022 01:56:42 -0500 Subject: [PATCH 2/3] revert unnecessary changes --- deepmd/descriptor/se_a.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index fb11171fb3..cf218309bd 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -738,7 +738,7 @@ def _filter_lower( # we can safely return the final xyz_scatter filled with zero directly return tf.cast(tf.fill((natom, 4, outputs_size[-1]), 0.), self.filter_precision) # natom x nei_type_i x out_size - xyz_scatter = tf.reshape(xyz_scatter, (natom, shape_i[1]//4, outputs_size[-1])) + xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1]//4, outputs_size[-1])) # When using tf.reshape(inputs_i, [-1, shape_i[1]//4, 4]) below # [588 24] -> [588 6 4] correct # but if sel is zero @@ -802,7 +802,7 @@ def _filter( # add zero is meaningless; skip rets.append(ret) start_index += self.sel_a[type_i] - # maybe faster to use accumulate_n than multiple add + # faster to use accumulate_n than multiple add xyz_scatter_1 = tf.accumulate_n(rets) else : xyz_scatter_1 = self._filter_lower( @@ -835,7 +835,6 @@ def _filter( # natom x outputs_size x outputs_size_2 result = tf.matmul(xyz_scatter_1, xyz_scatter_2, transpose_a = True) # natom x (outputs_size x outputs_size_2) - # we have reshaped in _pass_filter method - #result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) + result = tf.reshape(result, [-1, outputs_size_2 * outputs_size[-1]]) return result, qmat From 1f05364f8accbda05f391cd99756c06c7139596b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 14 Jan 2022 01:57:28 -0500 Subject: [PATCH 3/3] revert wfc.py as it has been decrepated --- deepmd/fit/wfc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/deepmd/fit/wfc.py b/deepmd/fit/wfc.py index 09dd619285..9b6b217432 100644 --- a/deepmd/fit/wfc.py +++ b/deepmd/fit/wfc.py @@ -65,7 +65,6 @@ def build (self, rot_mat = tf.reshape(rot_mat, [-1, 9 * natoms[0]]) count = 0 - outs_list = [] for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice (inputs, @@ -97,9 +96,11 @@ def build (self, final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], self.wfc_numb, 3]) # concat the results - outs_list.append(final_layer) + if count == 0: + outs = final_layer + else: + outs = tf.concat([outs, final_layer], axis = 1) count += 1 - outs = tf.concat(outs_list, axis = 1) tf.summary.histogram('fitting_net_output', outs) return tf.cast(tf.reshape(outs, [-1]), GLOBAL_TF_FLOAT_PRECISION)