diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9e2cbf607..765db654c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - name: Install rdkit run: python -m pip install rdkit openbabel-wheel - name: Install dependencies - run: python -m pip install .[amber,ase,pymatgen] coverage + run: python -m pip install .[amber,ase,pymatgen] coverage ./tests/plugin - name: Test run: cd tests && coverage run --source=../dpdata -m unittest && cd .. && coverage combine tests/.coverage && coverage report - name: Run codecov diff --git a/dpdata/data_type.py b/dpdata/data_type.py new file mode 100644 index 000000000..c0bd944fb --- /dev/null +++ b/dpdata/data_type.py @@ -0,0 +1,145 @@ +from enum import Enum, unique +from typing import TYPE_CHECKING, Tuple + +import numpy as np + +from dpdata.plugin import Plugin + +if TYPE_CHECKING: + from dpdata.system import System + + +@unique +class Axis(Enum): + """Data axis.""" + + NFRAMES = "nframes" + NATOMS = "natoms" + NTYPES = "ntypes" + NBONDS = "nbonds" + + +class DataError(Exception): + """Data is not correct.""" + + +class DataType: + """DataType represents a type of data, like coordinates, energies, etc. + + Parameters + ---------- + name : str + name of data + dtype : type or tuple[type] + data type, e.g. np.ndarray + shape : tuple[int], optional + shape of data. Used when data is list or np.ndarray. Use Axis to + represents numbers + required : bool, default=True + whether this data is required + """ + + def __init__( + self, + name: str, + dtype: type, + shape: Tuple[int, Axis] = None, + required: bool = True, + ) -> None: + self.name = name + self.dtype = dtype + self.shape = shape + self.required = required + + def real_shape(self, system: "System") -> Tuple[int]: + """Returns expected real shape of a system.""" + shape = [] + for ii in self.shape: + if ii is Axis.NFRAMES: + shape.append(system.get_nframes()) + elif ii is Axis.NTYPES: + shape.append(system.get_ntypes()) + elif ii is Axis.NATOMS: + shape.append(system.get_natoms()) + elif ii is Axis.NBONDS: + # BondOrderSystem + shape.append(system.get_nbonds()) + elif isinstance(ii, int): + shape.append(ii) + else: + raise RuntimeError("Shape is not an int!") + return tuple(shape) + + def check(self, system: "System"): + """Check if a system has correct data of this type. + + Parameters + ---------- + system : System + checked system + + Raises + ------ + DataError + type or shape of data is not correct + """ + # check if exists + if self.name in system.data: + data = system.data[self.name] + # check dtype + # allow list for empty np.ndarray + if isinstance(data, list) and not len(data): + pass + elif not isinstance(data, self.dtype): + raise DataError( + f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}" + ) + # check shape + if self.shape is not None: + shape = self.real_shape(system) + # skip checking empty list of np.ndarray + if isinstance(data, np.ndarray): + if data.size and shape != data.shape: + raise DataError( + f"Shape of {self.name} is {data.shape}, but expected {shape}" + ) + elif isinstance(data, list): + if len(shape) and shape[0] != len(data): + raise DataError( + "Length of %s is %d, but expected %d" + % (self.name, len(data), shape[0]) + ) + else: + raise RuntimeError("Unsupported type to check shape") + elif self.required: + raise DataError("%s not found in data" % self.name) + + +__system_data_type_plugin = Plugin() +__labeled_system_data_type_plugin = Plugin() + + +def register_data_type(data_type: DataType, labeled: bool): + """Register a data type. + + Parameters + ---------- + data_type : DataType + data type to be registered + labeled : bool + whether this data type is for LabeledSystem + """ + plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin + plugin.register(data_type.name)(data_type) + + +def get_data_types(labeled: bool): + """Get all registered data types. + + Parameters + ---------- + labeled : bool + whether this data type is for LabeledSystem + """ + plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin + return tuple(plugin.plugins.values()) diff --git a/dpdata/system.py b/dpdata/system.py index bfec97c6b..7b68a9e10 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -2,7 +2,6 @@ import glob import os from copy import deepcopy -from enum import Enum, unique from typing import Any, Dict, Optional, Tuple, Union import numpy as np @@ -15,6 +14,7 @@ # ensure all plugins are loaded! import dpdata.plugins from dpdata.amber.mask import load_param_file, pick_by_amber_mask +from dpdata.data_type import Axis, DataError, DataType, get_data_types from dpdata.driver import Driver, Minimizer from dpdata.format import Format from dpdata.plugin import Plugin @@ -33,112 +33,6 @@ def load_format(fmt): ) -@unique -class Axis(Enum): - """Data axis.""" - - NFRAMES = "nframes" - NATOMS = "natoms" - NTYPES = "ntypes" - NBONDS = "nbonds" - - -class DataError(Exception): - """Data is not correct.""" - - -class DataType: - """DataType represents a type of data, like coordinates, energies, etc. - - Parameters - ---------- - name : str - name of data - dtype : type or tuple[type] - data type, e.g. np.ndarray - shape : tuple[int], optional - shape of data. Used when data is list or np.ndarray. Use Axis to - represents numbers - required : bool, default=True - whether this data is required - """ - - def __init__( - self, - name: str, - dtype: type, - shape: Tuple[int, Axis] = None, - required: bool = True, - ) -> None: - self.name = name - self.dtype = dtype - self.shape = shape - self.required = required - - def real_shape(self, system: "System") -> Tuple[int]: - """Returns expected real shape of a system.""" - shape = [] - for ii in self.shape: - if ii is Axis.NFRAMES: - shape.append(system.get_nframes()) - elif ii is Axis.NTYPES: - shape.append(system.get_ntypes()) - elif ii is Axis.NATOMS: - shape.append(system.get_natoms()) - elif ii is Axis.NBONDS: - # BondOrderSystem - shape.append(system.get_nbonds()) - elif isinstance(ii, int): - shape.append(ii) - else: - raise RuntimeError("Shape is not an int!") - return tuple(shape) - - def check(self, system: "System"): - """Check if a system has correct data of this type. - - Parameters - ---------- - system : System - checked system - - Raises - ------ - DataError - type or shape of data is not correct - """ - # check if exists - if self.name in system.data: - data = system.data[self.name] - # check dtype - # allow list for empty np.ndarray - if isinstance(data, list) and not len(data): - pass - elif not isinstance(data, self.dtype): - raise DataError( - f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}" - ) - # check shape - if self.shape is not None: - shape = self.real_shape(system) - # skip checking empty list of np.ndarray - if isinstance(data, np.ndarray): - if data.size and shape != data.shape: - raise DataError( - f"Shape of {self.name} is {data.shape}, but expected {shape}" - ) - elif isinstance(data, list): - if len(shape) and shape[0] != len(data): - raise DataError( - "Length of %s is %d, but expected %d" - % (self.name, len(data), shape[0]) - ) - else: - raise RuntimeError("Unsupported type to check shape") - elif self.required: - raise DataError("%s not found in data" % self.name) - - class System(MSONable): """The data System. @@ -1657,7 +1551,8 @@ def get_cls_name(cls: object) -> str: def add_format_methods(): - """Add format methods to System, LabeledSystem, and MultiSystems. + """Add format methods to System, LabeledSystem, and MultiSystems; add data types + to System and LabeledSystem. Notes ----- @@ -1701,5 +1596,10 @@ def to_format(self, *args, **kwargs): setattr(LabeledSystem, method, get_func(formatcls)) setattr(MultiSystems, method, get_func(formatcls)) + # at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized + System.DTYPES = System.DTYPES + get_data_types(labeled=False) + LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False) + LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True) + add_format_methods() diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py new file mode 100644 index 000000000..29911b5f1 --- /dev/null +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -0,0 +1,11 @@ +import numpy as np + +from dpdata.data_type import Axis, DataType, register_data_type + +# test data type + +register_data_type( + DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True +) + +ep = None diff --git a/tests/plugin/pyproject.toml b/tests/plugin/pyproject.toml new file mode 100644 index 000000000..7ce1f854a --- /dev/null +++ b/tests/plugin/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[project] +name = "dpdata_plugin_test" +version = "0.0.0" +description = "A test for dpdata plugin" +dependencies = [ + 'numpy', + 'dpdata', +] +readme = "README.md" +requires-python = ">=3.7" + +[project.entry-points.'dpdata.plugins'] +random = "dpdata_plugin_test:ep" diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 58b9f5e49..49c99586f 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -4,23 +4,15 @@ import numpy as np import dpdata -from dpdata.system import Axis, DataType class TestDeepmdLoadDumpComp(unittest.TestCase): def setUp(self): - self.backup = dpdata.system.LabeledSystem.DTYPES - dpdata.system.LabeledSystem.DTYPES = dpdata.system.LabeledSystem.DTYPES + ( - DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), - ) self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") self.foo = np.ones((len(self.system), 2, 4)) self.system.data["foo"] = self.foo self.system.check_data() - def tearDown(self) -> None: - dpdata.system.LabeledSystem.DTYPES = self.backup - def test_to_deepmd_raw(self): self.system.to_deepmd_raw("data_foo") foo = np.loadtxt("data_foo/foo.raw")