From 14a980c21b50afe04d9b022272074101a077bcb1 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 18 May 2023 16:56:13 -0400 Subject: [PATCH] make Fittings pluginable Signed-off-by: Jinzhe Zeng --- deepmd/descriptor/descriptor.py | 3 ++- deepmd/fit/__init__.py | 4 +++ deepmd/fit/dipole.py | 1 + deepmd/fit/dos.py | 1 + deepmd/fit/ener.py | 1 + deepmd/fit/fitting.py | 46 ++++++++++++++++++++++++++++++++- deepmd/fit/polar.py | 1 + deepmd/train/trainer.py | 35 +++++-------------------- 8 files changed, 62 insertions(+), 30 deletions(-) diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index 38b56a85be..3542918cb3 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -3,6 +3,7 @@ ) from typing import ( Any, + Callable, Dict, List, Optional, @@ -43,7 +44,7 @@ class Descriptor(PluginVariant): __plugins = Plugin() @staticmethod - def register(key: str) -> "Descriptor": + def register(key: str) -> Callable: """Register a descriptor plugin. Parameters diff --git a/deepmd/fit/__init__.py b/deepmd/fit/__init__.py index 174dbd443d..875c67cd5d 100644 --- a/deepmd/fit/__init__.py +++ b/deepmd/fit/__init__.py @@ -7,6 +7,9 @@ from .ener import ( EnerFitting, ) +from .fitting import ( + Fitting, +) from .polar import ( GlobalPolarFittingSeA, PolarFittingSeA, @@ -18,4 +21,5 @@ "DOSFitting", "GlobalPolarFittingSeA", "PolarFittingSeA", + "Fitting", ] diff --git a/deepmd/fit/dipole.py b/deepmd/fit/dipole.py index 1087bbf340..9f497584ae 100644 --- a/deepmd/fit/dipole.py +++ b/deepmd/fit/dipole.py @@ -25,6 +25,7 @@ ) +@Fitting.register("dipole") class DipoleFittingSeA(Fitting): r"""Fit the atomic dipole with descriptor se_a. diff --git a/deepmd/fit/dos.py b/deepmd/fit/dos.py index 43e31185d1..4bb1838af1 100644 --- a/deepmd/fit/dos.py +++ b/deepmd/fit/dos.py @@ -40,6 +40,7 @@ log = logging.getLogger(__name__) +@Fitting.register("dos") class DOSFitting(Fitting): r"""Fitting the density of states (DOS) of the system. The energy should be shifted by the fermi level. diff --git a/deepmd/fit/ener.py b/deepmd/fit/ener.py index b01a46e7a8..f482173495 100644 --- a/deepmd/fit/ener.py +++ b/deepmd/fit/ener.py @@ -47,6 +47,7 @@ log = logging.getLogger(__name__) +@Fitting.register("ener") class EnerFitting(Fitting): r"""Fitting the energy of the system. The force and the virial can also be trained. diff --git a/deepmd/fit/fitting.py b/deepmd/fit/fitting.py index 7ddb42640a..b0a9efe3be 100644 --- a/deepmd/fit/fitting.py +++ b/deepmd/fit/fitting.py @@ -1,9 +1,53 @@ +from typing import ( + Callable, +) + from deepmd.env import ( tf, ) +from deepmd.utils import ( + Plugin, + PluginVariant, +) + + +class Fitting(PluginVariant): + __plugins = Plugin() + + @staticmethod + def register(key: str) -> Callable: + """Register a Fitting plugin. + + Parameters + ---------- + key : str + the key of a Fitting + + Returns + ------- + Fitting + the registered Fitting + + Examples + -------- + >>> @Fitting.register("some_fitting") + class SomeFitting(Fitting): + pass + """ + return Fitting.__plugins.register(key) + def __new__(cls, *args, **kwargs): + if cls is Fitting: + try: + fitting_type = kwargs["type"] + except KeyError: + raise KeyError("the type of fitting should be set by `type`") + if fitting_type in Fitting.__plugins.plugins: + cls = Fitting.__plugins.plugins[fitting_type] + else: + raise RuntimeError("Unknown descriptor type: " + fitting_type) + return super().__new__(cls) -class Fitting: @property def precision(self) -> tf.DType: """Precision of fitting network.""" diff --git a/deepmd/fit/polar.py b/deepmd/fit/polar.py index 2a56334dff..42b8d155f5 100644 --- a/deepmd/fit/polar.py +++ b/deepmd/fit/polar.py @@ -29,6 +29,7 @@ ) +@Fitting.register("polar") class PolarFittingSeA(Fitting): r"""Fit the atomic polarizability with descriptor se_a. diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 2c5f668589..183f3b0bff 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -34,10 +34,7 @@ tfv2, ) from deepmd.fit import ( - DipoleFittingSeA, - DOSFitting, - EnerFitting, - PolarFittingSeA, + Fitting, ) from deepmd.loss import ( DOSLoss, @@ -157,30 +154,13 @@ def _init_param(self, jdata): self.descrpt = Descriptor(**descrpt_param) # fitting net - def fitting_net_init(fitting_type_, descrpt_type_, params): - if fitting_type_ == "ener": - params["spin"] = self.spin - return EnerFitting(**params) - elif fitting_type_ == "dos": - return DOSFitting(**params) - elif fitting_type_ == "dipole": - return DipoleFittingSeA(**params) - elif fitting_type_ == "polar": - return PolarFittingSeA(**params) - # elif fitting_type_ == 'global_polar': - # if descrpt_type_ == 'se_e2_a': - # return GlobalPolarFittingSeA(**params) - # else: - # raise RuntimeError('fitting global_polar only supports descrptors: loc_frame and se_e2_a') - else: - raise RuntimeError("unknown fitting type " + fitting_type_) - if not self.multi_task_mode: fitting_type = fitting_param.get("type", "ener") self.fitting_type = fitting_type - fitting_param.pop("type", None) fitting_param["descrpt"] = self.descrpt - self.fitting = fitting_net_init(fitting_type, descrpt_type, fitting_param) + if fitting_type == "ener": + fitting_param["spin"] = self.spin + self.fitting = Fitting(**fitting_param) else: self.fitting_dict = {} self.fitting_type_dict = {} @@ -189,11 +169,10 @@ def fitting_net_init(fitting_type_, descrpt_type_, params): item_fitting_param = fitting_param[item] item_fitting_type = item_fitting_param.get("type", "ener") self.fitting_type_dict[item] = item_fitting_type - item_fitting_param.pop("type", None) item_fitting_param["descrpt"] = self.descrpt - self.fitting_dict[item] = fitting_net_init( - item_fitting_type, descrpt_type, item_fitting_param - ) + if item_fitting_type == "ener": + item_fitting_param["spin"] = self.spin + self.fitting_dict[item] = Fitting(**item_fitting_param) # type embedding padding = False