From a483609a3644d14d22f786c4ec00a51123118647 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 22 Jul 2023 23:10:03 -0400 Subject: [PATCH 1/5] add a public method to register new DataType in a plugin Signed-off-by: Jinzhe Zeng --- dpdata/data_type.py | 127 ++++++++++++++++++++++++++++++++++++++++++++ dpdata/system.py | 115 +++------------------------------------ 2 files changed, 134 insertions(+), 108 deletions(-) create mode 100644 dpdata/data_type.py diff --git a/dpdata/data_type.py b/dpdata/data_type.py new file mode 100644 index 000000000..768397ced --- /dev/null +++ b/dpdata/data_type.py @@ -0,0 +1,127 @@ +from enum import Enum, unique +from typing import Tuple, TYPE_CHECKING + +import numpy as np + +from dpdata.plugin import Plugin + +if TYPE_CHECKING: + from dpdata.system import System + + +@unique +class Axis(Enum): + """Data axis.""" + + NFRAMES = "nframes" + NATOMS = "natoms" + NTYPES = "ntypes" + NBONDS = "nbonds" + + +class DataError(Exception): + """Data is not correct.""" + + +class DataType: + """DataType represents a type of data, like coordinates, energies, etc. + + Parameters + ---------- + name : str + name of data + dtype : type or tuple[type] + data type, e.g. np.ndarray + shape : tuple[int], optional + shape of data. Used when data is list or np.ndarray. Use Axis to + represents numbers + required : bool, default=True + whether this data is required + """ + + def __init__( + self, + name: str, + dtype: type, + shape: Tuple[int, Axis] = None, + required: bool = True, + ) -> None: + self.name = name + self.dtype = dtype + self.shape = shape + self.required = required + + def real_shape(self, system: "System") -> Tuple[int]: + """Returns expected real shape of a system.""" + shape = [] + for ii in self.shape: + if ii is Axis.NFRAMES: + shape.append(system.get_nframes()) + elif ii is Axis.NTYPES: + shape.append(system.get_ntypes()) + elif ii is Axis.NATOMS: + shape.append(system.get_natoms()) + elif ii is Axis.NBONDS: + # BondOrderSystem + shape.append(system.get_nbonds()) + elif isinstance(ii, int): + shape.append(ii) + else: + raise RuntimeError("Shape is not an int!") + return tuple(shape) + + def check(self, system: "System"): + """Check if a system has correct data of this type. + + Parameters + ---------- + system : System + checked system + + Raises + ------ + DataError + type or shape of data is not correct + """ + # check if exists + if self.name in system.data: + data = system.data[self.name] + # check dtype + # allow list for empty np.ndarray + if isinstance(data, list) and not len(data): + pass + elif not isinstance(data, self.dtype): + raise DataError( + f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}" + ) + # check shape + if self.shape is not None: + shape = self.real_shape(system) + # skip checking empty list of np.ndarray + if isinstance(data, np.ndarray): + if data.size and shape != data.shape: + raise DataError( + f"Shape of {self.name} is {data.shape}, but expected {shape}" + ) + elif isinstance(data, list): + if len(shape) and shape[0] != len(data): + raise DataError( + "Length of %s is %d, but expected %d" + % (self.name, len(data), shape[0]) + ) + else: + raise RuntimeError("Unsupported type to check shape") + elif self.required: + raise DataError("%s not found in data" % self.name) + +__system_data_type_plugin = Plugin() +__labeled_system_data_type_plugin = Plugin() + +def register_data_type(data_type: DataType, labeled: bool): + plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin + plugin.register(data_type.name)(data_type) + + +def get_data_types(labeled: bool): + plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin + return tuple(plugin.plugins.values()) diff --git a/dpdata/system.py b/dpdata/system.py index bfec97c6b..257ef7ab0 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -19,7 +19,7 @@ from dpdata.format import Format from dpdata.plugin import Plugin from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names - +from dpdata.data_type import Axis, DataType, DataError, get_data_types def load_format(fmt): fmt = fmt.lower() @@ -33,112 +33,6 @@ def load_format(fmt): ) -@unique -class Axis(Enum): - """Data axis.""" - - NFRAMES = "nframes" - NATOMS = "natoms" - NTYPES = "ntypes" - NBONDS = "nbonds" - - -class DataError(Exception): - """Data is not correct.""" - - -class DataType: - """DataType represents a type of data, like coordinates, energies, etc. - - Parameters - ---------- - name : str - name of data - dtype : type or tuple[type] - data type, e.g. np.ndarray - shape : tuple[int], optional - shape of data. Used when data is list or np.ndarray. Use Axis to - represents numbers - required : bool, default=True - whether this data is required - """ - - def __init__( - self, - name: str, - dtype: type, - shape: Tuple[int, Axis] = None, - required: bool = True, - ) -> None: - self.name = name - self.dtype = dtype - self.shape = shape - self.required = required - - def real_shape(self, system: "System") -> Tuple[int]: - """Returns expected real shape of a system.""" - shape = [] - for ii in self.shape: - if ii is Axis.NFRAMES: - shape.append(system.get_nframes()) - elif ii is Axis.NTYPES: - shape.append(system.get_ntypes()) - elif ii is Axis.NATOMS: - shape.append(system.get_natoms()) - elif ii is Axis.NBONDS: - # BondOrderSystem - shape.append(system.get_nbonds()) - elif isinstance(ii, int): - shape.append(ii) - else: - raise RuntimeError("Shape is not an int!") - return tuple(shape) - - def check(self, system: "System"): - """Check if a system has correct data of this type. - - Parameters - ---------- - system : System - checked system - - Raises - ------ - DataError - type or shape of data is not correct - """ - # check if exists - if self.name in system.data: - data = system.data[self.name] - # check dtype - # allow list for empty np.ndarray - if isinstance(data, list) and not len(data): - pass - elif not isinstance(data, self.dtype): - raise DataError( - f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}" - ) - # check shape - if self.shape is not None: - shape = self.real_shape(system) - # skip checking empty list of np.ndarray - if isinstance(data, np.ndarray): - if data.size and shape != data.shape: - raise DataError( - f"Shape of {self.name} is {data.shape}, but expected {shape}" - ) - elif isinstance(data, list): - if len(shape) and shape[0] != len(data): - raise DataError( - "Length of %s is %d, but expected %d" - % (self.name, len(data), shape[0]) - ) - else: - raise RuntimeError("Unsupported type to check shape") - elif self.required: - raise DataError("%s not found in data" % self.name) - - class System(MSONable): """The data System. @@ -1657,7 +1551,8 @@ def get_cls_name(cls: object) -> str: def add_format_methods(): - """Add format methods to System, LabeledSystem, and MultiSystems. + """Add format methods to System, LabeledSystem, and MultiSystems; add data types + to System and LabeledSystem. Notes ----- @@ -1701,5 +1596,9 @@ def to_format(self, *args, **kwargs): setattr(LabeledSystem, method, get_func(formatcls)) setattr(MultiSystems, method, get_func(formatcls)) + # at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized + System.DTYPES = System.DTYPES + get_data_types(labeled=False) + LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False) + LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True) add_format_methods() From a771c2ceb3a6e46088ba42a0b265bb8fdabbc34a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 23 Jul 2023 22:51:52 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/data_type.py | 6 ++++-- dpdata/system.py | 5 +++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 768397ced..4dee92336 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -1,5 +1,5 @@ from enum import Enum, unique -from typing import Tuple, TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple import numpy as np @@ -113,10 +113,12 @@ def check(self, system: "System"): raise RuntimeError("Unsupported type to check shape") elif self.required: raise DataError("%s not found in data" % self.name) - + + __system_data_type_plugin = Plugin() __labeled_system_data_type_plugin = Plugin() + def register_data_type(data_type: DataType, labeled: bool): plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin plugin.register(data_type.name)(data_type) diff --git a/dpdata/system.py b/dpdata/system.py index 257ef7ab0..7b68a9e10 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -2,7 +2,6 @@ import glob import os from copy import deepcopy -from enum import Enum, unique from typing import Any, Dict, Optional, Tuple, Union import numpy as np @@ -15,11 +14,12 @@ # ensure all plugins are loaded! import dpdata.plugins from dpdata.amber.mask import load_param_file, pick_by_amber_mask +from dpdata.data_type import Axis, DataError, DataType, get_data_types from dpdata.driver import Driver, Minimizer from dpdata.format import Format from dpdata.plugin import Plugin from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names -from dpdata.data_type import Axis, DataType, DataError, get_data_types + def load_format(fmt): fmt = fmt.lower() @@ -1601,4 +1601,5 @@ def to_format(self, *args, **kwargs): LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False) LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True) + add_format_methods() From db37dd7d5699e584d0c34e97536c4172a8411fbe Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 24 Jul 2023 14:47:12 -0400 Subject: [PATCH 3/5] add tests and docs Signed-off-by: Jinzhe Zeng --- .github/workflows/test.yml | 2 +- dpdata/data_type.py | 16 ++++++++++++++++ tests/plugin/dpdata_plugin_test/__init__.py | 10 ++++++++++ tests/plugin/pyproject.toml | 17 +++++++++++++++++ tests/test_custom_data_type.py | 7 ------- 5 files changed, 44 insertions(+), 8 deletions(-) create mode 100644 tests/plugin/dpdata_plugin_test/__init__.py create mode 100644 tests/plugin/pyproject.toml diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0ab9509f4..1bc8872aa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - name: Install rdkit run: python -m pip install rdkit openbabel-wheel - name: Install dependencies - run: python -m pip install .[amber,ase,pymatgen] coverage "pymatgen==2023.7.11;python_version>='3.8'" + run: python -m pip install .[amber,ase,pymatgen] coverage ./tests/dpdata_plugin_test "pymatgen==2023.7.11;python_version>='3.8'" - name: Test run: cd tests && coverage run --source=../dpdata -m unittest && cd .. && coverage combine tests/.coverage && coverage report - name: Run codecov diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 4dee92336..4db594f7c 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -120,10 +120,26 @@ def check(self, system: "System"): def register_data_type(data_type: DataType, labeled: bool): + """Register a data type. + + Parameters + ---------- + data_type : DataType + data type to be registered + labeled : bool + whether this data type is for LabeledSystem + """ plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin plugin.register(data_type.name)(data_type) def get_data_types(labeled: bool): + """Get all registered data types. + + Parameters + ---------- + labeled : bool + whether this data type is for LabeledSystem + """ plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin return tuple(plugin.plugins.values()) diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py new file mode 100644 index 000000000..6bc59a042 --- /dev/null +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -0,0 +1,10 @@ +import numpy as np + +from dpdata.data_type import DataType, Axis, register_data_type + + +# test data type + +register_data_type(DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True) + +ep = None \ No newline at end of file diff --git a/tests/plugin/pyproject.toml b/tests/plugin/pyproject.toml new file mode 100644 index 000000000..7ce1f854a --- /dev/null +++ b/tests/plugin/pyproject.toml @@ -0,0 +1,17 @@ +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[project] +name = "dpdata_plugin_test" +version = "0.0.0" +description = "A test for dpdata plugin" +dependencies = [ + 'numpy', + 'dpdata', +] +readme = "README.md" +requires-python = ">=3.7" + +[project.entry-points.'dpdata.plugins'] +random = "dpdata_plugin_test:ep" diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 58b9f5e49..b2a236410 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -9,18 +9,11 @@ class TestDeepmdLoadDumpComp(unittest.TestCase): def setUp(self): - self.backup = dpdata.system.LabeledSystem.DTYPES - dpdata.system.LabeledSystem.DTYPES = dpdata.system.LabeledSystem.DTYPES + ( - DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), - ) self.system = dpdata.LabeledSystem("poscars/OUTCAR.h2o.md", fmt="vasp/outcar") self.foo = np.ones((len(self.system), 2, 4)) self.system.data["foo"] = self.foo self.system.check_data() - def tearDown(self) -> None: - dpdata.system.LabeledSystem.DTYPES = self.backup - def test_to_deepmd_raw(self): self.system.to_deepmd_raw("data_foo") foo = np.loadtxt("data_foo/foo.raw") From a30c4d92ea3d26fb877f99697485194b03f6a15f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jul 2023 18:48:14 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- dpdata/data_type.py | 4 ++-- tests/plugin/dpdata_plugin_test/__init__.py | 9 +++++---- tests/test_custom_data_type.py | 1 - 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 4db594f7c..c0bd944fb 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -121,7 +121,7 @@ def check(self, system: "System"): def register_data_type(data_type: DataType, labeled: bool): """Register a data type. - + Parameters ---------- data_type : DataType @@ -135,7 +135,7 @@ def register_data_type(data_type: DataType, labeled: bool): def get_data_types(labeled: bool): """Get all registered data types. - + Parameters ---------- labeled : bool diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py index 6bc59a042..29911b5f1 100644 --- a/tests/plugin/dpdata_plugin_test/__init__.py +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -1,10 +1,11 @@ import numpy as np -from dpdata.data_type import DataType, Axis, register_data_type - +from dpdata.data_type import Axis, DataType, register_data_type # test data type -register_data_type(DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True) +register_data_type( + DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False), labeled=True +) -ep = None \ No newline at end of file +ep = None diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index b2a236410..49c99586f 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -4,7 +4,6 @@ import numpy as np import dpdata -from dpdata.system import Axis, DataType class TestDeepmdLoadDumpComp(unittest.TestCase): From 322793dd0bcbd0682fe2703c1f564808234d9959 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 24 Jul 2023 14:49:38 -0400 Subject: [PATCH 5/5] fix test cmd Signed-off-by: Jinzhe Zeng --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e63589c32..765db654c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -22,7 +22,7 @@ jobs: - name: Install rdkit run: python -m pip install rdkit openbabel-wheel - name: Install dependencies - run: python -m pip install .[amber,ase,pymatgen] coverage ./tests/dpdata_plugin_test + run: python -m pip install .[amber,ase,pymatgen] coverage ./tests/plugin - name: Test run: cd tests && coverage run --source=../dpdata -m unittest && cd .. && coverage combine tests/.coverage && coverage report - name: Run codecov