From aa022e8a46b36c56107dbb954000e434963c48a6 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 9 Sep 2021 06:22:29 -0400 Subject: [PATCH 1/6] implement descriptor plugin; pass params to `Descriptor` --- deepmd/__init__.py | 8 +++ deepmd/descriptor/__init__.py | 1 + deepmd/descriptor/descriptor.py | 46 +++++++++++++++- deepmd/descriptor/hybrid.py | 11 +++- deepmd/descriptor/se_a.py | 2 + deepmd/descriptor/se_a_ebd.py | 3 ++ deepmd/descriptor/se_a_ef.py | 1 + deepmd/descriptor/se_ar.py | 3 +- deepmd/descriptor/se_r.py | 2 + deepmd/descriptor/se_t.py | 3 ++ deepmd/train/trainer.py | 48 ++--------------- deepmd/utils/__init__.py | 1 + deepmd/utils/plugin.py | 84 ++++++++++++++++++++++++++++++ source/tests/test_descrpt_se_ar.py | 2 +- 14 files changed, 165 insertions(+), 50 deletions(-) create mode 100644 deepmd/utils/plugin.py diff --git a/deepmd/__init__.py b/deepmd/__init__.py index 3a295dcbef..a99f71b7eb 100644 --- a/deepmd/__init__.py +++ b/deepmd/__init__.py @@ -1,5 +1,9 @@ """Root of the deepmd package, exposes all public classes and submodules.""" +try: + from importlib import metadata +except ImportError: # for Python<3.8 + import importlib_metadata as metadata import deepmd.utils.network as network from . import cluster, descriptor, fit, loss, utils @@ -14,6 +18,10 @@ except ImportError: from .__about__ import __version__ +# load third-party plugins +for ep in metadata.entry_points().get('deepmd', []): + ep.load() + __all__ = [ "descriptor", "fit", diff --git a/deepmd/descriptor/__init__.py b/deepmd/descriptor/__init__.py index 7c3d910091..c229d586e0 100644 --- a/deepmd/descriptor/__init__.py +++ b/deepmd/descriptor/__init__.py @@ -1,3 +1,4 @@ +from .descriptor import Descriptor from .hybrid import DescrptHybrid from .se_a import DescrptSeA from .se_r import DescrptSeR diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index 893493a201..a2287316d1 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -3,21 +3,63 @@ import numpy as np from deepmd.env import tf +from deepmd.utils import Plugin, Variant - -class Descriptor(ABC): +class Descriptor(Variant): r"""The abstract class for descriptors. All specific descriptors should be based on this class. The descriptor :math:`\mathcal{D}` describes the environment of an atom, which should be a function of coordinates and types of its neighbour atoms. + Examples + -------- + >>> descript = Descriptor(type="se_e2_a", rcut=6., rcut_smth=0.5, sel=[50]) + >>> type(descript) + + Notes ----- Only methods and attributes defined in this class are generally public, that can be called by other classes. """ + __plugins = Plugin() + + @staticmethod + def register(key) -> "Descriptor": + """Regiester a descriptor plugin. + + Parameters + ---------- + key : str + the key of a descriptor + + Returns + ------- + Descriptor + the regiestered descriptor + + Examples + -------- + >>> @Descriptor.register("some_descrpt") + class SomeDescript(Descriptor): + pass + """ + return Descriptor.__plugins.register(key) + + def __new__(cls, *args, **kwargs): + if cls is Descriptor: + try: + descrpt_type = kwargs['type'] + except KeyError: + raise KeyError('the type of descriptor should be set by `type`') + if descrpt_type in Descriptor.__plugins.plugins: + cls = Descriptor.__plugins.plugins[descrpt_type] + else: + raise RuntimeError('Unknown descriptor type: ' + descrpt_type) + return super().__new__(cls) + @abstractmethod def get_rcut(self) -> float: """ diff --git a/deepmd/descriptor/hybrid.py b/deepmd/descriptor/hybrid.py index 37b9578b4e..d1ef49aac8 100644 --- a/deepmd/descriptor/hybrid.py +++ b/deepmd/descriptor/hybrid.py @@ -21,6 +21,7 @@ from .se_a_ef import DescrptSeAEf from .loc_frame import DescrptLocFrame +@Descriptor.register("hybrid") class DescrptHybrid (Descriptor): """Concate a list of descriptors to form a new descriptor. @@ -37,11 +38,19 @@ def __init__ (self, """ if descrpt_list == [] or descrpt_list is None: raise RuntimeError('cannot build descriptor from an empty list of descriptors.') + formatted_descript_list = [] + for ii in descrpt_list: + if isinstance(ii, Descriptor): + formatted_descript_list.append(ii) + elif isinstance(ii, dict): + formatted_descript_list.append(Descriptor(**ii)) + else: + raise NotImplementedError # args = ClassArg()\ # .add('list', list, must = True) # class_data = args.parse(jdata) # dict_list = class_data['list'] - self.descrpt_list = descrpt_list + self.descrpt_list = formatted_descript_list self.numb_descrpt = len(self.descrpt_list) for ii in range(1, self.numb_descrpt): assert(self.descrpt_list[ii].get_ntypes() == diff --git a/deepmd/descriptor/se_a.py b/deepmd/descriptor/se_a.py index 39485463a9..893c152382 100644 --- a/deepmd/descriptor/se_a.py +++ b/deepmd/descriptor/se_a.py @@ -16,6 +16,8 @@ from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables from .descriptor import Descriptor +@Descriptor.register("se_e2_a") +@Descriptor.register("se_a") class DescrptSeA (Descriptor): r"""DeepPot-SE constructed from all information (both angular and radial) of atomic configurations. The embedding takes the distance between atoms as input. diff --git a/deepmd/descriptor/se_a_ebd.py b/deepmd/descriptor/se_a_ebd.py index b30501963f..7a1f640153 100644 --- a/deepmd/descriptor/se_a_ebd.py +++ b/deepmd/descriptor/se_a_ebd.py @@ -10,7 +10,10 @@ from deepmd.env import default_tf_session_config from deepmd.utils.network import embedding_net from .se_a import DescrptSeA +from .descriptor import Descriptor +@Descriptor.register("se_a_tpe") +@Descriptor.register("se_a_ebd") class DescrptSeAEbd (DescrptSeA): """DeepPot-SE descriptor with type embedding approach. diff --git a/deepmd/descriptor/se_a_ef.py b/deepmd/descriptor/se_a_ef.py index e2424aeae4..b037475722 100644 --- a/deepmd/descriptor/se_a_ef.py +++ b/deepmd/descriptor/se_a_ef.py @@ -12,6 +12,7 @@ from .se_a import DescrptSeA from .descriptor import Descriptor +@Descriptor.register("se_a_ef") class DescrptSeAEf (Descriptor): """ diff --git a/deepmd/descriptor/se_ar.py b/deepmd/descriptor/se_ar.py index 518cf9315c..2b3a9784b5 100644 --- a/deepmd/descriptor/se_ar.py +++ b/deepmd/descriptor/se_ar.py @@ -7,8 +7,9 @@ from deepmd.env import op_module from .descriptor import Descriptor +@Descriptor.register("se_ar") class DescrptSeAR (Descriptor): - def __init__ (self, jdata): + def __init__ (self, **jdata): args = ClassArg()\ .add('a', dict, must = True) \ .add('r', dict, must = True) diff --git a/deepmd/descriptor/se_r.py b/deepmd/descriptor/se_r.py index b6de76be76..8b014a2501 100644 --- a/deepmd/descriptor/se_r.py +++ b/deepmd/descriptor/se_r.py @@ -12,6 +12,8 @@ from deepmd.utils.sess import run_sess from .descriptor import Descriptor +@Descriptor.register("se_e2_r") +@Descriptor.register("se_r") class DescrptSeR (Descriptor): """DeepPot-SE constructed from radial information of atomic configurations. diff --git a/deepmd/descriptor/se_t.py b/deepmd/descriptor/se_t.py index 2ab7a732be..5663a33644 100644 --- a/deepmd/descriptor/se_t.py +++ b/deepmd/descriptor/se_t.py @@ -12,6 +12,9 @@ from deepmd.utils.sess import run_sess from .descriptor import Descriptor +@Descriptor.register("se_e3") +@Descriptor.register("se_at") +@Descriptor.register("se_a_3be") class DescrptSeT (Descriptor): """DeepPot-SE constructed from all information (both angular and radial) of atomic configurations. diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index dc888ad3e0..7556cc1ebb 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from deepmd.descriptor.descriptor import Descriptor import logging import os import glob @@ -11,14 +12,7 @@ from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION from deepmd.fit import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, GlobalPolarFittingSeA, DipoleFittingSeA -from deepmd.descriptor import DescrptLocFrame -from deepmd.descriptor import DescrptSeA -from deepmd.descriptor import DescrptSeT -from deepmd.descriptor import DescrptSeAEbd -from deepmd.descriptor import DescrptSeAEf -from deepmd.descriptor import DescrptSeR -from deepmd.descriptor import DescrptSeAR -from deepmd.descriptor import DescrptHybrid +from deepmd.descriptor import Descriptor from deepmd.model import EnerModel, WFCModel, DipoleModel, PolarModel, GlobalPolarModel from deepmd.loss import EnerStdLoss, EnerDipoleLoss, TensorLoss from deepmd.utils.errors import GraphTooLargeError @@ -47,36 +41,6 @@ def _is_subdir(path, directory): return False relative = os.path.relpath(path, directory) + os.sep return not relative.startswith(os.pardir + os.sep) - -def _generate_descrpt_from_param_dict(descrpt_param): - try: - descrpt_type = descrpt_param['type'] - except KeyError: - raise KeyError('the type of descriptor should be set by `type`') - descrpt_param.pop('type', None) - to_pop = [] - for kk in descrpt_param: - if kk[0] == '_': - to_pop.append(kk) - for kk in to_pop: - descrpt_param.pop(kk, None) - if descrpt_type == 'loc_frame': - descrpt = DescrptLocFrame(**descrpt_param) - elif descrpt_type == 'se_e2_a' or descrpt_type == 'se_a' : - descrpt = DescrptSeA(**descrpt_param) - elif descrpt_type == 'se_e2_r' or descrpt_type == 'se_r' : - descrpt = DescrptSeR(**descrpt_param) - elif descrpt_type == 'se_e3' or descrpt_type == 'se_at' or descrpt_type == 'se_a_3be' : - descrpt = DescrptSeT(**descrpt_param) - elif descrpt_type == 'se_a_tpe' or descrpt_type == 'se_a_ebd' : - descrpt = DescrptSeAEbd(**descrpt_param) - elif descrpt_type == 'se_a_ef' : - descrpt = DescrptSeAEf(**descrpt_param) - elif descrpt_type == 'se_ar' : - descrpt = DescrptSeAR(descrpt_param) - else : - raise RuntimeError('unknow model type ' + descrpt_type) - return descrpt class DPTrainer (object): @@ -103,13 +67,7 @@ def _init_param(self, jdata): except KeyError: raise KeyError('the type of descriptor should be set by `type`') - if descrpt_type != 'hybrid': - self.descrpt = _generate_descrpt_from_param_dict(descrpt_param) - else : - descrpt_list = [] - for ii in descrpt_param.get('list', []): - descrpt_list.append(_generate_descrpt_from_param_dict(ii)) - self.descrpt = DescrptHybrid(descrpt_list) + self.descrpt = Descriptor(**descrpt_param) # fitting net fitting_type = fitting_param.get('type', 'ener') diff --git a/deepmd/utils/__init__.py b/deepmd/utils/__init__.py index a54f69b853..c2f06593b8 100644 --- a/deepmd/utils/__init__.py +++ b/deepmd/utils/__init__.py @@ -7,3 +7,4 @@ from .data_system import DataSystem from .pair_tab import PairTab from .learning_rate import LearningRateExp +from .plugin import Plugin, Variant diff --git a/deepmd/utils/plugin.py b/deepmd/utils/plugin.py new file mode 100644 index 0000000000..6d2f208b6c --- /dev/null +++ b/deepmd/utils/plugin.py @@ -0,0 +1,84 @@ + +"""Base of plugin systems.""" +# copied from https://github.com/deepmodeling/dpdata/blob/a3e76d75de53f6076254de82d18605a010dc3b00/dpdata/plugin.py + +from abc import ABCMeta +from typing import Callable + + +class Plugin: + """A class to register and restore plugins. + + Attributes + ---------- + plugins : Dict[str, object] + plugins + + Examples + -------- + >>> plugin = Plugin() + >>> @plugin.register("xx") + def xxx(): + pass + >>> print(plugin.plugins['xx']) + """ + def __init__(self): + self.plugins = {} + + def __add__(self, other) -> "Plugin": + self.plugins.update(other.plugins) + return self + + def register(self, key : str) -> Callable[[object], object]: + """Register a plugin. + + Parameter + --------- + key : str + key of the plugin + + Returns + ------- + Callable[[object], object] + decorator + """ + def decorator(object : object) -> object: + self.plugins[key] = object + return object + return decorator + + def get_plugin(self, key) -> object: + """Visit a plugin by key. + + Parameters + ---------- + key : str + key of the plugin + + Returns + ------- + object + the plugin + """ + return self.plugins[key] + +class VariantMeta: + def __call__(cls, *args, **kwargs): + """Remove `type` and keys that starts with underline.""" + obj = cls.__new__(cls, *args, **kwargs) + kwargs.pop('type', None) + to_pop = [] + for kk in kwargs: + if kk[0] == '_': + to_pop.append(kk) + for kk in to_pop: + kwargs.pop(kk, None) + obj.__init__(*args, **kwargs) + return obj + +class VariantABCMeta(VariantMeta, ABCMeta): + pass + +class Variant(metaclass=VariantABCMeta): + """A class to remove `type` from input arguments.""" + pass \ No newline at end of file diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py index 9ebb085bab..65da8ba1c1 100644 --- a/source/tests/test_descrpt_se_ar.py +++ b/source/tests/test_descrpt_se_ar.py @@ -43,7 +43,7 @@ def setUp (self, 'seed': 1, } param = {'a': param_a, 'r': param_r} - self.descrpt = DescrptSeAR(param) + self.descrpt = DescrptSeAR(**param) self.ndescrpt = self.descrpt.get_dim_out() # davg = np.zeros ([self.ntypes, self.ndescrpt]) # dstd = np.ones ([self.ntypes, self.ndescrpt]) From 12d3bf16fd342bb261ab5b77986e3156401332c0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 9 Sep 2021 14:55:31 -0400 Subject: [PATCH 2/6] avoid class name conflict --- deepmd/descriptor/descriptor.py | 4 ++-- deepmd/utils/__init__.py | 2 +- deepmd/utils/plugin.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index a2287316d1..4e60e72c6f 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -3,9 +3,9 @@ import numpy as np from deepmd.env import tf -from deepmd.utils import Plugin, Variant +from deepmd.utils import Plugin, PluginVariant -class Descriptor(Variant): +class Descriptor(PluginVariant): r"""The abstract class for descriptors. All specific descriptors should be based on this class. diff --git a/deepmd/utils/__init__.py b/deepmd/utils/__init__.py index c2f06593b8..e81b474095 100644 --- a/deepmd/utils/__init__.py +++ b/deepmd/utils/__init__.py @@ -7,4 +7,4 @@ from .data_system import DataSystem from .pair_tab import PairTab from .learning_rate import LearningRateExp -from .plugin import Plugin, Variant +from .plugin import Plugin, PluginVariant diff --git a/deepmd/utils/plugin.py b/deepmd/utils/plugin.py index 6d2f208b6c..f195b7808c 100644 --- a/deepmd/utils/plugin.py +++ b/deepmd/utils/plugin.py @@ -79,6 +79,6 @@ def __call__(cls, *args, **kwargs): class VariantABCMeta(VariantMeta, ABCMeta): pass -class Variant(metaclass=VariantABCMeta): +class PluginVariant(metaclass=VariantABCMeta): """A class to remove `type` from input arguments.""" pass \ No newline at end of file From f13c41bbf6f894682458b0f6de5da8543e56091b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 9 Sep 2021 15:24:21 -0400 Subject: [PATCH 3/6] add ArgsPlugin --- deepmd/descriptor/descriptor.py | 2 +- deepmd/utils/argcheck.py | 67 +++++++++++++++++++++++++++++---- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index 4e60e72c6f..09bc29c13d 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -27,7 +27,7 @@ class Descriptor(PluginVariant): __plugins = Plugin() @staticmethod - def register(key) -> "Descriptor": + def register(key : str) -> "Descriptor": """Regiester a descriptor plugin. Parameters diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 6279b3e088..a09bfecdec 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1,5 +1,9 @@ +from typing import List, Callable + from dargs import dargs, Argument, Variant, ArgumentEncoder +from deepmd import descriptor from deepmd.common import ACTIVATION_FN_DICT, PRECISION_DICT +from deepmd.utils.plugin import Plugin import json @@ -38,6 +42,55 @@ def type_embedding_args(): # --- Descriptor configurations: --- # + +class ArgsPlugin: + def __init__(self) -> None: + self.__plugin = Plugin() + + def register(self, name : str, alias : List[str] = None) -> Callable[[], List[Argument]]: + """Regiester a descriptor argument plugin. + + Parameters + ---------- + name : str + the name of a descriptor + alias : List[str], optional + the list of aliases of this descriptor + + Returns + ------- + Callable[[], List[Argument]] + the regiestered descriptor argument method + + Examples + -------- + >>> some_plugin = ArgsPlugin() + >>> @some_plugin.register("some_descrpt") + def descrpt_some_descrpt_args(): + return [] + """ + # convert alias to hashed item + if isinstance(alias, list): + alias = tuple(alias) + return self.__plugin.register((name, alias)) + + def get_all_argument(self) -> List[Argument]: + """Get all arguments. + + Returns + ------- + List[Argument] + all arguments + """ + arguments = [] + for (name, alias), metd in self.__plugin.plugins.items(): + arguments.append(Argument(name=name, dtype=dict, sub_fields=metd(), alias=alias)) + return arguments + + +descrpt_args_plugin = ArgsPlugin() + +@descrpt_args_plugin.register("local_frame") def descrpt_local_frame_args (): doc_sel_a = 'A list of integers. The length of the list should be the same as the number of atom types in the system. `sel_a[i]` gives the selected number of type-i neighbors. The full relative coordinates of the neighbors are used by the descriptor.' doc_sel_r = 'A list of integers. The length of the list should be the same as the number of atom types in the system. `sel_r[i]` gives the selected number of type-i neighbors. Only relative distance of the neighbors are used by the descriptor. sel_a[i] + sel_r[i] is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius.' @@ -58,6 +111,7 @@ def descrpt_local_frame_args (): ] +@descrpt_args_plugin.register("se_e2_a", alias=["se_a"]) def descrpt_se_a_args(): doc_sel = 'This parameter set the number of selected neighbors for each type of atom. It can be:\n\n\ - `List[int]`. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. `sel[i]` is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius. It is noted that the total sel value must be less than 4096 in a GPU environment.\n\n\ @@ -92,6 +146,7 @@ def descrpt_se_a_args(): ] +@descrpt_args_plugin.register("se_e3", alias=['se_at', 'se_a_3be', 'se_t']) def descrpt_se_t_args(): doc_sel = 'This parameter set the number of selected neighbors for each type of atom. It can be:\n\n\ - `List[int]`. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. `sel[i]` is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius. It is noted that the total sel value must be less than 4096 in a GPU environment.\n\n\ @@ -121,6 +176,7 @@ def descrpt_se_t_args(): +@descrpt_args_plugin.register("se_a_tpe", alias=['se_a_ebd']) def descrpt_se_a_tpe_args(): doc_type_nchanl = 'number of channels for type embedding' doc_type_nlayer = 'number of hidden layers of type embedding net' @@ -133,6 +189,7 @@ def descrpt_se_a_tpe_args(): ] +@descrpt_args_plugin.register("se_e2_r", alias=['se_r']) def descrpt_se_r_args(): doc_sel = 'This parameter set the number of selected neighbors for each type of atom. It can be:\n\n\ - `List[int]`. The length of the list should be the same as the number of atom types in the system. `sel[i]` gives the selected number of type-i neighbors. `sel[i]` is recommended to be larger than the maximally possible number of type-i neighbors in the cut-off radius. It is noted that the total sel value must be less than 4096 in a GPU environment.\n\n\ @@ -177,6 +234,7 @@ def descrpt_se_ar_args(): ] +@descrpt_args_plugin.register("hybrid") def descrpt_hybrid_args(): doc_list = f'A list of descriptor definitions' @@ -200,14 +258,7 @@ def descrpt_variant_type_args(): - `se_a_tpe`: Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor. Type embedding will be used by this descriptor.\n\n\ - `hybrid`: Concatenate of a list of descriptors as a new descriptor.' - return Variant("type", [ - Argument("loc_frame", dict, descrpt_local_frame_args()), - Argument("se_e2_a", dict, descrpt_se_a_args(), alias = ['se_a']), - Argument("se_e2_r", dict, descrpt_se_r_args(), alias = ['se_r']), - Argument("se_e3", dict, descrpt_se_t_args(), alias = ['se_at', 'se_a_3be', 'se_t']), - Argument("se_a_tpe", dict, descrpt_se_a_tpe_args(), alias = ['se_a_ebd']), - Argument("hybrid", dict, descrpt_hybrid_args()), - ], doc = doc_descrpt_type) + return Variant("type", descrpt_args_plugin.get_all_argument(), doc = doc_descrpt_type) # --- Fitting net configurations: --- # From 29ba2f07ae9ce82eb8c2032368eeb9bfe9f51f23 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 9 Sep 2021 15:40:39 -0400 Subject: [PATCH 4/6] fix lint warning --- deepmd/__init__.py | 2 +- deepmd/descriptor/descriptor.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/deepmd/__init__.py b/deepmd/__init__.py index a99f71b7eb..9c3af2a6c7 100644 --- a/deepmd/__init__.py +++ b/deepmd/__init__.py @@ -2,7 +2,7 @@ try: from importlib import metadata -except ImportError: # for Python<3.8 +except ImportError: # for Python<3.8 import importlib_metadata as metadata import deepmd.utils.network as network diff --git a/deepmd/descriptor/descriptor.py b/deepmd/descriptor/descriptor.py index 09bc29c13d..d179660a9d 100644 --- a/deepmd/descriptor/descriptor.py +++ b/deepmd/descriptor/descriptor.py @@ -5,6 +5,7 @@ from deepmd.env import tf from deepmd.utils import Plugin, PluginVariant + class Descriptor(PluginVariant): r"""The abstract class for descriptors. All specific descriptors should be based on this class. @@ -27,9 +28,9 @@ class Descriptor(PluginVariant): __plugins = Plugin() @staticmethod - def register(key : str) -> "Descriptor": + def register(key: str) -> "Descriptor": """Regiester a descriptor plugin. - + Parameters ---------- key : str @@ -39,7 +40,7 @@ def register(key : str) -> "Descriptor": ------- Descriptor the regiestered descriptor - + Examples -------- >>> @Descriptor.register("some_descrpt") @@ -47,7 +48,7 @@ class SomeDescript(Descriptor): pass """ return Descriptor.__plugins.register(key) - + def __new__(cls, *args, **kwargs): if cls is Descriptor: try: From c913e848931efc357f80e782cecf52dc491ea285 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 10 Sep 2021 17:06:34 -0400 Subject: [PATCH 5/6] revert changes to DescrptSeAR --- deepmd/descriptor/se_ar.py | 3 +-- source/tests/test_descrpt_se_ar.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/deepmd/descriptor/se_ar.py b/deepmd/descriptor/se_ar.py index 2b3a9784b5..518cf9315c 100644 --- a/deepmd/descriptor/se_ar.py +++ b/deepmd/descriptor/se_ar.py @@ -7,9 +7,8 @@ from deepmd.env import op_module from .descriptor import Descriptor -@Descriptor.register("se_ar") class DescrptSeAR (Descriptor): - def __init__ (self, **jdata): + def __init__ (self, jdata): args = ClassArg()\ .add('a', dict, must = True) \ .add('r', dict, must = True) diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py index 65da8ba1c1..9ebb085bab 100644 --- a/source/tests/test_descrpt_se_ar.py +++ b/source/tests/test_descrpt_se_ar.py @@ -43,7 +43,7 @@ def setUp (self, 'seed': 1, } param = {'a': param_a, 'r': param_r} - self.descrpt = DescrptSeAR(**param) + self.descrpt = DescrptSeAR(param) self.ndescrpt = self.descrpt.get_dim_out() # davg = np.zeros ([self.ntypes, self.ndescrpt]) # dstd = np.ones ([self.ntypes, self.ndescrpt]) From 1e7cbb0a82bdd0d6bcdcbaef7f22a30f909ea52d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 10 Sep 2021 21:50:52 -0400 Subject: [PATCH 6/6] remove descrpt_se_ar which is replaced by hybrid --- deepmd/descriptor/__init__.py | 1 - deepmd/descriptor/hybrid.py | 1 - deepmd/descriptor/se_ar.py | 94 ----------------------- deepmd/utils/argcheck.py | 12 --- source/tests/test_descrpt_se_ar.py | 118 ----------------------------- 5 files changed, 226 deletions(-) delete mode 100644 deepmd/descriptor/se_ar.py delete mode 100644 source/tests/test_descrpt_se_ar.py diff --git a/deepmd/descriptor/__init__.py b/deepmd/descriptor/__init__.py index c229d586e0..dd022a4e08 100644 --- a/deepmd/descriptor/__init__.py +++ b/deepmd/descriptor/__init__.py @@ -2,7 +2,6 @@ from .hybrid import DescrptHybrid from .se_a import DescrptSeA from .se_r import DescrptSeR -from .se_ar import DescrptSeAR from .se_t import DescrptSeT from .se_a_ebd import DescrptSeAEbd from .se_a_ef import DescrptSeAEf diff --git a/deepmd/descriptor/hybrid.py b/deepmd/descriptor/hybrid.py index d1ef49aac8..66a11be959 100644 --- a/deepmd/descriptor/hybrid.py +++ b/deepmd/descriptor/hybrid.py @@ -15,7 +15,6 @@ from .descriptor import Descriptor from .se_a import DescrptSeA from .se_r import DescrptSeR -from .se_ar import DescrptSeAR from .se_t import DescrptSeT from .se_a_ebd import DescrptSeAEbd from .se_a_ef import DescrptSeAEf diff --git a/deepmd/descriptor/se_ar.py b/deepmd/descriptor/se_ar.py deleted file mode 100644 index 518cf9315c..0000000000 --- a/deepmd/descriptor/se_ar.py +++ /dev/null @@ -1,94 +0,0 @@ -import numpy as np -from deepmd.env import tf -from deepmd.common import ClassArg - -from .se_a import DescrptSeA -from .se_r import DescrptSeR -from deepmd.env import op_module -from .descriptor import Descriptor - -class DescrptSeAR (Descriptor): - def __init__ (self, jdata): - args = ClassArg()\ - .add('a', dict, must = True) \ - .add('r', dict, must = True) - class_data = args.parse(jdata) - self.param_a = class_data['a'] - self.param_r = class_data['r'] - self.descrpt_a = DescrptSeA(**self.param_a) - self.descrpt_r = DescrptSeR(**self.param_r) - assert(self.descrpt_a.get_ntypes() == self.descrpt_r.get_ntypes()) - self.davg = None - self.dstd = None - - def get_rcut (self) : - return np.max([self.descrpt_a.get_rcut(), self.descrpt_r.get_rcut()]) - - def get_ntypes (self) : - return self.descrpt_r.get_ntypes() - - def get_dim_out (self) : - return (self.descrpt_a.get_dim_out() + self.descrpt_r.get_dim_out()) - - def get_nlist_a (self) : - return self.descrpt_a.nlist, self.descrpt_a.rij, self.descrpt_a.sel_a, self.descrpt_a.sel_r - - def get_nlist_r (self) : - return self.descrpt_r.nlist, self.descrpt_r.rij, self.descrpt_r.sel_a, self.descrpt_r.sel_r - - def compute_input_stats (self, - data_coord, - data_box, - data_atype, - natoms_vec, - mesh, - input_dict) : - self.descrpt_a.compute_input_stats(data_coord, data_box, data_atype, natoms_vec, mesh, input_dict) - self.descrpt_r.compute_input_stats(data_coord, data_box, data_atype, natoms_vec, mesh, input_dict) - self.davg = [self.descrpt_a.davg, self.descrpt_r.davg] - self.dstd = [self.descrpt_a.dstd, self.descrpt_r.dstd] - - - def build (self, - coord_, - atype_, - natoms, - box, - mesh, - input_dict, - suffix = '', - reuse = None): - davg = self.davg - dstd = self.dstd - if davg is None: - davg = [np.zeros([self.descrpt_a.ntypes, self.descrpt_a.ndescrpt]), - np.zeros([self.descrpt_r.ntypes, self.descrpt_r.ndescrpt])] - if dstd is None: - dstd = [np.ones ([self.descrpt_a.ntypes, self.descrpt_a.ndescrpt]), - np.ones ([self.descrpt_r.ntypes, self.descrpt_r.ndescrpt])] - # dout - self.dout_a = self.descrpt_a.build(coord_, atype_, natoms, box, mesh, input_dict, suffix=suffix+'_a', reuse=reuse) - self.dout_r = self.descrpt_r.build(coord_, atype_, natoms, box, mesh, input_dict, suffix=suffix , reuse=reuse) - self.dout_a = tf.reshape(self.dout_a, [-1, self.descrpt_a.get_dim_out()]) - self.dout_r = tf.reshape(self.dout_r, [-1, self.descrpt_r.get_dim_out()]) - self.dout = tf.concat([self.dout_a, self.dout_r], axis = 1) - self.dout = tf.reshape(self.dout, [-1, natoms[0] * self.get_dim_out()]) - - tf.summary.histogram('embedding_net_output', self.dout) - return self.dout - - - def prod_force_virial(self, atom_ener, natoms) : - f_a, v_a, av_a = self.descrpt_a.prod_force_virial(atom_ener, natoms) - f_r, v_r, av_r = self.descrpt_r.prod_force_virial(atom_ener, natoms) - force = f_a + f_r - virial = v_a + v_r - atom_virial = av_a + av_r - tf.summary.histogram('force', force) - tf.summary.histogram('virial', virial) - tf.summary.histogram('atom_virial', atom_virial) - return force, virial, atom_virial - - - - diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index a09bfecdec..e5c4d19a7b 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -222,18 +222,6 @@ def descrpt_se_r_args(): ] -def descrpt_se_ar_args(): - link = make_link('se_a', 'model/descriptor[se_a]') - doc_a = f'The parameters of descriptor {link}' - link = make_link('se_r', 'model/descriptor[se_r]') - doc_r = f'The parameters of descriptor {link}' - - return [ - Argument("a", dict, optional = False, doc = doc_a), - Argument("r", dict, optional = False, doc = doc_r), - ] - - @descrpt_args_plugin.register("hybrid") def descrpt_hybrid_args(): doc_list = f'A list of descriptor definitions' diff --git a/source/tests/test_descrpt_se_ar.py b/source/tests/test_descrpt_se_ar.py deleted file mode 100644 index 9ebb085bab..0000000000 --- a/source/tests/test_descrpt_se_ar.py +++ /dev/null @@ -1,118 +0,0 @@ -import os,sys -import numpy as np -import unittest - -from deepmd.env import tf -from tensorflow.python.framework import ops - -# load grad of force module -import deepmd.op - -from common import force_test -from common import virial_test -from common import force_dw_test -from common import virial_dw_test -from common import Data - -from deepmd.descriptor import DescrptSeAR - -from deepmd.env import GLOBAL_TF_FLOAT_PRECISION -from deepmd.env import GLOBAL_NP_FLOAT_PRECISION -from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION - -class Inter(): - def setUp (self, - data, - sess = None) : - self.sess = sess - self.data = data - self.natoms = self.data.get_natoms() - self.ntypes = self.data.get_ntypes() - param_a = { - 'sel' : [12,24], - 'rcut': 4, - 'rcut_smth' : 3.5, - 'neuron': [5, 10, 20], - 'seed': 1, - } - param_r = { - 'sel' : [20,40], - 'rcut': 6, - 'rcut_smth' : 6.5, - 'neuron': [10, 20, 40], - 'seed': 1, - } - param = {'a': param_a, 'r': param_r} - self.descrpt = DescrptSeAR(param) - self.ndescrpt = self.descrpt.get_dim_out() - # davg = np.zeros ([self.ntypes, self.ndescrpt]) - # dstd = np.ones ([self.ntypes, self.ndescrpt]) - # self.t_avg = tf.constant(davg.astype(np.float64)) - # self.t_std = tf.constant(dstd.astype(np.float64)) - avg_a = np.zeros([self.ntypes, self.descrpt.descrpt_a.ndescrpt]) - std_a = np.ones ([self.ntypes, self.descrpt.descrpt_a.ndescrpt]) - avg_r = np.zeros([self.ntypes, self.descrpt.descrpt_r.ndescrpt]) - std_r = np.ones ([self.ntypes, self.descrpt.descrpt_r.ndescrpt]) - self.avg = [avg_a, avg_r] - self.std = [std_a, std_r] - self.default_mesh = np.zeros (6, dtype = np.int32) - self.default_mesh[3] = 2 - self.default_mesh[4] = 2 - self.default_mesh[5] = 2 - # make place holder - self.coord = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, self.natoms[0] * 3], name='t_coord') - self.box = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, 9], name='t_box') - self.type = tf.placeholder(tf.int32, [None, self.natoms[0]], name = "t_type") - self.tnatoms = tf.placeholder(tf.int32, [None], name = "t_natoms") - self.efield = tf.placeholder(GLOBAL_TF_FLOAT_PRECISION, [None, self.natoms[0] * 3], name='t_efield') - - def _net (self, - inputs, - name, - reuse = False) : - with tf.variable_scope(name, reuse=reuse): - net_w = tf.get_variable ('net_w', - [self.descrpt.get_dim_out()], - GLOBAL_TF_FLOAT_PRECISION, - tf.constant_initializer (self.net_w_i)) - dot_v = tf.matmul (tf.reshape (inputs, [-1, self.descrpt.get_dim_out()]), - tf.reshape (net_w, [self.descrpt.get_dim_out(), 1])) - return tf.reshape (dot_v, [-1]) - - def comp_ef (self, - dcoord, - dbox, - dtype, - tnatoms, - name, - reuse = None) : - dout = self.descrpt.build(dcoord, dtype, tnatoms, dbox, self.default_mesh, {"efield": self.efield}, suffix=name, reuse=reuse) - inputs_reshape = tf.reshape (dout, [-1, self.descrpt.get_dim_out()]) - atom_ener = self._net (inputs_reshape, name, reuse = reuse) - atom_ener_reshape = tf.reshape(atom_ener, [-1, self.natoms[0]]) - energy = tf.reduce_sum (atom_ener_reshape, axis = 1) - force, virial, av = self.descrpt.prod_force_virial(atom_ener_reshape, tnatoms) - return energy, force, virial - - -class TestDescrptAR(Inter, tf.test.TestCase): - # def __init__ (self, *args, **kwargs): - # data = Data() - # Inter.__init__(self, data) - # tf.test.TestCase.__init__(self, *args, **kwargs) - # self.controller = object() - - def setUp(self): - self.places = 5 - data = Data() - Inter.setUp(self, data, sess=self.test_session().__enter__()) - - def test_force (self) : - force_test(self, self, suffix = '_se_ar') - - def test_virial (self) : - virial_test(self, self, suffix = '_se_ar') - - -if __name__ == '__main__': - unittest.main()