From 4f9d64805579f25d8247c62b1122e440ad34720a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 27 Apr 2022 21:10:02 -0400 Subject: [PATCH 1/2] fix rcut in hybrid model compression The current way `self.graph.get_operation_by_name('ProdEnvMatA')` to get `rcut` is incorrect for hybrid models. There may be several ProdEnvMatA ops in a graph. (cherry picked from commit 34a4e9b04f815a86af5301fd2536ec261b1d2a91) --- deepmd/utils/tabulate.py | 34 +++++++++------------------------- 1 file changed, 9 insertions(+), 25 deletions(-) diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py index ee1088bd3c..af19ab84ce 100644 --- a/deepmd/utils/tabulate.py +++ b/deepmd/utils/tabulate.py @@ -84,26 +84,17 @@ def __init__(self, self.sub_sess = tf.Session(graph = self.sub_graph) if isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): - try: - self.sel_a = self.graph.get_operation_by_name('ProdEnvMatR').get_attr('sel') - self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatR') - except KeyError: - self.sel_a = self.graph.get_operation_by_name('DescrptSeR').get_attr('sel') - self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeR') + self.sel_a = self.descrpt.sel_r + self.rcut = self.descrpt.rcut_r + self.rcut_smth = self.descrpt.rcut_r_smth elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeA): - try: - self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatA') - except KeyError: - self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeA') + self.sel_a = self.descrpt.sel_a + self.rcut = self.descrpt.rcut_r + self.rcut_smth = self.descrpt.rcut_r_smth elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): - try: - self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatA') - except KeyError: - self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeA') + self.sel_a = self.descrpt.sel_a + self.rcut = self.descrpt.rcut_r + self.rcut_smth = self.descrpt.rcut_r_smth else: raise RuntimeError("Unsupported descriptor") @@ -111,13 +102,6 @@ def __init__(self, self.dstd = get_tensor_by_name_from_graph(self.graph, f'descrpt_attr{self.suffix}/t_std') self.ntypes = get_tensor_by_name_from_graph(self.graph, 'descrpt_attr/ntypes') - if isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): - self.rcut = self.prod_env_mat_op.get_attr('rcut') - self.rcut_smth = self.prod_env_mat_op.get_attr('rcut_smth') - else: - self.rcut = self.prod_env_mat_op.get_attr('rcut_r') - self.rcut_smth = self.prod_env_mat_op.get_attr('rcut_r_smth') - self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def(self.graph_def, suffix=self.suffix) # move it to the descriptor class From 9bc964d1da4842db8c465a1c7f370f8a37805551 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 27 Apr 2022 21:25:39 -0400 Subject: [PATCH 2/2] fix se_r attr --- deepmd/utils/tabulate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py index af19ab84ce..1aa108c185 100644 --- a/deepmd/utils/tabulate.py +++ b/deepmd/utils/tabulate.py @@ -85,8 +85,8 @@ def __init__(self, if isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): self.sel_a = self.descrpt.sel_r - self.rcut = self.descrpt.rcut_r - self.rcut_smth = self.descrpt.rcut_r_smth + self.rcut = self.descrpt.rcut + self.rcut_smth = self.descrpt.rcut_smth elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeA): self.sel_a = self.descrpt.sel_a self.rcut = self.descrpt.rcut_r