From 3a3f38d8568d87fb38b88f3a86f04e4a81c47607 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 4 Oct 2023 16:54:00 -0400 Subject: [PATCH 1/3] pairwise_dprc + frozen Signed-off-by: Jinzhe Zeng (cherry picked from commit c7f38fee421bf90f779cb84d1ecc7fb9edd69a97) --- deepmd/model/pairwise_dprc.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/deepmd/model/pairwise_dprc.py b/deepmd/model/pairwise_dprc.py index 80aea92bb1..9f09518782 100644 --- a/deepmd/model/pairwise_dprc.py +++ b/deepmd/model/pairwise_dprc.py @@ -32,10 +32,6 @@ TypeEmbedNet, ) -from .ener import ( - EnerModel, -) - class PairwiseDPRc(Model): """Pairwise Deep Potential - Range Correction.""" @@ -87,19 +83,19 @@ def __init__( padding=True, ) - self.qm_model = EnerModel( + self.qm_model = Model( **qm_model, type_map=type_map, type_embedding=self.typeebd, compress=compress, ) - self.qmmm_model = EnerModel( + self.qmmm_model = Model( **qmmm_model, type_map=type_map, type_embedding=self.typeebd, compress=compress, ) - add_data_requirement("aparam", 1, atomic=True, must=True, high_prec=False) + add_data_requirement("aparam", 1, atomic=True, must=False, high_prec=False) self.ntypes = len(type_map) self.rcut = max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut()) @@ -301,7 +297,7 @@ def get_rcut(self): return max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut()) def get_ntypes(self) -> int: - return self.qm_model.get_ntypes() + return self.qmmm_model.get_ntypes() def data_stat(self, data): self.qm_model.data_stat(data) From 42d841020d77e95429766e6d82296713c926b9e2 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 5 Oct 2023 22:49:19 -0400 Subject: [PATCH 2/3] allow QM models have fewer types Signed-off-by: Jinzhe Zeng --- deepmd/model/pairwise_dprc.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/deepmd/model/pairwise_dprc.py b/deepmd/model/pairwise_dprc.py index 9f09518782..42625af403 100644 --- a/deepmd/model/pairwise_dprc.py +++ b/deepmd/model/pairwise_dprc.py @@ -183,6 +183,14 @@ def build( mesh_mixed_type = make_default_mesh(False, True) + # allow loading a frozen QM model that has only QM types + # Note: here we don't map the type between models, so + # the type of the frozen model must be the same as + # the first Ntypes of the current model + if self.get_ntypes() > self.qm_model.get_ntypes(): + natoms_qm = tf.slice(natoms_qm, [0], [self.qm_model.get_ntypes() + 2]) + assert self.get_ntypes() == self.qmmm_model.get_ntypes() + qm_dict = self.qm_model.build( coord_qm, atype_qm, @@ -297,7 +305,7 @@ def get_rcut(self): return max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut()) def get_ntypes(self) -> int: - return self.qmmm_model.get_ntypes() + return self.ntypes def data_stat(self, data): self.qm_model.data_stat(data) From ed23f1af22816788fb12ea18b5805e45d8238099 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 5 Oct 2023 22:50:10 -0400 Subject: [PATCH 3/3] revert must Signed-off-by: Jinzhe Zeng --- deepmd/model/pairwise_dprc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/model/pairwise_dprc.py b/deepmd/model/pairwise_dprc.py index 42625af403..6983a31cfd 100644 --- a/deepmd/model/pairwise_dprc.py +++ b/deepmd/model/pairwise_dprc.py @@ -95,7 +95,7 @@ def __init__( type_embedding=self.typeebd, compress=compress, ) - add_data_requirement("aparam", 1, atomic=True, must=False, high_prec=False) + add_data_requirement("aparam", 1, atomic=True, must=True, high_prec=False) self.ntypes = len(type_map) self.rcut = max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut())