diff --git a/deepmd/utils/tabulate.py b/deepmd/utils/tabulate.py index ee1088bd3c..1aa108c185 100644 --- a/deepmd/utils/tabulate.py +++ b/deepmd/utils/tabulate.py @@ -84,26 +84,17 @@ def __init__(self, self.sub_sess = tf.Session(graph = self.sub_graph) if isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): - try: - self.sel_a = self.graph.get_operation_by_name('ProdEnvMatR').get_attr('sel') - self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatR') - except KeyError: - self.sel_a = self.graph.get_operation_by_name('DescrptSeR').get_attr('sel') - self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeR') + self.sel_a = self.descrpt.sel_r + self.rcut = self.descrpt.rcut + self.rcut_smth = self.descrpt.rcut_smth elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeA): - try: - self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatA') - except KeyError: - self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeA') + self.sel_a = self.descrpt.sel_a + self.rcut = self.descrpt.rcut_r + self.rcut_smth = self.descrpt.rcut_r_smth elif isinstance(self.descrpt, deepmd.descriptor.DescrptSeT): - try: - self.sel_a = self.graph.get_operation_by_name('ProdEnvMatA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('ProdEnvMatA') - except KeyError: - self.sel_a = self.graph.get_operation_by_name('DescrptSeA').get_attr('sel_a') - self.prod_env_mat_op = self.graph.get_operation_by_name ('DescrptSeA') + self.sel_a = self.descrpt.sel_a + self.rcut = self.descrpt.rcut_r + self.rcut_smth = self.descrpt.rcut_r_smth else: raise RuntimeError("Unsupported descriptor") @@ -111,13 +102,6 @@ def __init__(self, self.dstd = get_tensor_by_name_from_graph(self.graph, f'descrpt_attr{self.suffix}/t_std') self.ntypes = get_tensor_by_name_from_graph(self.graph, 'descrpt_attr/ntypes') - if isinstance(self.descrpt, deepmd.descriptor.DescrptSeR): - self.rcut = self.prod_env_mat_op.get_attr('rcut') - self.rcut_smth = self.prod_env_mat_op.get_attr('rcut_smth') - else: - self.rcut = self.prod_env_mat_op.get_attr('rcut_r') - self.rcut_smth = self.prod_env_mat_op.get_attr('rcut_r_smth') - self.embedding_net_nodes = get_embedding_net_nodes_from_graph_def(self.graph_def, suffix=self.suffix) # move it to the descriptor class