diff --git a/deepmd/model/pairwise_dprc.py b/deepmd/model/pairwise_dprc.py index 80aea92bb1..6983a31cfd 100644 --- a/deepmd/model/pairwise_dprc.py +++ b/deepmd/model/pairwise_dprc.py @@ -32,10 +32,6 @@ TypeEmbedNet, ) -from .ener import ( - EnerModel, -) - class PairwiseDPRc(Model): """Pairwise Deep Potential - Range Correction.""" @@ -87,13 +83,13 @@ def __init__( padding=True, ) - self.qm_model = EnerModel( + self.qm_model = Model( **qm_model, type_map=type_map, type_embedding=self.typeebd, compress=compress, ) - self.qmmm_model = EnerModel( + self.qmmm_model = Model( **qmmm_model, type_map=type_map, type_embedding=self.typeebd, @@ -187,6 +183,14 @@ def build( mesh_mixed_type = make_default_mesh(False, True) + # allow loading a frozen QM model that has only QM types + # Note: here we don't map the type between models, so + # the type of the frozen model must be the same as + # the first Ntypes of the current model + if self.get_ntypes() > self.qm_model.get_ntypes(): + natoms_qm = tf.slice(natoms_qm, [0], [self.qm_model.get_ntypes() + 2]) + assert self.get_ntypes() == self.qmmm_model.get_ntypes() + qm_dict = self.qm_model.build( coord_qm, atype_qm, @@ -301,7 +305,7 @@ def get_rcut(self): return max(self.qm_model.get_rcut(), self.qmmm_model.get_rcut()) def get_ntypes(self) -> int: - return self.qm_model.get_ntypes() + return self.ntypes def data_stat(self, data): self.qm_model.data_stat(data)