From 6abfac1cdb8e318d0d55fe48f2d7cc1735080094 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 9 Aug 2023 15:01:44 +0800 Subject: [PATCH 1/3] fix bug in deepmd.infer.deep_pot.DeepPot - fix bug in checking t_efield op in the graph - delete arg atomic in `_prepare_feed_dict` (otherwise, efield cannot be imported correctly when using this method) --- deepmd/infer/deep_pot.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 122dfd7442..3d19b28884 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import os from typing import ( TYPE_CHECKING, Callable, @@ -114,15 +115,15 @@ def __init__( operations = [op.name for op in self.graph.get_operations()] # check if the graph has these operations: # if yes add them - if "t_efield" in operations: - self._get_tensor("t_efield:0", "t_efield") + if os.path.join(load_prefix, "t_efield") in operations: + self.tensors.update({"t_efield": "t_efield:0"}) self.has_efield = True else: log.debug("Could not get tensor 't_efield:0'") self.t_efield = None self.has_efield = False - if "load/t_fparam" in operations: + if os.path.join(load_prefix, "t_fparam") in operations: self.tensors.update({"t_fparam": "t_fparam:0"}) self.has_fparam = True else: @@ -130,7 +131,7 @@ def __init__( self.t_fparam = None self.has_fparam = False - if "load/t_aparam" in operations: + if os.path.join(load_prefix, "t_aparam") in operations: self.tensors.update({"t_aparam": "t_aparam:0"}) self.has_aparam = True else: @@ -138,7 +139,7 @@ def __init__( self.t_aparam = None self.has_aparam = False - if "load/spin_attr/ntypes_spin" in operations: + if os.path.join(load_prefix, "spin_attr/ntypes_spin") in operations: self.tensors.update({"t_ntypes_spin": "spin_attr/ntypes_spin:0"}) self.has_spin = True else: @@ -399,7 +400,6 @@ def _prepare_feed_dict( atom_types, fparam=None, aparam=None, - atomic=False, efield=None, mixed_type=False, ): From c73a4fe89eb4b118c6d9471b465713a300fbdcf7 Mon Sep 17 00:00:00 2001 From: Jia-Xin Zhu Date: Wed, 9 Aug 2023 21:01:58 +0800 Subject: [PATCH 2/3] replace os.path.join with --- deepmd/infer/deep_pot.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 3d19b28884..396a9753a4 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -115,7 +115,8 @@ def __init__( operations = [op.name for op in self.graph.get_operations()] # check if the graph has these operations: # if yes add them - if os.path.join(load_prefix, "t_efield") in operations: + + if ("%s/t_efield" % load_prefix) in operations: self.tensors.update({"t_efield": "t_efield:0"}) self.has_efield = True else: @@ -123,7 +124,7 @@ def __init__( self.t_efield = None self.has_efield = False - if os.path.join(load_prefix, "t_fparam") in operations: + if ("%s/t_fparam" % load_prefix) in operations: self.tensors.update({"t_fparam": "t_fparam:0"}) self.has_fparam = True else: @@ -131,7 +132,7 @@ def __init__( self.t_fparam = None self.has_fparam = False - if os.path.join(load_prefix, "t_aparam") in operations: + if ("%s/t_aparam" % load_prefix) in operations: self.tensors.update({"t_aparam": "t_aparam:0"}) self.has_aparam = True else: @@ -139,7 +140,7 @@ def __init__( self.t_aparam = None self.has_aparam = False - if os.path.join(load_prefix, "spin_attr/ntypes_spin") in operations: + if ("%s/spin_attr/ntypes_spin" % load_prefix) in operations: self.tensors.update({"t_ntypes_spin": "spin_attr/ntypes_spin:0"}) self.has_spin = True else: From 5d526c95bae0e257efd62131899302d5202086dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Aug 2023 13:02:33 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/infer/deep_pot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepmd/infer/deep_pot.py b/deepmd/infer/deep_pot.py index 396a9753a4..b3e9be1e67 100644 --- a/deepmd/infer/deep_pot.py +++ b/deepmd/infer/deep_pot.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import os from typing import ( TYPE_CHECKING, Callable, @@ -115,7 +114,7 @@ def __init__( operations = [op.name for op in self.graph.get_operations()] # check if the graph has these operations: # if yes add them - + if ("%s/t_efield" % load_prefix) in operations: self.tensors.update({"t_efield": "t_efield:0"}) self.has_efield = True