Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions deepmd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Root of the deepmd package, exposes all public classes and submodules."""

try:
from importlib import metadata
except ImportError: # for Python<3.8
import importlib_metadata as metadata
import deepmd.utils.network as network

from . import cluster, descriptor, fit, loss, utils
Expand All @@ -14,6 +18,10 @@
except ImportError:
from .__about__ import __version__

# load third-party plugins
for ep in metadata.entry_points().get('deepmd', []):
ep.load()

__all__ = [
"descriptor",
"fit",
Expand Down
2 changes: 1 addition & 1 deletion deepmd/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .descriptor import Descriptor
from .hybrid import DescrptHybrid
from .se_a import DescrptSeA
from .se_r import DescrptSeR
from .se_ar import DescrptSeAR
from .se_t import DescrptSeT
from .se_a_ebd import DescrptSeAEbd
from .se_a_ef import DescrptSeAEf
Expand Down
45 changes: 44 additions & 1 deletion deepmd/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,64 @@

import numpy as np
from deepmd.env import tf
from deepmd.utils import Plugin, PluginVariant


class Descriptor(ABC):
class Descriptor(PluginVariant):
r"""The abstract class for descriptors. All specific descriptors should
be based on this class.

The descriptor :math:`\mathcal{D}` describes the environment of an atom,
which should be a function of coordinates and types of its neighbour atoms.

Examples
--------
>>> descript = Descriptor(type="se_e2_a", rcut=6., rcut_smth=0.5, sel=[50])
>>> type(descript)
<class 'deepmd.descriptor.se_a.DescrptSeA'>

Notes
-----
Only methods and attributes defined in this class are generally public,
that can be called by other classes.
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> "Descriptor":
"""Regiester a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
Descriptor
the regiestered descriptor

Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return Descriptor.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
try:
descrpt_type = kwargs['type']
except KeyError:
raise KeyError('the type of descriptor should be set by `type`')
if descrpt_type in Descriptor.__plugins.plugins:
cls = Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError('Unknown descriptor type: ' + descrpt_type)
return super().__new__(cls)

@abstractmethod
def get_rcut(self) -> float:
"""
Expand Down
12 changes: 10 additions & 2 deletions deepmd/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from .descriptor import Descriptor
from .se_a import DescrptSeA
from .se_r import DescrptSeR
from .se_ar import DescrptSeAR
from .se_t import DescrptSeT
from .se_a_ebd import DescrptSeAEbd
from .se_a_ef import DescrptSeAEf
from .loc_frame import DescrptLocFrame

@Descriptor.register("hybrid")
class DescrptHybrid (Descriptor):
"""Concate a list of descriptors to form a new descriptor.

Expand All @@ -37,11 +37,19 @@ def __init__ (self,
"""
if descrpt_list == [] or descrpt_list is None:
raise RuntimeError('cannot build descriptor from an empty list of descriptors.')
formatted_descript_list = []
for ii in descrpt_list:
if isinstance(ii, Descriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
formatted_descript_list.append(Descriptor(**ii))
else:
raise NotImplementedError
# args = ClassArg()\
# .add('list', list, must = True)
# class_data = args.parse(jdata)
# dict_list = class_data['list']
self.descrpt_list = descrpt_list
self.descrpt_list = formatted_descript_list
self.numb_descrpt = len(self.descrpt_list)
for ii in range(1, self.numb_descrpt):
assert(self.descrpt_list[ii].get_ntypes() ==
Expand Down
2 changes: 2 additions & 0 deletions deepmd/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables
from .descriptor import Descriptor

@Descriptor.register("se_e2_a")
@Descriptor.register("se_a")
class DescrptSeA (Descriptor):
r"""DeepPot-SE constructed from all information (both angular and radial) of
atomic configurations. The embedding takes the distance between atoms as input.
Expand Down
3 changes: 3 additions & 0 deletions deepmd/descriptor/se_a_ebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from deepmd.env import default_tf_session_config
from deepmd.utils.network import embedding_net
from .se_a import DescrptSeA
from .descriptor import Descriptor

@Descriptor.register("se_a_tpe")
@Descriptor.register("se_a_ebd")
class DescrptSeAEbd (DescrptSeA):
"""DeepPot-SE descriptor with type embedding approach.

Expand Down
1 change: 1 addition & 0 deletions deepmd/descriptor/se_a_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .se_a import DescrptSeA
from .descriptor import Descriptor

@Descriptor.register("se_a_ef")
class DescrptSeAEf (Descriptor):
"""

Expand Down
94 changes: 0 additions & 94 deletions deepmd/descriptor/se_ar.py

This file was deleted.

2 changes: 2 additions & 0 deletions deepmd/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor

@Descriptor.register("se_e2_r")
@Descriptor.register("se_r")
class DescrptSeR (Descriptor):
"""DeepPot-SE constructed from radial information of atomic configurations.

Expand Down
3 changes: 3 additions & 0 deletions deepmd/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from deepmd.utils.sess import run_sess
from .descriptor import Descriptor

@Descriptor.register("se_e3")
@Descriptor.register("se_at")
@Descriptor.register("se_a_3be")
class DescrptSeT (Descriptor):
"""DeepPot-SE constructed from all information (both angular and radial) of atomic
configurations.
Expand Down
48 changes: 3 additions & 45 deletions deepmd/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python3
from deepmd.descriptor.descriptor import Descriptor
import logging
import os
import glob
Expand All @@ -11,14 +12,7 @@
from deepmd.env import GLOBAL_TF_FLOAT_PRECISION
from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION
from deepmd.fit import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, GlobalPolarFittingSeA, DipoleFittingSeA
from deepmd.descriptor import DescrptLocFrame
from deepmd.descriptor import DescrptSeA
from deepmd.descriptor import DescrptSeT
from deepmd.descriptor import DescrptSeAEbd
from deepmd.descriptor import DescrptSeAEf
from deepmd.descriptor import DescrptSeR
from deepmd.descriptor import DescrptSeAR
from deepmd.descriptor import DescrptHybrid
from deepmd.descriptor import Descriptor
from deepmd.model import EnerModel, WFCModel, DipoleModel, PolarModel, GlobalPolarModel
from deepmd.loss import EnerStdLoss, EnerDipoleLoss, TensorLoss
from deepmd.utils.errors import GraphTooLargeError
Expand Down Expand Up @@ -47,36 +41,6 @@ def _is_subdir(path, directory):
return False
relative = os.path.relpath(path, directory) + os.sep
return not relative.startswith(os.pardir + os.sep)

def _generate_descrpt_from_param_dict(descrpt_param):
try:
descrpt_type = descrpt_param['type']
except KeyError:
raise KeyError('the type of descriptor should be set by `type`')
descrpt_param.pop('type', None)
to_pop = []
for kk in descrpt_param:
if kk[0] == '_':
to_pop.append(kk)
for kk in to_pop:
descrpt_param.pop(kk, None)
if descrpt_type == 'loc_frame':
descrpt = DescrptLocFrame(**descrpt_param)
elif descrpt_type == 'se_e2_a' or descrpt_type == 'se_a' :
descrpt = DescrptSeA(**descrpt_param)
elif descrpt_type == 'se_e2_r' or descrpt_type == 'se_r' :
descrpt = DescrptSeR(**descrpt_param)
elif descrpt_type == 'se_e3' or descrpt_type == 'se_at' or descrpt_type == 'se_a_3be' :
descrpt = DescrptSeT(**descrpt_param)
elif descrpt_type == 'se_a_tpe' or descrpt_type == 'se_a_ebd' :
descrpt = DescrptSeAEbd(**descrpt_param)
elif descrpt_type == 'se_a_ef' :
descrpt = DescrptSeAEf(**descrpt_param)
elif descrpt_type == 'se_ar' :
descrpt = DescrptSeAR(descrpt_param)
else :
raise RuntimeError('unknow model type ' + descrpt_type)
return descrpt


class DPTrainer (object):
Expand All @@ -103,13 +67,7 @@ def _init_param(self, jdata):
except KeyError:
raise KeyError('the type of descriptor should be set by `type`')

if descrpt_type != 'hybrid':
self.descrpt = _generate_descrpt_from_param_dict(descrpt_param)
else :
descrpt_list = []
for ii in descrpt_param.get('list', []):
descrpt_list.append(_generate_descrpt_from_param_dict(ii))
self.descrpt = DescrptHybrid(descrpt_list)
self.descrpt = Descriptor(**descrpt_param)

# fitting net
fitting_type = fitting_param.get('type', 'ener')
Expand Down
1 change: 1 addition & 0 deletions deepmd/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .data_system import DataSystem
from .pair_tab import PairTab
from .learning_rate import LearningRateExp
from .plugin import Plugin, PluginVariant
Loading